Tpetra parallel linear algebra Version of the Day
Loading...
Searching...
No Matches
Tpetra_Details_lclDot.hpp
Go to the documentation of this file.
1// @HEADER
2// *****************************************************************************
3// Tpetra: Templated Linear Algebra Services Package
4//
5// Copyright 2008 NTESS and the Tpetra contributors.
6// SPDX-License-Identifier: BSD-3-Clause
7// *****************************************************************************
8// @HEADER
9
10#ifndef TPETRA_DETAILS_LCLDOT_HPP
11#define TPETRA_DETAILS_LCLDOT_HPP
12
19
20#include "Kokkos_DualView.hpp"
21#include "KokkosKernels_ArithTraits.hpp"
22#include "KokkosBlas1_dot.hpp"
23#include "Teuchos_ArrayView.hpp"
24#include "Teuchos_TestForException.hpp"
25
26namespace Tpetra {
27namespace Details {
28
29template <class RV, class XMV>
30void lclDot(const RV& dotsOut,
31 const XMV& X_lcl,
32 const XMV& Y_lcl,
33 const size_t lclNumRows,
34 const size_t numVecs,
35 const size_t whichVecsX[],
36 const size_t whichVecsY[],
37 const bool constantStrideX,
38 const bool constantStrideY) {
39 using Kokkos::ALL;
40 using Kokkos::subview;
41 typedef typename RV::non_const_value_type dot_type;
42#ifdef HAVE_TPETRA_DEBUG
43 const char prefix[] = "Tpetra::MultiVector::lclDotImpl: ";
44#endif // HAVE_TPETRA_DEBUG
45
46 static_assert(Kokkos::is_view<RV>::value,
47 "Tpetra::MultiVector::lclDotImpl: "
48 "The first argument dotsOut is not a Kokkos::View.");
49 static_assert(RV::rank == 1,
50 "Tpetra::MultiVector::lclDotImpl: "
51 "The first argument dotsOut must have rank 1.");
52 static_assert(Kokkos::is_view<XMV>::value,
53 "Tpetra::MultiVector::lclDotImpl: The type of the 2nd and "
54 "3rd arguments (X_lcl and Y_lcl) is not a Kokkos::View.");
55 static_assert(XMV::rank == 2,
56 "Tpetra::MultiVector::lclDotImpl: "
57 "X_lcl and Y_lcl must have rank 2.");
58
59 // In case the input dimensions don't match, make sure that we
60 // don't overwrite memory that doesn't belong to us, by using
61 // subset views with the minimum dimensions over all input.
62 const std::pair<size_t, size_t> rowRng(0, lclNumRows);
63 const std::pair<size_t, size_t> colRng(0, numVecs);
64 RV theDots = subview(dotsOut, colRng);
65 XMV X = subview(X_lcl, rowRng, ALL());
66 XMV Y = subview(Y_lcl, rowRng, ALL());
67
68#ifdef HAVE_TPETRA_DEBUG
69 if (lclNumRows != 0) {
70 TEUCHOS_TEST_FOR_EXCEPTION(X.extent(0) != lclNumRows, std::logic_error, prefix << "X.extent(0) = " << X.extent(0) << " != lclNumRows "
71 "= "
72 << lclNumRows << ". "
73 "Please report this bug to the Tpetra developers.");
74 TEUCHOS_TEST_FOR_EXCEPTION(Y.extent(0) != lclNumRows, std::logic_error, prefix << "Y.extent(0) = " << Y.extent(0) << " != lclNumRows "
75 "= "
76 << lclNumRows << ". "
77 "Please report this bug to the Tpetra developers.");
78 // If a MultiVector is constant stride, then numVecs should
79 // equal its View's number of columns. Otherwise, numVecs
80 // should be less than its View's number of columns.
81 TEUCHOS_TEST_FOR_EXCEPTION(constantStrideX &&
82 (X.extent(0) != lclNumRows || X.extent(1) != numVecs),
83 std::logic_error, prefix << "X is " << X.extent(0) << " x " << X.extent(1) << " (constant stride), which differs from the "
84 "local dimensions "
85 << lclNumRows << " x " << numVecs << ". "
86 "Please report this bug to the Tpetra developers.");
87 TEUCHOS_TEST_FOR_EXCEPTION(!constantStrideX &&
88 (X.extent(0) != lclNumRows || X.extent(1) < numVecs),
89 std::logic_error, prefix << "X is " << X.extent(0) << " x " << X.extent(1) << " (NOT constant stride), but the local "
90 "dimensions are "
91 << lclNumRows << " x " << numVecs << ". "
92 "Please report this bug to the Tpetra developers.");
93 TEUCHOS_TEST_FOR_EXCEPTION(constantStrideY &&
94 (Y.extent(0) != lclNumRows || Y.extent(1) != numVecs),
95 std::logic_error, prefix << "Y is " << Y.extent(0) << " x " << Y.extent(1) << " (constant stride), which differs from the "
96 "local dimensions "
97 << lclNumRows << " x " << numVecs << ". "
98 "Please report this bug to the Tpetra developers.");
99 TEUCHOS_TEST_FOR_EXCEPTION(!constantStrideY &&
100 (Y.extent(0) != lclNumRows || Y.extent(1) < numVecs),
101 std::logic_error, prefix << "Y is " << Y.extent(0) << " x " << Y.extent(1) << " (NOT constant stride), but the local "
102 "dimensions are "
103 << lclNumRows << " x " << numVecs << ". "
104 "Please report this bug to the Tpetra developers.");
105 }
106#endif // HAVE_TPETRA_DEBUG
107
108 if (lclNumRows == 0) {
109 const dot_type zero = KokkosKernels::ArithTraits<dot_type>::zero();
110 // DEEP_COPY REVIEW - NOT TESTED
111 Kokkos::deep_copy(theDots, zero);
112 } else { // lclNumRows != 0
113 if (constantStrideX && constantStrideY) {
114 if (X.extent(1) == 1) {
115 typename RV::non_const_value_type result =
116 KokkosBlas::dot(subview(X, ALL(), 0), subview(Y, ALL(), 0));
117 // DEEP_COPY REVIEW - NOT TESTED
118 Kokkos::deep_copy(theDots, result);
119 } else {
120 KokkosBlas::dot(theDots, X, Y);
121 }
122 } else { // not constant stride
123 // NOTE (mfh 15 Jul 2014) This does a kernel launch for
124 // every column. It might be better to have a kernel that
125 // does the work all at once. On the other hand, we don't
126 // prioritize performance of MultiVector views of
127 // noncontiguous columns.
128 for (size_t k = 0; k < numVecs; ++k) {
129 const size_t X_col = constantStrideX ? k : whichVecsX[k];
130 const size_t Y_col = constantStrideY ? k : whichVecsY[k];
131 KokkosBlas::dot(subview(theDots, k), subview(X, ALL(), X_col),
132 subview(Y, ALL(), Y_col));
133 } // for each column
134 } // constantStride
135 } // lclNumRows != 0
136}
137
138} // namespace Details
139} // namespace Tpetra
140
141#endif // TPETRA_DETAILS_LCLDOT_HPP
Implementation details of Tpetra.
Namespace Tpetra contains the class and methods constituting the Tpetra library.