10#ifndef THYRA_TPETRA_MULTIVECTOR_HPP
11#define THYRA_TPETRA_MULTIVECTOR_HPP
13#include "Thyra_TpetraMultiVector_decl.hpp"
14#include "Thyra_TpetraVectorSpace.hpp"
15#include "Thyra_TpetraVector.hpp"
16#include "Teuchos_Assert.hpp"
17#include "Kokkos_Core.hpp"
26template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
30template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
31TpetraMultiVector<Scalar,LocalOrdinal,GlobalOrdinal,Node>::~TpetraMultiVector() =
default;
34template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
38 const RCP<Tpetra::MultiVector<Scalar,LocalOrdinal,GlobalOrdinal,Node> > &tpetraMultiVector
41 initializeImpl(tpetraVectorSpace, domainSpace, tpetraMultiVector);
45template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
49 const RCP<
const Tpetra::MultiVector<Scalar,LocalOrdinal,GlobalOrdinal,Node> > &tpetraMultiVector
52 initializeImpl(tpetraVectorSpace, domainSpace, tpetraMultiVector);
56template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
60 return tpetraMultiVector_.getNonconstObj();
64template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
68 return tpetraMultiVector_;
75template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
86template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
90 tpetraMultiVector_.getNonconstObj()->putScalar(alpha);
94template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
98 auto tmv = this->getConstTpetraMultiVector(Teuchos::rcpFromRef(mv));
103 tpetraMultiVector_.getNonconstObj()->assign(*tmv);
110template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
114 tpetraMultiVector_.getNonconstObj()->scale(alpha);
118template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
124 auto tmv = this->getConstTpetraMultiVector(Teuchos::rcpFromRef(mv));
130 tpetraMultiVector_.getNonconstObj()->update(alpha, *tmv, ST::one());
137template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
149 typedef Tpetra::MultiVector<Scalar,LocalOrdinal,GlobalOrdinal,Node> TMV;
152 bool allCastsSuccessful =
true;
154 auto mvIter = mv.begin();
155 auto tmvIter = tmvs.
begin();
156 for (; mvIter != mv.end(); ++mvIter, ++tmvIter) {
157 tmv = this->getConstTpetraMultiVector(Teuchos::rcpFromPtr(*mvIter));
161 allCastsSuccessful =
false;
169 auto len = tmvs.
size();
171 tpetraMultiVector_.getNonconstObj()->scale(beta);
172 }
else if (len == 1 && allCastsSuccessful) {
173 tpetraMultiVector_.getNonconstObj()->update(alpha[0], *tmvs[0], beta);
174 }
else if (len == 2 && allCastsSuccessful) {
175 tpetraMultiVector_.getNonconstObj()->update(alpha[0], *tmvs[0], alpha[1], *tmvs[1], beta);
176 }
else if (allCastsSuccessful) {
178 auto tmvIter = tmvs.
begin();
179 auto alphaIter = alpha.
begin();
184 for (; tmvIter != tmvs.
end(); ++tmvIter) {
185 if (tmvIter->getRawPtr() == tpetraMultiVector_.getConstObj().getRawPtr()) {
192 tmvIter = tmvs.
begin();
196 if ((tmvs.
size() % 2) == 0) {
197 tpetraMultiVector_.getNonconstObj()->scale(beta);
199 tpetraMultiVector_.getNonconstObj()->update(*alphaIter, *(*tmvIter), beta);
203 for (; tmvIter != tmvs.
end(); tmvIter+=2, alphaIter+=2) {
204 tpetraMultiVector_.getNonconstObj()->update(
205 *alphaIter, *(*tmvIter), *(alphaIter+1), *(*(tmvIter+1)), ST::one());
213template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
219 auto tmv = this->getConstTpetraMultiVector(Teuchos::rcpFromRef(mv));
224 tpetraMultiVector_.getConstObj()->dot(*tmv, prods);
231template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
236 tpetraMultiVector_.getConstObj()->norm1(norms);
240template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
245 tpetraMultiVector_.getConstObj()->norm2(norms);
249template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
254 tpetraMultiVector_.getConstObj()->normInf(norms);
258template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
265 return constTpetraVector<Scalar>(
267 tpetraMultiVector_->getVector(j)
272template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
279 return tpetraVector<Scalar>(
281 tpetraMultiVector_.getNonconstObj()->getVectorNonConst(j)
286template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
292#ifdef THYRA_DEFAULT_SPMD_MULTI_VECTOR_VERBOSE_TO_ERROR_OUT
293 std::cerr <<
"\nTpetraMultiVector::subView(Range1D) const called!\n";
295 const Range1D colRng = this->validateColRange(col_rng_in);
298 this->getConstTpetraMultiVector()->subView(colRng);
301 tpetraVectorSpace<Scalar>(
302 Tpetra::createLocalMapWithNode<LocalOrdinal,GlobalOrdinal,Node>(
303 tpetraView->getNumVectors(),
304 tpetraView->getMap()->getComm()
308 return constTpetraMultiVector(
316template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
322#ifdef THYRA_DEFAULT_SPMD_MULTI_VECTOR_VERBOSE_TO_ERROR_OUT
323 std::cerr <<
"\nTpetraMultiVector::subView(Range1D) called!\n";
325 const Range1D colRng = this->validateColRange(col_rng_in);
328 this->getTpetraMultiVector()->subViewNonConst(colRng);
331 tpetraVectorSpace<Scalar>(
332 Tpetra::createLocalMapWithNode<LocalOrdinal,GlobalOrdinal,Node>(
333 tpetraView->getNumVectors(),
334 tpetraView->getMap()->getComm()
338 return tpetraMultiVector(
346template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
352#ifdef THYRA_DEFAULT_SPMD_MULTI_VECTOR_VERBOSE_TO_ERROR_OUT
353 std::cerr <<
"\nTpetraMultiVector::subView(ArrayView) const called!\n";
358 cols[i] =
static_cast<std::size_t
>(cols_in[i]);
361 this->getConstTpetraMultiVector()->subView(cols());
364 tpetraVectorSpace<Scalar>(
365 Tpetra::createLocalMapWithNode<LocalOrdinal,GlobalOrdinal,Node>(
366 tpetraView->getNumVectors(),
367 tpetraView->getMap()->getComm()
371 return constTpetraMultiVector(
379template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
385#ifdef THYRA_DEFAULT_SPMD_MULTI_VECTOR_VERBOSE_TO_ERROR_OUT
386 std::cerr <<
"\nTpetraMultiVector::subView(ArrayView) called!\n";
391 cols[i] =
static_cast<std::size_t
>(cols_in[i]);
394 this->getTpetraMultiVector()->subViewNonConst(cols());
397 tpetraVectorSpace<Scalar>(
398 Tpetra::createLocalMapWithNode<LocalOrdinal,GlobalOrdinal,Node>(
399 tpetraView->getNumVectors(),
400 tpetraView->getMap()->getComm()
404 return tpetraMultiVector(
412template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
419 const Ordinal primary_global_offset
424 primary_op, multi_vecs, targ_multi_vecs, reduct_objs, primary_global_offset);
428template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
441template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
454template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
513template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
517 return tpetraVectorSpace_;
521template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
526 *localValues = tpetraMultiVector_.getNonconstObj()->get1dViewNonConst();
527 *leadingDim = tpetraMultiVector_->getStride();
531template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
536 *localValues = tpetraMultiVector_->get1dView();
537 *leadingDim = tpetraMultiVector_->getStride();
541template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
551 typedef Tpetra::MultiVector<Scalar,LocalOrdinal,GlobalOrdinal,Node> TMV;
557 if (nonnull(X_tpetra) && nonnull(Y_tpetra)) {
561 "Error, conjugation without transposition is not allowed for complex scalar types!");
579 Y_tpetra->multiply(trans,
Teuchos::NO_TRANS, alpha, *tpetraMultiVector_.getConstObj(), *X_tpetra, beta);
590template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
591template<
class TpetraMultiVector_t>
605 tpetraVectorSpace_ = tpetraVectorSpace;
606 domainSpace_ = domainSpace;
607 tpetraMultiVector_.initialize(tpetraMultiVector);
608 this->updateSpmdSpace();
612template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
613RCP<Tpetra::MultiVector<Scalar,LocalOrdinal,GlobalOrdinal,Node> >
617 using Teuchos::rcp_dynamic_cast;
621 RCP<TMV> tmv = rcp_dynamic_cast<TMV>(mv);
623 return tmv->getTpetraMultiVector();
626 RCP<TV> tv = rcp_dynamic_cast<TV>(mv);
628 return tv->getTpetraVector();
631 return Teuchos::null;
634template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
635RCP<const Tpetra::MultiVector<Scalar,LocalOrdinal,GlobalOrdinal,Node> >
639 using Teuchos::rcp_dynamic_cast;
643 RCP<const TMV> tmv = rcp_dynamic_cast<const TMV>(mv);
645 return tmv->getConstTpetraMultiVector();
648 RCP<const TV> tv = rcp_dynamic_cast<const TV>(mv);
650 return tv->getConstTpetraVector();
653 return Teuchos::null;
657template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
658RCP<TpetraMultiVector<Scalar,LocalOrdinal,GlobalOrdinal,Node> >
662 const RCP<Tpetra::MultiVector<Scalar,LocalOrdinal,GlobalOrdinal,Node> > &tpetraMultiVector
667 tmv->initialize(tpetraVectorSpace, domainSpace, tpetraMultiVector);
672template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
674constTpetraMultiVector(
677 const RCP<
const Tpetra::MultiVector<Scalar,LocalOrdinal,GlobalOrdinal,Node> > &tpetraMultiVector
682 tmv->constInitialize(tpetraVectorSpace, domainSpace, tpetraMultiVector);
688#define THYRATPETRAADAPTERS_TPETRAMULTIVECTOR_INSTANT(S, LO, GO, N) \
689 template class Thyra::TpetraMultiVector<S, LO, GO, N>; \
691 template Teuchos::RCP<Thyra::TpetraMultiVector<S, LO, GO, N>> \
692 Thyra::tpetraMultiVector( \
693 const Teuchos::RCP<const Thyra::TpetraVectorSpace<S, LO, GO, N>> &, \
694 const Teuchos::RCP<const Thyra::ScalarProdVectorSpaceBase<S>> &, \
695 const Teuchos::RCP<Tpetra::MultiVector<S, LO, GO, N>> &); \
697 template Teuchos::RCP<const Thyra::TpetraMultiVector<S, LO, GO, N>> \
698 Thyra::constTpetraMultiVector( \
699 const Teuchos::RCP<const Thyra::TpetraVectorSpace<S, LO, GO, N>> &, \
700 const Teuchos::RCP<const Thyra::ScalarProdVectorSpaceBase<S>> &, \
701 const Teuchos::RCP<const Tpetra::MultiVector<S, LO, GO, N>> &r);
Interface for a collection of column vectors called a multi-vector.
virtual void mvMultiReductApplyOpImpl(const RTOpPack::RTOpT< Scalar > &primary_op, const ArrayView< const Ptr< const MultiVectorBase< Scalar > > > &multi_vecs, const ArrayView< const Ptr< MultiVectorBase< Scalar > > > &targ_multi_vecs, const ArrayView< const Ptr< RTOpPack::ReductTarget > > &reduct_objs, const Ordinal primary_global_offset) const =0
Apply a reduction/transformation operator column by column and return an array of the reduction objec...
virtual void dotsImpl(const MultiVectorBase< Scalar > &mv, const ArrayView< Scalar > &prods) const
Default implementation of dots using RTOps.
virtual void linearCombinationImpl(const ArrayView< const Scalar > &alpha, const ArrayView< const Ptr< const MultiVectorBase< Scalar > > > &mv, const Scalar &beta)
Default implementation of linear_combination using RTOps.
virtual void assignMultiVecImpl(const MultiVectorBase< Scalar > &mv)
Default implementation of assign(MV) using RTOps.
virtual void updateImpl(Scalar alpha, const MultiVectorBase< Scalar > &mv)
Default implementation of update using RTOps.
void acquireDetachedMultiVectorViewImpl(const Range1D &rowRng, const Range1D &colRng, RTOpPack::ConstSubMultiVectorView< Scalar > *sub_mv) const
void euclideanApply(const EOpTransp M_trans, const MultiVectorBase< Scalar > &X, const Ptr< MultiVectorBase< Scalar > > &Y, const Scalar alpha, const Scalar beta) const
Uses GEMM() and Teuchos::reduceAll() to implement.
void acquireNonconstDetachedMultiVectorViewImpl(const Range1D &rowRng, const Range1D &colRng, RTOpPack::SubMultiVectorView< Scalar > *sub_mv)
void commitNonconstDetachedMultiVectorViewImpl(RTOpPack::SubMultiVectorView< Scalar > *sub_mv)
Concrete implementation of Thyra::MultiVector in terms of Tpetra::MultiVector.
RCP< Tpetra::MultiVector< Scalar, LocalOrdinal, GlobalOrdinal, Node > > getTpetraMultiVector()
Extract the underlying non-const Tpetra::MultiVector object.
void initialize(const RCP< const TpetraVectorSpace< Scalar, LocalOrdinal, GlobalOrdinal, Node > > &tpetraVectorSpace, const RCP< const ScalarProdVectorSpaceBase< Scalar > > &domainSpace, const RCP< Tpetra::MultiVector< Scalar, LocalOrdinal, GlobalOrdinal, Node > > &tpetraMultiVector)
Initialize.
RCP< const SpmdVectorSpaceBase< Scalar > > spmdSpaceImpl() const
RCP< const ScalarProdVectorSpaceBase< Scalar > > domainScalarProdVecSpc() const
RCP< const Tpetra::MultiVector< Scalar, LocalOrdinal, GlobalOrdinal, Node > > getConstTpetraMultiVector() const
Extract the underlying const Tpetra::MultiVector object.
virtual void norms2Impl(const ArrayView< typename ScalarTraits< Scalar >::magnitudeType > &norms) const
void commitNonconstDetachedMultiVectorViewImpl(RTOpPack::SubMultiVectorView< Scalar > *sub_mv)
virtual void updateImpl(Scalar alpha, const MultiVectorBase< Scalar > &mv)
void acquireNonconstDetachedMultiVectorViewImpl(const Range1D &rowRng, const Range1D &colRng, RTOpPack::SubMultiVectorView< Scalar > *sub_mv)
RCP< const VectorBase< Scalar > > colImpl(Ordinal j) const
void acquireDetachedMultiVectorViewImpl(const Range1D &rowRng, const Range1D &colRng, RTOpPack::ConstSubMultiVectorView< Scalar > *sub_mv) const
void getLocalMultiVectorDataImpl(const Ptr< ArrayRCP< const Scalar > > &localValues, const Ptr< Ordinal > &leadingDim) const
RCP< MultiVectorBase< Scalar > > nonconstNonContigSubViewImpl(const ArrayView< const int > &cols_in)
RCP< const MultiVectorBase< Scalar > > nonContigSubViewImpl(const ArrayView< const int > &cols_in) const
RCP< MultiVectorBase< Scalar > > nonconstContigSubViewImpl(const Range1D &colRng)
RCP< const MultiVectorBase< Scalar > > contigSubViewImpl(const Range1D &colRng) const
void constInitialize(const RCP< const TpetraVectorSpace< Scalar, LocalOrdinal, GlobalOrdinal, Node > > &tpetraVectorSpace, const RCP< const ScalarProdVectorSpaceBase< Scalar > > &domainSpace, const RCP< const Tpetra::MultiVector< Scalar, LocalOrdinal, GlobalOrdinal, Node > > &tpetraMultiVector)
Initialize.
virtual void norms1Impl(const ArrayView< typename ScalarTraits< Scalar >::magnitudeType > &norms) const
TpetraMultiVector()
Construct to uninitialized.
RCP< VectorBase< Scalar > > nonconstColImpl(Ordinal j)
virtual void dotsImpl(const MultiVectorBase< Scalar > &mv, const ArrayView< Scalar > &prods) const
virtual void euclideanApply(const EOpTransp M_trans, const MultiVectorBase< Scalar > &X, const Ptr< MultiVectorBase< Scalar > > &Y, const Scalar alpha, const Scalar beta) const
virtual void assignMultiVecImpl(const MultiVectorBase< Scalar > &mv)
virtual void scaleImpl(Scalar alpha)
void getNonconstLocalMultiVectorDataImpl(const Ptr< ArrayRCP< Scalar > > &localValues, const Ptr< Ordinal > &leadingDim)
virtual void mvMultiReductApplyOpImpl(const RTOpPack::RTOpT< Scalar > &primary_op, const ArrayView< const Ptr< const MultiVectorBase< Scalar > > > &multi_vecs, const ArrayView< const Ptr< MultiVectorBase< Scalar > > > &targ_multi_vecs, const ArrayView< const Ptr< RTOpPack::ReductTarget > > &reduct_objs, const Ordinal primary_global_offset) const
virtual void assignImpl(Scalar alpha)
virtual void linearCombinationImpl(const ArrayView< const Scalar > &alpha, const ArrayView< const Ptr< const MultiVectorBase< Scalar > > > &mv, const Scalar &beta)
virtual void normsInfImpl(const ArrayView< typename ScalarTraits< Scalar >::magnitudeType > &norms) const
Concrete Thyra::SpmdVectorBase using Tpetra::Vector.
Concrete implementation of an SPMD vector space for Tpetra.
#define TEUCHOS_ASSERT(assertion_test)
#define TEUCHOS_ASSERT_IN_RANGE_UPPER_EXCLUSIVE(index, lower_inclusive, upper_exclusive)
#define TEUCHOS_TEST_FOR_EXCEPTION(throw_exception_test, Exception, msg)
#define TEUCHOS_ASSERT_EQUALITY(val1, val2)
bool nonnull(const std::shared_ptr< T > &p)
EOpTransp
Enumeration for determining how a linear operator is applied. `*.
Teuchos::Ordinal Ordinal
Type for the dimension of a vector space. `*.
@ TRANS
Use the transposed operator.
@ NOTRANS
Use the non-transposed operator.
@ CONJTRANS
Use the transposed operator with complex-conjugate clements (same as TRANS for real scalar types).
@ CONJ
Use the non-transposed operator with complex-conjugate elements (same as NOTRANS for real scalar type...
TEUCHOS_DEPRECATED RCP< T > rcp(T *p, Dealloc_T dealloc, bool owns_mem)