Intrepid2
Intrepid2_DirectSumBasis.hpp
Go to the documentation of this file.
1// @HEADER
2// *****************************************************************************
3// Intrepid2 Package
4//
5// Copyright 2007 NTESS and the Intrepid2 contributors.
6// SPDX-License-Identifier: BSD-3-Clause
7// *****************************************************************************
8// @HEADER
9
15#ifndef Intrepid2_DirectSumBasis_h
16#define Intrepid2_DirectSumBasis_h
17
18#include <Kokkos_DynRankView.hpp>
19
20namespace Intrepid2
21{
32 template<typename BasisBaseClass>
33 class Basis_DirectSumBasis : public BasisBaseClass
34 {
35 public:
36 using BasisBase = BasisBaseClass;
37 using BasisPtr = Teuchos::RCP<BasisBase>;
38
39 using DeviceType = typename BasisBase::DeviceType;
40 using ExecutionSpace = typename BasisBase::ExecutionSpace;
41 using OutputValueType = typename BasisBase::OutputValueType;
42 using PointValueType = typename BasisBase::PointValueType;
43
44 using OrdinalTypeArray1DHost = typename BasisBase::OrdinalTypeArray1DHost;
45 using OrdinalTypeArray2DHost = typename BasisBase::OrdinalTypeArray2DHost;
46 using OutputViewType = typename BasisBase::OutputViewType;
47 using PointViewType = typename BasisBase::PointViewType;
48 using ScalarViewType = typename BasisBase::ScalarViewType;
49 protected:
50 BasisPtr basis1_;
51 BasisPtr basis2_;
52
53 std::string name_;
54 public:
59 Basis_DirectSumBasis(BasisPtr basis1, BasisPtr basis2)
60 :
61 basis1_(basis1),basis2_(basis2)
62 {
63 INTREPID2_TEST_FOR_EXCEPTION(basis1->getBasisType() != basis2->getBasisType(), std::invalid_argument, "basis1 and basis2 must agree in basis type");
64 INTREPID2_TEST_FOR_EXCEPTION(basis1->getBaseCellTopology().getKey() != basis2->getBaseCellTopology().getKey(),
65 std::invalid_argument, "basis1 and basis2 must agree in cell topology");
66 INTREPID2_TEST_FOR_EXCEPTION(basis1->getNumTensorialExtrusions() != basis2->getNumTensorialExtrusions(),
67 std::invalid_argument, "basis1 and basis2 must agree in number of tensorial extrusions");
68 INTREPID2_TEST_FOR_EXCEPTION(basis1->getCoordinateSystem() != basis2->getCoordinateSystem(),
69 std::invalid_argument, "basis1 and basis2 must agree in coordinate system");
70
71 this->basisCardinality_ = basis1->getCardinality() + basis2->getCardinality();
72 this->basisDegree_ = std::max(basis1->getDegree(), basis2->getDegree());
73
74 {
75 std::ostringstream basisName;
76 basisName << basis1->getName() << " + " << basis2->getName();
77 name_ = basisName.str();
78 }
79
80 this->basisCellTopologyKey_ = basis1->getBaseCellTopology().getKey();
81 this->basisType_ = basis1->getBasisType();
82 this->basisCoordinates_ = basis1->getCoordinateSystem();
83
84 if (this->basisType_ == BASIS_FEM_HIERARCHICAL)
85 {
86 int degreeLength = basis1_->getPolynomialDegreeLength();
87 INTREPID2_TEST_FOR_EXCEPTION(degreeLength != basis2_->getPolynomialDegreeLength(), std::invalid_argument, "Basis1 and Basis2 must agree on polynomial degree length");
88
89 this->fieldOrdinalPolynomialDegree_ = OrdinalTypeArray2DHost("DirectSumBasis degree lookup", this->basisCardinality_,degreeLength);
90 this->fieldOrdinalH1PolynomialDegree_ = OrdinalTypeArray2DHost("DirectSumBasis H^1 degree lookup",this->basisCardinality_,degreeLength);
91 // our field ordinals start with basis1_; basis2_ follows
92 for (int fieldOrdinal1=0; fieldOrdinal1<basis1_->getCardinality(); fieldOrdinal1++)
93 {
94 int fieldOrdinal = fieldOrdinal1;
95 auto polynomialDegree = basis1->getPolynomialDegreeOfField(fieldOrdinal1);
96 auto polynomialH1Degree = basis1->getH1PolynomialDegreeOfField(fieldOrdinal1);
97 for (int d=0; d<degreeLength; d++)
98 {
99 this->fieldOrdinalPolynomialDegree_ (fieldOrdinal,d) = polynomialDegree(d);
100 this->fieldOrdinalH1PolynomialDegree_(fieldOrdinal,d) = polynomialH1Degree(d);
101 }
102 }
103 for (int fieldOrdinal2=0; fieldOrdinal2<basis2_->getCardinality(); fieldOrdinal2++)
104 {
105 int fieldOrdinal = basis1->getCardinality() + fieldOrdinal2;
106
107 auto polynomialDegree = basis2->getPolynomialDegreeOfField(fieldOrdinal2);
108 auto polynomialH1Degree = basis2->getH1PolynomialDegreeOfField(fieldOrdinal2);
109 for (int d=0; d<degreeLength; d++)
110 {
111 this->fieldOrdinalPolynomialDegree_ (fieldOrdinal,d) = polynomialDegree(d);
112 this->fieldOrdinalH1PolynomialDegree_(fieldOrdinal,d) = polynomialH1Degree(d);
113 }
114 }
115 }
116
117 // initialize tags
118 {
119 const auto & cardinality = this->basisCardinality_;
120
121 // Basis-dependent initializations
122 const ordinal_type tagSize = 4; // size of DoF tag, i.e., number of fields in the tag
123 const ordinal_type posScDim = 0; // position in the tag, counting from 0, of the subcell dim
124 const ordinal_type posScOrd = 1; // position in the tag, counting from 0, of the subcell ordinal
125 const ordinal_type posDfOrd = 2; // position in the tag, counting from 0, of DoF ordinal relative to the subcell
126
127 OrdinalTypeArray1DHost tagView("tag view", cardinality*tagSize);
128
129 shards::CellTopology cellTopo(getCellTopologyData(this->basisCellTopologyKey_));
130
131 unsigned spaceDim = cellTopo.getDimension();
132
133 ordinal_type basis2Offset = basis1_->getCardinality();
134
135 for (unsigned d=0; d<=spaceDim; d++)
136 {
137 unsigned subcellCount = cellTopo.getSubcellCount(d);
138 for (unsigned subcellOrdinal=0; subcellOrdinal<subcellCount; subcellOrdinal++)
139 {
140 ordinal_type subcellDofCount1 = basis1->getDofCount(d, subcellOrdinal);
141 ordinal_type subcellDofCount2 = basis2->getDofCount(d, subcellOrdinal);
142
143 ordinal_type subcellDofCount = subcellDofCount1 + subcellDofCount2;
144 for (ordinal_type localDofID=0; localDofID<subcellDofCount; localDofID++)
145 {
146 ordinal_type fieldOrdinal;
147 if (localDofID < subcellDofCount1)
148 {
149 // first basis: field ordinal matches the basis1 ordinal
150 fieldOrdinal = basis1_->getDofOrdinal(d, subcellOrdinal, localDofID);
151 }
152 else
153 {
154 // second basis: field ordinal is offset by basis1 cardinality
155 fieldOrdinal = basis2Offset + basis2_->getDofOrdinal(d, subcellOrdinal, localDofID - subcellDofCount1);
156 }
157 tagView(fieldOrdinal*tagSize+0) = d; // subcell dimension
158 tagView(fieldOrdinal*tagSize+1) = subcellOrdinal;
159 tagView(fieldOrdinal*tagSize+2) = localDofID;
160 tagView(fieldOrdinal*tagSize+3) = subcellDofCount;
161 }
162 }
163 }
164 // // Basis-independent function sets tag and enum data in tagToOrdinal_ and ordinalToTag_ arrays:
165 // // tags are constructed on host
166 this->setOrdinalTagData(this->tagToOrdinal_,
167 this->ordinalToTag_,
168 tagView,
169 this->basisCardinality_,
170 tagSize,
171 posScDim,
172 posScOrd,
173 posDfOrd);
174 }
175 }
176
182 virtual BasisValues<OutputValueType,DeviceType> allocateBasisValues( TensorPoints<PointValueType,DeviceType> points, const EOperator operatorType = OPERATOR_VALUE) const override
183 {
184 BasisValues<OutputValueType,DeviceType> basisValues1 = basis1_->allocateBasisValues(points, operatorType);
185 BasisValues<OutputValueType,DeviceType> basisValues2 = basis2_->allocateBasisValues(points, operatorType);
186
187 const int numScalarFamilies1 = basisValues1.numTensorDataFamilies();
188 if (numScalarFamilies1 > 0)
189 {
190 // then both basis1 and basis2 should be scalar-valued; check that for basis2:
191 const int numScalarFamilies2 = basisValues2.numTensorDataFamilies();
192 INTREPID2_TEST_FOR_EXCEPTION(basisValues2.numTensorDataFamilies() <=0, std::invalid_argument, "When basis1 has scalar value, basis2 must also");
193 std::vector< TensorData<OutputValueType,DeviceType> > scalarFamilies(numScalarFamilies1 + numScalarFamilies2);
194 for (int i=0; i<numScalarFamilies1; i++)
195 {
196 scalarFamilies[i] = basisValues1.tensorData(i);
197 }
198 for (int i=0; i<numScalarFamilies2; i++)
199 {
200 scalarFamilies[i+numScalarFamilies1] = basisValues2.tensorData(i);
201 }
202 return BasisValues<OutputValueType,DeviceType>(scalarFamilies);
203 }
204 else
205 {
206 // then both basis1 and basis2 should be vector-valued; check that:
207 INTREPID2_TEST_FOR_EXCEPTION(!basisValues1.vectorData().isValid(), std::invalid_argument, "When basis1 does not have tensorData() defined, it must have a valid vectorData()");
208 INTREPID2_TEST_FOR_EXCEPTION(basisValues2.numTensorDataFamilies() > 0, std::invalid_argument, "When basis1 has vector value, basis2 must also");
209
210 const auto & vectorData1 = basisValues1.vectorData();
211 const auto & vectorData2 = basisValues2.vectorData();
212
213 const int numFamilies1 = vectorData1.numFamilies();
214 const int numComponents = vectorData1.numComponents();
215 INTREPID2_TEST_FOR_EXCEPTION(numComponents != vectorData2.numComponents(), std::invalid_argument, "basis1 and basis2 must agree on the number of components in each vector");
216 const int numFamilies2 = vectorData2.numFamilies();
217
218 const int numFamilies = numFamilies1 + numFamilies2;
219 std::vector< std::vector<TensorData<OutputValueType,DeviceType> > > vectorComponents(numFamilies, std::vector<TensorData<OutputValueType,DeviceType> >(numComponents));
220
221 for (int i=0; i<numFamilies1; i++)
222 {
223 for (int j=0; j<numComponents; j++)
224 {
225 vectorComponents[i][j] = vectorData1.getComponent(i,j);
226 }
227 }
228 for (int i=0; i<numFamilies2; i++)
229 {
230 for (int j=0; j<numComponents; j++)
231 {
232 vectorComponents[i+numFamilies1][j] = vectorData2.getComponent(i,j);
233 }
234 }
235 VectorData<OutputValueType,DeviceType> vectorData(vectorComponents);
237 }
238 }
239
248 virtual void getDofCoords( ScalarViewType dofCoords ) const override {
249 const int basisCardinality1 = basis1_->getCardinality();
250 const int basisCardinality2 = basis2_->getCardinality();
251 const int basisCardinality = basisCardinality1 + basisCardinality2;
252
253 auto dofCoords1 = Kokkos::subview(dofCoords, std::make_pair(0,basisCardinality1), Kokkos::ALL());
254 auto dofCoords2 = Kokkos::subview(dofCoords, std::make_pair(basisCardinality1,basisCardinality), Kokkos::ALL());
255
256 basis1_->getDofCoords(dofCoords1);
257 basis2_->getDofCoords(dofCoords2);
258 }
259
271 virtual void getDofCoeffs( ScalarViewType dofCoeffs ) const override {
272 const int basisCardinality1 = basis1_->getCardinality();
273 const int basisCardinality2 = basis2_->getCardinality();
274 const int basisCardinality = basisCardinality1 + basisCardinality2;
275
276 auto dofCoeffs1 = Kokkos::subview(dofCoeffs, std::make_pair(0,basisCardinality1), Kokkos::ALL());
277 auto dofCoeffs2 = Kokkos::subview(dofCoeffs, std::make_pair(basisCardinality1,basisCardinality), Kokkos::ALL());
278
279 basis1_->getDofCoeffs(dofCoeffs1);
280 basis2_->getDofCoeffs(dofCoeffs2);
281 }
282
283
288 virtual
289 const char*
290 getName() const override {
291 return name_.c_str();
292 }
293
294 // since the getValues() below only overrides the FEM variants, we specify that
295 // we use the base class's getValues(), which implements the FVD variant by throwing an exception.
296 // (It's an error to use the FVD variant on this basis.)
297 using BasisBase::getValues;
298
310 virtual
311 void
314 const EOperator operatorType = OPERATOR_VALUE ) const override
315 {
316 const int fieldStartOrdinal1 = 0;
317 const int numFields1 = basis1_->getCardinality();
318 const int fieldStartOrdinal2 = numFields1;
319 const int numFields2 = basis2_->getCardinality();
320
321 auto basisValues1 = outputValues.basisValuesForFields(fieldStartOrdinal1, numFields1);
322 auto basisValues2 = outputValues.basisValuesForFields(fieldStartOrdinal2, numFields2);
323
324 basis1_->getValues(basisValues1, inputPoints, operatorType);
325 basis2_->getValues(basisValues2, inputPoints, operatorType);
326 }
327
346 virtual void getValues( OutputViewType outputValues, const PointViewType inputPoints,
347 const EOperator operatorType = OPERATOR_VALUE ) const override
348 {
349 int cardinality1 = basis1_->getCardinality();
350 int cardinality2 = basis2_->getCardinality();
351
352 auto range1 = std::make_pair(0,cardinality1);
353 auto range2 = std::make_pair(cardinality1,cardinality1+cardinality2);
354 if (outputValues.rank() == 2) // F,P
355 {
356 auto outputValues1 = Kokkos::subview(outputValues, range1, Kokkos::ALL());
357 auto outputValues2 = Kokkos::subview(outputValues, range2, Kokkos::ALL());
358
359 basis1_->getValues(outputValues1, inputPoints, operatorType);
360 basis2_->getValues(outputValues2, inputPoints, operatorType);
361 }
362 else if (outputValues.rank() == 3) // F,P,D
363 {
364 auto outputValues1 = Kokkos::subview(outputValues, range1, Kokkos::ALL(), Kokkos::ALL());
365 auto outputValues2 = Kokkos::subview(outputValues, range2, Kokkos::ALL(), Kokkos::ALL());
366
367 basis1_->getValues(outputValues1, inputPoints, operatorType);
368 basis2_->getValues(outputValues2, inputPoints, operatorType);
369 }
370 else if (outputValues.rank() == 4) // F,P,D,D
371 {
372 auto outputValues1 = Kokkos::subview(outputValues, range1, Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL());
373 auto outputValues2 = Kokkos::subview(outputValues, range2, Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL());
374
375 basis1_->getValues(outputValues1, inputPoints, operatorType);
376 basis2_->getValues(outputValues2, inputPoints, operatorType);
377 }
378 else if (outputValues.rank() == 5) // F,P,D,D,D
379 {
380 auto outputValues1 = Kokkos::subview(outputValues, range1, Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL());
381 auto outputValues2 = Kokkos::subview(outputValues, range2, Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL());
382
383 basis1_->getValues(outputValues1, inputPoints, operatorType);
384 basis2_->getValues(outputValues2, inputPoints, operatorType);
385 }
386 else if (outputValues.rank() == 6) // F,P,D,D,D,D
387 {
388 auto outputValues1 = Kokkos::subview(outputValues, range1, Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL());
389 auto outputValues2 = Kokkos::subview(outputValues, range2, Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL());
390
391 basis1_->getValues(outputValues1, inputPoints, operatorType);
392 basis2_->getValues(outputValues2, inputPoints, operatorType);
393 }
394 else if (outputValues.rank() == 7) // F,P,D,D,D,D,D
395 {
396 auto outputValues1 = Kokkos::subview(outputValues, range1, Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL());
397 auto outputValues2 = Kokkos::subview(outputValues, range2, Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL());
398
399 basis1_->getValues(outputValues1, inputPoints, operatorType);
400 basis2_->getValues(outputValues2, inputPoints, operatorType);
401 }
402 else
403 {
404 INTREPID2_TEST_FOR_EXCEPTION(true, std::invalid_argument, "Unsupported outputValues rank");
405 }
406 }
407
408 virtual int getNumTensorialExtrusions() const override
409 {
410 return basis1_->getNumTensorialExtrusions();
411 }
412 };
413} // end namespace Intrepid2
414
415#endif /* Intrepid2_DirectSumBasis_h */
The data containers in Intrepid2 that support sum factorization and other reduced-data optimizations ...
const VectorDataType & vectorData() const
VectorData accessor.
BasisValues< Scalar, DeviceType > basisValuesForFields(const int &fieldStartOrdinal, const int &numFields)
field start and length must align with families in vectorData_ or tensorDataFamilies_ (whichever is v...
TensorDataType & tensorData()
TensorData accessor for single-family scalar data.
A basis that is the direct sum of two other bases.
virtual const char * getName() const override
Returns basis name.
virtual BasisValues< OutputValueType, DeviceType > allocateBasisValues(TensorPoints< PointValueType, DeviceType > points, const EOperator operatorType=OPERATOR_VALUE) const override
Allocate BasisValues container suitable for passing to the getValues() variant that takes a TensorPoi...
Basis_DirectSumBasis(BasisPtr basis1, BasisPtr basis2)
Constructor.
virtual void getValues(OutputViewType outputValues, const PointViewType inputPoints, const EOperator operatorType=OPERATOR_VALUE) const override
Evaluation of a FEM basis on a reference cell.
virtual void getDofCoords(ScalarViewType dofCoords) const override
Fills in spatial locations (coordinates) of degrees of freedom (nodes) on the reference cell.
virtual void getValues(BasisValues< OutputValueType, DeviceType > outputValues, const TensorPoints< PointValueType, DeviceType > inputPoints, const EOperator operatorType=OPERATOR_VALUE) const override
Evaluation of a FEM basis on a reference cell, using point and output value containers that allow pre...
virtual void getDofCoeffs(ScalarViewType dofCoeffs) const override
Fills in coefficients of degrees of freedom for Lagrangian basis on the reference cell.
View-like interface to tensor data; tensor components are stored separately and multiplied together a...
View-like interface to tensor points; point components are stored separately; the appropriate coordin...
Reference-space field values for a basis, designed to support typical vector-valued bases.
KOKKOS_INLINE_FUNCTION constexpr bool isValid() const
returns true for containers that have data; false for those that don't (e.g., those that have been co...
KOKKOS_INLINE_FUNCTION int numFamilies() const
returns the number of families