Skip to content

Commit

Permalink
Remove possibility to specify tag in mpi connection. Now the tag is c…
Browse files Browse the repository at this point in the history
…hosen by the runtime allowing multiple connect between peers.

Changes to macros enabling mutexes in managers and protocols in order to make MTCL usable in a multithreaded applications.

Fixed and added some tests for the new features.
  • Loading branch information
Nicolò Tonci committed Jan 19, 2024
1 parent ddebd3c commit eaf034e
Show file tree
Hide file tree
Showing 21 changed files with 257 additions and 72 deletions.
4 changes: 2 additions & 2 deletions examples/hello_world.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ void Server() {
// Some of the following calls might fail, but at least one will succeed
Manager::listen("SHM:/MTCA-server");
Manager::listen("TCP:0.0.0.0:42000");
Manager::listen("MPI:0:10");
Manager::listen("MPI:0"); // can be omitted
Manager::listen("MPIP2P:test");
Manager::listen("MQTT:label");
Manager::listen("UCX:0.0.0.0:21000");
Expand Down Expand Up @@ -132,7 +132,7 @@ void Client() {
auto handle = []() {
auto h = Manager::connect("MPIP2P:test");
if (!h.isValid()) {
auto h = Manager::connect("MPI:0:10");
auto h = Manager::connect("MPI:0");
if (!h.isValid()) {
auto h = Manager::connect("MQTT:label");
if(!h.isValid()) {
Expand Down
2 changes: 1 addition & 1 deletion examples/p2p-perf.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
#include <iomanip>
#include <vector>
#include "mtcl.hpp"

using namespace MTCL;
const int NROUND = 100;
const int N = 24;
const size_t minsize = 16; // bytes
Expand Down
4 changes: 2 additions & 2 deletions examples/pingpong.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,8 @@ int main(int argc, char** argv){
to make the Manager happy in this "protocol-agnostic" example.
*/
#ifdef ENABLE_MPI
listen_str = {"MPI:"};
connect_str = {"MPI:0:5"};
listen_str = {"MPI:0"};
connect_str = {"MPI:0"};
#endif


Expand Down
24 changes: 12 additions & 12 deletions include/collectives/collectiveContext.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@
#include "collectiveImpl.hpp"
#include "../handle.hpp"

#ifdef ENABLE_MPI
#ifdef MTCL_ENABLE_MPI
#include "mpiImpl.hpp"
#endif

#ifdef ENABLE_UCX
#ifdef MTCL_ENABLE_UCX
#include "uccImpl.hpp"
#endif

Expand Down Expand Up @@ -57,15 +57,15 @@ class CollectiveContext : public CommunicationHandle {
coll = new BroadcastGeneric(participants, size, root, rank, uniqtag);
break;
case MPI:
#ifdef ENABLE_MPI
#ifdef MTCL_ENABLE_MPI
void *max_tag;
int flag;
MPI_Comm_get_attr( MPI_COMM_WORLD, MPI_TAG_UB, &max_tag, &flag);
coll = new BroadcastMPI(participants, size, root, rank, uniqtag % (*(int*)max_tag));
#endif
break;
case UCC:
#ifdef ENABLE_UCX
#ifdef MTCL_ENABLE_UCX
coll = new BroadcastUCC(participants, size, root, rank, uniqtag);
#endif
break;
Expand All @@ -83,15 +83,15 @@ class CollectiveContext : public CommunicationHandle {
coll = new ScatterGeneric(participants, size, root, rank, uniqtag);
break;
case MPI:
#ifdef ENABLE_MPI
#ifdef MTCL_ENABLE_MPI
void *max_tag;
int flag;
MPI_Comm_get_attr( MPI_COMM_WORLD, MPI_TAG_UB, &max_tag, &flag);
coll = new ScatterMPI(participants, size, root, rank, uniqtag % (*(int*)max_tag));
#endif
break;
case UCC:
#ifdef ENABLE_UCX
#ifdef MTCL_ENABLE_UCX
coll = new ScatterUCC(participants, size, root, rank, uniqtag);
#endif
break;
Expand All @@ -111,15 +111,15 @@ class CollectiveContext : public CommunicationHandle {
coll = new GatherGeneric(participants, size, root, rank, uniqtag);
break;
case MPI:
#ifdef ENABLE_MPI
#ifdef MTCL_ENABLE_MPI
void *max_tag;
int flag;
MPI_Comm_get_attr( MPI_COMM_WORLD, MPI_TAG_UB, &max_tag, &flag);
coll = new GatherMPI(participants, size, root, rank, uniqtag % (*(int*)max_tag));
#endif
break;
case UCC:
#ifdef ENABLE_UCX
#ifdef MTCL_ENABLE_UCX
coll = new GatherUCC(participants, size, root, rank, uniqtag);
#endif
break;
Expand All @@ -137,15 +137,15 @@ class CollectiveContext : public CommunicationHandle {
coll = new AllGatherGeneric(participants, size, root, rank, uniqtag);
break;
case MPI:
#ifdef ENABLE_MPI
#ifdef MTCL_ENABLE_MPI
void *max_tag;
int flag;
MPI_Comm_get_attr( MPI_COMM_WORLD, MPI_TAG_UB, &max_tag, &flag);
coll = new AllGatherMPI(participants, size, root, rank, uniqtag % (*(int*)max_tag));
#endif
break;
case UCC:
#ifdef ENABLE_UCX
#ifdef MTCL_ENABLE_UCX
coll = new AllGatherUCC(participants, size, root, rank, uniqtag);
#endif
break;
Expand All @@ -163,15 +163,15 @@ class CollectiveContext : public CommunicationHandle {
coll = new AlltoallGeneric(participants, size, root, rank, uniqtag);
break;
case MPI:
#ifdef ENABLE_MPI
#ifdef MTCL_ENABLE_MPI
void *max_tag;
int flag;
MPI_Comm_get_attr( MPI_COMM_WORLD, MPI_TAG_UB, &max_tag, &flag);
coll = new AlltoallMPI(participants, size, root, rank, uniqtag % (*(int*)max_tag));
#endif
break;
case UCC:
#ifdef ENABLE_UCX
#ifdef MTCL_ENABLE_UCX
coll = new AlltoallUCC(participants, size, root, rank, uniqtag);
#endif
break;
Expand Down
40 changes: 22 additions & 18 deletions include/manager.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,23 +28,23 @@
#endif
#endif

#ifdef ENABLE_MPI
#ifdef MTCL_ENABLE_MPI
#include "protocols/mpi.hpp"
#endif

#ifdef ENABLE_MPIP2P
#ifdef MTCL_ENABLE_MPIP2P
#include "protocols/mpip2p.hpp"
#endif

#ifdef ENABLE_MQTT
#ifdef MTCL_ENABLE_MQTT
#include "protocols/mqtt.hpp"
#endif

#ifdef ENABLE_UCX
#ifdef MTCL_ENABLE_UCX
#include "protocols/ucx.hpp"
#endif

#ifdef ENABLE_SHM
#ifdef MTCL_ENABLE_SHM
#include "protocols/shm.hpp"
#endif
namespace MTCL {
Expand Down Expand Up @@ -75,7 +75,9 @@ class Manager {
inline static std::map<std::string, std::tuple<std::string, std::vector<std::string>, std::vector<std::string>>> components;
#endif

REMOVE_CODE_IF(inline static std::thread t1);
#ifndef SINGLE_IO_THREAD
inline static std::thread t1;
#endif
inline static bool end;
inline static bool initialized = false;

Expand Down Expand Up @@ -340,23 +342,23 @@ class Manager {
// default transports protocol
registerType<ConnTcp>("TCP");

#ifdef ENABLE_SHM
#ifdef MTCL_ENABLE_SHM
registerType<ConnSHM>("SHM");
#endif

#ifdef ENABLE_MPI
#ifdef MTCL_ENABLE_MPI
registerType<ConnMPI>("MPI");
#endif

#ifdef ENABLE_MPIP2P
#ifdef MTCL_ENABLE_MQTT
registerType<ConnMPIP2P>("MPIP2P");
#endif

#ifdef ENABLE_MQTT
#ifdef MTCL_ENABLE_MQTT
registerType<ConnMQTT>("MQTT");
#endif

#ifdef ENABLE_UCX
#ifdef MTCL_ENABLE_UCX
registerType<ConnUCX>("UCX");
#endif

Expand Down Expand Up @@ -389,8 +391,9 @@ class Manager {
}
#endif

REMOVE_CODE_IF(t1 = std::thread([&](){Manager::getReadyBackend();}));

#ifndef SINGLE_IO_THREAD
t1 = std::thread([&](){Manager::getReadyBackend();});
#endif
initialized = true;
return 0;
}
Expand All @@ -406,8 +409,9 @@ class Manager {
*/
static void finalize(bool blockflag=false) {
end = true;
REMOVE_CODE_IF(t1.join());

#ifndef SINGLE_IO_THREAD
t1.join();
#endif
//while(!handleReady.empty()) handleReady.pop();
#ifndef MTCL_DISABLE_COLLECTIVES
for(auto& [ctx, _] : contexts) {
Expand Down Expand Up @@ -438,8 +442,8 @@ class Manager {
// if us is not multiple of the IO_THREAD_POLL_TIMEOUT we wait a bit less....
// if the poll timeout is 0, we just iterate us times
size_t niter = us.count(); // in case IO_THREAD_POLL_TIMEOUT is set to 0
if constexpr (IO_THREAD_POLL_TIMEOUT)
niter = us/std::chrono::milliseconds(IO_THREAD_POLL_TIMEOUT);
if constexpr (!!IO_THREAD_POLL_TIMEOUT)
niter = us/std::chrono::microseconds(IO_THREAD_POLL_TIMEOUT);
if (niter==0) niter++;
size_t i=0;
do {
Expand All @@ -465,7 +469,7 @@ class Manager {
}
if (i >= niter) break;
++i;
if constexpr (IO_THREAD_POLL_TIMEOUT)
if constexpr (!!IO_THREAD_POLL_TIMEOUT)
std::this_thread::sleep_for(std::chrono::microseconds(IO_THREAD_POLL_TIMEOUT));
} while(true);
return HandleUser(nullptr, true, true);
Expand Down
14 changes: 13 additions & 1 deletion include/mtcl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,36 @@

namespace MTCL {


#if defined(ENABLE_MPI)
const bool MPI_ENABLED = true;
#define MTCL_ENABLE_MPI
#else
const bool MPI_ENABLED = false;
#endif
#if defined(ENABLE_MPIP2P)
#if defined(ENABLE_MPIP2P) && !defined(NO_MTCL_MULTITHREADED)
const bool MPIP2P_ENABLED = true;
#define MTCL_ENABLE_MPIP2P
#else
const bool MPIP2P_ENABLED = false;
#endif
#if defined(ENABLE_UCX)
const bool UCX_ENABLED = true;
const bool UCC_ENABLED = true;
#define MTCL_ENABLE_UCX
#else
const bool UCX_ENABLED = false;
const bool UCC_ENABLED = false;
#endif

#ifdef ENABLE_SHM
#define MTCL_ENABLE_SHM
#endif

#ifdef ENABLE_MQTT
#define MTCL_ENABLE_MQTT
#endif

} // namespace


Expand Down
21 changes: 17 additions & 4 deletions include/protocols/mpi.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <shared_mutex>
#include <thread>
#include <errno.h>
#include <atomic>

#include <mpi.h>

Expand Down Expand Up @@ -157,6 +158,7 @@ class ConnMPI : public ConnType {
// <rank, tag> => <HandleMPI, busy>
std::map<std::pair<int, int>, std::pair<HandleMPI*, bool>> connections;
std::shared_mutex shm;
std::atomic<unsigned int> tag_counter_even = 100, tag_counter_odd = 101;

public:

Expand Down Expand Up @@ -201,27 +203,38 @@ class ConnMPI : public ConnType {
}

int tag;
try {
/* try {
tag = stoi(dest.substr(dest.find(":") + 1, dest.length()));
}
catch(std::invalid_argument&) {
MTCL_MPI_PRINT(100, "ConnMPI::connect rank must be an integer greater than 0\n");
errno = EINVAL;
return nullptr;
}
}*/

if(rank < 0) {
MTCL_MPI_PRINT(100, "ConnMPI::connect the connection rank must be greater or equal than 0\n");
errno = EINVAL;
return nullptr;
}

if (tag <= (int)MPI_CONNECTION_TAG){
/*if (tag <= (int)MPI_CONNECTION_TAG){
MTCL_MPI_PRINT(100, "ConnMPI::connect the connection tag must be greater than 0\n");
errno = EINVAL;
return nullptr;
}

if (connections.count({rank, tag})){
MTCL_MPI_PRINT(100, "ConnMPI::connect: connection already done use the previous handler!\n");
errno = EINVAL;
return nullptr;
}*/

if (this->rank < rank)
tag = tag_counter_even.fetch_add(2);
else
tag = tag_counter_odd.fetch_add(2);

int header[1];
header[0] = tag;
if (MPI_Send(header, 1, MPI_INT, rank, MPI_CONNECTION_TAG, MPI_COMM_WORLD) != MPI_SUCCESS) {
Expand Down
8 changes: 3 additions & 5 deletions include/protocols/mpip2p.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -227,11 +227,9 @@ class ConnMPIP2P : public ConnType {
for(int i=0; i<remote_size; i++) {

HandleMPIP2P* handle = new HandleMPIP2P(this, i, client, false);
{
connections.insert({handle, false});
ADD_CODE_IF(addinQ(true, handle));
}
REMOVE_CODE_IF(addinQ(true, handle));
connections.insert({handle, false});
addinQ(true, handle);

}
}
MTCL_MPIP2P_PRINT(100, "Accept thread finalized.\n");
Expand Down
2 changes: 1 addition & 1 deletion include/protocols/shm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class ConnSHM : public ConnType {
shmBuffer connbuff;
std::map<HandleSHM*, bool> connections; // Active connections for this Connector

#if !defined(SINGLE_IO_THREAD)
#if defined(NO_MTCL_MULTITHREADED)
std::shared_mutex shm;
#endif

Expand Down
2 changes: 1 addition & 1 deletion include/protocols/tcp.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ class ConnTcp : public ConnType {

fd_set set, tmpset;
int listen_sck;
#if defined(SINGLE_IO_THREAD)
#if defined(NO_MTCL_MULTITHREADED)
int fdmax;
#else
std::atomic<int> fdmax;
Expand Down
2 changes: 1 addition & 1 deletion include/protocols/ucx.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ class ConnUCX : public ConnType {
fd_set set, tmpset;
int listen_sck;

#if defined(SINGLE_IO_THREAD)
#if defined(NO_MTCL_MULTITHREADED)
int fdmax;
#else
std::atomic<int> fdmax;
Expand Down
Loading

0 comments on commit eaf034e

Please sign in to comment.