17#ifndef KOKKOS_MATHEMATICAL_FUNCTIONS_HPP
18#define KOKKOS_MATHEMATICAL_FUNCTIONS_HPP
19#ifndef KOKKOS_IMPL_PUBLIC_INCLUDE
20#define KOKKOS_IMPL_PUBLIC_INCLUDE
21#define KOKKOS_IMPL_PUBLIC_INCLUDE_NOTDEFINED_MATHFUNCTIONS
24#include <Kokkos_Macros.hpp>
29#ifdef KOKKOS_ENABLE_SYCL
31#if __has_include(<sycl/sycl.hpp>)
32#include <sycl/sycl.hpp>
41template <
class T,
bool = std::is_
integral_v<T>>
46struct promote<T, false> {};
48struct promote<long double> {
49 using type =
long double;
52struct promote<double> {
56struct promote<float> {
60using promote_t =
typename promote<T>::type;
61template <
class T,
class U,
62 bool = std::is_arithmetic_v<T>&& std::is_arithmetic_v<U>>
64 using type =
decltype(promote_t<T>() + promote_t<U>());
66template <
class T,
class U>
67struct promote_2<T, U, false> {};
68template <
class T,
class U>
69using promote_2_t =
typename promote_2<T, U>::type;
70template <
class T,
class U,
class V,
71 bool = std::is_arithmetic_v<T>&& std::is_arithmetic_v<U>&&
72 std::is_arithmetic_v<V>>
74 using type =
decltype(promote_t<T>() + promote_t<U>() + promote_t<V>());
76template <
class T,
class U,
class V>
77struct promote_3<T, U, V, false> {};
78template <
class T,
class U,
class V>
79using promote_3_t =
typename promote_3<T, U, V>::type;
84#if defined(KOKKOS_ENABLE_SYCL)
85#define KOKKOS_IMPL_MATH_FUNCTIONS_NAMESPACE sycl
87#if (defined(KOKKOS_COMPILER_NVCC) || defined(KOKKOS_COMPILER_NVHPC)) && \
88 defined(__GNUC__) && (__GNUC__ < 6) && !defined(__clang__)
89#define KOKKOS_IMPL_MATH_FUNCTIONS_NAMESPACE
91#define KOKKOS_IMPL_MATH_FUNCTIONS_NAMESPACE std
95#define KOKKOS_IMPL_MATH_UNARY_FUNCTION(FUNC) \
96 KOKKOS_INLINE_FUNCTION float FUNC(float x) { \
97 using KOKKOS_IMPL_MATH_FUNCTIONS_NAMESPACE::FUNC; \
100 KOKKOS_INLINE_FUNCTION double FUNC(double x) { \
101 using KOKKOS_IMPL_MATH_FUNCTIONS_NAMESPACE::FUNC; \
104 inline long double FUNC(long double x) { \
108 KOKKOS_INLINE_FUNCTION float FUNC##f(float x) { \
109 using KOKKOS_IMPL_MATH_FUNCTIONS_NAMESPACE::FUNC; \
112 inline long double FUNC##l(long double x) { \
117 KOKKOS_INLINE_FUNCTION std::enable_if_t<std::is_integral_v<T>, double> FUNC( \
119 using KOKKOS_IMPL_MATH_FUNCTIONS_NAMESPACE::FUNC; \
120 return FUNC(static_cast<double>(x)); \
126#if defined(_WIN32) && defined(KOKKOS_ENABLE_CUDA)
127#define KOKKOS_IMPL_MATH_UNARY_PREDICATE(FUNC) \
128 KOKKOS_INLINE_FUNCTION bool FUNC(float x) { return ::FUNC(x); } \
129 KOKKOS_INLINE_FUNCTION bool FUNC(double x) { return ::FUNC(x); } \
130 inline bool FUNC(long double x) { \
135 KOKKOS_INLINE_FUNCTION std::enable_if_t<std::is_integral_v<T>, bool> FUNC( \
137 return ::FUNC(static_cast<double>(x)); \
140#define KOKKOS_IMPL_MATH_UNARY_PREDICATE(FUNC) \
141 KOKKOS_INLINE_FUNCTION bool FUNC(float x) { \
142 using KOKKOS_IMPL_MATH_FUNCTIONS_NAMESPACE::FUNC; \
145 KOKKOS_INLINE_FUNCTION bool FUNC(double x) { \
146 using KOKKOS_IMPL_MATH_FUNCTIONS_NAMESPACE::FUNC; \
149 inline bool FUNC(long double x) { \
154 KOKKOS_INLINE_FUNCTION std::enable_if_t<std::is_integral_v<T>, bool> FUNC( \
156 using KOKKOS_IMPL_MATH_FUNCTIONS_NAMESPACE::FUNC; \
157 return FUNC(static_cast<double>(x)); \
161#define KOKKOS_IMPL_MATH_BINARY_FUNCTION(FUNC) \
162 KOKKOS_INLINE_FUNCTION float FUNC(float x, float y) { \
163 using KOKKOS_IMPL_MATH_FUNCTIONS_NAMESPACE::FUNC; \
166 KOKKOS_INLINE_FUNCTION double FUNC(double x, double y) { \
167 using KOKKOS_IMPL_MATH_FUNCTIONS_NAMESPACE::FUNC; \
170 inline long double FUNC(long double x, long double y) { \
174 KOKKOS_INLINE_FUNCTION float FUNC##f(float x, float y) { \
175 using KOKKOS_IMPL_MATH_FUNCTIONS_NAMESPACE::FUNC; \
178 inline long double FUNC##l(long double x, long double y) { \
182 template <class T1, class T2> \
183 KOKKOS_INLINE_FUNCTION \
184 std::enable_if_t<std::is_arithmetic_v<T1> && std::is_arithmetic_v<T2> && \
185 !std::is_same_v<T1, long double> && \
186 !std::is_same_v<T2, long double>, \
187 Kokkos::Impl::promote_2_t<T1, T2>> \
189 using Promoted = Kokkos::Impl::promote_2_t<T1, T2>; \
190 using KOKKOS_IMPL_MATH_FUNCTIONS_NAMESPACE::FUNC; \
191 return FUNC(static_cast<Promoted>(x), static_cast<Promoted>(y)); \
193 template <class T1, class T2> \
194 inline std::enable_if_t<std::is_arithmetic_v<T1> && \
195 std::is_arithmetic_v<T2> && \
196 (std::is_same_v<T1, long double> || \
197 std::is_same_v<T2, long double>), \
200 using Promoted = Kokkos::Impl::promote_2_t<T1, T2>; \
201 static_assert(std::is_same_v<Promoted, long double>); \
203 return FUNC(static_cast<Promoted>(x), static_cast<Promoted>(y)); \
206#define KOKKOS_IMPL_MATH_TERNARY_FUNCTION(FUNC) \
207 KOKKOS_INLINE_FUNCTION float FUNC(float x, float y, float z) { \
208 using KOKKOS_IMPL_MATH_FUNCTIONS_NAMESPACE::FUNC; \
209 return FUNC(x, y, z); \
211 KOKKOS_INLINE_FUNCTION double FUNC(double x, double y, double z) { \
212 using KOKKOS_IMPL_MATH_FUNCTIONS_NAMESPACE::FUNC; \
213 return FUNC(x, y, z); \
215 inline long double FUNC(long double x, long double y, long double z) { \
217 return FUNC(x, y, z); \
219 KOKKOS_INLINE_FUNCTION float FUNC##f(float x, float y, float z) { \
220 using KOKKOS_IMPL_MATH_FUNCTIONS_NAMESPACE::FUNC; \
221 return FUNC(x, y, z); \
223 inline long double FUNC##l(long double x, long double y, long double z) { \
225 return FUNC(x, y, z); \
227 template <class T1, class T2, class T3> \
228 KOKKOS_INLINE_FUNCTION std::enable_if_t< \
229 std::is_arithmetic_v<T1> && std::is_arithmetic_v<T2> && \
230 std::is_arithmetic_v<T3> && !std::is_same_v<T1, long double> && \
231 !std::is_same_v<T2, long double> && \
232 !std::is_same_v<T3, long double>, \
233 Kokkos::Impl::promote_3_t<T1, T2, T3>> \
234 FUNC(T1 x, T2 y, T3 z) { \
235 using Promoted = Kokkos::Impl::promote_3_t<T1, T2, T3>; \
236 using KOKKOS_IMPL_MATH_FUNCTIONS_NAMESPACE::FUNC; \
237 return FUNC(static_cast<Promoted>(x), static_cast<Promoted>(y), \
238 static_cast<Promoted>(z)); \
240 template <class T1, class T2, class T3> \
241 inline std::enable_if_t<std::is_arithmetic_v<T1> && \
242 std::is_arithmetic_v<T2> && \
243 std::is_arithmetic_v<T3> && \
244 (std::is_same_v<T1, long double> || \
245 std::is_same_v<T2, long double> || \
246 std::is_same_v<T3, long double>), \
248 FUNC(T1 x, T2 y, T3 z) { \
249 using Promoted = Kokkos::Impl::promote_3_t<T1, T2, T3>; \
250 static_assert(std::is_same_v<Promoted, long double>); \
252 return FUNC(static_cast<Promoted>(x), static_cast<Promoted>(y), \
253 static_cast<Promoted>(z)); \
257KOKKOS_INLINE_FUNCTION
int abs(
int n) {
258 using KOKKOS_IMPL_MATH_FUNCTIONS_NAMESPACE::abs;
261KOKKOS_INLINE_FUNCTION
long abs(
long n) {
263#if defined(KOKKOS_COMPILER_NVHPC) && KOKKOS_COMPILER_NVHPC < 230700
264 return n > 0 ? n : -n;
266 using KOKKOS_IMPL_MATH_FUNCTIONS_NAMESPACE::abs;
270KOKKOS_INLINE_FUNCTION
long long abs(
long long n) {
272#if defined(KOKKOS_COMPILER_NVHPC) && KOKKOS_COMPILER_NVHPC < 230700
273 return n > 0 ? n : -n;
275 using KOKKOS_IMPL_MATH_FUNCTIONS_NAMESPACE::abs;
279KOKKOS_INLINE_FUNCTION
float abs(
float x) {
280#ifdef KOKKOS_ENABLE_SYCL
281 return sycl::fabs(x);
283 using KOKKOS_IMPL_MATH_FUNCTIONS_NAMESPACE::abs;
287KOKKOS_INLINE_FUNCTION
double abs(
double x) {
288#ifdef KOKKOS_ENABLE_SYCL
289 return sycl::fabs(x);
291 using KOKKOS_IMPL_MATH_FUNCTIONS_NAMESPACE::abs;
295inline long double abs(
long double x) {
299KOKKOS_IMPL_MATH_UNARY_FUNCTION(fabs)
300KOKKOS_IMPL_MATH_BINARY_FUNCTION(fmod)
301KOKKOS_IMPL_MATH_BINARY_FUNCTION(remainder)
303KOKKOS_IMPL_MATH_TERNARY_FUNCTION(fma)
304KOKKOS_IMPL_MATH_BINARY_FUNCTION(fmax)
305KOKKOS_IMPL_MATH_BINARY_FUNCTION(fmin)
306KOKKOS_IMPL_MATH_BINARY_FUNCTION(fdim)
307#ifndef KOKKOS_ENABLE_SYCL
308KOKKOS_INLINE_FUNCTION
float nanf(
char const* arg) { return ::nanf(arg); }
309KOKKOS_INLINE_FUNCTION
double nan(
char const* arg) { return ::nan(arg); }
315KOKKOS_INLINE_FUNCTION
float nanf(
char const*) {
return sycl::nan(0u); }
316KOKKOS_INLINE_FUNCTION
double nan(
char const*) {
return sycl::nan(0ul); }
318inline long double nanl(
char const* arg) { return ::nanl(arg); }
320KOKKOS_IMPL_MATH_UNARY_FUNCTION(exp)
322#if defined(KOKKOS_COMPILER_NVHPC) && KOKKOS_COMPILER_NVHPC < 230700
323KOKKOS_INLINE_FUNCTION
float exp2(
float val) {
324 constexpr float ln2 = 0.693147180559945309417232121458176568L;
325 return exp(ln2 * val);
327KOKKOS_INLINE_FUNCTION
double exp2(
double val) {
328 constexpr double ln2 = 0.693147180559945309417232121458176568L;
329 return exp(ln2 * val);
331inline long double exp2(
long double val) {
332 constexpr long double ln2 = 0.693147180559945309417232121458176568L;
333 return exp(ln2 * val);
336KOKKOS_INLINE_FUNCTION
double exp2(T val) {
337 constexpr double ln2 = 0.693147180559945309417232121458176568L;
338 return exp(ln2 *
static_cast<double>(val));
341KOKKOS_IMPL_MATH_UNARY_FUNCTION(exp2)
343KOKKOS_IMPL_MATH_UNARY_FUNCTION(expm1)
344KOKKOS_IMPL_MATH_UNARY_FUNCTION(log)
345KOKKOS_IMPL_MATH_UNARY_FUNCTION(log10)
346KOKKOS_IMPL_MATH_UNARY_FUNCTION(log2)
347KOKKOS_IMPL_MATH_UNARY_FUNCTION(log1p)
349KOKKOS_IMPL_MATH_BINARY_FUNCTION(pow)
350KOKKOS_IMPL_MATH_UNARY_FUNCTION(sqrt)
351KOKKOS_IMPL_MATH_UNARY_FUNCTION(cbrt)
352KOKKOS_IMPL_MATH_BINARY_FUNCTION(hypot)
353#if defined(KOKKOS_ENABLE_CUDA) || defined(KOKKOS_ENABLE_HIP) || \
354 defined(KOKKOS_ENABLE_SYCL)
355KOKKOS_INLINE_FUNCTION
float hypot(
float x,
float y,
float z) {
356 return sqrt(x * x + y * y + z * z);
358KOKKOS_INLINE_FUNCTION
double hypot(
double x,
double y,
double z) {
359 return sqrt(x * x + y * y + z * z);
361inline long double hypot(
long double x,
long double y,
long double z) {
362 return sqrt(x * x + y * y + z * z);
364KOKKOS_INLINE_FUNCTION
float hypotf(
float x,
float y,
float z) {
365 return sqrt(x * x + y * y + z * z);
367inline long double hypotl(
long double x,
long double y,
long double z) {
368 return sqrt(x * x + y * y + z * z);
371 class T1,
class T2,
class T3,
372 class Promoted = std::enable_if_t<
373 std::is_arithmetic_v<T1> && std::is_arithmetic_v<T2> &&
374 std::is_arithmetic_v<T3> && !std::is_same_v<T1, long double> &&
375 !std::is_same_v<T2, long double> &&
376 !std::is_same_v<T3, long double>,
377 Impl::promote_3_t<T1, T2, T3>>>
378KOKKOS_INLINE_FUNCTION Promoted hypot(T1 x, T2 y, T3 z) {
379 return hypot(
static_cast<Promoted
>(x),
static_cast<Promoted
>(y),
380 static_cast<Promoted
>(z));
383 class T1,
class T2,
class T3,
384 class = std::enable_if_t<
385 std::is_arithmetic_v<T1> && std::is_arithmetic_v<T2> &&
386 std::is_arithmetic_v<T3> &&
387 (std::is_same_v<T1, long double> || std::is_same_v<T2, long double> ||
388 std::is_same_v<T3, long double>)>>
389inline long double hypot(T1 x, T2 y, T3 z) {
390 return hypot(
static_cast<long double>(x),
static_cast<long double>(y),
391 static_cast<long double>(z));
394KOKKOS_IMPL_MATH_TERNARY_FUNCTION(hypot)
397KOKKOS_IMPL_MATH_UNARY_FUNCTION(sin)
398KOKKOS_IMPL_MATH_UNARY_FUNCTION(cos)
399KOKKOS_IMPL_MATH_UNARY_FUNCTION(tan)
400KOKKOS_IMPL_MATH_UNARY_FUNCTION(asin)
401KOKKOS_IMPL_MATH_UNARY_FUNCTION(acos)
402KOKKOS_IMPL_MATH_UNARY_FUNCTION(atan)
403KOKKOS_IMPL_MATH_BINARY_FUNCTION(atan2)
405KOKKOS_IMPL_MATH_UNARY_FUNCTION(sinh)
406KOKKOS_IMPL_MATH_UNARY_FUNCTION(cosh)
407KOKKOS_IMPL_MATH_UNARY_FUNCTION(tanh)
408KOKKOS_IMPL_MATH_UNARY_FUNCTION(asinh)
409KOKKOS_IMPL_MATH_UNARY_FUNCTION(acosh)
410KOKKOS_IMPL_MATH_UNARY_FUNCTION(atanh)
412KOKKOS_IMPL_MATH_UNARY_FUNCTION(erf)
413KOKKOS_IMPL_MATH_UNARY_FUNCTION(erfc)
414KOKKOS_IMPL_MATH_UNARY_FUNCTION(tgamma)
415KOKKOS_IMPL_MATH_UNARY_FUNCTION(lgamma)
417KOKKOS_IMPL_MATH_UNARY_FUNCTION(ceil)
418KOKKOS_IMPL_MATH_UNARY_FUNCTION(floor)
419KOKKOS_IMPL_MATH_UNARY_FUNCTION(trunc)
420KOKKOS_IMPL_MATH_UNARY_FUNCTION(round)
424#ifndef KOKKOS_ENABLE_SYCL
425KOKKOS_IMPL_MATH_UNARY_FUNCTION(nearbyint)
437KOKKOS_IMPL_MATH_UNARY_FUNCTION(logb)
438KOKKOS_IMPL_MATH_BINARY_FUNCTION(nextafter)
440KOKKOS_IMPL_MATH_BINARY_FUNCTION(copysign)
443KOKKOS_IMPL_MATH_UNARY_PREDICATE(isfinite)
444KOKKOS_IMPL_MATH_UNARY_PREDICATE(isinf)
445KOKKOS_IMPL_MATH_UNARY_PREDICATE(isnan)
447KOKKOS_IMPL_MATH_UNARY_PREDICATE(signbit)
455#undef KOKKOS_IMPL_MATH_FUNCTIONS_NAMESPACE
456#undef KOKKOS_IMPL_MATH_UNARY_FUNCTION
457#undef KOKKOS_IMPL_MATH_UNARY_PREDICATE
458#undef KOKKOS_IMPL_MATH_BINARY_FUNCTION
459#undef KOKKOS_IMPL_MATH_TERNARY_FUNCTION
462KOKKOS_INLINE_FUNCTION
float rsqrt(
float val) {
463#if defined(KOKKOS_ENABLE_CUDA) || defined(KOKKOS_ENABLE_HIP)
464 KOKKOS_IF_ON_DEVICE(return ::rsqrtf(val);)
465 KOKKOS_IF_ON_HOST(
return 1.0f / Kokkos::sqrt(val);)
466#elif defined(KOKKOS_ENABLE_SYCL)
467 KOKKOS_IF_ON_DEVICE(
return sycl::rsqrt(val);)
468 KOKKOS_IF_ON_HOST(
return 1.0f / Kokkos::sqrt(val);)
470 return 1.0f / Kokkos::sqrt(val);
473KOKKOS_INLINE_FUNCTION
double rsqrt(
double val) {
474#if defined(KOKKOS_ENABLE_CUDA) || defined(KOKKOS_ENABLE_HIP)
475 KOKKOS_IF_ON_DEVICE(return ::rsqrt(val);)
476 KOKKOS_IF_ON_HOST(
return 1.0 / Kokkos::sqrt(val);)
477#elif defined(KOKKOS_ENABLE_SYCL)
478 KOKKOS_IF_ON_DEVICE(
return sycl::rsqrt(val);)
479 KOKKOS_IF_ON_HOST(
return 1.0 / Kokkos::sqrt(val);)
481 return 1.0 / Kokkos::sqrt(val);
484inline long double rsqrt(
long double val) {
return 1.0l / Kokkos::sqrt(val); }
485KOKKOS_INLINE_FUNCTION
float rsqrtf(
float x) {
return Kokkos::rsqrt(x); }
486inline long double rsqrtl(
long double x) {
return Kokkos::rsqrt(x); }
488KOKKOS_INLINE_FUNCTION std::enable_if_t<std::is_integral_v<T>,
double> rsqrt(
490 return Kokkos::rsqrt(
static_cast<double>(x));
495#ifdef KOKKOS_IMPL_PUBLIC_INCLUDE_NOTDEFINED_MATHFUNCTIONS
496#undef KOKKOS_IMPL_PUBLIC_INCLUDE
497#undef KOKKOS_IMPL_PUBLIC_INCLUDE_NOTDEFINED_MATHFUNCTIONS