14#ifndef _ZOLTAN2_XPETRAMULTIVECTORADAPTER_HPP_
15#define _ZOLTAN2_XPETRAMULTIVECTORADAPTER_HPP_
22#include <Xpetra_TpetraMultiVector.hpp>
42template <
typename User>
46#ifndef DOXYGEN_SHOULD_SKIP_THIS
54 typedef User userCoord_t;
56 typedef Xpetra::MultiVector<scalar_t, lno_t, gno_t, node_t> x_mvector_t;
57 typedef Xpetra::TpetraMultiVector<
77 std::vector<const scalar_t *> &
weights, std::vector<int> &weightStrides);
95 ids = map_->getLocalElementList().getRawPtr();
99 Kokkos::View<const gno_t *, typename node_t::device_type> &ids)
const {
100 if (map_->lib() == Xpetra::UseTpetra) {
101 using device_type =
typename node_t::device_type;
102 const xt_mvector_t *tvector =
103 dynamic_cast<const xt_mvector_t *
>(vector_.get());
111 ids = Kokkos::create_mirror_view_and_copy(device_type(),
112 tvector->getTpetra_MultiVector()->getMap()->getMyGlobalIndices());
115 throw std::logic_error(
"getIDsKokkosView called but not on Tpetra!");
123 if(idx<0 || idx >= numWeights_)
125 std::ostringstream emsg;
126 emsg << __FILE__ <<
":" << __LINE__
127 <<
" Invalid weight index " << idx << std::endl;
128 throw std::runtime_error(emsg.str());
132 weights_[idx].getStridedList(length,
weights, stride);
136 typename node_t::device_type> &wgt)
const {
137 typedef Kokkos::View<scalar_t**, typename node_t::device_type> view_t;
138 wgt = view_t(
"wgts", vector_->getLocalLength(), numWeights_);
139 typename view_t::host_mirror_type host_wgt = Kokkos::create_mirror_view(wgt);
140 for(
int idx = 0; idx < numWeights_; ++idx) {
144 weights_[idx].getStridedList(length,
weights, stride);
145 size_t fill_index = 0;
146 for(
size_t n = 0; n < length; n += stride) {
147 host_wgt(fill_index++,idx) =
weights[n];
150 Kokkos::deep_copy(wgt, host_wgt);
163 Kokkos::View<impl_scalar_t **, Kokkos::LayoutLeft,
164 typename node_t::device_type> & elements)
const;
166 template <
typename Adapter>
170 template <
typename Adapter>
176 RCP<const User> invector_;
177 RCP<const x_mvector_t> vector_;
178 RCP<const Xpetra::Map<lno_t, gno_t, node_t> > map_;
181 ArrayRCP<StridedData<lno_t, scalar_t> > weights_;
188template <
typename User>
190 const RCP<const User> &invector,
191 std::vector<const scalar_t *> &
weights, std::vector<int> &weightStrides):
192 invector_(invector), vector_(), map_(),
198 RCP<x_mvector_t> tmp =
200 vector_ = rcp_const_cast<const x_mvector_t>(tmp);
204 map_ = vector_->getMap();
206 size_t length = vector_->getLocalLength();
208 if (length > 0 && numWeights_ > 0){
210 for (
int w=0; w < numWeights_; w++){
211 if (weightStrides.size())
212 stride = weightStrides[w];
213 ArrayRCP<const scalar_t> wgtV(
weights[w], 0, stride*length,
false);
214 weights_[w] = input_t(wgtV, stride);
221template <
typename User>
223 const RCP<const User> &invector):
224 invector_(invector), vector_(), map_(),
225 numWeights_(0), weights_()
228 RCP<x_mvector_t> tmp =
230 vector_ = rcp_const_cast<const x_mvector_t>(tmp);
234 map_ = vector_->getMap();
238template <
typename User>
240 const scalar_t *&elements,
int &stride,
int idx)
const
245 if (map_->lib() == Xpetra::UseTpetra){
246 const xt_mvector_t *tvector =
247 dynamic_cast<const xt_mvector_t *
>(vector_.get());
249 vecsize = tvector->getLocalLength();
251 ArrayRCP<const scalar_t> data = tvector->getData(idx);
252 elements = data.get();
256 throw std::logic_error(
"invalid underlying lib");
261template <
typename User>
264 Kokkos::View<impl_scalar_t **, Kokkos::LayoutLeft, typename node_t::device_type> & elements)
const
266 if (map_->lib() == Xpetra::UseTpetra){
267 const xt_mvector_t *tvector =
268 dynamic_cast<const xt_mvector_t *
>(vector_.get());
271 tvector->getTpetra_MultiVector()->template getLocalView<typename node_t::device_type>(Tpetra::Access::ReadWrite);
278 throw std::logic_error(
"getEntriesKokkosView called but not using Tpetra!");
283template <
typename User>
284 template <
typename Adapter>
286 const User &in, User *&out,
291 ArrayRCP<gno_t> importList;
295 (solution,
this, importList);
301 importList.getRawPtr());
307template <
typename User>
308 template <
typename Adapter>
310 const User &in, RCP<User> &out,
315 ArrayRCP<gno_t> importList;
319 (solution,
this, importList);
325 importList.getRawPtr());
Zoltan2::BasicUserTypes< zscalar_t, zlno_t, zgno_t > user_t
#define Z2_FORWARD_EXCEPTIONS
Forward an exception back through call stack.
Helper functions for Partitioning Problems.
This file defines the StridedData class.
Defines the VectorAdapter interface.
Traits of Xpetra classes, including migration method.
typename BaseAdapter< User >::scalar_t scalar_t
typename InputTraits< User >::node_t node_t
typename InputTraits< User >::lno_t lno_t
typename InputTraits< User >::gno_t gno_t
typename InputTraits< User >::part_t part_t
A PartitioningSolution is a solution to a partitioning problem.
The StridedData class manages lists of weights or coordinates.
VectorAdapter defines the interface for vector input.
An adapter for Xpetra::MultiVector.
void getEntriesView(const scalar_t *&elements, int &stride, int idx=0) const
Provide a pointer to the elements of the specified vector.
void applyPartitioningSolution(const User &in, User *&out, const PartitioningSolution< Adapter > &solution) const
void getIDsView(const gno_t *&ids) const
void getWeightsKokkos2dView(Kokkos::View< scalar_t **, typename node_t::device_type > &wgt) const
XpetraMultiVectorAdapter(const RCP< const User > &invector, std::vector< const scalar_t * > &weights, std::vector< int > &weightStrides)
Constructor.
void getIDsKokkosView(Kokkos::View< const gno_t *, typename node_t::device_type > &ids) const
int getNumEntriesPerID() const
Return the number of vectors.
int getNumWeightsPerID() const
Returns the number of weights per object. Number of weights per object should be zero or greater....
void getWeightsView(const scalar_t *&weights, int &stride, int idx) const
void getEntriesKokkosView(Kokkos::View< impl_scalar_t **, Kokkos::LayoutLeft, typename node_t::device_type > &elements) const
size_t getLocalNumIDs() const
Returns the number of objects on this process.
map_t::global_ordinal_type gno_t
Created by mbenlioglu on Aug 31, 2020.
size_t getImportList(const PartitioningSolution< SolutionAdapter > &solution, const DataAdapter *const data, ArrayRCP< typename DataAdapter::gno_t > &imports)
From a PartitioningSolution, get a list of IDs to be imported. Assumes part numbers in PartitioningSo...
static RCP< User > doMigration(const User &from, size_t numLocalRows, const gno_t *myNewRows)
Migrate the object Given a user object and a new row distribution, create and return a new user objec...
static RCP< User > convertToXpetra(const RCP< User > &a)
Convert the object to its Xpetra wrapped version.