10#include "Tpetra_Details_Ialltofewv.hpp"
17#include <Kokkos_Core.hpp>
26struct ProfilingRegion {
27 ProfilingRegion() =
delete;
28 ProfilingRegion(
const ProfilingRegion &other) =
delete;
29 ProfilingRegion(ProfilingRegion &&other) =
delete;
31 ProfilingRegion(
const std::string &name) {
32 Kokkos::Profiling::pushRegion(name);
35 Kokkos::Profiling::popRegion();
46KOKKOS_INLINE_FUNCTION
bool is_compatible(
const MemcpyArg &arg) {
47 return (0 == (uintptr_t(arg.dst) %
sizeof(T))) && (0 == (uintptr_t(arg.src) &
sizeof(T))) && (0 == (arg.count %
sizeof(T)));
50template <
typename T,
typename Member>
51KOKKOS_INLINE_FUNCTION
void team_memcpy_as(
const Member &member,
void *dst,
void *
const src,
size_t count) {
53 Kokkos::TeamThreadRange(member, count),
55 reinterpret_cast<T *
>(dst)[i] =
reinterpret_cast<T
const *
>(src)[i];
59template <
typename Member>
60KOKKOS_INLINE_FUNCTION
void team_memcpy(
const Member &member, MemcpyArg &arg) {
61 if (is_compatible<uint64_t>(arg)) {
62 team_memcpy_as<uint64_t>(member, arg.dst, arg.src, arg.count /
sizeof(uint64_t));
63 }
else if (is_compatible<uint32_t>(arg)) {
64 team_memcpy_as<uint32_t>(member, arg.dst, arg.src, arg.count /
sizeof(uint32_t));
66 team_memcpy_as<uint8_t>(member, arg.dst, arg.src, arg.count);
74struct Ialltofewv::Cache::impl {
76 : rootBufDev(
"rootBufDev", 0)
77 , rootBufHost(
"rootBufHost", 0)
78 , aggBufDev(
"aggBufDev", 0)
79 , aggBufHost(
"rootBufHost", 0)
80 , argsDev(
"argsDev", 0)
81 , argsHost(
"argsHost", 0)
93 , aggBufHostSize_(0) {}
96 Kokkos::View<uint8_t *, typename Kokkos::DefaultExecutionSpace::memory_space> rootBufDev;
97 Kokkos::View<uint8_t *, typename Kokkos::DefaultHostExecutionSpace::memory_space> rootBufHost;
98 Kokkos::View<char *, typename Kokkos::DefaultExecutionSpace::memory_space> aggBufDev;
99 Kokkos::View<char *, typename Kokkos::DefaultHostExecutionSpace::memory_space> aggBufHost;
100 Kokkos::View<MemcpyArg *, typename Kokkos::DefaultExecutionSpace::memory_space> argsDev;
101 Kokkos::View<MemcpyArg *, typename Kokkos::DefaultHostExecutionSpace::memory_space> argsHost;
110 size_t rootBufDevSize_, aggBufDevSize_;
111 size_t argsDevSize_, argsHostSize_;
112 size_t rootBufHostSize_, aggBufHostSize_;
114 template <
typename ExecSpace>
115 auto get_rootBuf(
size_t size) {
117 if constexpr (std::is_same_v<ExecSpace, Kokkos::DefaultExecutionSpace>) {
118 if (rootBufDev.extent(0) < size) {
119 Kokkos::resize(Kokkos::WithoutInitializing, rootBufDev, size);
120 rootBufDevSize_ = size;
124 return Kokkos::subview(rootBufDev, Kokkos::pair{size_t(0), size});
126 if (rootBufHost.extent(0) < size) {
127 Kokkos::resize(Kokkos::WithoutInitializing, rootBufHost, size);
128 rootBufHostSize_ = size;
132 return Kokkos::subview(rootBufHost, Kokkos::pair{size_t(0), size});
136 template <
typename ExecSpace>
137 auto get_aggBuf(
size_t size) {
139 if constexpr (std::is_same_v<ExecSpace, Kokkos::DefaultExecutionSpace>) {
140 if (aggBufDev.extent(0) < size) {
141 Kokkos::resize(Kokkos::WithoutInitializing, aggBufDev, size);
142 aggBufHostSize_ = size;
146 return Kokkos::subview(aggBufDev, Kokkos::pair{size_t(0), size});
148 if (aggBufHost.extent(0) < size) {
149 Kokkos::resize(Kokkos::WithoutInitializing, aggBufHost, size);
150 aggBufHostSize_ = size;
154 return Kokkos::subview(aggBufHost, Kokkos::pair{size_t(0), size});
158 template <
typename ExecSpace>
159 auto get_args(
size_t size) {
161 if constexpr (std::is_same_v<ExecSpace, Kokkos::DefaultExecutionSpace>) {
162 if (argsDev.extent(0) < size) {
163 Kokkos::resize(Kokkos::WithoutInitializing, argsDev, size);
164 argsHostSize_ = size;
168 return Kokkos::subview(argsDev, Kokkos::pair{size_t(0), size});
170 if (argsHost.extent(0) < size) {
171 Kokkos::resize(Kokkos::WithoutInitializing, argsHost, size);
172 argsHostSize_ = size;
176 return Kokkos::subview(argsHost, Kokkos::pair{size_t(0), size});
181Ialltofewv::Cache::Cache() =
default;
182Ialltofewv::Cache::~Cache() =
default;
185template <
typename RecvExecSpace>
186int wait_impl(Ialltofewv::Req &req, Ialltofewv::Cache &cache) {
188 req.completed =
true;
192 if (0 == req.nroots) {
196 ProfilingRegion pr(
"alltofewv::wait");
200 cache.pimpl = std::make_shared<Ialltofewv::Cache::impl>();
203 const int rank = [&]() ->
int {
205 MPI_Comm_rank(req.comm, &_rank);
209 const int size = [&]() ->
int {
211 MPI_Comm_size(req.comm, &_size);
215 const size_t sendSize = [&]() ->
size_t {
217 MPI_Type_size(req.sendtype, &_size);
221 const size_t recvSize = [&]() ->
size_t {
223 MPI_Type_size(req.recvtype, &_size);
228 const bool isRoot = std::find(req.roots, req.roots + req.nroots, rank) != req.roots + req.nroots;
237 const int naggs = std::sqrt(
size_t(size) *
size_t(req.nroots)) + 0.5;
240 const int srcsPerAgg = (size + naggs - 1) / naggs;
243 const int myAgg = rank / srcsPerAgg * srcsPerAgg;
247 std::vector<int> groupSendCounts(
size_t(req.nroots) *
size_t(srcsPerAgg));
248 std::vector<MPI_Request> reqs;
250 reqs.reserve(srcsPerAgg);
252 for (
int si = 0; si < srcsPerAgg && si + rank < size; ++si) {
256 if (
size_t(si) * req.nroots + req.nroots > groupSendCounts.size()) {
257 std::stringstream ss;
258 ss << __FILE__ <<
":" << __LINE__
259 <<
" [" << rank <<
"] tpetra internal Ialltofewv error: OOB access in recv buffer\n";
260 std::cerr << ss.str();
263 MPI_Irecv(&groupSendCounts[
size_t(si) *
size_t(req.nroots)], req.nroots, MPI_INT, si + rank,
264 req.aggTag, req.comm, &rreq);
265 reqs.push_back(rreq);
269 MPI_Send(req.sendcounts, req.nroots, MPI_INT, myAgg, req.aggTag, req.comm);
271 MPI_Waitall(reqs.size(), reqs.data(), MPI_STATUSES_IGNORE);
279 auto aggBuf = cache.pimpl->get_aggBuf<RecvExecSpace>(0);
280 std::vector<size_t> rootCount(req.nroots, 0);
283 for (
int si = 0; si < srcsPerAgg && si + rank < size; ++si) {
284 for (
int ri = 0; ri < req.nroots; ++ri) {
285 int count = groupSendCounts[si * req.nroots + ri];
286 rootCount[ri] += count;
287 aggBytes += count * sendSize;
291 aggBuf = cache.pimpl->get_aggBuf<RecvExecSpace>(aggBytes);
299 reqs.reserve(srcsPerAgg + req.nroots);
304 for (
int ri = 0; ri < req.nroots; ++ri) {
305 for (
int si = 0; si < srcsPerAgg && si + rank < size; ++si) {
307 const int count = groupSendCounts[si * req.nroots + ri];
310 if (displ + count * sendSize > aggBuf.size()) {
311 std::stringstream ss;
312 ss << __FILE__ <<
":" << __LINE__
313 <<
" [" << rank <<
"] tpetra internal Ialltofewv error: OOB access in send buffer\n";
314 std::cerr << ss.str();
319 MPI_Irecv(aggBuf.data() + displ, count, req.sendtype, si + rank, req.aggTag, req.comm, &rreq);
320 reqs.push_back(rreq);
321 displ += size_t(count) * sendSize;
326 reqs.reserve(req.nroots);
330 for (
int ri = 0; ri < req.nroots; ++ri) {
331 const size_t displ = size_t(req.sdispls[ri]) * sendSize;
332 const int count = req.sendcounts[ri];
335 MPI_Isend(&
reinterpret_cast<const char *
>(req.sendbuf)[displ], req.sendcounts[ri],
336 req.sendtype, myAgg, req.aggTag, req.comm, &sreq);
337 reqs.push_back(sreq);
341 MPI_Waitall(reqs.size(), reqs.data(), MPI_STATUSES_IGNORE);
346 auto rootBuf = cache.pimpl->get_rootBuf<RecvExecSpace>(0);
350 const size_t totalRecvd = recvSize * [&]() ->
size_t {
352 for (
int i = 0; i < size; ++i) {
353 acc += req.recvcounts[i];
357 rootBuf = cache.pimpl->get_rootBuf<RecvExecSpace>(totalRecvd);
363 for (
int aggSrc = 0; aggSrc < size; aggSrc += srcsPerAgg) {
366 for (
int origSrc = aggSrc;
367 origSrc < aggSrc + srcsPerAgg && origSrc < size; ++origSrc) {
368 count += req.recvcounts[origSrc];
373 MPI_Irecv(rootBuf.data() + displ, count, req.recvtype, aggSrc, req.rootTag, req.comm, &rreq);
374 reqs.push_back(rreq);
375 displ += size_t(count) * recvSize;
385 for (
int ri = 0; ri < req.nroots; ++ri) {
386 const size_t count = rootCount[ri];
389 MPI_Send(aggBuf.data() + displ, count, req.sendtype, req.roots[ri], req.rootTag, req.comm);
390 displ += count * sendSize;
395 MPI_Waitall(reqs.size(), reqs.data(), MPI_STATUSES_IGNORE);
400 auto args = cache.pimpl->get_args<RecvExecSpace>(size);
401 auto args_h = Kokkos::create_mirror_view(Kokkos::WithoutInitializing, args);
404 for (
int sRank = 0; sRank < size; ++sRank) {
405 const size_t dstOff = req.rdispls[sRank] * recvSize;
407 void *dst = &
reinterpret_cast<char *
>(req.recvbuf)[dstOff];
408 void *
const src = rootBuf.data() + srcOff;
409 const size_t count = req.recvcounts[sRank] * recvSize;
410 args_h(sRank) = MemcpyArg{dst, src, count};
413 if (srcOff + count > rootBuf.extent(0)) {
414 std::stringstream ss;
415 ss << __FILE__ <<
":" << __LINE__ <<
" Tpetra internal Ialltofewv error: src access OOB in memcpy\n";
416 std::cerr << ss.str();
423 Kokkos::deep_copy(args, args_h);
424 using Policy = Kokkos::TeamPolicy<RecvExecSpace>;
425 Policy policy(size, Kokkos::AUTO);
426 Kokkos::parallel_for(
427 "Tpetra::Details::Ialltofewv: apply rdispls to contiguous root buffer", policy,
428 KOKKOS_LAMBDA(
typename Policy::member_type member) {
429 team_memcpy(member, args(member.league_rank()));
431 Kokkos::fence(
"Tpetra::Details::Ialltofewv: after apply rdispls to contiguous root buffer");
438int Ialltofewv::wait(Req &req) {
440 return wait_impl<Kokkos::DefaultExecutionSpace>(req, cache_);
442 return wait_impl<Kokkos::DefaultHostExecutionSpace>(req, cache_);
446int Ialltofewv::get_status(
const Req &req,
int *flag, MPI_Status * )
const {
447 *flag = req.completed;
Nonmember function that computes a residual Computes R = B - A * X.
void finalize()
Finalize Tpetra.