10#ifndef TPETRA_APPLYDIRICHLETBOUNDARYCONDITION_HPP
11#define TPETRA_APPLYDIRICHLETBOUNDARYCONDITION_HPP
17#include "Tpetra_CrsMatrix.hpp"
18#include "Tpetra_Vector.hpp"
19#include "Tpetra_Map.hpp"
20#include "KokkosSparse_CrsMatrix.hpp"
21#include "KokkosKernels_ArithTraits.hpp"
34template <
class CrsMatrixType>
38 typename CrsMatrixType::local_ordinal_type*,
39 typename CrsMatrixType::device_type>& lclRowInds);
50template <
class CrsMatrixType>
53 typename CrsMatrixType::local_ordinal_type*,
54 typename CrsMatrixType::device_type>& lclRowInds);
65template <
class CrsMatrixType>
68 typename CrsMatrixType::local_ordinal_type*,
69 Kokkos::HostSpace>& lclRowInds);
80template <
class CrsMatrixType>
84 typename CrsMatrixType::local_ordinal_type*,
85 typename CrsMatrixType::device_type>& lclRowInds);
97template <
class CrsMatrixType>
100 typename CrsMatrixType::local_ordinal_type*,
101 Kokkos::HostSpace>& lclRowInds);
112template <
class CrsMatrixType>
115 typename CrsMatrixType::local_ordinal_type*,
116 typename CrsMatrixType::device_type>& lclRowInds);
120template <
class SC,
class LO,
class GO,
class NT>
121struct ApplyDirichletBoundaryConditionToLocalMatrixRows {
124 using local_row_indices_type =
125 Kokkos::View<const LO*, Kokkos::AnonymousSpace>;
128 run(
const execution_space& execSpace,
130 const local_row_indices_type& lclRowInds,
131 const bool runOnHost) {
138 using KAT = KokkosKernels::ArithTraits<IST>;
140 const auto rowMap = A.getRowMap();
141 TEUCHOS_TEST_FOR_EXCEPTION(rowMap.get() ==
nullptr, std::invalid_argument,
142 "The matrix must have a row Map.");
143 const auto colMap = A.getColMap();
144 TEUCHOS_TEST_FOR_EXCEPTION(colMap.get() ==
nullptr, std::invalid_argument,
145 "The matrix must have a column Map.");
146 auto A_lcl = A.getLocalMatrixDevice();
148 const LO lclNumRows =
static_cast<LO
>(rowMap->getLocalNumElements());
149 TEUCHOS_TEST_FOR_EXCEPTION(lclNumRows != 0 &&
static_cast<LO
>(A_lcl.graph.numRows()) != lclNumRows,
150 std::invalid_argument,
151 "The matrix must have been either created "
152 "with a KokkosSparse::CrsMatrix, or must have been fill-completed "
155 auto lclRowMap = A.getRowMap()->getLocalMap();
156 auto lclColMap = A.getColMap()->getLocalMap();
157 auto rowptr = A_lcl.graph.row_map;
158 auto colind = A_lcl.graph.entries;
159 auto values = A_lcl.values;
161 const bool wasFillComplete = A.isFillComplete();
162 if (wasFillComplete) {
166 const LO numInputRows = lclRowInds.extent(0);
168 using range_type = Kokkos::RangePolicy<execution_space, LO>;
169 Kokkos::parallel_for(
170 "Tpetra::CrsMatrix apply Dirichlet: Device",
171 range_type(execSpace, 0, numInputRows),
172 KOKKOS_LAMBDA(
const LO i) {
173 LO row = lclRowInds(i);
174 const GO row_gid = lclRowMap.getGlobalElement(row);
175 for (
auto j = rowptr(row); j < rowptr(row + 1); ++j) {
177 lclColMap.getGlobalElement(colind(j)) == row_gid;
178 values(j) = diagEnt ? KAT::one() : KAT::zero();
183 Kokkos::RangePolicy<Kokkos::DefaultHostExecutionSpace, LO>;
184 Kokkos::parallel_for(
"Tpetra::CrsMatrix apply Dirichlet: Host",
185 range_type(0, numInputRows),
187 LO row = lclRowInds(i);
188 const GO row_gid = lclRowMap.getGlobalElement(row);
189 for (
auto j = rowptr(row); j < rowptr(row + 1); ++j) {
191 lclColMap.getGlobalElement(colind(j)) == row_gid;
192 values(j) = diagEnt ? KAT::one() : KAT::zero();
196 if (wasFillComplete) {
197 A.fillComplete(A.getDomainMap(), A.getRangeMap());
202template <
class SC,
class LO,
class GO,
class NT>
203struct ApplyDirichletBoundaryConditionToLocalMatrixColumns {
206 using local_col_flag_type =
207 Kokkos::View<bool*, Kokkos::AnonymousSpace>;
210 run(
const execution_space& execSpace,
212 const local_col_flag_type& lclColFlags,
213 const bool runOnHost) {
220 using KAT = KokkosKernels::ArithTraits<IST>;
222 const auto rowMap = A.getRowMap();
223 TEUCHOS_TEST_FOR_EXCEPTION(rowMap.get() ==
nullptr, std::invalid_argument,
224 "The matrix must have a row Map.");
225 const auto colMap = A.getColMap();
226 TEUCHOS_TEST_FOR_EXCEPTION(colMap.get() ==
nullptr, std::invalid_argument,
227 "The matrix must have a column Map.");
228 auto A_lcl = A.getLocalMatrixDevice();
230 const LO lclNumRows =
static_cast<LO
>(rowMap->getLocalNumElements());
231 TEUCHOS_TEST_FOR_EXCEPTION(lclNumRows != 0 &&
static_cast<LO
>(A_lcl.graph.numRows()) != lclNumRows,
232 std::invalid_argument,
233 "The matrix must have been either created "
234 "with a KokkosSparse::CrsMatrix, or must have been fill-completed "
237 auto lclRowMap = A.getRowMap()->getLocalMap();
238 auto lclColMap = A.getColMap()->getLocalMap();
239 auto rowptr = A_lcl.graph.row_map;
240 auto colind = A_lcl.graph.entries;
241 auto values = A_lcl.values;
243 const bool wasFillComplete = A.isFillComplete();
244 if (wasFillComplete) {
248 const LO numRows = (LO)A.getRowMap()->getLocalNumElements();
250 using range_type = Kokkos::RangePolicy<execution_space, LO>;
251 Kokkos::parallel_for(
252 "Tpetra::CrsMatrix apply Dirichlet cols: Device",
253 range_type(execSpace, 0, numRows),
254 KOKKOS_LAMBDA(
const LO i) {
255 for (
auto j = rowptr(i); j < rowptr(i + 1); ++j) {
256 if (lclColFlags[colind[j]])
257 values(j) = KAT::zero();
262 Kokkos::RangePolicy<Kokkos::DefaultHostExecutionSpace, LO>;
263 Kokkos::parallel_for(
264 "Tpetra::CrsMatrix apply Dirichlet cols: Host",
265 range_type(0, numRows),
266 KOKKOS_LAMBDA(
const LO i) {
267 for (
auto j = rowptr(i); j < rowptr(i + 1); ++j) {
268 if (lclColFlags[colind[j]])
269 values(j) = KAT::zero();
273 if (wasFillComplete) {
274 A.fillComplete(A.getDomainMap(), A.getRangeMap());
279template <
class SC,
class LO,
class GO,
class NT>
280void localRowsToColumns(
const typename ::Tpetra::CrsMatrix<SC, LO, GO, NT>::execution_space& execSpace, const ::Tpetra::CrsMatrix<SC, LO, GO, NT>& A,
const Kokkos::View<const LO*, Kokkos::AnonymousSpace>& dirichletRowIds, Kokkos::View<bool*, Kokkos::AnonymousSpace>& dirichletColFlags) {
282 using execution_space =
typename crs_matrix_type::execution_space;
283 using memory_space =
typename crs_matrix_type::device_type::memory_space;
286 TEUCHOS_TEST_FOR_EXCEPTION(A.getColMap().get() ==
nullptr, std::invalid_argument,
"The matrix must have a column Map.");
290 TEUCHOS_TEST_FOR_EXCEPTION(!A.getRowMap()->isSameAs(*A.getDomainMap()), std::invalid_argument,
"localRowsToColumns: Row/Domain maps do not match");
293 TEUCHOS_TEST_FOR_EXCEPTION(A.getColMap()->getLocalNumElements() != dirichletColFlags.size(), std::invalid_argument,
"localRowsToColumns: dirichletColFlags must be the correct size");
295 LO numDirichletRows = (LO)dirichletRowIds.size();
296 LO LO_INVALID = Teuchos::OrdinalTraits<LO>::invalid();
298 if (A.getCrsGraph()->getImporter().is_null()) {
300 using range_type = Kokkos::RangePolicy<execution_space, LO>;
301 auto lclRowMap = A.getRowMap()->getLocalMap();
302 auto lclColMap = A.getColMap()->getLocalMap();
304 Kokkos::deep_copy(execSpace, dirichletColFlags,
false);
305 using range_type = Kokkos::RangePolicy<execution_space, LO>;
306 Kokkos::parallel_for(
307 "Tpetra::CrsMatrix flag Dirichlet cols",
308 range_type(execSpace, 0, numDirichletRows),
309 KOKKOS_LAMBDA(
const LO i) {
310 GO row_gid = lclRowMap.getGlobalElement(dirichletRowIds[i]);
311 LO col_lid = lclColMap.getLocalElement(row_gid);
312 if (col_lid != LO_INVALID)
313 dirichletColFlags[col_lid] =
true;
317 auto Importer = A.getCrsGraph()->getImporter();
318 auto lclRowMap = A.getRowMap()->getLocalMap();
319 auto lclColMap = A.getColMap()->getLocalMap();
322 const LO one = Teuchos::OrdinalTraits<LO>::one();
323 using range_type = Kokkos::RangePolicy<execution_space, LO>;
325 auto domain_data = domainDirichlet.template getLocalView<memory_space>(Access::ReadWrite);
326 Kokkos::parallel_for(
327 "Tpetra::CrsMatrix flag Dirichlet domains",
328 range_type(execSpace, 0, numDirichletRows),
329 KOKKOS_LAMBDA(
const LO i) {
330 GO row_gid = lclRowMap.getGlobalElement(dirichletRowIds[i]);
331 LO col_lid = lclColMap.getLocalElement(row_gid);
332 if (col_lid != LO_INVALID)
333 domain_data(col_lid, 0) = one;
337 LO numCols = (LO)A.getColMap()->getLocalNumElements();
339 auto col_data = colDirichlet.template getLocalView<memory_space>(Access::ReadOnly);
340 Kokkos::parallel_for(
341 "Tpetra::CrsMatrix flag Dirichlet cols",
342 range_type(execSpace, 0, numCols),
343 KOKKOS_LAMBDA(
const LO i) {
344 dirichletColFlags[i] = (col_data(i, 0) == one) ?
true : false;
352template <
class CrsMatrixType>
356 typename CrsMatrixType::local_ordinal_type*,
357 typename CrsMatrixType::device_type>&
lclRowInds) {
358 using SC =
typename CrsMatrixType::scalar_type;
359 using LO =
typename CrsMatrixType::local_ordinal_type;
360 using GO =
typename CrsMatrixType::global_ordinal_type;
361 using NT =
typename CrsMatrixType::node_type;
363 using local_row_indices_type =
364 Kokkos::View<const LO*, Kokkos::AnonymousSpace>;
367 using Details::ApplyDirichletBoundaryConditionToLocalMatrixRows;
374template <
class CrsMatrixType>
377 typename CrsMatrixType::local_ordinal_type*,
378 typename CrsMatrixType::device_type>&
lclRowInds) {
379 using execution_space =
typename CrsMatrixType::execution_space;
383template <
class CrsMatrixType>
386 typename CrsMatrixType::local_ordinal_type*,
388 using SC =
typename CrsMatrixType::scalar_type;
389 using LO =
typename CrsMatrixType::local_ordinal_type;
390 using GO =
typename CrsMatrixType::global_ordinal_type;
391 using NT =
typename CrsMatrixType::node_type;
393 using execution_space =
typename crs_matrix_type::execution_space;
394 using memory_space =
typename crs_matrix_type::device_type::memory_space;
396 using Details::ApplyDirichletBoundaryConditionToLocalMatrixRows;
401 const bool runOnHost = Kokkos::SpaceAccessibility<Kokkos::Serial, memory_space>::accessible;
403 using local_row_indices_type = Kokkos::View<const LO*, Kokkos::AnonymousSpace>;
412template <
class CrsMatrixType>
415 typename CrsMatrixType::local_ordinal_type*,
417 using SC =
typename CrsMatrixType::scalar_type;
418 using LO =
typename CrsMatrixType::local_ordinal_type;
419 using GO =
typename CrsMatrixType::global_ordinal_type;
420 using NT =
typename CrsMatrixType::node_type;
422 using execution_space =
typename crs_matrix_type::execution_space;
423 using memory_space =
typename crs_matrix_type::device_type::memory_space;
430 Kokkos::View<bool*, memory_space>
dirichletColFlags(
"dirichletColFlags",
A.getColMap()->getLocalNumElements());
434 Details::ApplyDirichletBoundaryConditionToLocalMatrixColumns<SC, LO, GO, NT>::run(execution_space(),
A,
dirichletColFlags,
false);
435 Details::ApplyDirichletBoundaryConditionToLocalMatrixRows<SC, LO, GO, NT>::run(execution_space(),
A,
lclRowInds_d,
false);
438template <
class CrsMatrixType>
441 typename CrsMatrixType::local_ordinal_type*,
443 using SC =
typename CrsMatrixType::scalar_type;
444 using LO =
typename CrsMatrixType::local_ordinal_type;
445 using GO =
typename CrsMatrixType::global_ordinal_type;
446 using NT =
typename CrsMatrixType::node_type;
448 using execution_space =
typename crs_matrix_type::execution_space;
449 using memory_space =
typename crs_matrix_type::device_type::memory_space;
453 Kokkos::View<bool*, memory_space>
dirichletColFlags(
"dirichletColFlags",
A.getColMap()->getLocalNumElements());
457 Details::ApplyDirichletBoundaryConditionToLocalMatrixColumns<SC, LO, GO, NT>::run(execution_space(),
A,
dirichletColFlags,
false);
458 Details::ApplyDirichletBoundaryConditionToLocalMatrixRows<SC, LO, GO, NT>::run(execution_space(),
A,
lclRowInds_d,
false);
461template <
class CrsMatrixType>
465 typename CrsMatrixType::local_ordinal_type*,
467 using SC =
typename CrsMatrixType::scalar_type;
468 using LO =
typename CrsMatrixType::local_ordinal_type;
469 using GO =
typename CrsMatrixType::global_ordinal_type;
470 using NT =
typename CrsMatrixType::node_type;
473 using memory_space =
typename crs_matrix_type::device_type::memory_space;
477 Kokkos::View<bool*, memory_space>
dirichletColFlags(
"dirichletColFlags",
A.getColMap()->getLocalNumElements());
482 Details::ApplyDirichletBoundaryConditionToLocalMatrixRows<SC, LO, GO, NT>::run(
execSpace,
A,
lclRowInds_d,
false);
typename device_type::execution_space execution_space
The Kokkos execution space.
typename row_matrix_type::impl_scalar_type impl_scalar_type
The type used internally in place of Scalar.
Struct that holds views of the contents of a CrsMatrix.
Implementation details of Tpetra.
Namespace Tpetra contains the class and methods constituting the Tpetra library.
void applyDirichletBoundaryConditionToLocalMatrixRows(const typename CrsMatrixType::execution_space &execSpace, CrsMatrixType &A, const Kokkos::View< typename CrsMatrixType::local_ordinal_type *, typename CrsMatrixType::device_type > &lclRowInds)
For all k in [0, lclRowInds.extent(0)), set local row lclRowInds[k] of A to have 1 on the diagonal an...
@ INSERT
Insert new values that don't currently exist.
void applyDirichletBoundaryConditionToLocalMatrixRowsAndColumns(const typename CrsMatrixType::execution_space &execSpace, CrsMatrixType &A, const Kokkos::View< typename CrsMatrixType::local_ordinal_type *, typename CrsMatrixType::device_type > &lclRowInds)
For all k in [0, lclRowInds.extent(0)), set local row and column lclRowInds[k] of A to have 1 on the ...