10#ifndef TPETRA_MATRIXMATRIX_CUDA_DEF_HPP
11#define TPETRA_MATRIXMATRIX_CUDA_DEF_HPP
13#include "Tpetra_Details_IntRowPtrHelper.hpp"
15#ifdef HAVE_TPETRA_INST_CUDA
21template <
class Scalar,
24 class LocalOrdinalViewType>
25struct KernelWrappers<Scalar, LocalOrdinal, GlobalOrdinal,
Tpetra::KokkosCompat::KokkosCudaWrapperNode, LocalOrdinalViewType> {
26 static inline void mult_A_B_newmatrix_kernel_wrapper(CrsMatrixStruct<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosCudaWrapperNode>& Aview,
27 CrsMatrixStruct<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosCudaWrapperNode>& Bview,
28 const LocalOrdinalViewType& Acol2Brow,
29 const LocalOrdinalViewType& Acol2Irow,
30 const LocalOrdinalViewType& Bcol2Ccol,
31 const LocalOrdinalViewType& Icol2Ccol,
32 CrsMatrix<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosCudaWrapperNode>& C,
33 Teuchos::RCP<
const Import<LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosCudaWrapperNode> > Cimport,
34 const std::string& label = std::string(),
35 const Teuchos::RCP<Teuchos::ParameterList>& params = Teuchos::null);
37 static inline void mult_A_B_reuse_kernel_wrapper(CrsMatrixStruct<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosCudaWrapperNode>& Aview,
38 CrsMatrixStruct<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosCudaWrapperNode>& Bview,
39 const LocalOrdinalViewType& Acol2Brow,
40 const LocalOrdinalViewType& Acol2Irow,
41 const LocalOrdinalViewType& Bcol2Ccol,
42 const LocalOrdinalViewType& Icol2Ccol,
43 CrsMatrix<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosCudaWrapperNode>& C,
44 Teuchos::RCP<
const Import<LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosCudaWrapperNode> > Cimport,
45 const std::string& label = std::string(),
46 const Teuchos::RCP<Teuchos::ParameterList>& params = Teuchos::null);
50template <
class Scalar,
52 class GlobalOrdinal,
class LocalOrdinalViewType>
53struct KernelWrappers2<Scalar, LocalOrdinal, GlobalOrdinal,
Tpetra::KokkosCompat::KokkosCudaWrapperNode, LocalOrdinalViewType> {
54 static inline void jacobi_A_B_newmatrix_kernel_wrapper(Scalar omega,
55 const Vector<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosCudaWrapperNode>& Dinv,
56 CrsMatrixStruct<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosCudaWrapperNode>& Aview,
57 CrsMatrixStruct<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosCudaWrapperNode>& Bview,
58 const LocalOrdinalViewType& Acol2Brow,
59 const LocalOrdinalViewType& Acol2Irow,
60 const LocalOrdinalViewType& Bcol2Ccol,
61 const LocalOrdinalViewType& Icol2Ccol,
62 CrsMatrix<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosCudaWrapperNode>& C,
63 Teuchos::RCP<
const Import<LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosCudaWrapperNode> > Cimport,
64 const std::string& label = std::string(),
65 const Teuchos::RCP<Teuchos::ParameterList>& params = Teuchos::null);
67 static inline void jacobi_A_B_reuse_kernel_wrapper(Scalar omega,
68 const Vector<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosCudaWrapperNode>& Dinv,
69 CrsMatrixStruct<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosCudaWrapperNode>& Aview,
70 CrsMatrixStruct<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosCudaWrapperNode>& Bview,
71 const LocalOrdinalViewType& Acol2Brow,
72 const LocalOrdinalViewType& Acol2Irow,
73 const LocalOrdinalViewType& Bcol2Ccol,
74 const LocalOrdinalViewType& Icol2Ccol,
75 CrsMatrix<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosCudaWrapperNode>& C,
76 Teuchos::RCP<
const Import<LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosCudaWrapperNode> > Cimport,
77 const std::string& label = std::string(),
78 const Teuchos::RCP<Teuchos::ParameterList>& params = Teuchos::null);
80 static inline void jacobi_A_B_newmatrix_KokkosKernels(Scalar omega,
81 const Vector<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosCudaWrapperNode>& Dinv,
82 CrsMatrixStruct<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosCudaWrapperNode>& Aview,
83 CrsMatrixStruct<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosCudaWrapperNode>& Bview,
84 const LocalOrdinalViewType& Acol2Brow,
85 const LocalOrdinalViewType& Acol2Irow,
86 const LocalOrdinalViewType& Bcol2Ccol,
87 const LocalOrdinalViewType& Icol2Ccol,
88 CrsMatrix<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosCudaWrapperNode>& C,
89 Teuchos::RCP<
const Import<LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosCudaWrapperNode> > Cimport,
90 const std::string& label = std::string(),
91 const Teuchos::RCP<Teuchos::ParameterList>& params = Teuchos::null);
96template <
class Scalar,
99 class LocalOrdinalViewType>
100void KernelWrappers<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosCudaWrapperNode, LocalOrdinalViewType>::mult_A_B_newmatrix_kernel_wrapper(CrsMatrixStruct<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosCudaWrapperNode>& Aview,
101 CrsMatrixStruct<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosCudaWrapperNode>& Bview,
102 const LocalOrdinalViewType& Acol2Brow,
103 const LocalOrdinalViewType& Acol2Irow,
104 const LocalOrdinalViewType& Bcol2Ccol,
105 const LocalOrdinalViewType& Icol2Ccol,
106 CrsMatrix<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosCudaWrapperNode>& C,
107 Teuchos::RCP<
const Import<LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosCudaWrapperNode> > Cimport,
108 const std::string& label,
109 const Teuchos::RCP<Teuchos::ParameterList>& params) {
115 typedef Tpetra::KokkosCompat::KokkosCudaWrapperNode Node;
116 std::string nodename(
"Cuda");
120 typedef typename KCRS::device_type device_t;
121 typedef typename KCRS::StaticCrsGraphType graph_t;
122 typedef typename graph_t::row_map_type::non_const_type lno_view_t;
123 using int_view_t = Kokkos::View<int*, typename lno_view_t::array_layout, typename lno_view_t::memory_space, typename lno_view_t::memory_traits>;
124 typedef typename graph_t::row_map_type::const_type c_lno_view_t;
125 typedef typename graph_t::entries_type::non_const_type lno_nnz_view_t;
126 typedef typename KCRS::values_type::non_const_type scalar_view_t;
130 int team_work_size = 16;
131 std::string myalg(
"SPGEMM_KK_MEMORY");
132 if (!params.is_null()) {
133 if (params->isParameter(
"cuda: algorithm"))
134 myalg = params->get(
"cuda: algorithm", myalg);
135 if (params->isParameter(
"cuda: team work size"))
136 team_work_size = params->get(
"cuda: team work size", team_work_size);
140 typedef KokkosKernels::Experimental::KokkosKernelsHandle<
141 typename lno_view_t::const_value_type,
typename lno_nnz_view_t::const_value_type,
typename scalar_view_t::const_value_type,
142 typename device_t::execution_space,
typename device_t::memory_space,
typename device_t::memory_space>
144 using IntKernelHandle = KokkosKernels::Experimental::KokkosKernelsHandle<
145 typename int_view_t::const_value_type,
typename lno_nnz_view_t::const_value_type,
typename scalar_view_t::const_value_type,
146 typename device_t::execution_space,
typename device_t::memory_space,
typename device_t::memory_space>;
149 const KCRS& Amat = Aview.origMatrix->getLocalMatrixDevice();
150 const KCRS& Bmat = Bview.origMatrix->getLocalMatrixDevice();
152 c_lno_view_t Arowptr = Amat.graph.row_map,
153 Browptr = Bmat.graph.row_map;
154 const lno_nnz_view_t Acolind = Amat.graph.entries,
155 Bcolind = Bmat.graph.entries;
156 const scalar_view_t Avals = Amat.values,
160 std::string alg = nodename + std::string(
" algorithm");
162 if (!params.is_null() && params->isParameter(alg)) myalg = params->get(alg, myalg);
163 KokkosSparse::SPGEMMAlgorithm alg_enum = KokkosSparse::StringToSPGEMMAlgorithm(myalg);
166 KCRS Bmerged = Tpetra::MMdetails::merge_matrices(Aview, Bview, Acol2Brow, Acol2Irow, Bcol2Ccol, Icol2Ccol, C.getColMap()->getLocalNumElements());
173#if defined(KOKKOS_ENABLE_CUDA) && defined(KOKKOSKERNELS_ENABLE_TPL_CUSPARSE) && ((CUDA_VERSION < 11000) || (CUDA_VERSION >= 11040))
174 if constexpr (std::is_same_v<typename device_t::execution_space, Kokkos::Cuda>) {
175 if (!KokkosSparse::isCrsGraphSorted(Bmerged.graph.row_map, Bmerged.graph.entries)) {
176 KokkosSparse::sort_crs_matrix(Bmerged);
185 typename KernelHandle::nnz_lno_t AnumRows = Amat.numRows();
186 typename KernelHandle::nnz_lno_t BnumRows = Bmerged.numRows();
187 typename KernelHandle::nnz_lno_t BnumCols = Bmerged.numCols();
190 lno_view_t row_mapC(Kokkos::ViewAllocateWithoutInitializing(
"non_const_lno_row"), AnumRows + 1);
191 lno_nnz_view_t entriesC;
192 scalar_view_t valuesC;
195 const bool useIntRowptrs =
196 irph.shouldUseIntRowptrs() &&
197 Aview.
origMatrix->getApplyHelper()->shouldUseIntRowptrs();
201 kh.create_spgemm_handle(alg_enum);
202 kh.set_team_work_size(team_work_size);
204 int_view_t int_row_mapC(Kokkos::ViewAllocateWithoutInitializing(
"non_const_int_row"), AnumRows + 1);
206 auto Aint = Aview.origMatrix->getApplyHelper()->getIntRowptrMatrix(Amat);
207 auto Bint = irph.getIntRowptrMatrix(Bmerged);
211 KokkosSparse::spgemm_symbolic(&kh, AnumRows, BnumRows, BnumCols, Aint.graph.row_map, Aint.graph.entries,
false, Bint.graph.row_map, Bint.graph.entries,
false, int_row_mapC);
216 size_t c_nnz_size = kh.get_spgemm_handle()->get_c_nnz();
218 entriesC = lno_nnz_view_t(Kokkos::ViewAllocateWithoutInitializing(
"entriesC"), c_nnz_size);
219 valuesC = scalar_view_t(Kokkos::ViewAllocateWithoutInitializing(
"valuesC"), c_nnz_size);
221 KokkosSparse::spgemm_numeric(&kh, AnumRows, BnumRows, BnumCols, Aint.graph.row_map, Aint.graph.entries, Aint.values,
false, Bint.graph.row_map, Bint.graph.entries, Bint.values,
false, int_row_mapC, entriesC, valuesC);
223 Kokkos::parallel_for(
224 int_row_mapC.size(), KOKKOS_LAMBDA(
int i) { row_mapC(i) = int_row_mapC(i); });
225 kh.destroy_spgemm_handle();
229 kh.create_spgemm_handle(alg_enum);
230 kh.set_team_work_size(team_work_size);
234 KokkosSparse::spgemm_symbolic(&kh, AnumRows, BnumRows, BnumCols, Amat.graph.row_map, Amat.graph.entries,
false, Bmerged.graph.row_map, Bmerged.graph.entries,
false, row_mapC);
238 size_t c_nnz_size = kh.get_spgemm_handle()->get_c_nnz();
240 entriesC = lno_nnz_view_t(Kokkos::ViewAllocateWithoutInitializing(
"entriesC"), c_nnz_size);
241 valuesC = scalar_view_t(Kokkos::ViewAllocateWithoutInitializing(
"valuesC"), c_nnz_size);
244 KokkosSparse::spgemm_numeric(&kh, AnumRows, BnumRows, BnumCols, Amat.graph.row_map, Amat.graph.entries, Amat.values,
false, Bmerged.graph.row_map, Bmerged.graph.entries, Bmerged.values,
false, row_mapC, entriesC, valuesC);
246 kh.destroy_spgemm_handle();
253 if (params.is_null() || params->get(
"sort entries",
true))
254 Import_Util::sortCrsEntries(row_mapC, entriesC, valuesC);
255 C.setAllValues(row_mapC, entriesC, valuesC);
261 RCP<Teuchos::ParameterList> labelList = rcp(
new Teuchos::ParameterList);
262 labelList->set(
"Timer Label", label);
263 if (!params.is_null()) labelList->set(
"compute global constants", params->get(
"compute global constants",
true));
264 RCP<const Export<LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosCudaWrapperNode> > dummyExport;
265 C.expertStaticFillComplete(Bview.origMatrix->getDomainMap(), Aview.origMatrix->getRangeMap(), Cimport, dummyExport, labelList);
269template <
class Scalar,
272 class LocalOrdinalViewType>
273void KernelWrappers<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosCudaWrapperNode, LocalOrdinalViewType>::mult_A_B_reuse_kernel_wrapper(
274 CrsMatrixStruct<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosCudaWrapperNode>& Aview,
275 CrsMatrixStruct<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosCudaWrapperNode>& Bview,
276 const LocalOrdinalViewType& targetMapToOrigRow_dev,
277 const LocalOrdinalViewType& targetMapToImportRow_dev,
278 const LocalOrdinalViewType& Bcol2Ccol_dev,
279 const LocalOrdinalViewType& Icol2Ccol_dev,
280 CrsMatrix<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosCudaWrapperNode>& C,
281 Teuchos::RCP<
const Import<LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosCudaWrapperNode> > Cimport,
282 const std::string& label,
283 const Teuchos::RCP<Teuchos::ParameterList>& params) {
285 typedef Tpetra::KokkosCompat::KokkosCudaWrapperNode Node;
292 typedef typename Tpetra::CrsMatrix<Scalar, LocalOrdinal, GlobalOrdinal, Node>::local_matrix_host_type KCRS;
293 typedef typename KCRS::StaticCrsGraphType graph_t;
294 typedef typename graph_t::row_map_type::const_type c_lno_view_t;
295 typedef typename graph_t::entries_type::non_const_type lno_nnz_view_t;
296 typedef typename KCRS::values_type::non_const_type scalar_view_t;
299 typedef LocalOrdinal LO;
300 typedef GlobalOrdinal GO;
302 typedef Map<LO, GO, NO> map_type;
303 const size_t ST_INVALID = Teuchos::OrdinalTraits<LO>::invalid();
304 const LO LO_INVALID = Teuchos::OrdinalTraits<LO>::invalid();
305 const SC SC_ZERO = Teuchos::ScalarTraits<Scalar>::zero();
313 auto targetMapToOrigRow =
314 Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(),
315 targetMapToOrigRow_dev);
316 auto targetMapToImportRow =
317 Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(),
318 targetMapToImportRow_dev);
320 Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(),
323 Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(),
327 RCP<const map_type> Ccolmap = C.getColMap();
328 size_t m = Aview.origMatrix->getLocalNumRows();
329 size_t n = Ccolmap->getLocalNumElements();
332 const KCRS& Amat = Aview.origMatrix->getLocalMatrixHost();
333 const KCRS& Bmat = Bview.origMatrix->getLocalMatrixHost();
334 const KCRS& Cmat = C.getLocalMatrixHost();
336 c_lno_view_t Arowptr = Amat.graph.row_map,
337 Browptr = Bmat.graph.row_map,
338 Crowptr = Cmat.graph.row_map;
339 const lno_nnz_view_t Acolind = Amat.graph.entries,
340 Bcolind = Bmat.graph.entries,
341 Ccolind = Cmat.graph.entries;
342 const scalar_view_t Avals = Amat.values, Bvals = Bmat.values;
343 scalar_view_t Cvals = Cmat.values;
345 c_lno_view_t Irowptr;
346 lno_nnz_view_t Icolind;
348 if (!Bview.importMatrix.is_null()) {
349 auto lclB = Bview.importMatrix->getLocalMatrixHost();
350 Irowptr = lclB.graph.row_map;
351 Icolind = lclB.graph.entries;
364 std::vector<size_t> c_status(n, ST_INVALID);
367 size_t CSR_ip = 0, OLD_ip = 0;
368 for (
size_t i = 0; i < m; i++) {
372 CSR_ip = Crowptr[i + 1];
373 for (
size_t k = OLD_ip; k < CSR_ip; k++) {
374 c_status[Ccolind[k]] = k;
380 for (
size_t k = Arowptr[i]; k < Arowptr[i + 1]; k++) {
382 const SC Aval = Avals[k];
386 if (targetMapToOrigRow[Aik] != LO_INVALID) {
388 size_t Bk = Teuchos::as<size_t>(targetMapToOrigRow[Aik]);
390 for (
size_t j = Browptr[Bk]; j < Browptr[Bk + 1]; ++j) {
392 LO Cij = Bcol2Ccol[Bkj];
394 TEUCHOS_TEST_FOR_EXCEPTION(c_status[Cij] < OLD_ip || c_status[Cij] >= CSR_ip,
395 std::runtime_error,
"Trying to insert a new entry (" << i <<
"," << Cij <<
") into a static graph "
396 <<
"(c_status = " << c_status[Cij] <<
" of [" << OLD_ip <<
"," << CSR_ip <<
"))");
398 Cvals[c_status[Cij]] += Aval * Bvals[j];
403 size_t Ik = Teuchos::as<size_t>(targetMapToImportRow[Aik]);
404 for (
size_t j = Irowptr[Ik]; j < Irowptr[Ik + 1]; ++j) {
406 LO Cij = Icol2Ccol[Ikj];
408 TEUCHOS_TEST_FOR_EXCEPTION(c_status[Cij] < OLD_ip || c_status[Cij] >= CSR_ip,
409 std::runtime_error,
"Trying to insert a new entry (" << i <<
"," << Cij <<
") into a static graph "
410 <<
"(c_status = " << c_status[Cij] <<
" of [" << OLD_ip <<
"," << CSR_ip <<
"))");
412 Cvals[c_status[Cij]] += Aval * Ivals[j];
418 C.fillComplete(C.getDomainMap(), C.getRangeMap());
422template <
class Scalar,
425 class LocalOrdinalViewType>
426void KernelWrappers2<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosCudaWrapperNode, LocalOrdinalViewType>::jacobi_A_B_newmatrix_kernel_wrapper(Scalar omega,
427 const Vector<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosCudaWrapperNode>& Dinv,
428 CrsMatrixStruct<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosCudaWrapperNode>& Aview,
429 CrsMatrixStruct<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosCudaWrapperNode>& Bview,
430 const LocalOrdinalViewType& Acol2Brow,
431 const LocalOrdinalViewType& Acol2Irow,
432 const LocalOrdinalViewType& Bcol2Ccol,
433 const LocalOrdinalViewType& Icol2Ccol,
434 CrsMatrix<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosCudaWrapperNode>& C,
435 Teuchos::RCP<
const Import<LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosCudaWrapperNode> > Cimport,
436 const std::string& label,
437 const Teuchos::RCP<Teuchos::ParameterList>& params) {
445 std::string myalg(
"KK");
446 if (!params.is_null()) {
447 if (params->isParameter(
"cuda: jacobi algorithm"))
448 myalg = params->get(
"cuda: jacobi algorithm", myalg);
451 if (myalg ==
"MSAK") {
452 ::Tpetra::MatrixMatrix::ExtraKernels::jacobi_A_B_newmatrix_MultiplyScaleAddKernel(omega, Dinv, Aview, Bview, Acol2Brow, Acol2Irow, Bcol2Ccol, Icol2Ccol, C, Cimport, label, params);
453 }
else if (myalg ==
"KK") {
454 jacobi_A_B_newmatrix_KokkosKernels(omega, Dinv, Aview, Bview, Acol2Brow, Acol2Irow, Bcol2Ccol, Icol2Ccol, C, Cimport, label, params);
456 throw std::runtime_error(
"Tpetra::MatrixMatrix::Jacobi newmatrix unknown kernel");
463 RCP<Teuchos::ParameterList> labelList = rcp(
new Teuchos::ParameterList);
464 labelList->set(
"Timer Label", label);
465 if (!params.is_null()) labelList->set(
"compute global constants", params->get(
"compute global constants",
true));
468 if (!C.isFillComplete()) {
469 RCP<const Export<LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosCudaWrapperNode> > dummyExport;
470 C.expertStaticFillComplete(Bview.origMatrix->getDomainMap(), Aview.origMatrix->getRangeMap(), Cimport, dummyExport, labelList);
475template <
class Scalar,
478 class LocalOrdinalViewType>
479void KernelWrappers2<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosCudaWrapperNode, LocalOrdinalViewType>::jacobi_A_B_reuse_kernel_wrapper(Scalar omega,
480 const Vector<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosCudaWrapperNode>& Dinv,
481 CrsMatrixStruct<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosCudaWrapperNode>& Aview,
482 CrsMatrixStruct<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosCudaWrapperNode>& Bview,
483 const LocalOrdinalViewType& targetMapToOrigRow_dev,
484 const LocalOrdinalViewType& targetMapToImportRow_dev,
485 const LocalOrdinalViewType& Bcol2Ccol_dev,
486 const LocalOrdinalViewType& Icol2Ccol_dev,
487 CrsMatrix<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosCudaWrapperNode>& C,
488 Teuchos::RCP<
const Import<LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosCudaWrapperNode> > Cimport,
489 const std::string& label,
490 const Teuchos::RCP<Teuchos::ParameterList>& params) {
492 typedef Tpetra::KokkosCompat::KokkosCudaWrapperNode Node;
499 typedef typename Tpetra::CrsMatrix<Scalar, LocalOrdinal, GlobalOrdinal, Node>::local_matrix_host_type KCRS;
500 typedef typename KCRS::StaticCrsGraphType graph_t;
501 typedef typename graph_t::row_map_type::const_type c_lno_view_t;
502 typedef typename graph_t::entries_type::non_const_type lno_nnz_view_t;
503 typedef typename KCRS::values_type::non_const_type scalar_view_t;
504 typedef typename scalar_view_t::memory_space scalar_memory_space;
507 typedef LocalOrdinal LO;
508 typedef GlobalOrdinal GO;
510 typedef Map<LO, GO, NO> map_type;
511 const size_t ST_INVALID = Teuchos::OrdinalTraits<LO>::invalid();
512 const LO LO_INVALID = Teuchos::OrdinalTraits<LO>::invalid();
513 const SC SC_ZERO = Teuchos::ScalarTraits<Scalar>::zero();
521 auto targetMapToOrigRow =
522 Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(),
523 targetMapToOrigRow_dev);
524 auto targetMapToImportRow =
525 Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(),
526 targetMapToImportRow_dev);
528 Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(),
531 Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(),
535 RCP<const map_type> Ccolmap = C.getColMap();
536 size_t m = Aview.origMatrix->getLocalNumRows();
537 size_t n = Ccolmap->getLocalNumElements();
540 const KCRS& Amat = Aview.origMatrix->getLocalMatrixHost();
541 const KCRS& Bmat = Bview.origMatrix->getLocalMatrixHost();
542 const KCRS& Cmat = C.getLocalMatrixHost();
544 c_lno_view_t Arowptr = Amat.graph.row_map, Browptr = Bmat.graph.row_map, Crowptr = Cmat.graph.row_map;
545 const lno_nnz_view_t Acolind = Amat.graph.entries, Bcolind = Bmat.graph.entries, Ccolind = Cmat.graph.entries;
546 const scalar_view_t Avals = Amat.values, Bvals = Bmat.values;
547 scalar_view_t Cvals = Cmat.values;
549 c_lno_view_t Irowptr;
550 lno_nnz_view_t Icolind;
552 if (!Bview.importMatrix.is_null()) {
553 auto lclB = Bview.importMatrix->getLocalMatrixHost();
554 Irowptr = lclB.graph.row_map;
555 Icolind = lclB.graph.entries;
561 Dinv.template getLocalView<scalar_memory_space>(Access::ReadOnly);
568 std::vector<size_t> c_status(n, ST_INVALID);
571 size_t CSR_ip = 0, OLD_ip = 0;
572 for (
size_t i = 0; i < m; i++) {
576 CSR_ip = Crowptr[i + 1];
577 for (
size_t k = OLD_ip; k < CSR_ip; k++) {
578 c_status[Ccolind[k]] = k;
584 SC minusOmegaDval = -omega * Dvals(i, 0);
587 for (
size_t j = Browptr[i]; j < Browptr[i + 1]; j++) {
588 Scalar Bval = Bvals[j];
592 LO Cij = Bcol2Ccol[Bij];
594 TEUCHOS_TEST_FOR_EXCEPTION(c_status[Cij] < OLD_ip || c_status[Cij] >= CSR_ip,
595 std::runtime_error,
"Trying to insert a new entry into a static graph");
597 Cvals[c_status[Cij]] = Bvals[j];
601 for (
size_t k = Arowptr[i]; k < Arowptr[i + 1]; k++) {
603 const SC Aval = Avals[k];
607 if (targetMapToOrigRow[Aik] != LO_INVALID) {
609 size_t Bk = Teuchos::as<size_t>(targetMapToOrigRow[Aik]);
611 for (
size_t j = Browptr[Bk]; j < Browptr[Bk + 1]; ++j) {
613 LO Cij = Bcol2Ccol[Bkj];
615 TEUCHOS_TEST_FOR_EXCEPTION(c_status[Cij] < OLD_ip || c_status[Cij] >= CSR_ip,
616 std::runtime_error,
"Trying to insert a new entry into a static graph");
618 Cvals[c_status[Cij]] += minusOmegaDval * Aval * Bvals[j];
623 size_t Ik = Teuchos::as<size_t>(targetMapToImportRow[Aik]);
624 for (
size_t j = Irowptr[Ik]; j < Irowptr[Ik + 1]; ++j) {
626 LO Cij = Icol2Ccol[Ikj];
628 TEUCHOS_TEST_FOR_EXCEPTION(c_status[Cij] < OLD_ip || c_status[Cij] >= CSR_ip,
629 std::runtime_error,
"Trying to insert a new entry into a static graph");
631 Cvals[c_status[Cij]] += minusOmegaDval * Aval * Ivals[j];
640 C.fillComplete(C.getDomainMap(), C.getRangeMap());
644template <
class Scalar,
647 class LocalOrdinalViewType>
648void KernelWrappers2<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosCudaWrapperNode, LocalOrdinalViewType>::jacobi_A_B_newmatrix_KokkosKernels(Scalar omega,
649 const Vector<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosCudaWrapperNode>& Dinv,
650 CrsMatrixStruct<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosCudaWrapperNode>& Aview,
651 CrsMatrixStruct<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosCudaWrapperNode>& Bview,
652 const LocalOrdinalViewType& Acol2Brow,
653 const LocalOrdinalViewType& Acol2Irow,
654 const LocalOrdinalViewType& Bcol2Ccol,
655 const LocalOrdinalViewType& Icol2Ccol,
656 CrsMatrix<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosCudaWrapperNode>& C,
657 Teuchos::RCP<
const Import<LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosCudaWrapperNode> > Cimport,
658 const std::string& label,
659 const Teuchos::RCP<Teuchos::ParameterList>& params) {
663 auto rowMap = Aview.origMatrix->getRowMap();
665 Aview.origMatrix->getLocalDiagCopy(diags);
666 size_t diagLength = rowMap->getLocalNumElements();
667 Teuchos::Array<Scalar> diagonal(diagLength);
668 diags.get1dCopy(diagonal());
670 for (
size_t i = 0; i < diagLength; ++i) {
671 TEUCHOS_TEST_FOR_EXCEPTION(diagonal[i] == Teuchos::ScalarTraits<Scalar>::zero(),
673 "Matrix A has a zero/missing diagonal: " << diagonal[i] << std::endl
674 <<
"KokkosKernels Jacobi-fused SpGEMM requires nonzero diagonal entries in A" << std::endl);
683 using device_t =
typename Tpetra::KokkosCompat::KokkosCudaWrapperNode::device_type;
685 using graph_t =
typename matrix_t::StaticCrsGraphType;
686 using lno_view_t =
typename graph_t::row_map_type::non_const_type;
687 using int_view_t = Kokkos::View<
int*,
688 typename lno_view_t::array_layout,
689 typename lno_view_t::memory_space,
690 typename lno_view_t::memory_traits>;
691 using c_lno_view_t =
typename graph_t::row_map_type::const_type;
692 using lno_nnz_view_t =
typename graph_t::entries_type::non_const_type;
693 using scalar_view_t =
typename matrix_t::values_type::non_const_type;
696 using handle_t =
typename KokkosKernels::Experimental::KokkosKernelsHandle<
697 typename lno_view_t::const_value_type,
typename lno_nnz_view_t::const_value_type,
typename scalar_view_t::const_value_type,
698 typename device_t::execution_space,
typename device_t::memory_space,
typename device_t::memory_space>;
700 using int_handle_t =
typename KokkosKernels::Experimental::KokkosKernelsHandle<
701 typename int_view_t::const_value_type,
typename lno_nnz_view_t::const_value_type,
typename scalar_view_t::const_value_type,
702 typename device_t::execution_space,
typename device_t::memory_space,
typename device_t::memory_space>;
705 const matrix_t Bmerged = Tpetra::MMdetails::merge_matrices(Aview, Bview, Acol2Brow, Acol2Irow, Bcol2Ccol, Icol2Ccol, C.getColMap()->getLocalNumElements());
708 const matrix_t& Amat = Aview.origMatrix->getLocalMatrixDevice();
709 const matrix_t& Bmat = Bview.origMatrix->getLocalMatrixDevice();
711 typename handle_t::nnz_lno_t AnumRows = Amat.numRows();
712 typename handle_t::nnz_lno_t BnumRows = Bmerged.numRows();
713 typename handle_t::nnz_lno_t BnumCols = Bmerged.numCols();
716 lno_view_t row_mapC(Kokkos::ViewAllocateWithoutInitializing(
"row_mapC"), AnumRows + 1);
717 lno_nnz_view_t entriesC;
718 scalar_view_t valuesC;
721 int team_work_size = 16;
722 std::string myalg(
"SPGEMM_KK_MEMORY");
723 if (!params.is_null()) {
724 if (params->isParameter(
"cuda: algorithm"))
725 myalg = params->get(
"cuda: algorithm", myalg);
726 if (params->isParameter(
"cuda: team work size"))
727 team_work_size = params->get(
"cuda: team work size", team_work_size);
731 std::string alg(
"Cuda algorithm");
732 if (!params.is_null() && params->isParameter(alg)) myalg = params->get(alg, myalg);
733 KokkosSparse::SPGEMMAlgorithm alg_enum = KokkosSparse::StringToSPGEMMAlgorithm(myalg);
737 const bool useIntRowptrs =
738 irph.shouldUseIntRowptrs() &&
739 Aview.
origMatrix->getApplyHelper()->shouldUseIntRowptrs();
743 kh.create_spgemm_handle(alg_enum);
744 kh.set_team_work_size(team_work_size);
746 int_view_t int_row_mapC(Kokkos::ViewAllocateWithoutInitializing(
"int_row_mapC"), AnumRows + 1);
748 auto Aint = Aview.origMatrix->getApplyHelper()->getIntRowptrMatrix(Amat);
749 auto Bint = irph.getIntRowptrMatrix(Bmerged);
753 KokkosSparse::spgemm_symbolic(&kh, AnumRows, BnumRows, BnumCols,
754 Aint.graph.row_map, Aint.graph.entries,
false,
755 Bint.graph.row_map, Bint.graph.entries,
false,
759 size_t c_nnz_size = kh.get_spgemm_handle()->get_c_nnz();
761 entriesC = lno_nnz_view_t(Kokkos::ViewAllocateWithoutInitializing(
"entriesC"), c_nnz_size);
762 valuesC = scalar_view_t(Kokkos::ViewAllocateWithoutInitializing(
"valuesC"), c_nnz_size);
768 KokkosSparse::Experimental::spgemm_jacobi(&kh, AnumRows, BnumRows, BnumCols,
769 Aint.graph.row_map, Aint.graph.entries, Amat.values,
false,
770 Bint.graph.row_map, Bint.graph.entries, Bint.values,
false,
771 int_row_mapC, entriesC, valuesC,
772 omega, Dinv.getLocalViewDevice(Access::ReadOnly));
774 Kokkos::parallel_for(
775 int_row_mapC.size(), KOKKOS_LAMBDA(
int i) { row_mapC(i) = int_row_mapC(i); });
776 kh.destroy_spgemm_handle();
779 kh.create_spgemm_handle(alg_enum);
780 kh.set_team_work_size(team_work_size);
784 KokkosSparse::spgemm_symbolic(&kh, AnumRows, BnumRows, BnumCols,
785 Amat.graph.row_map, Amat.graph.entries,
false,
786 Bmerged.graph.row_map, Bmerged.graph.entries,
false,
790 size_t c_nnz_size = kh.get_spgemm_handle()->get_c_nnz();
792 entriesC = lno_nnz_view_t(Kokkos::ViewAllocateWithoutInitializing(
"entriesC"), c_nnz_size);
793 valuesC = scalar_view_t(Kokkos::ViewAllocateWithoutInitializing(
"valuesC"), c_nnz_size);
797 KokkosSparse::Experimental::spgemm_jacobi(&kh, AnumRows, BnumRows, BnumCols,
798 Amat.graph.row_map, Amat.graph.entries, Amat.values,
false,
799 Bmerged.graph.row_map, Bmerged.graph.entries, Bmerged.values,
false,
800 row_mapC, entriesC, valuesC,
801 omega, Dinv.getLocalViewDevice(Access::ReadOnly));
802 kh.destroy_spgemm_handle();
809 if (params.is_null() || params->get(
"sort entries",
true))
810 Import_Util::sortCrsEntries(row_mapC, entriesC, valuesC);
811 C.setAllValues(row_mapC, entriesC, valuesC);
817 Teuchos::RCP<Teuchos::ParameterList> labelList = rcp(
new Teuchos::ParameterList);
818 labelList->set(
"Timer Label", label);
819 if (!params.is_null()) labelList->set(
"compute global constants", params->get(
"compute global constants",
true));
820 Teuchos::RCP<const Export<LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosCudaWrapperNode> > dummyExport;
821 C.expertStaticFillComplete(Bview.origMatrix->getDomainMap(), Aview.origMatrix->getRangeMap(), Cimport, dummyExport, labelList);
Struct that holds views of the contents of a CrsMatrix.
Teuchos::RCP< const CrsMatrix< Scalar, LocalOrdinal, GlobalOrdinal, Node > > origMatrix
The original matrix.
KokkosSparse::CrsMatrix< impl_scalar_type, local_ordinal_type, device_type, void, typename local_graph_device_type::size_type > local_matrix_device_type
The specialization of Kokkos::CrsMatrix that represents the part of the sparse matrix on each MPI pro...
static bool debug()
Whether Tpetra is in debug mode.
Namespace Tpetra contains the class and methods constituting the Tpetra library.