ROL
ROL_Sketch.hpp
Go to the documentation of this file.
1// @HEADER
2// *****************************************************************************
3// Rapid Optimization Library (ROL) Package
4//
5// Copyright 2014 NTESS and the ROL contributors.
6// SPDX-License-Identifier: BSD-3-Clause
7// *****************************************************************************
8// @HEADER
9
10#ifndef ROL_SKETCH_H
11#define ROL_SKETCH_H
12
13#include "ROL_Vector.hpp"
14#include "ROL_LinearAlgebra.hpp"
15#include "ROL_LAPACK.hpp"
16#include "ROL_UpdateType.hpp"
17#include "ROL_Types.hpp"
18#include <random>
19#include <chrono>
20
28namespace ROL {
29
30template <class Real>
31class Sketch {
32private:
33 // Sketch storage
34 std::vector<Ptr<Vector<Real>>> Y_;
35 LA::Matrix<Real> X_, Z_, C_;
36
37 // Random dimension reduction maps
38 std::vector<Ptr<Vector<Real>>> Upsilon_, Phi_;
39 LA::Matrix<Real> Omega_, Psi_;
40
42
43 const Real orthTol_;
44 const int orthIt_;
45
46 const bool truncate_;
47
48 LAPACK<int,Real> lapack_;
49
51
52 Ptr<std::ostream> out_;
53
54 Ptr<Elementwise::NormalRandom<Real>> nrand_;
55 Ptr<std::mt19937_64> gen_;
56 Ptr<std::normal_distribution<Real>> dist_;
57
58 void mgs2(std::vector<Ptr<Vector<Real>>> &Y) const {
59 const int nvec(Y.size());
60 const Real zero(0), one(1);
61 Real rjj(0), rij(0);
62 std::vector<Real> normQ(nvec,0);
63 bool flag(true);
64 for (int j = 0; j < nvec; ++j) {
65 rjj = Y[j]->norm();
66 if (rjj > ROL_EPSILON<Real>()) { // Ignore update if Y[i] is zero.
67 for (int k = 0; k < orthIt_; ++k) {
68 for (int i = 0; i < j; ++i) {
69 rij = Y[i]->dot(*Y[j]);
70 Y[j]->axpy(-rij,*Y[i]);
71 }
72 normQ[j] = Y[j]->norm();
73 flag = true;
74 for (int i = 0; i < j; ++i) {
75 rij = std::abs(Y[i]->dot(*Y[j]));
76 if (rij > orthTol_*normQ[j]*normQ[i]) {
77 flag = false;
78 break;
79 }
80 }
81 if (flag) break;
82 }
83 }
84 rjj = normQ[j];
85 if (rjj > zero) Y[j]->scale(one/rjj);
86 }
87 }
88
89 int LSsolver(LA::Matrix<Real> &A, LA::Matrix<Real> &B, bool trans = false) const {
90 int flag(0);
91 char TRANS = (trans ? 'T' : 'N');
92 int M = A.numRows();
93 int N = A.numCols();
94 int NRHS = B.numCols();
95 int LDA = M;
96 int LDB = std::max(M,N);
97 std::vector<Real> WORK(1);
98 int LWORK = -1;
99 int INFO;
100 lapack_.GELS(TRANS,M,N,NRHS,A.values(),LDA,B.values(),LDB,&WORK[0],LWORK,&INFO);
101 flag += INFO;
102 LWORK = static_cast<int>(WORK[0]);
103 WORK.resize(LWORK);
104 lapack_.GELS(TRANS,M,N,NRHS,A.values(),LDA,B.values(),LDB,&WORK[0],LWORK,&INFO);
105 flag += INFO;
106 return flag;
107 }
108
109 int lowRankApprox(LA::Matrix<Real> &A, int r) const {
110 const Real zero(0);
111 char JOBU = 'S';
112 char JOBVT = 'S';
113 int M = A.numRows();
114 int N = A.numCols();
115 int K = std::min(M,N);
116 int LDA = M;
117 std::vector<Real> S(K);
118 LA::Matrix<Real> U(M,K);
119 int LDU = M;
120 LA::Matrix<Real> VT(K,N);
121 int LDVT = K;
122 std::vector<Real> WORK(1), WORK0(1);
123 int LWORK = -1;
124 int INFO;
125 lapack_.GESVD(JOBU,JOBVT,M,N,A.values(),LDA,&S[0],U.values(),LDU,VT.values(),LDVT,&WORK[0],LWORK,&WORK0[0],&INFO);
126 LWORK = static_cast<int>(WORK[0]);
127 WORK.resize(LWORK);
128 lapack_.GESVD(JOBU,JOBVT,M,N,A.values(),LDA,&S[0],U.values(),LDU,VT.values(),LDVT,&WORK[0],LWORK,&WORK0[0],&INFO);
129 for (int i = 0; i < M; ++i) {
130 for (int j = 0; j < N; ++j) {
131 A(i,j) = zero;
132 for (int k = 0; k < r; ++k) {
133 A(i,j) += S[k] * U(i,k) * VT(k,j);
134 }
135 }
136 }
137 return INFO;
138 }
139
140 int computeP(void) {
141 int INFO(0);
142 if (!flagP_) {
143 // Solve least squares problem using LAPACK
144 int M = ncol_;
145 int N = k_;
146 int K = std::min(M,N);
147 int LDA = M;
148 std::vector<Real> TAU(K);
149 std::vector<Real> WORK(1);
150 int LWORK = -1;
151 // Compute QR factorization of X
152 lapack_.GEQRF(M,N,X_.values(),LDA,&TAU[0],&WORK[0],LWORK,&INFO);
153 LWORK = static_cast<int>(WORK[0]);
154 WORK.resize(LWORK);
155 lapack_.GEQRF(M,N,X_.values(),LDA,&TAU[0],&WORK[0],LWORK,&INFO);
156 // Generate Q
157 LWORK = -1;
158 lapack_.ORGQR(M,N,K,X_.values(),LDA,&TAU[0],&WORK[0],LWORK,&INFO);
159 LWORK = static_cast<int>(WORK[0]);
160 WORK.resize(LWORK);
161 lapack_.ORGQR(M,N,K,X_.values(),LDA,&TAU[0],&WORK[0],LWORK,&INFO);
162 flagP_ = true;
163 }
164 return INFO;
165 }
166
167 int computeQ(void) {
168 if (!flagQ_) {
169 mgs2(Y_);
170 flagQ_ = true;
171 }
172 return 0;
173 }
174
175 int computeC(void) {
176 int infoP(0), infoQ(0), infoLS1(0), infoLS2(0), infoLRA(0);
177 infoP = computeP();
178 infoQ = computeQ();
179 if (!flagC_) {
180 const Real zero(0);
181 LA::Matrix<Real> L(s_,k_), R(s_,k_);
182 for (int i = 0; i < s_; ++i) {
183 for (int j = 0; j < k_; ++j) {
184 L(i,j) = Phi_[i]->dot(*Y_[j]);
185 R(i,j) = zero;
186 for (int k = 0; k < ncol_; ++k) R(i,j) += Psi_(k,i) * X_(k,j);
187 }
188 }
189 // Solve least squares problems using LAPACK
190 infoLS1 = LSsolver(L,Z_,false);
191 LA::Matrix<Real> Zmat(s_,k_);
192 for (int i = 0; i < k_; ++i) {
193 for (int j = 0; j < s_; ++j) Zmat(j,i) = Z_(i,j);
194 }
195 infoLS2 = LSsolver(R,Zmat,false);
196 for (int i = 0; i < k_; ++i) {
197 for (int j = 0; j < k_; ++j) C_(j,i) = Zmat(i,j);
198 }
199 // Compute best rank r approximation
200 if (truncate_) infoLRA = lowRankApprox(C_,rank_);
201 // Set flag
202 flagC_ = true;
203 }
204 return std::abs(infoP)+std::abs(infoQ)+std::abs(infoLS1)
205 +std::abs(infoLS2)+std::abs(infoLRA);
206 }
207
208public:
209 virtual ~Sketch(void) {}
210
211 Sketch(const Vector<Real> &x, int ncol, int rank,
212 Real orthTol = 1e-8, int orthIt = 2, bool truncate = false,
213 unsigned dom_seed = 0, unsigned rng_seed = 0)
214 : ncol_(ncol), orthTol_(orthTol), orthIt_(orthIt), truncate_(truncate),
215 flagP_(false), flagQ_(false), flagC_(false),
216 out_(nullPtr) {
217 Real mu(0), sig(1);
218 nrand_ = makePtr<Elementwise::NormalRandom<Real>>(mu,sig,dom_seed);
219 unsigned seed = rng_seed;
220 if (seed == 0) seed = std::chrono::system_clock::now().time_since_epoch().count();
221 gen_ = makePtr<std::mt19937_64>(seed);
222 dist_ = makePtr<std::normal_distribution<Real>>(mu,sig);
223 // Compute reduced dimensions
224 maxRank_ = std::min(ncol_, x.dimension());
225 rank_ = std::min(rank, maxRank_);
226 k_ = std::min(2*rank_+1, maxRank_);
227 s_ = std::min(2*k_ +1, maxRank_);
228 // Initialize matrix storage
229 Upsilon_.resize(k_); Phi_.resize(s_); Omega_.reshape(ncol_,k_); Psi_.reshape(ncol_,s_);
230 X_.reshape(ncol_,k_); Y_.resize(k_); Z_.reshape(s_,s_); C_.reshape(k_,k_);
231 for (int i = 0; i < k_; ++i) {
232 Y_[i] = x.clone();
233 Upsilon_[i] = x.clone();
234 }
235 for (int i = 0; i < s_; ++i) Phi_[i] = x.clone();
236 reset(true);
237 }
238
239 void setStream(Ptr<std::ostream> &out) {
240 out_ = out;
241 }
242
243 void reset(bool randomize = true) {
244 const Real zero(0);
245 X_.putScalar(zero); Z_.putScalar(zero); C_.putScalar(zero);
246 for (int i = 0; i < k_; ++i) Y_[i]->zero();
247 flagP_ = false; flagQ_ = false; flagC_ = false;
248 if (randomize) {
249 for (int i = 0; i < s_; ++i) {
250 Phi_[i]->applyUnary(*nrand_);
251 for (int j = 0; j < ncol_; ++j) Psi_(j,i) = (*dist_)(*gen_);
252 }
253 for (int i = 0; i < k_; ++i) {
254 Upsilon_[i]->applyUnary(*nrand_);
255 for (int j = 0; j < ncol_; ++j) Omega_(j,i) = (*dist_)(*gen_);
256 }
257 }
258 }
259
260 void setRank(int rank) {
261 rank_ = std::min(rank, maxRank_);
262 // Compute reduced dimensions
263 int sold = s_, kold = k_;
264 k_ = std::min(2*rank_+1, maxRank_);
265 s_ = std::min(2*k_ +1, maxRank_);
266 Omega_.reshape(ncol_,k_); Psi_.reshape(ncol_,s_);
267 X_.reshape(ncol_,k_); Z_.reshape(s_,s_); C_.reshape(k_,k_);
268 if (s_ > sold) {
269 for (int i = sold; i < s_; ++i) Phi_.push_back(Phi_[0]->clone());
270 }
271 if (k_ > kold) {
272 for (int i = kold; i < k_; ++i) {
273 Y_.push_back(Y_[0]->clone());
274 Upsilon_.push_back(Upsilon_[0]->clone());
275 }
276 }
277 reset(true);
278 if ( out_ != nullPtr ) {
279 *out_ << std::string(80,'=') << std::endl;
280 *out_ << " ROL::Sketch::setRank" << std::endl;
281 *out_ << " **** Rank = " << rank_ << std::endl;
282 *out_ << " **** k = " << k_ << std::endl;
283 *out_ << " **** s = " << s_ << std::endl;
284 *out_ << std::string(80,'=') << std::endl;
285 }
286 }
287
288 void update(void) {
289 reset(true);
290 }
291
292 int advance(Real nu, const Vector<Real> &h, int col, Real eta = 1.0) {
293 // Check to see if col is less than ncol_
294 if ( col >= ncol_ || col < 0 ) return 1; // Input column index out of range!
295 if (!flagP_ && !flagQ_ && !flagC_) {
296 for (int i = 0; i < k_; ++i) {
297 // Update X
298 for (int j = 0; j < ncol_; ++j) X_(j,i) *= eta;
299 X_(col,i) += nu*h.dot(*Upsilon_[i]);
300 // Update Y
301 Y_[i]->scale(eta);
302 Y_[i]->axpy(nu*Omega_(col,i),h);
303 }
304 // Update Z
305 Real hphi(0);
306 for (int i = 0; i < s_; ++i) {
307 hphi = h.dot(*Phi_[i]);
308 for (int j = 0; j < s_; ++j) {
309 Z_(i,j) *= eta;
310 Z_(i,j) += nu*Psi_(col,j)*hphi;
311 }
312 }
313 if ( out_ != nullPtr ) {
314 *out_ << std::string(80,'=') << std::endl;
315 *out_ << " ROL::Sketch::advance" << std::endl;
316 *out_ << " **** col = " << col << std::endl;
317 *out_ << " **** norm(h) = " << h.norm() << std::endl;
318 *out_ << std::string(80,'=') << std::endl;
319 }
320 }
321 else {
322 // Reconstruct has already been called!
323 return 1;
324 }
325 return 0;
326 }
327
328 int reconstruct(Vector<Real> &a, const int col) {
329 // Check to see if col is less than ncol_
330 if ( col >= ncol_ || col < 0 ) return 2; // Input column index out of range!
331 const Real zero(0);
332 int flag(0);
333 // Compute QR factorization of X store in X
334 flag = computeP();
335 if (flag > 0 ) return 3;
336 // Compute QR factorization of Y store in Y
337 flag = computeQ();
338 if (flag > 0 ) return 4;
339 // Compute (Phi Q)\Z/(Psi P)* store in C
340 flag = computeC();
341 if (flag > 0 ) return 5;
342 // Recover sketch
343 a.zero();
344 Real coeff(0);
345 for (int i = 0; i < k_; ++i) {
346 coeff = zero;
347 for (int j = 0; j < k_; ++j) coeff += C_(i,j) * X_(col,j);
348 a.axpy(coeff,*Y_[i]);
349 }
350 if ( out_ != nullPtr ) {
351 *out_ << std::string(80,'=') << std::endl;
352 *out_ << " ROL::Sketch::reconstruct" << std::endl;
353 *out_ << " **** col = " << col << std::endl;
354 *out_ << " **** norm(a) = " << a.norm() << std::endl;
355 *out_ << std::string(80,'=') << std::endl;
356 }
357 return 0;
358 }
359
360 bool test(const int rank, std::ostream &outStream = std::cout, const int verbosity = 0) {
361 const Real one(1), tol(std::sqrt(ROL_EPSILON<Real>()));
362 using seed_type = std::mt19937_64::result_type;
363 seed_type const seed = 123;
364 std::mt19937_64 eng{seed};
365 std::uniform_real_distribution<Real> dist(static_cast<Real>(0),static_cast<Real>(1));
366 // Initialize low rank factors
367 std::vector<Ptr<Vector<Real>>> U(rank);
368 LA::Matrix<Real> V(ncol_,rank);
369 for (int i = 0; i < rank; ++i) {
370 U[i] = Y_[0]->clone();
371 U[i]->randomize(-one,one);
372 for (int j = 0; j < ncol_; ++j) V(j,i) = dist(eng);
373 }
374 // Initialize A and build sketch
375 update();
376 std::vector<Ptr<Vector<Real>>> A(ncol_);
377 for (int i = 0; i < ncol_; ++i) {
378 A[i] = Y_[0]->clone(); A[i]->zero();
379 for (int j = 0; j < rank; ++j) {
380 A[i]->axpy(V(i,j),*U[j]);
381 }
382 advance(one,*A[i],i,one);
383 }
384 // Test QR decomposition of X
385 bool flagP = testP(outStream, verbosity);
386 // Test QR decomposition of Y
387 bool flagQ = testQ(outStream, verbosity);
388 // Test reconstruction of A
389 Real nerr(0), maxerr(0);
390 Ptr<Vector<Real>> err = Y_[0]->clone();
391 for (int i = 0; i < ncol_; ++i) {
392 reconstruct(*err,i);
393 err->axpy(-one,*A[i]);
394 nerr = err->norm();
395 maxerr = (nerr > maxerr ? nerr : maxerr);
396 }
397 if (verbosity > 0) {
398 std::ios_base::fmtflags oflags(outStream.flags());
399 outStream << std::scientific << std::setprecision(3) << std::endl;
400 outStream << " TEST RECONSTRUCTION: Max Error = "
401 << std::setw(12) << std::right << maxerr
402 << std::endl << std::endl;
403 outStream.flags(oflags);
404 }
405 return flagP & flagQ & (maxerr < tol ? true : false);
406 }
407
408private:
409
410 // Test functions
411 bool testQ(std::ostream &outStream = std::cout, const int verbosity = 0) {
412 const Real one(1), tol(std::sqrt(ROL_EPSILON<Real>()));
413 computeQ();
414 Real qij(0), err(0), maxerr(0);
415 std::ios_base::fmtflags oflags(outStream.flags());
416 if (verbosity > 0) outStream << std::scientific << std::setprecision(3);
417 if (verbosity > 1) {
418 outStream << std::endl
419 << " Printing Q'Q...This should be approximately equal to I"
420 << std::endl << std::endl;
421 }
422 for (int i = 0; i < k_; ++i) {
423 for (int j = 0; j < k_; ++j) {
424 qij = Y_[i]->dot(*Y_[j]);
425 err = (i==j ? std::abs(qij-one) : std::abs(qij));
426 maxerr = (err > maxerr ? err : maxerr);
427 if (verbosity > 1) outStream << std::setw(12) << std::right << qij;
428 }
429 if (verbosity > 1) outStream << std::endl;
430 if (maxerr > tol) break;
431 }
432 if (verbosity > 0) {
433 outStream << std::endl << " TEST ORTHOGONALIZATION: Max Error = "
434 << std::setw(12) << std::right << maxerr
435 << std::endl;
436 outStream.flags(oflags);
437 }
438 return (maxerr < tol ? true : false);
439 }
440
441 bool testP(std::ostream &outStream = std::cout, const int verbosity = 0) {
442 const Real zero(0), one(1), tol(std::sqrt(ROL_EPSILON<Real>()));
443 computeP();
444 Real qij(0), err(0), maxerr(0);
445 std::ios_base::fmtflags oflags(outStream.flags());
446 if (verbosity > 0) outStream << std::scientific << std::setprecision(3);
447 if (verbosity > 1) {
448 outStream << std::endl
449 << " Printing P'P...This should be approximately equal to I"
450 << std::endl << std::endl;
451 }
452 for (int i = 0; i < k_; ++i) {
453 for (int j = 0; j < k_; ++j) {
454 qij = zero;
455 for (int k = 0; k < ncol_; ++k) qij += X_(k,i) * X_(k,j);
456 err = (i==j ? std::abs(qij-one) : std::abs(qij));
457 maxerr = (err > maxerr ? err : maxerr);
458 if (verbosity > 1) outStream << std::setw(12) << std::right << qij;
459 }
460 if (verbosity > 1) outStream << std::endl;
461 if (maxerr > tol) break;
462 }
463 if (verbosity > 0) {
464 outStream << std::endl << " TEST ORTHOGONALIZATION: Max Error = "
465 << std::setw(12) << std::right << maxerr
466 << std::endl;
467 outStream.flags(oflags);
468 }
469 return (maxerr < tol ? true : false);
470 }
471
472}; // class Sketch
473
474} // namespace ROL
475
476#endif
Vector< Real > V
Objective_SerialSimOpt(const Ptr< Obj > &obj, const V &ui) z0_ zero()
Contains definitions of custom data types in ROL.
Provides an interface for randomized sketching.
int advance(Real nu, const Vector< Real > &h, int col, Real eta=1.0)
int computeC(void)
LA::Matrix< Real > C_
Ptr< std::mt19937_64 > gen_
int LSsolver(LA::Matrix< Real > &A, LA::Matrix< Real > &B, bool trans=false) const
void setRank(int rank)
std::vector< Ptr< Vector< Real > > > Upsilon_
void mgs2(std::vector< Ptr< Vector< Real > > > &Y) const
LA::Matrix< Real > X_
LA::Matrix< Real > Psi_
virtual ~Sketch(void)
int reconstruct(Vector< Real > &a, const int col)
bool testP(std::ostream &outStream=std::cout, const int verbosity=0)
LA::Matrix< Real > Omega_
const Real orthTol_
LAPACK< int, Real > lapack_
Ptr< std::ostream > out_
Sketch(const Vector< Real > &x, int ncol, int rank, Real orthTol=1e-8, int orthIt=2, bool truncate=false, unsigned dom_seed=0, unsigned rng_seed=0)
void update(void)
std::vector< Ptr< Vector< Real > > > Y_
Ptr< std::normal_distribution< Real > > dist_
LA::Matrix< Real > Z_
int computeP(void)
void setStream(Ptr< std::ostream > &out)
void reset(bool randomize=true)
Ptr< Elementwise::NormalRandom< Real > > nrand_
const bool truncate_
std::vector< Ptr< Vector< Real > > > Phi_
bool testQ(std::ostream &outStream=std::cout, const int verbosity=0)
bool test(const int rank, std::ostream &outStream=std::cout, const int verbosity=0)
const int orthIt_
int computeQ(void)
int lowRankApprox(LA::Matrix< Real > &A, int r) const
Defines the linear algebra or vector space interface.
virtual Real norm() const =0
Returns where .
virtual void zero()
Set to zero vector.
virtual ROL::Ptr< Vector > clone() const =0
Clone to make a new (uninitialized) vector.
virtual int dimension() const
Return dimension of the vector space.
virtual void axpy(const Real alpha, const Vector &x)
Compute where .
virtual Real dot(const Vector &x) const =0
Compute where .