10#ifndef TPETRA_DETAILS_NORMIMPL_HPP
11#define TPETRA_DETAILS_NORMIMPL_HPP
22#include "TpetraCore_config.h"
23#include "Kokkos_Core.hpp"
24#include "Teuchos_ArrayView.hpp"
25#include "Teuchos_CommHelpers.hpp"
26#include "KokkosBlas.hpp"
27#include "KokkosKernels_ArithTraits.hpp"
29#ifndef DOXYGEN_SHOULD_SKIP_THIS
33template <
class OrdinalType>
58 const Kokkos::View<const ValueType**, ArrayLayout, DeviceType>&
X,
60 const Teuchos::ArrayView<const size_t>&
whichVecs,
61 const bool isConstantStride,
62 const bool isDistributed,
63 const Teuchos::Comm<int>* comm);
76template <
class RV,
class XMV>
77void lclNormImpl(
const RV& normsOut,
80 const Teuchos::ArrayView<const size_t>& whichVecs,
81 const bool constantStride,
84 using Kokkos::subview;
85 using mag_type =
typename RV::non_const_value_type;
87 static_assert(
static_cast<int>(RV::rank) == 1,
88 "Tpetra::MultiVector::lclNormImpl: "
89 "The first argument normsOut must have rank 1.");
90 static_assert(Kokkos::is_view<XMV>::value,
91 "Tpetra::MultiVector::lclNormImpl: "
92 "The second argument X is not a Kokkos::View.");
93 static_assert(
static_cast<int>(XMV::rank) == 2,
94 "Tpetra::MultiVector::lclNormImpl: "
95 "The second argument X must have rank 2.");
97 const size_t lclNumRows =
static_cast<size_t>(X.extent(0));
98 TEUCHOS_TEST_FOR_EXCEPTION(lclNumRows != 0 && constantStride &&
99 static_cast<size_t>(X.extent(1)) != numVecs,
100 std::logic_error,
"Constant Stride X's dimensions are " << X.extent(0) <<
" x " << X.extent(1) <<
", which differ from the local dimensions " << lclNumRows <<
" x " << numVecs <<
". Please report this bug to "
101 "the Tpetra developers.");
102 TEUCHOS_TEST_FOR_EXCEPTION(lclNumRows != 0 && !constantStride &&
103 static_cast<size_t>(X.extent(1)) < numVecs,
104 std::logic_error,
"Strided X's dimensions are " << X.extent(0) <<
" x " << X.extent(1) <<
", which are incompatible with the local dimensions " << lclNumRows <<
" x " << numVecs <<
". Please report this bug to "
105 "the Tpetra developers.");
107 if (lclNumRows == 0) {
108 const mag_type zeroMag = KokkosKernels::ArithTraits<mag_type>::zero();
110 using execution_space =
typename RV::execution_space;
111 Kokkos::deep_copy(execution_space(), normsOut, zeroMag);
113 if (constantStride) {
114 if (whichNorm == NORM_INF) {
115 KokkosBlas::nrminf(normsOut, X);
116 }
else if (whichNorm == NORM_ONE) {
117 KokkosBlas::nrm1(normsOut, X);
118 }
else if (whichNorm == NORM_TWO) {
119 KokkosBlas::nrm2_squared(normsOut, X);
121 TEUCHOS_TEST_FOR_EXCEPTION(
true, std::logic_error,
"Should never get here!");
129 for (
size_t k = 0; k < numVecs; ++k) {
130 const size_t X_col = constantStride ? k : whichVecs[k];
131 if (whichNorm == NORM_INF) {
132 KokkosBlas::nrminf(subview(normsOut, k),
133 subview(X, ALL(), X_col));
134 }
else if (whichNorm == NORM_ONE) {
135 KokkosBlas::nrm1(subview(normsOut, k),
136 subview(X, ALL(), X_col));
137 }
else if (whichNorm == NORM_TWO) {
138 KokkosBlas::nrm2_squared(subview(normsOut, k),
139 subview(X, ALL(), X_col));
141 TEUCHOS_TEST_FOR_EXCEPTION(
true, std::logic_error,
"Should never get here!");
150template <
class ViewType>
151class SquareRootFunctor {
153 typedef typename ViewType::execution_space execution_space;
154 typedef typename ViewType::size_type size_type;
156 SquareRootFunctor(
const ViewType& theView)
157 : theView_(theView) {}
159 KOKKOS_INLINE_FUNCTION
void
160 operator()(
const size_type& i)
const {
161 typedef typename ViewType::non_const_value_type value_type;
162 typedef KokkosKernels::ArithTraits<value_type> KAT;
163 theView_(i) = KAT::sqrt(theView_(i));
171void gblNormImpl(
const RV& normsOut,
172 const Teuchos::Comm<int>*
const comm,
173 const bool distributed,
175 using Teuchos::REDUCE_MAX;
176 using Teuchos::REDUCE_SUM;
177 using Teuchos::reduceAll;
178 typedef typename RV::non_const_value_type mag_type;
180 const size_t numVecs = normsOut.extent(0);
195 if (distributed && comm !=
nullptr) {
199 const int nv =
static_cast<int>(numVecs);
202 if (commIsInterComm) {
203 RV lclNorms(Kokkos::ViewAllocateWithoutInitializing(
"MV::normImpl lcl"), numVecs);
205 using execution_space =
typename RV::execution_space;
206 Kokkos::deep_copy(execution_space(), lclNorms, normsOut);
207 const mag_type*
const lclSum = lclNorms.data();
208 mag_type*
const gblSum = normsOut.data();
210 if (whichNorm == NORM_INF) {
211 reduceAll<int, mag_type>(*comm, REDUCE_MAX, nv, lclSum, gblSum);
213 reduceAll<int, mag_type>(*comm, REDUCE_SUM, nv, lclSum, gblSum);
216 mag_type*
const gblSum = normsOut.data();
217 if (whichNorm == NORM_INF) {
218 reduceAll<int, mag_type>(*comm, REDUCE_MAX, nv, gblSum, gblSum);
220 reduceAll<int, mag_type>(*comm, REDUCE_SUM, nv, gblSum, gblSum);
225 if (whichNorm == NORM_TWO) {
231 const bool inHostMemory =
232 std::is_same<
typename RV::memory_space,
233 typename RV::host_mirror_space::memory_space>::value;
235 for (
size_t j = 0; j < numVecs; ++j) {
236 normsOut(j) = KokkosKernels::ArithTraits<mag_type>::sqrt(normsOut(j));
243 SquareRootFunctor<RV> f(normsOut);
244 typedef typename RV::execution_space execution_space;
245 typedef Kokkos::RangePolicy<execution_space, size_t> range_type;
246 Kokkos::parallel_for(range_type(0, numVecs), f);
253template <
class ValueType,
258 const Kokkos::View<const ValueType**, ArrayLayout, DeviceType>&
X,
260 const Teuchos::ArrayView<const size_t>&
whichVecs,
261 const bool isConstantStride,
262 const bool isDistributed,
263 const Teuchos::Comm<int>* comm) {
264 using execution_space =
typename DeviceType::execution_space;
265 using RV = Kokkos::View<MagnitudeType*, Kokkos::HostSpace>;
269 const size_t numVecs = isConstantStride ?
static_cast<size_t>(
X.extent(1)) :
static_cast<size_t>(
whichVecs.size());
Struct that holds views of the contents of a CrsMatrix.
Implementation details of Tpetra.
void normImpl(MagnitudeType norms[], const Kokkos::View< const ValueType **, ArrayLayout, DeviceType > &X, const EWhichNorm whichNorm, const Teuchos::ArrayView< const size_t > &whichVecs, const bool isConstantStride, const bool isDistributed, const Teuchos::Comm< int > *comm)
Implementation of MultiVector norms.
EWhichNorm
Input argument for normImpl() (which see).
bool isInterComm(const Teuchos::Comm< int > &)
Return true if and only if the input communicator wraps an MPI intercommunicator.
Namespace Tpetra contains the class and methods constituting the Tpetra library.