10#ifndef TPETRA_APPLYDIRICHLETBOUNDARYCONDITION_HPP
11#define TPETRA_APPLYDIRICHLETBOUNDARYCONDITION_HPP
17#include "Tpetra_CrsMatrix.hpp"
18#include "Tpetra_Vector.hpp"
19#include "Tpetra_Map.hpp"
20#include "KokkosSparse_CrsMatrix.hpp"
21#include "KokkosKernels_ArithTraits.hpp"
34template <
class CrsMatrixType>
38 typename CrsMatrixType::local_ordinal_type*,
39 typename CrsMatrixType::device_type>& lclRowInds);
50template <
class CrsMatrixType>
53 typename CrsMatrixType::local_ordinal_type*,
54 typename CrsMatrixType::device_type>& lclRowInds);
65template <
class CrsMatrixType>
68 typename CrsMatrixType::local_ordinal_type*,
69 Kokkos::HostSpace>& lclRowInds);
81template <
class CrsMatrixType>
85 typename CrsMatrixType::local_ordinal_type*,
86 typename CrsMatrixType::device_type>& lclRowInds);
98template <
class CrsMatrixType>
101 typename CrsMatrixType::local_ordinal_type*,
102 Kokkos::HostSpace>& lclRowInds);
113template <
class CrsMatrixType>
116 typename CrsMatrixType::local_ordinal_type*,
117 typename CrsMatrixType::device_type>& lclRowInds);
121template <
class SC,
class LO,
class GO,
class NT>
122struct ApplyDirichletBoundaryConditionToLocalMatrixRows {
125 using local_row_indices_type =
126 Kokkos::View<const LO*, Kokkos::AnonymousSpace>;
129 run(
const execution_space& execSpace,
131 const local_row_indices_type& lclRowInds,
132 const bool runOnHost) {
139 using KAT = KokkosKernels::ArithTraits<IST>;
141 const auto rowMap = A.getRowMap();
142 TEUCHOS_TEST_FOR_EXCEPTION(rowMap.get() ==
nullptr, std::invalid_argument,
143 "The matrix must have a row Map.");
144 const auto colMap = A.getColMap();
145 TEUCHOS_TEST_FOR_EXCEPTION(colMap.get() ==
nullptr, std::invalid_argument,
146 "The matrix must have a column Map.");
147 auto A_lcl = A.getLocalMatrixDevice();
149 const LO lclNumRows =
static_cast<LO
>(rowMap->getLocalNumElements());
150 TEUCHOS_TEST_FOR_EXCEPTION(lclNumRows != 0 &&
static_cast<LO
>(A_lcl.graph.numRows()) != lclNumRows,
151 std::invalid_argument,
152 "The matrix must have been either created "
153 "with a KokkosSparse::CrsMatrix, or must have been fill-completed "
156 auto lclRowMap = A.getRowMap()->getLocalMap();
157 auto lclColMap = A.getColMap()->getLocalMap();
158 auto rowptr = A_lcl.graph.row_map;
159 auto colind = A_lcl.graph.entries;
160 auto values = A_lcl.values;
162 const bool wasFillComplete = A.isFillComplete();
163 if (wasFillComplete) {
167 const LO numInputRows = lclRowInds.extent(0);
169 using range_type = Kokkos::RangePolicy<execution_space, LO>;
170 Kokkos::parallel_for(
171 "Tpetra::CrsMatrix apply Dirichlet: Device",
172 range_type(execSpace, 0, numInputRows),
173 KOKKOS_LAMBDA(
const LO i) {
174 LO row = lclRowInds(i);
175 const GO row_gid = lclRowMap.getGlobalElement(row);
176 for (
auto j = rowptr(row); j < rowptr(row + 1); ++j) {
178 lclColMap.getGlobalElement(colind(j)) == row_gid;
179 values(j) = diagEnt ? KAT::one() : KAT::zero();
184 Kokkos::RangePolicy<Kokkos::DefaultHostExecutionSpace, LO>;
185 Kokkos::parallel_for(
"Tpetra::CrsMatrix apply Dirichlet: Host",
186 range_type(0, numInputRows),
188 LO row = lclRowInds(i);
189 const GO row_gid = lclRowMap.getGlobalElement(row);
190 for (
auto j = rowptr(row); j < rowptr(row + 1); ++j) {
192 lclColMap.getGlobalElement(colind(j)) == row_gid;
193 values(j) = diagEnt ? KAT::one() : KAT::zero();
197 if (wasFillComplete) {
198 A.fillComplete(A.getDomainMap(), A.getRangeMap());
203template <
class SC,
class LO,
class GO,
class NT>
204struct ApplyDirichletBoundaryConditionToLocalMatrixColumns {
207 using local_col_flag_type =
208 Kokkos::View<bool*, Kokkos::AnonymousSpace>;
211 run(
const execution_space& execSpace,
213 const local_col_flag_type& lclColFlags,
214 const bool runOnHost) {
221 using KAT = KokkosKernels::ArithTraits<IST>;
223 const auto rowMap = A.getRowMap();
224 TEUCHOS_TEST_FOR_EXCEPTION(rowMap.get() ==
nullptr, std::invalid_argument,
225 "The matrix must have a row Map.");
226 const auto colMap = A.getColMap();
227 TEUCHOS_TEST_FOR_EXCEPTION(colMap.get() ==
nullptr, std::invalid_argument,
228 "The matrix must have a column Map.");
229 auto A_lcl = A.getLocalMatrixDevice();
231 const LO lclNumRows =
static_cast<LO
>(rowMap->getLocalNumElements());
232 TEUCHOS_TEST_FOR_EXCEPTION(lclNumRows != 0 &&
static_cast<LO
>(A_lcl.graph.numRows()) != lclNumRows,
233 std::invalid_argument,
234 "The matrix must have been either created "
235 "with a KokkosSparse::CrsMatrix, or must have been fill-completed "
238 auto lclRowMap = A.getRowMap()->getLocalMap();
239 auto lclColMap = A.getColMap()->getLocalMap();
240 auto rowptr = A_lcl.graph.row_map;
241 auto colind = A_lcl.graph.entries;
242 auto values = A_lcl.values;
244 const bool wasFillComplete = A.isFillComplete();
245 if (wasFillComplete) {
249 const LO numRows = (LO)A.getRowMap()->getLocalNumElements();
251 using range_type = Kokkos::RangePolicy<execution_space, LO>;
252 Kokkos::parallel_for(
253 "Tpetra::CrsMatrix apply Dirichlet cols: Device",
254 range_type(execSpace, 0, numRows),
255 KOKKOS_LAMBDA(
const LO i) {
256 for (
auto j = rowptr(i); j < rowptr(i + 1); ++j) {
257 if (lclColFlags[colind[j]])
258 values(j) = KAT::zero();
263 Kokkos::RangePolicy<Kokkos::DefaultHostExecutionSpace, LO>;
264 Kokkos::parallel_for(
265 "Tpetra::CrsMatrix apply Dirichlet cols: Host",
266 range_type(0, numRows),
267 KOKKOS_LAMBDA(
const LO i) {
268 for (
auto j = rowptr(i); j < rowptr(i + 1); ++j) {
269 if (lclColFlags[colind[j]])
270 values(j) = KAT::zero();
274 if (wasFillComplete) {
275 A.fillComplete(A.getDomainMap(), A.getRangeMap());
280template <
class SC,
class LO,
class GO,
class NT>
281void localRowsToColumns(
const typename ::Tpetra::CrsMatrix<SC, LO, GO, NT>::execution_space& execSpace, const ::Tpetra::CrsMatrix<SC, LO, GO, NT>& A,
const Kokkos::View<const LO*, Kokkos::AnonymousSpace>& dirichletRowIds, Kokkos::View<bool*, Kokkos::AnonymousSpace>& dirichletColFlags) {
283 using execution_space =
typename crs_matrix_type::execution_space;
284 using memory_space =
typename crs_matrix_type::device_type::memory_space;
287 TEUCHOS_TEST_FOR_EXCEPTION(A.getColMap().get() ==
nullptr, std::invalid_argument,
"The matrix must have a column Map.");
291 TEUCHOS_TEST_FOR_EXCEPTION(!A.getRowMap()->isSameAs(*A.getDomainMap()), std::invalid_argument,
"localRowsToColumns: Row/Domain maps do not match");
294 TEUCHOS_TEST_FOR_EXCEPTION(A.getColMap()->getLocalNumElements() != dirichletColFlags.size(), std::invalid_argument,
"localRowsToColumns: dirichletColFlags must be the correct size");
296 LO numDirichletRows = (LO)dirichletRowIds.size();
297 LO LO_INVALID = Teuchos::OrdinalTraits<LO>::invalid();
299 if (A.getCrsGraph()->getImporter().is_null()) {
301 using range_type = Kokkos::RangePolicy<execution_space, LO>;
302 auto lclRowMap = A.getRowMap()->getLocalMap();
303 auto lclColMap = A.getColMap()->getLocalMap();
305 Kokkos::deep_copy(execSpace, dirichletColFlags,
false);
306 using range_type = Kokkos::RangePolicy<execution_space, LO>;
307 Kokkos::parallel_for(
308 "Tpetra::CrsMatrix flag Dirichlet cols",
309 range_type(execSpace, 0, numDirichletRows),
310 KOKKOS_LAMBDA(
const LO i) {
311 GO row_gid = lclRowMap.getGlobalElement(dirichletRowIds[i]);
312 LO col_lid = lclColMap.getLocalElement(row_gid);
313 if (col_lid != LO_INVALID)
314 dirichletColFlags[col_lid] =
true;
318 auto Importer = A.getCrsGraph()->getImporter();
319 auto lclRowMap = A.getRowMap()->getLocalMap();
320 auto lclColMap = A.getColMap()->getLocalMap();
323 const LO one = Teuchos::OrdinalTraits<LO>::one();
324 using range_type = Kokkos::RangePolicy<execution_space, LO>;
326 auto domain_data = domainDirichlet.template getLocalView<memory_space>(Access::ReadWrite);
327 Kokkos::parallel_for(
328 "Tpetra::CrsMatrix flag Dirichlet domains",
329 range_type(execSpace, 0, numDirichletRows),
330 KOKKOS_LAMBDA(
const LO i) {
331 GO row_gid = lclRowMap.getGlobalElement(dirichletRowIds[i]);
332 LO col_lid = lclColMap.getLocalElement(row_gid);
333 if (col_lid != LO_INVALID)
334 domain_data(col_lid, 0) = one;
338 LO numCols = (LO)A.getColMap()->getLocalNumElements();
340 auto col_data = colDirichlet.template getLocalView<memory_space>(Access::ReadOnly);
341 Kokkos::parallel_for(
342 "Tpetra::CrsMatrix flag Dirichlet cols",
343 range_type(execSpace, 0, numCols),
344 KOKKOS_LAMBDA(
const LO i) {
345 dirichletColFlags[i] = (col_data(i, 0) == one) ?
true : false;
353template <
class CrsMatrixType>
357 typename CrsMatrixType::local_ordinal_type*,
358 typename CrsMatrixType::device_type>&
lclRowInds) {
359 using SC =
typename CrsMatrixType::scalar_type;
360 using LO =
typename CrsMatrixType::local_ordinal_type;
361 using GO =
typename CrsMatrixType::global_ordinal_type;
362 using NT =
typename CrsMatrixType::node_type;
364 using local_row_indices_type =
365 Kokkos::View<const LO*, Kokkos::AnonymousSpace>;
368 using Details::ApplyDirichletBoundaryConditionToLocalMatrixRows;
375template <
class CrsMatrixType>
378 typename CrsMatrixType::local_ordinal_type*,
379 typename CrsMatrixType::device_type>&
lclRowInds) {
380 using execution_space =
typename CrsMatrixType::execution_space;
384template <
class CrsMatrixType>
387 typename CrsMatrixType::local_ordinal_type*,
389 using SC =
typename CrsMatrixType::scalar_type;
390 using LO =
typename CrsMatrixType::local_ordinal_type;
391 using GO =
typename CrsMatrixType::global_ordinal_type;
392 using NT =
typename CrsMatrixType::node_type;
394 using execution_space =
typename crs_matrix_type::execution_space;
395 using memory_space =
typename crs_matrix_type::device_type::memory_space;
397 using Details::ApplyDirichletBoundaryConditionToLocalMatrixRows;
402 const bool runOnHost = Kokkos::SpaceAccessibility<Kokkos::Serial, memory_space>::accessible;
404 using local_row_indices_type = Kokkos::View<const LO*, Kokkos::AnonymousSpace>;
413template <
class CrsMatrixType>
416 typename CrsMatrixType::local_ordinal_type*,
418 using SC =
typename CrsMatrixType::scalar_type;
419 using LO =
typename CrsMatrixType::local_ordinal_type;
420 using GO =
typename CrsMatrixType::global_ordinal_type;
421 using NT =
typename CrsMatrixType::node_type;
423 using execution_space =
typename crs_matrix_type::execution_space;
424 using memory_space =
typename crs_matrix_type::device_type::memory_space;
431 Kokkos::View<bool*, memory_space>
dirichletColFlags(
"dirichletColFlags",
A.getColMap()->getLocalNumElements());
435 Details::ApplyDirichletBoundaryConditionToLocalMatrixColumns<SC, LO, GO, NT>::run(execution_space(),
A,
dirichletColFlags,
false);
436 Details::ApplyDirichletBoundaryConditionToLocalMatrixRows<SC, LO, GO, NT>::run(execution_space(),
A,
lclRowInds_d,
false);
439template <
class CrsMatrixType>
442 typename CrsMatrixType::local_ordinal_type*,
444 using SC =
typename CrsMatrixType::scalar_type;
445 using LO =
typename CrsMatrixType::local_ordinal_type;
446 using GO =
typename CrsMatrixType::global_ordinal_type;
447 using NT =
typename CrsMatrixType::node_type;
449 using execution_space =
typename crs_matrix_type::execution_space;
450 using memory_space =
typename crs_matrix_type::device_type::memory_space;
454 Kokkos::View<bool*, memory_space>
dirichletColFlags(
"dirichletColFlags",
A.getColMap()->getLocalNumElements());
458 Details::ApplyDirichletBoundaryConditionToLocalMatrixColumns<SC, LO, GO, NT>::run(execution_space(),
A,
dirichletColFlags,
false);
459 Details::ApplyDirichletBoundaryConditionToLocalMatrixRows<SC, LO, GO, NT>::run(execution_space(),
A,
lclRowInds_d,
false);
462template <
class CrsMatrixType>
466 typename CrsMatrixType::local_ordinal_type*,
468 using SC =
typename CrsMatrixType::scalar_type;
469 using LO =
typename CrsMatrixType::local_ordinal_type;
470 using GO =
typename CrsMatrixType::global_ordinal_type;
471 using NT =
typename CrsMatrixType::node_type;
474 using memory_space =
typename crs_matrix_type::device_type::memory_space;
478 Kokkos::View<bool*, memory_space>
dirichletColFlags(
"dirichletColFlags",
A.getColMap()->getLocalNumElements());
483 Details::ApplyDirichletBoundaryConditionToLocalMatrixRows<SC, LO, GO, NT>::run(
execSpace,
A,
lclRowInds_d,
false);
typename device_type::execution_space execution_space
The Kokkos execution space.
typename row_matrix_type::impl_scalar_type impl_scalar_type
The type used internally in place of Scalar.
Struct that holds views of the contents of a CrsMatrix.
Implementation details of Tpetra.
Namespace Tpetra contains the class and methods constituting the Tpetra library.
void applyDirichletBoundaryConditionToLocalMatrixRows(const typename CrsMatrixType::execution_space &execSpace, CrsMatrixType &A, const Kokkos::View< typename CrsMatrixType::local_ordinal_type *, typename CrsMatrixType::device_type > &lclRowInds)
For all k in [0, lclRowInds.extent(0)), set local row lclRowInds[k] of A to have 1 on the diagonal an...
@ INSERT
Insert new values that don't currently exist.
void applyDirichletBoundaryConditionToLocalMatrixRowsAndColumns(const typename CrsMatrixType::execution_space &execSpace, CrsMatrixType &A, const Kokkos::View< typename CrsMatrixType::local_ordinal_type *, typename CrsMatrixType::device_type > &lclRowInds)
For all k in [0, lclRowInds.extent(0)), set local row and column lclRowInds[k] of A to have 1 on the ...