11 #ifndef OOPS_BASE_WEIGHTEDDIFFTLAD_H_
12 #define OOPS_BASE_WEIGHTEDDIFFTLAD_H_
28 #include "oops/util/DateTime.h"
29 #include "oops/util/Duration.h"
45 template <
typename MODEL>
59 void setupAD(std::shared_ptr<const Increment_>);
63 const util::DateTime &,
const util::Duration &)
override;
68 const util::DateTime &,
const util::Duration &)
override;
81 std::unique_ptr<Accumulator<MODEL, Increment_, Increment_>>
avg_;
94 template <
typename MODEL>
96 const util::DateTime & vt,
97 const util::Duration & span,
98 const util::Duration & tstep,
102 vars_(vars), wfct_(wfct), wdiff_(vars, vt, span, tstep, resol, wfct_),
103 weights_(), forcing_(), avg_(), sum_(0.0), linit_(false),
104 vtime_(vt), bgn_(vt-span/2), end_(vt+span/2), tstep_(tstep),
107 Log::trace() <<
"WeightedDiffTLAD::WeightedDiffTLAD" << std::endl;
112 template <
typename MODEL>
114 const util::DateTime & end,
const util::Duration & tstep) {
115 Log::trace() <<
"WeightedDiffTLAD::doInitializeTraj start" << std::endl;
116 wdiff_.initialize(xx, end, tstep);
117 Log::trace() <<
"WeightedDiffTLAD::doInitializeTraj done" << std::endl;
122 template <
typename MODEL>
124 Log::trace() <<
"WeightedDiffTLAD::doProcessingTraj start" << std::endl;
126 Log::trace() <<
"WeightedDiffTLAD::doProcessingTraj done" << std::endl;
131 template <
typename MODEL>
133 Log::trace() <<
"WeightedDiffTLAD::doFinalizeTraj start" << std::endl;
135 Log::trace() <<
"WeightedDiffTLAD::doFinalizeTraj done" << std::endl;
140 template <
typename MODEL>
142 Log::trace() <<
"WeightedDiffTLAD::setupTL start" << std::endl;
144 Log::trace() <<
"WeightedDiffTLAD::setupTL done" << std::endl;
149 template <
typename MODEL>
151 const util::DateTime & end,
152 const util::Duration & tstep) {
153 Log::trace() <<
"WeightedDiffTLAD::doInitializeTL start" << std::endl;
154 const util::DateTime bgn(dx.
validTime());
156 if (!linit_ && bgn <= end_ && end >= bgn_) {
157 if (tstep_ == util::Duration(0)) tstep_ = tstep;
158 ASSERT(tstep_ > util::Duration(0));
159 weights_ = wfct_.setWeights(bgn_, end_, tstep_);
161 ASSERT(weights_.find(vtime_) != weights_.end());
162 weights_[vtime_] -= 1.0;
166 Log::trace() <<
"WeightedDiffTLAD::doInitializeTL done" << std::endl;
171 template <
typename MODEL>
173 Log::trace() <<
"WeightedDiffTLAD::doProcessingTL start" << std::endl;
174 const util::DateTime now(xx.
validTime());
175 if (((bgnleg_ < end_ && endleg_ > bgn_) || bgnleg_ == endleg_) &&
176 (now != endleg_ || now == end_ || now == bgnleg_)) {
177 ASSERT(weights_.find(now) != weights_.end());
178 const double zz = weights_[now];
179 avg_->axpy(zz, xx,
false);
182 Log::trace() <<
"WeightedDiffTLAD::doProcessingTL done" << std::endl;
187 template <
typename MODEL>
189 Log::trace() <<
"WeightedDiffTLAD::finalTL start" << std::endl;
191 ASSERT(std::abs(sum_) < 1.0e-8);
193 Log::trace() <<
"WeightedDiffTLAD::finalTL done" << std::endl;
198 template <
typename MODEL>
200 Log::trace() <<
"WeightedDiffTLAD::setupAD start" << std::endl;
202 Log::trace() <<
"WeightedDiffTLAD::setupAD done" << std::endl;
208 template <
typename MODEL>
210 const util::DateTime & bgn,
211 const util::Duration & tstep) {
212 Log::trace() <<
"WeightedDiffTLAD::doFirstAD start" << std::endl;
213 const util::DateTime end(dx.
validTime());
215 if (!linit_ && bgn <= end_ && end >= bgn_) {
216 if (tstep_ == util::Duration(0)) tstep_ = tstep;
217 ASSERT(tstep_ > util::Duration(0));
218 weights_ = wfct_.setWeights(bgn_, end_, tstep_);
220 ASSERT(weights_.find(vtime_) != weights_.end());
221 weights_[vtime_] -= 1.0;
225 Log::trace() <<
"WeightedDiffTLAD::doFirstAD done" << std::endl;
230 template <
typename MODEL>
232 Log::trace() <<
"WeightedDiffTLAD::doProcessingAD start" << std::endl;
234 const util::DateTime now(dx.
validTime());
235 if (((bgnleg_ < end_ && endleg_ > bgn_) || bgnleg_ == endleg_) &&
236 (now != endleg_ || now == end_ || now == bgnleg_)) {
237 ASSERT(weights_.find(now) != weights_.end());
238 const double zz = weights_[now];
239 dx.
axpy(zz, *forcing_,
false);
242 Log::trace() <<
"WeightedDiffTLAD::doProcessingAD done" << std::endl;
Geometry class used in oops; subclass of interface class interface::Geometry.
Increment class used in oops.
Handles post-processing of model fields related to cost function.
State class used in oops; subclass of interface class interface::State.
Compute time average of states or increments during model run.
Compute time average of states or increments during linear model run.
std::shared_ptr< const Increment_ > forcing_
Increment< MODEL > Increment_
void doProcessingTL(const Increment_ &) override
void setupTL(const Geometry_ &)
void setupAD(std::shared_ptr< const Increment_ >)
const util::DateTime end_
Increment_ * releaseDiff()
void finalTL(Increment_ &)
std::unique_ptr< Accumulator< MODEL, Increment_, Increment_ > > avg_
void doInitializeTL(const Increment_ &, const util::DateTime &, const util::Duration &) override
void doProcessingTraj(const State_ &) override
void doProcessingAD(Increment_ &) override
const util::DateTime vtime_
WeightedDiff< MODEL, Increment_, State_ > wdiff_
virtual ~WeightedDiffTLAD()
void doInitializeTraj(const State_ &, const util::DateTime &, const util::Duration &) override
void doLastAD(Increment_ &) override
const util::DateTime bgn_
std::map< util::DateTime, double > weights_
void doFinalizeTraj(const State_ &) override
void doFirstAD(Increment_ &, const util::DateTime &, const util::Duration &) override
Geometry< MODEL > Geometry_
WeightedDiffTLAD(const Variables &, const util::DateTime &, const util::Duration &, const util::Duration &, const Geometry_ &, WeightingFct &)
void doFinalizeTL(const Increment_ &) override
void axpy(const double &w, const Increment &dx, const bool check=true)
const util::DateTime validTime() const
Accessor to the time of this Increment.
The namespace for the main oops code.