10#ifndef TPETRA_APPLYDIRICHLETBOUNDARYCONDITION_HPP
11#define TPETRA_APPLYDIRICHLETBOUNDARYCONDITION_HPP
17#include "Tpetra_CrsMatrix.hpp"
18#include "Tpetra_Vector.hpp"
19#include "Tpetra_Map.hpp"
20#include "KokkosSparse_CrsMatrix.hpp"
21#if KOKKOS_VERSION >= 40799
22#include "KokkosKernels_ArithTraits.hpp"
24#include "Kokkos_ArithTraits.hpp"
38template <
class CrsMatrixType>
42 typename CrsMatrixType::local_ordinal_type*,
43 typename CrsMatrixType::device_type>& lclRowInds);
54template <
class CrsMatrixType>
57 typename CrsMatrixType::local_ordinal_type*,
58 typename CrsMatrixType::device_type>& lclRowInds);
69template <
class CrsMatrixType>
72 typename CrsMatrixType::local_ordinal_type*,
73 Kokkos::HostSpace>& lclRowInds);
84template <
class CrsMatrixType>
88 typename CrsMatrixType::local_ordinal_type*,
89 typename CrsMatrixType::device_type>& lclRowInds);
101template <
class CrsMatrixType>
104 typename CrsMatrixType::local_ordinal_type*,
105 Kokkos::HostSpace>& lclRowInds);
116template <
class CrsMatrixType>
119 typename CrsMatrixType::local_ordinal_type*,
120 typename CrsMatrixType::device_type>& lclRowInds);
124template <
class SC,
class LO,
class GO,
class NT>
125struct ApplyDirichletBoundaryConditionToLocalMatrixRows {
128 using local_row_indices_type =
129 Kokkos::View<const LO*, Kokkos::AnonymousSpace>;
132 run(
const execution_space& execSpace,
134 const local_row_indices_type& lclRowInds,
135 const bool runOnHost) {
142#if KOKKOS_VERSION >= 40799
143 using KAT = KokkosKernels::ArithTraits<IST>;
145 using KAT = Kokkos::ArithTraits<IST>;
148 const auto rowMap = A.getRowMap();
149 TEUCHOS_TEST_FOR_EXCEPTION(rowMap.get() ==
nullptr, std::invalid_argument,
150 "The matrix must have a row Map.");
151 const auto colMap = A.getColMap();
152 TEUCHOS_TEST_FOR_EXCEPTION(colMap.get() ==
nullptr, std::invalid_argument,
153 "The matrix must have a column Map.");
154 auto A_lcl = A.getLocalMatrixDevice();
156 const LO lclNumRows =
static_cast<LO
>(rowMap->getLocalNumElements());
157 TEUCHOS_TEST_FOR_EXCEPTION(lclNumRows != 0 &&
static_cast<LO
>(A_lcl.graph.numRows()) != lclNumRows,
158 std::invalid_argument,
159 "The matrix must have been either created "
160 "with a KokkosSparse::CrsMatrix, or must have been fill-completed "
163 auto lclRowMap = A.getRowMap()->getLocalMap();
164 auto lclColMap = A.getColMap()->getLocalMap();
165 auto rowptr = A_lcl.graph.row_map;
166 auto colind = A_lcl.graph.entries;
167 auto values = A_lcl.values;
169 const bool wasFillComplete = A.isFillComplete();
170 if (wasFillComplete) {
174 const LO numInputRows = lclRowInds.extent(0);
176 using range_type = Kokkos::RangePolicy<execution_space, LO>;
177 Kokkos::parallel_for(
178 "Tpetra::CrsMatrix apply Dirichlet: Device",
179 range_type(execSpace, 0, numInputRows),
180 KOKKOS_LAMBDA(
const LO i) {
181 LO row = lclRowInds(i);
182 const GO row_gid = lclRowMap.getGlobalElement(row);
183 for (
auto j = rowptr(row); j < rowptr(row + 1); ++j) {
185 lclColMap.getGlobalElement(colind(j)) == row_gid;
186 values(j) = diagEnt ? KAT::one() : KAT::zero();
191 Kokkos::RangePolicy<Kokkos::DefaultHostExecutionSpace, LO>;
192 Kokkos::parallel_for(
"Tpetra::CrsMatrix apply Dirichlet: Host",
193 range_type(0, numInputRows),
195 LO row = lclRowInds(i);
196 const GO row_gid = lclRowMap.getGlobalElement(row);
197 for (
auto j = rowptr(row); j < rowptr(row + 1); ++j) {
199 lclColMap.getGlobalElement(colind(j)) == row_gid;
200 values(j) = diagEnt ? KAT::one() : KAT::zero();
204 if (wasFillComplete) {
205 A.fillComplete(A.getDomainMap(), A.getRangeMap());
210template <
class SC,
class LO,
class GO,
class NT>
211struct ApplyDirichletBoundaryConditionToLocalMatrixColumns {
214 using local_col_flag_type =
215 Kokkos::View<bool*, Kokkos::AnonymousSpace>;
218 run(
const execution_space& execSpace,
220 const local_col_flag_type& lclColFlags,
221 const bool runOnHost) {
228#if KOKKOS_VERSION >= 40799
229 using KAT = KokkosKernels::ArithTraits<IST>;
231 using KAT = Kokkos::ArithTraits<IST>;
234 const auto rowMap = A.getRowMap();
235 TEUCHOS_TEST_FOR_EXCEPTION(rowMap.get() ==
nullptr, std::invalid_argument,
236 "The matrix must have a row Map.");
237 const auto colMap = A.getColMap();
238 TEUCHOS_TEST_FOR_EXCEPTION(colMap.get() ==
nullptr, std::invalid_argument,
239 "The matrix must have a column Map.");
240 auto A_lcl = A.getLocalMatrixDevice();
242 const LO lclNumRows =
static_cast<LO
>(rowMap->getLocalNumElements());
243 TEUCHOS_TEST_FOR_EXCEPTION(lclNumRows != 0 &&
static_cast<LO
>(A_lcl.graph.numRows()) != lclNumRows,
244 std::invalid_argument,
245 "The matrix must have been either created "
246 "with a KokkosSparse::CrsMatrix, or must have been fill-completed "
249 auto lclRowMap = A.getRowMap()->getLocalMap();
250 auto lclColMap = A.getColMap()->getLocalMap();
251 auto rowptr = A_lcl.graph.row_map;
252 auto colind = A_lcl.graph.entries;
253 auto values = A_lcl.values;
255 const bool wasFillComplete = A.isFillComplete();
256 if (wasFillComplete) {
260 const LO numRows = (LO)A.getRowMap()->getLocalNumElements();
262 using range_type = Kokkos::RangePolicy<execution_space, LO>;
263 Kokkos::parallel_for(
264 "Tpetra::CrsMatrix apply Dirichlet cols: Device",
265 range_type(execSpace, 0, numRows),
266 KOKKOS_LAMBDA(
const LO i) {
267 for (
auto j = rowptr(i); j < rowptr(i + 1); ++j) {
268 if (lclColFlags[colind[j]])
269 values(j) = KAT::zero();
274 Kokkos::RangePolicy<Kokkos::DefaultHostExecutionSpace, LO>;
275 Kokkos::parallel_for(
276 "Tpetra::CrsMatrix apply Dirichlet cols: Host",
277 range_type(0, numRows),
278 KOKKOS_LAMBDA(
const LO i) {
279 for (
auto j = rowptr(i); j < rowptr(i + 1); ++j) {
280 if (lclColFlags[colind[j]])
281 values(j) = KAT::zero();
285 if (wasFillComplete) {
286 A.fillComplete(A.getDomainMap(), A.getRangeMap());
291template <
class SC,
class LO,
class GO,
class NT>
292void localRowsToColumns(
const typename ::Tpetra::CrsMatrix<SC, LO, GO, NT>::execution_space& execSpace, const ::Tpetra::CrsMatrix<SC, LO, GO, NT>& A,
const Kokkos::View<const LO*, Kokkos::AnonymousSpace>& dirichletRowIds, Kokkos::View<bool*, Kokkos::AnonymousSpace>& dirichletColFlags) {
294 using execution_space =
typename crs_matrix_type::execution_space;
295 using memory_space =
typename crs_matrix_type::device_type::memory_space;
298 TEUCHOS_TEST_FOR_EXCEPTION(A.getColMap().get() ==
nullptr, std::invalid_argument,
"The matrix must have a column Map.");
302 TEUCHOS_TEST_FOR_EXCEPTION(!A.getRowMap()->isSameAs(*A.getDomainMap()), std::invalid_argument,
"localRowsToColumns: Row/Domain maps do not match");
305 TEUCHOS_TEST_FOR_EXCEPTION(A.getColMap()->getLocalNumElements() != dirichletColFlags.size(), std::invalid_argument,
"localRowsToColumns: dirichletColFlags must be the correct size");
307 LO numDirichletRows = (LO)dirichletRowIds.size();
308 LO LO_INVALID = Teuchos::OrdinalTraits<LO>::invalid();
310 if (A.getCrsGraph()->getImporter().is_null()) {
312 using range_type = Kokkos::RangePolicy<execution_space, LO>;
313 auto lclRowMap = A.getRowMap()->getLocalMap();
314 auto lclColMap = A.getColMap()->getLocalMap();
316 Kokkos::deep_copy(execSpace, dirichletColFlags,
false);
317 using range_type = Kokkos::RangePolicy<execution_space, LO>;
318 Kokkos::parallel_for(
319 "Tpetra::CrsMatrix flag Dirichlet cols",
320 range_type(execSpace, 0, numDirichletRows),
321 KOKKOS_LAMBDA(
const LO i) {
322 GO row_gid = lclRowMap.getGlobalElement(dirichletRowIds[i]);
323 LO col_lid = lclColMap.getLocalElement(row_gid);
324 if (col_lid != LO_INVALID)
325 dirichletColFlags[col_lid] =
true;
329 auto Importer = A.getCrsGraph()->getImporter();
330 auto lclRowMap = A.getRowMap()->getLocalMap();
331 auto lclColMap = A.getColMap()->getLocalMap();
334 const LO one = Teuchos::OrdinalTraits<LO>::one();
335 using range_type = Kokkos::RangePolicy<execution_space, LO>;
337 auto domain_data = domainDirichlet.template getLocalView<memory_space>(Access::ReadWrite);
338 Kokkos::parallel_for(
339 "Tpetra::CrsMatrix flag Dirichlet domains",
340 range_type(execSpace, 0, numDirichletRows),
341 KOKKOS_LAMBDA(
const LO i) {
342 GO row_gid = lclRowMap.getGlobalElement(dirichletRowIds[i]);
343 LO col_lid = lclColMap.getLocalElement(row_gid);
344 if (col_lid != LO_INVALID)
345 domain_data(col_lid, 0) = one;
349 LO numCols = (LO)A.getColMap()->getLocalNumElements();
351 auto col_data = colDirichlet.template getLocalView<memory_space>(Access::ReadOnly);
352 Kokkos::parallel_for(
353 "Tpetra::CrsMatrix flag Dirichlet cols",
354 range_type(execSpace, 0, numCols),
355 KOKKOS_LAMBDA(
const LO i) {
356 dirichletColFlags[i] = (col_data(i, 0) == one) ?
true : false;
364template <
class CrsMatrixType>
368 typename CrsMatrixType::local_ordinal_type*,
369 typename CrsMatrixType::device_type>&
lclRowInds) {
370 using SC =
typename CrsMatrixType::scalar_type;
371 using LO =
typename CrsMatrixType::local_ordinal_type;
372 using GO =
typename CrsMatrixType::global_ordinal_type;
373 using NT =
typename CrsMatrixType::node_type;
375 using local_row_indices_type =
376 Kokkos::View<const LO*, Kokkos::AnonymousSpace>;
379 using Details::ApplyDirichletBoundaryConditionToLocalMatrixRows;
386template <
class CrsMatrixType>
389 typename CrsMatrixType::local_ordinal_type*,
390 typename CrsMatrixType::device_type>&
lclRowInds) {
391 using execution_space =
typename CrsMatrixType::execution_space;
395template <
class CrsMatrixType>
398 typename CrsMatrixType::local_ordinal_type*,
400 using SC =
typename CrsMatrixType::scalar_type;
401 using LO =
typename CrsMatrixType::local_ordinal_type;
402 using GO =
typename CrsMatrixType::global_ordinal_type;
403 using NT =
typename CrsMatrixType::node_type;
405 using execution_space =
typename crs_matrix_type::execution_space;
406 using memory_space =
typename crs_matrix_type::device_type::memory_space;
408 using Details::ApplyDirichletBoundaryConditionToLocalMatrixRows;
413 const bool runOnHost = Kokkos::SpaceAccessibility<Kokkos::Serial, memory_space>::accessible;
415 using local_row_indices_type = Kokkos::View<const LO*, Kokkos::AnonymousSpace>;
424template <
class CrsMatrixType>
427 typename CrsMatrixType::local_ordinal_type*,
429 using SC =
typename CrsMatrixType::scalar_type;
430 using LO =
typename CrsMatrixType::local_ordinal_type;
431 using GO =
typename CrsMatrixType::global_ordinal_type;
432 using NT =
typename CrsMatrixType::node_type;
434 using execution_space =
typename crs_matrix_type::execution_space;
435 using memory_space =
typename crs_matrix_type::device_type::memory_space;
442 Kokkos::View<bool*, memory_space>
dirichletColFlags(
"dirichletColFlags",
A.getColMap()->getLocalNumElements());
446 Details::ApplyDirichletBoundaryConditionToLocalMatrixColumns<SC, LO, GO, NT>::run(execution_space(),
A,
dirichletColFlags,
false);
447 Details::ApplyDirichletBoundaryConditionToLocalMatrixRows<SC, LO, GO, NT>::run(execution_space(),
A,
lclRowInds_d,
false);
450template <
class CrsMatrixType>
453 typename CrsMatrixType::local_ordinal_type*,
455 using SC =
typename CrsMatrixType::scalar_type;
456 using LO =
typename CrsMatrixType::local_ordinal_type;
457 using GO =
typename CrsMatrixType::global_ordinal_type;
458 using NT =
typename CrsMatrixType::node_type;
460 using execution_space =
typename crs_matrix_type::execution_space;
461 using memory_space =
typename crs_matrix_type::device_type::memory_space;
465 Kokkos::View<bool*, memory_space>
dirichletColFlags(
"dirichletColFlags",
A.getColMap()->getLocalNumElements());
469 Details::ApplyDirichletBoundaryConditionToLocalMatrixColumns<SC, LO, GO, NT>::run(execution_space(),
A,
dirichletColFlags,
false);
470 Details::ApplyDirichletBoundaryConditionToLocalMatrixRows<SC, LO, GO, NT>::run(execution_space(),
A,
lclRowInds_d,
false);
473template <
class CrsMatrixType>
477 typename CrsMatrixType::local_ordinal_type*,
479 using SC =
typename CrsMatrixType::scalar_type;
480 using LO =
typename CrsMatrixType::local_ordinal_type;
481 using GO =
typename CrsMatrixType::global_ordinal_type;
482 using NT =
typename CrsMatrixType::node_type;
485 using memory_space =
typename crs_matrix_type::device_type::memory_space;
489 Kokkos::View<bool*, memory_space>
dirichletColFlags(
"dirichletColFlags",
A.getColMap()->getLocalNumElements());
494 Details::ApplyDirichletBoundaryConditionToLocalMatrixRows<SC, LO, GO, NT>::run(
execSpace,
A,
lclRowInds_d,
false);
Struct that holds views of the contents of a CrsMatrix.
typename device_type::execution_space execution_space
The Kokkos execution space.
typename row_matrix_type::impl_scalar_type impl_scalar_type
The type used internally in place of Scalar.
Implementation details of Tpetra.
Namespace Tpetra contains the class and methods constituting the Tpetra library.
void applyDirichletBoundaryConditionToLocalMatrixRows(const typename CrsMatrixType::execution_space &execSpace, CrsMatrixType &A, const Kokkos::View< typename CrsMatrixType::local_ordinal_type *, typename CrsMatrixType::device_type > &lclRowInds)
For all k in [0, lclRowInds.extent(0)), set local row lclRowInds[k] of A to have 1 on the diagonal an...
@ INSERT
Insert new values that don't currently exist.
void applyDirichletBoundaryConditionToLocalMatrixRowsAndColumns(const typename CrsMatrixType::execution_space &execSpace, CrsMatrixType &A, const Kokkos::View< typename CrsMatrixType::local_ordinal_type *, typename CrsMatrixType::device_type > &lclRowInds)
For all k in [0, lclRowInds.extent(0)), set local row and column lclRowInds[k] of A to have 1 on the ...