10#ifndef TPETRA_MATRIXMATRIX_HIP_DEF_HPP
11#define TPETRA_MATRIXMATRIX_HIP_DEF_HPP
13#include "Tpetra_Details_IntRowPtrHelper.hpp"
15#ifdef HAVE_TPETRA_INST_HIP
21template <
class Scalar,
24 class LocalOrdinalViewType>
25struct KernelWrappers<Scalar, LocalOrdinal, GlobalOrdinal,
Tpetra::KokkosCompat::KokkosHIPWrapperNode, LocalOrdinalViewType> {
26 static inline void mult_A_B_newmatrix_kernel_wrapper(CrsMatrixStruct<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode>& Aview,
27 CrsMatrixStruct<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode>& Bview,
28 const LocalOrdinalViewType& Acol2Brow,
29 const LocalOrdinalViewType& Acol2Irow,
30 const LocalOrdinalViewType& Bcol2Ccol,
31 const LocalOrdinalViewType& Icol2Ccol,
32 CrsMatrix<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode>& C,
33 Teuchos::RCP<
const Import<LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode> > Cimport,
34 const std::string& label = std::string(),
35 const Teuchos::RCP<Teuchos::ParameterList>& params = Teuchos::null);
37 static inline void mult_A_B_reuse_kernel_wrapper(CrsMatrixStruct<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode>& Aview,
38 CrsMatrixStruct<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode>& Bview,
39 const LocalOrdinalViewType& Acol2Brow,
40 const LocalOrdinalViewType& Acol2Irow,
41 const LocalOrdinalViewType& Bcol2Ccol,
42 const LocalOrdinalViewType& Icol2Ccol,
43 CrsMatrix<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode>& C,
44 Teuchos::RCP<
const Import<LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode> > Cimport,
45 const std::string& label = std::string(),
46 const Teuchos::RCP<Teuchos::ParameterList>& params = Teuchos::null);
50template <
class Scalar,
52 class GlobalOrdinal,
class LocalOrdinalViewType>
53struct KernelWrappers2<Scalar, LocalOrdinal, GlobalOrdinal,
Tpetra::KokkosCompat::KokkosHIPWrapperNode, LocalOrdinalViewType> {
54 static inline void jacobi_A_B_newmatrix_kernel_wrapper(Scalar omega,
55 const Vector<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode>& Dinv,
56 CrsMatrixStruct<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode>& Aview,
57 CrsMatrixStruct<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode>& Bview,
58 const LocalOrdinalViewType& Acol2Brow,
59 const LocalOrdinalViewType& Acol2Irow,
60 const LocalOrdinalViewType& Bcol2Ccol,
61 const LocalOrdinalViewType& Icol2Ccol,
62 CrsMatrix<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode>& C,
63 Teuchos::RCP<
const Import<LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode> > Cimport,
64 const std::string& label = std::string(),
65 const Teuchos::RCP<Teuchos::ParameterList>& params = Teuchos::null);
67 static inline void jacobi_A_B_reuse_kernel_wrapper(Scalar omega,
68 const Vector<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode>& Dinv,
69 CrsMatrixStruct<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode>& Aview,
70 CrsMatrixStruct<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode>& Bview,
71 const LocalOrdinalViewType& Acol2Brow,
72 const LocalOrdinalViewType& Acol2Irow,
73 const LocalOrdinalViewType& Bcol2Ccol,
74 const LocalOrdinalViewType& Icol2Ccol,
75 CrsMatrix<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode>& C,
76 Teuchos::RCP<
const Import<LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode> > Cimport,
77 const std::string& label = std::string(),
78 const Teuchos::RCP<Teuchos::ParameterList>& params = Teuchos::null);
80 static inline void jacobi_A_B_newmatrix_KokkosKernels(Scalar omega,
81 const Vector<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode>& Dinv,
82 CrsMatrixStruct<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode>& Aview,
83 CrsMatrixStruct<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode>& Bview,
84 const LocalOrdinalViewType& Acol2Brow,
85 const LocalOrdinalViewType& Acol2Irow,
86 const LocalOrdinalViewType& Bcol2Ccol,
87 const LocalOrdinalViewType& Icol2Ccol,
88 CrsMatrix<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode>& C,
89 Teuchos::RCP<
const Import<LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode> > Cimport,
90 const std::string& label = std::string(),
91 const Teuchos::RCP<Teuchos::ParameterList>& params = Teuchos::null);
96template <
class Scalar,
99 class LocalOrdinalViewType>
100void KernelWrappers<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode, LocalOrdinalViewType>::mult_A_B_newmatrix_kernel_wrapper(CrsMatrixStruct<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode>& Aview,
101 CrsMatrixStruct<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode>& Bview,
102 const LocalOrdinalViewType& Acol2Brow,
103 const LocalOrdinalViewType& Acol2Irow,
104 const LocalOrdinalViewType& Bcol2Ccol,
105 const LocalOrdinalViewType& Icol2Ccol,
106 CrsMatrix<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode>& C,
107 Teuchos::RCP<
const Import<LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode> > Cimport,
108 const std::string& label,
109 const Teuchos::RCP<Teuchos::ParameterList>& params) {
111 typedef Tpetra::KokkosCompat::KokkosHIPWrapperNode Node;
112 std::string nodename(
"HIP");
118 typedef typename KCRS::device_type device_t;
119 typedef typename KCRS::StaticCrsGraphType graph_t;
120 typedef typename graph_t::row_map_type::non_const_type lno_view_t;
121 using int_view_t = Kokkos::View<int*, typename lno_view_t::array_layout, typename lno_view_t::memory_space, typename lno_view_t::memory_traits>;
122 typedef typename graph_t::row_map_type::const_type c_lno_view_t;
123 typedef typename graph_t::entries_type::non_const_type lno_nnz_view_t;
124 typedef typename KCRS::values_type::non_const_type scalar_view_t;
130 int team_work_size = 16;
131 std::string myalg(
"SPGEMM_KK_MEMORY");
132 if (!params.is_null()) {
133 if (params->isParameter(
"hip: algorithm"))
134 myalg = params->get(
"hip: algorithm", myalg);
135 if (params->isParameter(
"hip: team work size"))
136 team_work_size = params->get(
"hip: team work size", team_work_size);
140 typedef KokkosKernels::Experimental::KokkosKernelsHandle<
141 typename lno_view_t::const_value_type,
typename lno_nnz_view_t::const_value_type,
typename scalar_view_t::const_value_type,
142 typename device_t::execution_space,
typename device_t::memory_space,
typename device_t::memory_space>
144 using IntKernelHandle = KokkosKernels::Experimental::KokkosKernelsHandle<
145 typename int_view_t::const_value_type,
typename lno_nnz_view_t::const_value_type,
typename scalar_view_t::const_value_type,
146 typename device_t::execution_space,
typename device_t::memory_space,
typename device_t::memory_space>;
149 const KCRS& Amat = Aview.origMatrix->getLocalMatrixDevice();
150 const KCRS& Bmat = Bview.origMatrix->getLocalMatrixDevice();
152 c_lno_view_t Arowptr = Amat.graph.row_map,
153 Browptr = Bmat.graph.row_map;
154 const lno_nnz_view_t Acolind = Amat.graph.entries,
155 Bcolind = Bmat.graph.entries;
156 const scalar_view_t Avals = Amat.values,
160 std::string alg = nodename + std::string(
" algorithm");
162 if (!params.is_null() && params->isParameter(alg)) myalg = params->get(alg, myalg);
163 KokkosSparse::SPGEMMAlgorithm alg_enum = KokkosSparse::StringToSPGEMMAlgorithm(myalg);
166 const KCRS Bmerged = Tpetra::MMdetails::merge_matrices(Aview, Bview, Acol2Brow, Acol2Irow, Bcol2Ccol, Icol2Ccol, C.getColMap()->getLocalNumElements());
172 typename KernelHandle::nnz_lno_t AnumRows = Amat.numRows();
173 typename KernelHandle::nnz_lno_t BnumRows = Bmerged.numRows();
174 typename KernelHandle::nnz_lno_t BnumCols = Bmerged.numCols();
177 lno_view_t row_mapC(Kokkos::ViewAllocateWithoutInitializing(
"non_const_lno_row"), AnumRows + 1);
178 lno_nnz_view_t entriesC;
179 scalar_view_t valuesC;
182 const bool useIntRowptrs =
183 irph.shouldUseIntRowptrs() &&
184 Aview.
origMatrix->getApplyHelper()->shouldUseIntRowptrs();
188 kh.create_spgemm_handle(alg_enum);
189 kh.set_team_work_size(team_work_size);
191 int_view_t int_row_mapC(Kokkos::ViewAllocateWithoutInitializing(
"non_const_int_row"), AnumRows + 1);
193 auto Aint = Aview.origMatrix->getApplyHelper()->getIntRowptrMatrix(Amat);
194 auto Bint = irph.getIntRowptrMatrix(Bmerged);
198 KokkosSparse::spgemm_symbolic(&kh, AnumRows, BnumRows, BnumCols, Aint.graph.row_map, Aint.graph.entries,
false, Bint.graph.row_map, Bint.graph.entries,
false, int_row_mapC);
202 size_t c_nnz_size = kh.get_spgemm_handle()->get_c_nnz();
204 entriesC = lno_nnz_view_t(Kokkos::ViewAllocateWithoutInitializing(
"entriesC"), c_nnz_size);
205 valuesC = scalar_view_t(Kokkos::ViewAllocateWithoutInitializing(
"valuesC"), c_nnz_size);
207 KokkosSparse::spgemm_numeric(&kh, AnumRows, BnumRows, BnumCols, Aint.graph.row_map, Aint.graph.entries, Aint.values,
false, Bint.graph.row_map, Bint.graph.entries, Bint.values,
false, int_row_mapC, entriesC, valuesC);
209 Kokkos::parallel_for(
210 int_row_mapC.size(), KOKKOS_LAMBDA(
int i) { row_mapC(i) = int_row_mapC(i); });
211 kh.destroy_spgemm_handle();
214 kh.create_spgemm_handle(alg_enum);
215 kh.set_team_work_size(team_work_size);
219 KokkosSparse::spgemm_symbolic(&kh, AnumRows, BnumRows, BnumCols, Amat.graph.row_map, Amat.graph.entries,
false, Bmerged.graph.row_map, Bmerged.graph.entries,
false, row_mapC);
222 size_t c_nnz_size = kh.get_spgemm_handle()->get_c_nnz();
224 entriesC = lno_nnz_view_t(Kokkos::ViewAllocateWithoutInitializing(
"entriesC"), c_nnz_size);
225 valuesC = scalar_view_t(Kokkos::ViewAllocateWithoutInitializing(
"valuesC"), c_nnz_size);
229 KokkosSparse::spgemm_numeric(&kh, AnumRows, BnumRows, BnumCols, Amat.graph.row_map, Amat.graph.entries, Amat.values,
false, Bmerged.graph.row_map, Bmerged.graph.entries, Bmerged.values,
false, row_mapC, entriesC, valuesC);
230 kh.destroy_spgemm_handle();
237 if (params.is_null() || params->get(
"sort entries",
true))
238 Import_Util::sortCrsEntries(row_mapC, entriesC, valuesC);
239 C.setAllValues(row_mapC, entriesC, valuesC);
245 RCP<Teuchos::ParameterList> labelList = rcp(
new Teuchos::ParameterList);
246 labelList->set(
"Timer Label", label);
247 if (!params.is_null()) labelList->set(
"compute global constants", params->get(
"compute global constants",
true));
248 RCP<const Export<LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode> > dummyExport;
249 C.expertStaticFillComplete(Bview.origMatrix->getDomainMap(), Aview.origMatrix->getRangeMap(), Cimport, dummyExport, labelList);
253template <
class Scalar,
256 class LocalOrdinalViewType>
257void KernelWrappers<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode, LocalOrdinalViewType>::mult_A_B_reuse_kernel_wrapper(CrsMatrixStruct<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode>& Aview,
258 CrsMatrixStruct<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode>& Bview,
259 const LocalOrdinalViewType& targetMapToOrigRow_dev,
260 const LocalOrdinalViewType& targetMapToImportRow_dev,
261 const LocalOrdinalViewType& Bcol2Ccol_dev,
262 const LocalOrdinalViewType& Icol2Ccol_dev,
263 CrsMatrix<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode>& C,
264 Teuchos::RCP<
const Import<LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode> > Cimport,
265 const std::string& label,
266 const Teuchos::RCP<Teuchos::ParameterList>& params) {
268 typedef Tpetra::KokkosCompat::KokkosHIPWrapperNode Node;
275 typedef typename Tpetra::CrsMatrix<Scalar, LocalOrdinal, GlobalOrdinal, Node>::local_matrix_host_type KCRS;
276 typedef typename KCRS::StaticCrsGraphType graph_t;
277 typedef typename graph_t::row_map_type::const_type c_lno_view_t;
278 typedef typename graph_t::entries_type::non_const_type lno_nnz_view_t;
279 typedef typename KCRS::values_type::non_const_type scalar_view_t;
282 typedef LocalOrdinal LO;
283 typedef GlobalOrdinal GO;
285 typedef Map<LO, GO, NO> map_type;
286 const size_t ST_INVALID = Teuchos::OrdinalTraits<LO>::invalid();
287 const LO LO_INVALID = Teuchos::OrdinalTraits<LO>::invalid();
288 const SC SC_ZERO = Teuchos::ScalarTraits<Scalar>::zero();
293 auto targetMapToOrigRow =
294 Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(),
295 targetMapToOrigRow_dev);
296 auto targetMapToImportRow =
297 Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(),
298 targetMapToImportRow_dev);
300 Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(),
303 Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(),
307 RCP<const map_type> Ccolmap = C.getColMap();
308 size_t m = Aview.origMatrix->getLocalNumRows();
309 size_t n = Ccolmap->getLocalNumElements();
312 const KCRS& Amat = Aview.origMatrix->getLocalMatrixHost();
313 const KCRS& Bmat = Bview.origMatrix->getLocalMatrixHost();
314 const KCRS& Cmat = C.getLocalMatrixHost();
316 c_lno_view_t Arowptr = Amat.graph.row_map,
317 Browptr = Bmat.graph.row_map,
318 Crowptr = Cmat.graph.row_map;
319 const lno_nnz_view_t Acolind = Amat.graph.entries,
320 Bcolind = Bmat.graph.entries,
321 Ccolind = Cmat.graph.entries;
322 const scalar_view_t Avals = Amat.values, Bvals = Bmat.values;
323 scalar_view_t Cvals = Cmat.values;
325 c_lno_view_t Irowptr;
326 lno_nnz_view_t Icolind;
328 if (!Bview.importMatrix.is_null()) {
329 auto lclB = Bview.importMatrix->getLocalMatrixHost();
330 Irowptr = lclB.graph.row_map;
331 Icolind = lclB.graph.entries;
346 std::vector<size_t> c_status(n, ST_INVALID);
349 size_t CSR_ip = 0, OLD_ip = 0;
350 for (
size_t i = 0; i < m; i++) {
354 CSR_ip = Crowptr[i + 1];
355 for (
size_t k = OLD_ip; k < CSR_ip; k++) {
356 c_status[Ccolind[k]] = k;
362 for (
size_t k = Arowptr[i]; k < Arowptr[i + 1]; k++) {
364 const SC Aval = Avals[k];
368 if (targetMapToOrigRow[Aik] != LO_INVALID) {
370 size_t Bk = Teuchos::as<size_t>(targetMapToOrigRow[Aik]);
372 for (
size_t j = Browptr[Bk]; j < Browptr[Bk + 1]; ++j) {
374 LO Cij = Bcol2Ccol[Bkj];
376 TEUCHOS_TEST_FOR_EXCEPTION(c_status[Cij] < OLD_ip || c_status[Cij] >= CSR_ip,
377 std::runtime_error,
"Trying to insert a new entry (" << i <<
"," << Cij <<
") into a static graph "
378 <<
"(c_status = " << c_status[Cij] <<
" of [" << OLD_ip <<
"," << CSR_ip <<
"))");
380 Cvals[c_status[Cij]] += Aval * Bvals[j];
385 size_t Ik = Teuchos::as<size_t>(targetMapToImportRow[Aik]);
386 for (
size_t j = Irowptr[Ik]; j < Irowptr[Ik + 1]; ++j) {
388 LO Cij = Icol2Ccol[Ikj];
390 TEUCHOS_TEST_FOR_EXCEPTION(c_status[Cij] < OLD_ip || c_status[Cij] >= CSR_ip,
391 std::runtime_error,
"Trying to insert a new entry (" << i <<
"," << Cij <<
") into a static graph "
392 <<
"(c_status = " << c_status[Cij] <<
" of [" << OLD_ip <<
"," << CSR_ip <<
"))");
394 Cvals[c_status[Cij]] += Aval * Ivals[j];
400 C.fillComplete(C.getDomainMap(), C.getRangeMap());
404template <
class Scalar,
407 class LocalOrdinalViewType>
408void KernelWrappers2<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode, LocalOrdinalViewType>::jacobi_A_B_newmatrix_kernel_wrapper(Scalar omega,
409 const Vector<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode>& Dinv,
410 CrsMatrixStruct<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode>& Aview,
411 CrsMatrixStruct<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode>& Bview,
412 const LocalOrdinalViewType& Acol2Brow,
413 const LocalOrdinalViewType& Acol2Irow,
414 const LocalOrdinalViewType& Bcol2Ccol,
415 const LocalOrdinalViewType& Icol2Ccol,
416 CrsMatrix<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode>& C,
417 Teuchos::RCP<
const Import<LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode> > Cimport,
418 const std::string& label,
419 const Teuchos::RCP<Teuchos::ParameterList>& params) {
425 std::string myalg(
"KK");
426 if (!params.is_null()) {
427 if (params->isParameter(
"hip: jacobi algorithm"))
428 myalg = params->get(
"hip: jacobi algorithm", myalg);
431 if (myalg ==
"MSAK") {
432 ::Tpetra::MatrixMatrix::ExtraKernels::jacobi_A_B_newmatrix_MultiplyScaleAddKernel(omega, Dinv, Aview, Bview, Acol2Brow, Acol2Irow, Bcol2Ccol, Icol2Ccol, C, Cimport, label, params);
433 }
else if (myalg ==
"KK") {
434 jacobi_A_B_newmatrix_KokkosKernels(omega, Dinv, Aview, Bview, Acol2Brow, Acol2Irow, Bcol2Ccol, Icol2Ccol, C, Cimport, label, params);
436 throw std::runtime_error(
"Tpetra::MatrixMatrix::Jacobi newmatrix unknown kernel");
442 RCP<Teuchos::ParameterList> labelList = rcp(
new Teuchos::ParameterList);
443 labelList->set(
"Timer Label", label);
444 if (!params.is_null()) labelList->set(
"compute global constants", params->get(
"compute global constants",
true));
447 if (!C.isFillComplete()) {
448 RCP<const Export<LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode> > dummyExport;
449 C.expertStaticFillComplete(Bview.origMatrix->getDomainMap(), Aview.origMatrix->getRangeMap(), Cimport, dummyExport, labelList);
454template <
class Scalar,
457 class LocalOrdinalViewType>
458void KernelWrappers2<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode, LocalOrdinalViewType>::jacobi_A_B_reuse_kernel_wrapper(Scalar omega,
459 const Vector<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode>& Dinv,
460 CrsMatrixStruct<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode>& Aview,
461 CrsMatrixStruct<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode>& Bview,
462 const LocalOrdinalViewType& targetMapToOrigRow_dev,
463 const LocalOrdinalViewType& targetMapToImportRow_dev,
464 const LocalOrdinalViewType& Bcol2Ccol_dev,
465 const LocalOrdinalViewType& Icol2Ccol_dev,
466 CrsMatrix<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode>& C,
467 Teuchos::RCP<
const Import<LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode> > Cimport,
468 const std::string& label,
469 const Teuchos::RCP<Teuchos::ParameterList>& params) {
471 typedef Tpetra::KokkosCompat::KokkosHIPWrapperNode Node;
478 typedef typename Tpetra::CrsMatrix<Scalar, LocalOrdinal, GlobalOrdinal, Node>::local_matrix_host_type KCRS;
479 typedef typename KCRS::StaticCrsGraphType graph_t;
480 typedef typename graph_t::row_map_type::const_type c_lno_view_t;
481 typedef typename graph_t::entries_type::non_const_type lno_nnz_view_t;
482 typedef typename KCRS::values_type::non_const_type scalar_view_t;
483 typedef typename scalar_view_t::memory_space scalar_memory_space;
486 typedef LocalOrdinal LO;
487 typedef GlobalOrdinal GO;
489 typedef Map<LO, GO, NO> map_type;
490 const size_t ST_INVALID = Teuchos::OrdinalTraits<LO>::invalid();
491 const LO LO_INVALID = Teuchos::OrdinalTraits<LO>::invalid();
492 const SC SC_ZERO = Teuchos::ScalarTraits<Scalar>::zero();
497 auto targetMapToOrigRow =
498 Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(),
499 targetMapToOrigRow_dev);
500 auto targetMapToImportRow =
501 Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(),
502 targetMapToImportRow_dev);
504 Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(),
507 Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(),
511 RCP<const map_type> Ccolmap = C.getColMap();
512 size_t m = Aview.origMatrix->getLocalNumRows();
513 size_t n = Ccolmap->getLocalNumElements();
516 const KCRS& Amat = Aview.origMatrix->getLocalMatrixHost();
517 const KCRS& Bmat = Bview.origMatrix->getLocalMatrixHost();
518 const KCRS& Cmat = C.getLocalMatrixHost();
520 c_lno_view_t Arowptr = Amat.graph.row_map, Browptr = Bmat.graph.row_map, Crowptr = Cmat.graph.row_map;
521 const lno_nnz_view_t Acolind = Amat.graph.entries, Bcolind = Bmat.graph.entries, Ccolind = Cmat.graph.entries;
522 const scalar_view_t Avals = Amat.values, Bvals = Bmat.values;
523 scalar_view_t Cvals = Cmat.values;
525 c_lno_view_t Irowptr;
526 lno_nnz_view_t Icolind;
528 if (!Bview.importMatrix.is_null()) {
529 auto lclB = Bview.importMatrix->getLocalMatrixHost();
530 Irowptr = lclB.graph.row_map;
531 Icolind = lclB.graph.entries;
537 Dinv.template getLocalView<scalar_memory_space>(Access::ReadOnly);
546 std::vector<size_t> c_status(n, ST_INVALID);
549 size_t CSR_ip = 0, OLD_ip = 0;
550 for (
size_t i = 0; i < m; i++) {
554 CSR_ip = Crowptr[i + 1];
555 for (
size_t k = OLD_ip; k < CSR_ip; k++) {
556 c_status[Ccolind[k]] = k;
562 SC minusOmegaDval = -omega * Dvals(i, 0);
565 for (
size_t j = Browptr[i]; j < Browptr[i + 1]; j++) {
566 Scalar Bval = Bvals[j];
570 LO Cij = Bcol2Ccol[Bij];
572 TEUCHOS_TEST_FOR_EXCEPTION(c_status[Cij] < OLD_ip || c_status[Cij] >= CSR_ip,
573 std::runtime_error,
"Trying to insert a new entry into a static graph");
575 Cvals[c_status[Cij]] = Bvals[j];
579 for (
size_t k = Arowptr[i]; k < Arowptr[i + 1]; k++) {
581 const SC Aval = Avals[k];
585 if (targetMapToOrigRow[Aik] != LO_INVALID) {
587 size_t Bk = Teuchos::as<size_t>(targetMapToOrigRow[Aik]);
589 for (
size_t j = Browptr[Bk]; j < Browptr[Bk + 1]; ++j) {
591 LO Cij = Bcol2Ccol[Bkj];
593 TEUCHOS_TEST_FOR_EXCEPTION(c_status[Cij] < OLD_ip || c_status[Cij] >= CSR_ip,
594 std::runtime_error,
"Trying to insert a new entry into a static graph");
596 Cvals[c_status[Cij]] += minusOmegaDval * Aval * Bvals[j];
601 size_t Ik = Teuchos::as<size_t>(targetMapToImportRow[Aik]);
602 for (
size_t j = Irowptr[Ik]; j < Irowptr[Ik + 1]; ++j) {
604 LO Cij = Icol2Ccol[Ikj];
606 TEUCHOS_TEST_FOR_EXCEPTION(c_status[Cij] < OLD_ip || c_status[Cij] >= CSR_ip,
607 std::runtime_error,
"Trying to insert a new entry into a static graph");
609 Cvals[c_status[Cij]] += minusOmegaDval * Aval * Ivals[j];
618 C.fillComplete(C.getDomainMap(), C.getRangeMap());
622template <
class Scalar,
625 class LocalOrdinalViewType>
626void KernelWrappers2<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode, LocalOrdinalViewType>::jacobi_A_B_newmatrix_KokkosKernels(Scalar omega,
627 const Vector<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode>& Dinv,
628 CrsMatrixStruct<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode>& Aview,
629 CrsMatrixStruct<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode>& Bview,
630 const LocalOrdinalViewType& Acol2Brow,
631 const LocalOrdinalViewType& Acol2Irow,
632 const LocalOrdinalViewType& Bcol2Ccol,
633 const LocalOrdinalViewType& Icol2Ccol,
634 CrsMatrix<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode>& C,
635 Teuchos::RCP<
const Import<LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode> > Cimport,
636 const std::string& label,
637 const Teuchos::RCP<Teuchos::ParameterList>& params) {
641 auto rowMap = Aview.origMatrix->getRowMap();
643 Aview.origMatrix->getLocalDiagCopy(diags);
644 size_t diagLength = rowMap->getLocalNumElements();
645 Teuchos::Array<Scalar> diagonal(diagLength);
646 diags.get1dCopy(diagonal());
648 for (
size_t i = 0; i < diagLength; ++i) {
649 TEUCHOS_TEST_FOR_EXCEPTION(diagonal[i] == Teuchos::ScalarTraits<Scalar>::zero(),
651 "Matrix A has a zero/missing diagonal: " << diagonal[i] << std::endl
652 <<
"KokkosKernels Jacobi-fused SpGEMM requires nonzero diagonal entries in A" << std::endl);
661 using device_t =
typename Tpetra::KokkosCompat::KokkosHIPWrapperNode::device_type;
663 using graph_t =
typename matrix_t::StaticCrsGraphType;
664 using lno_view_t =
typename graph_t::row_map_type::non_const_type;
665 using int_view_t = Kokkos::View<
int*,
666 typename lno_view_t::array_layout,
667 typename lno_view_t::memory_space,
668 typename lno_view_t::memory_traits>;
669 using c_lno_view_t =
typename graph_t::row_map_type::const_type;
670 using lno_nnz_view_t =
typename graph_t::entries_type::non_const_type;
671 using scalar_view_t =
typename matrix_t::values_type::non_const_type;
674 using handle_t =
typename KokkosKernels::Experimental::KokkosKernelsHandle<
675 typename lno_view_t::const_value_type,
typename lno_nnz_view_t::const_value_type,
typename scalar_view_t::const_value_type,
676 typename device_t::execution_space,
typename device_t::memory_space,
typename device_t::memory_space>;
677 using int_handle_t =
typename KokkosKernels::Experimental::KokkosKernelsHandle<
678 typename int_view_t::const_value_type,
typename lno_nnz_view_t::const_value_type,
typename scalar_view_t::const_value_type,
679 typename device_t::execution_space,
typename device_t::memory_space,
typename device_t::memory_space>;
682 const matrix_t Bmerged = Tpetra::MMdetails::merge_matrices(Aview, Bview, Acol2Brow, Acol2Irow, Bcol2Ccol, Icol2Ccol, C.getColMap()->getLocalNumElements());
685 const matrix_t& Amat = Aview.origMatrix->getLocalMatrixDevice();
686 const matrix_t& Bmat = Bview.origMatrix->getLocalMatrixDevice();
688 typename handle_t::nnz_lno_t AnumRows = Amat.numRows();
689 typename handle_t::nnz_lno_t BnumRows = Bmerged.numRows();
690 typename handle_t::nnz_lno_t BnumCols = Bmerged.numCols();
692 c_lno_view_t Arowptr = Amat.graph.row_map, Browptr = Bmerged.graph.row_map;
693 const lno_nnz_view_t Acolind = Amat.graph.entries, Bcolind = Bmerged.graph.entries;
694 const scalar_view_t Avals = Amat.values, Bvals = Bmerged.values;
697 lno_view_t row_mapC(Kokkos::ViewAllocateWithoutInitializing(
"row_mapC"), AnumRows + 1);
698 lno_nnz_view_t entriesC;
699 scalar_view_t valuesC;
702 int team_work_size = 16;
703 std::string myalg(
"SPGEMM_KK_MEMORY");
704 if (!params.is_null()) {
705 if (params->isParameter(
"hip: algorithm"))
706 myalg = params->get(
"hip: algorithm", myalg);
707 if (params->isParameter(
"hip: team work size"))
708 team_work_size = params->get(
"hip: team work size", team_work_size);
712 std::string nodename(
"HIP");
713 std::string alg = nodename + std::string(
" algorithm");
714 if (!params.is_null() && params->isParameter(alg)) myalg = params->get(alg, myalg);
715 KokkosSparse::SPGEMMAlgorithm alg_enum = KokkosSparse::StringToSPGEMMAlgorithm(myalg);
718 const bool useIntRowptrs =
719 irph.shouldUseIntRowptrs() &&
720 Aview.
origMatrix->getApplyHelper()->shouldUseIntRowptrs();
724 kh.create_spgemm_handle(alg_enum);
725 kh.set_team_work_size(team_work_size);
727 int_view_t int_row_mapC(Kokkos::ViewAllocateWithoutInitializing(
"int_row_mapC"), AnumRows + 1);
729 auto Aint = Aview.origMatrix->getApplyHelper()->getIntRowptrMatrix(Amat);
730 auto Bint = irph.getIntRowptrMatrix(Bmerged);
734 KokkosSparse::spgemm_symbolic(&kh, AnumRows, BnumRows, BnumCols,
735 Aint.graph.row_map, Aint.graph.entries,
false,
736 Bint.graph.row_map, Bint.graph.entries,
false,
742 size_t c_nnz_size = kh.get_spgemm_handle()->get_c_nnz();
744 entriesC = lno_nnz_view_t(Kokkos::ViewAllocateWithoutInitializing(
"entriesC"), c_nnz_size);
745 valuesC = scalar_view_t(Kokkos::ViewAllocateWithoutInitializing(
"valuesC"), c_nnz_size);
750 KokkosSparse::Experimental::spgemm_jacobi(&kh, AnumRows, BnumRows, BnumCols,
751 Aint.graph.row_map, Aint.graph.entries, Amat.values,
false,
752 Bint.graph.row_map, Bint.graph.entries, Bint.values,
false,
753 int_row_mapC, entriesC, valuesC,
754 omega, Dinv.getLocalViewDevice(Access::ReadOnly));
756 Kokkos::parallel_for(
757 int_row_mapC.size(), KOKKOS_LAMBDA(
int i) { row_mapC(i) = int_row_mapC(i); });
758 kh.destroy_spgemm_handle();
761 kh.create_spgemm_handle(alg_enum);
762 kh.set_team_work_size(team_work_size);
767 KokkosSparse::spgemm_symbolic(&kh, AnumRows, BnumRows, BnumCols,
768 Amat.graph.row_map, Amat.graph.entries,
false,
769 Bmerged.graph.row_map, Bmerged.graph.entries,
false,
775 size_t c_nnz_size = kh.get_spgemm_handle()->get_c_nnz();
777 entriesC = lno_nnz_view_t(Kokkos::ViewAllocateWithoutInitializing(
"entriesC"), c_nnz_size);
778 valuesC = scalar_view_t(Kokkos::ViewAllocateWithoutInitializing(
"valuesC"), c_nnz_size);
780 KokkosSparse::Experimental::spgemm_jacobi(&kh, AnumRows, BnumRows, BnumCols,
781 Amat.graph.row_map, Amat.graph.entries, Amat.values,
false,
782 Bmerged.graph.row_map, Bmerged.graph.entries, Bmerged.values,
false,
783 row_mapC, entriesC, valuesC,
784 omega, Dinv.getLocalViewDevice(Access::ReadOnly));
785 kh.destroy_spgemm_handle();
792 if (params.is_null() || params->get(
"sort entries",
true))
793 Import_Util::sortCrsEntries(row_mapC, entriesC, valuesC);
794 C.setAllValues(row_mapC, entriesC, valuesC);
800 Teuchos::RCP<Teuchos::ParameterList> labelList = rcp(
new Teuchos::ParameterList);
801 labelList->set(
"Timer Label", label);
802 if (!params.is_null()) labelList->set(
"compute global constants", params->get(
"compute global constants",
true));
803 Teuchos::RCP<const Export<LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode> > dummyExport;
804 C.expertStaticFillComplete(Bview.origMatrix->getDomainMap(), Aview.origMatrix->getRangeMap(), Cimport, dummyExport, labelList);
Struct that holds views of the contents of a CrsMatrix.
Teuchos::RCP< const CrsMatrix< Scalar, LocalOrdinal, GlobalOrdinal, Node > > origMatrix
The original matrix.
KokkosSparse::CrsMatrix< impl_scalar_type, local_ordinal_type, device_type, void, typename local_graph_device_type::size_type > local_matrix_device_type
The specialization of Kokkos::CrsMatrix that represents the part of the sparse matrix on each MPI pro...
static bool debug()
Whether Tpetra is in debug mode.
Namespace Tpetra contains the class and methods constituting the Tpetra library.