10#ifndef TPETRA_MATRIXMATRIX_HIP_DEF_HPP 
   11#define TPETRA_MATRIXMATRIX_HIP_DEF_HPP 
   13#include "Tpetra_Details_IntRowPtrHelper.hpp" 
   15#ifdef HAVE_TPETRA_INST_HIP 
   21template <
class Scalar,
 
   24          class LocalOrdinalViewType>
 
   25struct KernelWrappers<Scalar, LocalOrdinal, GlobalOrdinal, 
Tpetra::KokkosCompat::KokkosHIPWrapperNode, LocalOrdinalViewType> {
 
   26  static inline void mult_A_B_newmatrix_kernel_wrapper(CrsMatrixStruct<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode>& Aview,
 
   27                                                       CrsMatrixStruct<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode>& Bview,
 
   28                                                       const LocalOrdinalViewType& Acol2Brow,
 
   29                                                       const LocalOrdinalViewType& Acol2Irow,
 
   30                                                       const LocalOrdinalViewType& Bcol2Ccol,
 
   31                                                       const LocalOrdinalViewType& Icol2Ccol,
 
   32                                                       CrsMatrix<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode>& C,
 
   33                                                       Teuchos::RCP<
const Import<LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode> > Cimport,
 
   34                                                       const std::string& label                           = std::string(),
 
   35                                                       const Teuchos::RCP<Teuchos::ParameterList>& params = Teuchos::null);
 
   37  static inline void mult_A_B_reuse_kernel_wrapper(CrsMatrixStruct<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode>& Aview,
 
   38                                                   CrsMatrixStruct<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode>& Bview,
 
   39                                                   const LocalOrdinalViewType& Acol2Brow,
 
   40                                                   const LocalOrdinalViewType& Acol2Irow,
 
   41                                                   const LocalOrdinalViewType& Bcol2Ccol,
 
   42                                                   const LocalOrdinalViewType& Icol2Ccol,
 
   43                                                   CrsMatrix<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode>& C,
 
   44                                                   Teuchos::RCP<
const Import<LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode> > Cimport,
 
   45                                                   const std::string& label                           = std::string(),
 
   46                                                   const Teuchos::RCP<Teuchos::ParameterList>& params = Teuchos::null);
 
   50template <
class Scalar,
 
   52          class GlobalOrdinal, 
class LocalOrdinalViewType>
 
   53struct KernelWrappers2<Scalar, LocalOrdinal, GlobalOrdinal, 
Tpetra::KokkosCompat::KokkosHIPWrapperNode, LocalOrdinalViewType> {
 
   54  static inline void jacobi_A_B_newmatrix_kernel_wrapper(Scalar omega,
 
   55                                                         const Vector<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode>& Dinv,
 
   56                                                         CrsMatrixStruct<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode>& Aview,
 
   57                                                         CrsMatrixStruct<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode>& Bview,
 
   58                                                         const LocalOrdinalViewType& Acol2Brow,
 
   59                                                         const LocalOrdinalViewType& Acol2Irow,
 
   60                                                         const LocalOrdinalViewType& Bcol2Ccol,
 
   61                                                         const LocalOrdinalViewType& Icol2Ccol,
 
   62                                                         CrsMatrix<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode>& C,
 
   63                                                         Teuchos::RCP<
const Import<LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode> > Cimport,
 
   64                                                         const std::string& label                           = std::string(),
 
   65                                                         const Teuchos::RCP<Teuchos::ParameterList>& params = Teuchos::null);
 
   67  static inline void jacobi_A_B_reuse_kernel_wrapper(Scalar omega,
 
   68                                                     const Vector<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode>& Dinv,
 
   69                                                     CrsMatrixStruct<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode>& Aview,
 
   70                                                     CrsMatrixStruct<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode>& Bview,
 
   71                                                     const LocalOrdinalViewType& Acol2Brow,
 
   72                                                     const LocalOrdinalViewType& Acol2Irow,
 
   73                                                     const LocalOrdinalViewType& Bcol2Ccol,
 
   74                                                     const LocalOrdinalViewType& Icol2Ccol,
 
   75                                                     CrsMatrix<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode>& C,
 
   76                                                     Teuchos::RCP<
const Import<LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode> > Cimport,
 
   77                                                     const std::string& label                           = std::string(),
 
   78                                                     const Teuchos::RCP<Teuchos::ParameterList>& params = Teuchos::null);
 
   80  static inline void jacobi_A_B_newmatrix_KokkosKernels(Scalar omega,
 
   81                                                        const Vector<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode>& Dinv,
 
   82                                                        CrsMatrixStruct<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode>& Aview,
 
   83                                                        CrsMatrixStruct<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode>& Bview,
 
   84                                                        const LocalOrdinalViewType& Acol2Brow,
 
   85                                                        const LocalOrdinalViewType& Acol2Irow,
 
   86                                                        const LocalOrdinalViewType& Bcol2Ccol,
 
   87                                                        const LocalOrdinalViewType& Icol2Ccol,
 
   88                                                        CrsMatrix<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode>& C,
 
   89                                                        Teuchos::RCP<
const Import<LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode> > Cimport,
 
   90                                                        const std::string& label                           = std::string(),
 
   91                                                        const Teuchos::RCP<Teuchos::ParameterList>& params = Teuchos::null);
 
   96template <
class Scalar,
 
   99          class LocalOrdinalViewType>
 
  100void KernelWrappers<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode, LocalOrdinalViewType>::mult_A_B_newmatrix_kernel_wrapper(CrsMatrixStruct<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode>& Aview,
 
  101                                                                                                                                                              CrsMatrixStruct<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode>& Bview,
 
  102                                                                                                                                                              const LocalOrdinalViewType& Acol2Brow,
 
  103                                                                                                                                                              const LocalOrdinalViewType& Acol2Irow,
 
  104                                                                                                                                                              const LocalOrdinalViewType& Bcol2Ccol,
 
  105                                                                                                                                                              const LocalOrdinalViewType& Icol2Ccol,
 
  106                                                                                                                                                              CrsMatrix<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode>& C,
 
  107                                                                                                                                                              Teuchos::RCP<
const Import<LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode> > Cimport,
 
  108                                                                                                                                                              const std::string& label,
 
  109                                                                                                                                                              const Teuchos::RCP<Teuchos::ParameterList>& params) {
 
  111  typedef Tpetra::KokkosCompat::KokkosHIPWrapperNode Node;
 
  112  std::string nodename(
"HIP");
 
  118  typedef typename KCRS::device_type device_t;
 
  119  typedef typename KCRS::StaticCrsGraphType graph_t;
 
  120  typedef typename graph_t::row_map_type::non_const_type lno_view_t;
 
  121  using int_view_t = Kokkos::View<int*, typename lno_view_t::array_layout, typename lno_view_t::memory_space, typename lno_view_t::memory_traits>;
 
  122  typedef typename graph_t::row_map_type::const_type c_lno_view_t;
 
  123  typedef typename graph_t::entries_type::non_const_type lno_nnz_view_t;
 
  124  typedef typename KCRS::values_type::non_const_type scalar_view_t;
 
  130  int team_work_size = 16;  
 
  131  std::string myalg(
"SPGEMM_KK_MEMORY");
 
  132  if (!params.is_null()) {
 
  133    if (params->isParameter(
"hip: algorithm"))
 
  134      myalg = params->get(
"hip: algorithm", myalg);
 
  135    if (params->isParameter(
"hip: team work size"))
 
  136      team_work_size = params->get(
"hip: team work size", team_work_size);
 
  140  typedef KokkosKernels::Experimental::KokkosKernelsHandle<
 
  141      typename lno_view_t::const_value_type, 
typename lno_nnz_view_t::const_value_type, 
typename scalar_view_t::const_value_type,
 
  142      typename device_t::execution_space, 
typename device_t::memory_space, 
typename device_t::memory_space>
 
  144  using IntKernelHandle = KokkosKernels::Experimental::KokkosKernelsHandle<
 
  145      typename int_view_t::const_value_type, 
typename lno_nnz_view_t::const_value_type, 
typename scalar_view_t::const_value_type,
 
  146      typename device_t::execution_space, 
typename device_t::memory_space, 
typename device_t::memory_space>;
 
  149  const KCRS& Amat = Aview.origMatrix->getLocalMatrixDevice();
 
  150  const KCRS& Bmat = Bview.origMatrix->getLocalMatrixDevice();
 
  152  c_lno_view_t Arowptr         = Amat.graph.row_map,
 
  153               Browptr         = Bmat.graph.row_map;
 
  154  const lno_nnz_view_t Acolind = Amat.graph.entries,
 
  155                       Bcolind = Bmat.graph.entries;
 
  156  const scalar_view_t Avals    = Amat.values,
 
  160  std::string alg = nodename + std::string(
" algorithm");
 
  162  if (!params.is_null() && params->isParameter(alg)) myalg = params->get(alg, myalg);
 
  163  KokkosSparse::SPGEMMAlgorithm alg_enum = KokkosSparse::StringToSPGEMMAlgorithm(myalg);
 
  166  const KCRS Bmerged = Tpetra::MMdetails::merge_matrices(Aview, Bview, Acol2Brow, Acol2Irow, Bcol2Ccol, Icol2Ccol, C.getColMap()->getLocalNumElements());
 
  172  typename KernelHandle::nnz_lno_t AnumRows = Amat.numRows();
 
  173  typename KernelHandle::nnz_lno_t BnumRows = Bmerged.numRows();
 
  174  typename KernelHandle::nnz_lno_t BnumCols = Bmerged.numCols();
 
  177  lno_view_t row_mapC(Kokkos::ViewAllocateWithoutInitializing(
"non_const_lno_row"), AnumRows + 1);
 
  178  lno_nnz_view_t entriesC;
 
  179  scalar_view_t valuesC;
 
  182  const bool useIntRowptrs =
 
  183      irph.shouldUseIntRowptrs() &&
 
  184      Aview.
origMatrix->getApplyHelper()->shouldUseIntRowptrs();
 
  188    kh.create_spgemm_handle(alg_enum);
 
  189    kh.set_team_work_size(team_work_size);
 
  191    int_view_t int_row_mapC(Kokkos::ViewAllocateWithoutInitializing(
"non_const_int_row"), AnumRows + 1);
 
  193    auto Aint = Aview.origMatrix->getApplyHelper()->getIntRowptrMatrix(Amat);
 
  194    auto Bint = irph.getIntRowptrMatrix(Bmerged);
 
  198      KokkosSparse::spgemm_symbolic(&kh, AnumRows, BnumRows, BnumCols, Aint.graph.row_map, Aint.graph.entries, 
false, Bint.graph.row_map, Bint.graph.entries, 
false, int_row_mapC);
 
  202    size_t c_nnz_size = kh.get_spgemm_handle()->get_c_nnz();
 
  204      entriesC = lno_nnz_view_t(Kokkos::ViewAllocateWithoutInitializing(
"entriesC"), c_nnz_size);
 
  205      valuesC  = scalar_view_t(Kokkos::ViewAllocateWithoutInitializing(
"valuesC"), c_nnz_size);
 
  207    KokkosSparse::spgemm_numeric(&kh, AnumRows, BnumRows, BnumCols, Aint.graph.row_map, Aint.graph.entries, Aint.values, 
false, Bint.graph.row_map, Bint.graph.entries, Bint.values, 
false, int_row_mapC, entriesC, valuesC);
 
  209    Kokkos::parallel_for(
 
  210        int_row_mapC.size(), KOKKOS_LAMBDA(
int i) { row_mapC(i) = int_row_mapC(i); });
 
  211    kh.destroy_spgemm_handle();
 
  214    kh.create_spgemm_handle(alg_enum);
 
  215    kh.set_team_work_size(team_work_size);
 
  219      KokkosSparse::spgemm_symbolic(&kh, AnumRows, BnumRows, BnumCols, Amat.graph.row_map, Amat.graph.entries, 
false, Bmerged.graph.row_map, Bmerged.graph.entries, 
false, row_mapC);
 
  222    size_t c_nnz_size = kh.get_spgemm_handle()->get_c_nnz();
 
  224      entriesC = lno_nnz_view_t(Kokkos::ViewAllocateWithoutInitializing(
"entriesC"), c_nnz_size);
 
  225      valuesC  = scalar_view_t(Kokkos::ViewAllocateWithoutInitializing(
"valuesC"), c_nnz_size);
 
  229    KokkosSparse::spgemm_numeric(&kh, AnumRows, BnumRows, BnumCols, Amat.graph.row_map, Amat.graph.entries, Amat.values, 
false, Bmerged.graph.row_map, Bmerged.graph.entries, Bmerged.values, 
false, row_mapC, entriesC, valuesC);
 
  230    kh.destroy_spgemm_handle();
 
  237  if (params.is_null() || params->get(
"sort entries", 
true))
 
  238    Import_Util::sortCrsEntries(row_mapC, entriesC, valuesC);
 
  239  C.setAllValues(row_mapC, entriesC, valuesC);
 
  245  RCP<Teuchos::ParameterList> labelList = rcp(
new Teuchos::ParameterList);
 
  246  labelList->set(
"Timer Label", label);
 
  247  if (!params.is_null()) labelList->set(
"compute global constants", params->get(
"compute global constants", 
true));
 
  248  RCP<const Export<LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode> > dummyExport;
 
  249  C.expertStaticFillComplete(Bview.origMatrix->getDomainMap(), Aview.origMatrix->getRangeMap(), Cimport, dummyExport, labelList);
 
  253template <
class Scalar,
 
  256          class LocalOrdinalViewType>
 
  257void KernelWrappers<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode, LocalOrdinalViewType>::mult_A_B_reuse_kernel_wrapper(CrsMatrixStruct<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode>& Aview,
 
  258                                                                                                                                                          CrsMatrixStruct<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode>& Bview,
 
  259                                                                                                                                                          const LocalOrdinalViewType& targetMapToOrigRow_dev,
 
  260                                                                                                                                                          const LocalOrdinalViewType& targetMapToImportRow_dev,
 
  261                                                                                                                                                          const LocalOrdinalViewType& Bcol2Ccol_dev,
 
  262                                                                                                                                                          const LocalOrdinalViewType& Icol2Ccol_dev,
 
  263                                                                                                                                                          CrsMatrix<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode>& C,
 
  264                                                                                                                                                          Teuchos::RCP<
const Import<LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode> > Cimport,
 
  265                                                                                                                                                          const std::string& label,
 
  266                                                                                                                                                          const Teuchos::RCP<Teuchos::ParameterList>& params) {
 
  268  typedef Tpetra::KokkosCompat::KokkosHIPWrapperNode Node;
 
  275  typedef typename Tpetra::CrsMatrix<Scalar, LocalOrdinal, GlobalOrdinal, Node>::local_matrix_host_type KCRS;
 
  276  typedef typename KCRS::StaticCrsGraphType graph_t;
 
  277  typedef typename graph_t::row_map_type::const_type c_lno_view_t;
 
  278  typedef typename graph_t::entries_type::non_const_type lno_nnz_view_t;
 
  279  typedef typename KCRS::values_type::non_const_type scalar_view_t;
 
  282  typedef LocalOrdinal LO;
 
  283  typedef GlobalOrdinal GO;
 
  285  typedef Map<LO, GO, NO> map_type;
 
  286  const size_t ST_INVALID = Teuchos::OrdinalTraits<LO>::invalid();
 
  287  const LO LO_INVALID     = Teuchos::OrdinalTraits<LO>::invalid();
 
  288  const SC SC_ZERO        = Teuchos::ScalarTraits<Scalar>::zero();
 
  293  auto targetMapToOrigRow =
 
  294      Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(),
 
  295                                          targetMapToOrigRow_dev);
 
  296  auto targetMapToImportRow =
 
  297      Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(),
 
  298                                          targetMapToImportRow_dev);
 
  300      Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(),
 
  303      Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(),
 
  307  RCP<const map_type> Ccolmap = C.getColMap();
 
  308  size_t m                    = Aview.origMatrix->getLocalNumRows();
 
  309  size_t n                    = Ccolmap->getLocalNumElements();
 
  312  const KCRS& Amat = Aview.origMatrix->getLocalMatrixHost();
 
  313  const KCRS& Bmat = Bview.origMatrix->getLocalMatrixHost();
 
  314  const KCRS& Cmat = C.getLocalMatrixHost();
 
  316  c_lno_view_t Arowptr         = Amat.graph.row_map,
 
  317               Browptr         = Bmat.graph.row_map,
 
  318               Crowptr         = Cmat.graph.row_map;
 
  319  const lno_nnz_view_t Acolind = Amat.graph.entries,
 
  320                       Bcolind = Bmat.graph.entries,
 
  321                       Ccolind = Cmat.graph.entries;
 
  322  const scalar_view_t Avals = Amat.values, Bvals = Bmat.values;
 
  323  scalar_view_t Cvals = Cmat.values;
 
  325  c_lno_view_t Irowptr;
 
  326  lno_nnz_view_t Icolind;
 
  328  if (!Bview.importMatrix.is_null()) {
 
  329    auto lclB = Bview.importMatrix->getLocalMatrixHost();
 
  330    Irowptr   = lclB.graph.row_map;
 
  331    Icolind   = lclB.graph.entries;
 
  346  std::vector<size_t> c_status(n, ST_INVALID);
 
  349  size_t CSR_ip = 0, OLD_ip = 0;
 
  350  for (
size_t i = 0; i < m; i++) {
 
  354    CSR_ip = Crowptr[i + 1];
 
  355    for (
size_t k = OLD_ip; k < CSR_ip; k++) {
 
  356      c_status[Ccolind[k]] = k;
 
  362    for (
size_t k = Arowptr[i]; k < Arowptr[i + 1]; k++) {
 
  364      const SC Aval = Avals[k];
 
  368      if (targetMapToOrigRow[Aik] != LO_INVALID) {
 
  370        size_t Bk = Teuchos::as<size_t>(targetMapToOrigRow[Aik]);
 
  372        for (
size_t j = Browptr[Bk]; j < Browptr[Bk + 1]; ++j) {
 
  374          LO Cij = Bcol2Ccol[Bkj];
 
  376          TEUCHOS_TEST_FOR_EXCEPTION(c_status[Cij] < OLD_ip || c_status[Cij] >= CSR_ip,
 
  377                                     std::runtime_error, 
"Trying to insert a new entry (" << i << 
"," << Cij << 
") into a static graph " 
  378                                                                                          << 
"(c_status = " << c_status[Cij] << 
" of [" << OLD_ip << 
"," << CSR_ip << 
"))");
 
  380          Cvals[c_status[Cij]] += Aval * Bvals[j];
 
  385        size_t Ik = Teuchos::as<size_t>(targetMapToImportRow[Aik]);
 
  386        for (
size_t j = Irowptr[Ik]; j < Irowptr[Ik + 1]; ++j) {
 
  388          LO Cij = Icol2Ccol[Ikj];
 
  390          TEUCHOS_TEST_FOR_EXCEPTION(c_status[Cij] < OLD_ip || c_status[Cij] >= CSR_ip,
 
  391                                     std::runtime_error, 
"Trying to insert a new entry (" << i << 
"," << Cij << 
") into a static graph " 
  392                                                                                          << 
"(c_status = " << c_status[Cij] << 
" of [" << OLD_ip << 
"," << CSR_ip << 
"))");
 
  394          Cvals[c_status[Cij]] += Aval * Ivals[j];
 
  400  C.fillComplete(C.getDomainMap(), C.getRangeMap());
 
  404template <
class Scalar,
 
  407          class LocalOrdinalViewType>
 
  408void KernelWrappers2<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode, LocalOrdinalViewType>::jacobi_A_B_newmatrix_kernel_wrapper(Scalar omega,
 
  409                                                                                                                                                                 const Vector<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode>& Dinv,
 
  410                                                                                                                                                                 CrsMatrixStruct<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode>& Aview,
 
  411                                                                                                                                                                 CrsMatrixStruct<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode>& Bview,
 
  412                                                                                                                                                                 const LocalOrdinalViewType& Acol2Brow,
 
  413                                                                                                                                                                 const LocalOrdinalViewType& Acol2Irow,
 
  414                                                                                                                                                                 const LocalOrdinalViewType& Bcol2Ccol,
 
  415                                                                                                                                                                 const LocalOrdinalViewType& Icol2Ccol,
 
  416                                                                                                                                                                 CrsMatrix<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode>& C,
 
  417                                                                                                                                                                 Teuchos::RCP<
const Import<LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode> > Cimport,
 
  418                                                                                                                                                                 const std::string& label,
 
  419                                                                                                                                                                 const Teuchos::RCP<Teuchos::ParameterList>& params) {
 
  425  std::string myalg(
"KK");
 
  426  if (!params.is_null()) {
 
  427    if (params->isParameter(
"hip: jacobi algorithm"))
 
  428      myalg = params->get(
"hip: jacobi algorithm", myalg);
 
  431  if (myalg == 
"MSAK") {
 
  432    ::Tpetra::MatrixMatrix::ExtraKernels::jacobi_A_B_newmatrix_MultiplyScaleAddKernel(omega, Dinv, Aview, Bview, Acol2Brow, Acol2Irow, Bcol2Ccol, Icol2Ccol, C, Cimport, label, params);
 
  433  } 
else if (myalg == 
"KK") {
 
  434    jacobi_A_B_newmatrix_KokkosKernels(omega, Dinv, Aview, Bview, Acol2Brow, Acol2Irow, Bcol2Ccol, Icol2Ccol, C, Cimport, label, params);
 
  436    throw std::runtime_error(
"Tpetra::MatrixMatrix::Jacobi newmatrix unknown kernel");
 
  442  RCP<Teuchos::ParameterList> labelList = rcp(
new Teuchos::ParameterList);
 
  443  labelList->set(
"Timer Label", label);
 
  444  if (!params.is_null()) labelList->set(
"compute global constants", params->get(
"compute global constants", 
true));
 
  447  if (!C.isFillComplete()) {
 
  448    RCP<const Export<LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode> > dummyExport;
 
  449    C.expertStaticFillComplete(Bview.origMatrix->getDomainMap(), Aview.origMatrix->getRangeMap(), Cimport, dummyExport, labelList);
 
  454template <
class Scalar,
 
  457          class LocalOrdinalViewType>
 
  458void KernelWrappers2<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode, LocalOrdinalViewType>::jacobi_A_B_reuse_kernel_wrapper(Scalar omega,
 
  459                                                                                                                                                             const Vector<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode>& Dinv,
 
  460                                                                                                                                                             CrsMatrixStruct<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode>& Aview,
 
  461                                                                                                                                                             CrsMatrixStruct<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode>& Bview,
 
  462                                                                                                                                                             const LocalOrdinalViewType& targetMapToOrigRow_dev,
 
  463                                                                                                                                                             const LocalOrdinalViewType& targetMapToImportRow_dev,
 
  464                                                                                                                                                             const LocalOrdinalViewType& Bcol2Ccol_dev,
 
  465                                                                                                                                                             const LocalOrdinalViewType& Icol2Ccol_dev,
 
  466                                                                                                                                                             CrsMatrix<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode>& C,
 
  467                                                                                                                                                             Teuchos::RCP<
const Import<LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode> > Cimport,
 
  468                                                                                                                                                             const std::string& label,
 
  469                                                                                                                                                             const Teuchos::RCP<Teuchos::ParameterList>& params) {
 
  471  typedef Tpetra::KokkosCompat::KokkosHIPWrapperNode Node;
 
  478  typedef typename Tpetra::CrsMatrix<Scalar, LocalOrdinal, GlobalOrdinal, Node>::local_matrix_host_type KCRS;
 
  479  typedef typename KCRS::StaticCrsGraphType graph_t;
 
  480  typedef typename graph_t::row_map_type::const_type c_lno_view_t;
 
  481  typedef typename graph_t::entries_type::non_const_type lno_nnz_view_t;
 
  482  typedef typename KCRS::values_type::non_const_type scalar_view_t;
 
  483  typedef typename scalar_view_t::memory_space scalar_memory_space;
 
  486  typedef LocalOrdinal LO;
 
  487  typedef GlobalOrdinal GO;
 
  489  typedef Map<LO, GO, NO> map_type;
 
  490  const size_t ST_INVALID = Teuchos::OrdinalTraits<LO>::invalid();
 
  491  const LO LO_INVALID     = Teuchos::OrdinalTraits<LO>::invalid();
 
  492  const SC SC_ZERO        = Teuchos::ScalarTraits<Scalar>::zero();
 
  497  auto targetMapToOrigRow =
 
  498      Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(),
 
  499                                          targetMapToOrigRow_dev);
 
  500  auto targetMapToImportRow =
 
  501      Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(),
 
  502                                          targetMapToImportRow_dev);
 
  504      Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(),
 
  507      Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(),
 
  511  RCP<const map_type> Ccolmap = C.getColMap();
 
  512  size_t m                    = Aview.origMatrix->getLocalNumRows();
 
  513  size_t n                    = Ccolmap->getLocalNumElements();
 
  516  const KCRS& Amat = Aview.origMatrix->getLocalMatrixHost();
 
  517  const KCRS& Bmat = Bview.origMatrix->getLocalMatrixHost();
 
  518  const KCRS& Cmat = C.getLocalMatrixHost();
 
  520  c_lno_view_t Arowptr = Amat.graph.row_map, Browptr = Bmat.graph.row_map, Crowptr = Cmat.graph.row_map;
 
  521  const lno_nnz_view_t Acolind = Amat.graph.entries, Bcolind = Bmat.graph.entries, Ccolind = Cmat.graph.entries;
 
  522  const scalar_view_t Avals = Amat.values, Bvals = Bmat.values;
 
  523  scalar_view_t Cvals = Cmat.values;
 
  525  c_lno_view_t Irowptr;
 
  526  lno_nnz_view_t Icolind;
 
  528  if (!Bview.importMatrix.is_null()) {
 
  529    auto lclB = Bview.importMatrix->getLocalMatrixHost();
 
  530    Irowptr   = lclB.graph.row_map;
 
  531    Icolind   = lclB.graph.entries;
 
  537      Dinv.template getLocalView<scalar_memory_space>(Access::ReadOnly);
 
  546  std::vector<size_t> c_status(n, ST_INVALID);
 
  549  size_t CSR_ip = 0, OLD_ip = 0;
 
  550  for (
size_t i = 0; i < m; i++) {
 
  554    CSR_ip = Crowptr[i + 1];
 
  555    for (
size_t k = OLD_ip; k < CSR_ip; k++) {
 
  556      c_status[Ccolind[k]] = k;
 
  562    SC minusOmegaDval = -omega * Dvals(i, 0);
 
  565    for (
size_t j = Browptr[i]; j < Browptr[i + 1]; j++) {
 
  566      Scalar Bval = Bvals[j];
 
  570      LO Cij = Bcol2Ccol[Bij];
 
  572      TEUCHOS_TEST_FOR_EXCEPTION(c_status[Cij] < OLD_ip || c_status[Cij] >= CSR_ip,
 
  573                                 std::runtime_error, 
"Trying to insert a new entry into a static graph");
 
  575      Cvals[c_status[Cij]] = Bvals[j];
 
  579    for (
size_t k = Arowptr[i]; k < Arowptr[i + 1]; k++) {
 
  581      const SC Aval = Avals[k];
 
  585      if (targetMapToOrigRow[Aik] != LO_INVALID) {
 
  587        size_t Bk = Teuchos::as<size_t>(targetMapToOrigRow[Aik]);
 
  589        for (
size_t j = Browptr[Bk]; j < Browptr[Bk + 1]; ++j) {
 
  591          LO Cij = Bcol2Ccol[Bkj];
 
  593          TEUCHOS_TEST_FOR_EXCEPTION(c_status[Cij] < OLD_ip || c_status[Cij] >= CSR_ip,
 
  594                                     std::runtime_error, 
"Trying to insert a new entry into a static graph");
 
  596          Cvals[c_status[Cij]] += minusOmegaDval * Aval * Bvals[j];
 
  601        size_t Ik = Teuchos::as<size_t>(targetMapToImportRow[Aik]);
 
  602        for (
size_t j = Irowptr[Ik]; j < Irowptr[Ik + 1]; ++j) {
 
  604          LO Cij = Icol2Ccol[Ikj];
 
  606          TEUCHOS_TEST_FOR_EXCEPTION(c_status[Cij] < OLD_ip || c_status[Cij] >= CSR_ip,
 
  607                                     std::runtime_error, 
"Trying to insert a new entry into a static graph");
 
  609          Cvals[c_status[Cij]] += minusOmegaDval * Aval * Ivals[j];
 
  618  C.fillComplete(C.getDomainMap(), C.getRangeMap());
 
  622template <
class Scalar,
 
  625          class LocalOrdinalViewType>
 
  626void KernelWrappers2<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode, LocalOrdinalViewType>::jacobi_A_B_newmatrix_KokkosKernels(Scalar omega,
 
  627                                                                                                                                                                const Vector<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode>& Dinv,
 
  628                                                                                                                                                                CrsMatrixStruct<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode>& Aview,
 
  629                                                                                                                                                                CrsMatrixStruct<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode>& Bview,
 
  630                                                                                                                                                                const LocalOrdinalViewType& Acol2Brow,
 
  631                                                                                                                                                                const LocalOrdinalViewType& Acol2Irow,
 
  632                                                                                                                                                                const LocalOrdinalViewType& Bcol2Ccol,
 
  633                                                                                                                                                                const LocalOrdinalViewType& Icol2Ccol,
 
  634                                                                                                                                                                CrsMatrix<Scalar, LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode>& C,
 
  635                                                                                                                                                                Teuchos::RCP<
const Import<LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode> > Cimport,
 
  636                                                                                                                                                                const std::string& label,
 
  637                                                                                                                                                                const Teuchos::RCP<Teuchos::ParameterList>& params) {
 
  641    auto rowMap = Aview.origMatrix->getRowMap();
 
  643    Aview.origMatrix->getLocalDiagCopy(diags);
 
  644    size_t diagLength = rowMap->getLocalNumElements();
 
  645    Teuchos::Array<Scalar> diagonal(diagLength);
 
  646    diags.get1dCopy(diagonal());
 
  648    for (
size_t i = 0; i < diagLength; ++i) {
 
  649      TEUCHOS_TEST_FOR_EXCEPTION(diagonal[i] == Teuchos::ScalarTraits<Scalar>::zero(),
 
  651                                 "Matrix A has a zero/missing diagonal: " << diagonal[i] << std::endl
 
  652                                                                          << 
"KokkosKernels Jacobi-fused SpGEMM requires nonzero diagonal entries in A" << std::endl);
 
  661  using device_t       = 
typename Tpetra::KokkosCompat::KokkosHIPWrapperNode::device_type;
 
  663  using graph_t        = 
typename matrix_t::StaticCrsGraphType;
 
  664  using lno_view_t     = 
typename graph_t::row_map_type::non_const_type;
 
  665  using int_view_t     = Kokkos::View<
int*,
 
  666                                  typename lno_view_t::array_layout,
 
  667                                  typename lno_view_t::memory_space,
 
  668                                  typename lno_view_t::memory_traits>;
 
  669  using c_lno_view_t   = 
typename graph_t::row_map_type::const_type;
 
  670  using lno_nnz_view_t = 
typename graph_t::entries_type::non_const_type;
 
  671  using scalar_view_t  = 
typename matrix_t::values_type::non_const_type;
 
  674  using handle_t = 
typename KokkosKernels::Experimental::KokkosKernelsHandle<
 
  675      typename lno_view_t::const_value_type, 
typename lno_nnz_view_t::const_value_type, 
typename scalar_view_t::const_value_type,
 
  676      typename device_t::execution_space, 
typename device_t::memory_space, 
typename device_t::memory_space>;
 
  677  using int_handle_t = 
typename KokkosKernels::Experimental::KokkosKernelsHandle<
 
  678      typename int_view_t::const_value_type, 
typename lno_nnz_view_t::const_value_type, 
typename scalar_view_t::const_value_type,
 
  679      typename device_t::execution_space, 
typename device_t::memory_space, 
typename device_t::memory_space>;
 
  682  const matrix_t Bmerged = Tpetra::MMdetails::merge_matrices(Aview, Bview, Acol2Brow, Acol2Irow, Bcol2Ccol, Icol2Ccol, C.getColMap()->getLocalNumElements());
 
  685  const matrix_t& Amat = Aview.origMatrix->getLocalMatrixDevice();
 
  686  const matrix_t& Bmat = Bview.origMatrix->getLocalMatrixDevice();
 
  688  typename handle_t::nnz_lno_t AnumRows = Amat.numRows();
 
  689  typename handle_t::nnz_lno_t BnumRows = Bmerged.numRows();
 
  690  typename handle_t::nnz_lno_t BnumCols = Bmerged.numCols();
 
  692  c_lno_view_t Arowptr = Amat.graph.row_map, Browptr = Bmerged.graph.row_map;
 
  693  const lno_nnz_view_t Acolind = Amat.graph.entries, Bcolind = Bmerged.graph.entries;
 
  694  const scalar_view_t Avals = Amat.values, Bvals = Bmerged.values;
 
  697  lno_view_t row_mapC(Kokkos::ViewAllocateWithoutInitializing(
"row_mapC"), AnumRows + 1);
 
  698  lno_nnz_view_t entriesC;
 
  699  scalar_view_t valuesC;
 
  702  int team_work_size = 16;
 
  703  std::string myalg(
"SPGEMM_KK_MEMORY");
 
  704  if (!params.is_null()) {
 
  705    if (params->isParameter(
"hip: algorithm"))
 
  706      myalg = params->get(
"hip: algorithm", myalg);
 
  707    if (params->isParameter(
"hip: team work size"))
 
  708      team_work_size = params->get(
"hip: team work size", team_work_size);
 
  712  std::string nodename(
"HIP");
 
  713  std::string alg = nodename + std::string(
" algorithm");
 
  714  if (!params.is_null() && params->isParameter(alg)) myalg = params->get(alg, myalg);
 
  715  KokkosSparse::SPGEMMAlgorithm alg_enum = KokkosSparse::StringToSPGEMMAlgorithm(myalg);
 
  718  const bool useIntRowptrs =
 
  719      irph.shouldUseIntRowptrs() &&
 
  720      Aview.
origMatrix->getApplyHelper()->shouldUseIntRowptrs();
 
  724    kh.create_spgemm_handle(alg_enum);
 
  725    kh.set_team_work_size(team_work_size);
 
  727    int_view_t int_row_mapC(Kokkos::ViewAllocateWithoutInitializing(
"int_row_mapC"), AnumRows + 1);
 
  729    auto Aint = Aview.origMatrix->getApplyHelper()->getIntRowptrMatrix(Amat);
 
  730    auto Bint = irph.getIntRowptrMatrix(Bmerged);
 
  734      KokkosSparse::spgemm_symbolic(&kh, AnumRows, BnumRows, BnumCols,
 
  735                                    Aint.graph.row_map, Aint.graph.entries, 
false,
 
  736                                    Bint.graph.row_map, Bint.graph.entries, 
false,
 
  742    size_t c_nnz_size = kh.get_spgemm_handle()->get_c_nnz();
 
  744      entriesC = lno_nnz_view_t(Kokkos::ViewAllocateWithoutInitializing(
"entriesC"), c_nnz_size);
 
  745      valuesC  = scalar_view_t(Kokkos::ViewAllocateWithoutInitializing(
"valuesC"), c_nnz_size);
 
  750    KokkosSparse::Experimental::spgemm_jacobi(&kh, AnumRows, BnumRows, BnumCols,
 
  751                                              Aint.graph.row_map, Aint.graph.entries, Amat.values, 
false,
 
  752                                              Bint.graph.row_map, Bint.graph.entries, Bint.values, 
false,
 
  753                                              int_row_mapC, entriesC, valuesC,
 
  754                                              omega, Dinv.getLocalViewDevice(Access::ReadOnly));
 
  756    Kokkos::parallel_for(
 
  757        int_row_mapC.size(), KOKKOS_LAMBDA(
int i) { row_mapC(i) = int_row_mapC(i); });
 
  758    kh.destroy_spgemm_handle();
 
  761    kh.create_spgemm_handle(alg_enum);
 
  762    kh.set_team_work_size(team_work_size);
 
  767      KokkosSparse::spgemm_symbolic(&kh, AnumRows, BnumRows, BnumCols,
 
  768                                    Amat.graph.row_map, Amat.graph.entries, 
false,
 
  769                                    Bmerged.graph.row_map, Bmerged.graph.entries, 
false,
 
  775    size_t c_nnz_size = kh.get_spgemm_handle()->get_c_nnz();
 
  777      entriesC = lno_nnz_view_t(Kokkos::ViewAllocateWithoutInitializing(
"entriesC"), c_nnz_size);
 
  778      valuesC  = scalar_view_t(Kokkos::ViewAllocateWithoutInitializing(
"valuesC"), c_nnz_size);
 
  780    KokkosSparse::Experimental::spgemm_jacobi(&kh, AnumRows, BnumRows, BnumCols,
 
  781                                              Amat.graph.row_map, Amat.graph.entries, Amat.values, 
false,
 
  782                                              Bmerged.graph.row_map, Bmerged.graph.entries, Bmerged.values, 
false,
 
  783                                              row_mapC, entriesC, valuesC,
 
  784                                              omega, Dinv.getLocalViewDevice(Access::ReadOnly));
 
  785    kh.destroy_spgemm_handle();
 
  792  if (params.is_null() || params->get(
"sort entries", 
true))
 
  793    Import_Util::sortCrsEntries(row_mapC, entriesC, valuesC);
 
  794  C.setAllValues(row_mapC, entriesC, valuesC);
 
  800  Teuchos::RCP<Teuchos::ParameterList> labelList = rcp(
new Teuchos::ParameterList);
 
  801  labelList->set(
"Timer Label", label);
 
  802  if (!params.is_null()) labelList->set(
"compute global constants", params->get(
"compute global constants", 
true));
 
  803  Teuchos::RCP<const Export<LocalOrdinal, GlobalOrdinal, Tpetra::KokkosCompat::KokkosHIPWrapperNode> > dummyExport;
 
  804  C.expertStaticFillComplete(Bview.origMatrix->getDomainMap(), Aview.origMatrix->getRangeMap(), Cimport, dummyExport, labelList);
 
Struct that holds views of the contents of a CrsMatrix.
 
Teuchos::RCP< const CrsMatrix< Scalar, LocalOrdinal, GlobalOrdinal, Node > > origMatrix
The original matrix.
 
KokkosSparse::CrsMatrix< impl_scalar_type, local_ordinal_type, device_type, void, typename local_graph_device_type::size_type > local_matrix_device_type
The specialization of Kokkos::CrsMatrix that represents the part of the sparse matrix on each MPI pro...
 
static bool debug()
Whether Tpetra is in debug mode.
 
Namespace Tpetra contains the class and methods constituting the Tpetra library.