Tpetra parallel linear algebra Version of the Day
Loading...
Searching...
No Matches
Tpetra_Details_leftScaleLocalCrsMatrix.hpp
Go to the documentation of this file.
1// @HEADER
2// *****************************************************************************
3// Tpetra: Templated Linear Algebra Services Package
4//
5// Copyright 2008 NTESS and the Tpetra contributors.
6// SPDX-License-Identifier: BSD-3-Clause
7// *****************************************************************************
8// @HEADER
9
10#ifndef TPETRA_DETAILS_LEFTSCALELOCALCRSMATRIX_HPP
11#define TPETRA_DETAILS_LEFTSCALELOCALCRSMATRIX_HPP
12
19
20#include "TpetraCore_config.h"
21#include "Kokkos_Core.hpp"
22#if KOKKOS_VERSION >= 40799
23#include "KokkosKernels_ArithTraits.hpp"
24#else
25#include "Kokkos_ArithTraits.hpp"
26#endif
27#include <type_traits>
28
29namespace Tpetra {
30namespace Details {
31
40template <class LocalSparseMatrixType,
41 class ScalingFactorsViewType,
42 const bool divide>
44 public:
45 using val_type =
46 typename std::remove_const<typename LocalSparseMatrixType::value_type>::type;
47 using mag_type = typename ScalingFactorsViewType::non_const_value_type;
48 static_assert(ScalingFactorsViewType::rank == 1,
49 "scalingFactors must be a rank-1 Kokkos::View.");
50 using device_type = typename LocalSparseMatrixType::device_type;
51 using LO = typename LocalSparseMatrixType::ordinal_type;
52 using policy_type = Kokkos::TeamPolicy<typename device_type::execution_space, LO>;
53
64 const bool assumeSymmetric)
65 : A_lcl_(A_lcl)
66 , scalingFactors_(scalingFactors)
67 , assumeSymmetric_(assumeSymmetric) {}
68
70 operator()(const typename policy_type::member_type& team) const {
71#if KOKKOS_VERSION >= 40799
72 using KAM = KokkosKernels::ArithTraits<mag_type>;
73#else
74 using KAM = Kokkos::ArithTraits<mag_type>;
75#endif
76
77 const LO lclRow = team.league_rank();
78 const mag_type curRowNorm = scalingFactors_(lclRow);
79 // Users are responsible for any divisions or multiplications by
80 // zero.
81 const mag_type scalingFactor = assumeSymmetric_ ? KAM::sqrt(curRowNorm) : curRowNorm;
82 auto curRow = A_lcl_.row(lclRow);
83 const LO numEnt = curRow.length;
84 Kokkos::parallel_for(Kokkos::TeamThreadRange(team, numEnt), [&](const LO k) {
85 if (divide) { // constexpr, so should get compiled out
86 curRow.value(k) = curRow.value(k) / scalingFactor;
87 } else {
88 curRow.value(k) = curRow.value(k) * scalingFactor;
89 }
90 });
91 }
92
93 private:
94 LocalSparseMatrixType A_lcl_;
95 typename ScalingFactorsViewType::const_type scalingFactors_;
96 bool assumeSymmetric_;
97};
98
112template <class LocalSparseMatrixType, class ScalingFactorsViewType>
115 const bool assumeSymmetric,
116 const bool divide = true) {
117 using device_type = typename LocalSparseMatrixType::device_type;
118 using execution_space = typename device_type::execution_space;
119 using LO = typename LocalSparseMatrixType::ordinal_type;
120 using policy_type = Kokkos::TeamPolicy<execution_space, LO>;
121
122 const LO lclNumRows = A_lcl.numRows();
123 if (divide) {
124 using functor_type =
126 typename ScalingFactorsViewType::const_type, true>;
127 functor_type functor(A_lcl, scalingFactors, assumeSymmetric);
128 Kokkos::parallel_for("leftScaleLocalCrsMatrix",
129 policy_type(lclNumRows, Kokkos::AUTO), functor);
130 } else {
131 using functor_type =
133 typename ScalingFactorsViewType::const_type, false>;
134 functor_type functor(A_lcl, scalingFactors, assumeSymmetric);
135 Kokkos::parallel_for("leftScaleLocalCrsMatrix",
136 policy_type(lclNumRows, Kokkos::AUTO), functor);
137 }
138}
139
140} // namespace Details
141} // namespace Tpetra
142
143#endif // TPETRA_DETAILS_LEFTSCALELOCALCRSMATRIX_HPP
Struct that holds views of the contents of a CrsMatrix.
Kokkos::parallel_for functor that left-scales a KokkosSparse::CrsMatrix.
LeftScaleLocalCrsMatrix(const LocalSparseMatrixType &A_lcl, const ScalingFactorsViewType &scalingFactors, const bool assumeSymmetric)
Implementation details of Tpetra.
void leftScaleLocalCrsMatrix(const LocalSparseMatrixType &A_lcl, const ScalingFactorsViewType &scalingFactors, const bool assumeSymmetric, const bool divide=true)
Left-scale a KokkosSparse::CrsMatrix.
Namespace Tpetra contains the class and methods constituting the Tpetra library.