OOPS
oops/mpi/mpi.cc
Go to the documentation of this file.
1 /*
2  * (C) Copyright 2013 ECMWF.
3  *
4  * This software is licensed under the terms of the Apache Licence Version 2.0
5  * which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
6  * In applying this licence, ECMWF does not waive the privileges and immunities
7  * granted to it by virtue of its status as an intergovernmental organisation
8  * nor does it submit to any jurisdiction.
9  */
10 
11 #include "oops/mpi/mpi.h"
12 
13 #include <numeric> // for accumulate()
14 #include <string>
15 #include <utility>
16 
17 #include "eckit/exception/Exceptions.h"
18 #include "oops/util/DateTime.h"
19 
20 namespace {
21 
22 // Helper functions used by the implementation of the specialization of allGatherv for a vectors
23 // of strings
24 
25 /// \brief Join strings into a single character array before MPI transfer.
26 ///
27 /// \param strings
28 /// Strings to join.
29 ///
30 /// \returns A pair of two vectors. The first is a concatenation of all input strings
31 /// (without any separating null characters). The second is the list of lengths of these strings.
32 std::pair<std::vector<char>, std::vector<size_t>> encodeStrings(
33  const std::vector<std::string> &strings) {
34  std::pair<std::vector<char>, std::vector<size_t>> result;
35  std::vector<char> &charArray = result.first;
36  std::vector<size_t> &lengths = result.second;
37 
38  size_t totalLength = 0;
39  lengths.reserve(strings.size());
40  for (const std::string &s : strings) {
41  lengths.push_back(s.size());
42  totalLength += s.size();
43  }
44 
45  charArray.reserve(totalLength);
46  for (const std::string &s : strings) {
47  charArray.insert(charArray.end(), s.begin(), s.end());
48  }
49 
50  return result;
51 }
52 
53 /// \brief Split a character array into multiple strings.
54 ///
55 /// \param charArray
56 /// A character array storing a number of concatenated strings (without separating null
57 /// characters).
58 ///
59 /// \param lengths
60 /// The list of lengths of the strings stored in \p charArray.
61 ///
62 /// \returns A vector of strings extracted from \p charArray.
63 std::vector<std::string> decodeStrings(const std::vector<char> &charArray,
64  const std::vector<size_t> &lengths) {
65  std::vector<std::string> strings;
66  strings.reserve(lengths.size());
67 
68  std::vector<char>::const_iterator nextStringBegin = charArray.begin();
69  for (size_t length : lengths) {
70  strings.emplace_back(nextStringBegin, nextStringBegin + length);
71  nextStringBegin += length;
72  }
73 
74  return strings;
75 }
76 
77 } // namespace
78 
79 namespace oops {
80 namespace mpi {
81 
82 // ------------------------------------------------------------------------------------------------
83 
84 const eckit::mpi::Comm & world() {
85  return eckit::mpi::comm();
86 }
87 
88 // ------------------------------------------------------------------------------------------------
89 
90 const eckit::mpi::Comm & myself() {
91  return eckit::mpi::self();
92 }
93 
94 // ------------------------------------------------------------------------------------------------
95 
96 void gather(const eckit::mpi::Comm & comm, const std::vector<double> & send,
97  std::vector<double> & recv, const size_t root) {
98  size_t ntasks = comm.size();
99  if (ntasks > 1) {
100  int mysize = send.size();
101  std::vector<int> sizes(ntasks);
102  comm.allGather(mysize, sizes.begin(), sizes.end());
103  std::vector<int> displs(ntasks);
104  size_t rcvsz = sizes[0];
105  displs[0] = 0;
106  for (size_t jj = 1; jj < ntasks; ++jj) {
107  displs[jj] = displs[jj - 1] + sizes[jj - 1];
108  rcvsz += sizes[jj];
109  }
110  if (comm.rank() == root) recv.resize(rcvsz);
111 
112  comm.gatherv(send, recv, sizes, displs, root);
113  } else {
114  recv = send;
115  }
116 }
117 
118 // ------------------------------------------------------------------------------------------------
119 
120 void allGather(const eckit::mpi::Comm & comm,
121  const Eigen::VectorXd & sendbuf, Eigen::MatrixXd & recvbuf) {
122  const int ntasks = comm.size();
123  int buf_size = sendbuf.size();
124 
125  std::vector<double> vbuf(sendbuf.data(), sendbuf.data() + buf_size);
126  std::vector<double> vbuf_total(ntasks * buf_size);
127 
128  std::vector<int> recvcounts(ntasks);
129  for (int ii = 0; ii < ntasks; ++ii) recvcounts[ii] = buf_size;
130 
131  std::vector<int> displs(ntasks);
132  for (int ii = 0; ii < ntasks; ++ii) displs[ii] = ii * buf_size;
133 
134  comm.allGatherv(vbuf.begin(), vbuf.end(),
135  vbuf_total.begin(), recvcounts.data(), displs.data());
136 
137  for (int ii = 0; ii < ntasks; ++ii) {
138  std::vector<double> vloc(vbuf_total.begin() + ii * buf_size,
139  vbuf_total.begin() + (ii + 1) * buf_size);
140  Eigen::VectorXd my_vect = Eigen::Map<Eigen::VectorXd, Eigen::Unaligned>(vloc.data(),
141  vloc.size());
142  recvbuf.col(ii) = my_vect;
143  }
144 }
145 
146 // ------------------------------------------------------------------------------------------------
147 
148 void allGatherv(const eckit::mpi::Comm & comm, std::vector<util::DateTime> &x) {
149  size_t globalSize = x.size();
150  comm.allReduceInPlace(globalSize, eckit::mpi::sum());
151  std::vector<util::DateTime> globalX(globalSize);
152  oops::mpi::allGathervUsingSerialize(comm, x.begin(), x.end(), globalX.begin());
153  x = std::move(globalX);
154 }
155 
156 // ------------------------------------------------------------------------------------------------
157 
158 void allGatherv(const eckit::mpi::Comm & comm, std::vector<std::string> &x) {
159  std::pair<std::vector<char>, std::vector<size_t>> encodedX = encodeStrings(x);
160 
161  // Gather all character arrays
162  eckit::mpi::Buffer<char> charBuffer(comm.size());
163  comm.allGatherv(encodedX.first.begin(), encodedX.first.end(), charBuffer);
164 
165  // Gather all string lengths
166  eckit::mpi::Buffer<size_t> lengthBuffer(comm.size());
167  comm.allGatherv(encodedX.second.begin(), encodedX.second.end(), lengthBuffer);
168 
169  // Free memory
170  encodedX = {};
171 
172  x = decodeStrings(charBuffer.buffer, lengthBuffer.buffer);
173 }
174 
175 // ------------------------------------------------------------------------------------------------
176 
177 void exclusiveScan(const eckit::mpi::Comm &comm, size_t &x) {
178  // Could be done with MPI_Exscan, but there's no wrapper for it in eckit::mpi.
179 
180  std::vector<size_t> xs(comm.size());
181  comm.allGather(x, xs.begin(), xs.end());
182  x = std::accumulate(xs.begin(), xs.begin() + comm.rank(), 0);
183 }
184 
185 } // namespace mpi
186 } // namespace oops
std::pair< std::vector< char >, std::vector< size_t > > encodeStrings(const std::vector< std::string > &strings)
Join strings into a single character array before MPI transfer.
Definition: oops/mpi/mpi.cc:32
std::vector< std::string > decodeStrings(const std::vector< char > &charArray, const std::vector< size_t > &lengths)
Split a character array into multiple strings.
Definition: oops/mpi/mpi.cc:63
void allGather(const eckit::mpi::Comm &comm, const Eigen::VectorXd &sendbuf, Eigen::MatrixXd &recvbuf)
const eckit::mpi::Comm & myself()
Default communicator with each MPI task by itself.
Definition: oops/mpi/mpi.cc:90
void allGathervUsingSerialize(const eckit::mpi::Comm &comm, CIter first, CIter last, Iter recvbuf)
A wrapper around the MPI all gather operation for serializable types.
Definition: oops/mpi/mpi.h:119
void gather(const eckit::mpi::Comm &comm, const std::vector< double > &send, std::vector< double > &recv, const size_t root)
Definition: oops/mpi/mpi.cc:96
void exclusiveScan(const eckit::mpi::Comm &comm, size_t &x)
Perform the exclusive scan operation.
void send(const eckit::mpi::Comm &comm, const SERIALIZABLE &sendobj, const int dest, const int tag)
Extend eckit Comm for Serializable oops objects.
Definition: oops/mpi/mpi.h:44
const eckit::mpi::Comm & world()
Default communicator with all MPI tasks (ie MPI_COMM_WORLD)
Definition: oops/mpi/mpi.cc:84
void allGatherv(const eckit::mpi::Comm &comm, std::vector< util::DateTime > &x)
Perform the MPI all gather operation on a vector of DateTime objects.
The namespace for the main oops code.