14#ifndef _ZOLTAN2_TPETRAROWMATRIXADAPTER_HPP_
15#define _ZOLTAN2_TPETRAROWMATRIXADAPTER_HPP_
21#include <Tpetra_RowMatrix.hpp>
41template <
typename User,
typename UserCoord = User>
46#ifndef DOXYGEN_SHOULD_SKIP_THIS
53 using device_t =
typename node_t::device_type;
54 using host_t =
typename Kokkos::HostSpace::memory_space;
56 using userCoord_t = UserCoord;
67 int nWeightsPerRow = 0);
182 typename Base::ConstIdsHostView &rowIds)
const override;
185 typename Base::ConstIdsDeviceView &rowIds)
const override;
188 ArrayRCP<const gno_t> &colIds)
const override;
191 typename Base::ConstOffsetsHostView &offsets,
192 typename Base::ConstIdsHostView &colIds)
const override;
195 typename Base::ConstOffsetsDeviceView &offsets,
196 typename Base::ConstIdsDeviceView &colIds)
const override;
199 ArrayRCP<const gno_t> &colIds,
200 ArrayRCP<const scalar_t> &values)
const override;
203 typename Base::ConstOffsetsHostView &offsets,
204 typename Base::ConstIdsHostView &colIds,
205 typename Base::ConstScalarsHostView &values)
const override;
208 typename Base::ConstOffsetsDeviceView &offsets,
209 typename Base::ConstIdsDeviceView &colIds,
210 typename Base::ConstScalarsDeviceView &values)
const override;
215 int idx = 0)
const override;
218 int idx = 0)
const override;
221 typename Base::WeightsDeviceView &
weights)
const override;
224 int idx = 0)
const override;
227 typename Base::WeightsHostView &
weights)
const override;
231 template <
typename Adapter>
233 const User &in, User *&out,
236 template <
typename Adapter>
238 const User &in, RCP<User> &out,
244 const RCP<const User> &inmatrix)
268 virtual RCP<User>
doMigration(
const User &from,
size_t numLocalRows,
269 const gno_t *myNewRows)
const;
276template <
typename User,
typename UserCoord>
278 const RCP<const User> &inmatrix,
int nWeightsPerRow):
279 matrix_(inmatrix), offset_(), columnIds_(),
280 nWeightsPerRow_(nWeightsPerRow), rowWeights_(),
281 mayHaveDiagonalEntries(true) {
283 using localInds_t =
typename User::nonconst_local_inds_host_view_type;
284 using localVals_t =
typename User::nonconst_values_host_view_type;
286 const auto nrows =
matrix_->getLocalNumRows();
287 const auto nnz =
matrix_->getLocalNumEntries();
288 auto maxNumEntries =
matrix_->getLocalMaxNumRowEntries();
293 colIdsHost_ =
typename Base::ConstIdsHostView(
"colIdsHost_", nnz);
294 offsHost_ =
typename Base::ConstOffsetsHostView(
"offsHost_", nrows + 1);
295 valuesHost_ =
typename Base::ScalarsHostView(
"valuesHost_", nnz);
297 localInds_t localColInds(
"localColInds", maxNumEntries);
298 localVals_t localVals(
"localVals", maxNumEntries);
300 for (
size_t r = 0; r < nrows; r++) {
301 size_t numEntries = 0;
302 matrix_->getLocalRowCopy(r, localColInds, localVals, numEntries);
308 for (
size_t j = 0; j < numEntries; j++) {
336template <
typename User,
typename UserCoord>
338 const scalar_t *weightVal,
int stride,
int idx) {
339 if (this->getPrimaryEntityType() ==
MATRIX_ROW)
340 setRowWeights(weightVal, stride, idx);
343 std::ostringstream emsg;
344 emsg << __FILE__ <<
"," << __LINE__
345 <<
" error: setWeights not yet supported for"
346 <<
" columns or nonzeros." << std::endl;
347 throw std::runtime_error(emsg.str());
352template <
typename User,
typename UserCoord>
354 typename Base::ConstWeightsDeviceView1D val,
int idx) {
355 if (this->getPrimaryEntityType() ==
MATRIX_ROW)
356 setRowWeightsDevice(val, idx);
359 std::ostringstream emsg;
360 emsg << __FILE__ <<
"," << __LINE__
361 <<
" error: setWeights not yet supported for"
362 <<
" columns or nonzeros." << std::endl;
363 throw std::runtime_error(emsg.str());
368template <
typename User,
typename UserCoord>
370 typename Base::ConstWeightsHostView1D val,
int idx) {
371 if (this->getPrimaryEntityType() ==
MATRIX_ROW)
372 setRowWeightsHost(val, idx);
375 std::ostringstream emsg;
376 emsg << __FILE__ <<
"," << __LINE__
377 <<
" error: setWeights not yet supported for"
378 <<
" columns or nonzeros." << std::endl;
379 throw std::runtime_error(emsg.str());
384template <
typename User,
typename UserCoord>
386 const scalar_t *weightVal,
int stride,
int idx) {
389 "Invalid row weight index: " + std::to_string(idx));
391 size_t nrows = getLocalNumRows();
392 ArrayRCP<const scalar_t> weightV(weightVal, 0, nrows * stride,
false);
393 rowWeights_[idx] = input_t(weightV, stride);
397template <
typename User,
typename UserCoord>
399 typename Base::ConstWeightsDeviceView1D
weights,
int idx) {
402 "Invalid row weight index: " + std::to_string(idx));
404 auto rowWeightsDevice = this->rowWeightsDevice_;
405 Kokkos::parallel_for(
406 rowWeightsDevice.extent(0), KOKKOS_LAMBDA(
const int rowID) {
407 rowWeightsDevice(rowID, idx) =
weights(rowID);
414template <
typename User,
typename UserCoord>
416 typename Base::ConstWeightsHostView1D weightsHost,
int idx) {
418 "Invalid row weight index: " + std::to_string(idx));
420 auto weightsDevice = Kokkos::create_mirror_view_and_copy(
421 typename Base::device_t(), weightsHost);
423 setRowWeightsDevice(weightsDevice, idx);
427template <
typename User,
typename UserCoord>
429 if (this->getPrimaryEntityType() ==
MATRIX_ROW)
430 setRowWeightIsNumberOfNonZeros(idx);
433 std::ostringstream emsg;
434 emsg << __FILE__ <<
"," << __LINE__
435 <<
" error: setWeightIsNumberOfNonZeros not yet supported for"
436 <<
" columns" << std::endl;
437 throw std::runtime_error(emsg.str());
442template <
typename User,
typename UserCoord>
445 if (idx < 0 || idx >= nWeightsPerRow_) {
446 std::ostringstream emsg;
447 emsg << __FILE__ <<
":" << __LINE__ <<
" Invalid row weight index " << idx
449 throw std::runtime_error(emsg.str());
452 numNzWeight_(idx) =
true;
456template <
typename User,
typename UserCoord>
458 return matrix_->getLocalNumRows();
462template <
typename User,
typename UserCoord>
464 return matrix_->getLocalNumCols();
468template <
typename User,
typename UserCoord>
470 return matrix_->getLocalNumEntries();
474template <
typename User,
typename UserCoord>
478template <
typename User,
typename UserCoord>
480 ArrayView<const gno_t> rowView = matrix_->getRowMap()->getLocalElementList();
481 rowIds = rowView.getRawPtr();
485template <
typename User,
typename UserCoord>
487 typename Base::ConstIdsHostView &rowIds)
const {
488 auto idsDevice = matrix_->getRowMap()->getMyGlobalIndices();
489 auto tmpIds =
typename Base::IdsHostView(
"", idsDevice.extent(0));
491 Kokkos::deep_copy(tmpIds, idsDevice);
497template <
typename User,
typename UserCoord>
499 typename Base::ConstIdsDeviceView &rowIds)
const {
501 auto idsDevice = matrix_->getRowMap()->getMyGlobalIndices();
502 auto tmpIds =
typename Base::IdsDeviceView(
"", idsDevice.extent(0));
504 Kokkos::deep_copy(tmpIds, idsDevice);
510template <
typename User,
typename UserCoord>
512 ArrayRCP<const gno_t> &colIds)
const {
518template <
typename User,
typename UserCoord>
520 typename Base::ConstOffsetsHostView &offsets,
521 typename Base::ConstIdsHostView &colIds)
const {
522 auto hostOffsets = Kokkos::create_mirror_view(offsDevice_);
523 Kokkos::deep_copy(hostOffsets, offsDevice_);
524 offsets = hostOffsets;
526 auto hostColIds = Kokkos::create_mirror_view(colIdsDevice_);
527 Kokkos::deep_copy(hostColIds, colIdsDevice_);
532template <
typename User,
typename UserCoord>
534 typename Base::ConstOffsetsDeviceView &offsets,
535 typename Base::ConstIdsDeviceView &colIds)
const {
536 offsets = offsDevice_;
537 colIds = colIdsDevice_;
541template <
typename User,
typename UserCoord>
543 ArrayRCP<const gno_t> &colIds,
544 ArrayRCP<const scalar_t> &values)
const {
551template <
typename User,
typename UserCoord>
553 typename Base::ConstOffsetsHostView &offsets,
554 typename Base::ConstIdsHostView &colIds,
555 typename Base::ConstScalarsHostView &values)
const {
556 auto hostOffsets = Kokkos::create_mirror_view(offsDevice_);
557 Kokkos::deep_copy(hostOffsets, offsDevice_);
558 offsets = hostOffsets;
560 auto hostColIds = Kokkos::create_mirror_view(colIdsDevice_);
561 Kokkos::deep_copy(hostColIds, colIdsDevice_);
564 auto hostValues = Kokkos::create_mirror_view(valuesDevice_);
565 Kokkos::deep_copy(hostValues, valuesDevice_);
570template <
typename User,
typename UserCoord>
572 typename Base::ConstOffsetsDeviceView &offsets,
573 typename Base::ConstIdsDeviceView &colIds,
574 typename Base::ConstScalarsDeviceView &values)
const {
575 offsets = offsDevice_;
576 colIds = colIdsDevice_;
577 values = valuesDevice_;
581template <
typename User,
typename UserCoord>
585template <
typename User,
typename UserCoord>
588 if (idx < 0 || idx >= nWeightsPerRow_) {
589 std::ostringstream emsg;
590 emsg << __FILE__ <<
":" << __LINE__ <<
" Invalid row weight index "
592 throw std::runtime_error(emsg.str());
596 rowWeights_[idx].getStridedList(length,
weights, stride);
600template <
typename User,
typename UserCoord>
602 typename Base::WeightsDeviceView1D &
weights,
int idx)
const {
604 "Invalid row weight index.");
606 const auto size = rowWeightsDevice_.extent(0);
607 weights =
typename Base::WeightsDeviceView1D(
"weights", size);
609 auto rowWeightsDevice = this->rowWeightsDevice_;
610 Kokkos::parallel_for(
611 size, KOKKOS_LAMBDA(
const int id) {
612 weights(
id) = rowWeightsDevice(
id, idx);
619template <
typename User,
typename UserCoord>
621 typename Base::WeightsDeviceView &
weights)
const {
627template <
typename User,
typename UserCoord>
629 typename Base::WeightsHostView1D &
weights,
int idx)
const {
631 "Invalid row weight index.");
633 auto weightsDevice =
typename Base::WeightsDeviceView1D(
634 "weights", rowWeightsDevice_.extent(0));
635 getRowWeightsDeviceView(weightsDevice, idx);
637 weights = Kokkos::create_mirror_view(weightsDevice);
638 Kokkos::deep_copy(
weights, weightsDevice);
642template <
typename User,
typename UserCoord>
644 typename Base::WeightsHostView &
weights)
const {
646 weights = Kokkos::create_mirror_view(rowWeightsDevice_);
647 Kokkos::deep_copy(
weights, rowWeightsDevice_);
651template <
typename User,
typename UserCoord>
655template <
typename User,
typename UserCoord>
656template <
typename Adapter>
658 const User &in, User *&out,
662 ArrayRCP<gno_t> importList;
665 Zoltan2::getImportList<Adapter, TpetraRowMatrixAdapter<User, UserCoord>>(
666 solution,
this, importList);
671 RCP<User> outPtr = doMigration(in, numNewRows, importList.getRawPtr());
677template <
typename User,
typename UserCoord>
678template <
typename Adapter>
680 const User &in, RCP<User> &out,
684 ArrayRCP<gno_t> importList;
687 Zoltan2::getImportList<Adapter, TpetraRowMatrixAdapter<User, UserCoord>>(
688 solution,
this, importList);
693 out = doMigration(in, numNewRows, importList.getRawPtr());
697template <
typename User,
typename UserCoord>
699 const User &from,
size_t numLocalRows,
const gno_t *myNewRows)
const {
700 typedef Tpetra::Map<lno_t, gno_t, node_t>
map_t;
701 typedef Tpetra::CrsMatrix<scalar_t, lno_t, gno_t, node_t> tcrsmatrix_t;
712 const tcrsmatrix_t *pCrsMatrix =
dynamic_cast<const tcrsmatrix_t *
>(&from);
715 throw std::logic_error(
"TpetraRowMatrixAdapter cannot migrate data for "
716 "your RowMatrix; it can migrate data only for "
717 "Tpetra::CrsMatrix. "
718 "You can inherit from TpetraRowMatrixAdapter and "
719 "implement migration for your RowMatrix.");
723 const RCP<const map_t> &smap = from.getRowMap();
724 gno_t numGlobalRows = smap->getGlobalNumElements();
725 gno_t base = smap->getMinAllGlobalIndex();
728 ArrayView<const gno_t> rowList(myNewRows, numLocalRows);
729 const RCP<const Teuchos::Comm<int>> &comm = from.getComm();
730 RCP<const map_t> tmap = rcp(
new map_t(numGlobalRows, rowList, base, comm));
733 Tpetra::Import<lno_t, gno_t, node_t> importer(smap, tmap);
735 int oldNumElts = smap->getLocalNumElements();
736 int newNumElts = numLocalRows;
739 typedef Tpetra::Vector<scalar_t, lno_t, gno_t, node_t> vector_t;
740 vector_t numOld(smap);
741 vector_t numNew(tmap);
742 for (
int lid = 0; lid < oldNumElts; lid++) {
743 numOld.replaceGlobalValue(smap->getGlobalElement(lid),
744 scalar_t(from.getNumEntriesInLocalRow(lid)));
746 numNew.doImport(numOld, importer, Tpetra::INSERT);
749 ArrayRCP<size_t> nnz(newNumElts);
750 if (newNumElts > 0) {
751 ArrayRCP<scalar_t> ptr = numNew.getDataNonConst(0);
752 for (
int lid = 0; lid < newNumElts; lid++) {
753 nnz[lid] =
static_cast<size_t>(ptr[lid]);
757 RCP<tcrsmatrix_t> M = rcp(
new tcrsmatrix_t(tmap, nnz()));
759 M->doImport(from, importer, Tpetra::INSERT);
762 return Teuchos::rcp_dynamic_cast<User>(M);
#define Z2_FORWARD_EXCEPTIONS
Forward an exception back through call stack.
Defines the MatrixAdapter interface.
Helper functions for Partitioning Problems.
This file defines the StridedData class.
typename InputTraits< User >::node_t node_t
typename InputTraits< User >::scalar_t scalar_t
typename InputTraits< User >::gno_t gno_t
typename Kokkos::HostSpace::memory_space host_t
typename InputTraits< User >::offset_t offset_t
typename InputTraits< User >::part_t part_t
typename node_t::device_type device_t
MatrixAdapter defines the adapter interface for matrices.
A PartitioningSolution is a solution to a partitioning problem.
The StridedData class manages lists of weights or coordinates.
Provides access for Zoltan2 to Tpetra::RowMatrix data.
Base::WeightsDeviceView rowWeightsDevice_
Base::ConstIdsDeviceView colIdsDevice_
size_t getLocalNumRows() const override
Returns the number of rows on this process.
void getRowWeightsHostView(typename Base::WeightsHostView &weights) const override
void getCRSView(ArrayRCP< const offset_t > &offsets, ArrayRCP< const gno_t > &colIds, ArrayRCP< const scalar_t > &values) const override
Base::ConstIdsHostView colIdsHost_
bool mayHaveDiagonalEntries
TpetraRowMatrixAdapter(const RCP< const User > &inmatrix, int nWeightsPerRow=0)
Constructor.
void getRowIDsHostView(typename Base::ConstIdsHostView &rowIds) const override
size_t getLocalNumColumns() const override
Returns the number of columns on this process.
ArrayRCP< StridedData< lno_t, scalar_t > > rowWeights_
TpetraRowMatrixAdapter(int nWeightsPerRow, const RCP< const User > &inmatrix)
bool CRSViewAvailable() const override
Indicates whether the MatrixAdapter implements a view of the matrix in compressed sparse row (CRS) fo...
void setRowWeightIsNumberOfNonZeros(int idx)
Specify an index for which the row weight should be the global number of nonzeros in the row.
Kokkos::View< bool *, host_t > numNzWeight_
Base::ConstOffsetsDeviceView offsDevice_
size_t getLocalNumEntries() const override
Returns the number of nonzeros on this process.
void setWeights(const scalar_t *weightVal, int stride, int idx=0)
Specify a weight for each entity of the primaryEntityType.
void getRowIDsView(const gno_t *&rowIds) const override
void getRowWeightsHostView(typename Base::WeightsHostView1D &weights, int idx=0) const override
void setRowWeightsDevice(typename Base::ConstWeightsDeviceView1D val, int idx)
Provide a device view to row weights.
void setWeightIsDegree(int idx)
Specify an index for which the weight should be the degree of the entity.
int getNumWeightsPerRow() const override
Returns the number of weights per row (0 or greater). Row weights may be used when partitioning matri...
RCP< const User > matrix_
void getCRSHostView(typename Base::ConstOffsetsHostView &offsets, typename Base::ConstIdsHostView &colIds) const override
ArrayRCP< offset_t > offset_
void getRowWeightsDeviceView(typename Base::WeightsDeviceView &weights) const override
void getCRSDeviceView(typename Base::ConstOffsetsDeviceView &offsets, typename Base::ConstIdsDeviceView &colIds) const override
void getRowWeightsDeviceView(typename Base::WeightsDeviceView1D &weights, int idx=0) const override
Base::ScalarsHostView valuesHost_
void setRowWeightsHost(typename Base::ConstWeightsHostView1D val, int idx)
Provide a host view to row weights.
void setWeightsHost(typename Base::ConstWeightsHostView1D val, int idx)
Provide a host view of weights for the primary entity type.
void getCRSDeviceView(typename Base::ConstOffsetsDeviceView &offsets, typename Base::ConstIdsDeviceView &colIds, typename Base::ConstScalarsDeviceView &values) const override
ArrayRCP< scalar_t > values_
void getCRSHostView(typename Base::ConstOffsetsHostView &offsets, typename Base::ConstIdsHostView &colIds, typename Base::ConstScalarsHostView &values) const override
void getRowWeightsView(const scalar_t *&weights, int &stride, int idx=0) const override
Provide a pointer to the row weights, if any.
virtual RCP< User > doMigration(const User &from, size_t numLocalRows, const gno_t *myNewRows) const
void setRowWeights(const scalar_t *weightVal, int stride, int idx=0)
Specify a weight for each row.
void getRowIDsDeviceView(typename Base::ConstIdsDeviceView &rowIds) const override
void applyPartitioningSolution(const User &in, RCP< User > &out, const PartitioningSolution< Adapter > &solution) const
Base::ConstOffsetsHostView offsHost_
void setWeightsDevice(typename Base::ConstWeightsDeviceView1D val, int idx)
Provide a device view of weights for the primary entity type.
void applyPartitioningSolution(const User &in, User *&out, const PartitioningSolution< Adapter > &solution) const
void getCRSView(ArrayRCP< const offset_t > &offsets, ArrayRCP< const gno_t > &colIds) const override
ArrayRCP< gno_t > columnIds_
bool useNumNonzerosAsRowWeight(int idx) const override
Indicate whether row weight with index idx should be the global number of nonzeros in the row.
Base::ScalarsDeviceView valuesDevice_
map_t::local_ordinal_type lno_t
map_t::global_ordinal_type gno_t
Created by mbenlioglu on Aug 31, 2020.
static void AssertCondition(bool condition, const std::string &message, const char *file=__FILE__, int line=__LINE__)