Tpetra parallel linear algebra Version of the Day
Loading...
Searching...
No Matches
Tpetra_LocalCrsMatrixOperator_def.hpp
1// @HEADER
2// *****************************************************************************
3// Tpetra: Templated Linear Algebra Services Package
4//
5// Copyright 2008 NTESS and the Tpetra contributors.
6// SPDX-License-Identifier: BSD-3-Clause
7// *****************************************************************************
8// @HEADER
9
10#ifndef TPETRA_LOCALCRSMATRIXOPERATOR_DEF_HPP
11#define TPETRA_LOCALCRSMATRIXOPERATOR_DEF_HPP
12
13#include "Tpetra_LocalOperator.hpp"
15#include "KokkosSparse.hpp"
16#include "Teuchos_TestForException.hpp"
17#include "Teuchos_OrdinalTraits.hpp"
18
19namespace Tpetra {
20
21template <class MultiVectorScalar, class MatrixScalar, class Device>
22LocalCrsMatrixOperator<MultiVectorScalar, MatrixScalar, Device>::
23 LocalCrsMatrixOperator(const std::shared_ptr<local_matrix_device_type>& A)
24 : A_(A)
25 , have_A_cusparse(false) {
26 const char tfecfFuncName[] = "LocalCrsMatrixOperator: ";
27 TEUCHOS_TEST_FOR_EXCEPTION_CLASS_FUNC(A_.get() == nullptr, std::invalid_argument,
28 "Input matrix A is null.");
29}
30
31template <class MultiVectorScalar, class MatrixScalar, class Device>
32LocalCrsMatrixOperator<MultiVectorScalar, MatrixScalar, Device>::
33 LocalCrsMatrixOperator(const std::shared_ptr<local_matrix_device_type>& A, const ordinal_view_type& A_ordinal_rowptrs)
34 : A_(A)
35 , A_cusparse("LocalCrsMatrixOperator_cuSPARSE", A->numRows(), A->numCols(), A->nnz(),
36 A->values, A_ordinal_rowptrs, A->graph.entries)
37 , have_A_cusparse(true) {
38 const char tfecfFuncName[] = "LocalCrsMatrixOperator: ";
39 TEUCHOS_TEST_FOR_EXCEPTION_CLASS_FUNC(A_.get() == nullptr, std::invalid_argument,
40 "Input matrix A is null.");
41}
42
43template <class MultiVectorScalar, class MatrixScalar, class Device>
44bool LocalCrsMatrixOperator<MultiVectorScalar, MatrixScalar, Device>::
45 hasTransposeApply() const {
46 return true;
47}
48
49template <class MultiVectorScalar, class MatrixScalar, class Device>
50void LocalCrsMatrixOperator<MultiVectorScalar, MatrixScalar, Device>::
51 apply(Kokkos::View<const mv_scalar_type**, array_layout,
52 device_type, Kokkos::MemoryTraits<Kokkos::Unmanaged> >
53 X,
54 Kokkos::View<mv_scalar_type**, array_layout,
55 device_type, Kokkos::MemoryTraits<Kokkos::Unmanaged> >
56 Y,
57 const Teuchos::ETransp mode,
58 const mv_scalar_type alpha,
59 const mv_scalar_type beta) const {
60 const bool conjugate = (mode == Teuchos::CONJ_TRANS);
61 const bool transpose = (mode != Teuchos::NO_TRANS);
62
63#ifdef HAVE_TPETRA_DEBUG
64 const char tfecfFuncName[] = "apply: ";
65
66 TEUCHOS_TEST_FOR_EXCEPTION_CLASS_FUNC(X.extent(1) != Y.extent(1), std::runtime_error,
67 "X.extent(1) = " << X.extent(1) << " != Y.extent(1) = "
68 << Y.extent(1) << ".");
69 // If the two pointers are NULL, then they don't alias one
70 // another, even though they are equal.
71 TEUCHOS_TEST_FOR_EXCEPTION_CLASS_FUNC(X.data() == Y.data() && X.data() != nullptr,
72 std::runtime_error, "X and Y may not alias one another.");
73#endif // HAVE_TPETRA_DEBUG
74
75 const auto op = transpose ? (conjugate ? KokkosSparse::ConjugateTranspose : KokkosSparse::Transpose) : KokkosSparse::NoTranspose;
76 if (have_A_cusparse) {
77 KokkosSparse::spmv(op, alpha, A_cusparse, X, beta, Y);
78 } else {
79 KokkosSparse::spmv(op, alpha, *A_, X, beta, Y);
80 }
82
85template <class MultiVectorScalar, class MatrixScalar, class Device>
88 Kokkos::View<const mv_scalar_type**, array_layout,
89 device_type, Kokkos::MemoryTraits<Kokkos::Unmanaged> >
90 X,
91 Kokkos::View<mv_scalar_type**, array_layout,
92 device_type, Kokkos::MemoryTraits<Kokkos::Unmanaged> >
93 Y,
94 const Teuchos::ETransp mode,
95 const mv_scalar_type alpha,
96 const mv_scalar_type beta) const {
97 apply(X, Y, mode, alpha, beta);
98}
99
100template <class MultiVectorScalar, class MatrixScalar, class Device>
101const typename LocalCrsMatrixOperator<MultiVectorScalar, MatrixScalar, Device>::local_matrix_device_type&
103 getLocalMatrixDevice() const {
104 return *A_;
105}
106
107} // namespace Tpetra
108
109//
110// Explicit instantiation macro
111//
112// Must be expanded from within the Tpetra namespace!
113//
114
115// We only explicitly instantiate for MultiVectorScalar ==
116// MatrixScalar, which is what CrsMatrix needs.
117
118#define TPETRA_LOCALCRSMATRIXOPERATOR_INSTANT(SC, NT) \
119 template class LocalCrsMatrixOperator<SC, SC, NT::device_type>;
120
121// If we want mixed versions, we use this macro.
122
123#define TPETRA_LOCALCRSMATRIXOPERATOR_MIXED_INSTANT(SC, MATSC, LO, GO, NT) \
124 template class LocalCrsMatrixOperator<SC, MATSC, NT::device_type>;
125
126#endif // TPETRA_LOCALCRSMATRIXOPERATOR_DEF_HPP
Declaration of Tpetra::Details::Behavior, a class that describes Tpetra's behavior.
Struct that holds views of the contents of a CrsMatrix.
Abstract interface for local operators (e.g., matrices and preconditioners).
Namespace Tpetra contains the class and methods constituting the Tpetra library.