Tpetra parallel linear algebra Version of the Day
Loading...
Searching...
No Matches
Tpetra_LocalCrsMatrixOperator_decl.hpp
1// @HEADER
2// *****************************************************************************
3// Tpetra: Templated Linear Algebra Services Package
4//
5// Copyright 2008 NTESS and the Tpetra contributors.
6// SPDX-License-Identifier: BSD-3-Clause
7// *****************************************************************************
8// @HEADER
9
10#ifndef TPETRA_LOCALCRSMATRIXOPERATOR_DECL_HPP
11#define TPETRA_LOCALCRSMATRIXOPERATOR_DECL_HPP
12
13#include "Tpetra_LocalCrsMatrixOperator_fwd.hpp"
14#include "Tpetra_LocalOperator.hpp"
15#include "KokkosSparse_CrsMatrix.hpp"
16#include <memory> // std::shared_ptr
17
18namespace Tpetra {
19
29template <class MultiVectorScalar, class MatrixScalar, class Device>
30class LocalCrsMatrixOperator : public LocalOperator<MultiVectorScalar, Device> {
31 private:
32 using mv_scalar_type =
33 typename LocalOperator<MultiVectorScalar, Device>::scalar_type;
34 using matrix_scalar_type =
35 typename LocalOperator<MatrixScalar, Device>::scalar_type;
36 using array_layout =
37 typename LocalOperator<MultiVectorScalar, Device>::array_layout;
38 using device_type =
39 typename LocalOperator<MultiVectorScalar, Device>::device_type;
40 using local_ordinal_type =
42 using execution_space = typename Device::execution_space;
43
44 public:
45 using local_matrix_device_type =
46 KokkosSparse::CrsMatrix<matrix_scalar_type,
48 device_type,
49 void,
50 size_t>;
51
52 private:
53 // The type of a matrix with offset=ordinal, but otherwise the same as local_matrix_device_type
54 using local_cusparse_matrix_type =
55 KokkosSparse::CrsMatrix<matrix_scalar_type,
57 device_type,
58 void,
60 using local_graph_device_type = typename local_matrix_device_type::StaticCrsGraphType;
61
62 public:
63 using ordinal_view_type = typename local_graph_device_type::entries_type::non_const_type;
64
65 LocalCrsMatrixOperator(const std::shared_ptr<local_matrix_device_type>& A);
66 LocalCrsMatrixOperator(const std::shared_ptr<local_matrix_device_type>& A, const ordinal_view_type& A_ordinal_rowptrs);
67 ~LocalCrsMatrixOperator() override = default;
68
77 void
78 apply(Kokkos::View<const mv_scalar_type**, array_layout,
79 device_type, Kokkos::MemoryTraits<Kokkos::Unmanaged> >
80 X,
81 Kokkos::View<mv_scalar_type**, array_layout,
82 device_type, Kokkos::MemoryTraits<Kokkos::Unmanaged> >
83 Y,
84 const Teuchos::ETransp mode,
85 const mv_scalar_type alpha,
86 const mv_scalar_type beta) const override;
87
97 void
99 Kokkos::View<const mv_scalar_type**, array_layout,
100 device_type, Kokkos::MemoryTraits<Kokkos::Unmanaged> >
101 X,
102 Kokkos::View<mv_scalar_type**, array_layout,
103 device_type, Kokkos::MemoryTraits<Kokkos::Unmanaged> >
104 Y,
105 const Teuchos::ETransp mode,
106 const mv_scalar_type alpha,
107 const mv_scalar_type beta) const;
108
109 bool hasTransposeApply() const override;
110
111 const local_matrix_device_type& getLocalMatrixDevice() const;
112
113 private:
114 std::shared_ptr<local_matrix_device_type> A_;
115 local_cusparse_matrix_type A_cusparse;
116 const bool have_A_cusparse;
117};
118
119} // namespace Tpetra
120
121#endif // TPETRA_LOCALCRSMATRIXOPERATOR_DECL_HPP
Struct that holds views of the contents of a CrsMatrix.
Abstract interface for local operators (e.g., matrices and preconditioners).
void apply(Kokkos::View< const mv_scalar_type **, array_layout, device_type, Kokkos::MemoryTraits< Kokkos::Unmanaged > > X, Kokkos::View< mv_scalar_type **, array_layout, device_type, Kokkos::MemoryTraits< Kokkos::Unmanaged > > Y, const Teuchos::ETransp mode, const mv_scalar_type alpha, const mv_scalar_type beta) const override
Compute Y := beta*Y + alpha*Op(A)*X.
void applyImbalancedRows(Kokkos::View< const mv_scalar_type **, array_layout, device_type, Kokkos::MemoryTraits< Kokkos::Unmanaged > > X, Kokkos::View< mv_scalar_type **, array_layout, device_type, Kokkos::MemoryTraits< Kokkos::Unmanaged > > Y, const Teuchos::ETransp mode, const mv_scalar_type alpha, const mv_scalar_type beta) const
Compute Y := beta*Y + alpha*Op(A)*X, with a hint to use an SPMV algorithm for imbalanced rows.
Abstract interface for local operators (e.g., matrices and preconditioners).
int local_ordinal_type
Default value of Scalar template parameter.
Namespace Tpetra contains the class and methods constituting the Tpetra library.