10#ifndef TPETRA_DETAILS_IALLREDUCE_HPP
11#define TPETRA_DETAILS_IALLREDUCE_HPP
29#include "TpetraCore_config.h"
30#include "Teuchos_EReductionType.hpp"
31#ifdef HAVE_TPETRACORE_MPI
35#include "Tpetra_Details_temporaryViewUtils.hpp"
37#include "Kokkos_Core.hpp"
43#ifndef DOXYGEN_SHOULD_SKIP_THIS
46template <
class OrdinalType>
54#ifdef HAVE_TPETRACORE_MPI
55std::string getMpiErrorString(
const int errCode);
85std::shared_ptr<CommRequest>
88#ifdef HAVE_TPETRACORE_MPI
90template <
typename InputViewType,
typename OutputViewType,
typename ResultViewType>
107 void wait()
override {
108 if (req != MPI_REQUEST_NULL) {
109 const int err = MPI_Wait(&req, MPI_STATUS_IGNORE);
110 TEUCHOS_TEST_FOR_EXCEPTION(err != MPI_SUCCESS, std::runtime_error,
111 "MpiCommRequest::wait: MPI_Wait failed with error \""
112 << getMpiErrorString(err));
115 req = MPI_REQUEST_NULL;
117 Kokkos::deep_copy(resultBuf, recvBuf);
124 void cancel()
override {
127 req = MPI_REQUEST_NULL;
131 InputViewType sendBuf;
132 OutputViewType recvBuf;
133 ResultViewType resultBuf;
141iallreduceRaw(
const void* sendbuf,
144 MPI_Datatype mpiDatatype,
145 const Teuchos::EReductionType op,
150void allreduceRaw(
const void* sendbuf,
153 MPI_Datatype mpiDatatype,
154 const Teuchos::EReductionType op,
157template <
class InputViewType,
class OutputViewType>
158std::shared_ptr<CommRequest>
159iallreduceImpl(
const InputViewType& sendbuf,
160 const OutputViewType& recvbuf,
161 const ::Teuchos::EReductionType op,
162 const ::Teuchos::Comm<int>& comm) {
163 using Packet =
typename InputViewType::non_const_value_type;
164 if (comm.getSize() == 1) {
165 Kokkos::deep_copy(recvbuf, sendbuf);
166 return emptyCommRequest();
168 Packet examplePacket;
169 MPI_Datatype mpiDatatype = sendbuf.extent(0) ? MpiTypeTraits<Packet>::getType(examplePacket) : MPI_BYTE;
170 bool datatypeNeedsFree = MpiTypeTraits<Packet>::needsFree;
171 MPI_Comm rawComm = ::Tpetra::Details::extractMpiCommFromTeuchos(comm);
174 auto sendMPI = Tpetra::Details::TempView::toMPISafe<InputViewType, false>(sendbuf);
175 auto recvMPI = Tpetra::Details::TempView::toMPISafe<OutputViewType, false>(recvbuf);
176 std::shared_ptr<CommRequest> req;
179 if (
isInterComm(comm) && sendMPI.data() == recvMPI.data()) {
182 Kokkos::View<Packet*, Kokkos::HostSpace> tempInput(Kokkos::ViewAllocateWithoutInitializing(
"tempInput"), sendMPI.extent(0));
183 for (
size_t i = 0; i < sendMPI.extent(0); i++)
184 tempInput(i) = sendMPI.data()[i];
187 MPI_Request mpiReq = iallreduceRaw((
const void*)tempInput.data(), (
void*)recvMPI.data(), tempInput.extent(0), mpiDatatype, op, rawComm);
188 req = std::shared_ptr<CommRequest>(
new MpiRequest<
decltype(tempInput),
decltype(recvMPI), OutputViewType>(tempInput, recvMPI, recvbuf, mpiReq));
191 allreduceRaw((
const void*)sendMPI.data(), (
void*)recvMPI.data(), sendMPI.extent(0), mpiDatatype, op, rawComm);
192 Kokkos::deep_copy(recvbuf, recvMPI);
193 req = emptyCommRequest();
198 MPI_Request mpiReq = iallreduceRaw((
const void*)sendMPI.data(), (
void*)recvMPI.data(), sendMPI.extent(0), mpiDatatype, op, rawComm);
199 req = std::shared_ptr<CommRequest>(
new MpiRequest<
decltype(sendMPI),
decltype(recvMPI), OutputViewType>(sendMPI, recvMPI, recvbuf, mpiReq));
202 allreduceRaw((
const void*)sendMPI.data(), (
void*)recvMPI.data(), sendMPI.extent(0), mpiDatatype, op, rawComm);
203 Kokkos::deep_copy(recvbuf, recvMPI);
204 req = emptyCommRequest();
207 if (datatypeNeedsFree)
208 MPI_Type_free(&mpiDatatype);
215template <
class InputViewType,
class OutputViewType>
216std::shared_ptr<CommRequest>
217iallreduceImpl(
const InputViewType& sendbuf,
218 const OutputViewType& recvbuf,
219 const ::Teuchos::EReductionType,
220 const ::Teuchos::Comm<int>&) {
221 Kokkos::deep_copy(recvbuf, sendbuf);
222 return emptyCommRequest();
258template <
class InputViewType,
class OutputViewType>
259std::shared_ptr<CommRequest>
262 const ::Teuchos::EReductionType op,
263 const ::Teuchos::Comm<int>& comm) {
264 static_assert(Kokkos::is_view<InputViewType>::value,
265 "InputViewType must be a Kokkos::View specialization.");
266 static_assert(Kokkos::is_view<OutputViewType>::value,
267 "OutputViewType must be a Kokkos::View specialization.");
268 constexpr int rank =
static_cast<int>(OutputViewType::rank);
269 static_assert(
static_cast<int>(InputViewType::rank) ==
rank,
270 "InputViewType and OutputViewType must have the same rank.");
271 static_assert(
rank == 0 ||
rank == 1,
272 "InputViewType and OutputViewType must both have "
273 "rank 0 or rank 1.");
274 typedef typename OutputViewType::non_const_value_type packet_type;
275 static_assert(std::is_same<
typename OutputViewType::value_type,
277 "OutputViewType must be a nonconst Kokkos::View.");
278 static_assert(std::is_same<
typename InputViewType::non_const_value_type,
280 "InputViewType and OutputViewType must be Views "
281 "whose entries have the same type.");
283 static_assert(!std::is_same<typename InputViewType::array_layout, Kokkos::LayoutStride>::value,
284 "Input/Output views must be contiguous (not LayoutStride)");
285 static_assert(!std::is_same<typename OutputViewType::array_layout, Kokkos::LayoutStride>::value,
286 "Input/Output views must be contiguous (not LayoutStride)");
288 return Impl::iallreduceImpl<InputViewType, OutputViewType>(sendbuf, recvbuf, op, comm);
291std::shared_ptr<CommRequest>
294 const ::Teuchos::EReductionType op,
295 const ::Teuchos::Comm<int>& comm);
Declaration of Tpetra::Details::Behavior, a class that describes Tpetra's behavior.
Add specializations of Teuchos::Details::MpiTypeTraits for Kokkos::complex<float> and Kokkos::complex...
Struct that holds views of the contents of a CrsMatrix.
Base class for the request (more or less a future) representing a pending nonblocking MPI operation.
virtual ~CommRequest()
Destructor (virtual for memory safety of derived classes).
virtual void cancel()
Cancel the pending communication request.
virtual void wait()
Wait on this communication request to complete.
Implementation details of Tpetra.
bool isInterComm(const Teuchos::Comm< int > &)
Return true if and only if the input communicator wraps an MPI intercommunicator.
Namespace Tpetra contains the class and methods constituting the Tpetra library.