MueLu Version of the Day
Loading...
Searching...
No Matches
MueLu_MatlabUtils.cpp
Go to the documentation of this file.
1// @HEADER
2// *****************************************************************************
3// MueLu: A package for multigrid based preconditioning
4//
5// Copyright 2012 NTESS and the MueLu contributors.
6// SPDX-License-Identifier: BSD-3-Clause
7// *****************************************************************************
8// @HEADER
9
11
12#if !defined(HAVE_MUELU_MATLAB) || !defined(HAVE_MUELU_TPETRA)
13#error "Muemex types require MATLAB and Tpetra."
14#else
15
16/* Stuff for MATLAB R2006b vs. previous versions */
17#if (defined(MX_API_VER) && MX_API_VER >= 0x07030000)
18#else
19typedef int mwIndex;
20#endif
21
22using namespace std;
23using namespace Teuchos;
24
25namespace MueLu {
26
27/* Explicit instantiation of MuemexData variants */
28template class MuemexData<RCP<Xpetra::MultiVector<double, mm_LocalOrd, mm_GlobalOrd, mm_node_t> > >;
29template class MuemexData<RCP<Xpetra::MultiVector<complex_t, mm_LocalOrd, mm_GlobalOrd, mm_node_t> > >;
30template class MuemexData<RCP<Xpetra::Matrix<double, mm_LocalOrd, mm_GlobalOrd, mm_node_t> > >;
31template class MuemexData<RCP<Xpetra::Matrix<complex_t, mm_LocalOrd, mm_GlobalOrd, mm_node_t> > >;
32template class MuemexData<RCP<MAggregates> >;
33template class MuemexData<RCP<MAmalInfo> >;
34template class MuemexData<int>;
35template class MuemexData<bool>;
36template class MuemexData<complex_t>;
37template class MuemexData<string>;
38template class MuemexData<double>;
39template class MuemexData<RCP<Tpetra::CrsMatrix<double, mm_LocalOrd, mm_GlobalOrd, mm_node_t> > >;
40template class MuemexData<RCP<Tpetra::CrsMatrix<complex_t, mm_LocalOrd, mm_GlobalOrd, mm_node_t> > >;
41template class MuemexData<RCP<Tpetra::MultiVector<double, mm_LocalOrd, mm_GlobalOrd, mm_node_t> > >;
42template class MuemexData<RCP<Tpetra::MultiVector<complex_t, mm_LocalOrd, mm_GlobalOrd, mm_node_t> > >;
43template class MuemexData<RCP<Xpetra::Vector<mm_LocalOrd, mm_LocalOrd, mm_GlobalOrd, mm_node_t> > >;
44
45// Flag set to true if MATLAB's CSC matrix index type is not int (usually false)
46bool rewrap_ints = sizeof(int) != sizeof(mwIndex);
47
48int* mwIndex_to_int(int N, mwIndex* mwi_array) {
49 // int* rv = (int*) malloc(N * sizeof(int));
50 int* rv = new int[N]; // not really better but may avoid confusion for valgrind
51 for (int i = 0; i < N; i++)
52 rv[i] = (int)mwi_array[i];
53 return rv;
54}
55
56/* ******************************* */
57/* Specializations */
58/* ******************************* */
59
60template <>
61mxArray* createMatlabSparse<double>(int numRows, int numCols, int nnz) {
62 return mxCreateSparse(numRows, numCols, nnz, mxREAL);
63}
64
65template <>
66mxArray* createMatlabSparse<complex_t>(int numRows, int numCols, int nnz) {
67 return mxCreateSparse(numRows, numCols, nnz, mxCOMPLEX);
68}
69
70template <>
71void fillMatlabArray<double>(double* array, const mxArray* mxa, int n) {
72 memcpy(mxGetPr(mxa), array, n * sizeof(double));
73}
74
75template <>
76void fillMatlabArray<complex_t>(complex_t* array, const mxArray* mxa, int n) {
77 double* pr = mxGetPr(mxa);
78 double* pi = mxGetPi(mxa);
79 for (int i = 0; i < n; i++) {
80 pr[i] = std::real<double>(array[i]);
81 pi[i] = std::imag<double>(array[i]);
82 }
83}
84
85/******************************/
86/* Callback Functions */
87/******************************/
88
89void callMatlabNoArgs(std::string function) {
90 int result = mexEvalString(function.c_str());
91 if (result != 0)
92 mexPrintf("An error occurred while running a MATLAB command.");
93}
94
95std::vector<RCP<MuemexArg> > callMatlab(std::string function, int numOutputs, std::vector<RCP<MuemexArg> > args) {
96 using Teuchos::rcp_static_cast;
97 mxArray** matlabArgs = new mxArray*[args.size()];
98 mxArray** matlabOutput = new mxArray*[numOutputs];
99 std::vector<RCP<MuemexArg> > output;
100
101 for (int i = 0; i < int(args.size()); i++) {
102 try {
103 switch (args[i]->type) {
104 case BOOL:
105 matlabArgs[i] = rcp_static_cast<MuemexData<bool>, MuemexArg>(args[i])->convertToMatlab();
106 break;
107 case INT:
108 matlabArgs[i] = rcp_static_cast<MuemexData<int>, MuemexArg>(args[i])->convertToMatlab();
109 break;
110 case DOUBLE:
111 matlabArgs[i] = rcp_static_cast<MuemexData<double>, MuemexArg>(args[i])->convertToMatlab();
112 break;
113 case STRING:
114 matlabArgs[i] = rcp_static_cast<MuemexData<std::string>, MuemexArg>(args[i])->convertToMatlab();
115 break;
116 case COMPLEX:
117 matlabArgs[i] = rcp_static_cast<MuemexData<complex_t>, MuemexArg>(args[i])->convertToMatlab();
118 break;
119 case XPETRA_MAP:
120 matlabArgs[i] = rcp_static_cast<MuemexData<RCP<Xpetra_map> >, MuemexArg>(args[i])->convertToMatlab();
121 break;
123 matlabArgs[i] = rcp_static_cast<MuemexData<RCP<Xpetra_ordinal_vector> >, MuemexArg>(args[i])->convertToMatlab();
124 break;
126 matlabArgs[i] = rcp_static_cast<MuemexData<RCP<Tpetra::MultiVector<double, mm_LocalOrd, mm_GlobalOrd, mm_node_t> > >, MuemexArg>(args[i])->convertToMatlab();
127 break;
129 matlabArgs[i] = rcp_static_cast<MuemexData<RCP<Tpetra::MultiVector<complex_t, mm_LocalOrd, mm_GlobalOrd, mm_node_t> > >, MuemexArg>(args[i])->convertToMatlab();
130 break;
132 matlabArgs[i] = rcp_static_cast<MuemexData<RCP<Tpetra::CrsMatrix<double, mm_LocalOrd, mm_GlobalOrd, mm_node_t> > >, MuemexArg>(args[i])->convertToMatlab();
133 break;
135 matlabArgs[i] = rcp_static_cast<MuemexData<RCP<Tpetra::CrsMatrix<complex_t, mm_LocalOrd, mm_GlobalOrd, mm_node_t> > >, MuemexArg>(args[i])->convertToMatlab();
136 break;
138 matlabArgs[i] = rcp_static_cast<MuemexData<RCP<Xpetra_Matrix_double> >, MuemexArg>(args[i])->convertToMatlab();
139 break;
141 matlabArgs[i] = rcp_static_cast<MuemexData<RCP<Xpetra_Matrix_complex> >, MuemexArg>(args[i])->convertToMatlab();
142 break;
144 matlabArgs[i] = rcp_static_cast<MuemexData<RCP<Xpetra_MultiVector_double> >, MuemexArg>(args[i])->convertToMatlab();
145 break;
147 matlabArgs[i] = rcp_static_cast<MuemexData<RCP<Xpetra_MultiVector_complex> >, MuemexArg>(args[i])->convertToMatlab();
148 break;
149 case AGGREGATES:
150 matlabArgs[i] = rcp_static_cast<MuemexData<RCP<MAggregates> >, MuemexArg>(args[i])->convertToMatlab();
151 break;
153 matlabArgs[i] = rcp_static_cast<MuemexData<RCP<MAmalInfo> >, MuemexArg>(args[i])->convertToMatlab();
154 break;
155 case GRAPH:
156 matlabArgs[i] = rcp_static_cast<MuemexData<RCP<MGraph> >, MuemexArg>(args[i])->convertToMatlab();
157#ifdef HAVE_MUELU_INTREPID2
158 case FIELDCONTAINER_ORDINAL:
159 matlabArgs[i] = rcp_static_cast<MuemexData<RCP<FieldContainer_ordinal> >, MuemexArg>(args[i])->convertToMatlab();
160 break;
161#endif
162 }
163 } catch (std::exception& e) {
164 mexPrintf("An error occurred while converting arg #%d to MATLAB:\n", i);
165 std::cout << e.what() << std::endl;
166 mexPrintf("Passing 0 instead.\n");
167 matlabArgs[i] = mxCreateDoubleScalar(0);
168 }
169 }
170 // now matlabArgs is populated with MATLAB data types
171 int result = mexCallMATLAB(numOutputs, matlabOutput, args.size(), matlabArgs, function.c_str());
172 if (result != 0)
173 mexPrintf("Matlab encountered an error while running command through muemexCallbacks.\n");
174 // now, if all went well, matlabOutput contains all the output to return to user
175 for (int i = 0; i < numOutputs; i++) {
176 try {
177 output.push_back(convertMatlabVar(matlabOutput[i]));
178 } catch (std::exception& e) {
179 mexPrintf("An error occurred while converting output #%d from MATLAB:\n", i);
180 std::cout << e.what() << std::endl;
181 }
182 }
183 delete[] matlabOutput;
184 delete[] matlabArgs;
185 return output;
186}
187
188/******************************/
189/* More utility functions */
190/******************************/
191
192template <>
193mxArray* createMatlabMultiVector<double>(int numRows, int numCols) {
194 return mxCreateDoubleMatrix(numRows, numCols, mxREAL);
195}
196
197template <>
198mxArray* createMatlabMultiVector<complex_t>(int numRows, int numCols) {
199 return mxCreateDoubleMatrix(numRows, numCols, mxCOMPLEX);
200}
201
202mxArray* saveAmalInfo(RCP<MAmalInfo>& amalInfo) {
203 throw runtime_error("AmalgamationInfo not supported in MueMex yet.");
204 return mxCreateDoubleScalar(0);
205}
206
208 bool isValidAggregates = true;
209 if (!mxIsStruct(mxa))
210 return false;
211 int numFields = mxGetNumberOfFields(mxa); // check that struct has correct # of fields
212 if (numFields != 5)
213 isValidAggregates = false;
214 if (isValidAggregates) {
215 const char* mem1 = mxGetFieldNameByNumber(mxa, 0);
216 if (mem1 == NULL || strcmp(mem1, "nVertices") != 0)
217 isValidAggregates = false;
218 const char* mem2 = mxGetFieldNameByNumber(mxa, 1);
219 if (mem2 == NULL || strcmp(mem2, "nAggregates") != 0)
220 isValidAggregates = false;
221 const char* mem3 = mxGetFieldNameByNumber(mxa, 2);
222 if (mem3 == NULL || strcmp(mem3, "vertexToAggID") != 0)
223 isValidAggregates = false;
224 const char* mem4 = mxGetFieldNameByNumber(mxa, 3);
225 if (mem3 == NULL || strcmp(mem4, "rootNodes") != 0)
226 isValidAggregates = false;
227 const char* mem5 = mxGetFieldNameByNumber(mxa, 4);
228 if (mem4 == NULL || strcmp(mem5, "aggSizes") != 0)
229 isValidAggregates = false;
230 }
231 return isValidAggregates;
232}
233
234bool isValidMatlabGraph(const mxArray* mxa) {
235 bool isValidGraph = true;
236 if (!mxIsStruct(mxa))
237 return false;
238 int numFields = mxGetNumberOfFields(mxa); // check that struct has correct # of fields
239 if (numFields != 2)
240 isValidGraph = false;
241 if (isValidGraph) {
242 const char* mem1 = mxGetFieldNameByNumber(mxa, 0);
243 if (mem1 == NULL || strcmp(mem1, "edges") != 0)
244 isValidGraph = false;
245 const char* mem2 = mxGetFieldNameByNumber(mxa, 1);
246 if (mem2 == NULL || strcmp(mem2, "boundaryNodes") != 0)
247 isValidGraph = false;
248 }
249 return isValidGraph;
250}
251
252std::vector<std::string> tokenizeList(const std::string& params) {
253 using namespace std;
254 vector<string> rlist;
255 const char* delims = ",";
256 char* copy = (char*)malloc(params.length() + 1);
257 strcpy(copy, params.c_str());
258 char* mark = (char*)strtok(copy, delims);
259 while (mark != NULL) {
260 // Remove leading and trailing whitespace in token
261 char* tail = mark + strlen(mark) - 1;
262 while (*mark == ' ')
263 mark++;
264 while (*tail == ' ' && tail > mark)
265 tail--;
266 tail++;
267 *tail = 0;
268 string tok(mark); // copies the characters to string object
269 rlist.push_back(tok);
270 mark = strtok(NULL, delims);
271 }
272 free(copy);
273 return rlist;
274}
275
276Teuchos::RCP<Teuchos::ParameterList> getInputParamList() {
277 using namespace Teuchos;
278 RCP<ParameterList> validParamList = rcp(new ParameterList());
279 validParamList->set<RCP<const FactoryBase> >("A", Teuchos::null, "Factory for the matrix A.");
280 validParamList->set<RCP<const FactoryBase> >("P", Teuchos::null, "Factory for the prolongator.");
281 validParamList->set<RCP<const FactoryBase> >("R", Teuchos::null, "Factory for the restrictor.");
282 validParamList->set<RCP<const FactoryBase> >("Ptent", Teuchos::null, "Factory for the tentative (unsmoothed) prolongator.");
283 validParamList->set<RCP<const FactoryBase> >("Coordinates", Teuchos::null, "Factory for the node coordinates.");
284 validParamList->set<RCP<const FactoryBase> >("Nullspace", Teuchos::null, "Factory for the nullspace.");
285 validParamList->set<RCP<const FactoryBase> >("Aggregates", Teuchos::null, "Factory for the aggregates.");
286 validParamList->set<RCP<const FactoryBase> >("UnamalgamationInfo", Teuchos::null, "Factory for amalgamation.");
287#ifdef HAVE_MUELU_INTREPID2
288 validParamList->set<RCP<const FactoryBase> >("pcoarsen: element to node map", Teuchos::null, "Generating factory of the element to node map");
289#endif
290 return validParamList;
291}
292
293Teuchos::RCP<MuemexArg> convertMatlabVar(const mxArray* mxa) {
294 switch (mxGetClassID(mxa)) {
295 case mxCHAR_CLASS:
296 // string
297 return rcp_implicit_cast<MuemexArg>(rcp(new MuemexData<std::string>(mxa)));
298 break;
299 case mxLOGICAL_CLASS:
300 // boolean
301 return rcp_implicit_cast<MuemexArg>(rcp(new MuemexData<bool>(mxa)));
302 break;
303 case mxINT32_CLASS:
304 if (mxGetM(mxa) == 1 && mxGetN(mxa) == 1)
305 // individual integer
306 return rcp_implicit_cast<MuemexArg>(rcp(new MuemexData<int>(mxa)));
307 else if (mxGetM(mxa) != 1 || mxGetN(mxa) != 1)
308 // ordinal vector
309 return rcp_implicit_cast<MuemexArg>(rcp(new MuemexData<RCP<Xpetra_ordinal_vector> >(mxa)));
310 else
311 throw std::runtime_error("Error: Don't know what to do with integer array.\n");
312 break;
313 case mxDOUBLE_CLASS:
314 if (mxGetM(mxa) == 1 && mxGetN(mxa) == 1) {
315 if (mxIsComplex(mxa))
316 // single double (scalar, real)
317 return rcp_implicit_cast<MuemexArg>(rcp(new MuemexData<complex_t>(mxa)));
318 else
319 // single complex scalar
320 return rcp_implicit_cast<MuemexArg>(rcp(new MuemexData<double>(mxa)));
321 } else if (mxIsSparse(mxa)) // use a CRS matrix
322 {
323 // Default to Tpetra matrix for this
324 if (mxIsComplex(mxa))
325 // complex matrix
326 return rcp_implicit_cast<MuemexArg>(rcp(new MuemexData<RCP<Xpetra_Matrix_complex> >(mxa)));
327 else
328 // real-valued matrix
329 return rcp_implicit_cast<MuemexArg>(rcp(new MuemexData<RCP<Xpetra_Matrix_double> >(mxa)));
330 } else {
331 // Default to Xpetra multivector for this case
332 if (mxIsComplex(mxa))
333 return rcp_implicit_cast<MuemexArg>(rcp(new MuemexData<RCP<Xpetra::MultiVector<complex_t, mm_LocalOrd, mm_GlobalOrd, mm_node_t> > >(mxa)));
334 else
335 return rcp_implicit_cast<MuemexArg>(rcp(new MuemexData<RCP<Xpetra::MultiVector<double, mm_LocalOrd, mm_GlobalOrd, mm_node_t> > >(mxa)));
336 }
337 break;
338 case mxSTRUCT_CLASS: {
339 // the only thing that should get here currently is an Aggregates struct or Graph struct
340 // verify that it has the correct fields with the correct types
341 // also assume that aggregates data will not be stored in an array of more than 1 element.
342 if (isValidMatlabAggregates(mxa)) {
343 return rcp_implicit_cast<MuemexArg>(rcp(new MuemexData<RCP<MAggregates> >(mxa)));
344 } else if (isValidMatlabGraph(mxa)) {
345 return rcp_implicit_cast<MuemexArg>(rcp(new MuemexData<RCP<MGraph> >(mxa)));
346 } else {
347 throw runtime_error("Invalid aggregates or graph struct passed in from MATLAB.");
348 return Teuchos::null;
349 }
350 break;
351 }
352 default:
353 throw std::runtime_error("MATLAB returned an unsupported type as a function output.\n");
354 return Teuchos::null;
355 }
356}
357
358/******************************/
359/* Explicit Instantiations */
360/******************************/
361
362template bool loadDataFromMatlab<bool>(const mxArray* mxa);
363template int loadDataFromMatlab<int>(const mxArray* mxa);
364template double loadDataFromMatlab<double>(const mxArray* mxa);
366template string loadDataFromMatlab<string>(const mxArray* mxa);
367template RCP<Xpetra_ordinal_vector> loadDataFromMatlab<RCP<Xpetra_ordinal_vector> >(const mxArray* mxa);
368template RCP<Tpetra_MultiVector_double> loadDataFromMatlab<RCP<Tpetra_MultiVector_double> >(const mxArray* mxa);
369template RCP<Tpetra_MultiVector_complex> loadDataFromMatlab<RCP<Tpetra_MultiVector_complex> >(const mxArray* mxa);
370template RCP<Tpetra_CrsMatrix_double> loadDataFromMatlab<RCP<Tpetra_CrsMatrix_double> >(const mxArray* mxa);
371template RCP<Tpetra_CrsMatrix_complex> loadDataFromMatlab<RCP<Tpetra_CrsMatrix_complex> >(const mxArray* mxa);
372template RCP<Xpetra_Matrix_double> loadDataFromMatlab<RCP<Xpetra_Matrix_double> >(const mxArray* mxa);
373template RCP<Xpetra_Matrix_complex> loadDataFromMatlab<RCP<Xpetra_Matrix_complex> >(const mxArray* mxa);
374template RCP<Xpetra_MultiVector_double> loadDataFromMatlab<RCP<Xpetra_MultiVector_double> >(const mxArray* mxa);
375template RCP<Xpetra_MultiVector_complex> loadDataFromMatlab<RCP<Xpetra_MultiVector_complex> >(const mxArray* mxa);
376template RCP<MAggregates> loadDataFromMatlab<RCP<MAggregates> >(const mxArray* mxa);
377template RCP<MAmalInfo> loadDataFromMatlab<RCP<MAmalInfo> >(const mxArray* mxa);
378
379template mxArray* saveDataToMatlab(bool& data);
380template mxArray* saveDataToMatlab(int& data);
381template mxArray* saveDataToMatlab(double& data);
382template mxArray* saveDataToMatlab(complex_t& data);
383template mxArray* saveDataToMatlab(string& data);
384template mxArray* saveDataToMatlab(RCP<Xpetra_ordinal_vector>& data);
385template mxArray* saveDataToMatlab(RCP<Tpetra_MultiVector_double>& data);
386template mxArray* saveDataToMatlab(RCP<Tpetra_MultiVector_complex>& data);
387template mxArray* saveDataToMatlab(RCP<Tpetra_CrsMatrix_double>& data);
388template mxArray* saveDataToMatlab(RCP<Tpetra_CrsMatrix_complex>& data);
389template mxArray* saveDataToMatlab(RCP<Xpetra_Matrix_double>& data);
390template mxArray* saveDataToMatlab(RCP<Xpetra_Matrix_complex>& data);
391template mxArray* saveDataToMatlab(RCP<Xpetra_MultiVector_double>& data);
392template mxArray* saveDataToMatlab(RCP<Xpetra_MultiVector_complex>& data);
393template mxArray* saveDataToMatlab(RCP<MAggregates>& data);
394template mxArray* saveDataToMatlab(RCP<MAmalInfo>& data);
395
396template vector<RCP<MuemexArg> > processNeeds<double>(const Factory* factory, string& needsParam, Level& lvl);
397template vector<RCP<MuemexArg> > processNeeds<complex_t>(const Factory* factory, string& needsParam, Level& lvl);
398template void processProvides<double>(vector<RCP<MuemexArg> >& mexOutput, const Factory* factory, string& providesParam, Level& lvl);
399template void processProvides<complex_t>(vector<RCP<MuemexArg> >& mexOutput, const Factory* factory, string& providesParam, Level& lvl);
400
401} // namespace MueLu
402#endif // HAVE_MUELU_MATLAB
int mwIndex
struct mxArray_tag mxArray
Class that holds all level-specific information.
Namespace for MueLu classes and methods.
template int loadDataFromMatlab< int >(const mxArray *mxa)
bool isValidMatlabGraph(const mxArray *mxa)
Teuchos::RCP< Teuchos::ParameterList > getInputParamList()
Teuchos::RCP< MuemexArg > convertMatlabVar(const mxArray *mxa)
template double loadDataFromMatlab< double >(const mxArray *mxa)
mxArray * createMatlabSparse< double >(int numRows, int numCols, int nnz)
template bool loadDataFromMatlab< bool >(const mxArray *mxa)
template vector< RCP< MuemexArg > > processNeeds< double >(const Factory *factory, string &needsParam, Level &lvl)
bool isValidMatlabAggregates(const mxArray *mxa)
std::vector< RCP< MuemexArg > > callMatlab(std::string function, int numOutputs, std::vector< RCP< MuemexArg > > args)
template string loadDataFromMatlab< string >(const mxArray *mxa)
template complex_t loadDataFromMatlab< complex_t >(const mxArray *mxa)
int * mwIndex_to_int(int N, mwIndex *mwi_array)
void fillMatlabArray< double >(double *array, const mxArray *mxa, int n)
template void processProvides< double >(vector< RCP< MuemexArg > > &mexOutput, const Factory *factory, string &providesParam, Level &lvl)
mxArray * saveAmalInfo(RCP< MAmalInfo > &amalInfo)
template void processProvides< complex_t >(vector< RCP< MuemexArg > > &mexOutput, const Factory *factory, string &providesParam, Level &lvl)
void fillMatlabArray< complex_t >(complex_t *array, const mxArray *mxa, int n)
mxArray * createMatlabMultiVector< complex_t >(int numRows, int numCols)
std::complex< double > complex_t
std::vector< std::string > tokenizeList(const std::string &params)
mxArray * createMatlabSparse< complex_t >(int numRows, int numCols, int nnz)
void callMatlabNoArgs(std::string function)
mxArray * createMatlabMultiVector< double >(int numRows, int numCols)
template vector< RCP< MuemexArg > > processNeeds< complex_t >(const Factory *factory, string &needsParam, Level &lvl)
template mxArray * saveDataToMatlab(bool &data)