14#ifndef _ZOLTAN2_XPETRACRSMATRIXADAPTER_HPP_
15#define _ZOLTAN2_XPETRACRSMATRIXADAPTER_HPP_
22#include <Xpetra_CrsMatrix.hpp>
51template <
typename User,
typename UserCoord=User>
55#ifndef DOXYGEN_SHOULD_SKIP_THIS
62 using xmatrix_t = Xpetra::CrsMatrix<scalar_t, lno_t, gno_t, node_t>;
64 using userCoord_t = UserCoord;
78 int nWeightsPerRow=0);
130 return matrix_->getLocalNumRows();
134 return matrix_->getLocalNumCols();
138 return matrix_->getLocalNumEntries();
143 ArrayView<const gno_t> rowView = rowMap_->getLocalElementList();
144 rowIds = rowView.getRawPtr();
149 ArrayView<const gno_t> colView = colMap_->getLocalElementList();
150 colIds = colView.getRawPtr();
153 void getCRSView(ArrayRCP<const offset_t> &offsets, ArrayRCP<const gno_t> &colIds)
const
155 ArrayRCP< const lno_t > localColumnIds;
156 ArrayRCP<const scalar_t> values;
157 matrix_->getAllValues(offsets,localColumnIds,values);
164 ArrayRCP<const gno_t> &colIds,
165 ArrayRCP<const scalar_t> &values)
const {
166 ArrayRCP< const lno_t > localColumnIds;
167 matrix_->getAllValues(offsets,localColumnIds,values);
172 ArrayRCP<const gno_t> &rowIds)
const override {
173 ArrayRCP<const offset_t> crsOffsets;
174 ArrayRCP<const lno_t> crsLocalColumnIds;
175 ArrayRCP<const scalar_t> values;
176 matrix_->getAllValues(crsOffsets, crsLocalColumnIds, values);
178 const auto localRowIds = rowMap_->getLocalElementList();
179 const auto numLocalCols = colMap_->getLocalNumElements();
182 auto determineRow = [&crsOffsets, &localRowIds](
const int columnIdx) {
184 for (
int rowIdx = 0; rowIdx < localRowIds.size(); ++rowIdx) {
185 if (rowIdx < (localRowIds.size() - 1)) {
186 if (
static_cast<offset_t>(columnIdx) < crsOffsets[rowIdx + 1]) {
199 std::vector<std::vector<gno_t>> rowIDsPerCol(numLocalCols);
201 for (
int colIdx = 0; colIdx < crsLocalColumnIds.size(); ++colIdx) {
202 const auto colID = crsLocalColumnIds[colIdx];
203 const auto globalRow = rowMap_->getGlobalElement(determineRow(colIdx));
205 rowIDsPerCol[colID].push_back(globalRow);
208 size_t offsetWrite = 0;
209 ArrayRCP<gno_t> ccsRowIds(values.size());
210 ArrayRCP<offset_t> ccsOffsets(colMap_->getLocalNumElements() + 1);
213 for (int64_t colID = 1; colID < ccsOffsets.size(); ++colID) {
214 const auto &rowIDs = rowIDsPerCol[colID - 1];
216 if (not rowIDs.empty()) {
217 std::copy(rowIDs.begin(), rowIDs.end(),
218 ccsRowIds.begin() + offsetWrite);
219 offsetWrite += rowIDs.size();
222 ccsOffsets[colID] = offsetWrite;
225 ccsOffsets[numLocalCols] = crsLocalColumnIds.size();
228 offsets = ccsOffsets;
236 if(idx<0 || idx >= nWeightsPerRow_)
238 std::ostringstream emsg;
239 emsg << __FILE__ <<
":" << __LINE__
240 <<
" Invalid row weight index " << idx << std::endl;
241 throw std::runtime_error(emsg.str());
245 rowWeights_[idx].getStridedList(length,
weights, stride);
250 template <
typename Adapter>
254 template <
typename Adapter>
260 RCP<const User> inmatrix_;
261 RCP<const xmatrix_t> matrix_;
262 RCP<const Xpetra::Map<lno_t, gno_t, node_t> > rowMap_;
263 RCP<const Xpetra::Map<lno_t, gno_t, node_t> > colMap_;
265 ArrayRCP<gno_t> columnIds_;
268 ArrayRCP<StridedData<lno_t, scalar_t> > rowWeights_;
269 ArrayRCP<bool> numNzWeight_;
271 bool mayHaveDiagonalEntries;
278template <
typename User,
typename UserCoord>
280 const RCP<const User> &inmatrix,
int nWeightsPerRow):
281 inmatrix_(inmatrix), matrix_(), rowMap_(), colMap_(),
283 nWeightsPerRow_(nWeightsPerRow), rowWeights_(), numNzWeight_(),
284 mayHaveDiagonalEntries(true)
288 matrix_ = rcp_const_cast<const xmatrix_t>(
293 rowMap_ = matrix_->getRowMap();
294 colMap_ = matrix_->getColMap();
296 size_t nrows = matrix_->getLocalNumRows();
297 size_t nnz = matrix_->getLocalNumEntries();
300 ArrayRCP< const offset_t > offset;
301 ArrayRCP< const lno_t > localColumnIds;
302 ArrayRCP< const scalar_t > values;
303 matrix_->getAllValues(offset,localColumnIds,values);
304 columnIds_.resize(nnz, 0);
306 for (
offset_t i = 0; i < offset[nrows]; i++) {
307 columnIds_[i] = colMap_->getGlobalElement(localColumnIds[i]);
310 if (nWeightsPerRow_ > 0){
311 rowWeights_ = arcp(
new input_t [nWeightsPerRow_], 0, nWeightsPerRow_,
true);
312 numNzWeight_ = arcp(
new bool [nWeightsPerRow_], 0, nWeightsPerRow_,
true);
313 for (
int i=0; i < nWeightsPerRow_; i++)
314 numNzWeight_[i] =
false;
319template <
typename User,
typename UserCoord>
321 const scalar_t *weightVal,
int stride,
int idx)
323 if (this->getPrimaryEntityType() ==
MATRIX_ROW)
324 setRowWeights(weightVal, stride, idx);
327 std::ostringstream emsg;
328 emsg << __FILE__ <<
"," << __LINE__
329 <<
" error: setWeights not yet supported for"
330 <<
" columns or nonzeros."
332 throw std::runtime_error(emsg.str());
337template <
typename User,
typename UserCoord>
339 const scalar_t *weightVal,
int stride,
int idx)
342 if(idx<0 || idx >= nWeightsPerRow_)
344 std::ostringstream emsg;
345 emsg << __FILE__ <<
":" << __LINE__
346 <<
" Invalid row weight index " << idx << std::endl;
347 throw std::runtime_error(emsg.str());
350 size_t nvtx = getLocalNumRows();
351 ArrayRCP<const scalar_t> weightV(weightVal, 0, nvtx*stride,
false);
352 rowWeights_[idx] = input_t(weightV, stride);
356template <
typename User,
typename UserCoord>
360 if (this->getPrimaryEntityType() ==
MATRIX_ROW)
361 setRowWeightIsNumberOfNonZeros(idx);
364 std::ostringstream emsg;
365 emsg << __FILE__ <<
"," << __LINE__
366 <<
" error: setWeightIsNumberOfNonZeros not yet supported for"
367 <<
" columns" << std::endl;
368 throw std::runtime_error(emsg.str());
373template <
typename User,
typename UserCoord>
377 if(idx<0 || idx >= nWeightsPerRow_)
379 std::ostringstream emsg;
380 emsg << __FILE__ <<
":" << __LINE__
381 <<
" Invalid row weight index " << idx << std::endl;
382 throw std::runtime_error(emsg.str());
386 numNzWeight_[idx] =
true;
390template <
typename User,
typename UserCoord>
391 template <
typename Adapter>
393 const User &in, User *&out,
398 ArrayRCP<gno_t> importList;
402 (solution,
this, importList);
408 importList.getRawPtr());
409 out =
const_cast<User *
>(outPtr.get());
414template <
typename User,
typename UserCoord>
415 template <
typename Adapter>
417 const User &in, RCP<User> &out,
422 ArrayRCP<gno_t> importList;
426 (solution,
this, importList);
432 importList.getRawPtr());
#define Z2_FORWARD_EXCEPTIONS
Forward an exception back through call stack.
Defines the MatrixAdapter interface.
Helper functions for Partitioning Problems.
This file defines the StridedData class.
Traits of Xpetra classes, including migration method.
typename InputTraits< User >::scalar_t scalar_t
typename InputTraits< User >::offset_t offset_t
typename InputTraits< User >::part_t part_t
MatrixAdapter defines the adapter interface for matrices.
A PartitioningSolution is a solution to a partitioning problem.
The StridedData class manages lists of weights or coordinates.
Provides access for Zoltan2 to Xpetra::CrsMatrix data.
void applyPartitioningSolution(const User &in, User *&out, const PartitioningSolution< Adapter > &solution) const
bool CRSViewAvailable() const
Indicates whether the MatrixAdapter implements a view of the matrix in compressed sparse row (CRS) fo...
bool useNumNonzerosAsRowWeight(int idx) const
Indicate whether row weight with index idx should be the global number of nonzeros in the row.
void getCRSView(ArrayRCP< const offset_t > &offsets, ArrayRCP< const gno_t > &colIds) const
void getRowIDsView(const gno_t *&rowIds) const
~XpetraCrsMatrixAdapter()
Destructor.
size_t getLocalNumColumns() const
Returns the number of columns on this process.
void getCCSView(ArrayRCP< const offset_t > &offsets, ArrayRCP< const gno_t > &rowIds) const override
void getColumnIDsView(const gno_t *&colIds) const
size_t getLocalNumEntries() const
Returns the number of nonzeros on this process.
void setRowWeights(const scalar_t *weightVal, int stride, int idx=0)
Specify a weight for each row.
size_t getLocalNumRows() const
Returns the number of rows on this process.
int getNumWeightsPerRow() const
Returns the number of weights per row (0 or greater). Row weights may be used when partitioning matri...
void setWeights(const scalar_t *weightVal, int stride, int idx=0)
Specify a weight for each entity of the primaryEntityType.
void getRowWeightsView(const scalar_t *&weights, int &stride, int idx=0) const
Provide a pointer to the row weights, if any.
void getCRSView(ArrayRCP< const offset_t > &offsets, ArrayRCP< const gno_t > &colIds, ArrayRCP< const scalar_t > &values) const
void setWeightIsDegree(int idx)
Specify an index for which the weight should be the degree of the entity.
void setRowWeightIsNumberOfNonZeros(int idx)
Specify an index for which the row weight should be the global number of nonzeros in the row.
XpetraCrsMatrixAdapter(const RCP< const User > &inmatrix, int nWeightsPerRow=0)
Constructor.
map_t::local_ordinal_type lno_t
map_t::global_ordinal_type gno_t
Created by mbenlioglu on Aug 31, 2020.
size_t getImportList(const PartitioningSolution< SolutionAdapter > &solution, const DataAdapter *const data, ArrayRCP< typename DataAdapter::gno_t > &imports)
From a PartitioningSolution, get a list of IDs to be imported. Assumes part numbers in PartitioningSo...
Defines the traits required for Tpetra, Eptra and Xpetra objects.
static RCP< User > doMigration(const User &from, size_t numLocalRows, const gno_t *myNewRows)
Migrate the object Given a user object and a new row distribution, create and return a new user objec...