00001
00002
00003
00004
00005 #include "PlayaMPIComm.hpp"
00006 #include "PlayaMPIDataType.hpp"
00007 #include "PlayaMPIOp.hpp"
00008 #include "PlayaErrorPolling.hpp"
00009
00010
00011
00012
00013 namespace Playa
00014 {
00015 using namespace Teuchos;
00016
00017
00018 MPIComm::MPIComm()
00019 :
00020 #ifdef HAVE_MPI
00021 comm_(MPI_COMM_WORLD),
00022 #endif
00023 nProc_(0), myRank_(0)
00024 {
00025 init();
00026 }
00027
00028 #ifdef HAVE_MPI
00029 MPIComm::MPIComm(MPI_Comm comm)
00030 : comm_(comm), nProc_(0), myRank_(0)
00031 {
00032 init();
00033 }
00034 #endif
00035
00036 int MPIComm::mpiIsRunning() const
00037 {
00038 int mpiStarted = 0;
00039 #ifdef HAVE_MPI
00040 MPI_Initialized(&mpiStarted);
00041 #endif
00042 return mpiStarted;
00043 }
00044
00045 void MPIComm::init()
00046 {
00047 #ifdef HAVE_MPI
00048
00049 if (mpiIsRunning())
00050 {
00051 errCheck(MPI_Comm_rank(comm_, &myRank_), "Comm_rank");
00052 errCheck(MPI_Comm_size(comm_, &nProc_), "Comm_size");
00053 }
00054 else
00055 {
00056 nProc_ = 1;
00057 myRank_ = 0;
00058 }
00059
00060 #else
00061 nProc_ = 1;
00062 myRank_ = 0;
00063 #endif
00064 }
00065
00066 #ifdef USE_MPI_GROUPS
00067
00068 MPIComm::MPIComm(const MPIComm& parent, const MPIGroup& group)
00069 :
00070 #ifdef HAVE_MPI
00071 comm_(MPI_COMM_WORLD),
00072 #endif
00073 nProc_(0), myRank_(0)
00074 {
00075 #ifdef HAVE_MPI
00076 if (group.getNProc()==0)
00077 {
00078 rank_ = -1;
00079 nProc_ = 0;
00080 }
00081 else if (parent.containsMe())
00082 {
00083 MPI_Comm parentComm = parent.comm_;
00084 MPI_Group newGroup = group.group_;
00085
00086 errCheck(MPI_Comm_create(parentComm, newGroup, &comm_),
00087 "Comm_create");
00088
00089 if (group.containsProc(parent.getRank()))
00090 {
00091 errCheck(MPI_Comm_rank(comm_, &rank_), "Comm_rank");
00092
00093 errCheck(MPI_Comm_size(comm_, &nProc_), "Comm_size");
00094 }
00095 else
00096 {
00097 rank_ = -1;
00098 nProc_ = -1;
00099 return;
00100 }
00101 }
00102 else
00103 {
00104 rank_ = -1;
00105 nProc_ = -1;
00106 }
00107 #endif
00108 }
00109
00110 #endif
00111
00112 MPIComm& MPIComm::world()
00113 {
00114 static MPIComm w = MPIComm();
00115 return w;
00116 }
00117
00118
00119 MPIComm& MPIComm::self()
00120 {
00121 #ifdef HAVE_MPI
00122 static MPIComm w = MPIComm(MPI_COMM_SELF);
00123 #else
00124 static MPIComm w = MPIComm();
00125 #endif
00126 return w;
00127 }
00128
00129
00130 void MPIComm::synchronize() const
00131 {
00132 #ifdef HAVE_MPI
00133
00134 {
00135 if (mpiIsRunning())
00136 {
00137
00138
00139 TEUCHOS_POLL_FOR_FAILURES(*this);
00140
00141
00142 errCheck(::MPI_Barrier(comm_), "Barrier");
00143 }
00144 }
00145
00146 #endif
00147 }
00148
00149 void MPIComm::allToAll(void* sendBuf, int sendCount,
00150 const MPIDataType& sendType,
00151 void* recvBuf, int recvCount, const MPIDataType& recvType) const
00152 {
00153 #ifdef HAVE_MPI
00154
00155 {
00156 MPI_Datatype mpiSendType = sendType.handle();
00157 MPI_Datatype mpiRecvType = recvType.handle();
00158
00159
00160 if (mpiIsRunning())
00161 {
00162
00163
00164 TEUCHOS_POLL_FOR_FAILURES(*this);
00165
00166
00167 errCheck(::MPI_Alltoall(sendBuf, sendCount, mpiSendType,
00168 recvBuf, recvCount, mpiRecvType,
00169 comm_), "Alltoall");
00170 }
00171 }
00172
00173 #else
00174 (void)sendBuf;
00175 (void)sendCount;
00176 (void)sendType;
00177 (void)recvBuf;
00178 (void)recvCount;
00179 (void)recvType;
00180 #endif
00181 }
00182
00183 void MPIComm::allToAllv(void* sendBuf, int* sendCount,
00184 int* sendDisplacements, const MPIDataType& sendType,
00185 void* recvBuf, int* recvCount,
00186 int* recvDisplacements, const MPIDataType& recvType) const
00187 {
00188 #ifdef HAVE_MPI
00189
00190 {
00191 MPI_Datatype mpiSendType = sendType.handle();
00192 MPI_Datatype mpiRecvType = recvType.handle();
00193
00194 if (mpiIsRunning())
00195 {
00196
00197
00198 TEUCHOS_POLL_FOR_FAILURES(*this);
00199
00200
00201 errCheck(::MPI_Alltoallv(sendBuf, sendCount, sendDisplacements, mpiSendType,
00202 recvBuf, recvCount, recvDisplacements, mpiRecvType,
00203 comm_), "Alltoallv");
00204 }
00205 }
00206
00207 #else
00208 (void)sendBuf;
00209 (void)sendCount;
00210 (void)sendDisplacements;
00211 (void)sendType;
00212 (void)recvBuf;
00213 (void)recvCount;
00214 (void)recvDisplacements;
00215 (void)recvType;
00216 #endif
00217 }
00218
00219 void MPIComm::gather(void* sendBuf, int sendCount, const MPIDataType& sendType,
00220 void* recvBuf, int recvCount, const MPIDataType& recvType,
00221 int root) const
00222 {
00223 #ifdef HAVE_MPI
00224
00225 {
00226 MPI_Datatype mpiSendType = sendType.handle();
00227 MPI_Datatype mpiRecvType = recvType.handle();
00228
00229
00230 if (mpiIsRunning())
00231 {
00232
00233
00234 TEUCHOS_POLL_FOR_FAILURES(*this);
00235
00236
00237 errCheck(::MPI_Gather(sendBuf, sendCount, mpiSendType,
00238 recvBuf, recvCount, mpiRecvType,
00239 root, comm_), "Gather");
00240 }
00241 }
00242
00243 #endif
00244 }
00245
00246 void MPIComm::gatherv(void* sendBuf, int sendCount, const MPIDataType& sendType,
00247 void* recvBuf, int* recvCount, int* displacements, const MPIDataType& recvType,
00248 int root) const
00249 {
00250 #ifdef HAVE_MPI
00251
00252 {
00253 MPI_Datatype mpiSendType = sendType.handle();
00254 MPI_Datatype mpiRecvType = recvType.handle();
00255
00256 if (mpiIsRunning())
00257 {
00258
00259
00260 TEUCHOS_POLL_FOR_FAILURES(*this);
00261
00262
00263 errCheck(::MPI_Gatherv(sendBuf, sendCount, mpiSendType,
00264 recvBuf, recvCount, displacements, mpiRecvType,
00265 root, comm_), "Gatherv");
00266 }
00267 }
00268
00269 #endif
00270 }
00271
00272 void MPIComm::allGather(void* sendBuf, int sendCount, const MPIDataType& sendType,
00273 void* recvBuf, int recvCount,
00274 const MPIDataType& recvType) const
00275 {
00276 #ifdef HAVE_MPI
00277
00278 {
00279 MPI_Datatype mpiSendType = sendType.handle();
00280 MPI_Datatype mpiRecvType = recvType.handle();
00281
00282 if (mpiIsRunning())
00283 {
00284
00285
00286 TEUCHOS_POLL_FOR_FAILURES(*this);
00287
00288
00289 errCheck(::MPI_Allgather(sendBuf, sendCount, mpiSendType,
00290 recvBuf, recvCount,
00291 mpiRecvType, comm_),
00292 "AllGather");
00293 }
00294 }
00295
00296 #endif
00297 }
00298
00299
00300 void MPIComm::allGatherv(void* sendBuf, int sendCount, const MPIDataType& sendType,
00301 void* recvBuf, int* recvCount,
00302 int* recvDisplacements,
00303 const MPIDataType& recvType) const
00304 {
00305 #ifdef HAVE_MPI
00306
00307 {
00308 MPI_Datatype mpiSendType = sendType.handle();
00309 MPI_Datatype mpiRecvType = recvType.handle();
00310
00311 if (mpiIsRunning())
00312 {
00313
00314
00315 TEUCHOS_POLL_FOR_FAILURES(*this);
00316
00317
00318 errCheck(::MPI_Allgatherv(sendBuf, sendCount, mpiSendType,
00319 recvBuf, recvCount, recvDisplacements,
00320 mpiRecvType,
00321 comm_),
00322 "AllGatherv");
00323 }
00324 }
00325
00326 #endif
00327 }
00328
00329
00330 void MPIComm::bcast(void* msg, int length,
00331 const MPIDataType& type, int src) const
00332 {
00333 #ifdef HAVE_MPI
00334
00335 {
00336 if (mpiIsRunning())
00337 {
00338
00339
00340 TEUCHOS_POLL_FOR_FAILURES(*this);
00341
00342
00343 MPI_Datatype mpiType = type.handle();
00344 errCheck(::MPI_Bcast(msg, length, mpiType, src,
00345 comm_), "Bcast");
00346 }
00347 }
00348
00349 #endif
00350 }
00351
00352 void MPIComm::allReduce(void* input, void* result, int inputCount,
00353 const MPIDataType& type,
00354 const MPIOp& op) const
00355 {
00356 #ifdef HAVE_MPI
00357
00358 {
00359 MPI_Op mpiOp = op.handle();
00360 MPI_Datatype mpiType = type.handle();
00361
00362 if (mpiIsRunning())
00363 {
00364 errCheck(::MPI_Allreduce(input, result, inputCount, mpiType,
00365 mpiOp, comm_),
00366 "Allreduce");
00367 }
00368 }
00369
00370 #endif
00371 }
00372
00373
00374 void MPIComm::errCheck(int errCode, const std::string& methodName)
00375 {
00376 TEUCHOS_TEST_FOR_EXCEPTION(errCode != 0, std::runtime_error,
00377 "MPI function MPI_" << methodName
00378 << " returned error code=" << errCode);
00379 }
00380
00381
00382
00383 }