OOPS
WeightedDiffTLAD.h
Go to the documentation of this file.
1 /*
2  * (C) Copyright 2009-2016 ECMWF.
3  *
4  * This software is licensed under the terms of the Apache Licence Version 2.0
5  * which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
6  * In applying this licence, ECMWF does not waive the privileges and immunities
7  * granted to it by virtue of its status as an intergovernmental organisation nor
8  * does it submit to any jurisdiction.
9  */
10 
11 #ifndef OOPS_BASE_WEIGHTEDDIFFTLAD_H_
12 #define OOPS_BASE_WEIGHTEDDIFFTLAD_H_
13 
14 #include <cmath>
15 #include <map>
16 #include <memory>
17 #include <utility>
18 
19 #include "oops/base/Accumulator.h"
21 #include "oops/base/Geometry.h"
22 #include "oops/base/Increment.h"
23 #include "oops/base/PostBaseTLAD.h"
24 #include "oops/base/State.h"
25 #include "oops/base/Variables.h"
26 #include "oops/base/WeightedDiff.h"
27 #include "oops/base/WeightingFct.h"
28 #include "oops/util/DateTime.h"
29 #include "oops/util/Duration.h"
30 
31 namespace oops {
32 
33 // -----------------------------------------------------------------------------
34 
35 /// Compute time average of states or increments during linear model run.
36 /*!
37  * Derived classes will compute different types of averages (plain
38  * mean, various types of digital filters) by overwriting the weights
39  * computation method.
40  *
41  * A lot of code here is common with WeightedDiff and even WeightedMean,
42  * the design could be improved to reduce code duplication. YT
43  */
44 
45 template <typename MODEL>
46 class WeightedDiffTLAD : public PostBaseTLAD<MODEL> {
50 
51  public:
52  WeightedDiffTLAD(const Variables &, const util::DateTime &, const util::Duration &,
53  const util::Duration &, const Geometry_ &, WeightingFct &);
54  virtual ~WeightedDiffTLAD() {}
55 
56  Increment_ * releaseDiff() {return wdiff_.releaseDiff();}
57  void setupTL(const Geometry_ &);
58  void finalTL(Increment_ &);
59  void setupAD(std::shared_ptr<const Increment_>);
60 
61  private:
62  void doInitializeTraj(const State_ &,
63  const util::DateTime &, const util::Duration &) override;
64  void doProcessingTraj(const State_ &) override;
65  void doFinalizeTraj(const State_ &) override;
66 
67  void doInitializeTL(const Increment_ &,
68  const util::DateTime &, const util::Duration &) override;
69  void doProcessingTL(const Increment_ &) override;
70  void doFinalizeTL(const Increment_ &) override {}
71 
72  void doFirstAD(Increment_ &, const util::DateTime &, const util::Duration &) override;
73  void doProcessingAD(Increment_ &) override;
74  void doLastAD(Increment_ &) override {}
75 
79  std::map< util::DateTime, double > weights_;
80  std::shared_ptr<const Increment_> forcing_;
81  std::unique_ptr<Accumulator<MODEL, Increment_, Increment_>> avg_;
82  double sum_;
83  bool linit_;
84  const util::DateTime vtime_;
85  const util::DateTime bgn_;
86  const util::DateTime end_;
87  util::Duration tstep_;
88  util::DateTime bgnleg_;
89  util::DateTime endleg_;
90 };
91 
92 // =============================================================================
93 
94 template <typename MODEL>
96  const util::DateTime & vt,
97  const util::Duration & span,
98  const util::Duration & tstep,
99  const Geometry_ & resol,
100  WeightingFct & wfct)
101  : PostBaseTLAD<MODEL>(vt-span/2, vt+span/2),
102  vars_(vars), wfct_(wfct), wdiff_(vars, vt, span, tstep, resol, wfct_),
103  weights_(), forcing_(), avg_(), sum_(0.0), linit_(false),
104  vtime_(vt), bgn_(vt-span/2), end_(vt+span/2), tstep_(tstep),
105  bgnleg_(), endleg_()
106 {
107  Log::trace() << "WeightedDiffTLAD::WeightedDiffTLAD" << std::endl;
108 }
109 
110 // -----------------------------------------------------------------------------
111 
112 template <typename MODEL>
114  const util::DateTime & end, const util::Duration & tstep) {
115  Log::trace() << "WeightedDiffTLAD::doInitializeTraj start" << std::endl;
116  wdiff_.initialize(xx, end, tstep);
117  Log::trace() << "WeightedDiffTLAD::doInitializeTraj done" << std::endl;
118 }
119 
120 // -----------------------------------------------------------------------------
121 
122 template <typename MODEL>
124  Log::trace() << "WeightedDiffTLAD::doProcessingTraj start" << std::endl;
125  wdiff_.process(xx);
126  Log::trace() << "WeightedDiffTLAD::doProcessingTraj done" << std::endl;
127 }
128 
129 // -----------------------------------------------------------------------------
130 
131 template <typename MODEL>
133  Log::trace() << "WeightedDiffTLAD::doFinalizeTraj start" << std::endl;
134  wdiff_.finalize(xx);
135  Log::trace() << "WeightedDiffTLAD::doFinalizeTraj done" << std::endl;
136 }
137 
138 // -----------------------------------------------------------------------------
139 
140 template <typename MODEL>
142  Log::trace() << "WeightedDiffTLAD::setupTL start" << std::endl;
143  avg_.reset(new Accumulator<MODEL, Increment_, Increment_>(resol, vars_, vtime_));
144  Log::trace() << "WeightedDiffTLAD::setupTL done" << std::endl;
145 }
146 
147 // -----------------------------------------------------------------------------
148 
149 template <typename MODEL>
151  const util::DateTime & end,
152  const util::Duration & tstep) {
153  Log::trace() << "WeightedDiffTLAD::doInitializeTL start" << std::endl;
154  const util::DateTime bgn(dx.validTime());
155  ASSERT(bgn <= end);
156  if (!linit_ && bgn <= end_ && end >= bgn_) {
157  if (tstep_ == util::Duration(0)) tstep_ = tstep;
158  ASSERT(tstep_ > util::Duration(0));
159  weights_ = wfct_.setWeights(bgn_, end_, tstep_);
160  linit_ = true;
161  ASSERT(weights_.find(vtime_) != weights_.end());
162  weights_[vtime_] -= 1.0;
163  }
164  bgnleg_ = bgn;
165  endleg_ = end;
166  Log::trace() << "WeightedDiffTLAD::doInitializeTL done" << std::endl;
167 }
168 
169 // -----------------------------------------------------------------------------
170 
171 template <typename MODEL>
173  Log::trace() << "WeightedDiffTLAD::doProcessingTL start" << std::endl;
174  const util::DateTime now(xx.validTime());
175  if (((bgnleg_ < end_ && endleg_ > bgn_) || bgnleg_ == endleg_) &&
176  (now != endleg_ || now == end_ || now == bgnleg_)) {
177  ASSERT(weights_.find(now) != weights_.end());
178  const double zz = weights_[now];
179  avg_->axpy(zz, xx, false);
180  sum_ += zz;
181  }
182  Log::trace() << "WeightedDiffTLAD::doProcessingTL done" << std::endl;
183 }
184 
185 // -----------------------------------------------------------------------------
186 
187 template <typename MODEL>
189  Log::trace() << "WeightedDiffTLAD::finalTL start" << std::endl;
190  ASSERT(linit_);
191  ASSERT(std::abs(sum_) < 1.0e-8);
192  out = *avg_;
193  Log::trace() << "WeightedDiffTLAD::finalTL done" << std::endl;
194 }
195 
196 // -----------------------------------------------------------------------------
197 
198 template <typename MODEL>
199 void WeightedDiffTLAD<MODEL>::setupAD(std::shared_ptr<const Increment_> forcing) {
200  Log::trace() << "WeightedDiffTLAD::setupAD start" << std::endl;
201  forcing_ = forcing;
202  Log::trace() << "WeightedDiffTLAD::setupAD done" << std::endl;
203 }
204 
205 // -----------------------------------------------------------------------------
206 
207 
208 template <typename MODEL>
210  const util::DateTime & bgn,
211  const util::Duration & tstep) {
212  Log::trace() << "WeightedDiffTLAD::doFirstAD start" << std::endl;
213  const util::DateTime end(dx.validTime());
214  ASSERT(bgn <= end);
215  if (!linit_ && bgn <= end_ && end >= bgn_) {
216  if (tstep_ == util::Duration(0)) tstep_ = tstep;
217  ASSERT(tstep_ > util::Duration(0));
218  weights_ = wfct_.setWeights(bgn_, end_, tstep_);
219  linit_ = true;
220  ASSERT(weights_.find(vtime_) != weights_.end());
221  weights_[vtime_] -= 1.0;
222  }
223  bgnleg_ = bgn;
224  endleg_ = end;
225  Log::trace() << "WeightedDiffTLAD::doFirstAD done" << std::endl;
226 }
227 
228 // -----------------------------------------------------------------------------
229 
230 template <typename MODEL>
232  Log::trace() << "WeightedDiffTLAD::doProcessingAD start" << std::endl;
233  ASSERT(forcing_);
234  const util::DateTime now(dx.validTime());
235  if (((bgnleg_ < end_ && endleg_ > bgn_) || bgnleg_ == endleg_) &&
236  (now != endleg_ || now == end_ || now == bgnleg_)) {
237  ASSERT(weights_.find(now) != weights_.end());
238  const double zz = weights_[now];
239  dx.axpy(zz, *forcing_, false);
240  sum_ += zz;
241  }
242  Log::trace() << "WeightedDiffTLAD::doProcessingAD done" << std::endl;
243 }
244 
245 // -----------------------------------------------------------------------------
246 
247 } // namespace oops
248 
249 #endif // OOPS_BASE_WEIGHTEDDIFFTLAD_H_
Geometry class used in oops; subclass of interface class interface::Geometry.
Increment class used in oops.
Handles post-processing of model fields related to cost function.
Definition: PostBaseTLAD.h:41
State class used in oops; subclass of interface class interface::State.
Compute time average of states or increments during model run.
Definition: WeightedDiff.h:39
Compute time average of states or increments during linear model run.
std::shared_ptr< const Increment_ > forcing_
Increment< MODEL > Increment_
void doProcessingTL(const Increment_ &) override
void setupTL(const Geometry_ &)
void setupAD(std::shared_ptr< const Increment_ >)
const util::DateTime end_
Increment_ * releaseDiff()
void finalTL(Increment_ &)
std::unique_ptr< Accumulator< MODEL, Increment_, Increment_ > > avg_
void doInitializeTL(const Increment_ &, const util::DateTime &, const util::Duration &) override
void doProcessingTraj(const State_ &) override
void doProcessingAD(Increment_ &) override
const util::DateTime vtime_
WeightedDiff< MODEL, Increment_, State_ > wdiff_
void doInitializeTraj(const State_ &, const util::DateTime &, const util::Duration &) override
void doLastAD(Increment_ &) override
const util::DateTime bgn_
std::map< util::DateTime, double > weights_
void doFinalizeTraj(const State_ &) override
void doFirstAD(Increment_ &, const util::DateTime &, const util::Duration &) override
Geometry< MODEL > Geometry_
WeightedDiffTLAD(const Variables &, const util::DateTime &, const util::Duration &, const util::Duration &, const Geometry_ &, WeightingFct &)
void doFinalizeTL(const Increment_ &) override
Weighting Function.
Definition: WeightingFct.h:31
void axpy(const double &w, const Increment &dx, const bool check=true)
const util::DateTime validTime() const
Accessor to the time of this Increment.
The namespace for the main oops code.