Tpetra parallel linear algebra Version of the Day
Loading...
Searching...
No Matches
Tpetra_Details_extractBlockDiagonal.hpp
Go to the documentation of this file.
1// @HEADER
2// *****************************************************************************
3// Tpetra: Templated Linear Algebra Services Package
4//
5// Copyright 2008 NTESS and the Tpetra contributors.
6// SPDX-License-Identifier: BSD-3-Clause
7// *****************************************************************************
8// @HEADER
9
10#ifndef TPETRA_DETAILS_EXTRACTBLOCKDIAGONAL_HPP
11#define TPETRA_DETAILS_EXTRACTBLOCKDIAGONAL_HPP
12
13#include "TpetraCore_config.h"
14#include "Tpetra_CrsMatrix.hpp"
15#include "Teuchos_RCP.hpp"
17
24
25namespace Tpetra {
26namespace Details {
27
28template <class SparseMatrixType,
29 class MultiVectorType>
30void extractBlockDiagonal(const SparseMatrixType& A, MultiVectorType& diagonal) {
31 using local_map_type = typename SparseMatrixType::map_type::local_map_type;
32 using SC = typename MultiVectorType::scalar_type;
33 using LO = typename SparseMatrixType::local_ordinal_type;
34 using KCRS = typename SparseMatrixType::local_matrix_device_type;
35 using lno_view_t = typename KCRS::StaticCrsGraphType::row_map_type::const_type;
36 using lno_nnz_view_t = typename KCRS::StaticCrsGraphType::entries_type::const_type;
37 using scalar_view_t = typename KCRS::values_type::const_type;
38 using local_mv_type = typename MultiVectorType::dual_view_type::t_dev;
39 using range_type = Kokkos::RangePolicy<typename SparseMatrixType::node_type::execution_space, LO>;
40#if KOKKOS_VERSION >= 40799
41 using ATS = KokkosKernels::ArithTraits<SC>;
42#else
43 using ATS = Kokkos::ArithTraits<SC>;
44#endif
45#if KOKKOS_VERSION >= 40799
46 using impl_ATS = KokkosKernels::ArithTraits<typename ATS::val_type>;
47#else
48 using impl_ATS = Kokkos::ArithTraits<typename ATS::val_type>;
49#endif
50
51 // Sanity checking: Map Compatibility (A's rowmap matches diagonal's map)
53 TEUCHOS_TEST_FOR_EXCEPTION(!A.getRowMap()->isSameAs(*diagonal.getMap()),
54 std::runtime_error, "Tpetra::Details::extractBlockDiagonal was given incompatible maps");
55 }
56
57 LO numrows = diagonal.getLocalLength();
58 LO blocksize = diagonal.getNumVectors();
59
60 // Get Kokkos versions of objects
61 local_map_type rowmap = A.getRowMap()->getLocalMap();
62 local_map_type colmap = A.getRowMap()->getLocalMap();
63 local_mv_type diag = diagonal.getLocalViewDevice(Access::OverwriteAll);
64 const KCRS Amat = A.getLocalMatrixDevice();
65 lno_view_t Arowptr = Amat.graph.row_map;
66 lno_nnz_view_t Acolind = Amat.graph.entries;
67 scalar_view_t Avals = Amat.values;
68
69 Kokkos::parallel_for(
70 "Tpetra::extractBlockDiagonal", range_type(0, numrows), KOKKOS_LAMBDA(const LO i) {
71 LO diag_col = colmap.getLocalElement(rowmap.getGlobalElement(i));
72 LO blockStart = diag_col - (diag_col % blocksize);
73 LO blockStop = blockStart + blocksize;
74 for (LO k = 0; k < blocksize; k++)
75 diag(i, k) = impl_ATS::zero();
76
77 for (size_t k = Arowptr(i); k < Arowptr(i + 1); k++) {
78 LO col = Acolind(k);
79 if (blockStart <= col && col < blockStop) {
80 diag(i, col - blockStart) = Avals(k);
81 }
82 }
83 });
84}
85
86} // namespace Details
87} // namespace Tpetra
88
89#endif // TPETRA_DETAILS_EXTRACTBLOCKDIAGONAL_HPP
Declaration of Tpetra::Details::Behavior, a class that describes Tpetra's behavior.
static bool debug()
Whether Tpetra is in debug mode.
Implementation details of Tpetra.
Namespace Tpetra contains the class and methods constituting the Tpetra library.