ROL
ROL_KLDivergence.hpp
Go to the documentation of this file.
1// @HEADER
2// *****************************************************************************
3// Rapid Optimization Library (ROL) Package
4//
5// Copyright 2014 NTESS and the ROL contributors.
6// SPDX-License-Identifier: BSD-3-Clause
7// *****************************************************************************
8// @HEADER
9
10#ifndef ROL_KLDIVERGENCE_HPP
11#define ROL_KLDIVERGENCE_HPP
12
14
44namespace ROL {
45
46template<class Real>
47class KLDivergence : public RandVarFunctional<Real> {
48private:
49 Real eps_;
50
51 Real gval_;
52 Real gvval_;
53 Real hval_;
54 ROL::Ptr<Vector<Real> > scaledGradient_;
55 ROL::Ptr<Vector<Real> > scaledHessVec_;
56
58
59 using RandVarFunctional<Real>::val_;
60 using RandVarFunctional<Real>::gv_;
61 using RandVarFunctional<Real>::g_;
62 using RandVarFunctional<Real>::hv_;
64
65 using RandVarFunctional<Real>::point_;
67
72
73 void checkInputs(void) const {
74 Real zero(0);
75 ROL_TEST_FOR_EXCEPTION((eps_ <= zero), std::invalid_argument,
76 ">>> ERROR (ROL::KLDivergence): Threshold must be positive!");
77 }
78
79public:
84 KLDivergence(const Real eps = 1.e-2)
85 : RandVarFunctional<Real>(), eps_(eps), firstResetKLD_(true) {
87 }
88
97 KLDivergence(ROL::ParameterList &parlist)
98 : RandVarFunctional<Real>(), firstResetKLD_(true) {
99 ROL::ParameterList &list
100 = parlist.sublist("SOL").sublist("Risk Measure").sublist("KL Divergence");
101 eps_ = list.get<Real>("Threshold");
102 checkInputs();
103 }
104
105 void initialize(const Vector<Real> &x) {
107 if ( firstResetKLD_ ) {
108 scaledGradient_ = x.dual().clone();
109 scaledHessVec_ = x.dual().clone();
110 firstResetKLD_ = false;
111 }
112 const Real zero(0);
113 gval_ = zero; gvval_ = zero; hval_ = zero;
114 scaledGradient_->zero(); scaledHessVec_->zero();
115 }
116
118 const Vector<Real> &x,
119 const std::vector<Real> &xstat,
120 Real &tol) {
121 Real val = computeValue(obj,x,tol);
122 Real ev = exponential(val,xstat[0]*eps_);
123 val_ += weight_ * ev;
124 }
125
126 Real getValue(const Vector<Real> &x,
127 const std::vector<Real> &xstat,
128 SampleGenerator<Real> &sampler) {
129 if ( xstat[0] == static_cast<Real>(0) ) {
130 return ROL_INF<Real>();
131 }
132 Real ev(0);
133 sampler.sumAll(&val_,&ev,1);
134 return (static_cast<Real>(1) + std::log(ev)/eps_)/xstat[0];
135 }
136
138 const Vector<Real> &x,
139 const std::vector<Real> &xstat,
140 Real &tol) {
141 Real val = computeValue(obj,x,tol);
142 Real ev = exponential(val,xstat[0]*eps_);
143 val_ += weight_ * ev;
144 gval_ += weight_ * ev * val;
145 computeGradient(*dualVector_,obj,x,tol);
146 g_->axpy(weight_*ev,*dualVector_);
147 }
148
150 std::vector<Real> &gstat,
151 const Vector<Real> &x,
152 const std::vector<Real> &xstat,
153 SampleGenerator<Real> &sampler) {
154 std::vector<Real> local(2), global(2);
155 local[0] = val_;
156 local[1] = gval_;
157 sampler.sumAll(&local[0],&global[0],2);
158 Real ev = global[0], egval = global[1];
159
160 sampler.sumAll(*g_,g);
161 g.scale(static_cast<Real>(1)/ev);
162
163 if ( xstat[0] == static_cast<Real>(0) ) {
164 gstat[0] = ROL_INF<Real>();
165 }
166 else {
167 gstat[0] = -((static_cast<Real>(1) + std::log(ev)/eps_)/xstat[0]
168 - egval/ev)/xstat[0];
169 }
170 }
171
173 const Vector<Real> &v,
174 const std::vector<Real> &vstat,
175 const Vector<Real> &x,
176 const std::vector<Real> &xstat,
177 Real &tol) {
178 Real val = computeValue(obj,x,tol);
179 Real ev = exponential(val,xstat[0]*eps_);
180 Real gv = computeGradVec(*dualVector_,obj,v,x,tol);
181 val_ += weight_ * ev;
182 gv_ += weight_ * ev * gv;
183 gval_ += weight_ * ev * val;
184 gvval_ += weight_ * ev * val * gv;
185 hval_ += weight_ * ev * val * val;
186 g_->axpy(weight_*ev,*dualVector_);
187 scaledGradient_->axpy(weight_*ev*gv,*dualVector_);
188 scaledHessVec_->axpy(weight_*ev*val,*dualVector_);
189 computeHessVec(*dualVector_,obj,v,x,tol);
190 hv_->axpy(weight_*ev,*dualVector_);
191 }
192
194 std::vector<Real> &hvstat,
195 const Vector<Real> &v,
196 const std::vector<Real> &vstat,
197 const Vector<Real> &x,
198 const std::vector<Real> &xstat,
199 SampleGenerator<Real> &sampler) {
200 std::vector<Real> local(5), global(5);
201 local[0] = val_;
202 local[1] = gv_;
203 local[2] = gval_;
204 local[3] = gvval_;
205 local[4] = hval_;
206 sampler.sumAll(&local[0],&global[0],5);
207 Real ev = global[0], egv = global[1], egval = global[2];
208 Real egvval = global[3], ehval = global[4];
209 Real c0 = static_cast<Real>(1)/ev, c1 = c0*egval, c2 = c0*egv, c3 = eps_*c0;
210
211 sampler.sumAll(*hv_,hv);
212 dualVector_->zero();
214 hv.axpy(xstat[0]*eps_,*dualVector_);
215 hv.scale(c0);
216
217 dualVector_->zero();
218 sampler.sumAll(*g_,*dualVector_);
219 hv.axpy(-c3*(vstat[0]*c1 + xstat[0]*c2),*dualVector_);
220
221 dualVector_->zero();
223 hv.axpy(vstat[0]*c3,*dualVector_);
224
225 if ( xstat[0] == static_cast<Real>(0) ) {
226 hvstat[0] = ROL_INF<Real>();
227 }
228 else {
229 Real xstat2 = static_cast<Real>(2)/(xstat[0]*xstat[0]);
230 Real h11 = xstat2*((static_cast<Real>(1) + std::log(ev)/eps_)/xstat[0] - c1)
231 + (c3*ehval - eps_*c1*c1)/xstat[0];
232 hvstat[0] = vstat[0] * h11 + (c3*egvval - eps_*c1*c2);
233 }
234 }
235
236private:
237 Real exponential(const Real arg1, const Real arg2) const {
238 if ( arg1 < arg2 ) {
239 return power(exponential(arg1),arg2);
240 }
241 else {
242 return power(exponential(arg2),arg1);
243 }
244 }
245
246 Real exponential(const Real arg) const {
247 if ( arg >= std::log(ROL_INF<Real>()) ) {
248 return ROL_INF<Real>();
249 }
250 else {
251 return std::exp(arg);
252 }
253 }
254
255 Real power(const Real arg, const Real pow) const {
256 if ( arg >= std::pow(ROL_INF<Real>(),static_cast<Real>(1)/pow) ) {
257 return ROL_INF<Real>();
258 }
259 else {
260 return std::pow(arg,pow);
261 }
262 }
263};
264
265}
266
267#endif
Objective_SerialSimOpt(const Ptr< Obj > &obj, const V &ui) z0_ zero()
Provides an interface for the Kullback-Leibler distributionally robust expectation.
void checkInputs(void) const
void updateValue(Objective< Real > &obj, const Vector< Real > &x, const std::vector< Real > &xstat, Real &tol)
Update internal storage for value computation.
void getHessVec(Vector< Real > &hv, std::vector< Real > &hvstat, const Vector< Real > &v, const std::vector< Real > &vstat, const Vector< Real > &x, const std::vector< Real > &xstat, SampleGenerator< Real > &sampler)
Return risk measure Hessian-times-a-vector.
KLDivergence(ROL::ParameterList &parlist)
Constructor.
void updateHessVec(Objective< Real > &obj, const Vector< Real > &v, const std::vector< Real > &vstat, const Vector< Real > &x, const std::vector< Real > &xstat, Real &tol)
Update internal risk measure storage for Hessian-time-a-vector computation.
void getGradient(Vector< Real > &g, std::vector< Real > &gstat, const Vector< Real > &x, const std::vector< Real > &xstat, SampleGenerator< Real > &sampler)
Return risk measure (sub)gradient.
Real getValue(const Vector< Real > &x, const std::vector< Real > &xstat, SampleGenerator< Real > &sampler)
Return risk measure value.
KLDivergence(const Real eps=1.e-2)
Constructor.
Real power(const Real arg, const Real pow) const
void initialize(const Vector< Real > &x)
Initialize temporary variables.
Real exponential(const Real arg) const
Real exponential(const Real arg1, const Real arg2) const
void updateGradient(Objective< Real > &obj, const Vector< Real > &x, const std::vector< Real > &xstat, Real &tol)
Update internal risk measure storage for gradient computation.
ROL::Ptr< Vector< Real > > scaledGradient_
ROL::Ptr< Vector< Real > > scaledHessVec_
Provides the interface to evaluate objective functions.
Provides the interface to implement any functional that maps a random variable to a (extended) real n...
Real computeValue(Objective< Real > &obj, const Vector< Real > &x, Real &tol)
virtual void initialize(const Vector< Real > &x)
Initialize temporary variables.
void computeHessVec(Vector< Real > &hv, Objective< Real > &obj, const Vector< Real > &v, const Vector< Real > &x, Real &tol)
void computeGradient(Vector< Real > &g, Objective< Real > &obj, const Vector< Real > &x, Real &tol)
Ptr< Vector< Real > > dualVector_
Real computeGradVec(Vector< Real > &g, Objective< Real > &obj, const Vector< Real > &v, const Vector< Real > &x, Real &tol)
void sumAll(Real *input, Real *output, int dim) const
Defines the linear algebra or vector space interface.
virtual void scale(const Real alpha)=0
Compute where .
virtual const Vector & dual() const
Return dual representation of , for example, the result of applying a Riesz map, or change of basis,...
virtual ROL::Ptr< Vector > clone() const =0
Clone to make a new (uninitialized) vector.
virtual void axpy(const Real alpha, const Vector &x)
Compute where .