Tpetra parallel linear algebra Version of the Day
Loading...
Searching...
No Matches
Tpetra_Details_iallreduce.hpp
Go to the documentation of this file.
1// @HEADER
2// *****************************************************************************
3// Tpetra: Templated Linear Algebra Services Package
4//
5// Copyright 2008 NTESS and the Tpetra contributors.
6// SPDX-License-Identifier: BSD-3-Clause
7// *****************************************************************************
8// @HEADER
9
10#ifndef TPETRA_DETAILS_IALLREDUCE_HPP
11#define TPETRA_DETAILS_IALLREDUCE_HPP
12
28
29#include "TpetraCore_config.h"
30#include "Teuchos_EReductionType.hpp"
32#ifdef HAVE_TPETRACORE_MPI
35#endif // HAVE_TPETRACORE_MPI
36#include "Tpetra_Details_temporaryViewUtils.hpp"
38#include "Kokkos_Core.hpp"
39#include <memory>
40#include <stdexcept>
41#include <type_traits>
42#include <functional>
43
44#ifndef DOXYGEN_SHOULD_SKIP_THIS
45namespace Teuchos {
46// forward declaration of Comm
47template <class OrdinalType>
48class Comm;
49} // namespace Teuchos
50#endif // NOT DOXYGEN_SHOULD_SKIP_THIS
51
52namespace Tpetra {
53namespace Details {
54
55#ifdef HAVE_TPETRACORE_MPI
56std::string getMpiErrorString(const int errCode);
57#endif
58
66 public:
68 virtual ~CommRequest() {}
69
74 virtual void wait() {}
75
79 virtual void cancel() {}
80};
81
82// Don't rely on anything in this namespace.
83namespace Impl {
84
86std::shared_ptr<CommRequest>
87emptyCommRequest();
88
89#ifdef HAVE_TPETRACORE_MPI
90#if MPI_VERSION >= 3
91template <typename InputViewType, typename OutputViewType, typename ResultViewType>
92struct MpiRequest : public CommRequest {
94 : sendBuf(send)
95 , recvBuf(recv)
97 , req(req_) {}
98
99 ~MpiRequest() {
100 // this is a no-op if wait() or cancel() have already been called
101 cancel();
102 }
103
108 void wait() override {
109 if (req != MPI_REQUEST_NULL) {
110 Details::ProfilingRegion pr("Tpetra::Details::MpiRequest::wait");
111 const int err = MPI_Wait(&req, MPI_STATUS_IGNORE);
112 TEUCHOS_TEST_FOR_EXCEPTION(err != MPI_SUCCESS, std::runtime_error,
113 "MpiCommRequest::wait: MPI_Wait failed with error \""
114 << getMpiErrorString(err));
115 // MPI_Wait should set the MPI_Request to MPI_REQUEST_NULL on
116 // success. We'll do it here just to be conservative.
117 req = MPI_REQUEST_NULL;
118 // Since recvBuf contains the result, copy it to the user's resultBuf.
119 Kokkos::deep_copy(resultBuf, recvBuf);
120 }
121 }
122
126 void cancel() override {
127 // BMK: Per https://www.mpi-forum.org/docs/mpi-3.1/mpi31-report/node126.htm,
128 // MPI_Cancel cannot be used for collectives like iallreduce.
129 req = MPI_REQUEST_NULL;
130 }
131
132 private:
133 InputViewType sendBuf;
134 OutputViewType recvBuf;
135 ResultViewType resultBuf;
136 // This request is active if and only if req != MPI_REQUEST_NULL.
137 MPI_Request req;
138};
139
142MPI_Request
143iallreduceRaw(const void* sendbuf,
144 void* recvbuf,
145 const int count,
146 MPI_Datatype mpiDatatype,
147 const Teuchos::EReductionType op,
148 MPI_Comm comm);
149#endif
150
152void allreduceRaw(const void* sendbuf,
153 void* recvbuf,
154 const int count,
155 MPI_Datatype mpiDatatype,
156 const Teuchos::EReductionType op,
157 MPI_Comm comm);
158
159template <class InputViewType, class OutputViewType>
160std::shared_ptr<CommRequest>
161iallreduceImpl(const InputViewType& sendbuf,
162 const OutputViewType& recvbuf,
163 const ::Teuchos::EReductionType op,
164 const ::Teuchos::Comm<int>& comm) {
165 using Packet = typename InputViewType::non_const_value_type;
166 if (comm.getSize() == 1) {
167 // Avoid deep_copy precond violation if views are identical
168 if (recvbuf != sendbuf)
169 Kokkos::deep_copy(recvbuf, sendbuf);
170 return emptyCommRequest();
171 }
172 Packet examplePacket;
173 MPI_Datatype mpiDatatype = sendbuf.extent(0) ? MpiTypeTraits<Packet>::getType(examplePacket) : MPI_BYTE;
174 bool datatypeNeedsFree = MpiTypeTraits<Packet>::needsFree;
175 MPI_Comm rawComm = ::Tpetra::Details::extractMpiCommFromTeuchos(comm);
176 // Note BMK: Nonblocking collectives like iallreduce cannot use GPU buffers.
177 // See https://www.open-mpi.org/faq/?category=runcuda#mpi-cuda-support
178 auto sendMPI = Tpetra::Details::TempView::toMPISafe<InputViewType, false>(sendbuf);
179 auto recvMPI = Tpetra::Details::TempView::toMPISafe<OutputViewType, false>(recvbuf);
180 std::shared_ptr<CommRequest> req;
181 // Next, if input/output alias and comm is an intercomm, make a deep copy of input.
182 // Not possible to do in-place allreduce for intercomm.
183 if (isInterComm(comm) && sendMPI.data() == recvMPI.data()) {
184 // Can't do in-place collective on an intercomm,
185 // so use a separate 1D copy as the input.
186 Kokkos::View<Packet*, Kokkos::HostSpace> tempInput(Kokkos::ViewAllocateWithoutInitializing("tempInput"), sendMPI.extent(0));
187 for (size_t i = 0; i < sendMPI.extent(0); i++)
188 tempInput(i) = sendMPI.data()[i];
189#if MPI_VERSION >= 3
190 // MPI 3+: use async allreduce
191 MPI_Request mpiReq = iallreduceRaw((const void*)tempInput.data(), (void*)recvMPI.data(), tempInput.extent(0), mpiDatatype, op, rawComm);
192 req = std::shared_ptr<CommRequest>(new MpiRequest<decltype(tempInput), decltype(recvMPI), OutputViewType>(tempInput, recvMPI, recvbuf, mpiReq));
193#else
194 // Older MPI: Iallreduce not available. Instead do blocking all-reduce and return empty request.
195 allreduceRaw((const void*)sendMPI.data(), (void*)recvMPI.data(), sendMPI.extent(0), mpiDatatype, op, rawComm);
196 Kokkos::deep_copy(recvbuf, recvMPI);
197 req = emptyCommRequest();
198#endif
199 } else {
200#if MPI_VERSION >= 3
201 // MPI 3+: use async allreduce
202 MPI_Request mpiReq = iallreduceRaw((const void*)sendMPI.data(), (void*)recvMPI.data(), sendMPI.extent(0), mpiDatatype, op, rawComm);
203 req = std::shared_ptr<CommRequest>(new MpiRequest<decltype(sendMPI), decltype(recvMPI), OutputViewType>(sendMPI, recvMPI, recvbuf, mpiReq));
204#else
205 // Older MPI: Iallreduce not available. Instead do blocking all-reduce and return empty request.
206 allreduceRaw((const void*)sendMPI.data(), (void*)recvMPI.data(), sendMPI.extent(0), mpiDatatype, op, rawComm);
207 Kokkos::deep_copy(recvbuf, recvMPI);
208 req = emptyCommRequest();
209#endif
210 }
211 if (datatypeNeedsFree)
212 MPI_Type_free(&mpiDatatype);
213 return req;
214}
215
216#else
217
218// No MPI: reduction is always the same as input.
219template <class InputViewType, class OutputViewType>
220std::shared_ptr<CommRequest>
221iallreduceImpl(const InputViewType& sendbuf,
222 const OutputViewType& recvbuf,
223 const ::Teuchos::EReductionType,
224 const ::Teuchos::Comm<int>&) {
225 // Avoid deep_copy precond violation if views are identical
226 if (recvbuf != sendbuf)
227 Kokkos::deep_copy(recvbuf, sendbuf);
228 return emptyCommRequest();
229}
230
231#endif // HAVE_TPETRACORE_MPI
232
233} // namespace Impl
234
235//
236// SKIP DOWN TO HERE
237//
238
264template <class InputViewType, class OutputViewType>
265std::shared_ptr<CommRequest>
266iallreduce(const InputViewType& sendbuf,
267 const OutputViewType& recvbuf,
268 const ::Teuchos::EReductionType op,
269 const ::Teuchos::Comm<int>& comm) {
270 static_assert(Kokkos::is_view<InputViewType>::value,
271 "InputViewType must be a Kokkos::View specialization.");
272 static_assert(Kokkos::is_view<OutputViewType>::value,
273 "OutputViewType must be a Kokkos::View specialization.");
274 constexpr int rank = static_cast<int>(OutputViewType::rank);
275 static_assert(static_cast<int>(InputViewType::rank) == rank,
276 "InputViewType and OutputViewType must have the same rank.");
277 static_assert(rank == 0 || rank == 1,
278 "InputViewType and OutputViewType must both have "
279 "rank 0 or rank 1.");
280 typedef typename OutputViewType::non_const_value_type packet_type;
281 static_assert(std::is_same<typename OutputViewType::value_type,
282 packet_type>::value,
283 "OutputViewType must be a nonconst Kokkos::View.");
284 static_assert(std::is_same<typename InputViewType::non_const_value_type,
285 packet_type>::value,
286 "InputViewType and OutputViewType must be Views "
287 "whose entries have the same type.");
288 // Make sure layouts are contiguous (don't accept strided 1D view)
289 static_assert(!std::is_same<typename InputViewType::array_layout, Kokkos::LayoutStride>::value,
290 "Input/Output views must be contiguous (not LayoutStride)");
291 static_assert(!std::is_same<typename OutputViewType::array_layout, Kokkos::LayoutStride>::value,
292 "Input/Output views must be contiguous (not LayoutStride)");
293
294 return Impl::iallreduceImpl<InputViewType, OutputViewType>(sendbuf, recvbuf, op, comm);
295}
296
297template <class ValueType>
298std::shared_ptr<CommRequest>
299iallreduce(const ValueType localValue,
301 const ::Teuchos::EReductionType op,
302 const ::Teuchos::Comm<int>& comm);
303
304} // namespace Details
305} // namespace Tpetra
306
307#endif // TPETRA_DETAILS_IALLREDUCE_HPP
Declaration of Tpetra::Details::Behavior, a class that describes Tpetra's behavior.
Add specializations of Teuchos::Details::MpiTypeTraits for Kokkos::complex<float> and Kokkos::complex...
Declaration of Tpetra::Details::Profiling, a scope guard for Kokkos Profiling.
Declaration of Tpetra::Details::extractMpiCommFromTeuchos.
Struct that holds views of the contents of a CrsMatrix.
Base class for the request (more or less a future) representing a pending nonblocking MPI operation.
virtual ~CommRequest()
Destructor (virtual for memory safety of derived classes).
virtual void cancel()
Cancel the pending communication request.
virtual void wait()
Wait on this communication request to complete.
Implementation details of Tpetra.
bool isInterComm(const Teuchos::Comm< int > &)
Return true if and only if the input communicator wraps an MPI intercommunicator.
Namespace Tpetra contains the class and methods constituting the Tpetra library.