10#ifndef TPETRA_DETAILS_IALLREDUCE_HPP
11#define TPETRA_DETAILS_IALLREDUCE_HPP
29#include "TpetraCore_config.h"
30#include "Teuchos_EReductionType.hpp"
32#ifdef HAVE_TPETRACORE_MPI
36#include "Tpetra_Details_temporaryViewUtils.hpp"
38#include "Kokkos_Core.hpp"
44#ifndef DOXYGEN_SHOULD_SKIP_THIS
47template <
class OrdinalType>
55#ifdef HAVE_TPETRACORE_MPI
56std::string getMpiErrorString(
const int errCode);
86std::shared_ptr<CommRequest>
89#ifdef HAVE_TPETRACORE_MPI
91template <
typename InputViewType,
typename OutputViewType,
typename ResultViewType>
108 void wait()
override {
109 if (req != MPI_REQUEST_NULL) {
110 Details::ProfilingRegion pr(
"Tpetra::Details::MpiRequest::wait");
111 const int err = MPI_Wait(&req, MPI_STATUS_IGNORE);
112 TEUCHOS_TEST_FOR_EXCEPTION(err != MPI_SUCCESS, std::runtime_error,
113 "MpiCommRequest::wait: MPI_Wait failed with error \""
114 << getMpiErrorString(err));
117 req = MPI_REQUEST_NULL;
119 Kokkos::deep_copy(resultBuf, recvBuf);
126 void cancel()
override {
129 req = MPI_REQUEST_NULL;
133 InputViewType sendBuf;
134 OutputViewType recvBuf;
135 ResultViewType resultBuf;
143iallreduceRaw(
const void* sendbuf,
146 MPI_Datatype mpiDatatype,
147 const Teuchos::EReductionType op,
152void allreduceRaw(
const void* sendbuf,
155 MPI_Datatype mpiDatatype,
156 const Teuchos::EReductionType op,
159template <
class InputViewType,
class OutputViewType>
160std::shared_ptr<CommRequest>
161iallreduceImpl(
const InputViewType& sendbuf,
162 const OutputViewType& recvbuf,
163 const ::Teuchos::EReductionType op,
164 const ::Teuchos::Comm<int>& comm) {
165 using Packet =
typename InputViewType::non_const_value_type;
166 if (comm.getSize() == 1) {
168 if (recvbuf != sendbuf)
169 Kokkos::deep_copy(recvbuf, sendbuf);
170 return emptyCommRequest();
172 Packet examplePacket;
173 MPI_Datatype mpiDatatype = sendbuf.extent(0) ? MpiTypeTraits<Packet>::getType(examplePacket) : MPI_BYTE;
174 bool datatypeNeedsFree = MpiTypeTraits<Packet>::needsFree;
175 MPI_Comm rawComm = ::Tpetra::Details::extractMpiCommFromTeuchos(comm);
178 auto sendMPI = Tpetra::Details::TempView::toMPISafe<InputViewType, false>(sendbuf);
179 auto recvMPI = Tpetra::Details::TempView::toMPISafe<OutputViewType, false>(recvbuf);
180 std::shared_ptr<CommRequest> req;
183 if (
isInterComm(comm) && sendMPI.data() == recvMPI.data()) {
186 Kokkos::View<Packet*, Kokkos::HostSpace> tempInput(Kokkos::ViewAllocateWithoutInitializing(
"tempInput"), sendMPI.extent(0));
187 for (
size_t i = 0; i < sendMPI.extent(0); i++)
188 tempInput(i) = sendMPI.data()[i];
191 MPI_Request mpiReq = iallreduceRaw((
const void*)tempInput.data(), (
void*)recvMPI.data(), tempInput.extent(0), mpiDatatype, op, rawComm);
192 req = std::shared_ptr<CommRequest>(
new MpiRequest<
decltype(tempInput),
decltype(recvMPI), OutputViewType>(tempInput, recvMPI, recvbuf, mpiReq));
195 allreduceRaw((
const void*)sendMPI.data(), (
void*)recvMPI.data(), sendMPI.extent(0), mpiDatatype, op, rawComm);
196 Kokkos::deep_copy(recvbuf, recvMPI);
197 req = emptyCommRequest();
202 MPI_Request mpiReq = iallreduceRaw((
const void*)sendMPI.data(), (
void*)recvMPI.data(), sendMPI.extent(0), mpiDatatype, op, rawComm);
203 req = std::shared_ptr<CommRequest>(
new MpiRequest<
decltype(sendMPI),
decltype(recvMPI), OutputViewType>(sendMPI, recvMPI, recvbuf, mpiReq));
206 allreduceRaw((
const void*)sendMPI.data(), (
void*)recvMPI.data(), sendMPI.extent(0), mpiDatatype, op, rawComm);
207 Kokkos::deep_copy(recvbuf, recvMPI);
208 req = emptyCommRequest();
211 if (datatypeNeedsFree)
212 MPI_Type_free(&mpiDatatype);
219template <
class InputViewType,
class OutputViewType>
220std::shared_ptr<CommRequest>
221iallreduceImpl(
const InputViewType& sendbuf,
222 const OutputViewType& recvbuf,
223 const ::Teuchos::EReductionType,
224 const ::Teuchos::Comm<int>&) {
226 if (recvbuf != sendbuf)
227 Kokkos::deep_copy(recvbuf, sendbuf);
228 return emptyCommRequest();
264template <
class InputViewType,
class OutputViewType>
265std::shared_ptr<CommRequest>
268 const ::Teuchos::EReductionType op,
269 const ::Teuchos::Comm<int>& comm) {
270 static_assert(Kokkos::is_view<InputViewType>::value,
271 "InputViewType must be a Kokkos::View specialization.");
272 static_assert(Kokkos::is_view<OutputViewType>::value,
273 "OutputViewType must be a Kokkos::View specialization.");
274 constexpr int rank =
static_cast<int>(OutputViewType::rank);
275 static_assert(
static_cast<int>(InputViewType::rank) ==
rank,
276 "InputViewType and OutputViewType must have the same rank.");
277 static_assert(
rank == 0 ||
rank == 1,
278 "InputViewType and OutputViewType must both have "
279 "rank 0 or rank 1.");
280 typedef typename OutputViewType::non_const_value_type packet_type;
281 static_assert(std::is_same<
typename OutputViewType::value_type,
283 "OutputViewType must be a nonconst Kokkos::View.");
284 static_assert(std::is_same<
typename InputViewType::non_const_value_type,
286 "InputViewType and OutputViewType must be Views "
287 "whose entries have the same type.");
289 static_assert(!std::is_same<typename InputViewType::array_layout, Kokkos::LayoutStride>::value,
290 "Input/Output views must be contiguous (not LayoutStride)");
291 static_assert(!std::is_same<typename OutputViewType::array_layout, Kokkos::LayoutStride>::value,
292 "Input/Output views must be contiguous (not LayoutStride)");
294 return Impl::iallreduceImpl<InputViewType, OutputViewType>(sendbuf, recvbuf, op, comm);
297template <
class ValueType>
298std::shared_ptr<CommRequest>
301 const ::Teuchos::EReductionType op,
302 const ::Teuchos::Comm<int>& comm);
Declaration of Tpetra::Details::Behavior, a class that describes Tpetra's behavior.
Add specializations of Teuchos::Details::MpiTypeTraits for Kokkos::complex<float> and Kokkos::complex...
Declaration of Tpetra::Details::Profiling, a scope guard for Kokkos Profiling.
Struct that holds views of the contents of a CrsMatrix.
Base class for the request (more or less a future) representing a pending nonblocking MPI operation.
virtual ~CommRequest()
Destructor (virtual for memory safety of derived classes).
virtual void cancel()
Cancel the pending communication request.
virtual void wait()
Wait on this communication request to complete.
Implementation details of Tpetra.
bool isInterComm(const Teuchos::Comm< int > &)
Return true if and only if the input communicator wraps an MPI intercommunicator.
Namespace Tpetra contains the class and methods constituting the Tpetra library.