Tpetra parallel linear algebra Version of the Day
Loading...
Searching...
No Matches
Tpetra_Details_lclDot.hpp
Go to the documentation of this file.
1// @HEADER
2// *****************************************************************************
3// Tpetra: Templated Linear Algebra Services Package
4//
5// Copyright 2008 NTESS and the Tpetra contributors.
6// SPDX-License-Identifier: BSD-3-Clause
7// *****************************************************************************
8// @HEADER
9
10#ifndef TPETRA_DETAILS_LCLDOT_HPP
11#define TPETRA_DETAILS_LCLDOT_HPP
12
19
20#include "Kokkos_DualView.hpp"
21#if KOKKOS_VERSION >= 40799
22#include "KokkosKernels_ArithTraits.hpp"
23#else
24#include "Kokkos_ArithTraits.hpp"
25#endif
26#include "KokkosBlas1_dot.hpp"
27#include "Teuchos_ArrayView.hpp"
28#include "Teuchos_TestForException.hpp"
29
30namespace Tpetra {
31namespace Details {
32
33template <class RV, class XMV>
34void lclDot(const RV& dotsOut,
35 const XMV& X_lcl,
36 const XMV& Y_lcl,
37 const size_t lclNumRows,
38 const size_t numVecs,
39 const size_t whichVecsX[],
40 const size_t whichVecsY[],
41 const bool constantStrideX,
42 const bool constantStrideY) {
43 using Kokkos::ALL;
44 using Kokkos::subview;
45 typedef typename RV::non_const_value_type dot_type;
46#ifdef HAVE_TPETRA_DEBUG
47 const char prefix[] = "Tpetra::MultiVector::lclDotImpl: ";
48#endif // HAVE_TPETRA_DEBUG
49
50 static_assert(Kokkos::is_view<RV>::value,
51 "Tpetra::MultiVector::lclDotImpl: "
52 "The first argument dotsOut is not a Kokkos::View.");
53 static_assert(RV::rank == 1,
54 "Tpetra::MultiVector::lclDotImpl: "
55 "The first argument dotsOut must have rank 1.");
56 static_assert(Kokkos::is_view<XMV>::value,
57 "Tpetra::MultiVector::lclDotImpl: The type of the 2nd and "
58 "3rd arguments (X_lcl and Y_lcl) is not a Kokkos::View.");
59 static_assert(XMV::rank == 2,
60 "Tpetra::MultiVector::lclDotImpl: "
61 "X_lcl and Y_lcl must have rank 2.");
62
63 // In case the input dimensions don't match, make sure that we
64 // don't overwrite memory that doesn't belong to us, by using
65 // subset views with the minimum dimensions over all input.
66 const std::pair<size_t, size_t> rowRng(0, lclNumRows);
67 const std::pair<size_t, size_t> colRng(0, numVecs);
68 RV theDots = subview(dotsOut, colRng);
69 XMV X = subview(X_lcl, rowRng, ALL());
70 XMV Y = subview(Y_lcl, rowRng, ALL());
71
72#ifdef HAVE_TPETRA_DEBUG
73 if (lclNumRows != 0) {
74 TEUCHOS_TEST_FOR_EXCEPTION(X.extent(0) != lclNumRows, std::logic_error, prefix << "X.extent(0) = " << X.extent(0) << " != lclNumRows "
75 "= "
76 << lclNumRows << ". "
77 "Please report this bug to the Tpetra developers.");
78 TEUCHOS_TEST_FOR_EXCEPTION(Y.extent(0) != lclNumRows, std::logic_error, prefix << "Y.extent(0) = " << Y.extent(0) << " != lclNumRows "
79 "= "
80 << lclNumRows << ". "
81 "Please report this bug to the Tpetra developers.");
82 // If a MultiVector is constant stride, then numVecs should
83 // equal its View's number of columns. Otherwise, numVecs
84 // should be less than its View's number of columns.
85 TEUCHOS_TEST_FOR_EXCEPTION(constantStrideX &&
86 (X.extent(0) != lclNumRows || X.extent(1) != numVecs),
87 std::logic_error, prefix << "X is " << X.extent(0) << " x " << X.extent(1) << " (constant stride), which differs from the "
88 "local dimensions "
89 << lclNumRows << " x " << numVecs << ". "
90 "Please report this bug to the Tpetra developers.");
91 TEUCHOS_TEST_FOR_EXCEPTION(!constantStrideX &&
92 (X.extent(0) != lclNumRows || X.extent(1) < numVecs),
93 std::logic_error, prefix << "X is " << X.extent(0) << " x " << X.extent(1) << " (NOT constant stride), but the local "
94 "dimensions are "
95 << lclNumRows << " x " << numVecs << ". "
96 "Please report this bug to the Tpetra developers.");
97 TEUCHOS_TEST_FOR_EXCEPTION(constantStrideY &&
98 (Y.extent(0) != lclNumRows || Y.extent(1) != numVecs),
99 std::logic_error, prefix << "Y is " << Y.extent(0) << " x " << Y.extent(1) << " (constant stride), which differs from the "
100 "local dimensions "
101 << lclNumRows << " x " << numVecs << ". "
102 "Please report this bug to the Tpetra developers.");
103 TEUCHOS_TEST_FOR_EXCEPTION(!constantStrideY &&
104 (Y.extent(0) != lclNumRows || Y.extent(1) < numVecs),
105 std::logic_error, prefix << "Y is " << Y.extent(0) << " x " << Y.extent(1) << " (NOT constant stride), but the local "
106 "dimensions are "
107 << lclNumRows << " x " << numVecs << ". "
108 "Please report this bug to the Tpetra developers.");
109 }
110#endif // HAVE_TPETRA_DEBUG
111
112 if (lclNumRows == 0) {
113#if KOKKOS_VERSION >= 40799
114 const dot_type zero = KokkosKernels::ArithTraits<dot_type>::zero();
115#else
116 const dot_type zero = Kokkos::ArithTraits<dot_type>::zero();
117#endif
118 // DEEP_COPY REVIEW - NOT TESTED
119 Kokkos::deep_copy(theDots, zero);
120 } else { // lclNumRows != 0
121 if (constantStrideX && constantStrideY) {
122 if (X.extent(1) == 1) {
123 typename RV::non_const_value_type result =
124 KokkosBlas::dot(subview(X, ALL(), 0), subview(Y, ALL(), 0));
125 // DEEP_COPY REVIEW - NOT TESTED
126 Kokkos::deep_copy(theDots, result);
127 } else {
128 KokkosBlas::dot(theDots, X, Y);
129 }
130 } else { // not constant stride
131 // NOTE (mfh 15 Jul 2014) This does a kernel launch for
132 // every column. It might be better to have a kernel that
133 // does the work all at once. On the other hand, we don't
134 // prioritize performance of MultiVector views of
135 // noncontiguous columns.
136 for (size_t k = 0; k < numVecs; ++k) {
137 const size_t X_col = constantStrideX ? k : whichVecsX[k];
138 const size_t Y_col = constantStrideY ? k : whichVecsY[k];
139 KokkosBlas::dot(subview(theDots, k), subview(X, ALL(), X_col),
140 subview(Y, ALL(), Y_col));
141 } // for each column
142 } // constantStride
143 } // lclNumRows != 0
144}
145
146} // namespace Details
147} // namespace Tpetra
148
149#endif // TPETRA_DETAILS_LCLDOT_HPP
Implementation details of Tpetra.
Namespace Tpetra contains the class and methods constituting the Tpetra library.