MueLu Version of the Day
Loading...
Searching...
No Matches
MueLu_LowPrecisionFactory_def.hpp
Go to the documentation of this file.
1// @HEADER
2// *****************************************************************************
3// MueLu: A package for multigrid based preconditioning
4//
5// Copyright 2012 NTESS and the MueLu contributors.
6// SPDX-License-Identifier: BSD-3-Clause
7// *****************************************************************************
8// @HEADER
9
10#ifndef MUELU_LOWPRECISIONFACTORY_DEF_HPP
11#define MUELU_LOWPRECISIONFACTORY_DEF_HPP
12
13#include <Xpetra_Matrix.hpp>
14#include <Xpetra_Operator.hpp>
15#include <Xpetra_TpetraOperator.hpp>
16#include <Tpetra_CrsMatrixMultiplyOp.hpp>
17
19
20#include "MueLu_Level.hpp"
21#include "MueLu_Monitor.hpp"
22
23namespace MueLu {
24
25template <class Scalar, class LocalOrdinal, class GlobalOrdinal, class Node>
27 RCP<ParameterList> validParamList = rcp(new ParameterList());
28
29 validParamList->set<std::string>("matrix key", "A", "");
30 validParamList->set<RCP<const FactoryBase> >("R", Teuchos::null, "Generating factory of the matrix A to be converted to lower precision");
31 validParamList->set<RCP<const FactoryBase> >("A", Teuchos::null, "Generating factory of the matrix A to be converted to lower precision");
32 validParamList->set<RCP<const FactoryBase> >("P", Teuchos::null, "Generating factory of the matrix A to be converted to lower precision");
33
34 return validParamList;
35}
36
37template <class Scalar, class LocalOrdinal, class GlobalOrdinal, class Node>
39 const ParameterList& pL = GetParameterList();
40 std::string matrixKey = pL.get<std::string>("matrix key");
41 Input(currentLevel, matrixKey);
42}
43
44template <class Scalar, class LocalOrdinal, class GlobalOrdinal, class Node>
46 using Teuchos::ParameterList;
47
48 const ParameterList& pL = GetParameterList();
49 std::string matrixKey = pL.get<std::string>("matrix key");
50
51 FactoryMonitor m(*this, "Converting " + matrixKey + " to half precision", currentLevel);
52
53 RCP<Matrix> A = Get<RCP<Matrix> >(currentLevel, matrixKey);
54
55 GetOStream(Warnings) << "Matrix not converted to half precision. This only works for Tpetra and when both Scalar and HalfScalar have been instantiated." << std::endl;
56 Set(currentLevel, matrixKey, A);
57}
58
59#if defined(HAVE_TPETRA_INST_DOUBLE) && defined(HAVE_TPETRA_INST_FLOAT)
60template <class LocalOrdinal, class GlobalOrdinal, class Node>
62 RCP<ParameterList> validParamList = rcp(new ParameterList());
63
64 validParamList->set<std::string>("matrix key", "A", "");
65 validParamList->set<RCP<const FactoryBase> >("R", Teuchos::null, "Generating factory of the matrix A to be converted to lower precision");
66 validParamList->set<RCP<const FactoryBase> >("A", Teuchos::null, "Generating factory of the matrix A to be converted to lower precision");
67 validParamList->set<RCP<const FactoryBase> >("P", Teuchos::null, "Generating factory of the matrix A to be converted to lower precision");
68
69 return validParamList;
70}
71
72template <class LocalOrdinal, class GlobalOrdinal, class Node>
74 const ParameterList& pL = GetParameterList();
75 std::string matrixKey = pL.get<std::string>("matrix key");
76 Input(currentLevel, matrixKey);
77}
78
79template <class LocalOrdinal, class GlobalOrdinal, class Node>
81 using Teuchos::ParameterList;
82 using HalfScalar = typename Teuchos::ScalarTraits<Scalar>::halfPrecision;
83
84 const ParameterList& pL = GetParameterList();
85 std::string matrixKey = pL.get<std::string>("matrix key");
86
87 FactoryMonitor m(*this, "Converting " + matrixKey + " to half precision", currentLevel);
88
89 RCP<Matrix> A = Get<RCP<Matrix> >(currentLevel, matrixKey);
90
91 if ((A->getRowMap()->lib() == Xpetra::UseTpetra) && std::is_same<Scalar, double>::value) {
92 auto tpA = toTpetra(A);
93 auto tpLowA = tpA->template convert<HalfScalar>();
94 auto tpLowOpA = rcp(new Tpetra::CrsMatrixMultiplyOp<Scalar, HalfScalar, LocalOrdinal, GlobalOrdinal, Node>(tpLowA));
95 auto xpTpLowOpA = rcp(new TpetraOperator(tpLowOpA));
96 auto xpLowOpA = rcp_dynamic_cast<Operator>(xpTpLowOpA);
97 Set(currentLevel, matrixKey, xpLowOpA);
98 return;
99 }
100
101 GetOStream(Warnings) << "Matrix not converted to half precision. This only works for Tpetra and when both Scalar and HalfScalar have been instantiated." << std::endl;
102 Set(currentLevel, matrixKey, A);
103}
104#endif
105
106#if defined(HAVE_TPETRA_INST_COMPLEX_DOUBLE) && defined(HAVE_TPETRA_INST_COMPLEX_FLOAT)
107template <class LocalOrdinal, class GlobalOrdinal, class Node>
108RCP<const ParameterList> LowPrecisionFactory<std::complex<double>, LocalOrdinal, GlobalOrdinal, Node>::GetValidParameterList() const {
109 RCP<ParameterList> validParamList = rcp(new ParameterList());
110
111 validParamList->set<std::string>("matrix key", "A", "");
112 validParamList->set<RCP<const FactoryBase> >("R", Teuchos::null, "Generating factory of the matrix A to be converted to lower precision");
113 validParamList->set<RCP<const FactoryBase> >("A", Teuchos::null, "Generating factory of the matrix A to be converted to lower precision");
114 validParamList->set<RCP<const FactoryBase> >("P", Teuchos::null, "Generating factory of the matrix A to be converted to lower precision");
115
116 return validParamList;
117}
118
119template <class LocalOrdinal, class GlobalOrdinal, class Node>
120void LowPrecisionFactory<std::complex<double>, LocalOrdinal, GlobalOrdinal, Node>::DeclareInput(Level& currentLevel) const {
121 const ParameterList& pL = GetParameterList();
122 std::string matrixKey = pL.get<std::string>("matrix key");
123 Input(currentLevel, matrixKey);
124}
125
126template <class LocalOrdinal, class GlobalOrdinal, class Node>
127void LowPrecisionFactory<std::complex<double>, LocalOrdinal, GlobalOrdinal, Node>::Build(Level& currentLevel) const {
128 using Teuchos::ParameterList;
129 using HalfScalar = typename Teuchos::ScalarTraits<Scalar>::halfPrecision;
130
131 const ParameterList& pL = GetParameterList();
132 std::string matrixKey = pL.get<std::string>("matrix key");
133
134 FactoryMonitor m(*this, "Converting " + matrixKey + " to half precision", currentLevel);
135
136 RCP<Matrix> A = Get<RCP<Matrix> >(currentLevel, matrixKey);
137
138 if ((A->getRowMap()->lib() == Xpetra::UseTpetra) && std::is_same<Scalar, std::complex<double> >::value) {
139 auto tpA = toTpetra(A);
140 auto tpLowA = tpA->template convert<HalfScalar>();
141 auto tpLowOpA = rcp(new Tpetra::CrsMatrixMultiplyOp<Scalar, HalfScalar, LocalOrdinal, GlobalOrdinal, Node>(tpLowA));
142 auto xpTpLowOpA = rcp(new TpetraOperator(tpLowOpA));
143 auto xpLowOpA = rcp_dynamic_cast<Operator>(xpTpLowOpA);
144 Set(currentLevel, matrixKey, xpLowOpA);
145 return;
146 }
147
148 GetOStream(Warnings) << "Matrix not converted to half precision. This only works for Tpetra and when both Scalar and HalfScalar have been instantiated." << std::endl;
149 Set(currentLevel, matrixKey, A);
150}
151#endif
152
153} // namespace MueLu
154
155#endif // MUELU_LOWPRECISIONFACTORY_DEF_HPP
MueLu::DefaultLocalOrdinal LocalOrdinal
MueLu::DefaultGlobalOrdinal GlobalOrdinal
MueLu::DefaultNode Node
Timer to be used in factories. Similar to Monitor but with additional timers.
Class that holds all level-specific information.
RCP< const ParameterList > GetValidParameterList() const
Return a const parameter list of valid parameters that setParameterList() will accept.
void DeclareInput(Level &currentLevel) const
Input.
void Build(Level &currentLevel) const
Build method.
Namespace for MueLu classes and methods.
@ Warnings
Print all warning messages.