10#ifndef TPETRA_DETAILS_SCALEBLOCKDIAGONAL_HPP
11#define TPETRA_DETAILS_SCALEBLOCKDIAGONAL_HPP
13#include "TpetraCore_config.h"
14#include "Tpetra_CrsMatrix.hpp"
15#include "Teuchos_ScalarTraits.hpp"
17#include "KokkosBatched_Util.hpp"
18#include "KokkosBatched_LU_Decl.hpp"
19#include "KokkosBatched_LU_Serial_Impl.hpp"
20#include "KokkosBatched_Trsm_Decl.hpp"
21#include "KokkosBatched_Trsm_Serial_Impl.hpp"
33template <
class MultiVectorType>
34void inverseScaleBlockDiagonal(MultiVectorType& blockDiagonal,
bool doTranspose, MultiVectorType& multiVectorToBeScaled) {
35 using LO =
typename MultiVectorType::local_ordinal_type;
36 using range_type = Kokkos::RangePolicy<typename MultiVectorType::node_type::execution_space, LO>;
37 using namespace KokkosBatched;
38 typename MultiVectorType::impl_scalar_type SC_one = Teuchos::ScalarTraits<typename MultiVectorType::impl_scalar_type>::one();
42 TEUCHOS_TEST_FOR_EXCEPTION(!blockDiagonal.getMap()->isSameAs(*multiVectorToBeScaled.getMap()),
43 std::runtime_error,
"Tpetra::Details::scaledBlockDiagonal was given incompatible maps");
46 LO numrows = blockDiagonal.getLocalLength();
47 LO blocksize = blockDiagonal.getNumVectors();
48 LO numblocks = numrows / blocksize;
52 auto blockDiag = blockDiagonal.getLocalViewDevice(Access::OverwriteAll);
53 auto toScale = multiVectorToBeScaled.getLocalViewDevice(Access::ReadWrite);
55 typedef Algo::Level3::Unblocked algo_type;
57 "scaleBlockDiagonal", range_type(0, numblocks), KOKKOS_LAMBDA(
const LO i) {
58 Kokkos::pair<LO, LO> row_range(i * blocksize, (i + 1) * blocksize);
59 auto A = Kokkos::subview(blockDiag, row_range, Kokkos::ALL());
60 auto B = Kokkos::subview(toScale, row_range, Kokkos::ALL());
63 SerialLU<algo_type>::invoke(A);
67 SerialTrsm<Side::Left, Uplo::Upper, Trans::Transpose, Diag::NonUnit, algo_type>::invoke(SC_one, A, B);
69 SerialTrsm<Side::Left, Uplo::Lower, Trans::Transpose, Diag::Unit, algo_type>::invoke(SC_one, A, B);
72 SerialTrsm<Side::Left, Uplo::Lower, Trans::NoTranspose, Diag::Unit, algo_type>::invoke(SC_one, A, B);
74 SerialTrsm<Side::Left, Uplo::Upper, Trans::NoTranspose, Diag::NonUnit, algo_type>::invoke(SC_one, A, B);
Declaration of Tpetra::Details::Behavior, a class that describes Tpetra's behavior.
static bool debug()
Whether Tpetra is in debug mode.
Implementation details of Tpetra.
Namespace Tpetra contains the class and methods constituting the Tpetra library.