10#ifndef MUELU_LOWPRECISIONFACTORY_DEF_HPP
11#define MUELU_LOWPRECISIONFACTORY_DEF_HPP
13#include <Xpetra_Matrix.hpp>
14#include <Xpetra_Operator.hpp>
15#include <Xpetra_TpetraOperator.hpp>
16#include <Tpetra_CrsMatrixMultiplyOp.hpp>
25template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
27 RCP<ParameterList> validParamList = rcp(
new ParameterList());
29 validParamList->set<std::string>(
"matrix key",
"A",
"");
30 validParamList->set<RCP<const FactoryBase> >(
"R", Teuchos::null,
"Generating factory of the matrix A to be converted to lower precision");
31 validParamList->set<RCP<const FactoryBase> >(
"A", Teuchos::null,
"Generating factory of the matrix A to be converted to lower precision");
32 validParamList->set<RCP<const FactoryBase> >(
"P", Teuchos::null,
"Generating factory of the matrix A to be converted to lower precision");
34 return validParamList;
37template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
39 const ParameterList& pL = GetParameterList();
40 std::string matrixKey = pL.get<std::string>(
"matrix key");
41 Input(currentLevel, matrixKey);
44template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
46 using Teuchos::ParameterList;
48 const ParameterList& pL = GetParameterList();
49 std::string matrixKey = pL.get<std::string>(
"matrix key");
51 FactoryMonitor m(*
this,
"Converting " + matrixKey +
" to half precision", currentLevel);
53 RCP<Matrix> A = Get<RCP<Matrix> >(currentLevel, matrixKey);
55 GetOStream(
Warnings) <<
"Matrix not converted to half precision. This only works for Tpetra and when both Scalar and HalfScalar have been instantiated." << std::endl;
56 Set(currentLevel, matrixKey, A);
59#if defined(HAVE_TPETRA_INST_DOUBLE) && defined(HAVE_TPETRA_INST_FLOAT)
60template <
class LocalOrdinal,
class GlobalOrdinal,
class Node>
62 RCP<ParameterList> validParamList = rcp(
new ParameterList());
64 validParamList->set<std::string>(
"matrix key",
"A",
"");
65 validParamList->set<RCP<const FactoryBase> >(
"R", Teuchos::null,
"Generating factory of the matrix A to be converted to lower precision");
66 validParamList->set<RCP<const FactoryBase> >(
"A", Teuchos::null,
"Generating factory of the matrix A to be converted to lower precision");
67 validParamList->set<RCP<const FactoryBase> >(
"P", Teuchos::null,
"Generating factory of the matrix A to be converted to lower precision");
69 return validParamList;
72template <
class LocalOrdinal,
class GlobalOrdinal,
class Node>
74 const ParameterList& pL = GetParameterList();
75 std::string matrixKey = pL.get<std::string>(
"matrix key");
76 Input(currentLevel, matrixKey);
79template <
class LocalOrdinal,
class GlobalOrdinal,
class Node>
81 using Teuchos::ParameterList;
82 using HalfScalar =
typename Teuchos::ScalarTraits<Scalar>::halfPrecision;
84 const ParameterList& pL = GetParameterList();
85 std::string matrixKey = pL.get<std::string>(
"matrix key");
87 FactoryMonitor m(*
this,
"Converting " + matrixKey +
" to half precision", currentLevel);
89 RCP<Matrix> A = Get<RCP<Matrix> >(currentLevel, matrixKey);
91 if ((A->getRowMap()->lib() == Xpetra::UseTpetra) && std::is_same<Scalar, double>::value) {
92 auto tpA = toTpetra(A);
93 auto tpLowA = tpA->template convert<HalfScalar>();
94 auto tpLowOpA = rcp(
new Tpetra::CrsMatrixMultiplyOp<Scalar, HalfScalar, LocalOrdinal, GlobalOrdinal, Node>(tpLowA));
95 auto xpTpLowOpA = rcp(
new TpetraOperator(tpLowOpA));
96 auto xpLowOpA = rcp_dynamic_cast<Operator>(xpTpLowOpA);
97 Set(currentLevel, matrixKey, xpLowOpA);
101 GetOStream(
Warnings) <<
"Matrix not converted to half precision. This only works for Tpetra and when both Scalar and HalfScalar have been instantiated." << std::endl;
102 Set(currentLevel, matrixKey, A);
106#if defined(HAVE_TPETRA_INST_COMPLEX_DOUBLE) && defined(HAVE_TPETRA_INST_COMPLEX_FLOAT)
107template <
class LocalOrdinal,
class GlobalOrdinal,
class Node>
109 RCP<ParameterList> validParamList = rcp(
new ParameterList());
111 validParamList->set<std::string>(
"matrix key",
"A",
"");
112 validParamList->set<RCP<const FactoryBase> >(
"R", Teuchos::null,
"Generating factory of the matrix A to be converted to lower precision");
113 validParamList->set<RCP<const FactoryBase> >(
"A", Teuchos::null,
"Generating factory of the matrix A to be converted to lower precision");
114 validParamList->set<RCP<const FactoryBase> >(
"P", Teuchos::null,
"Generating factory of the matrix A to be converted to lower precision");
116 return validParamList;
119template <
class LocalOrdinal,
class GlobalOrdinal,
class Node>
121 const ParameterList& pL = GetParameterList();
122 std::string matrixKey = pL.get<std::string>(
"matrix key");
123 Input(currentLevel, matrixKey);
126template <
class LocalOrdinal,
class GlobalOrdinal,
class Node>
128 using Teuchos::ParameterList;
129 using HalfScalar =
typename Teuchos::ScalarTraits<Scalar>::halfPrecision;
131 const ParameterList& pL = GetParameterList();
132 std::string matrixKey = pL.get<std::string>(
"matrix key");
134 FactoryMonitor m(*
this,
"Converting " + matrixKey +
" to half precision", currentLevel);
136 RCP<Matrix> A = Get<RCP<Matrix> >(currentLevel, matrixKey);
138 if ((A->getRowMap()->lib() == Xpetra::UseTpetra) && std::is_same<Scalar, std::complex<double> >::value) {
139 auto tpA = toTpetra(A);
140 auto tpLowA = tpA->template convert<HalfScalar>();
141 auto tpLowOpA = rcp(
new Tpetra::CrsMatrixMultiplyOp<Scalar, HalfScalar, LocalOrdinal, GlobalOrdinal, Node>(tpLowA));
142 auto xpTpLowOpA = rcp(
new TpetraOperator(tpLowOpA));
143 auto xpLowOpA = rcp_dynamic_cast<Operator>(xpTpLowOpA);
144 Set(currentLevel, matrixKey, xpLowOpA);
148 GetOStream(
Warnings) <<
"Matrix not converted to half precision. This only works for Tpetra and when both Scalar and HalfScalar have been instantiated." << std::endl;
149 Set(currentLevel, matrixKey, A);
MueLu::DefaultLocalOrdinal LocalOrdinal
MueLu::DefaultGlobalOrdinal GlobalOrdinal
Timer to be used in factories. Similar to Monitor but with additional timers.
Class that holds all level-specific information.
RCP< const ParameterList > GetValidParameterList() const
Return a const parameter list of valid parameters that setParameterList() will accept.
void DeclareInput(Level ¤tLevel) const
Input.
void Build(Level ¤tLevel) const
Build method.
Namespace for MueLu classes and methods.
@ Warnings
Print all warning messages.