Tpetra parallel linear algebra Version of the Day
Loading...
Searching...
No Matches
Tpetra_Details_CrsPadding.hpp
1// @HEADER
2// *****************************************************************************
3// Tpetra: Templated Linear Algebra Services Package
4//
5// Copyright 2008 NTESS and the Tpetra contributors.
6// SPDX-License-Identifier: BSD-3-Clause
7// *****************************************************************************
8// @HEADER
9
10#ifndef TPETRA_DETAILS_CRSPADDING_HPP
11#define TPETRA_DETAILS_CRSPADDING_HPP
12
14#include "Tpetra_Util.hpp"
15#include <algorithm>
16#include <iostream>
17#include <memory>
18#include <sstream>
19#include <vector>
20
21namespace Tpetra {
22namespace Details {
23
27template <class LocalOrdinal, class GlobalOrdinal>
29 private:
30 using LO = LocalOrdinal;
31 using GO = GlobalOrdinal;
32
33 enum class Phase {
34 SAME,
35 PERMUTE,
36 IMPORT
37 };
38
39 public:
40 CrsPadding(const int myRank,
41 const size_t /* numSameIDs */,
42 const size_t /* numPermutes */)
43 : myRank_(myRank) {}
44
45 CrsPadding(const int myRank,
46 const size_t /* numImports */)
47 : myRank_(myRank) {}
48
49 void
50 update_same(
51 const LO targetLocalIndex,
52 GO tgtGblColInds[],
53 const size_t origNumTgtEnt,
54 const bool tgtIsUnique,
55 GO srcGblColInds[],
56 const size_t origNumSrcEnt,
57 const bool srcIsUnique) {
58 const LO whichSame = targetLocalIndex;
59 update_impl(Phase::SAME, whichSame, targetLocalIndex,
62 }
63
64 void
65 update_permute(
66 const LO whichPermute, // index in permuteFrom/To
67 const LO targetLocalIndex,
68 GO tgtGblColInds[],
69 const size_t origNumTgtEnt,
70 const bool tgtIsUnique,
71 GO srcGblColInds[],
72 const size_t origNumSrcEnt,
73 const bool srcIsUnique) {
74 update_impl(Phase::PERMUTE, whichPermute, targetLocalIndex,
77 }
78
79 void
80 update_import(
81 const LO whichImport,
82 const LO targetLocalIndex,
83 GO tgtGblColInds[],
84 const size_t origNumTgtEnt,
85 const bool tgtIsUnique,
86 GO srcGblColInds[],
87 const size_t origNumSrcEnt,
88 const bool srcIsUnique) {
89 update_impl(Phase::IMPORT, whichImport, targetLocalIndex,
92 }
93
94 void print(std::ostream& out) const {
95 const size_t maxNumToPrint =
97 const size_t size = entries_.size();
98 out << "entries: [";
99 size_t k = 0;
100 for (const auto& keyval : entries_) {
101 if (k > maxNumToPrint) {
102 out << "...";
103 break;
104 }
105 out << "(" << keyval.first << ", ";
107 "Global column indices", maxNumToPrint);
108 out << ")";
109 if (k + size_t(1) < size) {
110 out << ", ";
111 }
112 ++k;
113 }
114 out << "]";
115 }
116
117 struct Result {
118 size_t numInSrcNotInTgt;
119 bool found;
120 };
121
129 Result
130 get_result(const LO targetLocalIndex) const {
131 auto it = entries_.find(targetLocalIndex);
132 if (it == entries_.end()) {
133 return {0, false};
134 } else {
135 return {it->second.size(), true};
136 }
137 }
138
139 private:
140 void
141 update_impl(
142 const Phase phase,
143 const LO whichImport,
144 const LO targetLocalIndex,
145 GO tgtGblColInds[],
146 const size_t origNumTgtEnt,
147 const bool tgtIsUnique,
148 GO srcGblColInds[],
149 const size_t origNumSrcEnt,
150 const bool srcIsUnique) {
151 using std::endl;
152 std::unique_ptr<std::string> prefix;
153 if (verbose_) {
154 prefix = createPrefix("update_impl");
155 std::ostringstream os;
156 os << *prefix << "Start: "
157 << "targetLocalIndex=" << targetLocalIndex
158 << ", origNumTgtEnt=" << origNumTgtEnt
159 << ", origNumSrcEnt=" << origNumSrcEnt << endl;
160 std::cerr << os.str();
161 }
162
163 // FIXME (08 Feb 2020) We only need to sort and unique
164 // tgtGblColInds if we haven't already seen it before.
167 std::sort(tgtGblColInds, tgtEnd);
168 if (!tgtIsUnique) {
169 tgtEnd = std::unique(tgtGblColInds, tgtEnd);
172 }
173
174 if (verbose_) {
175 std::ostringstream os;
176 os << *prefix << "finished src; process tgt" << endl;
177 std::cerr << os.str();
178 }
179
180 size_t newNumSrcEnt = origNumSrcEnt;
181 auto srcEnd = srcGblColInds + origNumSrcEnt;
182 std::sort(srcGblColInds, srcEnd);
183 if (!srcIsUnique) {
184 srcEnd = std::unique(srcGblColInds, srcEnd);
185 newNumSrcEnt = size_t(srcEnd - srcGblColInds);
186 TEUCHOS_ASSERT(newNumSrcEnt <= origNumSrcEnt);
187 }
188
189 merge_with_current_state(phase, whichImport, targetLocalIndex,
190 tgtGblColInds, newNumTgtEnt,
191 srcGblColInds, newNumSrcEnt);
192 if (verbose_) {
193 std::ostringstream os;
194 os << *prefix << "Done" << endl;
195 std::cerr << os.str();
196 }
197 }
198
199 std::vector<GO>&
200 get_difference_col_inds(const Phase /* phase */,
201 const LO /* whichIndex */,
202 const LO tgtLclRowInd) {
203 return entries_[tgtLclRowInd];
204 }
205
206 void
207 merge_with_current_state(
208 const Phase phase,
209 const LO whichIndex,
210 const LO tgtLclRowInd,
211 const GO tgtColInds[], // sorted & merged
212 const size_t numTgtEnt,
213 const GO srcColInds[], // sorted & merged
214 const size_t numSrcEnt) {
215 using std::endl;
216 std::unique_ptr<std::string> prefix;
217 if (verbose_) {
218 prefix = createPrefix("merge_with_current_state");
219 std::ostringstream os;
220 os << *prefix << "Start: "
221 << "tgtLclRowInd=" << tgtLclRowInd
222 << ", numTgtEnt=" << numTgtEnt
223 << ", numSrcEnt=" << numSrcEnt << endl;
224 std::cerr << os.str();
225 }
226 // We only need to accumulate those source indices that are
227 // not already target indices. This is because we always have
228 // the target indices on input to this function, so there's no
229 // need to store them here again. That still could be a lot
230 // to store, but it's better than duplicating target matrix
231 // storage.
232 //
233 // This means that consumers of this data structure need to
234 // treat entries_[tgtLclRowInd].size() as an increment, not as
235 // the required new allocation size itself.
236 //
237 // We store
238 //
239 // difference(union(incoming source indices,
240 // already stored source indices),
241 // target indices)
242
243 auto tgtEnd = tgtColInds + numTgtEnt;
244 auto srcEnd = srcColInds + numSrcEnt;
245
246 // At least one input source index isn't in the target.
247 std::vector<GO>& diffColInds =
248 get_difference_col_inds(phase, whichIndex, tgtLclRowInd);
249 const size_t oldDiffNumEnt = diffColInds.size();
250
251 if (oldDiffNumEnt == 0) {
252 if (verbose_) {
253 std::ostringstream os;
254 os << *prefix << "oldDiffNumEnt=0; call "
255 "set_difference(src,tgt,diff)"
256 << endl;
257 std::cerr << os.str();
258 }
259 diffColInds.resize(numSrcEnt);
260 auto diffEnd = std::set_difference(srcColInds, srcEnd,
261 tgtColInds, tgtEnd,
262 diffColInds.begin());
263 const size_t newLen(diffEnd - diffColInds.begin());
264 TEUCHOS_ASSERT(newLen <= numSrcEnt);
265 diffColInds.resize(newLen);
266 } else {
267 // scratch = union(srcColInds, diffColInds);
268 // diffColInds = difference(scratch, tgtColInds);
269
270 const size_t maxUnionSize = numSrcEnt + oldDiffNumEnt;
271 if (verbose_) {
272 std::ostringstream os;
273 os << *prefix << "oldDiffNumEnt=" << oldDiffNumEnt
274 << ", maxUnionSize=" << maxUnionSize
275 << "; call set_union(src,diff,union)" << endl;
276 std::cerr << os.str();
277 }
278 if (scratchColInds_.size() < maxUnionSize) {
279 scratchColInds_.resize(maxUnionSize);
280 }
281 auto unionBeg = scratchColInds_.begin();
282 auto unionEnd = std::set_union(srcColInds, srcEnd,
283 diffColInds.begin(), diffColInds.end(),
284 unionBeg);
285 const size_t unionSize(unionEnd - unionBeg);
286 TEUCHOS_ASSERT(unionSize <= maxUnionSize);
287
288 if (verbose_) {
289 std::ostringstream os;
290 os << *prefix << "oldDiffNumEnt=" << oldDiffNumEnt
291 << ", unionSize=" << unionSize << "; call "
292 "set_difference(union,tgt,diff)"
293 << endl;
294 std::cerr << os.str();
295 }
296 diffColInds.resize(unionSize);
297 auto diffEnd = std::set_difference(unionBeg, unionEnd,
298 tgtColInds, tgtEnd,
299 diffColInds.begin());
300 const size_t diffLen(diffEnd - diffColInds.begin());
301 TEUCHOS_ASSERT(diffLen <= unionSize);
302 diffColInds.resize(diffLen);
303 }
304
305 if (verbose_) {
306 std::ostringstream os;
307 os << *prefix << "Done" << endl;
308 std::cerr << os.str();
309 }
310 }
311
312 std::unique_ptr<std::string>
313 createPrefix(const char funcName[]) {
314 std::ostringstream os;
315 os << "Proc " << myRank_ << ": CrsPadding::" << funcName
316 << ": ";
317 return std::unique_ptr<std::string>(new std::string(os.str()));
318 }
319
320 // imports may overlap with sames and/or permutes, so it makes
321 // sense to store them all in one map.
322 std::map<LO, std::vector<GO> > entries_;
323 std::vector<GO> scratchColInds_;
324 int myRank_ = -1;
325 bool verbose_ = Behavior::verbose("CrsPadding");
326};
327} // namespace Details
328} // namespace Tpetra
329
330#endif // TPETRA_DETAILS_CRSPADDING_HPP
Declaration of Tpetra::Details::Behavior, a class that describes Tpetra's behavior.
Stand-alone utility functions and macros.
Struct that holds views of the contents of a CrsMatrix.
static bool verbose()
Whether Tpetra is in verbose mode.
static size_t verbosePrintCountThreshold()
Number of entries below which arrays, lists, etc. will be printed in debug mode.
Keep track of how much more space a CrsGraph or CrsMatrix needs, when the graph or matrix is the targ...
Result get_result(const LO targetLocalIndex) const
For a given target matrix local row index, return the number of unique source column indices to merge...
Implementation details of Tpetra.
void verbosePrintArray(std::ostream &out, const ArrayType &x, const char name[], const size_t maxNumToPrint)
Print min(x.size(), maxNumToPrint) entries of x.
Namespace Tpetra contains the class and methods constituting the Tpetra library.