79 TEMPUS_FUNC_TIME_MONITOR_DIFF(
80 "Tempus::IntegratorAdjointSensitivity::advanceTime()", TEMPUS_AS_AT);
83 using Teuchos::rcp_dynamic_cast;
85 using Thyra::createMember;
86 using Thyra::createMembers;
94 using Thyra::VectorSpaceBase;
95 typedef Thyra::ModelEvaluatorBase MEB;
96 typedef Thyra::DefaultMultiVectorProductVector<Scalar> DMVPV;
97 typedef Thyra::DefaultProductVector<Scalar> DPV;
100 RCP<const SolutionHistory<Scalar>> state_solution_history =
101 state_integrator_->getSolutionHistory();
102 RCP<const SolutionState<Scalar>> initial_state = (*state_solution_history)[0];
105 bool state_status =
true;
107 TEMPUS_FUNC_TIME_MONITOR_DIFF(
108 "Tempus::IntegratorAdjointSensitivity::advanceTime::state",
111 state_status = state_integrator_->advanceTime(timeFinal);
118 adjoint_aux_model_->setFinalTime(state_integrator_->getTime());
121 adjoint_aux_model_->setForwardSolutionHistory(state_solution_history);
124 RCP<const VectorSpaceBase<Scalar>> g_space = model_->get_g_space(g_index_);
125 RCP<const VectorSpaceBase<Scalar>> x_space = model_->get_x_space();
126 const int num_g = g_space->dim();
127 RCP<MultiVectorBase<Scalar>> dgdx = createMembers(x_space, num_g);
128 MEB::InArgs<Scalar> inargs = model_->getNominalValues();
129 RCP<const SolutionState<Scalar>> state =
130 state_solution_history->getCurrentState();
131 inargs.set_t(state->getTime());
132 inargs.set_x(state->getX());
133 inargs.set_x_dot(state->getXDot());
134 MEB::OutArgs<Scalar> outargs = model_->createOutArgs();
135 MEB::OutArgs<Scalar> adj_outargs = adjoint_model_->createOutArgs();
136 outargs.set_DgDx(g_index_,
137 MEB::Derivative<Scalar>(dgdx, MEB::DERIV_MV_GRADIENT_FORM));
138 model_->evalModel(inargs, outargs);
139 outargs.set_DgDx(g_index_, MEB::Derivative<Scalar>());
145 RCP<DPV> adjoint_init = rcp_dynamic_cast<DPV>(
146 Thyra::createMember(adjoint_aux_model_->get_x_space()));
147 RCP<MultiVectorBase<Scalar>> adjoint_init_mv =
148 rcp_dynamic_cast<DMVPV>(adjoint_init->getNonconstVectorBlock(0))
149 ->getNonconstMultiVector();
150 assign(adjoint_init->getNonconstVectorBlock(1).ptr(),
151 Teuchos::ScalarTraits<Scalar>::zero());
152 if (mass_matrix_is_identity_)
153 assign(adjoint_init_mv.ptr(), *dgdx);
155 inargs.set_alpha(1.0);
156 inargs.set_beta(0.0);
157 RCP<LinearOpWithSolveBase<Scalar>> W;
158 if (adj_outargs.supports(MEB::OUT_ARG_W)) {
160 W = adjoint_model_->create_W();
161 adj_outargs.set_W(W);
162 adjoint_model_->evalModel(inargs, adj_outargs);
163 adj_outargs.set_W(Teuchos::null);
167 RCP<const LinearOpWithSolveFactoryBase<Scalar>> lowsfb =
168 adjoint_model_->get_W_factory();
169 TEUCHOS_TEST_FOR_EXCEPTION(lowsfb == Teuchos::null, std::logic_error,
170 "Adjoint ME must support W out-arg or provide "
171 "a W_factory for non-identity mass matrix");
174 RCP<LinearOpBase<Scalar>> W_op = adjoint_model_->create_W_op();
175 adj_outargs.set_W_op(W_op);
176 RCP<PreconditionerFactoryBase<Scalar>> prec_factory =
177 lowsfb->getPreconditionerFactory();
178 RCP<PreconditionerBase<Scalar>> W_prec;
179 if (prec_factory != Teuchos::null)
180 W_prec = prec_factory->createPrec();
181 else if (adj_outargs.supports(MEB::OUT_ARG_W_prec)) {
182 W_prec = adjoint_model_->create_W_prec();
183 adj_outargs.set_W_prec(W_prec);
185 adjoint_model_->evalModel(inargs, adj_outargs);
186 adj_outargs.set_W_op(Teuchos::null);
187 if (adj_outargs.supports(MEB::OUT_ARG_W_prec))
188 adj_outargs.set_W_prec(Teuchos::null);
191 W = lowsfb->createOp();
192 if (W_prec != Teuchos::null) {
193 if (prec_factory != Teuchos::null)
194 prec_factory->initializePrec(
195 Thyra::defaultLinearOpSource<Scalar>(W_op), W_prec.get());
196 Thyra::initializePreconditionedOp<Scalar>(*lowsfb, W_op, W_prec,
200 Thyra::initializeOp<Scalar>(*lowsfb, W_op, W.ptr());
202 TEUCHOS_TEST_FOR_EXCEPTION(
203 W == Teuchos::null, std::logic_error,
204 "A null W has been encountered in "
205 "Tempus::IntegratorAdjointSensitivity::advanceTime!\n");
208 assign(adjoint_init_mv.ptr(), Teuchos::ScalarTraits<Scalar>::zero());
209 W->solve(Thyra::NOTRANS, *dgdx, adjoint_init_mv.ptr());
213 bool sens_status =
true;
215 TEMPUS_FUNC_TIME_MONITOR_DIFF(
216 "Tempus::IntegratorAdjointSensitivity::advanceTime::adjoint",
220 adjoint_integrator_->getTimeStepControl()->getInitTime();
221 adjoint_integrator_->initializeSolutionHistory(tinit, adjoint_init);
222 sens_status = adjoint_integrator_->advanceTime(timeFinal);
224 RCP<const SolutionHistory<Scalar>> adjoint_solution_history =
225 adjoint_integrator_->getSolutionHistory();
228 RCP<const VectorSpaceBase<Scalar>> p_space = model_->get_p_space(p_index_);
229 dgdp_ = createMembers(p_space, num_g);
230 if (g_depends_on_p_) {
231 MEB::DerivativeSupport dgdp_support =
232 outargs.supports(MEB::OUT_ARG_DgDp, g_index_, p_index_);
233 if (dgdp_support.supports(MEB::DERIV_MV_GRADIENT_FORM)) {
236 MEB::Derivative<Scalar>(dgdp_, MEB::DERIV_MV_GRADIENT_FORM));
237 model_->evalModel(inargs, outargs);
239 else if (dgdp_support.supports(MEB::DERIV_MV_JACOBIAN_FORM)) {
240 const int num_p = p_space->dim();
241 RCP<MultiVectorBase<Scalar>> dgdp_trans = createMembers(g_space, num_p);
244 MEB::Derivative<Scalar>(dgdp_trans, MEB::DERIV_MV_JACOBIAN_FORM));
245 model_->evalModel(inargs, outargs);
246 Thyra::DetachedMultiVectorView<Scalar> dgdp_view(*dgdp_);
247 Thyra::DetachedMultiVectorView<Scalar> dgdp_trans_view(*dgdp_trans);
248 for (
int i = 0; i < num_p; ++i)
249 for (
int j = 0; j < num_g; ++j) dgdp_view(i, j) = dgdp_trans_view(j, i);
252 TEUCHOS_TEST_FOR_EXCEPTION(
true, std::logic_error,
253 "Invalid dg/dp support");
254 outargs.set_DgDp(g_index_, p_index_, MEB::Derivative<Scalar>());
257 assign(dgdp_.ptr(), Scalar(0.0));
261 if (ic_depends_on_p_ && dxdp_init_ != Teuchos::null) {
262 RCP<const SolutionState<Scalar>> adjoint_state =
263 adjoint_solution_history->getCurrentState();
264 RCP<const VectorBase<Scalar>> adjoint_x =
265 rcp_dynamic_cast<const DPV>(adjoint_state->getX())->getVectorBlock(0);
266 RCP<const MultiVectorBase<Scalar>> adjoint_mv =
267 rcp_dynamic_cast<const DMVPV>(adjoint_x)->getMultiVector();
268 if (mass_matrix_is_identity_)
269 dxdp_init_->apply(Thyra::CONJTRANS, *adjoint_mv, dgdp_.ptr(), Scalar(1.0),
272 inargs.set_t(initial_state->getTime());
273 inargs.set_x(initial_state->getX());
274 inargs.set_x_dot(initial_state->getXDot());
275 inargs.set_alpha(1.0);
276 inargs.set_beta(0.0);
277 RCP<LinearOpBase<Scalar>> W_op = adjoint_model_->create_W_op();
278 adj_outargs.set_W_op(W_op);
279 adjoint_model_->evalModel(inargs, adj_outargs);
280 adj_outargs.set_W_op(Teuchos::null);
281 RCP<MultiVectorBase<Scalar>> tmp = createMembers(x_space, num_g);
282 W_op->apply(Thyra::NOTRANS, *adjoint_mv, tmp.ptr(), Scalar(1.0),
284 dxdp_init_->apply(Thyra::CONJTRANS, *tmp, dgdp_.ptr(), Scalar(1.0),
292 if (f_depends_on_p_) {
293 RCP<const SolutionState<Scalar>> adjoint_state =
294 adjoint_solution_history->getCurrentState();
295 RCP<const VectorBase<Scalar>> z =
296 rcp_dynamic_cast<const DPV>(adjoint_state->getX())->getVectorBlock(1);
297 RCP<const MultiVectorBase<Scalar>> z_mv =
298 rcp_dynamic_cast<const DMVPV>(z)->getMultiVector();
299 Thyra::V_VmV(dgdp_.ptr(), *dgdp_, *z_mv);
302 buildSolutionHistory(state_solution_history, adjoint_solution_history);
304 return state_status && sens_status;
564 using Teuchos::ParameterList;
567 using Teuchos::rcp_dynamic_cast;
569 using Thyra::createMembers;
571 using Thyra::multiVectorProductVector;
573 using Thyra::VectorSpaceBase;
574 typedef Thyra::DefaultProductVectorSpace<Scalar> DPVS;
575 typedef Thyra::DefaultProductVector<Scalar> DPV;
577 RCP<const VectorSpaceBase<Scalar>> x_space = model_->get_x_space();
578 RCP<const VectorSpaceBase<Scalar>> adjoint_space =
579 rcp_dynamic_cast<const DPVS>(adjoint_aux_model_->get_x_space())
581 Teuchos::Array<RCP<const VectorSpaceBase<Scalar>>> spaces(2);
583 spaces[1] = adjoint_space;
584 RCP<const DPVS> prod_space = Thyra::productVectorSpace(spaces());
586 int num_states = state_solution_history->getNumStates();
587 const Scalar t_init = state_integrator_->getTimeStepControl()->getInitTime();
588 const Scalar t_final = state_integrator_->getTime();
589 for (
int i = 0; i < num_states; ++i) {
590 RCP<const SolutionState<Scalar>> forward_state =
591 (*state_solution_history)[i];
592 RCP<const SolutionState<Scalar>> adjoint_state =
593 adjoint_solution_history->findState(t_final + t_init -
594 forward_state->getTime());
597 RCP<DPV> x = Thyra::defaultProductVector(prod_space);
598 RCP<const VectorBase<Scalar>> adjoint_x =
599 rcp_dynamic_cast<const DPV>(adjoint_state->getX())->getVectorBlock(0);
600 assign(x->getNonconstVectorBlock(0).ptr(), *(forward_state->getX()));
601 assign(x->getNonconstVectorBlock(1).ptr(), *(adjoint_x));
602 RCP<VectorBase<Scalar>> x_b = x;
605 RCP<DPV> x_dot = Thyra::defaultProductVector(prod_space);
606 RCP<const VectorBase<Scalar>> adjoint_x_dot =
607 rcp_dynamic_cast<const DPV>(adjoint_state->getXDot())
609 assign(x_dot->getNonconstVectorBlock(0).ptr(), *(forward_state->getXDot()));
610 assign(x_dot->getNonconstVectorBlock(1).ptr(), *(adjoint_x_dot));
611 RCP<VectorBase<Scalar>> x_dot_b = x_dot;
615 if (forward_state->getXDotDot() != Teuchos::null) {
616 x_dot_dot = Thyra::defaultProductVector(prod_space);
617 RCP<const VectorBase<Scalar>> adjoint_x_dot_dot =
618 rcp_dynamic_cast<const DPV>(adjoint_state->getXDotDot())
620 assign(x_dot_dot->getNonconstVectorBlock(0).ptr(),
621 *(forward_state->getXDotDot()));
622 assign(x_dot_dot->getNonconstVectorBlock(1).ptr(), *(adjoint_x_dot_dot));
624 RCP<VectorBase<Scalar>> x_dot_dot_b = x_dot_dot;
626 RCP<SolutionState<Scalar>> prod_state = forward_state->clone();
627 prod_state->setX(x_b);
628 prod_state->setXDot(x_dot_b);
629 prod_state->setXDotDot(x_dot_dot_b);
630 prod_state->setPhysicsState(Teuchos::null);
631 solutionHistory_->addState(prod_state);
639 Teuchos::RCP<Teuchos::ParameterList> inputPL,
644 Teuchos::RCP<Teuchos::ParameterList> spl = Teuchos::parameterList();
645 if (inputPL != Teuchos::null) *spl = inputPL->sublist(
"Sensitivities");
647 int p_index = spl->get<
int>(
"Sensitivity Parameter Index", 0);
648 int g_index = spl->get<
int>(
"Response Function Index", 0);
649 bool g_depends_on_p = spl->get<
bool>(
"Response Depends on Parameters",
true);
650 bool f_depends_on_p = spl->get<
bool>(
"Residual Depends on Parameters",
true);
651 bool ic_depends_on_p = spl->get<
bool>(
"IC Depends on Parameters",
true);
652 bool mass_matrix_is_identity =
653 spl->get<
bool>(
"Mass Matrix Is Identity",
false);
655 auto state_integrator = createIntegratorBasic<Scalar>(inputPL, model);
658 if (spl->isParameter(
"Response Depends on Parameters"))
659 spl->remove(
"Response Depends on Parameters");
660 if (spl->isParameter(
"Residual Depends on Parameters"))
661 spl->remove(
"Residual Depends on Parameters");
662 if (spl->isParameter(
"IC Depends on Parameters"))
663 spl->remove(
"IC Depends on Parameters");
665 const Scalar tinit = state_integrator->getTimeStepControl()->getInitTime();
666 const Scalar tfinal = state_integrator->getTimeStepControl()->getFinalTime();
671 Teuchos::RCP<Thyra::ModelEvaluator<Scalar>> adjt_model = adjoint_model;
672 if (adjoint_model == Teuchos::null)
675 auto adjoint_aux_model =
677 model, adjt_model, tinit, tfinal, spl));
681 auto integrator_name = inputPL->get<std::string>(
"Integrator Name");
682 auto integratorPL = Teuchos::sublist(inputPL, integrator_name,
true);
683 auto shPL = Teuchos::sublist(integratorPL,
"Solution History",
true);
684 auto combined_solution_History = createSolutionHistoryPL<Scalar>(shPL);
686 auto adjoint_integrator =
687 createIntegratorBasic<Scalar>(inputPL, adjoint_aux_model);
689 Teuchos::RCP<IntegratorAdjointSensitivity<Scalar>> integrator =
691 model, state_integrator, adjt_model, adjoint_aux_model,
692 adjoint_integrator, combined_solution_History, p_index, g_index,
693 g_depends_on_p, f_depends_on_p, ic_depends_on_p,
694 mass_matrix_is_identity));