Tpetra parallel linear algebra Version of the Day
Loading...
Searching...
No Matches
Tpetra_Details_getDiagCopyWithoutOffsets_def.hpp
Go to the documentation of this file.
1// @HEADER
2// *****************************************************************************
3// Tpetra: Templated Linear Algebra Services Package
4//
5// Copyright 2008 NTESS and the Tpetra contributors.
6// SPDX-License-Identifier: BSD-3-Clause
7// *****************************************************************************
8// @HEADER
9
10#ifndef TPETRA_DETAILS_GETDIAGCOPYWITHOUTOFFSETS_DEF_HPP
11#define TPETRA_DETAILS_GETDIAGCOPYWITHOUTOFFSETS_DEF_HPP
12
20
22#include "Tpetra_RowGraph.hpp"
23#include "Tpetra_CrsGraph.hpp"
24#include "Tpetra_RowMatrix.hpp"
25#include "Tpetra_Vector.hpp"
26
27namespace Tpetra {
28namespace Details {
29
30// Work-around for #499: Implementation of one-argument (no offsets)
31// getLocalDiagCopy for the NOT fill-complete case.
32//
33// NOTE (mfh 18 Jul 2016) This calls functions that are NOT GPU device
34// functions! Thus, we do NOT use KOKKOS_INLINE_FUNCTION or
35// KOKKOS_FUNCTION here, because those attempt to mark the functions
36// they modify as CUDA device functions. This functor is ONLY for
37// non-CUDA execution spaces!
38template <class SC, class LO, class GO, class NT>
39class GetLocalDiagCopyWithoutOffsetsNotFillCompleteFunctor {
40 public:
41 using row_matrix_type = ::Tpetra::RowMatrix<SC, LO, GO, NT>;
42 using vec_type = ::Tpetra::Vector<SC, LO, GO, NT>;
43
44 using IST = typename vec_type::impl_scalar_type;
45 // The output Vector determines the execution space.
46
47 using host_execution_space = typename vec_type::dual_view_type::t_host::execution_space;
48
49 private:
50 using map_type = typename vec_type::map_type;
51
52 static bool
53 graphIsSorted(const row_matrix_type& A) {
54 using Teuchos::RCP;
55 using Teuchos::rcp_dynamic_cast;
56 using crs_graph_type = Tpetra::CrsGraph<LO, GO, NT>;
57 using row_graph_type = Tpetra::RowGraph<LO, GO, NT>;
58
59 // We conservatively assume not sorted. RowGraph lacks an
60 // "isSorted" predicate, so we can't know for sure unless the cast
61 // to CrsGraph succeeds.
62 bool sorted = false;
63
64 RCP<const row_graph_type> G_row = A.getGraph();
65 if (!G_row.is_null()) {
66 RCP<const crs_graph_type> G_crs =
67 rcp_dynamic_cast<const crs_graph_type>(G_row);
68 if (!G_crs.is_null()) {
69 sorted = G_crs->isSorted();
70 }
71 }
72
73 return sorted;
74 }
75
76 public:
77 // lclNumErrs [out] Total count of errors on this process.
78 GetLocalDiagCopyWithoutOffsetsNotFillCompleteFunctor(LO& lclNumErrs,
79 vec_type& diag,
80 const row_matrix_type& A)
81 : A_(A)
82 , lclRowMap_(*A.getRowMap())
83 , lclColMap_(*A.getColMap())
84 , sorted_(graphIsSorted(A)) {
85 const LO lclNumRows = static_cast<LO>(diag.getLocalLength());
86 {
87 const LO matLclNumRows =
88 static_cast<LO>(lclRowMap_.getLocalNumElements());
89 TEUCHOS_TEST_FOR_EXCEPTION(lclNumRows != matLclNumRows, std::invalid_argument,
90 "diag.getLocalLength() = " << lclNumRows << " != "
91 "A.getRowMap()->getLocalNumElements() = "
92 << matLclNumRows << ".");
93 }
94
95 // Side effects start below this point.
96 D_lcl_ = diag.getLocalViewHost(Access::OverwriteAll);
97 D_lcl_1d_ = Kokkos::subview(D_lcl_, Kokkos::ALL(), 0);
98
99 Kokkos::RangePolicy<host_execution_space, LO> range(0, lclNumRows);
100 lclNumErrs = 0;
101 Kokkos::parallel_reduce(range, *this, lclNumErrs);
102
103 // sync changes back to device, since the user doesn't know that
104 // we had to run on host.
105 // diag.template sync<typename device_type::memory_space> ();
106 }
107
108 void operator()(const LO& lclRowInd, LO& errCount) const {
109 using KokkosSparse::findRelOffset;
110
111 D_lcl_1d_(lclRowInd) = KokkosKernels::ArithTraits<IST>::zero();
112 const GO gblInd = lclRowMap_.getGlobalElement(lclRowInd);
113 const LO lclColInd = lclColMap_.getLocalElement(gblInd);
114
115 if (lclColInd == Tpetra::Details::OrdinalTraits<LO>::invalid()) {
116 errCount++;
117 } else { // row index is also in the column Map on this process
118 typename row_matrix_type::local_inds_host_view_type lclColInds;
119 typename row_matrix_type::values_host_view_type curVals;
120 A_.getLocalRowView(lclRowInd, lclColInds, curVals);
121 LO numEnt = lclColInds.extent(0);
122 // The search hint is always zero, since we only call this
123 // once per row of the matrix.
124 const LO hint = 0;
125 const LO offset =
126 findRelOffset(lclColInds, numEnt, lclColInd, hint, sorted_);
127 if (offset == numEnt) { // didn't find the diagonal column index
128 errCount++;
129 } else {
130 D_lcl_1d_(lclRowInd) = curVals[offset];
131 }
132 }
133 }
134
135 private:
136 const row_matrix_type& A_;
137 map_type lclRowMap_;
138 map_type lclColMap_;
139 typename vec_type::dual_view_type::t_host D_lcl_;
140 decltype(Kokkos::subview(D_lcl_, Kokkos::ALL(), 0)) D_lcl_1d_;
141 const bool sorted_;
142};
143
144template <class SC, class LO, class GO, class NT>
146 const ::Tpetra::RowMatrix<SC, LO, GO, NT>& A,
147 const bool debug) {
148 using Teuchos::outArg;
149 using Teuchos::REDUCE_MIN;
150 using Teuchos::reduceAll;
151 using ::Tpetra::Details::gathervPrint;
153
154 // The functor's constructor does error checking and executes the
155 // thread-parallel kernel.
156
157 LO lclNumErrs = 0;
158
159 if (debug) {
160 int lclSuccess = 1;
161 int gblSuccess = 0;
162 std::ostringstream errStrm;
163 Teuchos::RCP<const Teuchos::Comm<int> > commPtr = A.getComm();
164 if (commPtr.is_null()) {
165 return lclNumErrs; // this process does not participate
166 }
167 const Teuchos::Comm<int>& comm = *commPtr;
168
169 try {
171 } catch (std::exception& e) {
172 lclSuccess = -1;
173 errStrm << "Process " << A.getComm()->getRank() << ": "
174 << e.what() << std::endl;
175 }
176 if (lclNumErrs != 0) {
177 lclSuccess = 0;
178 }
179
181 if (gblSuccess == -1) {
182 if (comm.getRank() == 0) {
183 // We gather into std::cerr, rather than using an
184 // std::ostringstream, because there might be a lot of MPI
185 // processes. It could take too much memory to gather all the
186 // messages to Process 0 before printing. gathervPrint gathers
187 // and prints one message at a time, thus saving memory. I
188 // don't want to run out of memory while trying to print an
189 // error message; that would hide the real problem.
190 std::cerr << "getLocalDiagCopyWithoutOffsetsNotFillComplete threw an "
191 "exception on one or more MPI processes in the matrix's comunicator."
192 << std::endl;
193 }
194 gathervPrint(std::cerr, errStrm.str(), comm);
195 // Don't need to print anything here, since we've already
196 // printed to std::cerr above.
197 TEUCHOS_TEST_FOR_EXCEPTION(true, std::runtime_error, "");
198 } else if (gblSuccess == 0) {
199 TEUCHOS_TEST_FOR_EXCEPTION(gblSuccess != 1, std::runtime_error,
200 "getLocalDiagCopyWithoutOffsetsNotFillComplete failed on "
201 "one or more MPI processes in the matrix's communicator.");
202 }
203 } else { // ! debug
205 }
206
207 return lclNumErrs;
208}
209
210} // namespace Details
211} // namespace Tpetra
212
213// Explicit template instantiation macro for
214// getLocalDiagCopyWithoutOffsetsNotFillComplete. NOT FOR USERS!!!
215// Must be used inside the Tpetra namespace.
216#define TPETRA_DETAILS_GETDIAGCOPYWITHOUTOFFSETS_INSTANT(SCALAR, LO, GO, NODE) \
217 template LO \
218 Details::getLocalDiagCopyWithoutOffsetsNotFillComplete<SCALAR, LO, GO, NODE>(::Tpetra::Vector<SCALAR, LO, GO, NODE> & diag, \
219 const ::Tpetra::RowMatrix<SCALAR, LO, GO, NODE>& A, \
220 const bool debug);
221
222#endif // TPETRA_DETAILS_GETDIAGCOPYWITHOUTOFFSETS_DEF_HPP
Declaration of a function that prints strings from each process.
Struct that holds views of the contents of a CrsMatrix.
Implementation details of Tpetra.
LO getLocalDiagCopyWithoutOffsetsNotFillComplete(::Tpetra::Vector< SC, LO, GO, NT > &diag, const ::Tpetra::RowMatrix< SC, LO, GO, NT > &A, const bool debug=false)
Given a locally indexed, global sparse matrix, extract the matrix's diagonal entries into a Tpetra::V...
void gathervPrint(std::ostream &out, const std::string &s, const Teuchos::Comm< int > &comm)
On Process 0 in the given communicator, print strings from each process in that communicator,...
Namespace Tpetra contains the class and methods constituting the Tpetra library.