OOPS
JqTermTLAD.h
Go to the documentation of this file.
1 /*
2  * (C) Copyright 2009-2016 ECMWF.
3  *
4  * This software is licensed under the terms of the Apache Licence Version 2.0
5  * which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
6  * In applying this licence, ECMWF does not waive the privileges and immunities
7  * granted to it by virtue of its status as an intergovernmental organisation nor
8  * does it submit to any jurisdiction.
9  */
10 
11 #ifndef OOPS_ASSIMILATION_JQTERMTLAD_H_
12 #define OOPS_ASSIMILATION_JQTERMTLAD_H_
13 
14 #include <memory>
15 #include <vector>
16 
17 #include "oops/base/Increment.h"
18 #include "oops/base/PostBaseTLAD.h"
19 #include "oops/base/State.h"
20 #include "oops/mpi/mpi.h"
21 #include "oops/util/DateTime.h"
22 #include "oops/util/Duration.h"
23 
24 namespace oops {
25 
26 // -----------------------------------------------------------------------------
27 
28 template <typename MODEL>
29 class JqTermTLAD : public PostBaseTLAD<MODEL> {
32 
33  public:
34  explicit JqTermTLAD(const eckit::mpi::Comm &);
36 
37  void clear() {xi_.reset();}
38 // void computeModelErrorTraj(const State_ &, Increment_ &); // not used
39  State_ & getMxi() const;
41 
42  void setupAD(const Increment_ & dx);
43 
44  private:
45  void doInitializeTraj(const State_ &, const util::DateTime &,
46  const util::Duration &) override {}
47  void doProcessingTraj(const State_ &) override {}
48  void doFinalizeTraj(const State_ &) override;
49 
50  void doInitializeTL(const Increment_ &, const util::DateTime &,
51  const util::Duration &) override {}
52  void doProcessingTL(const Increment_ &) override {}
53  void doFinalizeTL(const Increment_ &) override;
54 
55  void doFirstAD(Increment_ &, const util::DateTime &, const util::Duration &) override;
56  void doProcessingAD(Increment_ &) override {}
57  void doLastAD(Increment_ &) override {}
58 
59  const eckit::mpi::Comm & commTime_;
60  std::unique_ptr<State_> xtraj_;
61  std::unique_ptr<Increment_> mxi_;
62  std::unique_ptr<Increment_> xi_;
63 };
64 
65 // =============================================================================
66 
67 template <typename MODEL>
68 JqTermTLAD<MODEL>::JqTermTLAD(const eckit::mpi::Comm & comm)
69  : commTime_(comm), xtraj_(), mxi_(), xi_()
70 {
71  Log::trace() << "JqTermTLAD::JqTermTLAD" << std::endl;
72 }
73 
74 // -----------------------------------------------------------------------------
75 
76 template <typename MODEL>
78  Log::trace() << "JqTermTLAD::doFinalizeTraj start" << std::endl;
79  xtraj_.reset(new State_(xx));
80  Log::trace() << "JqTermTLAD::doFinalizeTraj done" << std::endl;
81 }
82 
83 // -----------------------------------------------------------------------------
84 
85 /*
86 template <typename MODEL>
87 void JqTermTLAD<MODEL>::computeModelErrorTraj(const State_ & fg, Increment_ & dx) {
88  Log::trace() << "JqTermTLAD::computeModelErrorTraj start" << std::endl;
89 
90  static int tag = 83655;
91  size_t mytime = commTime_.rank();
92 // Send values of M(x_i) at end of my subwindow to next subwindow
93  if (mytime + 1 < commTime_.size()) {
94  Log::debug() << "JqTermTLAD::computeModelErrorTraj: sending to " << mytime+1
95  << " " << tag << std::endl;
96  oops::mpi::send(commTime_, fg, mytime+1, tag);
97  Log::debug() << "JqTermTLAD::computeModelErrorTraj: sent to " << mytime+1
98  << " " << tag << std::endl;
99  }
100 
101 // Receive values at beginning of my subwindow from previous subwindow
102  if (mytime > 0) {
103  State_ mxi(fg);
104  Log::debug() << "JqTermTLAD::computeModelErrorTraj: receiving from " << mytime-1
105  << " " << tag << std::endl;
106  oops::mpi::receive(commTime_, mxi, mytime-1, tag);
107  Log::debug() << "JqTermTLAD::computeModelErrorTraj: received from " << mytime-1
108  << " " << tag << std::endl;
109 
110 // Compute x_i - M(x_{i-1})
111  dx.diff(fg, mxi);
112  }
113  ++tag;
114  Log::trace() << "JqTermTLAD::computeModelErrorTraj done" << std::endl;
115 }
116 */
117 
118 // -----------------------------------------------------------------------------
119 
120 template <typename MODEL>
122  Log::trace() << "JqTermTLAD::getMxi" << std::endl;
123 // Retrieve M(x-i)
124  return *xtraj_;
125 }
126 
127 // -----------------------------------------------------------------------------
128 
129 template <typename MODEL>
131  Log::trace() << "JqTermTLAD::doFinalizeTL start" << std::endl;
132  Log::debug() << "JqTermTLAD::doFinalizeTL MPI size " << commTime_.size() << std::endl;
133  Log::debug() << "JqTermTLAD::doFinalizeTL MPI rank " << commTime_.rank() << std::endl;
134  size_t mytime = commTime_.rank();
135  if (mytime + 1 < commTime_.size()) oops::mpi::send(commTime_, dx, mytime+1, 2468);
136  Log::trace() << "JqTermTLAD::doFinalizeTL done" << std::endl;
137 }
138 
139 // -----------------------------------------------------------------------------
140 
141 template <typename MODEL>
143  Log::trace() << "JqTermTLAD::computeModelErrorTL start" << std::endl;
144 // Compute x_i - M(x_{i-1})
145  Log::debug() << "JqTermTLAD::computeModelErrorTL MPI size " << commTime_.size() << std::endl;
146  Log::debug() << "JqTermTLAD::computeModelErrorTL MPI rank " << commTime_.rank() << std::endl;
147  size_t mytime = commTime_.rank();
148  if (mytime > 0) {
149  Increment_ mxim1(dx, false);
150  oops::mpi::receive(commTime_, mxim1, mytime-1, 2468);
151  dx -= mxim1;
152  }
153  Log::info() << "JqTermTLAD: x_i - M(x_i)" << dx << std::endl;
154  Log::trace() << "JqTermTLAD::computeModelErrorTL done" << std::endl;
155 }
156 
157 // -----------------------------------------------------------------------------
158 
159 template <typename MODEL>
161  Log::trace() << "JqTermTLAD::setupAD start" << std::endl;
162  size_t mytime = commTime_.rank();
163  if (mytime > 0) oops::mpi::send(commTime_, dx, mytime-1, 8642);
164  Log::trace() << "JqTermTLAD::setupAD done" << std::endl;
165 }
166 
167 // -----------------------------------------------------------------------------
168 
169 template <typename MODEL>
170 void JqTermTLAD<MODEL>::doFirstAD(Increment_ & dx, const util::DateTime &,
171  const util::Duration &) {
172  Log::trace() << "JqTermTLAD::doFirstAD start" << std::endl;
173  size_t mytime = commTime_.rank();
174  if (mytime + 1 < commTime_.size()) {
175  Increment_ xip1(dx, false);
176  oops::mpi::receive(commTime_, xip1, mytime+1, 8642);
177  dx -= xip1;
178  }
179  Log::trace() << "JqTermTLAD::doFirstAD done" << std::endl;
180 }
181 
182 // -----------------------------------------------------------------------------
183 
184 } // namespace oops
185 
186 #endif // OOPS_ASSIMILATION_JQTERMTLAD_H_
Increment class used in oops.
void computeModelErrorTL(Increment_ &)
Definition: JqTermTLAD.h:142
void doProcessingTraj(const State_ &) override
Definition: JqTermTLAD.h:47
JqTermTLAD(const eckit::mpi::Comm &)
Definition: JqTermTLAD.h:68
void doFinalizeTraj(const State_ &) override
Definition: JqTermTLAD.h:77
void setupAD(const Increment_ &dx)
Definition: JqTermTLAD.h:160
std::unique_ptr< Increment_ > mxi_
Definition: JqTermTLAD.h:61
std::unique_ptr< State_ > xtraj_
Definition: JqTermTLAD.h:60
void doFirstAD(Increment_ &, const util::DateTime &, const util::Duration &) override
Definition: JqTermTLAD.h:170
void doProcessingTL(const Increment_ &) override
Definition: JqTermTLAD.h:52
Increment< MODEL > Increment_
Definition: JqTermTLAD.h:30
State< MODEL > State_
Definition: JqTermTLAD.h:31
void doLastAD(Increment_ &) override
Definition: JqTermTLAD.h:57
const eckit::mpi::Comm & commTime_
Definition: JqTermTLAD.h:59
std::unique_ptr< Increment_ > xi_
Definition: JqTermTLAD.h:62
void doInitializeTraj(const State_ &, const util::DateTime &, const util::Duration &) override
Definition: JqTermTLAD.h:45
void doInitializeTL(const Increment_ &, const util::DateTime &, const util::Duration &) override
Definition: JqTermTLAD.h:50
State_ & getMxi() const
Definition: JqTermTLAD.h:121
void doFinalizeTL(const Increment_ &) override
Definition: JqTermTLAD.h:130
void doProcessingAD(Increment_ &) override
Definition: JqTermTLAD.h:56
Handles post-processing of model fields related to cost function.
Definition: PostBaseTLAD.h:41
State class used in oops; subclass of interface class interface::State.
void send(const eckit::mpi::Comm &comm, const SERIALIZABLE &sendobj, const int dest, const int tag)
Extend eckit Comm for Serializable oops objects.
Definition: oops/mpi/mpi.h:44
void receive(const eckit::mpi::Comm &comm, SERIALIZABLE &recvobj, const int source, const int tag)
Definition: oops/mpi/mpi.h:55
The namespace for the main oops code.