Thyra Version of the Day
Loading...
Searching...
No Matches
Thyra_TpetraLinearOp_def.hpp
1// @HEADER
2// *****************************************************************************
3// Thyra: Interfaces and Support for Abstract Numerical Algorithms
4//
5// Copyright 2004 NTESS and the Thyra contributors.
6// SPDX-License-Identifier: BSD-3-Clause
7// *****************************************************************************
8// @HEADER
9
10#ifndef THYRA_TPETRA_LINEAR_OP_HPP
11#define THYRA_TPETRA_LINEAR_OP_HPP
12
13#include "Thyra_TpetraLinearOp_decl.hpp"
14#include "Kokkos_Core.hpp"
15#include "Thyra_TpetraVectorSpace.hpp"
16#include "Teuchos_ScalarTraits.hpp"
17#include "Teuchos_TypeNameTraits.hpp"
18
19#include "Tpetra_CrsMatrix.hpp"
20
21#ifdef HAVE_THYRA_TPETRA_EPETRA
23#endif
24
25namespace Thyra {
26
27
28#ifdef HAVE_THYRA_TPETRA_EPETRA
29
30// Utilites
31
32
34 template<class Scalar, class LocalOrdinal, class GlobalOrdinal>
35class GetTpetraEpetraRowMatrixWrapper {
36public:
37 template<class TpetraMatrixType>
38 static
39 RCP<Tpetra::EpetraRowMatrix<TpetraMatrixType> >
40 get(const RCP<TpetraMatrixType> &tpetraMatrix)
41 {
42 return Teuchos::null;
43 }
44};
45
46
47// NOTE: We could support other ordinal types, but we have to
48// specialize the EpetraRowMatrix
49template<>
50class GetTpetraEpetraRowMatrixWrapper<double, int, int> {
51public:
52 template<class TpetraMatrixType>
53 static
54 RCP<Tpetra::EpetraRowMatrix<TpetraMatrixType> >
55 get(const RCP<TpetraMatrixType> &tpetraMatrix)
56 {
57 return Teuchos::rcp(
58 new Tpetra::EpetraRowMatrix<TpetraMatrixType>(tpetraMatrix,
60 *convertTpetraToThyraComm(tpetraMatrix->getRowMap()->getComm())
61 )
62 )
63 );
64 }
65};
66
67
68#endif // HAVE_THYRA_TPETRA_EPETRA
69
70
71template <class Scalar>
72inline
74convertConjNoTransToTeuchosTransMode()
75{
78 Exceptions::OpNotSupported,
79 "For complex scalars such as " + Teuchos::TypeNameTraits<Scalar>::name() +
80 ", Tpetra does not support conjugation without transposition."
81 );
82 return Teuchos::NO_TRANS; // For non-complex scalars, CONJ is equivalent to NOTRANS.
83}
84
85
86template <class Scalar>
87inline
89convertToTeuchosTransMode(const Thyra::EOpTransp transp)
90{
91 switch (transp) {
92 case NOTRANS: return Teuchos::NO_TRANS;
93 case CONJ: return convertConjNoTransToTeuchosTransMode<Scalar>();
94 case TRANS: return Teuchos::TRANS;
95 case CONJTRANS: return Teuchos::CONJ_TRANS;
96 }
97
98 // Should not escape the switch
100}
101
102
103// Constructors/initializers
104
105
106template <class Scalar, class LocalOrdinal, class GlobalOrdinal, class Node>
109
110
111template <class Scalar, class LocalOrdinal, class GlobalOrdinal, class Node>
113 const RCP<const VectorSpaceBase<Scalar> > &rangeSpace,
114 const RCP<const VectorSpaceBase<Scalar> > &domainSpace,
115 const RCP<Tpetra::Operator<Scalar,LocalOrdinal,GlobalOrdinal,Node> > &tpetraOperator
116 )
117{
118 initializeImpl(rangeSpace, domainSpace, tpetraOperator);
119}
120
121
122template <class Scalar, class LocalOrdinal, class GlobalOrdinal, class Node>
124 const RCP<const VectorSpaceBase<Scalar> > &rangeSpace,
125 const RCP<const VectorSpaceBase<Scalar> > &domainSpace,
126 const RCP<const Tpetra::Operator<Scalar,LocalOrdinal,GlobalOrdinal,Node> > &tpetraOperator
127 )
128{
129 initializeImpl(rangeSpace, domainSpace, tpetraOperator);
130}
131
132
133template <class Scalar, class LocalOrdinal, class GlobalOrdinal, class Node>
136{
137 return tpetraOperator_.getNonconstObj();
138}
139
140
141template <class Scalar, class LocalOrdinal, class GlobalOrdinal, class Node>
147
148
149// Public Overridden functions from LinearOpBase
150
151
152template <class Scalar, class LocalOrdinal, class GlobalOrdinal, class Node>
158
159
160template <class Scalar, class LocalOrdinal, class GlobalOrdinal, class Node>
163{
164 return domainSpace_;
165}
166
167
168// Overridden from EpetraLinearOpBase
169
170
171#ifdef HAVE_THYRA_TPETRA_EPETRA
172
173
174template <class Scalar, class LocalOrdinal, class GlobalOrdinal, class Node>
176 const Ptr<RCP<Epetra_Operator> > &epetraOp,
177 const Ptr<EOpTransp> &epetraOpTransp,
178 const Ptr<EApplyEpetraOpAs> &epetraOpApplyAs,
179 const Ptr<EAdjointEpetraOp> &epetraOpAdjointSupport
180 )
181{
183}
184
185
186template <class Scalar, class LocalOrdinal, class GlobalOrdinal, class Node>
187void TpetraLinearOp<Scalar,LocalOrdinal,GlobalOrdinal,Node>::getEpetraOpView(
188 const Ptr<RCP<const Epetra_Operator> > &epetraOp,
189 const Ptr<EOpTransp> &epetraOpTransp,
190 const Ptr<EApplyEpetraOpAs> &epetraOpApplyAs,
191 const Ptr<EAdjointEpetraOp> &epetraOpAdjointSupport
192 ) const
193{
194 using Teuchos::rcp_dynamic_cast;
195 typedef Tpetra::RowMatrix<Scalar,LocalOrdinal,GlobalOrdinal,Node> TpetraRowMatrix_t;
196 if (nonnull(tpetraOperator_)) {
197 if (is_null(epetraOp_)) {
198 epetraOp_ = GetTpetraEpetraRowMatrixWrapper<Scalar,LocalOrdinal,GlobalOrdinal>::get(
199 rcp_dynamic_cast<const TpetraRowMatrix_t>(tpetraOperator_.getConstObj(), true));
200 }
201 *epetraOp = epetraOp_;
202 *epetraOpTransp = NOTRANS;
203 *epetraOpApplyAs = EPETRA_OP_APPLY_APPLY;
204 *epetraOpAdjointSupport = ( tpetraOperator_->hasTransposeApply()
206 }
207 else {
208 *epetraOp = Teuchos::null;
209 }
210}
211
212
213#endif // HAVE_THYRA_TPETRA_EPETRA
214
215
216// Protected Overridden functions from LinearOpBase
217
218
219template <class Scalar, class LocalOrdinal, class GlobalOrdinal, class Node>
221 Thyra::EOpTransp M_trans) const
222{
223 if (is_null(tpetraOperator_))
224 return false;
225
226 if (M_trans == NOTRANS)
227 return true;
228
229 if (M_trans == CONJ) {
230 // For non-complex scalars, CONJ is always supported since it is equivalent to NO_TRANS.
231 // For complex scalars, Tpetra does not support conjugation without transposition.
233 }
234
235 return tpetraOperator_->hasTransposeApply();
236}
237
238
239template <class Scalar, class LocalOrdinal, class GlobalOrdinal, class Node>
241 const Thyra::EOpTransp M_trans,
244 const Scalar alpha,
245 const Scalar beta
246 ) const
247{
248 using Teuchos::rcpFromRef;
249 using Teuchos::rcpFromPtr;
251 ConverterT;
252 typedef Tpetra::MultiVector<Scalar,LocalOrdinal,GlobalOrdinal,Node>
253 TpetraMultiVector_t;
254
255 // Get Tpetra::MultiVector objects for X and Y
256
258 ConverterT::getConstTpetraMultiVector(rcpFromRef(X_in));
259
260 const RCP<TpetraMultiVector_t> tY =
261 ConverterT::getTpetraMultiVector(rcpFromPtr(Y_inout));
262
263 const Teuchos::ETransp tTransp = convertToTeuchosTransMode<Scalar>(M_trans);
264
265 // Apply the operator
266
267 tpetraOperator_->apply(*tX, *tY, tTransp, alpha, beta);
268 // CAG: Commented out since the purpose seems unclear.
269 // Tpetra apply should do all the necessary fencing.
270 // Kokkos::fence();
271}
272
273// Protected member functions overridden from ScaledLinearOpBase
274
275
276template <class Scalar, class LocalOrdinal, class GlobalOrdinal, class Node>
281
282
283template <class Scalar, class LocalOrdinal, class GlobalOrdinal, class Node>
288
289
290template <class Scalar, class LocalOrdinal, class GlobalOrdinal, class Node>
291void
293scaleLeftImpl(const VectorBase<Scalar> &row_scaling_in)
294{
295 using Teuchos::rcpFromRef;
296
299
302
303 rowMatrix->leftScale(*row_scaling);
304 Kokkos::fence();
305}
306
307
308template <class Scalar, class LocalOrdinal, class GlobalOrdinal, class Node>
309void
311scaleRightImpl(const VectorBase<Scalar> &col_scaling_in)
312{
313 using Teuchos::rcpFromRef;
314
317
320
321 rowMatrix->rightScale(*col_scaling);
322 Kokkos::fence();
323}
324
325// Protected member functions overridden from RowStatLinearOpBase
326
327
328template <class Scalar, class LocalOrdinal, class GlobalOrdinal, class Node>
331 const RowStatLinearOpBaseUtils::ERowStat rowStat) const
332{
333 if (is_null(tpetraOperator_))
334 return false;
335
336 switch (rowStat) {
337 case RowStatLinearOpBaseUtils::ROW_STAT_INV_ROW_SUM:
338 case RowStatLinearOpBaseUtils::ROW_STAT_ROW_SUM:
339 return true;
340 default:
341 return false;
342 }
343
344}
345
346
347template <class Scalar, class LocalOrdinal, class GlobalOrdinal, class Node>
349 const RowStatLinearOpBaseUtils::ERowStat rowStat,
350 const Ptr<VectorBase<Scalar> > &rowStatVec_in
351 ) const
352{
353 typedef Tpetra::Vector<Scalar,LocalOrdinal,GlobalOrdinal,Node>
354 TpetraVector_t;
356 typedef typename STS::magnitudeType MT;
357 typedef Teuchos::ScalarTraits<MT> STM;
358
359 if ( (rowStat == RowStatLinearOpBaseUtils::ROW_STAT_INV_ROW_SUM) ||
360 (rowStat == RowStatLinearOpBaseUtils::ROW_STAT_ROW_SUM) ) {
361
362 TEUCHOS_ASSERT(nonnull(tpetraOperator_));
363 TEUCHOS_ASSERT(nonnull(rowStatVec_in));
364
365 // Currently we only support the case of row sums for a concrete
366 // Tpetra::CrsMatrix where (1) the entire row is stored on a
367 // single process and (2) that the domain map, the range map and
368 // the row map are the SAME. These checks enforce that. Later on
369 // we hope to add complete support for any mapping to the concrete
370 // tpetra matrix types.
371
372 const RCP<TpetraVector_t> tRowSumVec =
374
377
378 // EGP: The following assert fails when row sum scaling is applied to blocked Tpetra operators, but without the assert, the correct row sum scaling is obtained.
379 // Furthermore, no valgrind memory errors occur in this case when the assert is removed.
380 //TEUCHOS_ASSERT(tCrsMatrix->getRowMap()->isSameAs(*tCrsMatrix->getDomainMap()));
381 TEUCHOS_ASSERT(tCrsMatrix->getRowMap()->isSameAs(*tCrsMatrix->getRangeMap()));
382 TEUCHOS_ASSERT(tCrsMatrix->getRowMap()->isSameAs(*tRowSumVec->getMap()));
383
384 size_t numMyRows = tCrsMatrix->getLocalNumRows();
385
386 using crs_t = Tpetra::CrsMatrix<Scalar,LocalOrdinal,GlobalOrdinal,Node>;
387 typename crs_t::local_inds_host_view_type indices;
388 typename crs_t::values_host_view_type values;
389
390
391 for (size_t row=0; row < numMyRows; ++row) {
392 MT sum = STM::zero ();
393 tCrsMatrix->getLocalRowView (row, indices, values);
394
395 for (int col = 0; col < (int) values.size(); ++col) {
396 sum += STS::magnitude (values[col]);
397 }
398
399 if (rowStat == RowStatLinearOpBaseUtils::ROW_STAT_INV_ROW_SUM) {
400 if (sum < STM::sfmin ()) {
401 TEUCHOS_TEST_FOR_EXCEPTION(sum == STM::zero (), std::runtime_error,
402 "Error - Thyra::TpetraLinearOp::getRowStatImpl() - Inverse row sum "
403 << "requested for a matrix where one of the rows has a zero row sum!");
404 sum = STM::one () / STM::sfmin ();
405 }
406 else {
407 sum = STM::one () / sum;
408 }
409 }
410
411 tRowSumVec->replaceLocalValue (row, Scalar (sum));
412 }
413
414 }
415 else {
416 TEUCHOS_TEST_FOR_EXCEPTION(true,std::runtime_error,
417 "Error - Thyra::TpetraLinearOp::getRowStatImpl() - Column sum support not implemented!");
418 }
419 Kokkos::fence();
420}
421
422
423// private
424
425
426template <class Scalar, class LocalOrdinal, class GlobalOrdinal, class Node>
427template<class TpetraOperator_t>
429 const RCP<const VectorSpaceBase<Scalar> > &rangeSpace,
430 const RCP<const VectorSpaceBase<Scalar> > &domainSpace,
431 const RCP<TpetraOperator_t> &tpetraOperator
432 )
433{
434#ifdef THYRA_DEBUG
435 TEUCHOS_ASSERT(nonnull(rangeSpace));
436 TEUCHOS_ASSERT(nonnull(domainSpace));
437 TEUCHOS_ASSERT(nonnull(tpetraOperator));
438 // ToDo: Assert that spaces are comparible with tpetraOperator
439#endif
440 rangeSpace_ = rangeSpace;
441 domainSpace_ = domainSpace;
442 tpetraOperator_ = tpetraOperator;
443}
444
445
446} // namespace Thyra
447
448
449#endif // THYRA_TPETRA_LINEAR_OP_HPP
Interface for a collection of column vectors called a multi-vector.
Concrete Thyra::LinearOpBase subclass for Tpetra::Operator.
virtual bool rowStatIsSupportedImpl(const RowStatLinearOpBaseUtils::ERowStat rowStat) const
void applyImpl(const Thyra::EOpTransp M_trans, const Thyra::MultiVectorBase< Scalar > &X_in, const Teuchos::Ptr< Thyra::MultiVectorBase< Scalar > > &Y_inout, const Scalar alpha, const Scalar beta) const
bool opSupportedImpl(Thyra::EOpTransp M_trans) const
virtual bool supportsScaleLeftImpl() const
RCP< const Thyra::VectorSpaceBase< Scalar > > range() const
TpetraLinearOp()
Construct to uninitialized.
virtual void getRowStatImpl(const RowStatLinearOpBaseUtils::ERowStat rowStat, const Ptr< VectorBase< Scalar > > &rowStatVec) const
virtual void scaleLeftImpl(const VectorBase< Scalar > &row_scaling)
virtual void scaleRightImpl(const VectorBase< Scalar > &col_scaling)
void initialize(const RCP< const VectorSpaceBase< Scalar > > &rangeSpace, const RCP< const VectorSpaceBase< Scalar > > &domainSpace, const RCP< Tpetra::Operator< Scalar, LocalOrdinal, GlobalOrdinal, Node > > &tpetraOperator)
Initialize.
RCP< const Tpetra::Operator< Scalar, LocalOrdinal, GlobalOrdinal, Node > > getConstTpetraOperator() const
Get embedded const Tpetra::Operator.
void constInitialize(const RCP< const VectorSpaceBase< Scalar > > &rangeSpace, const RCP< const VectorSpaceBase< Scalar > > &domainSpace, const RCP< const Tpetra::Operator< Scalar, LocalOrdinal, GlobalOrdinal, Node > > &tpetraOperator)
Initialize.
virtual bool supportsScaleRightImpl() const
RCP< Tpetra::Operator< Scalar, LocalOrdinal, GlobalOrdinal, Node > > getTpetraOperator()
Get embedded non-const Tpetra::Operator.
RCP< const Thyra::VectorSpaceBase< Scalar > > domain() const
Traits class that enables the extraction of Tpetra operator/vector objects wrapped in Thyra operator/...
static RCP< Tpetra::Vector< Scalar, LocalOrdinal, GlobalOrdinal, Node > > getTpetraVector(const RCP< VectorBase< Scalar > > &v)
Get a non-const Tpetra::Vector from a non-const Thyra::VectorBase object.
static RCP< const Tpetra::Vector< Scalar, LocalOrdinal, GlobalOrdinal, Node > > getConstTpetraVector(const RCP< const VectorBase< Scalar > > &v)
Get a const Tpetra::Vector from a const Thyra::VectorBase object.
Abstract interface for finite-dimensional dense vectors.
Abstract interface for objects that represent a space for vectors.
RCP< const Epetra_Comm > get_Epetra_Comm(const Teuchos::Comm< Ordinal > &comm)
Get (or create) and Epetra_Comm given a Teuchos::Comm object.
@ EPETRA_OP_APPLY_APPLY
Apply using Epetra_Operator::Apply(...)
@ EPETRA_OP_ADJOINT_UNSUPPORTED
Adjoint not supported.
@ EPETRA_OP_ADJOINT_SUPPORTED
Adjoint supported.
#define TEUCHOS_ASSERT(assertion_test)
#define TEUCHOS_TEST_FOR_EXCEPT(throw_exception_test)
#define TEUCHOS_TEST_FOR_EXCEPTION(throw_exception_test, Exception, msg)
EOpTransp
Enumeration for determining how a linear operator is applied. `*.
@ TRANS
Use the transposed operator.
@ NOTRANS
Use the non-transposed operator.
@ CONJTRANS
Use the transposed operator with complex-conjugate clements (same as TRANS for real scalar types).
@ CONJ
Use the non-transposed operator with complex-conjugate elements (same as NOTRANS for real scalar type...
RCP< const Teuchos::Comm< Ordinal > > convertTpetraToThyraComm(const RCP< const Teuchos::Comm< int > > &tpetraComm)
Given an Tpetra Teuchos::Comm<int> object, return an equivalent Teuchos::Comm<Ordinal> object.
T_To & dyn_cast(T_From &from)
TEUCHOS_DEPRECATED RCP< T > rcp(T *p, Dealloc_T dealloc, bool owns_mem)