Intrepid2
Intrepid2_DataCombiners.hpp
Go to the documentation of this file.
1// @HEADER
2// *****************************************************************************
3// Intrepid2 Package
4//
5// Copyright 2007 NTESS and the Intrepid2 contributors.
6// SPDX-License-Identifier: BSD-3-Clause
7// *****************************************************************************
8// @HEADER
9//
10// Intrepid2_DataCombiners.hpp
11// Trilinos
12//
13// Created by Roberts, Nathan V on 5/31/23.
14//
15
16#ifndef Intrepid2_DataCombiners_hpp
17#define Intrepid2_DataCombiners_hpp
18
25#include "Intrepid2_Data.hpp"
28#include "Intrepid2_ScalarView.hpp"
29
30namespace Intrepid2 {
31 template<class DataScalar,typename DeviceType>
32 class Data;
33
34 template<class BinaryOperator, class ThisUnderlyingViewType, class AUnderlyingViewType, class BUnderlyingViewType,
35 class ArgExtractorThis, class ArgExtractorA, class ArgExtractorB, bool includeInnerLoop=false>
37 {
38 private:
39 ThisUnderlyingViewType this_underlying_;
40 AUnderlyingViewType A_underlying_;
41 BUnderlyingViewType B_underlying_;
42 BinaryOperator binaryOperator_;
43 int innerLoopSize_;
44 public:
45 InPlaceCombinationFunctor(ThisUnderlyingViewType this_underlying, AUnderlyingViewType A_underlying, BUnderlyingViewType B_underlying,
46 BinaryOperator binaryOperator)
47 :
48 this_underlying_(this_underlying),
49 A_underlying_(A_underlying),
50 B_underlying_(B_underlying),
51 binaryOperator_(binaryOperator),
52 innerLoopSize_(-1)
53 {
54 INTREPID2_TEST_FOR_EXCEPTION(includeInnerLoop,std::invalid_argument,"If includeInnerLoop is true, must specify the size of the inner loop");
55 }
56
57 InPlaceCombinationFunctor(ThisUnderlyingViewType this_underlying, AUnderlyingViewType A_underlying, BUnderlyingViewType B_underlying,
58 BinaryOperator binaryOperator, int innerLoopSize)
59 :
60 this_underlying_(this_underlying),
61 A_underlying_(A_underlying),
62 B_underlying_(B_underlying),
63 binaryOperator_(binaryOperator),
64 innerLoopSize_(innerLoopSize)
65 {
66 INTREPID2_TEST_FOR_EXCEPTION(includeInnerLoop,std::invalid_argument,"If includeInnerLoop is true, must specify the size of the inner loop");
67 }
68
69 template<class ...IntArgs, bool M=includeInnerLoop>
70 KOKKOS_INLINE_FUNCTION
71 enable_if_t<!M, void>
72 operator()(const IntArgs&... args) const
73 {
74 auto & result = ArgExtractorThis::get( this_underlying_, args... );
75 const auto & A_val = ArgExtractorA::get( A_underlying_, args... );
76 const auto & B_val = ArgExtractorB::get( B_underlying_, args... );
77
78 result = binaryOperator_(A_val,B_val);
79 }
80
81 template<class ...IntArgs, bool M=includeInnerLoop>
82 KOKKOS_INLINE_FUNCTION
83 enable_if_t<M, void>
84 operator()(const IntArgs&... args) const
85 {
86 using int_type = std::tuple_element_t<0, std::tuple<IntArgs...>>;
87 for (int_type iFinal=0; iFinal<static_cast<int_type>(innerLoopSize_); iFinal++)
88 {
89 auto & result = ArgExtractorThis::get( this_underlying_, args..., iFinal );
90 const auto & A_val = ArgExtractorA::get( A_underlying_, args..., iFinal );
91 const auto & B_val = ArgExtractorB::get( B_underlying_, args..., iFinal );
92
93 result = binaryOperator_(A_val,B_val);
94 }
95 }
96 };
97
99 template<class BinaryOperator, class ThisUnderlyingViewType, class AUnderlyingViewType, class BUnderlyingViewType>
101 {
102 private:
103 ThisUnderlyingViewType this_underlying_;
104 AUnderlyingViewType A_underlying_;
105 BUnderlyingViewType B_underlying_;
106 BinaryOperator binaryOperator_;
107 public:
108 InPlaceCombinationFunctorConstantCase(ThisUnderlyingViewType this_underlying,
109 AUnderlyingViewType A_underlying,
110 BUnderlyingViewType B_underlying,
111 BinaryOperator binaryOperator)
112 :
113 this_underlying_(this_underlying),
114 A_underlying_(A_underlying),
115 B_underlying_(B_underlying),
116 binaryOperator_(binaryOperator)
117 {
118 INTREPID2_TEST_FOR_EXCEPTION(this_underlying.extent(0) != 1,std::invalid_argument,"all views for InPlaceCombinationFunctorConstantCase should have rank 1 and extent 1");
119 INTREPID2_TEST_FOR_EXCEPTION(A_underlying.extent(0) != 1,std::invalid_argument,"all views for InPlaceCombinationFunctorConstantCase should have rank 1 and extent 1");
120 INTREPID2_TEST_FOR_EXCEPTION(B_underlying.extent(0) != 1,std::invalid_argument,"all views for InPlaceCombinationFunctorConstantCase should have rank 1 and extent 1");
121 }
122
123 KOKKOS_INLINE_FUNCTION
124 void operator()(const int arg0) const
125 {
126 auto & result = this_underlying_(0);
127 const auto & A_val = A_underlying_(0);
128 const auto & B_val = B_underlying_(0);
129
130 result = binaryOperator_(A_val,B_val);
131 }
132 };
133
135 template<bool passThroughBlockDiagonalArgs>
137 {
138 template<class ViewType, class ...IntArgs>
139 static KOKKOS_INLINE_FUNCTION typename ViewType::reference_type get(const ViewType &view, const IntArgs&... intArgs)
140 {
141 return view.getWritableEntryWithPassThroughOption(passThroughBlockDiagonalArgs, intArgs...);
142 }
143 };
144
146 template<bool passThroughBlockDiagonalArgs>
148 {
149 template<class ViewType, class ...IntArgs>
150 static KOKKOS_INLINE_FUNCTION typename ViewType::const_reference_type get(const ViewType &view, const IntArgs&... intArgs)
151 {
152 return view.getEntryWithPassThroughOption(passThroughBlockDiagonalArgs, intArgs...);
153 }
154 };
155
156// static class for combining two Data objects using a specified binary operator
157 template <class DataScalar,typename DeviceType, class BinaryOperator>
159{
160 using reference_type = typename ScalarView<DataScalar,DeviceType>::reference_type;
161 using const_reference_type = typename ScalarView<const DataScalar,DeviceType>::reference_type;
162public:
164 template<class PolicyType, class ThisUnderlyingViewType, class AUnderlyingViewType, class BUnderlyingViewType,
165 class ArgExtractorThis, class ArgExtractorA, class ArgExtractorB>
166 static void storeInPlaceCombination(PolicyType &policy, ThisUnderlyingViewType &this_underlying,
167 AUnderlyingViewType &A_underlying, BUnderlyingViewType &B_underlying,
168 BinaryOperator &binaryOperator, ArgExtractorThis argThis, ArgExtractorA argA, ArgExtractorB argB)
169 {
171 Functor functor(this_underlying, A_underlying, B_underlying, binaryOperator);
172 Kokkos::parallel_for("compute in-place", policy, functor);
173 }
174
176 template<int rank>
177 static
178 enable_if_t<rank != 7, void>
180 {
181 auto policy = thisData.template dataExtentRangePolicy<rank>();
182
183 const bool A_1D = A.getUnderlyingViewRank() == 1;
184 const bool B_1D = B.getUnderlyingViewRank() == 1;
185 const bool this_1D = thisData.getUnderlyingViewRank() == 1;
186 const bool A_constant = A_1D && (A.getUnderlyingViewSize() == 1);
187 const bool B_constant = B_1D && (B.getUnderlyingViewSize() == 1);
188 const bool this_constant = this_1D && (thisData.getUnderlyingViewSize() == 1);
189 const bool A_full = A.underlyingMatchesLogical();
190 const bool B_full = B.underlyingMatchesLogical();
191 const bool this_full = thisData.underlyingMatchesLogical();
192
194
196 const FullArgExtractorData<true> fullArgsData; // true: pass through block diagonal args. This is due to the behavior of dataExtentRangePolicy() for block diagonal args.
197 const FullArgExtractorWritableData<true> fullArgsWritable; // true: pass through block diagonal args. This is due to the behavior of dataExtentRangePolicy() for block diagonal args.
198
205
206 // this lambda returns -1 if there is not a rank-1 underlying view whose data extent matches the logical extent in the corresponding dimension;
207 // otherwise, it returns the logical index of the corresponding dimension.
208 auto get1DArgIndex = [](const Data<DataScalar,DeviceType> &data) -> int
209 {
210 const auto & variationTypes = data.getVariationTypes();
211 for (int d=0; d<rank; d++)
212 {
213 if (variationTypes[d] == GENERAL)
214 {
215 return d;
216 }
217 }
218 return -1;
219 };
220 if (this_constant)
221 {
222 // then A, B are constant, too
223 auto thisAE = constArg;
224 auto AAE = constArg;
225 auto BAE = constArg;
226 auto & this_underlying = thisData.template getUnderlyingView<1>();
227 auto & A_underlying = A.template getUnderlyingView<1>();
228 auto & B_underlying = B.template getUnderlyingView<1>();
229 storeInPlaceCombination(policy, this_underlying, A_underlying, B_underlying, binaryOperator, thisAE, AAE, BAE);
230 }
231 else if (this_full && A_full && B_full)
232 {
233 auto thisAE = fullArgs;
234 auto AAE = fullArgs;
235 auto BAE = fullArgs;
236
237 auto & this_underlying = thisData.template getUnderlyingView<rank>();
238 auto & A_underlying = A.template getUnderlyingView<rank>();
239 auto & B_underlying = B.template getUnderlyingView<rank>();
240
241 storeInPlaceCombination(policy, this_underlying, A_underlying, B_underlying, binaryOperator, thisAE, AAE, BAE);
242 }
243 else if (A_constant)
244 {
245 auto AAE = constArg;
246 auto & A_underlying = A.template getUnderlyingView<1>();
247 if (this_full)
248 {
249 auto thisAE = fullArgs;
250 auto & this_underlying = thisData.template getUnderlyingView<rank>();
251
252 if (B_full)
253 {
254 auto BAE = fullArgs;
255 auto & B_underlying = B.template getUnderlyingView<rank>();
256 storeInPlaceCombination(policy, this_underlying, A_underlying, B_underlying, binaryOperator, thisAE, AAE, BAE);
257 }
258 else // this_full, not B_full: B may have modular data, etc.
259 {
260 auto BAE = fullArgsData;
261 storeInPlaceCombination(policy, this_underlying, A_underlying, B, binaryOperator, thisAE, AAE, BAE);
262 }
263 }
264 else // this is not full
265 {
266 // below, we optimize for the case of 1D data in B, when A is constant. Still need to handle other cases…
267 if (B_1D && (get1DArgIndex(B) != -1) )
268 {
269 // since A is constant, that implies that this_1D is true, and has the same 1DArgIndex
270 const int argIndex = get1DArgIndex(B);
271 auto & B_underlying = B.template getUnderlyingView<1>();
272 auto & this_underlying = thisData.template getUnderlyingView<1>();
273 switch (argIndex)
274 {
275 case 0: storeInPlaceCombination(policy, this_underlying, A_underlying, B_underlying, binaryOperator, arg0, AAE, arg0); break;
276 case 1: storeInPlaceCombination(policy, this_underlying, A_underlying, B_underlying, binaryOperator, arg1, AAE, arg1); break;
277 case 2: storeInPlaceCombination(policy, this_underlying, A_underlying, B_underlying, binaryOperator, arg2, AAE, arg2); break;
278 case 3: storeInPlaceCombination(policy, this_underlying, A_underlying, B_underlying, binaryOperator, arg3, AAE, arg3); break;
279 case 4: storeInPlaceCombination(policy, this_underlying, A_underlying, B_underlying, binaryOperator, arg4, AAE, arg4); break;
280 case 5: storeInPlaceCombination(policy, this_underlying, A_underlying, B_underlying, binaryOperator, arg5, AAE, arg5); break;
281 default: INTREPID2_TEST_FOR_EXCEPTION(true, std::invalid_argument, "Invalid/unexpected arg index");
282 }
283 }
284 else
285 {
286 // since storing to Data object requires a call to getWritableEntry(), we use FullArgExtractorWritableData
287 auto thisAE = fullArgsWritable;
288 auto BAE = fullArgsData;
289 storeInPlaceCombination(policy, thisData, A_underlying, B, binaryOperator, thisAE, AAE, BAE);
290 }
291 }
292 }
293 else if (B_constant)
294 {
295 auto BAE = constArg;
296 auto & B_underlying = B.template getUnderlyingView<1>();
297 if (this_full)
298 {
299 auto thisAE = fullArgs;
300 auto & this_underlying = thisData.template getUnderlyingView<rank>();
301 if (A_full)
302 {
303 auto AAE = fullArgs;
304 auto & A_underlying = A.template getUnderlyingView<rank>();
305
306 storeInPlaceCombination(policy, this_underlying, A_underlying, B_underlying, binaryOperator, thisAE, AAE, BAE);
307 }
308 else // this_full, not A_full: A may have modular data, etc.
309 {
310 // use A (the Data object). This could be further optimized by using A's underlying View and an appropriately-defined ArgExtractor.
311 auto AAE = fullArgsData;
312 storeInPlaceCombination(policy, this_underlying, A, B_underlying, binaryOperator, thisAE, AAE, BAE);
313 }
314 }
315 else // this is not full
316 {
317 // below, we optimize for the case of 1D data in A, when B is constant. Still need to handle other cases…
318 if (A_1D && (get1DArgIndex(A) != -1) )
319 {
320 // since B is constant, that implies that this_1D is true, and has the same 1DArgIndex as A
321 const int argIndex = get1DArgIndex(A);
322 auto & A_underlying = A.template getUnderlyingView<1>();
323 auto & this_underlying = thisData.template getUnderlyingView<1>();
324 switch (argIndex)
325 {
326 case 0: storeInPlaceCombination(policy, this_underlying, A_underlying, B_underlying, binaryOperator, arg0, arg0, BAE); break;
327 case 1: storeInPlaceCombination(policy, this_underlying, A_underlying, B_underlying, binaryOperator, arg1, arg1, BAE); break;
328 case 2: storeInPlaceCombination(policy, this_underlying, A_underlying, B_underlying, binaryOperator, arg2, arg2, BAE); break;
329 case 3: storeInPlaceCombination(policy, this_underlying, A_underlying, B_underlying, binaryOperator, arg3, arg3, BAE); break;
330 case 4: storeInPlaceCombination(policy, this_underlying, A_underlying, B_underlying, binaryOperator, arg4, arg4, BAE); break;
331 case 5: storeInPlaceCombination(policy, this_underlying, A_underlying, B_underlying, binaryOperator, arg5, arg5, BAE); break;
332 default: INTREPID2_TEST_FOR_EXCEPTION(true, std::invalid_argument, "Invalid/unexpected arg index");
333 }
334 }
335 else
336 {
337 // since storing to Data object requires a call to getWritableEntry(), we use FullArgExtractorWritableData
338 auto thisAE = fullArgsWritable;
339 auto AAE = fullArgsData;
340 storeInPlaceCombination(policy, thisData, A, B_underlying, binaryOperator, thisAE, AAE, BAE);
341 }
342 }
343 }
344 else // neither A nor B constant
345 {
346 if (this_1D && (get1DArgIndex(thisData) != -1))
347 {
348 // possible ways that "this" could have full-extent, 1D data
349 // 1. A constant, B 1D
350 // 2. A 1D, B constant
351 // 3. A 1D, B 1D
352 // The constant possibilities are already addressed above, leaving us with (3). Note that A and B don't have to be full-extent, however
353 const int argThis = get1DArgIndex(thisData);
354 const int argA = get1DArgIndex(A); // if not full-extent, will be -1
355 const int argB = get1DArgIndex(B); // ditto
356
357 auto & A_underlying = A.template getUnderlyingView<1>();
358 auto & B_underlying = B.template getUnderlyingView<1>();
359 auto & this_underlying = thisData.template getUnderlyingView<1>();
360 if ((argA != -1) && (argB != -1))
361 {
362#ifdef INTREPID2_HAVE_DEBUG
363 INTREPID2_TEST_FOR_EXCEPTION(argA != argThis, std::logic_error, "Unexpected 1D arg combination.");
364 INTREPID2_TEST_FOR_EXCEPTION(argB != argThis, std::logic_error, "Unexpected 1D arg combination.");
365#endif
366 switch (argThis)
367 {
368 case 0: storeInPlaceCombination(policy, this_underlying, A_underlying, B_underlying, binaryOperator, arg0, arg0, arg0); break;
369 case 1: storeInPlaceCombination(policy, this_underlying, A_underlying, B_underlying, binaryOperator, arg1, arg1, arg1); break;
370 case 2: storeInPlaceCombination(policy, this_underlying, A_underlying, B_underlying, binaryOperator, arg2, arg2, arg2); break;
371 case 3: storeInPlaceCombination(policy, this_underlying, A_underlying, B_underlying, binaryOperator, arg3, arg3, arg3); break;
372 case 4: storeInPlaceCombination(policy, this_underlying, A_underlying, B_underlying, binaryOperator, arg4, arg4, arg4); break;
373 case 5: storeInPlaceCombination(policy, this_underlying, A_underlying, B_underlying, binaryOperator, arg5, arg5, arg5); break;
374 default: INTREPID2_TEST_FOR_EXCEPTION(true, std::invalid_argument, "Invalid/unexpected arg index");
375 }
376 }
377 else if (argA != -1)
378 {
379 // B is not full-extent in dimension argThis; use the Data object
380 switch (argThis)
381 {
382 case 0: storeInPlaceCombination(policy, this_underlying, A_underlying, B, binaryOperator, arg0, arg0, fullArgsData); break;
383 case 1: storeInPlaceCombination(policy, this_underlying, A_underlying, B, binaryOperator, arg1, arg1, fullArgsData); break;
384 case 2: storeInPlaceCombination(policy, this_underlying, A_underlying, B, binaryOperator, arg2, arg2, fullArgsData); break;
385 case 3: storeInPlaceCombination(policy, this_underlying, A_underlying, B, binaryOperator, arg3, arg3, fullArgsData); break;
386 case 4: storeInPlaceCombination(policy, this_underlying, A_underlying, B, binaryOperator, arg4, arg4, fullArgsData); break;
387 case 5: storeInPlaceCombination(policy, this_underlying, A_underlying, B, binaryOperator, arg5, arg5, fullArgsData); break;
388 default: INTREPID2_TEST_FOR_EXCEPTION(true, std::invalid_argument, "Invalid/unexpected arg index");
389 }
390 }
391 else
392 {
393 // A is not full-extent in dimension argThis; use the Data object
394 switch (argThis)
395 {
396 case 0: storeInPlaceCombination(policy, this_underlying, A, B_underlying, binaryOperator, arg0, fullArgsData, arg0); break;
397 case 1: storeInPlaceCombination(policy, this_underlying, A, B_underlying, binaryOperator, arg1, fullArgsData, arg1); break;
398 case 2: storeInPlaceCombination(policy, this_underlying, A, B_underlying, binaryOperator, arg2, fullArgsData, arg2); break;
399 case 3: storeInPlaceCombination(policy, this_underlying, A, B_underlying, binaryOperator, arg3, fullArgsData, arg3); break;
400 case 4: storeInPlaceCombination(policy, this_underlying, A, B_underlying, binaryOperator, arg4, fullArgsData, arg4); break;
401 case 5: storeInPlaceCombination(policy, this_underlying, A, B_underlying, binaryOperator, arg5, fullArgsData, arg5); break;
402 default: INTREPID2_TEST_FOR_EXCEPTION(true, std::invalid_argument, "Invalid/unexpected arg index");
403 }
404 }
405 }
406 else if (this_full)
407 {
408 // This case uses A,B Data objects; could be optimized by dividing into subcases and using underlying Views with appropriate ArgExtractors.
409 auto & this_underlying = thisData.template getUnderlyingView<rank>();
410 auto thisAE = fullArgs;
411
412 if (A_full)
413 {
414 auto & A_underlying = A.template getUnderlyingView<rank>();
415 auto AAE = fullArgs;
416
417 if (B_1D && (get1DArgIndex(B) != -1))
418 {
419 const int argIndex = get1DArgIndex(B);
420 auto & B_underlying = B.template getUnderlyingView<1>();
421 switch (argIndex)
422 {
423 case 0: storeInPlaceCombination(policy, this_underlying, A_underlying, B_underlying, binaryOperator, thisAE, AAE, arg0); break;
424 case 1: storeInPlaceCombination(policy, this_underlying, A_underlying, B_underlying, binaryOperator, thisAE, AAE, arg1); break;
425 case 2: storeInPlaceCombination(policy, this_underlying, A_underlying, B_underlying, binaryOperator, thisAE, AAE, arg2); break;
426 case 3: storeInPlaceCombination(policy, this_underlying, A_underlying, B_underlying, binaryOperator, thisAE, AAE, arg3); break;
427 case 4: storeInPlaceCombination(policy, this_underlying, A_underlying, B_underlying, binaryOperator, thisAE, AAE, arg4); break;
428 case 5: storeInPlaceCombination(policy, this_underlying, A_underlying, B_underlying, binaryOperator, thisAE, AAE, arg5); break;
429 default: INTREPID2_TEST_FOR_EXCEPTION(true, std::invalid_argument, "Invalid/unexpected arg index");
430 }
431 }
432 else
433 {
434 // A is full; B is not full, but not constant or full-extent 1D
435 // unoptimized in B access:
437 storeInPlaceCombination(policy, this_underlying, A_underlying, B, binaryOperator, thisAE, AAE, BAE);
438 }
439 }
440 else // A is not full
441 {
442 if (A_1D && (get1DArgIndex(A) != -1))
443 {
444 const int argIndex = get1DArgIndex(A);
445 auto & A_underlying = A.template getUnderlyingView<1>();
446 if (B_full)
447 {
448 auto & B_underlying = B.template getUnderlyingView<rank>();
449 auto BAE = fullArgs;
450 switch (argIndex)
451 {
452 case 0: storeInPlaceCombination(policy, this_underlying, A_underlying, B_underlying, binaryOperator, thisAE, arg0, BAE); break;
453 case 1: storeInPlaceCombination(policy, this_underlying, A_underlying, B_underlying, binaryOperator, thisAE, arg1, BAE); break;
454 case 2: storeInPlaceCombination(policy, this_underlying, A_underlying, B_underlying, binaryOperator, thisAE, arg2, BAE); break;
455 case 3: storeInPlaceCombination(policy, this_underlying, A_underlying, B_underlying, binaryOperator, thisAE, arg3, BAE); break;
456 case 4: storeInPlaceCombination(policy, this_underlying, A_underlying, B_underlying, binaryOperator, thisAE, arg4, BAE); break;
457 case 5: storeInPlaceCombination(policy, this_underlying, A_underlying, B_underlying, binaryOperator, thisAE, arg5, BAE); break;
458 default: INTREPID2_TEST_FOR_EXCEPTION(true, std::invalid_argument, "Invalid/unexpected arg index");
459 }
460 }
461 else
462 {
463 auto BAE = fullArgsData;
464 switch (argIndex)
465 {
466 case 0: storeInPlaceCombination(policy, this_underlying, A_underlying, B, binaryOperator, thisAE, arg0, BAE); break;
467 case 1: storeInPlaceCombination(policy, this_underlying, A_underlying, B, binaryOperator, thisAE, arg1, BAE); break;
468 case 2: storeInPlaceCombination(policy, this_underlying, A_underlying, B, binaryOperator, thisAE, arg2, BAE); break;
469 case 3: storeInPlaceCombination(policy, this_underlying, A_underlying, B, binaryOperator, thisAE, arg3, BAE); break;
470 case 4: storeInPlaceCombination(policy, this_underlying, A_underlying, B, binaryOperator, thisAE, arg4, BAE); break;
471 case 5: storeInPlaceCombination(policy, this_underlying, A_underlying, B, binaryOperator, thisAE, arg5, BAE); break;
472 default: INTREPID2_TEST_FOR_EXCEPTION(true, std::invalid_argument, "Invalid/unexpected arg index");
473 }
474 }
475 }
476 else // A not full, and not full-extent 1D
477 {
478 // unoptimized in A, B accesses.
479 auto AAE = fullArgsData;
480 auto BAE = fullArgsData;
481 storeInPlaceCombination(policy, this_underlying, A, B, binaryOperator, thisAE, AAE, BAE);
482 }
483 }
484 }
485 else
486 {
487 // completely un-optimized case: we use Data objects for this, A, B.
488 auto thisAE = fullArgsWritable;
489 auto AAE = fullArgsData;
490 auto BAE = fullArgsData;
491 storeInPlaceCombination(policy, thisData, A, B, binaryOperator, thisAE, AAE, BAE);
492 }
493 }
494 }
495
497 template<int rank>
498 static
499 enable_if_t<rank == 7, void>
501 {
502 auto policy = thisData.template dataExtentRangePolicy<rank>();
503
504 using DataType = Data<DataScalar,DeviceType>;
508
509 const ordinal_type dim6 = thisData.getDataExtent(6);
510 const bool includeInnerLoop = true;
512 Functor functor(thisData, A, B, binaryOperator, dim6);
513 Kokkos::parallel_for("compute in-place", policy, functor);
514 }
515
516 static void storeInPlaceCombination(Data<DataScalar,DeviceType> &thisData, const Data<DataScalar,DeviceType> &A, const Data<DataScalar,DeviceType> &B, BinaryOperator binaryOperator)
517 {
518 using ExecutionSpace = typename DeviceType::execution_space;
519
520#ifdef INTREPID2_HAVE_DEBUG
521 // check logical extents
522 for (int d=0; d<rank_; d++)
523 {
524 INTREPID2_TEST_FOR_EXCEPTION(A.extent_int(d) != thisData.extent_int(d), std::invalid_argument, "A, B, and this must agree on all logical extents");
525 INTREPID2_TEST_FOR_EXCEPTION(B.extent_int(d) != thisData.extent_int(d), std::invalid_argument, "A, B, and this must agree on all logical extents");
526 }
527 // TODO: add some checks that data extent of this suffices to accept combined A + B data.
528#endif
529
530 const bool this_constant = (thisData.getUnderlyingViewRank() == 1) && (thisData.getUnderlyingViewSize() == 1);
531
532 // we special-case for constant output here; since the constant case is essentially all overhead, we want to avoid as much of the overhead of storeInPlaceCombination() as possible…
533 if (this_constant)
534 {
535 // constant data
536 Kokkos::RangePolicy<ExecutionSpace> policy(ExecutionSpace(),0,1); // just 1 entry
537
538 auto this_underlying = thisData.template getUnderlyingView<1>();
539 auto A_underlying = A.template getUnderlyingView<1>();
540 auto B_underlying = B.template getUnderlyingView<1>();
541
542 using ConstantCaseFunctor = InPlaceCombinationFunctorConstantCase<decltype(binaryOperator), decltype(this_underlying),
543 decltype(A_underlying), decltype(B_underlying)>;
544
545 ConstantCaseFunctor functor(this_underlying, A_underlying, B_underlying, binaryOperator);
546 Kokkos::parallel_for("compute in-place", policy,functor);
547 }
548 else
549 {
550 switch (thisData.rank())
551 {
552 case 1: storeInPlaceCombination<1>(thisData, A, B, binaryOperator); break;
553 case 2: storeInPlaceCombination<2>(thisData, A, B, binaryOperator); break;
554 case 3: storeInPlaceCombination<3>(thisData, A, B, binaryOperator); break;
555 case 4: storeInPlaceCombination<4>(thisData, A, B, binaryOperator); break;
556 case 5: storeInPlaceCombination<5>(thisData, A, B, binaryOperator); break;
557 case 6: storeInPlaceCombination<6>(thisData, A, B, binaryOperator); break;
558 case 7: storeInPlaceCombination<7>(thisData, A, B, binaryOperator); break;
559 default:
560 INTREPID2_TEST_FOR_EXCEPTION_DEVICE_SAFE(true, std::logic_error, "unhandled rank in switch");
561 }
562 }
563 }
564};
565
566} // end namespace Intrepid2
567
568// We do ETI for basic double arithmetic on default device.
569//template<class Scalar> struct ScalarSumFunctor;
570//template<class Scalar> struct ScalarDifferenceFunctor;
571//template<class Scalar> struct ScalarProductFunctor;
572//template<class Scalar> struct ScalarQuotientFunctor;
573
578
579#endif /* Intrepid2_DataCombiners_hpp */
Header file with various static argument-extractor classes. These are useful for writing efficient,...
Defines functors for use with Data objects: so far, we include simple arithmetical functors for sum,...
Defines DataVariationType enum that specifies the types of variation possible within a Data object.
@ GENERAL
arbitrary variation
Defines the Data class, a wrapper around a Kokkos::View that allows data that is constant or repeatin...
#define INTREPID2_TEST_FOR_EXCEPTION_DEVICE_SAFE(test, x, msg)
static enable_if_t< rank==7, void > storeInPlaceCombination(Data< DataScalar, DeviceType > &thisData, const Data< DataScalar, DeviceType > &A, const Data< DataScalar, DeviceType > &B, BinaryOperator binaryOperator)
storeInPlaceCombination with compile-time rank – implementation for rank of 7. (Not optimized; expect...
static void storeInPlaceCombination(PolicyType &policy, ThisUnderlyingViewType &this_underlying, AUnderlyingViewType &A_underlying, BUnderlyingViewType &B_underlying, BinaryOperator &binaryOperator, ArgExtractorThis argThis, ArgExtractorA argA, ArgExtractorB argB)
storeInPlaceCombination implementation for rank < 7, with compile-time underlying views and argument ...
static enable_if_t< rank !=7, void > storeInPlaceCombination(Data< DataScalar, DeviceType > &thisData, const Data< DataScalar, DeviceType > &A, const Data< DataScalar, DeviceType > &B, BinaryOperator binaryOperator)
storeInPlaceCombination with compile-time rank – implementation for rank < 7.
Wrapper around a Kokkos::View that allows data that is constant or repeating in various logical dimen...
KOKKOS_INLINE_FUNCTION int extent_int(const int &r) const
Returns the logical extent in the specified dimension.
KOKKOS_INLINE_FUNCTION ordinal_type getUnderlyingViewSize() const
returns the number of entries in the View that stores the unique data
KOKKOS_INLINE_FUNCTION int getDataExtent(const ordinal_type &d) const
returns the true extent of the data corresponding to the logical dimension provided; if the data does...
KOKKOS_INLINE_FUNCTION bool underlyingMatchesLogical() const
Returns true if the underlying container has exactly the same rank and extents as the logical contain...
KOKKOS_INLINE_FUNCTION ordinal_type getUnderlyingViewRank() const
returns the rank of the View that stores the unique data
KOKKOS_INLINE_FUNCTION unsigned rank() const
Returns the logical rank of the Data container.
Argument extractor class which ignores the input arguments in favor of passing a single 0 argument to...
For use with Data object into which a value will be stored. We use passThroughBlockDiagonalArgs = tru...
For use with Data object into which a value will be stored. We use passThroughBlockDiagonalArgs = tru...
Argument extractor class which passes all arguments to the provided container.
functor definition for the constant-data case.
Argument extractor class which passes a single argument, indicated by the template parameter whichArg...