ROL
ROL_FDivergence.hpp
Go to the documentation of this file.
1// @HEADER
2// *****************************************************************************
3// Rapid Optimization Library (ROL) Package
4//
5// Copyright 2014 NTESS and the ROL contributors.
6// SPDX-License-Identifier: BSD-3-Clause
7// *****************************************************************************
8// @HEADER
9
10#ifndef ROL_FDIVERGENCE_HPP
11#define ROL_FDIVERGENCE_HPP
12
14#include "ROL_Types.hpp"
15
50namespace ROL {
51
52template<class Real>
53class FDivergence : public RandVarFunctional<Real> {
54private:
55 Real thresh_;
56
57 Real valLam_;
59 Real valMu_;
60 Real valMu2_;
61
62 using RandVarFunctional<Real>::val_;
63 using RandVarFunctional<Real>::gv_;
64 using RandVarFunctional<Real>::g_;
65 using RandVarFunctional<Real>::hv_;
67
68 using RandVarFunctional<Real>::point_;
70
75
76 void checkInputs(void) const {
77 Real zero(0);
78 ROL_TEST_FOR_EXCEPTION((thresh_ <= zero), std::invalid_argument,
79 ">>> ERROR (ROL::FDivergence): Threshold must be positive!");
80 }
81
82public:
87 FDivergence(const Real thresh) : RandVarFunctional<Real>(), thresh_(thresh),
88 valLam_(0),valLam2_(0), valMu_(0), valMu2_(0) {
90 }
91
100 FDivergence(ROL::ParameterList &parlist) : RandVarFunctional<Real>(),
101 valLam_(0),valLam2_(0), valMu_(0), valMu2_(0) {
102 ROL::ParameterList &list
103 = parlist.sublist("SOL").sublist("Risk Measure").sublist("F-Divergence");
104 thresh_ = list.get<Real>("Threshold");
105 checkInputs();
106 }
107
115 virtual Real Fprimal(Real x, int deriv = 0) const = 0;
116
129 virtual Real Fdual(Real x, int deriv = 0) const = 0;
130
131 bool check(std::ostream &outStream = std::cout) const {
132 const Real tol(std::sqrt(ROL_EPSILON<Real>()));
133 bool flag = true;
134
135 Real x = static_cast<Real>(rand())/static_cast<Real>(RAND_MAX);
136 Real t = static_cast<Real>(rand())/static_cast<Real>(RAND_MAX);
137 Real fp = Fprimal(x);
138 Real fd = Fdual(t);
139 outStream << "Check Fenchel-Young Inequality: F(x) + F*(t) >= xt" << std::endl;
140 outStream << "x = " << x << std::endl;
141 outStream << "t = " << t << std::endl;
142 outStream << "F(x) = " << fp << std::endl;
143 outStream << "F*(t) = " << fd << std::endl;
144 outStream << "Is Valid? " << (fp+fd >= x*t) << std::endl;
145 flag = (fp+fd >= x*t) ? flag : false;
146
147 x = static_cast<Real>(rand())/static_cast<Real>(RAND_MAX);
148 t = Fprimal(x,1);
149 fp = Fprimal(x);
150 fd = Fdual(t);
151 outStream << "Check Fenchel-Young Equality: F(x) + F(t) = xt for t = d/dx F(x)" << std::endl;
152 outStream << "x = " << x << std::endl;
153 outStream << "t = " << t << std::endl;
154 outStream << "F(x) = " << fp << std::endl;
155 outStream << "F*(t) = " << fd << std::endl;
156 outStream << "Is Valid? " << (std::abs(fp+fd - x*t)<=tol) << std::endl;
157 flag = (std::abs(fp+fd - x*t)<=tol) ? flag : false;
158
159 t = static_cast<Real>(rand())/static_cast<Real>(RAND_MAX);
160 x = Fdual(t,1);
161 fp = Fprimal(x);
162 fd = Fdual(t);
163 outStream << "Check Fenchel-Young Equality: F(x) + F(t) = xt for x = d/dt F*(t)" << std::endl;
164 outStream << "x = " << x << std::endl;
165 outStream << "t = " << t << std::endl;
166 outStream << "F(x) = " << fp << std::endl;
167 outStream << "F*(t) = " << fd << std::endl;
168 outStream << "Is Valid? " << (std::abs(fp+fd - x*t)<=tol) << std::endl;
169 flag = (std::abs(fp+fd - x*t)<=tol) ? flag : false;
170
171 return flag;
172 }
173
174 void initialize(const Vector<Real> &x) {
176 valLam_ = 0; valLam2_ = 0; valMu_ = 0; valMu2_ = 0;
177 }
178
179 // Value update and get functions
181 const Vector<Real> &x,
182 const std::vector<Real> &xstat,
183 Real &tol) {
184 Real val = computeValue(obj,x,tol);
185 Real xlam = xstat[0];
186 Real xmu = xstat[1];
187 Real r = Fdual((val-xmu)/xlam,0);
188 val_ += weight_ * r;
189 }
190
191 Real getValue(const Vector<Real> &x,
192 const std::vector<Real> &xstat,
193 SampleGenerator<Real> &sampler) {
194 Real val(0);
195 sampler.sumAll(&val_,&val,1);
196 Real xlam = xstat[0];
197 Real xmu = xstat[1];
198 return xlam*(thresh_ + val) + xmu;
199 }
200
201 // Gradient update and get functions
203 const Vector<Real> &x,
204 const std::vector<Real> &xstat,
205 Real &tol) {
206 Real val = computeValue(obj,x,tol);
207 Real xlam = xstat[0];
208 Real xmu = xstat[1];
209 Real inp = (val-xmu)/xlam;
210 Real r0 = Fdual(inp,0), r1 = Fdual(inp,1);
211
212 if (std::abs(r0) >= ROL_EPSILON<Real>()) {
213 val_ += weight_ * r0;
214 }
215 if (std::abs(r1) >= ROL_EPSILON<Real>()) {
216 valLam_ -= weight_ * r1 * inp;
217 valMu_ -= weight_ * r1;
218 computeGradient(*dualVector_,obj,x,tol);
219 g_->axpy(weight_*r1,*dualVector_);
220 }
221 }
222
224 std::vector<Real> &gstat,
225 const Vector<Real> &x,
226 const std::vector<Real> &xstat,
227 SampleGenerator<Real> &sampler) {
228 std::vector<Real> mygval(3), gval(3);
229 mygval[0] = val_;
230 mygval[1] = valLam_;
231 mygval[2] = valMu_;
232 sampler.sumAll(&mygval[0],&gval[0],3);
233
234 gstat[0] = thresh_ + gval[0] + gval[1];
235 gstat[1] = static_cast<Real>(1) + gval[2];
236
237 sampler.sumAll(*g_,g);
238 }
239
241 const Vector<Real> &v,
242 const std::vector<Real> &vstat,
243 const Vector<Real> &x,
244 const std::vector<Real> &xstat,
245 Real &tol) {
246 Real val = computeValue(obj,x,tol);
247 Real xlam = xstat[0];
248 Real xmu = xstat[1];
249 Real vlam = vstat[0];
250 Real vmu = vstat[1];
251 Real inp = (val-xmu)/xlam;
252 Real r1 = Fdual(inp,1), r2 = Fdual(inp,2);
253 if (std::abs(r2) >= ROL_EPSILON<Real>()) {
254 Real gv = computeGradVec(*dualVector_,obj,v,x,tol);
255 val_ += weight_ * r2 * inp;
256 valLam_ += weight_ * r2 * inp * inp;
257 valLam2_ -= weight_ * r2 * gv * inp;
258 valMu_ += weight_ * r2;
259 valMu2_ -= weight_ * r2 * gv;
260 hv_->axpy(weight_ * r2 * (gv - vmu - vlam*inp)/xlam, *dualVector_);
261 }
262 if (std::abs(r1) >= ROL_EPSILON<Real>()) {
263 computeHessVec(*dualVector_,obj,v,x,tol);
264 hv_->axpy(weight_ * r1, *dualVector_);
265 }
266 }
267
269 std::vector<Real> &hvstat,
270 const Vector<Real> &v,
271 const std::vector<Real> &vstat,
272 const Vector<Real> &x,
273 const std::vector<Real> &xstat,
274 SampleGenerator<Real> &sampler) {
275 std::vector<Real> myhval(5), hval(5);
276 myhval[0] = val_;
277 myhval[1] = valLam_;
278 myhval[2] = valLam2_;
279 myhval[3] = valMu_;
280 myhval[4] = valMu2_;
281 sampler.sumAll(&myhval[0],&hval[0],5);
282
283 std::vector<Real> stat(2);
284 Real xlam = xstat[0];
285 //Real xmu = xstat[1];
286 Real vlam = vstat[0];
287 Real vmu = vstat[1];
288 hvstat[0] = (vlam * hval[1] + vmu * hval[0] + hval[2])/xlam;
289 hvstat[1] = (vlam * hval[0] + vmu * hval[3] + hval[4])/xlam;
290
291 sampler.sumAll(*hv_,hv);
292 }
293};
294
295}
296
297#endif
Objective_SerialSimOpt(const Ptr< Obj > &obj, const V &ui) z0_ zero()
Contains definitions of custom data types in ROL.
Provides a general interface for the F-divergence distributionally robust expectation.
void updateGradient(Objective< Real > &obj, const Vector< Real > &x, const std::vector< Real > &xstat, Real &tol)
Update internal risk measure storage for gradient computation.
void updateValue(Objective< Real > &obj, const Vector< Real > &x, const std::vector< Real > &xstat, Real &tol)
Update internal storage for value computation.
FDivergence(ROL::ParameterList &parlist)
Constructor.
void getHessVec(Vector< Real > &hv, std::vector< Real > &hvstat, const Vector< Real > &v, const std::vector< Real > &vstat, const Vector< Real > &x, const std::vector< Real > &xstat, SampleGenerator< Real > &sampler)
Return risk measure Hessian-times-a-vector.
void getGradient(Vector< Real > &g, std::vector< Real > &gstat, const Vector< Real > &x, const std::vector< Real > &xstat, SampleGenerator< Real > &sampler)
Return risk measure (sub)gradient.
virtual Real Fprimal(Real x, int deriv=0) const =0
Implementation of the scalar primal F function.
Real getValue(const Vector< Real > &x, const std::vector< Real > &xstat, SampleGenerator< Real > &sampler)
Return risk measure value.
void initialize(const Vector< Real > &x)
Initialize temporary variables.
FDivergence(const Real thresh)
Constructor.
void checkInputs(void) const
bool check(std::ostream &outStream=std::cout) const
virtual Real Fdual(Real x, int deriv=0) const =0
Implementation of the scalar dual F function.
void updateHessVec(Objective< Real > &obj, const Vector< Real > &v, const std::vector< Real > &vstat, const Vector< Real > &x, const std::vector< Real > &xstat, Real &tol)
Update internal risk measure storage for Hessian-time-a-vector computation.
Provides the interface to evaluate objective functions.
Provides the interface to implement any functional that maps a random variable to a (extended) real n...
Real computeValue(Objective< Real > &obj, const Vector< Real > &x, Real &tol)
virtual void initialize(const Vector< Real > &x)
Initialize temporary variables.
void computeHessVec(Vector< Real > &hv, Objective< Real > &obj, const Vector< Real > &v, const Vector< Real > &x, Real &tol)
void computeGradient(Vector< Real > &g, Objective< Real > &obj, const Vector< Real > &x, Real &tol)
Ptr< Vector< Real > > dualVector_
Real computeGradVec(Vector< Real > &g, Objective< Real > &obj, const Vector< Real > &v, const Vector< Real > &x, Real &tol)
void sumAll(Real *input, Real *output, int dim) const
Defines the linear algebra or vector space interface.