14#ifndef _ZOLTAN2_TPETRAROWMATRIXADAPTER_HPP_
15#define _ZOLTAN2_TPETRAROWMATRIXADAPTER_HPP_
21#include <Tpetra_RowMatrix.hpp>
41template <
typename User,
typename UserCoord = User>
46#ifndef DOXYGEN_SHOULD_SKIP_THIS
53 using device_t =
typename node_t::device_type;
54 using host_t =
typename Kokkos::HostSpace::memory_space;
56 using userCoord_t = UserCoord;
67 int nWeightsPerRow = 0);
182 typename Base::ConstIdsHostView &rowIds)
const override;
185 typename Base::ConstIdsDeviceView &rowIds)
const override;
188 ArrayRCP<const gno_t> &colIds)
const;
191 typename Base::ConstOffsetsHostView &offsets,
192 typename Base::ConstIdsHostView &colIds)
const override;
195 typename Base::ConstOffsetsDeviceView &offsets,
196 typename Base::ConstIdsDeviceView &colIds)
const override;
199 ArrayRCP<const gno_t> &colIds,
200 ArrayRCP<const scalar_t> &values)
const;
203 typename Base::ConstOffsetsHostView &offsets,
204 typename Base::ConstIdsHostView &colIds,
205 typename Base::ConstScalarsHostView &values)
const override;
208 typename Base::ConstOffsetsDeviceView &offsets,
209 typename Base::ConstIdsDeviceView &colIds,
210 typename Base::ConstScalarsDeviceView &values)
const override;
221 typename Base::WeightsDeviceView &
weights)
const override;
227 typename Base::WeightsHostView &
weights)
const override;
231 template <
typename Adapter>
233 const User &in, User *&out,
236 template <
typename Adapter>
238 const User &in, RCP<User> &out,
244 const RCP<const User> &inmatrix)
268 virtual RCP<User>
doMigration(
const User &from,
size_t numLocalRows,
269 const gno_t *myNewRows)
const;
276template <
typename User,
typename UserCoord>
278 const RCP<const User> &inmatrix,
int nWeightsPerRow):
279 matrix_(inmatrix), offset_(), columnIds_(),
280 nWeightsPerRow_(nWeightsPerRow), rowWeights_(),
281 mayHaveDiagonalEntries(true) {
283 using localInds_t =
typename User::nonconst_local_inds_host_view_type;
284 using localVals_t =
typename User::nonconst_values_host_view_type;
286 const auto nrows =
matrix_->getLocalNumRows();
287 const auto nnz =
matrix_->getLocalNumEntries();
288 auto maxNumEntries =
matrix_->getLocalMaxNumRowEntries();
293 colIdsHost_ =
typename Base::ConstIdsHostView(
"colIdsHost_", nnz);
294 offsHost_ =
typename Base::ConstOffsetsHostView(
"offsHost_", nrows + 1);
295 valuesHost_ =
typename Base::ScalarsHostView(
"valuesHost_", nnz);
297 localInds_t localColInds(
"localColInds", maxNumEntries);
298 localVals_t localVals(
"localVals", maxNumEntries);
300 for (
size_t r = 0; r < nrows; r++) {
301 size_t numEntries = 0;
302 matrix_->getLocalRowCopy(r, localColInds, localVals, numEntries);
308 for (
size_t j = 0; j < numEntries; j++) {
336template <
typename User,
typename UserCoord>
338 const scalar_t *weightVal,
int stride,
int idx) {
339 if (this->getPrimaryEntityType() ==
MATRIX_ROW)
340 setRowWeights(weightVal, stride, idx);
343 std::ostringstream emsg;
344 emsg << __FILE__ <<
"," << __LINE__
345 <<
" error: setWeights not yet supported for"
346 <<
" columns or nonzeros." << std::endl;
347 throw std::runtime_error(emsg.str());
352template <
typename User,
typename UserCoord>
354 typename Base::ConstWeightsDeviceView1D val,
int idx) {
355 if (this->getPrimaryEntityType() ==
MATRIX_ROW)
356 setRowWeightsDevice(val, idx);
359 std::ostringstream emsg;
360 emsg << __FILE__ <<
"," << __LINE__
361 <<
" error: setWeights not yet supported for"
362 <<
" columns or nonzeros." << std::endl;
363 throw std::runtime_error(emsg.str());
368template <
typename User,
typename UserCoord>
370 typename Base::ConstWeightsHostView1D val,
int idx) {
371 if (this->getPrimaryEntityType() ==
MATRIX_ROW)
372 setRowWeightsHost(val, idx);
375 std::ostringstream emsg;
376 emsg << __FILE__ <<
"," << __LINE__
377 <<
" error: setWeights not yet supported for"
378 <<
" columns or nonzeros." << std::endl;
379 throw std::runtime_error(emsg.str());
384template <
typename User,
typename UserCoord>
386 const scalar_t *weightVal,
int stride,
int idx) {
389 "Invalid row weight index: " + std::to_string(idx));
391 size_t nrows = getLocalNumRows();
392 ArrayRCP<const scalar_t> weightV(weightVal, 0, nrows * stride,
false);
393 rowWeights_[idx] = input_t(weightV, stride);
397template <
typename User,
typename UserCoord>
399 typename Base::ConstWeightsDeviceView1D
weights,
int idx) {
402 "Invalid row weight index: " + std::to_string(idx));
404 auto rowWeightsDevice = this->rowWeightsDevice_;
405 Kokkos::parallel_for(
406 rowWeightsDevice.extent(0), KOKKOS_LAMBDA(
const int rowID) {
407 rowWeightsDevice(rowID, idx) =
weights(rowID);
414template <
typename User,
typename UserCoord>
416 typename Base::ConstWeightsHostView1D weightsHost,
int idx) {
418 "Invalid row weight index: " + std::to_string(idx));
420 auto weightsDevice = Kokkos::create_mirror_view_and_copy(
421 typename Base::device_t(), weightsHost);
423 setRowWeightsDevice(weightsDevice, idx);
427template <
typename User,
typename UserCoord>
429 if (this->getPrimaryEntityType() ==
MATRIX_ROW)
430 setRowWeightIsNumberOfNonZeros(idx);
433 std::ostringstream emsg;
434 emsg << __FILE__ <<
"," << __LINE__
435 <<
" error: setWeightIsNumberOfNonZeros not yet supported for"
436 <<
" columns" << std::endl;
437 throw std::runtime_error(emsg.str());
442template <
typename User,
typename UserCoord>
445 if (idx < 0 || idx >= nWeightsPerRow_) {
446 std::ostringstream emsg;
447 emsg << __FILE__ <<
":" << __LINE__ <<
" Invalid row weight index " << idx
449 throw std::runtime_error(emsg.str());
452 numNzWeight_(idx) =
true;
456template <
typename User,
typename UserCoord>
458 return matrix_->getLocalNumRows();
462template <
typename User,
typename UserCoord>
464 return matrix_->getLocalNumCols();
468template <
typename User,
typename UserCoord>
470 return matrix_->getLocalNumEntries();
474template <
typename User,
typename UserCoord>
478template <
typename User,
typename UserCoord>
480 ArrayView<const gno_t> rowView = matrix_->getRowMap()->getLocalElementList();
481 rowIds = rowView.getRawPtr();
485template <
typename User,
typename UserCoord>
487 typename Base::ConstIdsHostView &rowIds)
const {
488 auto idsDevice = matrix_->getRowMap()->getMyGlobalIndices();
489 auto tmpIds =
typename Base::IdsHostView(
"", idsDevice.extent(0));
491 Kokkos::deep_copy(tmpIds, idsDevice);
497template <
typename User,
typename UserCoord>
499 typename Base::ConstIdsDeviceView &rowIds)
const {
501 auto idsDevice = matrix_->getRowMap()->getMyGlobalIndices();
502 auto tmpIds =
typename Base::IdsDeviceView(
"", idsDevice.extent(0));
504 Kokkos::deep_copy(tmpIds, idsDevice);
510template <
typename User,
typename UserCoord>
512 ArrayRCP<const gno_t> &colIds)
const {
518template <
typename User,
typename UserCoord>
520 typename Base::ConstOffsetsHostView &offsets,
521 typename Base::ConstIdsHostView &colIds)
const {
522 auto hostOffsets = Kokkos::create_mirror_view(offsDevice_);
523 Kokkos::deep_copy(hostOffsets, offsDevice_);
524 offsets = hostOffsets;
526 auto hostColIds = Kokkos::create_mirror_view(colIdsDevice_);
527 Kokkos::deep_copy(hostColIds, colIdsDevice_);
532template <
typename User,
typename UserCoord>
534 typename Base::ConstOffsetsDeviceView &offsets,
535 typename Base::ConstIdsDeviceView &colIds)
const {
536 offsets = offsDevice_;
537 colIds = colIdsDevice_;
541template <
typename User,
typename UserCoord>
543 ArrayRCP<const gno_t> &colIds,
544 ArrayRCP<const scalar_t> &values)
const {
551template <
typename User,
typename UserCoord>
553 typename Base::ConstOffsetsHostView &offsets,
554 typename Base::ConstIdsHostView &colIds,
555 typename Base::ConstScalarsHostView &values)
const {
556 auto hostOffsets = Kokkos::create_mirror_view(offsDevice_);
557 Kokkos::deep_copy(hostOffsets, offsDevice_);
558 offsets = hostOffsets;
560 auto hostColIds = Kokkos::create_mirror_view(colIdsDevice_);
561 Kokkos::deep_copy(hostColIds, colIdsDevice_);
564 auto hostValues = Kokkos::create_mirror_view(valuesDevice_);
565 Kokkos::deep_copy(hostValues, valuesDevice_);
570template <
typename User,
typename UserCoord>
572 typename Base::ConstOffsetsDeviceView &offsets,
573 typename Base::ConstIdsDeviceView &colIds,
574 typename Base::ConstScalarsDeviceView &values)
const {
575 offsets = offsDevice_;
576 colIds = colIdsDevice_;
577 values = valuesDevice_;
581template <
typename User,
typename UserCoord>
585template <
typename User,
typename UserCoord>
588 if (idx < 0 || idx >= nWeightsPerRow_) {
589 std::ostringstream emsg;
590 emsg << __FILE__ <<
":" << __LINE__ <<
" Invalid row weight index "
592 throw std::runtime_error(emsg.str());
596 rowWeights_[idx].getStridedList(length,
weights, stride);
600template <
typename User,
typename UserCoord>
602 typename Base::WeightsDeviceView1D &
weights,
int idx)
const {
604 "Invalid row weight index.");
606 const auto size = rowWeightsDevice_.extent(0);
607 weights =
typename Base::WeightsDeviceView1D(
"weights", size);
609 auto rowWeightsDevice = this->rowWeightsDevice_;
610 Kokkos::parallel_for(
611 size, KOKKOS_LAMBDA(
const int id) {
612 weights(
id) = rowWeightsDevice(
id, idx);
619template <
typename User,
typename UserCoord>
621 typename Base::WeightsDeviceView &
weights)
const {
627template <
typename User,
typename UserCoord>
629 typename Base::WeightsHostView1D &
weights,
int idx)
const {
631 "Invalid row weight index.");
633 auto weightsDevice =
typename Base::WeightsDeviceView1D(
634 "weights", rowWeightsDevice_.extent(0));
635 getRowWeightsDeviceView(weightsDevice, idx);
637 weights = Kokkos::create_mirror_view(weightsDevice);
638 Kokkos::deep_copy(
weights, weightsDevice);
642template <
typename User,
typename UserCoord>
644 typename Base::WeightsHostView &
weights)
const {
646 weights = Kokkos::create_mirror_view(rowWeightsDevice_);
647 Kokkos::deep_copy(
weights, rowWeightsDevice_);
651template <
typename User,
typename UserCoord>
655template <
typename User,
typename UserCoord>
656template <
typename Adapter>
658 const User &in, User *&out,
662 ArrayRCP<gno_t> importList;
665 Zoltan2::getImportList<Adapter, TpetraRowMatrixAdapter<User, UserCoord>>(
666 solution,
this, importList);
671 RCP<User> outPtr = doMigration(in, numNewRows, importList.getRawPtr());
677template <
typename User,
typename UserCoord>
678template <
typename Adapter>
680 const User &in, RCP<User> &out,
684 ArrayRCP<gno_t> importList;
687 Zoltan2::getImportList<Adapter, TpetraRowMatrixAdapter<User, UserCoord>>(
688 solution,
this, importList);
693 out = doMigration(in, numNewRows, importList.getRawPtr());
697template <
typename User,
typename UserCoord>
699 const User &from,
size_t numLocalRows,
const gno_t *myNewRows)
const {
700 typedef Tpetra::Map<lno_t, gno_t, node_t>
map_t;
701 typedef Tpetra::CrsMatrix<scalar_t, lno_t, gno_t, node_t> tcrsmatrix_t;
712 const tcrsmatrix_t *pCrsMatrix =
dynamic_cast<const tcrsmatrix_t *
>(&from);
715 throw std::logic_error(
"TpetraRowMatrixAdapter cannot migrate data for "
716 "your RowMatrix; it can migrate data only for "
717 "Tpetra::CrsMatrix. "
718 "You can inherit from TpetraRowMatrixAdapter and "
719 "implement migration for your RowMatrix.");
723 const RCP<const map_t> &smap = from.getRowMap();
724 gno_t numGlobalRows = smap->getGlobalNumElements();
725 gno_t base = smap->getMinAllGlobalIndex();
728 ArrayView<const gno_t> rowList(myNewRows, numLocalRows);
729 const RCP<const Teuchos::Comm<int>> &comm = from.getComm();
730 RCP<const map_t> tmap = rcp(
new map_t(numGlobalRows, rowList, base, comm));
733 Tpetra::Import<lno_t, gno_t, node_t> importer(smap, tmap);
735 int oldNumElts = smap->getLocalNumElements();
736 int newNumElts = numLocalRows;
739 typedef Tpetra::Vector<scalar_t, lno_t, gno_t, node_t> vector_t;
740 vector_t numOld(smap);
741 vector_t numNew(tmap);
742 for (
int lid = 0; lid < oldNumElts; lid++) {
743 numOld.replaceGlobalValue(smap->getGlobalElement(lid),
744 scalar_t(from.getNumEntriesInLocalRow(lid)));
746 numNew.doImport(numOld, importer, Tpetra::INSERT);
749 ArrayRCP<size_t> nnz(newNumElts);
750 if (newNumElts > 0) {
751 ArrayRCP<scalar_t> ptr = numNew.getDataNonConst(0);
752 for (
int lid = 0; lid < newNumElts; lid++) {
753 nnz[lid] =
static_cast<size_t>(ptr[lid]);
757 RCP<tcrsmatrix_t> M = rcp(
new tcrsmatrix_t(tmap, nnz()));
759 M->doImport(from, importer, Tpetra::INSERT);
762 return Teuchos::rcp_dynamic_cast<User>(M);
#define Z2_FORWARD_EXCEPTIONS
Forward an exception back through call stack.
Defines the MatrixAdapter interface.
Helper functions for Partitioning Problems.
This file defines the StridedData class.
typename InputTraits< User >::node_t node_t
typename InputTraits< User >::scalar_t scalar_t
typename InputTraits< User >::gno_t gno_t
typename Kokkos::HostSpace::memory_space host_t
typename InputTraits< User >::offset_t offset_t
typename InputTraits< User >::part_t part_t
typename node_t::device_type device_t
MatrixAdapter defines the adapter interface for matrices.
A PartitioningSolution is a solution to a partitioning problem.
The StridedData class manages lists of weights or coordinates.
Provides access for Zoltan2 to Tpetra::RowMatrix data.
bool CRSViewAvailable() const
Indicates whether the MatrixAdapter implements a view of the matrix in compressed sparse row (CRS) fo...
Base::WeightsDeviceView rowWeightsDevice_
size_t getLocalNumEntries() const
Returns the number of nonzeros on this process.
size_t getLocalNumColumns() const
Returns the number of columns on this process.
Base::ConstIdsDeviceView colIdsDevice_
void getRowWeightsHostView(typename Base::WeightsHostView &weights) const override
size_t getLocalNumRows() const
Returns the number of rows on this process.
Base::ConstIdsHostView colIdsHost_
bool mayHaveDiagonalEntries
TpetraRowMatrixAdapter(const RCP< const User > &inmatrix, int nWeightsPerRow=0)
Constructor.
void getRowIDsHostView(typename Base::ConstIdsHostView &rowIds) const override
void getCRSView(ArrayRCP< const offset_t > &offsets, ArrayRCP< const gno_t > &colIds, ArrayRCP< const scalar_t > &values) const
ArrayRCP< StridedData< lno_t, scalar_t > > rowWeights_
void getRowWeightsView(const scalar_t *&weights, int &stride, int idx=0) const
Provide a pointer to the row weights, if any.
TpetraRowMatrixAdapter(int nWeightsPerRow, const RCP< const User > &inmatrix)
void setRowWeightIsNumberOfNonZeros(int idx)
Specify an index for which the row weight should be the global number of nonzeros in the row.
void getRowWeightsHostView(typename Base::WeightsHostView1D &weights, int idx=0) const
Kokkos::View< bool *, host_t > numNzWeight_
Base::ConstOffsetsDeviceView offsDevice_
void setWeights(const scalar_t *weightVal, int stride, int idx=0)
Specify a weight for each entity of the primaryEntityType.
void getRowIDsView(const gno_t *&rowIds) const override
void getCRSView(ArrayRCP< const offset_t > &offsets, ArrayRCP< const gno_t > &colIds) const
void setRowWeightsDevice(typename Base::ConstWeightsDeviceView1D val, int idx)
Provide a device view to row weights.
void setWeightIsDegree(int idx)
Specify an index for which the weight should be the degree of the entity.
RCP< const User > matrix_
void getCRSHostView(typename Base::ConstOffsetsHostView &offsets, typename Base::ConstIdsHostView &colIds) const override
ArrayRCP< offset_t > offset_
void getRowWeightsDeviceView(typename Base::WeightsDeviceView &weights) const override
void getCRSDeviceView(typename Base::ConstOffsetsDeviceView &offsets, typename Base::ConstIdsDeviceView &colIds) const override
bool useNumNonzerosAsRowWeight(int idx) const
Indicate whether row weight with index idx should be the global number of nonzeros in the row.
Base::ScalarsHostView valuesHost_
void setRowWeightsHost(typename Base::ConstWeightsHostView1D val, int idx)
Provide a host view to row weights.
void setWeightsHost(typename Base::ConstWeightsHostView1D val, int idx)
Provide a host view of weights for the primary entity type.
void getCRSDeviceView(typename Base::ConstOffsetsDeviceView &offsets, typename Base::ConstIdsDeviceView &colIds, typename Base::ConstScalarsDeviceView &values) const override
ArrayRCP< scalar_t > values_
int getNumWeightsPerRow() const
Returns the number of weights per row (0 or greater). Row weights may be used when partitioning matri...
void getRowWeightsDeviceView(typename Base::WeightsDeviceView1D &weights, int idx=0) const
void getCRSHostView(typename Base::ConstOffsetsHostView &offsets, typename Base::ConstIdsHostView &colIds, typename Base::ConstScalarsHostView &values) const override
virtual RCP< User > doMigration(const User &from, size_t numLocalRows, const gno_t *myNewRows) const
void setRowWeights(const scalar_t *weightVal, int stride, int idx=0)
Specify a weight for each row.
void getRowIDsDeviceView(typename Base::ConstIdsDeviceView &rowIds) const override
void applyPartitioningSolution(const User &in, RCP< User > &out, const PartitioningSolution< Adapter > &solution) const
Base::ConstOffsetsHostView offsHost_
void setWeightsDevice(typename Base::ConstWeightsDeviceView1D val, int idx)
Provide a device view of weights for the primary entity type.
void applyPartitioningSolution(const User &in, User *&out, const PartitioningSolution< Adapter > &solution) const
ArrayRCP< gno_t > columnIds_
Base::ScalarsDeviceView valuesDevice_
map_t::local_ordinal_type lno_t
map_t::global_ordinal_type gno_t
Created by mbenlioglu on Aug 31, 2020.
static void AssertCondition(bool condition, const std::string &message, const char *file=__FILE__, int line=__LINE__)