Tpetra parallel linear algebra Version of the Day
Loading...
Searching...
No Matches
Tpetra_TsqrAdaptor.hpp
Go to the documentation of this file.
1// @HEADER
2// *****************************************************************************
3// Tpetra: Templated Linear Algebra Services Package
4//
5// Copyright 2008 NTESS and the Tpetra contributors.
6// SPDX-License-Identifier: BSD-3-Clause
7// *****************************************************************************
8// @HEADER
9
10#ifndef TPETRA_TSQRADAPTOR_HPP
11#define TPETRA_TSQRADAPTOR_HPP
12
16
17#include "Tpetra_ConfigDefs.hpp"
18
19#ifdef HAVE_TPETRA_TSQR
20#include "Tsqr_NodeTsqrFactory.hpp" // create intranode TSQR object
21#include "Tsqr.hpp" // full (internode + intranode) TSQR
22#include "Tsqr_DistTsqr.hpp" // internode TSQR
23// Subclass of TSQR::MessengerBase, implemented using Teuchos
24// communicator template helper functions
25#include "Tsqr_TeuchosMessenger.hpp"
26#include "Tpetra_MultiVector.hpp"
27#include "Teuchos_ParameterListAcceptorDefaultBase.hpp"
28#include <stdexcept>
29
30namespace Tpetra {
31
53template <class MV>
54class TsqrAdaptor : public Teuchos::ParameterListAcceptorDefaultBase {
55 public:
56 using scalar_type = typename MV::scalar_type;
57 using ordinal_type = typename MV::local_ordinal_type;
58 using dense_matrix_type =
59 Teuchos::SerialDenseMatrix<ordinal_type, scalar_type>;
60 using magnitude_type =
61 typename Teuchos::ScalarTraits<scalar_type>::magnitudeType;
62
63 private:
64 using node_tsqr_factory_type =
65 TSQR::NodeTsqrFactory<scalar_type, ordinal_type,
66 typename MV::device_type>;
67 using node_tsqr_type = TSQR::NodeTsqr<ordinal_type, scalar_type>;
68 using dist_tsqr_type = TSQR::DistTsqr<ordinal_type, scalar_type>;
69 using tsqr_type = TSQR::Tsqr<ordinal_type, scalar_type>;
70
71 TSQR::MatView<ordinal_type, scalar_type>
72 get_mat_view(MV& X) {
73 TEUCHOS_ASSERT(!tsqr_.is_null());
74 // FIXME (mfh 18 Oct 2010, 22 Dec 2019) Check Teuchos::Comm<int>
75 // object in Q to make sure it is the same communicator as the
76 // one we are using in our dist_tsqr_type implementation.
77
78 const ordinal_type lclNumRows(X.getLocalLength());
79 const ordinal_type numCols(X.getNumVectors());
80 scalar_type* X_ptr = nullptr;
81 // LAPACK and BLAS functions require "LDA" >= 1, even if the
82 // corresponding matrix dimension is zero.
83 ordinal_type X_stride = 1;
84 if (tsqr_->wants_device_memory()) {
85 auto X_view = X.getLocalViewDevice(Access::ReadWrite);
86 X_ptr = reinterpret_cast<scalar_type*>(X_view.data());
87 X_stride = static_cast<ordinal_type>(X_view.stride(1));
88 if (X_stride == 0) {
89 X_stride = ordinal_type(1); // see note above
90 }
91 } else {
92 auto X_view = X.getLocalViewHost(Access::ReadWrite);
93 X_ptr = reinterpret_cast<scalar_type*>(X_view.data());
94 X_stride = static_cast<ordinal_type>(X_view.stride(1));
95 if (X_stride == 0) {
96 X_stride = ordinal_type(1); // see note above
97 }
98 }
99 using mat_view_type = TSQR::MatView<ordinal_type, scalar_type>;
100 return mat_view_type(lclNumRows, numCols, X_ptr, X_stride);
101 }
102
103 public:
110 TsqrAdaptor(const Teuchos::RCP<Teuchos::ParameterList>& plist)
111 : nodeTsqr_(node_tsqr_factory_type::getNodeTsqr())
112 , distTsqr_(new dist_tsqr_type)
113 , tsqr_(new tsqr_type(nodeTsqr_, distTsqr_)) {
114 setParameterList(plist);
115 }
116
118 TsqrAdaptor()
119 : nodeTsqr_(node_tsqr_factory_type::getNodeTsqr())
120 , distTsqr_(new dist_tsqr_type)
121 , tsqr_(new tsqr_type(nodeTsqr_, distTsqr_)) {
122 setParameterList(Teuchos::null);
123 }
124
126 Teuchos::RCP<const Teuchos::ParameterList>
127 getValidParameters() const {
128 if (defaultParams_.is_null()) {
129 auto params = Teuchos::parameterList("TSQR implementation");
130 params->set("NodeTsqr", *(nodeTsqr_->getValidParameters()));
131 params->set("DistTsqr", *(distTsqr_->getValidParameters()));
132 defaultParams_ = params;
133 }
134 return defaultParams_;
135 }
136
162 void
163 setParameterList(const Teuchos::RCP<Teuchos::ParameterList>& plist) {
164 auto params = plist.is_null() ? Teuchos::parameterList(*getValidParameters()) : plist;
165 using Teuchos::sublist;
166 nodeTsqr_->setParameterList(sublist(params, "NodeTsqr"));
167 distTsqr_->setParameterList(sublist(params, "DistTsqr"));
168
169 this->setMyParamList(params);
170 }
171
193 void
194 factorExplicit(MV& A,
195 MV& Q,
196 dense_matrix_type& R,
197 const bool forceNonnegativeDiagonal = false) {
198 TEUCHOS_TEST_FOR_EXCEPTION(!A.isConstantStride(), std::invalid_argument,
199 "TsqrAdaptor::"
200 "factorExplicit: Input MultiVector A must have constant stride.");
201 TEUCHOS_TEST_FOR_EXCEPTION(!Q.isConstantStride(), std::invalid_argument,
202 "TsqrAdaptor::"
203 "factorExplicit: Input MultiVector Q must have constant stride.");
204 prepareTsqr(Q); // Finish initializing TSQR.
205 TEUCHOS_ASSERT(!tsqr_.is_null());
206
207 auto A_view = get_mat_view(A);
208 auto Q_view = get_mat_view(Q);
209 constexpr bool contiguousCacheBlocks = false;
210 tsqr_->factorExplicitRaw(A_view.extent(0),
211 A_view.extent(1),
212 A_view.data(), A_view.stride(1),
213 Q_view.data(), Q_view.stride(1),
214 R.values(), R.stride(),
215 contiguousCacheBlocks,
216 forceNonnegativeDiagonal);
217 }
218
249 int revealRank(MV& Q,
250 dense_matrix_type& R,
251 const magnitude_type& tol) {
252 TEUCHOS_TEST_FOR_EXCEPTION(!Q.isConstantStride(), std::invalid_argument,
253 "TsqrAdaptor::"
254 "revealRank: Input MultiVector Q must have constant stride.");
255 prepareTsqr(Q); // Finish initializing TSQR.
256
257 auto Q_view = get_mat_view(Q);
258 constexpr bool contiguousCacheBlocks = false;
259 return tsqr_->revealRankRaw(Q_view.extent(0),
260 Q_view.extent(1),
261 Q_view.data(), Q_view.stride(1),
262 R.values(), R.stride(),
263 tol, contiguousCacheBlocks);
264 }
265
266 private:
268 Teuchos::RCP<node_tsqr_type> nodeTsqr_;
269
271 Teuchos::RCP<dist_tsqr_type> distTsqr_;
272
274 Teuchos::RCP<tsqr_type> tsqr_;
275
277 mutable Teuchos::RCP<const Teuchos::ParameterList> defaultParams_;
278
280 bool ready_ = false;
281
302 void
303 prepareTsqr(const MV& mv) {
304 if (!ready_) {
305 prepareDistTsqr(mv);
306 ready_ = true;
307 }
308 }
309
316 void
317 prepareDistTsqr(const MV& mv) {
318 using Teuchos::RCP;
319 using Teuchos::rcp_implicit_cast;
320 using mess_type = TSQR::TeuchosMessenger<scalar_type>;
321 using base_mess_type = TSQR::MessengerBase<scalar_type>;
322
323 auto comm = mv.getMap()->getComm();
324 RCP<mess_type> mess(new mess_type(comm));
325 auto messBase = rcp_implicit_cast<base_mess_type>(mess);
326 distTsqr_->init(messBase);
327 }
328};
329
330} // namespace Tpetra
331
332#endif // HAVE_TPETRA_TSQR
333
334#endif // TPETRA_TSQRADAPTOR_HPP
Namespace Tpetra contains the class and methods constituting the Tpetra library.