10#ifndef TPETRA_MATRIXMATRIX_HIP_DEF_HPP
11#define TPETRA_MATRIXMATRIX_HIP_DEF_HPP
13#include "Tpetra_Details_IntRowPtrHelper.hpp"
15#ifdef HAVE_TPETRA_INST_HIP
21template <
class Scalar,
24 class LocalOrdinalViewType>
25struct KernelWrappers<Scalar, LocalOrdinal, GlobalOrdinal,
Tpetra::KokkosCompat::KokkosHIPWrapperNode, LocalOrdinalViewType> {
26 static inline void mult_A_B_newmatrix_kernel_wrapper(CrsMatrixStruct<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode>& Aview,
27 CrsMatrixStruct<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode>& Bview,
28 const LocalOrdinalViewType& Acol2Brow,
29 const LocalOrdinalViewType& Acol2Irow,
30 const LocalOrdinalViewType& Bcol2Ccol,
31 const LocalOrdinalViewType& Icol2Ccol,
32 CrsMatrix<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode>& C,
33 Teuchos::RCP<
const Import<LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode> > Cimport,
34 const std::string& label = std::string(),
35 const Teuchos::RCP<Teuchos::ParameterList>& params = Teuchos::null);
37 static inline void mult_A_B_reuse_kernel_wrapper(CrsMatrixStruct<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode>& Aview,
38 CrsMatrixStruct<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode>& Bview,
39 const LocalOrdinalViewType& Acol2Brow,
40 const LocalOrdinalViewType& Acol2Irow,
41 const LocalOrdinalViewType& Bcol2Ccol,
42 const LocalOrdinalViewType& Icol2Ccol,
43 CrsMatrix<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode>& C,
44 Teuchos::RCP<
const Import<LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode> > Cimport,
45 const std::string& label = std::string(),
46 const Teuchos::RCP<Teuchos::ParameterList>& params = Teuchos::null);
50template <
class Scalar,
52 class GlobalOrdinal,
class LocalOrdinalViewType>
53struct KernelWrappers2<Scalar, LocalOrdinal, GlobalOrdinal,
Tpetra::KokkosCompat::KokkosHIPWrapperNode, LocalOrdinalViewType> {
54 static inline void jacobi_A_B_newmatrix_kernel_wrapper(
typename Teuchos::ScalarTraits<Scalar>::magnitudeType omega,
55 const Vector<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode>& Dinv,
56 CrsMatrixStruct<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode>& Aview,
57 CrsMatrixStruct<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode>& Bview,
58 const LocalOrdinalViewType& Acol2Brow,
59 const LocalOrdinalViewType& Acol2Irow,
60 const LocalOrdinalViewType& Bcol2Ccol,
61 const LocalOrdinalViewType& Icol2Ccol,
62 CrsMatrix<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode>& C,
63 Teuchos::RCP<
const Import<LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode> > Cimport,
64 const std::string& label = std::string(),
65 const Teuchos::RCP<Teuchos::ParameterList>& params = Teuchos::null);
67 static inline void jacobi_A_B_reuse_kernel_wrapper(
typename Teuchos::ScalarTraits<Scalar>::magnitudeType omega,
68 const Vector<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode>& Dinv,
69 CrsMatrixStruct<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode>& Aview,
70 CrsMatrixStruct<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode>& Bview,
71 const LocalOrdinalViewType& Acol2Brow,
72 const LocalOrdinalViewType& Acol2Irow,
73 const LocalOrdinalViewType& Bcol2Ccol,
74 const LocalOrdinalViewType& Icol2Ccol,
75 CrsMatrix<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode>& C,
76 Teuchos::RCP<
const Import<LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode> > Cimport,
77 const std::string& label = std::string(),
78 const Teuchos::RCP<Teuchos::ParameterList>& params = Teuchos::null);
80 static inline void jacobi_A_B_newmatrix_KokkosKernels(
typename Teuchos::ScalarTraits<Scalar>::magnitudeType omega,
81 const Vector<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode>& Dinv,
82 CrsMatrixStruct<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode>& Aview,
83 CrsMatrixStruct<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode>& Bview,
84 const LocalOrdinalViewType& Acol2Brow,
85 const LocalOrdinalViewType& Acol2Irow,
86 const LocalOrdinalViewType& Bcol2Ccol,
87 const LocalOrdinalViewType& Icol2Ccol,
88 CrsMatrix<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode>& C,
89 Teuchos::RCP<
const Import<LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode> > Cimport,
90 const std::string& label = std::string(),
91 const Teuchos::RCP<Teuchos::ParameterList>& params = Teuchos::null);
96template <
class Scalar,
99 class LocalOrdinalViewType>
100void KernelWrappers<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode, LocalOrdinalViewType>::mult_A_B_newmatrix_kernel_wrapper(CrsMatrixStruct<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode>& Aview,
101 CrsMatrixStruct<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode>& Bview,
102 const LocalOrdinalViewType& Acol2Brow,
103 const LocalOrdinalViewType& Acol2Irow,
104 const LocalOrdinalViewType& Bcol2Ccol,
105 const LocalOrdinalViewType& Icol2Ccol,
106 CrsMatrix<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode>& C,
107 Teuchos::RCP<
const Import<LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode> > Cimport,
108 const std::string& label,
109 const Teuchos::RCP<Teuchos::ParameterList>& params) {
111 typedef Tpetra::KokkosCompat::KokkosHIPWrapperNode Node;
112 std::string nodename(
"HIP");
118 typedef typename KCRS::device_type device_t;
119 typedef typename KCRS::StaticCrsGraphType graph_t;
120 typedef typename graph_t::row_map_type::non_const_type lno_view_t;
121 using int_view_t = Kokkos::View<int*, typename lno_view_t::array_layout, typename lno_view_t::memory_space, typename lno_view_t::memory_traits>;
122 typedef typename graph_t::row_map_type::const_type c_lno_view_t;
123 typedef typename graph_t::entries_type::non_const_type lno_nnz_view_t;
124 typedef typename KCRS::values_type::non_const_type scalar_view_t;
130 int team_work_size = 16;
131 std::string myalg(
"SPGEMM_KK_MEMORY");
132 if (!params.is_null()) {
133 if (params->isParameter(
"hip: algorithm"))
134 myalg = params->get(
"hip: algorithm", myalg);
135 if (params->isParameter(
"hip: team work size"))
136 team_work_size = params->get(
"hip: team work size", team_work_size);
140 typedef KokkosKernels::Experimental::KokkosKernelsHandle<
141 typename lno_view_t::const_value_type,
typename lno_nnz_view_t::const_value_type,
typename scalar_view_t::const_value_type,
142 typename device_t::execution_space,
typename device_t::memory_space,
typename device_t::memory_space>
144 using IntKernelHandle = KokkosKernels::Experimental::KokkosKernelsHandle<
145 typename int_view_t::const_value_type,
typename lno_nnz_view_t::const_value_type,
typename scalar_view_t::const_value_type,
146 typename device_t::execution_space,
typename device_t::memory_space,
typename device_t::memory_space>;
149 const KCRS Amat = Aview.origMatrix->getLocalMatrixDevice();
150 const KCRS Bmat = Bview.origMatrix->getLocalMatrixDevice();
152 c_lno_view_t Arowptr = Amat.graph.row_map,
153 Browptr = Bmat.graph.row_map;
154 const lno_nnz_view_t Acolind = Amat.graph.entries,
155 Bcolind = Bmat.graph.entries;
156 const scalar_view_t Avals = Amat.values,
160 std::string alg = nodename + std::string(
" algorithm");
162 if (!params.is_null() && params->isParameter(alg)) myalg = params->get(alg, myalg);
163 KokkosSparse::SPGEMMAlgorithm alg_enum = KokkosSparse::StringToSPGEMMAlgorithm(myalg);
166 const KCRS Bmerged = Tpetra::MMdetails::merge_matrices(Aview, Bview, Acol2Brow, Acol2Irow, Bcol2Ccol, Icol2Ccol, C.getColMap()->getLocalNumElements());
172 typename KernelHandle::nnz_lno_t AnumRows = Amat.numRows();
173 typename KernelHandle::nnz_lno_t BnumRows = Bmerged.numRows();
174 typename KernelHandle::nnz_lno_t BnumCols = Bmerged.numCols();
177 lno_view_t row_mapC(Kokkos::ViewAllocateWithoutInitializing(
"non_const_lno_row"), AnumRows + 1);
178 lno_nnz_view_t entriesC;
179 scalar_view_t valuesC;
182 const bool useIntRowptrs =
183 irph.shouldUseIntRowptrs() &&
184 Aview.
origMatrix->getApplyHelper()->shouldUseIntRowptrs();
188 kh.create_spgemm_handle(alg_enum);
189 kh.set_team_work_size(team_work_size);
191 int_view_t int_row_mapC(Kokkos::ViewAllocateWithoutInitializing(
"non_const_int_row"), AnumRows + 1);
193 auto Aint = Aview.origMatrix->getApplyHelper()->getIntRowptrMatrix(Amat);
194 auto Bint = irph.getIntRowptrMatrix(Bmerged);
198 KokkosSparse::spgemm_symbolic(&kh, AnumRows, BnumRows, BnumCols, Aint.graph.row_map, Aint.graph.entries,
false, Bint.graph.row_map, Bint.graph.entries,
false, int_row_mapC);
202 size_t c_nnz_size = kh.get_spgemm_handle()->get_c_nnz();
204 entriesC = lno_nnz_view_t(Kokkos::ViewAllocateWithoutInitializing(
"entriesC"), c_nnz_size);
205 valuesC = scalar_view_t(Kokkos::ViewAllocateWithoutInitializing(
"valuesC"), c_nnz_size);
207 KokkosSparse::spgemm_numeric(&kh, AnumRows, BnumRows, BnumCols, Aint.graph.row_map, Aint.graph.entries, Aint.values,
false, Bint.graph.row_map, Bint.graph.entries, Bint.values,
false, int_row_mapC, entriesC, valuesC);
209 Kokkos::parallel_for(
210 int_row_mapC.size(), KOKKOS_LAMBDA(
int i) { row_mapC(i) = int_row_mapC(i); });
211 kh.destroy_spgemm_handle();
214 kh.create_spgemm_handle(alg_enum);
215 kh.set_team_work_size(team_work_size);
219 KokkosSparse::spgemm_symbolic(&kh, AnumRows, BnumRows, BnumCols, Amat.graph.row_map, Amat.graph.entries,
false, Bmerged.graph.row_map, Bmerged.graph.entries,
false, row_mapC);
222 size_t c_nnz_size = kh.get_spgemm_handle()->get_c_nnz();
224 entriesC = lno_nnz_view_t(Kokkos::ViewAllocateWithoutInitializing(
"entriesC"), c_nnz_size);
225 valuesC = scalar_view_t(Kokkos::ViewAllocateWithoutInitializing(
"valuesC"), c_nnz_size);
229 KokkosSparse::spgemm_numeric(&kh, AnumRows, BnumRows, BnumCols, Amat.graph.row_map, Amat.graph.entries, Amat.values,
false, Bmerged.graph.row_map, Bmerged.graph.entries, Bmerged.values,
false, row_mapC, entriesC, valuesC);
230 kh.destroy_spgemm_handle();
237 if (params.is_null() || params->get(
"sort entries",
true))
238 Import_Util::sortCrsEntries(row_mapC, entriesC, valuesC);
239 C.setAllValues(row_mapC, entriesC, valuesC);
245 RCP<Teuchos::ParameterList> labelList = rcp(
new Teuchos::ParameterList);
246 labelList->set(
"Timer Label", label);
247 if (!params.is_null()) labelList->set(
"compute global constants", params->get(
"compute global constants",
true));
248 RCP<const Export<LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode> > dummyExport;
249 C.expertStaticFillComplete(Bview.origMatrix->getDomainMap(), Aview.origMatrix->getRangeMap(), Cimport, dummyExport, labelList);
253template <
class Scalar,
256 class LocalOrdinalViewType>
257void KernelWrappers<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode, LocalOrdinalViewType>::mult_A_B_reuse_kernel_wrapper(CrsMatrixStruct<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode>& Aview,
258 CrsMatrixStruct<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode>& Bview,
259 const LocalOrdinalViewType& targetMapToOrigRow_dev,
260 const LocalOrdinalViewType& targetMapToImportRow_dev,
261 const LocalOrdinalViewType& Bcol2Ccol_dev,
262 const LocalOrdinalViewType& Icol2Ccol_dev,
263 CrsMatrix<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode>& C,
264 Teuchos::RCP<
const Import<LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode> > Cimport,
265 const std::string& label,
266 const Teuchos::RCP<Teuchos::ParameterList>& params) {
268 typedef Tpetra::KokkosCompat::KokkosHIPWrapperNode Node;
274 bool throwOnInsert =
true;
275 if (!params.is_null() && params->isType<
bool>(
"MM Throw For Non-Existent Entries"))
276 throwOnInsert = params->get<
bool>(
"MM Throw For Non-Existent Entries");
281 typedef typename Tpetra::CrsMatrix<Scalar, LocalOrdinal, GlobalOrdinal, Node>::local_matrix_host_type KCRS;
282 typedef typename KCRS::StaticCrsGraphType graph_t;
283 typedef typename graph_t::row_map_type::const_type c_lno_view_t;
284 typedef typename graph_t::entries_type::non_const_type lno_nnz_view_t;
285 typedef typename KCRS::values_type::non_const_type scalar_view_t;
288 typedef LocalOrdinal LO;
289 typedef GlobalOrdinal GO;
291 typedef Map<LO, GO, NO> map_type;
292 const size_t ST_INVALID = Teuchos::OrdinalTraits<LO>::invalid();
293 const LO LO_INVALID = Teuchos::OrdinalTraits<LO>::invalid();
294 const SC SC_ZERO = Teuchos::ScalarTraits<Scalar>::zero();
299 auto targetMapToOrigRow =
300 Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(),
301 targetMapToOrigRow_dev);
302 auto targetMapToImportRow =
303 Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(),
304 targetMapToImportRow_dev);
306 Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(),
309 Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(),
313 RCP<const map_type> Ccolmap = C.getColMap();
314 size_t m = Aview.origMatrix->getLocalNumRows();
315 size_t n = Ccolmap->getLocalNumElements();
318 const KCRS Amat = Aview.origMatrix->getLocalMatrixHost();
319 const KCRS Bmat = Bview.origMatrix->getLocalMatrixHost();
320 const KCRS Cmat = C.getLocalMatrixHost();
322 c_lno_view_t Arowptr = Amat.graph.row_map,
323 Browptr = Bmat.graph.row_map,
324 Crowptr = Cmat.graph.row_map;
325 const lno_nnz_view_t Acolind = Amat.graph.entries,
326 Bcolind = Bmat.graph.entries,
327 Ccolind = Cmat.graph.entries;
328 const scalar_view_t Avals = Amat.values, Bvals = Bmat.values;
329 scalar_view_t Cvals = Cmat.values;
331 c_lno_view_t Irowptr;
332 lno_nnz_view_t Icolind;
334 if (!Bview.importMatrix.is_null()) {
335 auto lclB = Bview.importMatrix->getLocalMatrixHost();
336 Irowptr = lclB.graph.row_map;
337 Icolind = lclB.graph.entries;
352 std::vector<size_t> c_status(n, ST_INVALID);
355 size_t CSR_ip = 0, OLD_ip = 0;
356 for (
size_t i = 0; i < m; i++) {
360 CSR_ip = Crowptr[i + 1];
361 for (
size_t k = OLD_ip; k < CSR_ip; k++) {
362 c_status[Ccolind[k]] = k;
368 for (
size_t k = Arowptr[i]; k < Arowptr[i + 1]; k++) {
370 const SC Aval = Avals[k];
374 if (targetMapToOrigRow[Aik] != LO_INVALID) {
376 size_t Bk = Teuchos::as<size_t>(targetMapToOrigRow[Aik]);
378 for (
size_t j = Browptr[Bk]; j < Browptr[Bk + 1]; ++j) {
380 LO Cij = Bcol2Ccol[Bkj];
382 const bool badInsert = c_status[Cij] < OLD_ip || c_status[Cij] >= CSR_ip;
384 Cvals[c_status[Cij]] += Aval * Bvals[j];
385 else if (throwOnInsert)
386 TEUCHOS_TEST_FOR_EXCEPTION(badInsert,
387 std::runtime_error,
"Trying to insert a new entry (" << i <<
"," << Cij <<
") into a static graph "
388 <<
"(c_status = " << c_status[Cij] <<
" of [" << OLD_ip <<
"," << CSR_ip <<
"))");
393 size_t Ik = Teuchos::as<size_t>(targetMapToImportRow[Aik]);
394 for (
size_t j = Irowptr[Ik]; j < Irowptr[Ik + 1]; ++j) {
396 LO Cij = Icol2Ccol[Ikj];
398 const bool badInsert = c_status[Cij] < OLD_ip || c_status[Cij] >= CSR_ip;
400 Cvals[c_status[Cij]] += Aval * Ivals[j];
401 else if (throwOnInsert)
402 TEUCHOS_TEST_FOR_EXCEPTION(badInsert,
403 std::runtime_error,
"Trying to insert a new entry (" << i <<
"," << Cij <<
") into a static graph "
404 <<
"(c_status = " << c_status[Cij] <<
" of [" << OLD_ip <<
"," << CSR_ip <<
"))");
410 C.fillComplete(C.getDomainMap(), C.getRangeMap());
414template <
class Scalar,
417 class LocalOrdinalViewType>
418void KernelWrappers2<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode, LocalOrdinalViewType>::jacobi_A_B_newmatrix_kernel_wrapper(
typename Teuchos::ScalarTraits<Scalar>::magnitudeType omega,
419 const Vector<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode>& Dinv,
420 CrsMatrixStruct<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode>& Aview,
421 CrsMatrixStruct<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode>& Bview,
422 const LocalOrdinalViewType& Acol2Brow,
423 const LocalOrdinalViewType& Acol2Irow,
424 const LocalOrdinalViewType& Bcol2Ccol,
425 const LocalOrdinalViewType& Icol2Ccol,
426 CrsMatrix<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode>& C,
427 Teuchos::RCP<
const Import<LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode> > Cimport,
428 const std::string& label,
429 const Teuchos::RCP<Teuchos::ParameterList>& params) {
435 std::string myalg(
"KK");
436 if (!params.is_null()) {
437 if (params->isParameter(
"hip: jacobi algorithm"))
438 myalg = params->get(
"hip: jacobi algorithm", myalg);
441 if (myalg ==
"MSAK") {
442 ::Tpetra::MatrixMatrix::ExtraKernels::jacobi_A_B_newmatrix_MultiplyScaleAddKernel(omega, Dinv, Aview, Bview, Acol2Brow, Acol2Irow, Bcol2Ccol, Icol2Ccol, C, Cimport, label, params);
443 }
else if (myalg ==
"KK") {
444 jacobi_A_B_newmatrix_KokkosKernels(omega, Dinv, Aview, Bview, Acol2Brow, Acol2Irow, Bcol2Ccol, Icol2Ccol, C, Cimport, label, params);
446 throw std::runtime_error(
"Tpetra::MatrixMatrix::Jacobi newmatrix unknown kernel");
452 RCP<Teuchos::ParameterList> labelList = rcp(
new Teuchos::ParameterList);
453 labelList->set(
"Timer Label", label);
454 if (!params.is_null()) labelList->set(
"compute global constants", params->get(
"compute global constants",
true));
457 if (!C.isFillComplete()) {
458 RCP<const Export<LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode> > dummyExport;
459 C.expertStaticFillComplete(Bview.origMatrix->getDomainMap(), Aview.origMatrix->getRangeMap(), Cimport, dummyExport, labelList);
464template <
class Scalar,
467 class LocalOrdinalViewType>
468void KernelWrappers2<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode, LocalOrdinalViewType>::jacobi_A_B_reuse_kernel_wrapper(
typename Teuchos::ScalarTraits<Scalar>::magnitudeType omega,
469 const Vector<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode>& Dinv,
470 CrsMatrixStruct<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode>& Aview,
471 CrsMatrixStruct<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode>& Bview,
472 const LocalOrdinalViewType& targetMapToOrigRow_dev,
473 const LocalOrdinalViewType& targetMapToImportRow_dev,
474 const LocalOrdinalViewType& Bcol2Ccol_dev,
475 const LocalOrdinalViewType& Icol2Ccol_dev,
476 CrsMatrix<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode>& C,
477 Teuchos::RCP<
const Import<LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode> > Cimport,
478 const std::string& label,
479 const Teuchos::RCP<Teuchos::ParameterList>& params) {
481 typedef Tpetra::KokkosCompat::KokkosHIPWrapperNode Node;
488 typedef typename Tpetra::CrsMatrix<Scalar, LocalOrdinal, GlobalOrdinal, Node>::local_matrix_host_type KCRS;
489 typedef typename KCRS::StaticCrsGraphType graph_t;
490 typedef typename graph_t::row_map_type::const_type c_lno_view_t;
491 typedef typename graph_t::entries_type::non_const_type lno_nnz_view_t;
492 typedef typename KCRS::values_type::non_const_type scalar_view_t;
493 typedef typename scalar_view_t::memory_space scalar_memory_space;
496 typedef LocalOrdinal LO;
497 typedef GlobalOrdinal GO;
499 typedef Map<LO, GO, NO> map_type;
500 const size_t ST_INVALID = Teuchos::OrdinalTraits<LO>::invalid();
501 const LO LO_INVALID = Teuchos::OrdinalTraits<LO>::invalid();
502 const SC SC_ZERO = Teuchos::ScalarTraits<Scalar>::zero();
507 auto targetMapToOrigRow =
508 Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(),
509 targetMapToOrigRow_dev);
510 auto targetMapToImportRow =
511 Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(),
512 targetMapToImportRow_dev);
514 Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(),
517 Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(),
521 RCP<const map_type> Ccolmap = C.getColMap();
522 size_t m = Aview.origMatrix->getLocalNumRows();
523 size_t n = Ccolmap->getLocalNumElements();
526 const KCRS Amat = Aview.origMatrix->getLocalMatrixHost();
527 const KCRS Bmat = Bview.origMatrix->getLocalMatrixHost();
528 const KCRS Cmat = C.getLocalMatrixHost();
530 c_lno_view_t Arowptr = Amat.graph.row_map, Browptr = Bmat.graph.row_map, Crowptr = Cmat.graph.row_map;
531 const lno_nnz_view_t Acolind = Amat.graph.entries, Bcolind = Bmat.graph.entries, Ccolind = Cmat.graph.entries;
532 const scalar_view_t Avals = Amat.values, Bvals = Bmat.values;
533 scalar_view_t Cvals = Cmat.values;
535 c_lno_view_t Irowptr;
536 lno_nnz_view_t Icolind;
538 if (!Bview.importMatrix.is_null()) {
539 auto lclB = Bview.importMatrix->getLocalMatrixHost();
540 Irowptr = lclB.graph.row_map;
541 Icolind = lclB.graph.entries;
547 Dinv.template getLocalView<scalar_memory_space>(Access::ReadOnly);
556 std::vector<size_t> c_status(n, ST_INVALID);
559 size_t CSR_ip = 0, OLD_ip = 0;
560 for (
size_t i = 0; i < m; i++) {
564 CSR_ip = Crowptr[i + 1];
565 for (
size_t k = OLD_ip; k < CSR_ip; k++) {
566 c_status[Ccolind[k]] = k;
572 SC minusOmegaDval = -omega * Dvals(i, 0);
575 for (
size_t j = Browptr[i]; j < Browptr[i + 1]; j++) {
576 Scalar Bval = Bvals[j];
580 LO Cij = Bcol2Ccol[Bij];
582 TEUCHOS_TEST_FOR_EXCEPTION(c_status[Cij] < OLD_ip || c_status[Cij] >= CSR_ip,
583 std::runtime_error,
"Trying to insert a new entry into a static graph");
585 Cvals[c_status[Cij]] = Bvals[j];
589 for (
size_t k = Arowptr[i]; k < Arowptr[i + 1]; k++) {
591 const SC Aval = Avals[k];
595 if (targetMapToOrigRow[Aik] != LO_INVALID) {
597 size_t Bk = Teuchos::as<size_t>(targetMapToOrigRow[Aik]);
599 for (
size_t j = Browptr[Bk]; j < Browptr[Bk + 1]; ++j) {
601 LO Cij = Bcol2Ccol[Bkj];
603 TEUCHOS_TEST_FOR_EXCEPTION(c_status[Cij] < OLD_ip || c_status[Cij] >= CSR_ip,
604 std::runtime_error,
"Trying to insert a new entry into a static graph");
606 Cvals[c_status[Cij]] += minusOmegaDval * Aval * Bvals[j];
611 size_t Ik = Teuchos::as<size_t>(targetMapToImportRow[Aik]);
612 for (
size_t j = Irowptr[Ik]; j < Irowptr[Ik + 1]; ++j) {
614 LO Cij = Icol2Ccol[Ikj];
616 TEUCHOS_TEST_FOR_EXCEPTION(c_status[Cij] < OLD_ip || c_status[Cij] >= CSR_ip,
617 std::runtime_error,
"Trying to insert a new entry into a static graph");
619 Cvals[c_status[Cij]] += minusOmegaDval * Aval * Ivals[j];
628 C.fillComplete(C.getDomainMap(), C.getRangeMap());
632template <
class Scalar,
635 class LocalOrdinalViewType>
636void KernelWrappers2<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode, LocalOrdinalViewType>::jacobi_A_B_newmatrix_KokkosKernels(
typename Teuchos::ScalarTraits<Scalar>::magnitudeType omega,
637 const Vector<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode>& Dinv,
638 CrsMatrixStruct<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode>& Aview,
639 CrsMatrixStruct<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode>& Bview,
640 const LocalOrdinalViewType& Acol2Brow,
641 const LocalOrdinalViewType& Acol2Irow,
642 const LocalOrdinalViewType& Bcol2Ccol,
643 const LocalOrdinalViewType& Icol2Ccol,
644 CrsMatrix<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode>& C,
645 Teuchos::RCP<
const Import<LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode> > Cimport,
646 const std::string& label,
647 const Teuchos::RCP<Teuchos::ParameterList>& params) {
651 auto rowMap = Aview.origMatrix->getRowMap();
653 Aview.origMatrix->getLocalDiagCopy(diags);
654 size_t diagLength = rowMap->getLocalNumElements();
655 Teuchos::Array<Scalar> diagonal(diagLength);
656 diags.get1dCopy(diagonal());
658 for (
size_t i = 0; i < diagLength; ++i) {
659 TEUCHOS_TEST_FOR_EXCEPTION(diagonal[i] == Teuchos::ScalarTraits<Scalar>::zero(),
661 "Matrix A has a zero/missing diagonal: " << diagonal[i] << std::endl
662 <<
"KokkosKernels Jacobi-fused SpGEMM requires nonzero diagonal entries in A" << std::endl);
671 using device_t =
typename Tpetra::KokkosCompat::KokkosHIPWrapperNode::device_type;
673 using graph_t =
typename matrix_t::StaticCrsGraphType;
674 using lno_view_t =
typename graph_t::row_map_type::non_const_type;
675 using int_view_t = Kokkos::View<
int*,
676 typename lno_view_t::array_layout,
677 typename lno_view_t::memory_space,
678 typename lno_view_t::memory_traits>;
679 using c_lno_view_t =
typename graph_t::row_map_type::const_type;
680 using lno_nnz_view_t =
typename graph_t::entries_type::non_const_type;
681 using scalar_view_t =
typename matrix_t::values_type::non_const_type;
684 using handle_t =
typename KokkosKernels::Experimental::KokkosKernelsHandle<
685 typename lno_view_t::const_value_type,
typename lno_nnz_view_t::const_value_type,
typename scalar_view_t::const_value_type,
686 typename device_t::execution_space,
typename device_t::memory_space,
typename device_t::memory_space>;
687 using int_handle_t =
typename KokkosKernels::Experimental::KokkosKernelsHandle<
688 typename int_view_t::const_value_type,
typename lno_nnz_view_t::const_value_type,
typename scalar_view_t::const_value_type,
689 typename device_t::execution_space,
typename device_t::memory_space,
typename device_t::memory_space>;
692 const matrix_t Bmerged = Tpetra::MMdetails::merge_matrices(Aview, Bview, Acol2Brow, Acol2Irow, Bcol2Ccol, Icol2Ccol, C.getColMap()->getLocalNumElements());
695 const matrix_t Amat = Aview.origMatrix->getLocalMatrixDevice();
696 const matrix_t Bmat = Bview.origMatrix->getLocalMatrixDevice();
698 typename handle_t::nnz_lno_t AnumRows = Amat.numRows();
699 typename handle_t::nnz_lno_t BnumRows = Bmerged.numRows();
700 typename handle_t::nnz_lno_t BnumCols = Bmerged.numCols();
702 c_lno_view_t Arowptr = Amat.graph.row_map, Browptr = Bmerged.graph.row_map;
703 const lno_nnz_view_t Acolind = Amat.graph.entries, Bcolind = Bmerged.graph.entries;
704 const scalar_view_t Avals = Amat.values, Bvals = Bmerged.values;
707 lno_view_t row_mapC(Kokkos::ViewAllocateWithoutInitializing(
"row_mapC"), AnumRows + 1);
708 lno_nnz_view_t entriesC;
709 scalar_view_t valuesC;
712 int team_work_size = 16;
713 std::string myalg(
"SPGEMM_KK_MEMORY");
714 if (!params.is_null()) {
715 if (params->isParameter(
"hip: algorithm"))
716 myalg = params->get(
"hip: algorithm", myalg);
717 if (params->isParameter(
"hip: team work size"))
718 team_work_size = params->get(
"hip: team work size", team_work_size);
722 std::string nodename(
"HIP");
723 std::string alg = nodename + std::string(
" algorithm");
724 if (!params.is_null() && params->isParameter(alg)) myalg = params->get(alg, myalg);
725 KokkosSparse::SPGEMMAlgorithm alg_enum = KokkosSparse::StringToSPGEMMAlgorithm(myalg);
728 const bool useIntRowptrs =
729 irph.shouldUseIntRowptrs() &&
730 Aview.
origMatrix->getApplyHelper()->shouldUseIntRowptrs();
732 const Scalar jacobiOmega = omega * Teuchos::ScalarTraits<Scalar>::one();
736 kh.create_spgemm_handle(alg_enum);
737 kh.set_team_work_size(team_work_size);
739 int_view_t int_row_mapC(Kokkos::ViewAllocateWithoutInitializing(
"int_row_mapC"), AnumRows + 1);
741 auto Aint = Aview.origMatrix->getApplyHelper()->getIntRowptrMatrix(Amat);
742 auto Bint = irph.getIntRowptrMatrix(Bmerged);
746 KokkosSparse::spgemm_symbolic(&kh, AnumRows, BnumRows, BnumCols,
747 Aint.graph.row_map, Aint.graph.entries,
false,
748 Bint.graph.row_map, Bint.graph.entries,
false,
754 size_t c_nnz_size = kh.get_spgemm_handle()->get_c_nnz();
756 entriesC = lno_nnz_view_t(Kokkos::ViewAllocateWithoutInitializing(
"entriesC"), c_nnz_size);
757 valuesC = scalar_view_t(Kokkos::ViewAllocateWithoutInitializing(
"valuesC"), c_nnz_size);
762 KokkosSparse::Experimental::spgemm_jacobi(&kh, AnumRows, BnumRows, BnumCols,
763 Aint.graph.row_map, Aint.graph.entries, Amat.values,
false,
764 Bint.graph.row_map, Bint.graph.entries, Bint.values,
false,
765 int_row_mapC, entriesC, valuesC,
766 jacobiOmega, Dinv.getLocalViewDevice(Access::ReadOnly));
768 Kokkos::parallel_for(
769 int_row_mapC.size(), KOKKOS_LAMBDA(
int i) { row_mapC(i) = int_row_mapC(i); });
770 kh.destroy_spgemm_handle();
773 kh.create_spgemm_handle(alg_enum);
774 kh.set_team_work_size(team_work_size);
779 KokkosSparse::spgemm_symbolic(&kh, AnumRows, BnumRows, BnumCols,
780 Amat.graph.row_map, Amat.graph.entries,
false,
781 Bmerged.graph.row_map, Bmerged.graph.entries,
false,
787 size_t c_nnz_size = kh.get_spgemm_handle()->get_c_nnz();
789 entriesC = lno_nnz_view_t(Kokkos::ViewAllocateWithoutInitializing(
"entriesC"), c_nnz_size);
790 valuesC = scalar_view_t(Kokkos::ViewAllocateWithoutInitializing(
"valuesC"), c_nnz_size);
792 KokkosSparse::Experimental::spgemm_jacobi(&kh, AnumRows, BnumRows, BnumCols,
793 Amat.graph.row_map, Amat.graph.entries, Amat.values,
false,
794 Bmerged.graph.row_map, Bmerged.graph.entries, Bmerged.values,
false,
795 row_mapC, entriesC, valuesC,
796 jacobiOmega, Dinv.getLocalViewDevice(Access::ReadOnly));
797 kh.destroy_spgemm_handle();
804 if (params.is_null() || params->get(
"sort entries",
true))
805 Import_Util::sortCrsEntries(row_mapC, entriesC, valuesC);
806 C.setAllValues(row_mapC, entriesC, valuesC);
812 Teuchos::RCP<Teuchos::ParameterList> labelList = rcp(
new Teuchos::ParameterList);
813 labelList->set(
"Timer Label", label);
814 if (!params.is_null()) labelList->set(
"compute global constants", params->get(
"compute global constants",
true));
815 Teuchos::RCP<const Export<LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode> > dummyExport;
816 C.expertStaticFillComplete(Bview.origMatrix->getDomainMap(), Aview.origMatrix->getRangeMap(), Cimport, dummyExport, labelList);
KokkosSparse::CrsMatrix< impl_scalar_type, local_ordinal_type, device_type, void, typename local_graph_device_type::size_type > local_matrix_device_type
The specialization of Kokkos::CrsMatrix that represents the part of the sparse matrix on each MPI pro...
Struct that holds views of the contents of a CrsMatrix.
Teuchos::RCP< const CrsMatrix< Scalar, LocalOrdinal, GlobalOrdinal, Node > > origMatrix
The original matrix.
static bool debug()
Whether Tpetra is in debug mode.
Namespace Tpetra contains the class and methods constituting the Tpetra library.