Intrepid2
Intrepid2_Kernels.hpp
Go to the documentation of this file.
1// @HEADER
2// *****************************************************************************
3// Intrepid2 Package
4//
5// Copyright 2007 NTESS and the Intrepid2 contributors.
6// SPDX-License-Identifier: BSD-3-Clause
7// *****************************************************************************
8// @HEADER
9
15#ifndef __INTREPID2_KERNELS_HPP__
16#define __INTREPID2_KERNELS_HPP__
17
18#include "Intrepid2_ConfigDefs.hpp"
19
20#include "Intrepid2_Types.hpp"
21#include "Intrepid2_Utils.hpp"
22
23#include "Kokkos_Core.hpp"
24
25namespace Intrepid2 {
26
27 namespace Kernels {
28
29 struct Serial {
30 template<typename ScalarType,
31 typename AViewType,
32 typename BViewType,
33 typename CViewType>
34 KOKKOS_INLINE_FUNCTION
35 static void
36 gemm_trans_notrans(const ScalarType alpha,
37 const AViewType &A,
38 const BViewType &B,
39 const ScalarType beta,
40 const CViewType &C) {
41 //C = beta*C + alpha * A'*B
42 const ordinal_type
43 m = C.extent(0),
44 n = C.extent(1),
45 k = B.extent(0);
46
47 for (ordinal_type i=0;i<m;++i)
48 for (ordinal_type j=0;j<n;++j) {
49 C(i,j) *= beta;
50 for (ordinal_type l=0;l<k;++l)
51 C(i,j) += alpha*A(l,i)*B(l,j);
52 }
53 }
54
55 template<typename ScalarType,
56 typename AViewType,
57 typename BViewType,
58 typename CViewType>
59 KOKKOS_INLINE_FUNCTION
60 static void
61 gemm_notrans_trans(const ScalarType alpha,
62 const AViewType &A,
63 const BViewType &B,
64 const ScalarType beta,
65 const CViewType &C) {
66 //C = beta*C + alpha * A*B'
67 const ordinal_type
68 m = C.extent(0),
69 n = C.extent(1),
70 k = A.extent(1);
71
72 for (ordinal_type i=0;i<m;++i)
73 for (ordinal_type j=0;j<n;++j) {
74 C(i,j) *= beta;
75 for (ordinal_type l=0;l<k;++l)
76 C(i,j) += alpha*A(i,l)*B(j,l);
77 }
78 }
79
80 template<typename ScalarType,
81 typename AViewType,
82 typename xViewType,
83 typename yViewType>
84 KOKKOS_INLINE_FUNCTION
85 static void
86 gemv_trans(const ScalarType alpha,
87 const AViewType &A,
88 const xViewType &x,
89 const ScalarType beta,
90 const yViewType &y) {
91 //y = beta*y + alpha * A'*x
92 const ordinal_type
93 m = y.extent(0),
94 n = x.extent(0);
95
96 for (ordinal_type i=0;i<m;++i) {
97 y(i) *= beta;
98 for (ordinal_type j=0;j<n;++j)
99 y(i) += alpha*A(j,i)*x(j);
100 }
101 }
102
103 template<typename ScalarType,
104 typename AViewType,
105 typename xViewType,
106 typename yViewType>
107 KOKKOS_INLINE_FUNCTION
108 static void
109 gemv_notrans(const ScalarType alpha,
110 const AViewType &A,
111 const xViewType &x,
112 const ScalarType beta,
113 const yViewType &y) {
114 //y = beta*y + alpha * A*x
115 const ordinal_type
116 m = y.extent(0),
117 n = x.extent(0);
118
119 for (ordinal_type i=0;i<m;++i) {
120 y(i) *= beta;
121 for (ordinal_type j=0;j<n;++j)
122 y(i) += alpha*A(i,j)*x(j);
123 }
124 }
125
126 template<typename matViewType>
127 KOKKOS_INLINE_FUNCTION
128 static typename matViewType::non_const_value_type
129 determinant(const matViewType &mat) {
130 INTREPID2_TEST_FOR_ABORT(mat.extent(0) != mat.extent(1), "mat should be a square matrix.");
131 INTREPID2_TEST_FOR_ABORT(mat.extent(0) > 3, "Higher dimensions (> 3) are not supported.");
132
133 typename matViewType::non_const_value_type r_val(0);
134 const int m = mat.extent(0);
135 switch (m) {
136 case 1:
137 r_val = mat(0,0);
138 break;
139 case 2:
140 r_val = ( mat(0,0) * mat(1,1) -
141 mat(0,1) * mat(1,0) );
142 break;
143 case 3:
144 r_val = ( mat(0,0) * mat(1,1) * mat(2,2) +
145 mat(1,0) * mat(2,1) * mat(0,2) +
146 mat(2,0) * mat(0,1) * mat(1,2) -
147 mat(2,0) * mat(1,1) * mat(0,2) -
148 mat(0,0) * mat(2,1) * mat(1,2) -
149 mat(1,0) * mat(0,1) * mat(2,2) );
150 break;
151 }
152 return r_val;
153 }
154
155 template<typename matViewType,
156 typename invViewType>
157 KOKKOS_INLINE_FUNCTION
158 static void
159 inverse(const invViewType &inv,
160 const matViewType &mat) {
161 INTREPID2_TEST_FOR_ABORT(mat.extent(0) != mat.extent(1), "mat should be a square matrix.");
162 INTREPID2_TEST_FOR_ABORT(inv.extent(0) != inv.extent(1), "inv should be a square matrix.");
163 INTREPID2_TEST_FOR_ABORT(mat.extent(0) != inv.extent(0), "mat and inv must have the same dimension.");
164 INTREPID2_TEST_FOR_ABORT(mat.extent(0) > 3, "Higher dimensions (> 3) are not supported.");
165 INTREPID2_TEST_FOR_ABORT(mat.data() == inv.data(), "mat and inv must have different data pointer.");
166
167 const auto val = determinant(mat);
168 const int m = mat.extent(0);
169 switch (m) {
170 case 1: {
171 inv(0,0) = 1.0/mat(0,0);
172 break;
173 }
174 case 2: {
175 inv(0,0) = mat(1,1)/val;
176 inv(1,1) = mat(0,0)/val;
177
178 inv(1,0) = - mat(1,0)/val;
179 inv(0,1) = - mat(0,1)/val;
180 break;
181 }
182 case 3: {
183 {
184 const auto val0 = mat(1,1)*mat(2,2) - mat(2,1)*mat(1,2);
185 const auto val1 = - mat(1,0)*mat(2,2) + mat(2,0)*mat(1,2);
186 const auto val2 = mat(1,0)*mat(2,1) - mat(2,0)*mat(1,1);
187
188 inv(0,0) = val0/val;
189 inv(1,0) = val1/val;
190 inv(2,0) = val2/val;
191 }
192 {
193 const auto val0 = mat(2,1)*mat(0,2) - mat(0,1)*mat(2,2);
194 const auto val1 = mat(0,0)*mat(2,2) - mat(2,0)*mat(0,2);
195 const auto val2 = - mat(0,0)*mat(2,1) + mat(2,0)*mat(0,1);
196
197 inv(0,1) = val0/val;
198 inv(1,1) = val1/val;
199 inv(2,1) = val2/val;
200 }
201 {
202 const auto val0 = mat(0,1)*mat(1,2) - mat(1,1)*mat(0,2);
203 const auto val1 = - mat(0,0)*mat(1,2) + mat(1,0)*mat(0,2);
204 const auto val2 = mat(0,0)*mat(1,1) - mat(1,0)*mat(0,1);
205
206 inv(0,2) = val0/val;
207 inv(1,2) = val1/val;
208 inv(2,2) = val2/val;
209 }
210 break;
211 }
212 }
213 }
214
215 template<typename ScalarType,
216 typename xViewType,
217 typename yViewType,
218 typename zViewType>
219 KOKKOS_INLINE_FUNCTION
220 static void
221 z_is_axby(const zViewType &z,
222 const ScalarType alpha,
223 const xViewType &x,
224 const ScalarType beta,
225 const yViewType &y) {
226 //y = beta*y + alpha*x
227 const ordinal_type
228 m = z.extent(0);
229
230 for (ordinal_type i=0;i<m;++i)
231 z(i) = alpha*x(i) + beta*y(i);
232 }
233
234 template<typename AViewType>
235 KOKKOS_INLINE_FUNCTION
236 static double
237 norm(const AViewType &A, const ENorm normType) {
238 typedef typename AViewType::non_const_value_type value_type;
239 const ordinal_type m = A.extent(0), n = A.extent(1);
240 double r_val = 0;
241 switch(normType) {
242 case NORM_TWO:{
243 for (ordinal_type i=0;i<m;++i)
244 for (ordinal_type j=0;j<n;++j)
245 r_val += A.access(i,j)*A.access(i,j);
246 r_val = sqrt(r_val);
247 break;
248 }
249 case NORM_INF:{
250 for (ordinal_type i=0;i<m;++i)
251 for (ordinal_type j=0;j<n;++j) {
252 const value_type current = Util<value_type>::abs(A.access(i,j));
253 r_val = (r_val < current ? current : r_val);
254 }
255 break;
256 }
257 case NORM_ONE:{
258 for (ordinal_type i=0;i<m;++i)
259 for (ordinal_type j=0;j<n;++j)
260 r_val += Util<value_type>::abs(A.access(i,j));
261 break;
262 }
263 default: {
264 Kokkos::abort("norm type is not supported");
265 break;
266 }
267 }
268 return r_val;
269 }
270
271 template<typename dstViewType,
272 typename srcViewType>
273 KOKKOS_INLINE_FUNCTION
274 static void
275 copy(const dstViewType &dst, const srcViewType &src) {
276 if (dst.data() != src.data()) {
277 const ordinal_type m = dst.extent(0), n = dst.extent(1);
278 for (ordinal_type i=0;i<m;++i)
279 for (ordinal_type j=0;j<n;++j)
280 dst.access(i,j) = src.access(i,j);
281 }
282 }
283
284 // y = Ax
285 template<typename yViewType,
286 typename AViewType,
287 typename xViewType>
288 KOKKOS_FORCEINLINE_FUNCTION
289 static void
290 matvec_trans_product_d2( const yViewType &y,
291 const AViewType &A,
292 const xViewType &x ) {
293 y(0) = A(0,0)*x(0) + A(1,0)*x(1);
294 y(1) = A(0,1)*x(0) + A(1,1)*x(1);
295 }
296
297 template<typename yViewType,
298 typename AViewType,
299 typename xViewType>
300 KOKKOS_FORCEINLINE_FUNCTION
301 static void
302 matvec_trans_product_d3( const yViewType &y,
303 const AViewType &A,
304 const xViewType &x ) {
305 y(0) = A(0,0)*x(0) + A(1,0)*x(1) + A(2,0)*x(2);
306 y(1) = A(0,1)*x(0) + A(1,1)*x(1) + A(2,1)*x(2);
307 y(2) = A(0,2)*x(0) + A(1,2)*x(1) + A(2,2)*x(2);
308 }
309
310 // y = Ax
311 template<typename yViewType,
312 typename AViewType,
313 typename xViewType>
314 KOKKOS_FORCEINLINE_FUNCTION
315 static void
316 matvec_product_d2( const yViewType &y,
317 const AViewType &A,
318 const xViewType &x ) {
319 y(0) = A(0,0)*x(0) + A(0,1)*x(1);
320 y(1) = A(1,0)*x(0) + A(1,1)*x(1);
321 }
322
323 template<typename yViewType,
324 typename AViewType,
325 typename xViewType>
326 KOKKOS_FORCEINLINE_FUNCTION
327 static void
328 matvec_product_d3( const yViewType &y,
329 const AViewType &A,
330 const xViewType &x ) {
331 y(0) = A(0,0)*x(0) + A(0,1)*x(1) + A(0,2)*x(2);
332 y(1) = A(1,0)*x(0) + A(1,1)*x(1) + A(1,2)*x(2);
333 y(2) = A(2,0)*x(0) + A(2,1)*x(1) + A(2,2)*x(2);
334 }
335
336 template<typename yViewType,
337 typename AViewType,
338 typename xViewType>
339 KOKKOS_FORCEINLINE_FUNCTION
340 static void
341 matvec_product( const yViewType &y,
342 const AViewType &A,
343 const xViewType &x ) {
344 switch (y.extent(0)) {
345 case 2: matvec_product_d2(y, A, x); break;
346 case 3: matvec_product_d3(y, A, x); break;
347 default: {
348 INTREPID2_TEST_FOR_ABORT(true, "matvec only support dimension 2 and 3 (consider to use gemv interface).");
349 break;
350 }
351 }
352 }
353
354 template<typename zViewType,
355 typename xViewType,
356 typename yViewType>
357 KOKKOS_FORCEINLINE_FUNCTION
358 static void
359 vector_product_d2( const zViewType &z,
360 const xViewType &x,
361 const yViewType &y ) {
362 z(0) = x(0)*y(1) - x(1)*y(0);
363 }
364
365 template<typename zViewType,
366 typename xViewType,
367 typename yViewType>
368 KOKKOS_FORCEINLINE_FUNCTION
369 static void
370 vector_product_d3( const zViewType &z,
371 const xViewType &x,
372 const yViewType &y ) {
373 z(0) = x(1)*y(2) - x(2)*y(1);
374 z(1) = x(2)*y(0) - x(0)*y(2);
375 z(2) = x(0)*y(1) - x(1)*y(0);
376 }
377
378
379 };
380
381
382
383 template<typename xViewType,
384 typename yViewType>
385 KOKKOS_FORCEINLINE_FUNCTION
386 static typename xViewType::value_type
387 dot( const xViewType x,
388 const yViewType y ) {
389 typename xViewType::value_type r_val(0);
390 ordinal_type i = 0, iend = x.extent(0);
391 for (;i<iend;i+=4)
392 r_val += ( x(i )*y(i ) +
393 x(i+1)*y(i+1) +
394 x(i+2)*y(i+2) +
395 x(i+3)*y(i+3) );
396 for (;i<iend;++i)
397 r_val += x(i)*y(i);
398
399 return r_val;
400 }
401
402 template<typename xViewType,
403 typename yViewType>
404 KOKKOS_FORCEINLINE_FUNCTION
405 static typename xViewType::value_type
406 dot_d2( const xViewType x,
407 const yViewType y ) {
408 return ( x(0)*y(0) + x(1)*y(1) );
409 }
410
411 template<typename xViewType,
412 typename yViewType>
413 KOKKOS_FORCEINLINE_FUNCTION
414 static typename xViewType::value_type
415 dot_d3( const xViewType x,
416 const yViewType y ) {
417 return ( x(0)*y(0) + x(1)*y(1) + x(2)*y(2) );
418 }
419
420 template<typename AViewType,
421 typename alphaScalarType>
422 KOKKOS_FORCEINLINE_FUNCTION
423 static void
424 scale_mat( AViewType &A,
425 const alphaScalarType alpha ) {
426 const ordinal_type
427 iend = A.extent(0),
428 jend = A.extent(1);
429
430 for (ordinal_type i=0;i<iend;++i)
431 for (ordinal_type j=0;j<jend;++j)
432 A.access(i,j) *= alpha;
433 }
434
435 template<typename AViewType,
436 typename alphaScalarType>
437 KOKKOS_FORCEINLINE_FUNCTION
438 static void
439 inv_scale_mat( AViewType &A,
440 const alphaScalarType alpha ) {
441 const ordinal_type
442 iend = A.extent(0),
443 jend = A.extent(1);
444
445 for (ordinal_type i=0;i<iend;++i)
446 for (ordinal_type j=0;j<jend;++j)
447 A.access(i,j) /= alpha;
448 }
449
450 template<typename AViewType,
451 typename alphaScalarType,
452 typename BViewType>
453 KOKKOS_FORCEINLINE_FUNCTION
454 static void
455 scalar_mult_mat( AViewType &A,
456 const alphaScalarType alpha,
457 const BViewType &B ) {
458 const ordinal_type
459 iend = A.extent(0),
460 jend = A.extent(1);
461
462 for (ordinal_type i=0;i<iend;++i)
463 for (ordinal_type j=0;j<jend;++j)
464 A.access(i,j) = alpha*B.access(i,j);
465 }
466
467 template<typename AViewType,
468 typename alphaScalarType,
469 typename BViewType>
470 KOKKOS_FORCEINLINE_FUNCTION
471 static void
472 inv_scalar_mult_mat( AViewType &A,
473 const alphaScalarType alpha,
474 const BViewType &B ) {
475 const ordinal_type
476 iend = A.extent(0),
477 jend = A.extent(1);
478
479 for (ordinal_type i=0;i<iend;++i)
480 for (ordinal_type j=0;j<jend;++j)
481 A.access(i,j) = B.access(i,j)/alpha;
482 }
483
484 }
485}
486
487#endif
Contains definitions of custom data types in Intrepid2.
Header function for Intrepid2::Util class and other utility functions.
small utility functions