Tpetra parallel linear algebra Version of the Day
Loading...
Searching...
No Matches
Tpetra_Details_rightScaleLocalCrsMatrix.hpp
Go to the documentation of this file.
1// @HEADER
2// *****************************************************************************
3// Tpetra: Templated Linear Algebra Services Package
4//
5// Copyright 2008 NTESS and the Tpetra contributors.
6// SPDX-License-Identifier: BSD-3-Clause
7// *****************************************************************************
8// @HEADER
9
10#ifndef TPETRA_DETAILS_RIGHTSCALELOCALCRSMATRIX_HPP
11#define TPETRA_DETAILS_RIGHTSCALELOCALCRSMATRIX_HPP
12
19
20#include "TpetraCore_config.h"
21#include "Kokkos_Core.hpp"
22#include "KokkosKernels_ArithTraits.hpp"
23#include <type_traits>
24
25namespace Tpetra {
26namespace Details {
27
36template <class LocalSparseMatrixType,
37 class ScalingFactorsViewType,
38 const bool divide>
40 public:
41 using val_type =
42 typename std::remove_const<typename LocalSparseMatrixType::value_type>::type;
43 using mag_type = typename ScalingFactorsViewType::non_const_value_type;
44 static_assert(ScalingFactorsViewType::rank == 1,
45 "scalingFactors must be a rank-1 Kokkos::View.");
46 using device_type = typename LocalSparseMatrixType::device_type;
47 using LO = typename LocalSparseMatrixType::ordinal_type;
48 using policy_type = Kokkos::TeamPolicy<typename device_type::execution_space, LO>;
49
62 const bool assumeSymmetric)
63 : A_lcl_(A_lcl)
64 , scalingFactors_(scalingFactors)
65 , assumeSymmetric_(assumeSymmetric) {}
66
68 operator()(const typename policy_type::member_type& team) const {
69 using KAM = KokkosKernels::ArithTraits<mag_type>;
70
71 const LO lclRow = team.league_rank();
72 auto curRow = A_lcl_.row(lclRow);
73 const LO numEnt = curRow.length;
74 Kokkos::parallel_for(Kokkos::TeamThreadRange(team, numEnt), [&](const LO k) {
75 const LO lclColInd = curRow.colidx(k);
76 const mag_type curColNorm = scalingFactors_(lclColInd);
77 // Users are responsible for any divisions or multiplications by
78 // zero.
79 const mag_type scalingFactor = assumeSymmetric_ ? KAM::sqrt(curColNorm) : curColNorm;
80 if (divide) { // constexpr, so should get compiled out
81 curRow.value(k) = curRow.value(k) / scalingFactor;
82 } else {
83 curRow.value(k) = curRow.value(k) * scalingFactor;
84 }
85 });
86 }
87
88 private:
89 LocalSparseMatrixType A_lcl_;
90 typename ScalingFactorsViewType::const_type scalingFactors_;
91 bool assumeSymmetric_;
92};
93
107template <class LocalSparseMatrixType, class ScalingFactorsViewType>
110 const bool assumeSymmetric,
111 const bool divide = true) {
112 using device_type = typename LocalSparseMatrixType::device_type;
113 using execution_space = typename device_type::execution_space;
114 using LO = typename LocalSparseMatrixType::ordinal_type;
115 using policy_type = Kokkos::TeamPolicy<execution_space, LO>;
116
117 const LO lclNumRows = A_lcl.numRows();
118 if (divide) {
119 using functor_type =
121 typename ScalingFactorsViewType::const_type, true>;
122 functor_type functor(A_lcl, scalingFactors, assumeSymmetric);
123 Kokkos::parallel_for("rightScaleLocalCrsMatrix",
124 policy_type(lclNumRows, Kokkos::AUTO), functor);
125 } else {
126 using functor_type =
128 typename ScalingFactorsViewType::const_type, false>;
129 functor_type functor(A_lcl, scalingFactors, assumeSymmetric);
130 Kokkos::parallel_for("rightScaleLocalCrsMatrix",
131 policy_type(lclNumRows, Kokkos::AUTO), functor);
132 }
133}
134
135} // namespace Details
136} // namespace Tpetra
137
138#endif // TPETRA_DETAILS_RIGHTSCALELOCALCRSMATRIX_HPP
Struct that holds views of the contents of a CrsMatrix.
Kokkos::parallel_for functor that right-scales a KokkosSparse::CrsMatrix.
RightScaleLocalCrsMatrix(const LocalSparseMatrixType &A_lcl, const ScalingFactorsViewType &scalingFactors, const bool assumeSymmetric)
Implementation details of Tpetra.
void rightScaleLocalCrsMatrix(const LocalSparseMatrixType &A_lcl, const ScalingFactorsViewType &scalingFactors, const bool assumeSymmetric, const bool divide=true)
Right-scale a KokkosSparse::CrsMatrix.
Namespace Tpetra contains the class and methods constituting the Tpetra library.