Amesos2 - Direct Sparse Solver Interfaces Version of the Day
Amesos2_cuSOLVER_FunctionMap.hpp
1// @HEADER
2// *****************************************************************************
3// Amesos2: Templated Direct Sparse Solver Package
4//
5// Copyright 2011 NTESS and the Amesos2 contributors.
6// SPDX-License-Identifier: BSD-3-Clause
7// *****************************************************************************
8// @HEADER
9
10#ifndef AMESOS2_CUSOLVER_FUNCTIONMAP_HPP
11#define AMESOS2_CUSOLVER_FUNCTIONMAP_HPP
12
14#include "Amesos2_cuSOLVER_TypeMap.hpp"
15
16#include <cuda.h>
17#include <cusolverSp.h>
18#include <cusolverDn.h>
19#include <cusparse.h>
20#include <cusolverSp_LOWLEVEL_PREVIEW.h>
21
22#ifdef HAVE_TEUCHOS_COMPLEX
23#include <cuComplex.h>
24#endif
25
26namespace Amesos2 {
27
28 template <>
29 struct FunctionMap<cuSOLVER,double>
30 {
31 static cusolverStatus_t bufferInfo(
32 cusolverSpHandle_t handle,
33 int size,
34 int nnz,
35 cusparseMatDescr_t & desc,
36 const double * values,
37 const int * rowPtr,
38 const int * colIdx,
39 csrcholInfo_t & chol_info,
40 size_t * internalDataInBytes,
41 size_t * workspaceInBytes)
42 {
43 cusolverStatus_t status =
44 cusolverSpDcsrcholBufferInfo(handle, size, nnz, desc, values,
45 rowPtr, colIdx, chol_info, internalDataInBytes, workspaceInBytes);
46 return status;
47 }
48
49 static cusolverStatus_t numeric(
50 cusolverSpHandle_t handle,
51 int size,
52 int nnz,
53 cusparseMatDescr_t & desc,
54 const double * values,
55 const int * rowPtr,
56 const int * colIdx,
57 csrcholInfo_t & chol_info,
58 void * buffer)
59 {
60 cusolverStatus_t status = cusolverSpDcsrcholFactor(
61 handle, size, nnz, desc, values, rowPtr, colIdx, chol_info, buffer);
62 cudaDeviceSynchronize();
63 return status;
64 }
65
66 static cusolverStatus_t solve(
67 cusolverSpHandle_t handle,
68 int size,
69 const double * b,
70 double * x,
71 csrcholInfo_t & chol_info,
72 void * buffer)
73 {
74 cusolverStatus_t status = cusolverSpDcsrcholSolve(
75 handle, size, b, x, chol_info, buffer);
76 cudaDeviceSynchronize();
77 return status;
78 }
79 };
80
81 template <>
82 struct FunctionMap<cuSOLVER,float>
83 {
84 static cusolverStatus_t bufferInfo(
85 cusolverSpHandle_t handle,
86 int size,
87 int nnz,
88 cusparseMatDescr_t & desc,
89 const float * values,
90 const int * rowPtr,
91 const int * colIdx,
92 csrcholInfo_t & chol_info,
93 size_t * internalDataInBytes,
94 size_t * workspaceInBytes)
95 {
96 cusolverStatus_t status =
97 cusolverSpScsrcholBufferInfo(handle, size, nnz, desc, values,
98 rowPtr, colIdx, chol_info, internalDataInBytes, workspaceInBytes);
99 return status;
100 }
101
102 static cusolverStatus_t numeric(
103 cusolverSpHandle_t handle,
104 int size,
105 int nnz,
106 cusparseMatDescr_t & desc,
107 const float * values,
108 const int * rowPtr,
109 const int * colIdx,
110 csrcholInfo_t & chol_info,
111 void * buffer)
112 {
113 cusolverStatus_t status = cusolverSpScsrcholFactor(
114 handle, size, nnz, desc, values, rowPtr, colIdx, chol_info, buffer);
115 cudaDeviceSynchronize();
116 return status;
117 }
118
119 static cusolverStatus_t solve(
120 cusolverSpHandle_t handle,
121 int size,
122 const float * b,
123 float * x,
124 csrcholInfo_t & chol_info,
125 void * buffer)
126 {
127 cusolverStatus_t status = cusolverSpScsrcholSolve(
128 handle, size, b, x, chol_info, buffer);
129 cudaDeviceSynchronize();
130 return status;
131 }
132 };
133
134#ifdef HAVE_TEUCHOS_COMPLEX
135 template <>
136 struct FunctionMap<cuSOLVER,Kokkos::complex<double>>
137 {
138 static cusolverStatus_t bufferInfo(
139 cusolverSpHandle_t handle,
140 int size,
141 int nnz,
142 cusparseMatDescr_t & desc,
143 const void * values,
144 const int * rowPtr,
145 const int * colIdx,
146 csrcholInfo_t & chol_info,
147 size_t * internalDataInBytes,
148 size_t * workspaceInBytes)
149 {
150 typedef cuDoubleComplex scalar_t;
151 const scalar_t * cu_values = reinterpret_cast<const scalar_t *>(values);
152 cusolverStatus_t status =
153 cusolverSpZcsrcholBufferInfo(handle, size, nnz, desc,
154 cu_values, rowPtr, colIdx, chol_info,
155 internalDataInBytes, workspaceInBytes);
156 return status;
157 }
158
159 static cusolverStatus_t numeric(
160 cusolverSpHandle_t handle,
161 int size,
162 int nnz,
163 cusparseMatDescr_t & desc,
164 const void * values,
165 const int * rowPtr,
166 const int * colIdx,
167 csrcholInfo_t & chol_info,
168 void * buffer)
169 {
170 typedef cuDoubleComplex scalar_t;
171 const scalar_t * cu_values =
172 reinterpret_cast<const scalar_t *>(values);
173 cusolverStatus_t status = cusolverSpZcsrcholFactor(
174 handle, size, nnz, desc, cu_values, rowPtr, colIdx, chol_info, buffer);
175 cudaDeviceSynchronize();
176 return status;
177 }
178
179 static cusolverStatus_t solve(
180 cusolverSpHandle_t handle,
181 int size,
182 const void * b,
183 void * x,
184 csrcholInfo_t & chol_info,
185 void * buffer)
186 {
187 typedef cuDoubleComplex scalar_t;
188 const scalar_t * cu_b = reinterpret_cast<const scalar_t *>(b);
189 scalar_t * cu_x = reinterpret_cast<scalar_t *>(x);
190 cusolverStatus_t status = cusolverSpZcsrcholSolve(
191 handle, size, cu_b, cu_x, chol_info, buffer);
192 cudaDeviceSynchronize();
193 return status;
194 }
195 };
196
197 template <>
198 struct FunctionMap<cuSOLVER,Kokkos::complex<float>>
199 {
200 static cusolverStatus_t bufferInfo(
201 cusolverSpHandle_t handle,
202 int size,
203 int nnz,
204 cusparseMatDescr_t & desc,
205 const void * values,
206 const int * rowPtr,
207 const int * colIdx,
208 csrcholInfo_t & chol_info,
209 size_t * internalDataInBytes,
210 size_t * workspaceInBytes)
211 {
212 typedef cuFloatComplex scalar_t;
213 const scalar_t * cu_values = reinterpret_cast<const scalar_t *>(values);
214 cusolverStatus_t status =
215 cusolverSpCcsrcholBufferInfo(handle, size, nnz, desc,
216 cu_values, rowPtr, colIdx, chol_info,
217 internalDataInBytes, workspaceInBytes);
218 return status;
219 }
220
221 static cusolverStatus_t numeric(
222 cusolverSpHandle_t handle,
223 int size,
224 int nnz,
225 cusparseMatDescr_t & desc,
226 const void * values,
227 const int * rowPtr,
228 const int * colIdx,
229 csrcholInfo_t & chol_info,
230 void * buffer)
231 {
232 typedef cuFloatComplex scalar_t;
233 const scalar_t * cu_values = reinterpret_cast<const scalar_t *>(values);
234 cusolverStatus_t status = cusolverSpCcsrcholFactor(
235 handle, size, nnz, desc, cu_values, rowPtr, colIdx, chol_info, buffer);
236 cudaDeviceSynchronize();
237 return status;
238 }
239
240 static cusolverStatus_t solve(
241 cusolverSpHandle_t handle,
242 int size,
243 const void * b,
244 void * x,
245 csrcholInfo_t & chol_info,
246 void * buffer)
247 {
248 typedef cuFloatComplex scalar_t;
249 const scalar_t * cu_b = reinterpret_cast<const scalar_t *>(b);
250 scalar_t * cu_x = reinterpret_cast<scalar_t *>(x);
251 cusolverStatus_t status = cusolverSpCcsrcholSolve(
252 handle, size, cu_b, cu_x, chol_info, buffer);
253 cudaDeviceSynchronize();
254 return status;
255 }
256 };
257#endif
258
259} // end namespace Amesos2
260
261#endif // AMESOS2_CUSOLVER_FUNCTIONMAP_HPP
Declaration of Function mapping class for Amesos2.
const int size
Definition klu2_simple.cpp:50