OOPS
test/mpi/mpi.h
Go to the documentation of this file.
1 /*
2  * (C) Copyright 2020 Met Office UK
3  *
4  * This software is licensed under the terms of the Apache Licence Version 2.0
5  * which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
6  */
7 
8 #ifndef TEST_MPI_MPI_H_
9 #define TEST_MPI_MPI_H_
10 
11 #include <Eigen/Dense>
12 
13 #include <string>
14 #include <vector>
15 
16 #include "eckit/config/LocalConfiguration.h"
17 #include "eckit/mpi/Comm.h"
18 #include "eckit/testing/Test.h"
19 
21 #include "oops/mpi/mpi.h"
22 #include "oops/runs/Test.h"
23 #include "oops/util/DateTime.h"
24 #include "oops/util/Expect.h"
25 #include "oops/util/parameters/Parameters.h"
26 #include "oops/util/parameters/RequiredParameter.h"
27 
28 namespace eckit
29 {
30  // Don't use the contracted output for these types: the current implementation works only
31  // with integer types.
32  template <> struct VectorPrintSelector<util::DateTime> { typedef VectorPrintSimple selector; };
33 } // namespace eckit
34 
35 namespace test {
36 
37 // -----------------------------------------------------------------------------------------------
38 class TestParameters : public oops::Parameters {
39  OOPS_CONCRETE_PARAMETERS(TestParameters, Parameters)
40  public:
41  oops::RequiredParameter<std::vector<util::DateTime>> datetime{"datetime", this};
42  oops::RequiredParameter<std::vector<int>> int_{"int", this};
43  oops::RequiredParameter<std::vector<std::string>> string{"string", this};
44 };
45 
46 // -----------------------------------------------------------------------------------------------
47 template <typename T>
48 const std::vector<T> & getTestData(const TestParameters &params);
49 
50 template <>
51 const std::vector<util::DateTime> & getTestData(const TestParameters &params) {
52  return params.datetime;
53 }
54 
55 template <>
56 const std::vector<int> & getTestData(const TestParameters &params) {
57  return params.int_;
58 }
59 
60 template <>
61 const std::vector<std::string> & getTestData(const TestParameters &params) {
62  return params.string;
63 }
64 
65 // -----------------------------------------------------------------------------------------------
66 template <typename T>
68  const eckit::Configuration &conf = TestEnvironment::config();
69  const eckit::mpi::Comm &comm = oops::mpi::world();
70 
71  TestParameters localParams;
72  const size_t rank = comm.rank();
73  localParams.deserialize(conf.getSubConfiguration("local" + std::to_string(rank)));
74  std::vector<T> values = getTestData<T>(localParams);
75 
76  TestParameters globalParams;
77  globalParams.deserialize(conf.getSubConfiguration("global"));
78  const std::vector<T> &expectedResult = getTestData<T>(globalParams);
79 
80  oops::mpi::allGatherv(comm, values);
81  EXPECT_EQUAL(values, expectedResult);
82 }
83 
84 // -----------------------------------------------------------------------------------------------
85 CASE("mpi/mpi/defaultCommunicators") {
86  const eckit::mpi::Comm & world = oops::mpi::world();
87  size_t worldsize = world.size();
88  EXPECT_EQUAL(worldsize, 4);
89 
90  const eckit::mpi::Comm & talk_to_myself = oops::mpi::myself();
91  size_t myownsize = talk_to_myself.size();
92  EXPECT_EQUAL(myownsize, 1);
93 }
94 // -----------------------------------------------------------------------------------------------
95 CASE("mpi/mpi/allGathervUsingSerialize") {
96  const eckit::Configuration &conf = TestEnvironment::config();
97  const eckit::mpi::Comm &comm = oops::mpi::world();
98 
99  TestParameters localParams;
100  const size_t rank = comm.rank();
101  localParams.deserialize(conf.getSubConfiguration("local" + std::to_string(rank)));
102  const std::vector<util::DateTime> &localValues = localParams.datetime;
103 
104  TestParameters globalParams;
105  globalParams.deserialize(conf.getSubConfiguration("global"));
106  const std::vector<util::DateTime> &expectedGlobalValues = globalParams.datetime;
107 
108  size_t numGlobalValues;
109  comm.allReduce(localValues.size(), numGlobalValues, eckit::mpi::Operation::SUM);
110 
111  std::vector<util::DateTime> globalValues(numGlobalValues);
112  oops::mpi::allGathervUsingSerialize(comm, localValues.begin(), localValues.end(),
113  globalValues.begin());
114  EXPECT_EQUAL(globalValues, expectedGlobalValues);
115 }
116 // -----------------------------------------------------------------------------------------------
117 CASE("mpi/mpi/allGathervInt") {
118  testAllGatherv<int>();
119 }
120 // -----------------------------------------------------------------------------------------------
121 CASE("mpi/mpi/allGathervDateTime") {
122  testAllGatherv<util::DateTime>();
123 }
124 // -----------------------------------------------------------------------------------------------
125 CASE("mpi/mpi/allGathervInt") {
126  testAllGatherv<std::string>();
127 }
128 // -----------------------------------------------------------------------------------------------
129 CASE("mpi/mpi/SendReceive") {
130  const eckit::Configuration &conf = TestEnvironment::config();
131  const eckit::mpi::Comm &comm = oops::mpi::world();
132  const size_t rank = comm.rank();
133  int source = (rank + 3) % comm.size();
134  int destination = (rank + 1) % comm.size();
135  int tag_send = destination;
136  int tag_recv = rank;
137 
138  util::DateTime sendValue(conf.getString("send"+ std::to_string(rank)));
139  util::DateTime expectedValue(conf.getString("expected"+ std::to_string(rank)));
140  util::DateTime receivedValue;
141 
142  if (rank < 3) {
143  oops::mpi::send(comm, sendValue, destination, tag_send);
144  oops::mpi::receive(comm, receivedValue, source, tag_recv);
145  } else {
146  oops::mpi::receive(comm, receivedValue, source, tag_recv);
147  oops::mpi::send(comm, sendValue, destination, tag_send);
148  }
149  EXPECT_EQUAL(receivedValue, expectedValue);
150 }
151 // -----------------------------------------------------------------------------------------------
152 CASE("mpi/mpi/gatherSerializable") {
153  const eckit::Configuration &conf = TestEnvironment::config();
154  const eckit::mpi::Comm &comm = oops::mpi::world();
155 
156  TestParameters localParams;
157  const size_t rank = comm.rank();
158  localParams.deserialize(conf.getSubConfiguration("local" + std::to_string(rank)));
159  const std::vector<util::DateTime> &localValues = localParams.datetime;
160 
161  TestParameters globalParams;
162  globalParams.deserialize(conf.getSubConfiguration("global"));
163  const std::vector<util::DateTime> &expectedGlobalValues = globalParams.datetime;
164 
165  size_t numGlobalValues;
166  comm.allReduce(localValues.size(), numGlobalValues, eckit::mpi::Operation::SUM);
167 
168  std::vector<util::DateTime> globalValues(numGlobalValues);
169 
170  util::DateTime zeroDate("0001-01-01T00:00:00Z");
171  for (size_t ii = 0; ii < numGlobalValues; ++ii) {
172  globalValues[ii] = zeroDate;
173  }
174 
175  std::vector<util::DateTime> zeroValues = globalValues;
176 
177  size_t root_gather = conf.getInt("root for gathering", 0);
178 
179  oops::mpi::gather(comm, localValues, globalValues, root_gather);
180  if (rank == root_gather) {
181  EXPECT_EQUAL(globalValues, expectedGlobalValues);
182  } else {
183  EXPECT_EQUAL(globalValues, zeroValues);
184  }
185 }
186 // -----------------------------------------------------------------------------------------------
187 CASE("mpi/mpi/gatherDouble") {
188  const eckit::Configuration &conf = TestEnvironment::config();
189  const eckit::mpi::Comm &comm = oops::mpi::world();
190  const size_t rank = comm.rank();
191 
192  std::vector<double> localDouble;
193  conf.get("localDouble" + std::to_string(rank), localDouble);
194 
195  std::vector<double> globalDoubleExpected;
196  conf.get("globalDouble", globalDoubleExpected);
197 
198  size_t numGlobalDouble;
199  comm.allReduce(localDouble.size(), numGlobalDouble, eckit::mpi::Operation::SUM);
200 
201  std::vector<double> globalDouble(numGlobalDouble, 0.0);
202  std::vector<double> zerosDouble = globalDouble;
203 
204  size_t root_gather = conf.getInt("root for gathering", 0);
205 
206  oops::mpi::gather(comm, localDouble, globalDouble, root_gather);
207 
208  if (rank == root_gather) {
209  EXPECT_EQUAL(globalDouble, globalDoubleExpected);
210  } else {
211  EXPECT_EQUAL(globalDouble, zerosDouble);
212  }
213 }
214 // -----------------------------------------------------------------------------------------------
215 CASE("mpi/mpi/allGatherEigen") {
216  const eckit::mpi::Comm &comm = oops::mpi::world();
217  const size_t rank = comm.rank();
218 
219  Eigen::VectorXd localEigen = rank * Eigen::VectorXd::Ones(4);
220 
221  Eigen::MatrixXd globalEigen(4, 4);
222  globalEigen << Eigen::VectorXd::Zero(4),
223  Eigen::VectorXd::Zero(4),
224  Eigen::VectorXd::Zero(4),
225  Eigen::VectorXd::Zero(4);
226 
227  Eigen::MatrixXd expectedEigen(4, 4);
228  expectedEigen << 0*Eigen::VectorXd::Ones(4),
229  1*Eigen::VectorXd::Ones(4),
230  2*Eigen::VectorXd::Ones(4),
231  3*Eigen::VectorXd::Ones(4);
232 
233  oops::mpi::allGather(comm, localEigen, globalEigen);
234  EXPECT_EQUAL(expectedEigen, globalEigen);
235 }
236 // -----------------------------------------------------------------------------------------------
237 CASE("mpi/mpi/exclusiveScan") {
238  const eckit::mpi::Comm &comm = oops::mpi::world();
239  const size_t rank = comm.rank();
240 
241  size_t expectedResult = 0;
242  for (size_t lowerRank = 0; lowerRank < rank; ++lowerRank)
243  expectedResult += lowerRank;
244 
245  size_t result = rank;
246  oops::mpi::exclusiveScan(comm, result);
247  EXPECT_EQUAL(result, expectedResult);
248 }
249 // -----------------------------------------------------------------------------------------------
250 
251 class Mpi : public oops::Test {
252  private:
253  std::string testid() const override {return "test::mpi::mpi";}
254 
255  void register_tests() const override {}
256  void clear() const override {}
257 };
258 
259 } // namespace test
260 
261 #endif // TEST_MPI_MPI_H_
void register_tests() const override
Definition: test/mpi/mpi.h:255
std::string testid() const override
Definition: test/mpi/mpi.h:253
void clear() const override
Definition: test/mpi/mpi.h:256
static const eckit::Configuration & config()
oops::RequiredParameter< std::vector< util::DateTime > > datetime
Definition: test/mpi/mpi.h:41
oops::RequiredParameter< std::vector< int > > int_
Definition: test/mpi/mpi.h:42
Definition: FieldL95.h:22
void allGather(const eckit::mpi::Comm &comm, const Eigen::VectorXd &sendbuf, Eigen::MatrixXd &recvbuf)
const eckit::mpi::Comm & myself()
Default communicator with each MPI task by itself.
Definition: oops/mpi/mpi.cc:90
void allGathervUsingSerialize(const eckit::mpi::Comm &comm, CIter first, CIter last, Iter recvbuf)
A wrapper around the MPI all gather operation for serializable types.
Definition: oops/mpi/mpi.h:119
void gather(const eckit::mpi::Comm &comm, const std::vector< double > &send, std::vector< double > &recv, const size_t root)
Definition: oops/mpi/mpi.cc:96
void exclusiveScan(const eckit::mpi::Comm &comm, size_t &x)
Perform the exclusive scan operation.
void send(const eckit::mpi::Comm &comm, const SERIALIZABLE &sendobj, const int dest, const int tag)
Extend eckit Comm for Serializable oops objects.
Definition: oops/mpi/mpi.h:44
const eckit::mpi::Comm & world()
Default communicator with all MPI tasks (ie MPI_COMM_WORLD)
Definition: oops/mpi/mpi.cc:84
void allGatherv(const eckit::mpi::Comm &comm, std::vector< util::DateTime > &x)
Perform the MPI all gather operation on a vector of DateTime objects.
void receive(const eckit::mpi::Comm &comm, SERIALIZABLE &recvobj, const int source, const int tag)
Definition: oops/mpi/mpi.h:55
const std::vector< T > & getTestData(const TestParameters &params)
Definition: test/mpi/mpi.h:51
void testAllGatherv()
Definition: test/mpi/mpi.h:67
CASE("test_linearmodelparameterswrapper_valid_name")
Definition: TLML95.h:34