ROL
ROL_PH_RiskObjective.hpp
Go to the documentation of this file.
1// @HEADER
2// *****************************************************************************
3// Rapid Optimization Library (ROL) Package
4//
5// Copyright 2014 NTESS and the ROL contributors.
6// SPDX-License-Identifier: BSD-3-Clause
7// *****************************************************************************
8// @HEADER
9
10#ifndef PH_RISKOBJECTIVE_H
11#define PH_RISKOBJECTIVE_H
12
13#include "ROL_Objective.hpp"
15
22namespace ROL {
23
24template <class Real>
25class PH_RiskObjective : public Objective<Real> {
26private:
27 const Ptr<Objective<Real>> obj_;
28 Ptr<ExpectationQuad<Real>> quad_;
29
31 Real val_;
32
35 Ptr<Vector<Real>> g_;
36
37 void getValue(const Vector<Real> &x, Real &tol) {
38 if (!isValueComputed_) {
39 val_ = obj_->value(x,tol);
40 isValueComputed_ = true;
41 }
42 }
43
44 void getGradient(const Vector<Real> &x, Real &tol) {
46 g_ = x.dual().clone();
48 }
50 obj_->gradient(*g_,x,tol);
52 }
53 }
54
55 Ptr<const Vector<Real>> getConstVector(const Vector<Real> &x) const {
56 const RiskVector<Real> &xrv = dynamic_cast<const RiskVector<Real>&>(x);
57 return xrv.getVector();
58 }
59
60 Ptr<Vector<Real>> getVector(Vector<Real> &x) const {
61 RiskVector<Real> &xrv = dynamic_cast<RiskVector<Real>&>(x);
62 return xrv.getVector();
63 }
64
65 Ptr<const std::vector<Real>> getConstStat(const Vector<Real> &x) const {
66 const RiskVector<Real> &xrv = dynamic_cast<const RiskVector<Real>&>(x);
67 Ptr<const std::vector<Real>> xstat = xrv.getStatistic();
68 if (xstat == nullPtr) {
69 xstat = makePtr<const std::vector<Real>>(0);
70 }
71 return xstat;
72 }
73
74 Ptr<std::vector<Real>> getStat(Vector<Real> &x) const {
75 RiskVector<Real> &xrv = dynamic_cast<RiskVector<Real>&>(x);
76 Ptr<std::vector<Real>> xstat = xrv.getStatistic();
77 if (xstat == nullPtr) {
78 xstat = makePtr<std::vector<Real>>(0);
79 }
80 return xstat;
81 }
82
83public:
84
86 ParameterList &parlist)
87 : obj_(obj),
88 isValueComputed_(false),
90 isGradientComputed_(false) {
91 std::string risk = parlist.sublist("SOL").sublist("Risk Measure").get("Name","CVaR");
93 switch(ed) {
95 quad_ = makePtr<QuantileQuadrangle<Real>>(parlist); break;
97 quad_ = makePtr<MoreauYosidaCVaR<Real>>(parlist); break;
99 quad_ = makePtr<GenMoreauYosidaCVaR<Real>>(parlist); break;
101 quad_ = makePtr<LogExponentialQuadrangle<Real>>(parlist); break;
103 quad_ = makePtr<MeanVarianceQuadrangle<Real>>(parlist); break;
105 quad_ = makePtr<TruncatedMeanQuadrangle<Real>>(parlist); break;
107 quad_ = makePtr<LogQuantileQuadrangle<Real>>(parlist); break;
109 quad_ = makePtr<SmoothedWorstCaseQuadrangle<Real>>(parlist); break;
110// case RISKMEASURE_CHI2DIVERGENCE:
111// return makePtr<Chi2Divergence<Real>>(parlist);
112// case RISKMEASURE_KLDIVERGENCE:
113// return makePtr<KLDivergence<Real>>(parlist);
114 default:
115 ROL_TEST_FOR_EXCEPTION(true,std::invalid_argument,
116 "Invalid risk measure type " << risk << "!");
117 }
118 }
119
120 void update( const Vector<Real> &x, bool flag = true, int iter = -1 ) {
121 Ptr<const Vector<Real>> xvec = getConstVector(x);
122 obj_->update(*xvec,flag,iter);
123 isValueComputed_ = false;
124 isGradientComputed_ = false;
125 }
126
127 Real value( const Vector<Real> &x, Real &tol ) {
128 Ptr<const Vector<Real>> xvec = getConstVector(x);
129 Ptr<const std::vector<Real>> xstat = getConstStat(x);
130 getValue(*xvec,tol);
131 Real reg = quad_->regret(val_-(*xstat)[0],0);
132 return (*xstat)[0] + reg;
133 }
134
135 void gradient( Vector<Real> &g, const Vector<Real> &x, Real &tol ) {
136 Ptr<Vector<Real>> gvec = getVector(g);
137 Ptr<std::vector<Real>> gstat = getStat(g);
138 Ptr<const Vector<Real>> xvec = getConstVector(x);
139 Ptr<const std::vector<Real>> xstat = getConstStat(x);
140 getValue(*xvec,tol);
141 Real reg = quad_->regret(val_-(*xstat)[0],1);
142 getGradient(*xvec,tol);
143 gvec->set(*g_); gvec->scale(reg);
144 (*gstat)[0] = static_cast<Real>(1)-reg;
145 }
146
147 void hessVec( Vector<Real> &hv, const Vector<Real> &v, const Vector<Real> &x, Real &tol ) {
148 Ptr<Vector<Real>> hvec = getVector(hv);
149 Ptr<std::vector<Real>> hstat = getStat(hv);
150 Ptr<const Vector<Real>> vvec = getConstVector(v);
151 Ptr<const std::vector<Real>> vstat = getConstStat(v);
152 Ptr<const Vector<Real>> xvec = getConstVector(x);
153 Ptr<const std::vector<Real>> xstat = getConstStat(x);
154 getValue(*xvec,tol);
155 Real reg1 = quad_->regret(val_-(*xstat)[0],1);
156 Real reg2 = quad_->regret(val_-(*xstat)[0],2);
157 getGradient(*xvec,tol);
158 //Real gv = vvec->dot(g_->dual());
159 Real gv = vvec->apply(*g_);
160 obj_->hessVec(*hvec,*vvec,*xvec,tol);
161 hvec->scale(reg1); hvec->axpy(reg2*(gv-(*vstat)[0]),*g_);
162 (*hstat)[0] = reg2*((*vstat)[0]-gv);
163 }
164
165 void setParameter(const std::vector<Real> &param) {
166 obj_->setParameter(param);
168 }
169
170};
171
172}
173#endif
Provides the interface to evaluate objective functions.
virtual void setParameter(const std::vector< Real > &param)
Provides the interface for the progressive hedging risk objective.
Ptr< ExpectationQuad< Real > > quad_
void update(const Vector< Real > &x, bool flag=true, int iter=-1)
Update objective function.
const Ptr< Objective< Real > > obj_
Ptr< const std::vector< Real > > getConstStat(const Vector< Real > &x) const
Ptr< const Vector< Real > > getConstVector(const Vector< Real > &x) const
void setParameter(const std::vector< Real > &param)
void gradient(Vector< Real > &g, const Vector< Real > &x, Real &tol)
Compute gradient.
Real value(const Vector< Real > &x, Real &tol)
Compute value.
PH_RiskObjective(const Ptr< Objective< Real > > &obj, ParameterList &parlist)
Ptr< Vector< Real > > g_
void getGradient(const Vector< Real > &x, Real &tol)
void getValue(const Vector< Real > &x, Real &tol)
Ptr< std::vector< Real > > getStat(Vector< Real > &x) const
void hessVec(Vector< Real > &hv, const Vector< Real > &v, const Vector< Real > &x, Real &tol)
Apply Hessian approximation to vector.
Ptr< Vector< Real > > getVector(Vector< Real > &x) const
Ptr< std::vector< Real > > getStatistic(const int comp=0, const int index=0)
Ptr< const Vector< Real > > getVector(void) const
Defines the linear algebra or vector space interface.
virtual const Vector & dual() const
Return dual representation of , for example, the result of applying a Riesz map, or change of basis,...
virtual ROL::Ptr< Vector > clone() const =0
Clone to make a new (uninitialized) vector.
ERiskMeasure StringToERiskMeasure(std::string s)
@ RISKMEASURE_GENMOREAUYOSIDACVAR
@ RISKMEASURE_SMOOTHEDWORSTCASE