Tpetra parallel linear algebra Version of the Day
Loading...
Searching...
No Matches
Tpetra_Details_iallreduce.hpp
Go to the documentation of this file.
1// @HEADER
2// *****************************************************************************
3// Tpetra: Templated Linear Algebra Services Package
4//
5// Copyright 2008 NTESS and the Tpetra contributors.
6// SPDX-License-Identifier: BSD-3-Clause
7// *****************************************************************************
8// @HEADER
9
10#ifndef TPETRA_DETAILS_IALLREDUCE_HPP
11#define TPETRA_DETAILS_IALLREDUCE_HPP
12
28
29#include "TpetraCore_config.h"
30#include "Teuchos_EReductionType.hpp"
31#ifdef HAVE_TPETRACORE_MPI
34#endif // HAVE_TPETRACORE_MPI
35#include "Tpetra_Details_temporaryViewUtils.hpp"
37#include "Kokkos_Core.hpp"
38#include <memory>
39#include <stdexcept>
40#include <type_traits>
41#include <functional>
42
43#ifndef DOXYGEN_SHOULD_SKIP_THIS
44namespace Teuchos {
45// forward declaration of Comm
46template <class OrdinalType>
47class Comm;
48} // namespace Teuchos
49#endif // NOT DOXYGEN_SHOULD_SKIP_THIS
50
51namespace Tpetra {
52namespace Details {
53
54#ifdef HAVE_TPETRACORE_MPI
55std::string getMpiErrorString(const int errCode);
56#endif
57
65 public:
67 virtual ~CommRequest() {}
68
73 virtual void wait() {}
74
78 virtual void cancel() {}
79};
80
81// Don't rely on anything in this namespace.
82namespace Impl {
83
85std::shared_ptr<CommRequest>
86emptyCommRequest();
87
88#ifdef HAVE_TPETRACORE_MPI
89#if MPI_VERSION >= 3
90template <typename InputViewType, typename OutputViewType, typename ResultViewType>
91struct MpiRequest : public CommRequest {
93 : sendBuf(send)
94 , recvBuf(recv)
96 , req(req_) {}
97
98 ~MpiRequest() {
99 // this is a no-op if wait() or cancel() have already been called
100 cancel();
101 }
102
107 void wait() override {
108 if (req != MPI_REQUEST_NULL) {
109 const int err = MPI_Wait(&req, MPI_STATUS_IGNORE);
110 TEUCHOS_TEST_FOR_EXCEPTION(err != MPI_SUCCESS, std::runtime_error,
111 "MpiCommRequest::wait: MPI_Wait failed with error \""
112 << getMpiErrorString(err));
113 // MPI_Wait should set the MPI_Request to MPI_REQUEST_NULL on
114 // success. We'll do it here just to be conservative.
115 req = MPI_REQUEST_NULL;
116 // Since recvBuf contains the result, copy it to the user's resultBuf.
117 Kokkos::deep_copy(resultBuf, recvBuf);
118 }
119 }
120
124 void cancel() override {
125 // BMK: Per https://www.mpi-forum.org/docs/mpi-3.1/mpi31-report/node126.htm,
126 // MPI_Cancel cannot be used for collectives like iallreduce.
127 req = MPI_REQUEST_NULL;
128 }
129
130 private:
131 InputViewType sendBuf;
132 OutputViewType recvBuf;
133 ResultViewType resultBuf;
134 // This request is active if and only if req != MPI_REQUEST_NULL.
135 MPI_Request req;
136};
137
140MPI_Request
141iallreduceRaw(const void* sendbuf,
142 void* recvbuf,
143 const int count,
144 MPI_Datatype mpiDatatype,
145 const Teuchos::EReductionType op,
146 MPI_Comm comm);
147#endif
148
150void allreduceRaw(const void* sendbuf,
151 void* recvbuf,
152 const int count,
153 MPI_Datatype mpiDatatype,
154 const Teuchos::EReductionType op,
155 MPI_Comm comm);
156
157template <class InputViewType, class OutputViewType>
158std::shared_ptr<CommRequest>
159iallreduceImpl(const InputViewType& sendbuf,
160 const OutputViewType& recvbuf,
161 const ::Teuchos::EReductionType op,
162 const ::Teuchos::Comm<int>& comm) {
163 using Packet = typename InputViewType::non_const_value_type;
164 if (comm.getSize() == 1) {
165 Kokkos::deep_copy(recvbuf, sendbuf);
166 return emptyCommRequest();
167 }
168 Packet examplePacket;
169 MPI_Datatype mpiDatatype = sendbuf.extent(0) ? MpiTypeTraits<Packet>::getType(examplePacket) : MPI_BYTE;
170 bool datatypeNeedsFree = MpiTypeTraits<Packet>::needsFree;
171 MPI_Comm rawComm = ::Tpetra::Details::extractMpiCommFromTeuchos(comm);
172 // Note BMK: Nonblocking collectives like iallreduce cannot use GPU buffers.
173 // See https://www.open-mpi.org/faq/?category=runcuda#mpi-cuda-support
174 auto sendMPI = Tpetra::Details::TempView::toMPISafe<InputViewType, false>(sendbuf);
175 auto recvMPI = Tpetra::Details::TempView::toMPISafe<OutputViewType, false>(recvbuf);
176 std::shared_ptr<CommRequest> req;
177 // Next, if input/output alias and comm is an intercomm, make a deep copy of input.
178 // Not possible to do in-place allreduce for intercomm.
179 if (isInterComm(comm) && sendMPI.data() == recvMPI.data()) {
180 // Can't do in-place collective on an intercomm,
181 // so use a separate 1D copy as the input.
182 Kokkos::View<Packet*, Kokkos::HostSpace> tempInput(Kokkos::ViewAllocateWithoutInitializing("tempInput"), sendMPI.extent(0));
183 for (size_t i = 0; i < sendMPI.extent(0); i++)
184 tempInput(i) = sendMPI.data()[i];
185#if MPI_VERSION >= 3
186 // MPI 3+: use async allreduce
187 MPI_Request mpiReq = iallreduceRaw((const void*)tempInput.data(), (void*)recvMPI.data(), tempInput.extent(0), mpiDatatype, op, rawComm);
188 req = std::shared_ptr<CommRequest>(new MpiRequest<decltype(tempInput), decltype(recvMPI), OutputViewType>(tempInput, recvMPI, recvbuf, mpiReq));
189#else
190 // Older MPI: Iallreduce not available. Instead do blocking all-reduce and return empty request.
191 allreduceRaw((const void*)sendMPI.data(), (void*)recvMPI.data(), sendMPI.extent(0), mpiDatatype, op, rawComm);
192 Kokkos::deep_copy(recvbuf, recvMPI);
193 req = emptyCommRequest();
194#endif
195 } else {
196#if MPI_VERSION >= 3
197 // MPI 3+: use async allreduce
198 MPI_Request mpiReq = iallreduceRaw((const void*)sendMPI.data(), (void*)recvMPI.data(), sendMPI.extent(0), mpiDatatype, op, rawComm);
199 req = std::shared_ptr<CommRequest>(new MpiRequest<decltype(sendMPI), decltype(recvMPI), OutputViewType>(sendMPI, recvMPI, recvbuf, mpiReq));
200#else
201 // Older MPI: Iallreduce not available. Instead do blocking all-reduce and return empty request.
202 allreduceRaw((const void*)sendMPI.data(), (void*)recvMPI.data(), sendMPI.extent(0), mpiDatatype, op, rawComm);
203 Kokkos::deep_copy(recvbuf, recvMPI);
204 req = emptyCommRequest();
205#endif
206 }
207 if (datatypeNeedsFree)
208 MPI_Type_free(&mpiDatatype);
209 return req;
210}
211
212#else
213
214// No MPI: reduction is always the same as input.
215template <class InputViewType, class OutputViewType>
216std::shared_ptr<CommRequest>
217iallreduceImpl(const InputViewType& sendbuf,
218 const OutputViewType& recvbuf,
219 const ::Teuchos::EReductionType,
220 const ::Teuchos::Comm<int>&) {
221 Kokkos::deep_copy(recvbuf, sendbuf);
222 return emptyCommRequest();
223}
224
225#endif // HAVE_TPETRACORE_MPI
226
227} // namespace Impl
228
229//
230// SKIP DOWN TO HERE
231//
232
258template <class InputViewType, class OutputViewType>
259std::shared_ptr<CommRequest>
260iallreduce(const InputViewType& sendbuf,
261 const OutputViewType& recvbuf,
262 const ::Teuchos::EReductionType op,
263 const ::Teuchos::Comm<int>& comm) {
264 static_assert(Kokkos::is_view<InputViewType>::value,
265 "InputViewType must be a Kokkos::View specialization.");
266 static_assert(Kokkos::is_view<OutputViewType>::value,
267 "OutputViewType must be a Kokkos::View specialization.");
268 constexpr int rank = static_cast<int>(OutputViewType::rank);
269 static_assert(static_cast<int>(InputViewType::rank) == rank,
270 "InputViewType and OutputViewType must have the same rank.");
271 static_assert(rank == 0 || rank == 1,
272 "InputViewType and OutputViewType must both have "
273 "rank 0 or rank 1.");
274 typedef typename OutputViewType::non_const_value_type packet_type;
275 static_assert(std::is_same<typename OutputViewType::value_type,
276 packet_type>::value,
277 "OutputViewType must be a nonconst Kokkos::View.");
278 static_assert(std::is_same<typename InputViewType::non_const_value_type,
279 packet_type>::value,
280 "InputViewType and OutputViewType must be Views "
281 "whose entries have the same type.");
282 // Make sure layouts are contiguous (don't accept strided 1D view)
283 static_assert(!std::is_same<typename InputViewType::array_layout, Kokkos::LayoutStride>::value,
284 "Input/Output views must be contiguous (not LayoutStride)");
285 static_assert(!std::is_same<typename OutputViewType::array_layout, Kokkos::LayoutStride>::value,
286 "Input/Output views must be contiguous (not LayoutStride)");
287
288 return Impl::iallreduceImpl<InputViewType, OutputViewType>(sendbuf, recvbuf, op, comm);
289}
290
291std::shared_ptr<CommRequest>
292iallreduce(const int localValue,
293 int& globalValue,
294 const ::Teuchos::EReductionType op,
295 const ::Teuchos::Comm<int>& comm);
296
297} // namespace Details
298} // namespace Tpetra
299
300#endif // TPETRA_DETAILS_IALLREDUCE_HPP
Declaration of Tpetra::Details::Behavior, a class that describes Tpetra's behavior.
Add specializations of Teuchos::Details::MpiTypeTraits for Kokkos::complex<float> and Kokkos::complex...
Declaration of Tpetra::Details::extractMpiCommFromTeuchos.
Struct that holds views of the contents of a CrsMatrix.
Base class for the request (more or less a future) representing a pending nonblocking MPI operation.
virtual ~CommRequest()
Destructor (virtual for memory safety of derived classes).
virtual void cancel()
Cancel the pending communication request.
virtual void wait()
Wait on this communication request to complete.
Implementation details of Tpetra.
bool isInterComm(const Teuchos::Comm< int > &)
Return true if and only if the input communicator wraps an MPI intercommunicator.
Namespace Tpetra contains the class and methods constituting the Tpetra library.