80 RCP<Matrix> A = Get<RCP<Matrix> >(currentLevel,
"A");
81 if (Get<bool>(currentLevel,
"Filtering") ==
false) {
82 GetOStream(
Runtime0) <<
"Filtered matrix is not being constructed as no filtering is being done" << std::endl;
83 Set(currentLevel,
"A", A);
87 const ParameterList& pL = GetParameterList();
88 bool lumping = pL.get<
bool>(
"filtered matrix: use lumping");
90 GetOStream(
Runtime0) <<
"Lumping dropped entries" << std::endl;
92 bool use_spread_lumping = pL.get<
bool>(
"filtered matrix: use spread lumping");
93 if (use_spread_lumping && (!lumping))
94 throw std::runtime_error(
"Must also request 'filtered matrix: use lumping' in order to use spread lumping");
96 if (use_spread_lumping) {
97 GetOStream(
Runtime0) <<
"using spread lumping " << std::endl;
100 double DdomAllowGrowthRate = 1.1;
101 double DdomCap = 2.0;
102 if (use_spread_lumping) {
103 DdomAllowGrowthRate = pL.get<
double>(
"filtered matrix: spread lumping diag dom growth factor");
104 DdomCap = pL.get<
double>(
"filtered matrix: spread lumping diag dom cap");
106 bool use_root_stencil = lumping && pL.get<
bool>(
"filtered matrix: use root stencil");
107 if (use_root_stencil)
108 GetOStream(
Runtime0) <<
"Using root stencil for dropping" << std::endl;
109 double dirichlet_threshold = pL.get<
double>(
"filtered matrix: Dirichlet threshold");
110 if (dirichlet_threshold >= 0.0)
111 GetOStream(
Runtime0) <<
"Filtering Dirichlet threshold of " << dirichlet_threshold << std::endl;
113 if (use_root_stencil || pL.get<
bool>(
"filtered matrix: reuse graph"))
114 GetOStream(
Runtime0) <<
"Reusing graph" << std::endl;
116 GetOStream(
Runtime0) <<
"Generating new graph" << std::endl;
118 RCP<LWGraph> G = Get<RCP<LWGraph> >(currentLevel,
"Graph");
120 FILE* f = fopen(
"graph.dat",
"w");
121 size_t numGRows = G->GetNodeNumVertices();
122 for (
size_t i = 0; i < numGRows; i++) {
124 auto indsG = G->getNeighborVertices(i);
125 for (
size_t j = 0; j < (size_t)indsG.length; j++) {
126 fprintf(f,
"%d %d 1.0\n", (
int)i, (
int)indsG(j));
132 RCP<ParameterList> fillCompleteParams(
new ParameterList);
133 fillCompleteParams->set(
"No Nonlocal Changes",
true);
135 RCP<Matrix> filteredA;
136 if (use_root_stencil) {
137 filteredA = MatrixFactory::Build(A->getCrsGraph());
138 filteredA->fillComplete(fillCompleteParams);
139 filteredA->resumeFill();
140 BuildNewUsingRootStencil(*A, *G, dirichlet_threshold, currentLevel, *filteredA, use_spread_lumping, DdomAllowGrowthRate, DdomCap);
141 filteredA->fillComplete(fillCompleteParams);
143 }
else if (pL.get<
bool>(
"filtered matrix: reuse graph")) {
144 filteredA = MatrixFactory::Build(A->getCrsGraph());
145 filteredA->resumeFill();
146 BuildReuse(*A, *G, (lumping != use_spread_lumping), dirichlet_threshold, *filteredA);
150 if (use_spread_lumping) ExperimentalLumping(*A, *filteredA, DdomAllowGrowthRate, DdomCap);
151 filteredA->fillComplete(fillCompleteParams);
154 filteredA = MatrixFactory::Build(A->getRowMap(), A->getColMap(), A->getLocalMaxNumRowEntries());
155 BuildNew(*A, *G, (lumping != use_spread_lumping), dirichlet_threshold, *filteredA);
158 if (use_spread_lumping) ExperimentalLumping(*A, *filteredA, DdomAllowGrowthRate, DdomCap);
159 filteredA->fillComplete(A->getDomainMap(), A->getRangeMap(), fillCompleteParams);
163 Xpetra::IO<SC, LO, GO, NO>::Write(
"filteredA.dat", *filteredA);
166 Xpetra::IO<SC, LO, GO, NO>::Write(
"A.dat", *A);
167 RCP<Matrix> origFilteredA = MatrixFactory::Build(A->getRowMap(), A->getColMap(), A->getLocalMaxNumRowEntries());
168 BuildNew(*A, *G, lumping, dirichlet_threshold, *origFilteredA);
169 if (use_spread_lumping) ExperimentalLumping(*A, *origFilteredA, DdomAllowGrowthRate, DdomCap);
170 origFilteredA->fillComplete(A->getDomainMap(), A->getRangeMap(), fillCompleteParams);
171 Xpetra::IO<SC, LO, GO, NO>::Write(
"origFilteredA.dat", *origFilteredA);
174 filteredA->SetFixedBlockSize(A->GetFixedBlockSize());
176 if (pL.get<
bool>(
"filtered matrix: reuse eigenvalue")) {
181 filteredA->SetMaxEigenvalueEstimate(A->GetMaxEigenvalueEstimate());
184 if (pL.get<
bool>(
"filtered matrix: count negative diagonals")) {
187 GetOStream(
Runtime0) <<
"FilteredA: Negative diagonals: " << neg_count << std::endl;
190 Set(currentLevel,
"A", filteredA);
210 BuildReuse(
const Matrix& A,
const LWGraph& G,
const bool lumping,
double dirichletThresh, Matrix& filteredA)
const {
211 using TST =
typename Teuchos::ScalarTraits<SC>;
212 SC zero = TST::zero();
214 size_t blkSize = A.GetFixedBlockSize();
216 ArrayView<const LO> inds;
217 ArrayView<const SC> valsA;
218#ifdef ASSUME_DIRECT_ACCESS_TO_ROW
224 Array<char> filter(std::max(blkSize * G.
GetImportMap()->getLocalNumElements(),
225 A.getColMap()->getLocalNumElements()),
229 for (
size_t i = 0; i < numGRows; i++) {
232 for (
size_t j = 0; j < as<size_t>(indsG.length); j++)
233 for (
size_t k = 0; k < blkSize; k++)
234 filter[indsG(j) * blkSize + k] = 1;
236 for (
size_t k = 0; k < blkSize; k++) {
237 LO row = i * blkSize + k;
239 A.getLocalRowView(row, inds, valsA);
241 size_t nnz = inds.size();
245#ifdef ASSUME_DIRECT_ACCESS_TO_ROW
247 ArrayView<const SC> vals1;
248 filteredA.getLocalRowView(row, inds, vals1);
249 vals = ArrayView<SC>(
const_cast<SC*
>(vals1.getRawPtr()), nnz);
251 memcpy(vals.getRawPtr(), valsA.getRawPtr(), nnz *
sizeof(SC));
253 vals = Array<SC>(valsA);
256 SC ZERO = Teuchos::ScalarTraits<SC>::zero();
258 SC A_rowsum = ZERO, F_rowsum = ZERO;
259 for (LO l = 0; l < (LO)inds.size(); l++)
260 A_rowsum += valsA[l];
262 if (lumping ==
false) {
263 for (
size_t j = 0; j < nnz; j++)
264 if (!filter[inds[j]])
271 for (
size_t j = 0; j < nnz; j++) {
272 if (filter[inds[j]]) {
273 if (inds[j] == row) {
280 diagExtra += vals[j];
290 if (diagIndex != -1) {
292 vals[diagIndex] += diagExtra;
293 if (dirichletThresh >= 0.0 && TST::real(vals[diagIndex]) <= dirichletThresh) {
295 for (LO l = 0; l < (LO)nnz; l++)
298 vals[diagIndex] = TST::one();
303#ifndef ASSUME_DIRECT_ACCESS_TO_ROW
306 filteredA.replaceLocalValues(row, inds, vals);
311 for (
size_t j = 0; j < as<size_t>(indsG.length); j++)
312 for (
size_t k = 0; k < blkSize; k++)
313 filter[indsG(j) * blkSize + k] = 0;
414 BuildNewUsingRootStencil(
const Matrix& A,
const LWGraph& G,
double dirichletThresh,
Level& currentLevel, Matrix& filteredA,
bool use_spread_lumping,
double DdomAllowGrowthRate,
double DdomCap)
const {
415 using TST =
typename Teuchos::ScalarTraits<SC>;
416 using Teuchos::arcp_const_cast;
417 SC ZERO = Teuchos::ScalarTraits<SC>::zero();
418 SC ONE = Teuchos::ScalarTraits<SC>::one();
419 LO INVALID = Teuchos::OrdinalTraits<LO>::invalid();
422 RCP<Aggregates> aggregates = Get<RCP<Aggregates> >(currentLevel,
"Aggregates");
423 RCP<AmalgamationInfo> amalgInfo = Get<RCP<AmalgamationInfo> >(currentLevel,
"UnAmalgamationInfo");
424 LO numAggs = aggregates->GetNumAggregates();
427 size_t blkSize = A.GetFixedBlockSize();
428 size_t numRows = A.getMap()->getLocalNumElements();
429 ArrayView<const LO> indsA;
430 ArrayView<const SC> valsA;
431 ArrayRCP<const size_t> rowptr;
432 ArrayRCP<const LO> inds;
433 ArrayRCP<const SC> vals_const;
439 RCP<CrsMatrix> filteredAcrs =
dynamic_cast<const CrsMatrixWrap*
>(&filteredA)->getCrsMatrix();
440 filteredAcrs->getAllValues(rowptr, inds, vals_const);
441 vals = arcp_const_cast<SC>(vals_const);
442 Array<bool> vals_dropped_indicator(vals.size(),
false);
445 RCP<const Map> rowMap = A.getRowMap();
446 RCP<const Map> colMap = A.getColMap();
451 Array<LO> diagIndex(numRows, INVALID);
452 Array<SC> diagExtra(numRows, ZERO);
459 typename Aggregates::LO_view::host_mirror_type ptr_h, nodes_h, unaggregated_h;
461 aggregates->ComputeNodesInAggregate(nodesInAgg.ptr, nodesInAgg.nodes, nodesInAgg.unaggregated);
462 nodesInAgg.ptr_h = Kokkos::create_mirror_view(nodesInAgg.ptr);
463 nodesInAgg.nodes_h = Kokkos::create_mirror_view(nodesInAgg.nodes);
464 nodesInAgg.unaggregated_h = Kokkos::create_mirror_view(nodesInAgg.unaggregated);
465 Kokkos::deep_copy(nodesInAgg.ptr_h, nodesInAgg.ptr);
466 Kokkos::deep_copy(nodesInAgg.nodes_h, nodesInAgg.nodes);
467 Kokkos::deep_copy(nodesInAgg.unaggregated_h, nodesInAgg.unaggregated);
468 Teuchos::ArrayRCP<const LO> vertex2AggId = aggregates->GetVertex2AggId()->getData(0);
470 LO graphNumCols = G.
GetImportMap()->getLocalNumElements();
471 Array<bool> filter(graphNumCols,
false);
474 for (LO i = 0; i < (LO)nodesInAgg.unaggregated_h.extent(0); i++) {
475 for (LO m = 0; m < (LO)blkSize; m++) {
476 LO row = amalgInfo->ComputeLocalDOF(nodesInAgg.unaggregated_h(i), m);
477 if (row >= (LO)numRows)
continue;
478 size_t index_start = rowptr[row];
479 A.getLocalRowView(row, indsA, valsA);
480 for (LO k = 0; k < (LO)indsA.size(); k++) {
481 if (row == indsA[k]) {
482 vals[index_start + k] = ONE;
485 vals[index_start + k] = ZERO;
490 std::vector<LO> badCount(numAggs, 0);
494 for (LO i = 0; i < numAggs; i++)
495 maxAggSize = std::max(maxAggSize, nodesInAgg.ptr_h(i + 1) - nodesInAgg.ptr_h(i));
501 size_t numNewDrops = 0;
502 size_t numOldDrops = 0;
503 size_t numFixedDiags = 0;
504 size_t numSymDrops = 0;
506 for (LO i = 0; i < numAggs; i++) {
507 LO numNodesInAggregate = nodesInAgg.ptr_h(i + 1) - nodesInAgg.ptr_h(i);
508 if (numNodesInAggregate == 0)
continue;
511 LO root_node = INVALID;
512 for (LO k = nodesInAgg.ptr_h(i); k < nodesInAgg.ptr_h(i + 1); k++) {
513 if (aggregates->IsRoot(nodesInAgg.nodes_h(k))) {
514 root_node = nodesInAgg.nodes_h(k);
519 TEUCHOS_TEST_FOR_EXCEPTION(root_node == INVALID,
526 goodAggNeighbors.resize(0);
527 for (LO k = 0; k < (LO)goodNodeNeighbors.length; k++) {
528 goodAggNeighbors.push_back(vertex2AggId[goodNodeNeighbors(k)]);
535 badAggNeighbors.resize(0);
536 for (LO j = 0; j < (LO)blkSize; j++) {
537 LO row = amalgInfo->ComputeLocalDOF(root_node, j);
538 if (row >= (LO)numRows)
continue;
539 A.getLocalRowView(row, indsA, valsA);
540 for (LO k = 0; k < (LO)indsA.size(); k++) {
541 if ((indsA[k] < (LO)numRows) && (TST::magnitude(valsA[k]) != TST::magnitude(ZERO))) {
542 LO node = amalgInfo->ComputeLocalNode(indsA[k]);
543 LO agg = vertex2AggId[node];
544 if (!std::binary_search(goodAggNeighbors.begin(), goodAggNeighbors.end(), agg))
545 badAggNeighbors.push_back(agg);
554 for (LO k = nodesInAgg.ptr_h(i); k < nodesInAgg.ptr_h(i + 1); k++) {
556 for (LO kk = 0; kk < nodeNeighbors.length; kk++) {
557 if ((vertex2AggId[nodeNeighbors(kk)] >= 0) && (vertex2AggId[nodeNeighbors(kk)] < numAggs))
558 (badCount[vertex2AggId[nodeNeighbors(kk)]])++;
562 reallyBadAggNeighbors.resize(0);
563 for (LO k = 0; k < (LO)badAggNeighbors.size(); k++) {
564 if (badCount[badAggNeighbors[k]] <= 1) reallyBadAggNeighbors.push_back(badAggNeighbors[k]);
566 for (LO k = nodesInAgg.ptr_h(i); k < nodesInAgg.ptr_h(i + 1); k++) {
568 for (LO kk = 0; kk < nodeNeighbors.length; kk++) {
569 if ((vertex2AggId[nodeNeighbors(kk)] >= 0) && (vertex2AggId[nodeNeighbors(kk)] < numAggs))
570 badCount[vertex2AggId[nodeNeighbors(kk)]] = 0;
576 for (LO b = 0; b < (LO)reallyBadAggNeighbors.size(); b++) {
577 LO bad_agg = reallyBadAggNeighbors[b];
578 for (LO k = nodesInAgg.ptr_h(bad_agg); k < nodesInAgg.ptr_h(bad_agg + 1); k++) {
579 LO bad_node = nodesInAgg.nodes_h(k);
580 for (LO j = 0; j < (LO)blkSize; j++) {
581 LO bad_row = amalgInfo->ComputeLocalDOF(bad_node, j);
582 if (bad_row >= (LO)numRows)
continue;
583 size_t index_start = rowptr[bad_row];
584 A.getLocalRowView(bad_row, indsA, valsA);
585 for (LO l = 0; l < (LO)indsA.size(); l++) {
586 if (indsA[l] < (LO)numRows && vertex2AggId[amalgInfo->ComputeLocalNode(indsA[l])] == i && vals_dropped_indicator[index_start + l] ==
false) {
587 vals_dropped_indicator[index_start + l] =
true;
588 vals[index_start + l] = ZERO;
589 diagExtra[bad_row] += valsA[l];
600 for (LO k = nodesInAgg.ptr_h(i); k < nodesInAgg.ptr_h(i + 1); k++) {
601 LO row_node = nodesInAgg.nodes_h(k);
605 for (
size_t j = 0; j < as<size_t>(indsG.length); j++)
606 filter[indsG(j)] =
true;
608 for (LO m = 0; m < (LO)blkSize; m++) {
609 LO row = amalgInfo->ComputeLocalDOF(row_node, m);
610 if (row >= (LO)numRows)
continue;
611 size_t index_start = rowptr[row];
612 A.getLocalRowView(row, indsA, valsA);
614 for (LO l = 0; l < (LO)indsA.size(); l++) {
615 int col_node = amalgInfo->ComputeLocalNode(indsA[l]);
616 bool is_good = filter[col_node];
617 if (indsA[l] == row) {
619 vals[index_start + l] = valsA[l];
624 if (vals_dropped_indicator[index_start + l] ==
true) {
634 if (is_good && indsA[l] < (LO)numRows) {
635 int agg = vertex2AggId[col_node];
636 if (std::binary_search(reallyBadAggNeighbors.begin(), reallyBadAggNeighbors.end(), agg))
641 vals[index_start + l] = valsA[l];
643 if (!filter[col_node])
647 diagExtra[row] += valsA[l];
648 vals[index_start + l] = ZERO;
649 vals_dropped_indicator[index_start + l] =
true;
656 for (
size_t j = 0; j < as<size_t>(indsG.length); j++)
657 filter[indsG(j)] =
false;
662 if (!use_spread_lumping) {
664 for (LO row = 0; row < (LO)numRows; row++) {
665 if (diagIndex[row] != INVALID) {
666 size_t index_start = rowptr[row];
667 size_t diagIndexInMatrix = index_start + diagIndex[row];
669 vals[diagIndexInMatrix] += diagExtra[row];
670 SC A_rowsum = ZERO, A_absrowsum = ZERO, F_rowsum = ZERO;
672 if ((dirichletThresh >= 0.0 && TST::real(vals[diagIndexInMatrix]) <= dirichletThresh) || TST::real(vals[diagIndexInMatrix]) == ZERO) {
674 A.getLocalRowView(row, indsA, valsA);
678 for (LO l = 0; l < (LO)indsA.size(); l++) {
679 A_rowsum += valsA[l];
680 A_absrowsum += std::abs(valsA[l]);
682 for (LO l = 0; l < (LO)indsA.size(); l++)
683 F_rowsum += vals[index_start + l];
697 for (
size_t l = rowptr[row]; l < rowptr[row + 1]; l++) {
700 vals[diagIndexInMatrix] = TST::one();
704 GetOStream(
Runtime0) <<
"WARNING: Row " << row <<
" has no diagonal " << std::endl;
710 for (LO row = 0; row < (LO)numRows; row++) {
711 filteredA.replaceLocalValues(row, inds(rowptr[row], rowptr[row + 1] - rowptr[row]), vals(rowptr[row], rowptr[row + 1] - rowptr[row]));
713 if (use_spread_lumping) ExperimentalLumping(A, filteredA, DdomAllowGrowthRate, DdomCap);
715 size_t g_newDrops = 0, g_oldDrops = 0, g_fixedDiags = 0;
717 MueLu_sumAll(A.getRowMap()->getComm(), numNewDrops, g_newDrops);
718 MueLu_sumAll(A.getRowMap()->getComm(), numOldDrops, g_oldDrops);
719 MueLu_sumAll(A.getRowMap()->getComm(), numFixedDiags, g_fixedDiags);
720 GetOStream(
Runtime0) <<
"Filtering out " << g_newDrops <<
" edges, in addition to the " << g_oldDrops <<
" edges dropped earlier" << std::endl;
721 GetOStream(
Runtime0) <<
"Fixing " << g_fixedDiags <<
" zero diagonal values" << std::endl;
738 using TST =
typename Teuchos::ScalarTraits<SC>;
739 SC zero = TST::zero();
742 ArrayView<const LO> inds;
743 ArrayView<const SC> vals;
744 ArrayView<const LO> finds;
747 SC PosOffSum, NegOffSum, PosOffDropSum, NegOffDropSum;
748 SC diag, gamma, alpha;
749 LO NumPosKept, NumNegKept;
753 SC PosFilteredSum, NegFilteredSum;
756 SC rho = as<Scalar>(irho);
757 SC rho2 = as<Scalar>(irho2);
759 for (LO row = 0; row < (LO)A.getRowMap()->getLocalNumElements(); row++) {
760 noLumpDdom = as<Scalar>(10000.0);
773 ArrayView<const SC> tvals;
774 A.getLocalRowView(row, inds, vals);
775 size_t nnz = inds.size();
776 if (nnz == 0)
continue;
777 filteredA.getLocalRowView(row, finds, tvals);
779 fvals = ArrayView<SC>(
const_cast<SC*
>(tvals.getRawPtr()), nnz);
781 LO diagIndex = -1, fdiagIndex = -1;
785 PosOffDropSum = zero;
786 NegOffDropSum = zero;
792 for (
size_t j = 0; j < nnz; j++) {
793 if (inds[j] == row) {
797 if (TST::real(vals[j]) > TST::real(zero))
798 PosOffSum += vals[j];
800 NegOffSum += vals[j];
803 PosOffDropSum = PosOffSum;
804 NegOffDropSum = NegOffSum;
808 for (
size_t jj = 0; jj < (size_t)finds.size(); jj++) {
809 while (inds[j] != finds[jj]) j++;
811 if (finds[jj] == row)
814 if (TST::real(vals[j]) > TST::real(zero)) {
815 PosOffDropSum -= fvals[jj];
816 if (TST::real(fvals[jj]) != TST::real(zero)) NumPosKept++;
818 NegOffDropSum -= fvals[jj];
819 if (TST::real(fvals[jj]) != TST::real(zero)) NumNegKept++;
825 if (TST::magnitude(diag) != TST::magnitude(zero))
826 noLumpDdom = (PosOffSum - NegOffSum) / diag;
831 Target = rho * noLumpDdom;
832 if (TST::magnitude(Target) <= TST::magnitude(rho)) Target = rho2;
834 PosFilteredSum = PosOffSum - PosOffDropSum;
835 NegFilteredSum = NegOffSum - NegOffDropSum;
845 diag += PosOffDropSum;
848 gamma = -NegOffDropSum - PosFilteredSum;
850 if (TST::real(gamma) < TST::real(zero)) {
858 if (fdiagIndex != -1) fvals[fdiagIndex] = diag;
860 for (LO jj = 0; jj < (LO)finds.size(); jj++) {
861 while (inds[j] != finds[jj]) j++;
863 if ((j != diagIndex) && (TST::real(vals[j]) > TST::real(zero)) && (TST::magnitude(fvals[jj]) != TST::magnitude(zero)))
864 fvals[jj] = -gamma * (vals[j] / PosFilteredSum);
876 bool flipPosOffDiagsToNeg =
false;
888 if ((TST::real(diag) > TST::real(gamma)) &&
889 (TST::real((-NegFilteredSum) / (diag - gamma)) <= TST::real(Target))) {
896 if (fdiagIndex != -1) fvals[fdiagIndex] = diag - gamma;
897 }
else if (NumNegKept > 0) {
903 numer = -NegFilteredSum - Target * (diag - gamma);
904 denom = gamma * (Target - TST::one());
926 if (TST::magnitude(denom) < TST::magnitude(numer))
929 alpha = numer / denom;
930 if (TST::real(alpha) < TST::real(zero)) alpha = zero;
931 if (TST::real(diag) < TST::real((one - alpha) * gamma)) alpha = TST::one();
935 if (fdiagIndex != -1) fvals[fdiagIndex] = diag - (one - alpha) * gamma;
945 SC temp = (NegFilteredSum + alpha * gamma) / NegFilteredSum;
947 for (LO jj = 0; jj < (LO)finds.size(); jj++) {
948 while (inds[j] != finds[jj]) j++;
950 if ((jj != fdiagIndex) && (TST::magnitude(fvals[jj]) != TST::magnitude(zero)) &&
951 (TST::real(vals[j]) < TST::real(zero)))
952 fvals[jj] = temp * vals[j];
957 if (NumPosKept > 0) {
960 flipPosOffDiagsToNeg =
true;
963 for (LO jj = 0; jj < (LO)finds.size(); jj++) {
964 while (inds[j] != finds[jj]) j++;
966 if ((j != diagIndex) && (TST::magnitude(fvals[jj]) != TST::magnitude(zero)) &&
967 (TST::real(vals[j]) > TST::real(zero)))
968 fvals[jj] = -gamma / ((SC)NumPosKept);
973 if (!flipPosOffDiagsToNeg) {
978 for (LO jj = 0; jj < (LO)finds.size(); jj++) {
979 while (inds[j] != finds[jj]) j++;
981 if ((jj != fdiagIndex) && (TST::real(vals[j]) > TST::real(zero))) fvals[jj] = zero;