Ifpack2 Templated Preconditioning Package Version 1.0
Loading...
Searching...
No Matches
Ifpack2_Details_ChebyshevKernel_def.hpp
1// @HEADER
2// *****************************************************************************
3// Ifpack2: Templated Object-Oriented Algebraic Preconditioner Package
4//
5// Copyright 2009 NTESS and the Ifpack2 contributors.
6// SPDX-License-Identifier: BSD-3-Clause
7// *****************************************************************************
8// @HEADER
9
10#ifndef IFPACK2_DETAILS_CHEBYSHEVKERNEL_DEF_HPP
11#define IFPACK2_DETAILS_CHEBYSHEVKERNEL_DEF_HPP
12
13#include "Tpetra_CrsMatrix.hpp"
14#include "Tpetra_MultiVector.hpp"
15#include "Tpetra_Operator.hpp"
16#include "Tpetra_Vector.hpp"
17#include "Tpetra_Export_decl.hpp"
18#include "Tpetra_Import_decl.hpp"
19#if KOKKOS_VERSION >= 40799
20#include "KokkosKernels_ArithTraits.hpp"
21#else
22#include "Kokkos_ArithTraits.hpp"
23#endif
24#include "Teuchos_Assert.hpp"
25#include <type_traits>
26#include "KokkosSparse_spmv_impl.hpp"
27
28namespace Ifpack2 {
29namespace Details {
30namespace Impl {
31
36template <class WVector,
37 class DVector,
38 class BVector,
39 class AMatrix,
40 class XVector_colMap,
41 class XVector_domMap,
42 class Scalar,
43 bool use_beta,
44 bool do_X_update>
46 static_assert(static_cast<int>(WVector::rank) == 1,
47 "WVector must be a rank 1 View.");
48 static_assert(static_cast<int>(DVector::rank) == 1,
49 "DVector must be a rank 1 View.");
50 static_assert(static_cast<int>(BVector::rank) == 1,
51 "BVector must be a rank 1 View.");
52 static_assert(static_cast<int>(XVector_colMap::rank) == 1,
53 "XVector_colMap must be a rank 1 View.");
54 static_assert(static_cast<int>(XVector_domMap::rank) == 1,
55 "XVector_domMap must be a rank 1 View.");
56
57 using execution_space = typename AMatrix::execution_space;
58 using LO = typename AMatrix::non_const_ordinal_type;
59 using value_type = typename AMatrix::non_const_value_type;
60 using team_policy = typename Kokkos::TeamPolicy<execution_space>;
61 using team_member = typename team_policy::member_type;
62#if KOKKOS_VERSION >= 40799
63 using ATV = KokkosKernels::ArithTraits<value_type>;
64#else
65 using ATV = Kokkos::ArithTraits<value_type>;
66#endif
67
68 const Scalar alpha;
69 WVector m_w;
70 DVector m_d;
71 BVector m_b;
72 AMatrix m_A;
73 XVector_colMap m_x_colMap;
74 XVector_domMap m_x_domMap;
75 const Scalar beta;
76
77 const LO rows_per_team;
78
80 const WVector& m_w_,
81 const DVector& m_d_,
82 const BVector& m_b_,
83 const AMatrix& m_A_,
86 const Scalar& beta_,
87 const int rows_per_team_)
88 : alpha(alpha_)
89 , m_w(m_w_)
90 , m_d(m_d_)
91 , m_b(m_b_)
92 , m_A(m_A_)
93 , m_x_colMap(m_x_colMap_)
94 , m_x_domMap(m_x_domMap_)
95 , beta(beta_)
96 , rows_per_team(rows_per_team_) {
97 const size_t numRows = m_A.numRows();
98 const size_t numCols = m_A.numCols();
99
100 TEUCHOS_ASSERT(m_w.extent(0) == m_d.extent(0));
101 TEUCHOS_ASSERT(m_w.extent(0) == m_b.extent(0));
102 TEUCHOS_ASSERT(numRows == size_t(m_w.extent(0)));
103 TEUCHOS_ASSERT(numCols <= size_t(m_x_colMap.extent(0)));
104 TEUCHOS_ASSERT(numRows <= size_t(m_x_domMap.extent(0)));
105 }
106
108 void operator()(const team_member& dev) const {
109 using residual_value_type = typename BVector::non_const_value_type;
110#if KOKKOS_VERSION >= 40799
111 using KAT = KokkosKernels::ArithTraits<residual_value_type>;
112#else
113 using KAT = Kokkos::ArithTraits<residual_value_type>;
114#endif
115
116 Kokkos::parallel_for(Kokkos::TeamThreadRange(dev, 0, rows_per_team),
117 [&](const LO& loop) {
118 const LO lclRow =
119 static_cast<LO>(dev.league_rank()) * rows_per_team + loop;
120 if (lclRow >= m_A.numRows()) {
121 return;
122 }
123 const KokkosSparse::SparseRowViewConst<AMatrix> A_row = m_A.rowConst(lclRow);
124 const LO row_length = static_cast<LO>(A_row.length);
125 residual_value_type A_x = KAT::zero();
126
127 Kokkos::parallel_reduce(
128 Kokkos::ThreadVectorRange(dev, row_length),
129 [&](const LO iEntry, residual_value_type& lsum) {
130 const auto A_val = A_row.value(iEntry);
131 lsum += A_val * m_x_colMap(A_row.colidx(iEntry));
132 },
133 A_x);
134
135 Kokkos::single(Kokkos::PerThread(dev),
136 [&]() {
137 const auto alpha_D_res =
138 alpha * m_d(lclRow) * (m_b(lclRow) - A_x);
139 if (use_beta) {
140 m_w(lclRow) = beta * m_w(lclRow) + alpha_D_res;
141 } else {
142 m_w(lclRow) = alpha_D_res;
143 }
144 if (do_X_update)
145 m_x_domMap(lclRow) += m_w(lclRow);
146 });
147 });
148 }
149};
150
151// W := alpha * D * (B - A*X) + beta * W.
152template <class WVector,
153 class DVector,
154 class BVector,
155 class AMatrix,
156 class XVector_colMap,
157 class XVector_domMap,
158 class Scalar>
159static void
160chebyshev_kernel_vector(const Scalar& alpha,
161 const WVector& w,
162 const DVector& d,
163 const BVector& b,
164 const AMatrix& A,
167 const Scalar& beta,
168 const bool do_X_update) {
169 using execution_space = typename AMatrix::execution_space;
170
171 if (A.numRows() == 0) {
172 return;
173 }
174
175 int team_size = -1;
176 int vector_length = -1;
178
179 const int64_t rows_per_team = KokkosSparse::Impl::spmv_launch_parameters<execution_space>(A.numRows(), A.nnz(), rows_per_thread, team_size, vector_length);
180 int64_t worksets = (b.extent(0) + rows_per_team - 1) / rows_per_team;
181
182 using Kokkos::Dynamic;
183 using Kokkos::Schedule;
184 using Kokkos::Static;
185 using Kokkos::TeamPolicy;
188 const char kernel_label[] = "chebyshev_kernel_vector";
191 if (team_size < 0) {
192 policyDynamic = policy_type_dynamic(worksets, Kokkos::AUTO, vector_length);
193 policyStatic = policy_type_static(worksets, Kokkos::AUTO, vector_length);
194 } else {
195 policyDynamic = policy_type_dynamic(worksets, team_size, vector_length);
196 policyStatic = policy_type_static(worksets, team_size, vector_length);
197 }
198
199 // Canonicalize template arguments to avoid redundant instantiations.
200 using w_vec_type = typename WVector::non_const_type;
201 using d_vec_type = typename DVector::const_type;
202 using b_vec_type = typename BVector::const_type;
203 using matrix_type = AMatrix;
204 using x_colMap_vec_type = typename XVector_colMap::const_type;
205 using x_domMap_vec_type = typename XVector_domMap::non_const_type;
206#if KOKKOS_VERSION >= 40799
207 using scalar_type = typename KokkosKernels::ArithTraits<Scalar>::val_type;
208#else
209 using scalar_type = typename Kokkos::ArithTraits<Scalar>::val_type;
210#endif
211
212#if KOKKOS_VERSION >= 40799
213 if (beta == KokkosKernels::ArithTraits<Scalar>::zero()) {
214#else
215 if (beta == Kokkos::ArithTraits<Scalar>::zero()) {
216#endif
217 constexpr bool use_beta = false;
218 if (do_X_update) {
219 using functor_type =
220 ChebyshevKernelVectorFunctor<w_vec_type, d_vec_type,
221 b_vec_type, matrix_type,
222 x_colMap_vec_type, x_domMap_vec_type,
223 scalar_type,
224 use_beta,
225 true>;
226 functor_type func(alpha, w, d, b, A, x_colMap, x_domMap, beta, rows_per_team);
227 if (A.nnz() > 10000000)
228 Kokkos::parallel_for(kernel_label, policyDynamic, func);
229 else
230 Kokkos::parallel_for(kernel_label, policyStatic, func);
231 } else {
232 using functor_type =
233 ChebyshevKernelVectorFunctor<w_vec_type, d_vec_type,
234 b_vec_type, matrix_type,
235 x_colMap_vec_type, x_domMap_vec_type,
236 scalar_type,
237 use_beta,
238 false>;
239 functor_type func(alpha, w, d, b, A, x_colMap, x_domMap, beta, rows_per_team);
240 if (A.nnz() > 10000000)
241 Kokkos::parallel_for(kernel_label, policyDynamic, func);
242 else
243 Kokkos::parallel_for(kernel_label, policyStatic, func);
244 }
245 } else {
246 constexpr bool use_beta = true;
247 if (do_X_update) {
248 using functor_type =
249 ChebyshevKernelVectorFunctor<w_vec_type, d_vec_type,
250 b_vec_type, matrix_type,
251 x_colMap_vec_type, x_domMap_vec_type,
252 scalar_type,
253 use_beta,
254 true>;
255 functor_type func(alpha, w, d, b, A, x_colMap, x_domMap, beta, rows_per_team);
256 if (A.nnz() > 10000000)
257 Kokkos::parallel_for(kernel_label, policyDynamic, func);
258 else
259 Kokkos::parallel_for(kernel_label, policyStatic, func);
260 } else {
261 using functor_type =
262 ChebyshevKernelVectorFunctor<w_vec_type, d_vec_type,
263 b_vec_type, matrix_type,
264 x_colMap_vec_type, x_domMap_vec_type,
265 scalar_type,
266 use_beta,
267 false>;
268 functor_type func(alpha, w, d, b, A, x_colMap, x_domMap, beta, rows_per_team);
269 if (A.nnz() > 10000000)
270 Kokkos::parallel_for(kernel_label, policyDynamic, func);
271 else
272 Kokkos::parallel_for(kernel_label, policyStatic, func);
273 }
274 }
275}
276
277} // namespace Impl
278
279template <class TpetraOperatorType>
280ChebyshevKernel<TpetraOperatorType>::
281 ChebyshevKernel(const Teuchos::RCP<const operator_type>& A,
282 const bool useNativeSpMV)
283 : useNativeSpMV_(useNativeSpMV) {
284 setMatrix(A);
285}
286
287template <class TpetraOperatorType>
288void ChebyshevKernel<TpetraOperatorType>::
289 setMatrix(const Teuchos::RCP<const operator_type>& A) {
290 if (A_op_.get() != A.get()) {
291 A_op_ = A;
292
293 // We'll (re)allocate these on demand.
294 V1_ = std::unique_ptr<multivector_type>(nullptr);
295
296 using Teuchos::rcp_dynamic_cast;
297 Teuchos::RCP<const crs_matrix_type> A_crs =
298 rcp_dynamic_cast<const crs_matrix_type>(A);
299 if (A_crs.is_null()) {
300 A_crs_ = Teuchos::null;
301 imp_ = Teuchos::null;
302 exp_ = Teuchos::null;
303 X_colMap_ = nullptr;
304 } else {
305 TEUCHOS_ASSERT(A_crs->isFillComplete());
306 A_crs_ = A_crs;
307 auto G = A_crs->getCrsGraph();
308 imp_ = G->getImporter();
309 exp_ = G->getExporter();
310 if (!imp_.is_null()) {
311 if (X_colMap_.get() == nullptr ||
312 !X_colMap_->getMap()->isSameAs(*(imp_->getTargetMap()))) {
313 X_colMap_ = std::unique_ptr<multivector_type>(new multivector_type(imp_->getTargetMap(), 1));
314 }
315 } else
316 X_colMap_ = nullptr;
317 }
318 }
319}
320
321template <class TpetraOperatorType>
322void ChebyshevKernel<TpetraOperatorType>::
323 setAuxiliaryVectors(size_t numVectors) {
324 if ((V1_.get() == nullptr) || V1_->getNumVectors() != numVectors) {
325 using MV = multivector_type;
326 V1_ = std::unique_ptr<MV>(new MV(A_op_->getRangeMap(), numVectors));
327 }
328}
329
330template <class TpetraOperatorType>
331void ChebyshevKernel<TpetraOperatorType>::
332 compute(multivector_type& W,
333 const SC& alpha,
334 vector_type& D_inv,
335 multivector_type& B,
336 multivector_type& X,
337 const SC& beta) {
338 using Teuchos::RCP;
339 using Teuchos::rcp;
340
341 if (canFuse(B)) {
342 TEUCHOS_ASSERT(!A_crs_.is_null());
343 fusedCase(W, alpha, D_inv, B, *A_crs_, X, beta);
344 } else {
345 TEUCHOS_ASSERT(!A_op_.is_null());
346 unfusedCase(W, alpha, D_inv, B, *A_op_, X, beta);
347 }
348}
349
350template <class TpetraOperatorType>
351typename ChebyshevKernel<TpetraOperatorType>::multivector_type&
352ChebyshevKernel<TpetraOperatorType>::
353 importVector(multivector_type& X_domMap) {
354 if (imp_.is_null()) {
355 return X_domMap;
356 } else {
357 X_colMap_->doImport(X_domMap, *imp_, Tpetra::REPLACE);
358 return *X_colMap_;
359 }
360}
361
362template <class TpetraOperatorType>
363bool ChebyshevKernel<TpetraOperatorType>::
364 canFuse(const multivector_type& B) const {
365 // If override is enabled
366 if (useNativeSpMV_)
367 return false;
368
369 // Some criteria must be met for fused kernel
370 return B.getNumVectors() == size_t(1) &&
371 !A_crs_.is_null() &&
372 exp_.is_null();
373}
374
375template <class TpetraOperatorType>
376void ChebyshevKernel<TpetraOperatorType>::
377 unfusedCase(multivector_type& W,
378 const SC& alpha,
379 vector_type& D_inv,
380 multivector_type& B,
381 const operator_type& A,
382 multivector_type& X,
383 const SC& beta) {
384 using STS = Teuchos::ScalarTraits<SC>;
385 setAuxiliaryVectors(B.getNumVectors());
386
387 const SC one = Teuchos::ScalarTraits<SC>::one();
388
389 // V1 = B - A*X
390 Tpetra::deep_copy(*V1_, B);
391 A.apply(X, *V1_, Teuchos::NO_TRANS, -one, one);
392
393 // W := alpha * D_inv * V1 + beta * W
394 W.elementWiseMultiply(alpha, D_inv, *V1_, beta);
395
396 // X := X + W
397 X.update(STS::one(), W, STS::one());
398}
399
400template <class TpetraOperatorType>
401void ChebyshevKernel<TpetraOperatorType>::
402 fusedCase(multivector_type& W,
403 const SC& alpha,
404 multivector_type& D_inv,
405 multivector_type& B,
406 const crs_matrix_type& A,
407 multivector_type& X,
408 const SC& beta) {
409 multivector_type& X_colMap = importVector(X);
410
411 using Impl::chebyshev_kernel_vector;
412 using STS = Teuchos::ScalarTraits<SC>;
413
414 auto A_lcl = A.getLocalMatrixDevice();
415 // D_inv, B, X and W are all Vectors, so it's safe to take the first column only
416 auto Dinv_lcl = Kokkos::subview(D_inv.getLocalViewDevice(Tpetra::Access::ReadOnly), Kokkos::ALL(), 0);
417 auto B_lcl = Kokkos::subview(B.getLocalViewDevice(Tpetra::Access::ReadOnly), Kokkos::ALL(), 0);
418 auto X_domMap_lcl = Kokkos::subview(X.getLocalViewDevice(Tpetra::Access::ReadWrite), Kokkos::ALL(), 0);
419 auto X_colMap_lcl = Kokkos::subview(X_colMap.getLocalViewDevice(Tpetra::Access::ReadOnly), Kokkos::ALL(), 0);
420
421 const bool do_X_update = !imp_.is_null();
422 if (beta == STS::zero()) {
423 auto W_lcl = Kokkos::subview(W.getLocalViewDevice(Tpetra::Access::OverwriteAll), Kokkos::ALL(), 0);
424 chebyshev_kernel_vector(alpha, W_lcl, Dinv_lcl,
425 B_lcl, A_lcl,
426 X_colMap_lcl, X_domMap_lcl,
427 beta,
428 do_X_update);
429 } else { // need to read _and_ write W if beta != 0
430 auto W_lcl = Kokkos::subview(W.getLocalViewDevice(Tpetra::Access::ReadWrite), Kokkos::ALL(), 0);
431 chebyshev_kernel_vector(alpha, W_lcl, Dinv_lcl,
432 B_lcl, A_lcl,
433 X_colMap_lcl, X_domMap_lcl,
434 beta,
435 do_X_update);
436 }
437 if (!do_X_update)
438 X.update(STS::one(), W, STS::one());
439}
440
441} // namespace Details
442} // namespace Ifpack2
443
444#define IFPACK2_DETAILS_CHEBYSHEVKERNEL_INSTANT(SC, LO, GO, NT) \
445 template class Ifpack2::Details::ChebyshevKernel<Tpetra::Operator<SC, LO, GO, NT> >;
446
447#endif // IFPACK2_DETAILS_CHEBYSHEVKERNEL_DEF_HPP
Ifpack2's implementation of Trilinos::Details::LinearSolver interface.
Definition Ifpack2_Details_LinearSolver_decl.hpp:75
Ifpack2 implementation details.
Preconditioners and smoothers for Tpetra sparse matrices.
Definition Ifpack2_AdditiveSchwarz_decl.hpp:40
Functor for computing W := alpha * D * (B - A*X) + beta * W and X := X+W.
Definition Ifpack2_Details_ChebyshevKernel_def.hpp:45