Skip to content

Commit

Permalink
Get rid of comm.setup()
Browse files Browse the repository at this point in the history
  • Loading branch information
olsaarik committed Sep 12, 2023
1 parent 015e29c commit cc08b23
Show file tree
Hide file tree
Showing 20 changed files with 243 additions and 333 deletions.
105 changes: 25 additions & 80 deletions include/mscclpp/core.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,15 @@ class Bootstrap {
public:
Bootstrap(){};
virtual ~Bootstrap() = default;
virtual int getRank() = 0;
virtual int getNranks() = 0;
virtual int rank() = 0;
virtual int size() = 0;
virtual void send(void* data, int size, int peer, int tag) = 0;
virtual void recv(void* data, int size, int peer, int tag) = 0;
[[nodiscard]] virtual std::future<void> recv(void* data, int size, int peer, int tag) = 0;
virtual void allGather(void* allData, int size) = 0;
virtual void barrier() = 0;

void send(const std::vector<char>& data, int peer, int tag);
void recv(std::vector<char>& data, int peer, int tag);
std::future<std::vector<char>> recv(int peer, int tag);
};

/// A native implementation of the bootstrap using TCP sockets.
Expand Down Expand Up @@ -70,10 +70,10 @@ class TcpBootstrap : public Bootstrap {
void initialize(const std::string& ifIpPortTrio, int64_t timeoutSec = 30);

/// Return the rank of the process.
int getRank() override;
int rank() override;

/// Return the total number of ranks.
int getNranks() override;
int size() override;

/// Send data to another process.
///
Expand All @@ -95,7 +95,8 @@ class TcpBootstrap : public Bootstrap {
/// @param size The size of the data to receive.
/// @param peer The rank of the process to receive the data from.
/// @param tag The tag to receive the data with.
void recv(void* data, int size, int peer, int tag) override;
/// @return A future that will be ready when the data has been received.
[[nodiscard]] std::future<void> recv(void* data, int size, int peer, int tag) override;

/// Gather data from all processes.
///
Expand Down Expand Up @@ -324,17 +325,17 @@ class RegisteredMemory {
/// Get the size of the memory block.
///
/// @return The size of the memory block.
size_t size();
size_t size() const;

/// Get the transport flags associated with the memory block.
///
/// @return The transport flags associated with the memory block.
TransportFlags transports();
TransportFlags transports() const;

/// Serialize the RegisteredMemory object to a vector of characters.
///
/// @return A vector of characters representing the serialized RegisteredMemory object.
std::vector<char> serialize();
std::vector<char> serialize() const;

/// Deserialize a RegisteredMemory object from a vector of characters.
///
Expand Down Expand Up @@ -365,12 +366,12 @@ class Endpoint {
/// Get the transport used.
///
/// @return The transport used.
Transport transport();
Transport transport() const;

/// Serialize the Endpoint object to a vector of characters.
///
/// @return A vector of characters representing the serialized Endpoint object.
std::vector<char> serialize();
std::vector<char> serialize() const;

/// Deserialize a Endpoint object from a vector of characters.
///
Expand Down Expand Up @@ -527,50 +528,14 @@ struct Setuppable {
virtual void endSetup(std::shared_ptr<Bootstrap> bootstrap);
};

/// A non-blocking future that can be used to check if a value is ready and retrieve it.
template <typename T>
class NonblockingFuture {
std::shared_future<T> future;

public:
/// Default constructor.
NonblockingFuture() = default;

/// Constructor that takes a shared future and moves it into the NonblockingFuture.
///
/// @param future The shared future to move.
NonblockingFuture(std::shared_future<T>&& future) : future(std::move(future)) {}

/// Copy constructor.
///
/// @param other The @ref NonblockingFuture to copy.
NonblockingFuture(const NonblockingFuture& other) = default;

/// Check if the value is ready to be retrieved.
///
/// @return True if the value is ready, false otherwise.
bool ready() const { return future.wait_for(std::chrono::seconds(0)) == std::future_status::ready; }

/// Get the value.
///
/// @return The value.
///
/// @throws Error if the value is not ready.
T get() const {
if (!ready()) throw Error("NonblockingFuture::get() called before ready", ErrorCode::InvalidUsage);
return future.get();
}
};

/// A class that sets up all registered memories and connections between processes.
///
/// A typical way to use this class:
/// 1. Call @ref connectOnSetup() to declare connections between the calling process with other processes.
/// 1. Call @ref connect() to declare connections between the calling process with other processes.
/// 2. Call @ref registerMemory() to register memory regions that will be used for communication.
/// 3. Call @ref sendMemoryOnSetup() or @ref recvMemoryOnSetup() to send/receive registered memory regions to/from
/// 3. Call @ref sendMemory() or @ref recvMemory() to send/receive registered memory regions to/from
/// other processes.
/// 4. Call @ref setup() to set up all registered memories and connections declared in the previous steps.
/// 5. Call @ref NonblockingFuture<RegisteredMemory>::get() to get the registered memory regions received from other
/// 5. Call @ref std::future<RegisteredMemory>::get() to get the registered memory regions received from other
/// processes.
/// 6. All done; use connections and registered memories to build channels.
///
Expand Down Expand Up @@ -603,40 +568,32 @@ class Communicator {
/// @return RegisteredMemory A handle to the buffer.
RegisteredMemory registerMemory(void* ptr, size_t size, TransportFlags transports);

/// Send information of a registered memory to the remote side on setup.
///
/// This function registers a send to a remote process that will happen by a following call of @ref setup(). The send
/// will carry information about a registered memory on the local process.
/// Send information of a registered memory to the remote side.
///
/// @param memory The registered memory buffer to send information about.
/// @param remoteRank The rank of the remote process.
/// @param tag The tag to use for identifying the send.
void sendMemoryOnSetup(RegisteredMemory memory, int remoteRank, int tag);
void sendMemory(RegisteredMemory memory, int remoteRank, int tag);

/// Receive memory on setup.
///
/// This function registers a receive from a remote process that will happen by a following call of @ref setup(). The
/// receive will carry information about a registered memory on the remote process.
/// Receive memory.
///
/// @param remoteRank The rank of the remote process.
/// @param tag The tag to use for identifying the receive.
/// @return NonblockingFuture<RegisteredMemory> A non-blocking future of registered memory.
NonblockingFuture<RegisteredMemory> recvMemoryOnSetup(int remoteRank, int tag);
/// @return std::future<RegisteredMemory> A future of registered memory.
std::future<RegisteredMemory> recvMemory(int remoteRank, int tag);

/// Connect to a remote rank on setup.
/// Connect to a remote rank.
///
/// This function only prepares metadata for connection. The actual connection is made by a following call of
/// @ref setup(). Note that this function is two-way and a connection from rank `i` to remote rank `j` needs
/// Note that this function is two-way and a connection from rank `i` to remote rank `j` needs
/// to have a counterpart from rank `j` to rank `i`. Note that with IB, buffers are registered at a page level and if
/// a buffer is spread through multiple pages and do not fully utilize all of them, IB's QP has to register for all
/// involved pages. This potentially has security risks if the connection's accesses are given to a malicious process.
///
/// @param remoteRank The rank of the remote process.
/// @param tag The tag of the connection for identifying it.
/// @param config The configuration for the local endpoint.
/// @return NonblockingFuture<NonblockingFuture<std::shared_ptr<Connection>>> A non-blocking future of shared pointer
/// to the connection.
NonblockingFuture<std::shared_ptr<Connection>> connectOnSetup(int remoteRank, int tag, EndpointConfig localConfig);
/// @return std::future<std::shared_ptr<Connection>> A future of shared pointer to the connection.
std::future<std::shared_ptr<Connection>> connect(int remoteRank, int tag, EndpointConfig localConfig);

/// Get the remote rank a connection is connected to.
///
Expand All @@ -650,18 +607,6 @@ class Communicator {
/// @return The tag the connection was made with.
int tagOf(const Connection& connection);

/// Add a custom Setuppable object to a list of objects to be setup later, when @ref setup() is called.
///
/// @param setuppable A shared pointer to the Setuppable object.
void onSetup(std::shared_ptr<Setuppable> setuppable);

/// Setup all objects that have registered for setup.
///
/// This includes previous calls of @ref sendMemoryOnSetup(), @ref recvMemoryOnSetup(), @ref connectOnSetup(), and
/// @ref onSetup(). It is allowed to call this function multiple times, where the n-th call will only setup objects
/// that have been registered after the (n-1)-th call.
void setup();

private:
// The interal implementation.
struct Impl;
Expand Down
2 changes: 1 addition & 1 deletion include/mscclpp/semaphore.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ template <template <typename> typename InboundDeleter, template <typename> typen
class BaseSemaphore {
protected:
/// The registered memory for the remote peer's inbound semaphore ID.
NonblockingFuture<RegisteredMemory> remoteInboundSemaphoreIdsRegMem_;
std::shared_future<RegisteredMemory> remoteInboundSemaphoreIdsRegMem_;

/// The inbound semaphore ID that is incremented by the remote peer and waited on by the local peer.
///
Expand Down
33 changes: 14 additions & 19 deletions python/mscclpp/core_py.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,15 @@ extern void register_utils(nb::module_& m);
extern void register_numa(nb::module_& m);

template <typename T>
void def_nonblocking_future(nb::handle& m, const std::string& typestr) {
std::string pyclass_name = std::string("NonblockingFuture") + typestr;
nb::class_<NonblockingFuture<T>>(m, pyclass_name.c_str())
.def("ready", &NonblockingFuture<T>::ready)
.def("get", &NonblockingFuture<T>::get);
void def_future(nb::handle& m, const std::string& typestr) {
std::string pyclass_name = std::string("std_future_") + typestr;
nb::class_<std::future<T>>(m, pyclass_name.c_str()).def("get", &std::future<T>::get);
}

void register_core(nb::module_& m) {
nb::class_<Bootstrap>(m, "Bootstrap")
.def("get_rank", &Bootstrap::getRank)
.def("get_n_ranks", &Bootstrap::getNranks)
.def_prop_ro("rank", &Bootstrap::rank)
.def_prop_ro("size", &Bootstrap::size)
.def(
"send",
[](Bootstrap* self, uintptr_t ptr, size_t size, int peer, int tag) {
Expand All @@ -43,15 +41,15 @@ void register_core(nb::module_& m) {
"recv",
[](Bootstrap* self, uintptr_t ptr, size_t size, int peer, int tag) {
void* data = reinterpret_cast<void*>(ptr);
self->recv(data, size, peer, tag);
return self->recv(data, size, peer, tag);
},
nb::arg("data"), nb::arg("size"), nb::arg("peer"), nb::arg("tag"))
.def("all_gather", &Bootstrap::allGather, nb::arg("allData"), nb::arg("size"))
.def("barrier", &Bootstrap::barrier)
.def("send", (void (Bootstrap::*)(const std::vector<char>&, int, int)) & Bootstrap::send, nb::arg("data"),
nb::arg("peer"), nb::arg("tag"))
.def("recv", (void (Bootstrap::*)(std::vector<char>&, int, int)) & Bootstrap::recv, nb::arg("data"),
nb::arg("peer"), nb::arg("tag"));
.def("recv", (std::future<std::vector<char>>(Bootstrap::*)(int, int)) & Bootstrap::recv, nb::arg("peer"),
nb::arg("tag"));

nb::class_<UniqueId>(m, "UniqueId");

Expand Down Expand Up @@ -147,8 +145,8 @@ void register_core(nb::module_& m) {
.def("create_endpoint", &Context::createEndpoint, nb::arg("config"))
.def("connect", &Context::connect, nb::arg("local_endpoint"), nb::arg("remote_endpoint"));

def_nonblocking_future<RegisteredMemory>(m, "RegisteredMemory");
def_nonblocking_future<std::shared_ptr<Connection>>(m, "shared_ptr_Connection");
def_future<RegisteredMemory>(m, "RegisteredMemory");
def_future<std::shared_ptr<Connection>>(m, "shared_ptr_Connection");

nb::class_<Communicator>(m, "Communicator")
.def(nb::init<std::shared_ptr<Bootstrap>, std::shared_ptr<Context>>(), nb::arg("bootstrap"),
Expand All @@ -161,14 +159,11 @@ void register_core(nb::module_& m) {
return self->registerMemory((void*)ptr, size, transports);
},
nb::arg("ptr"), nb::arg("size"), nb::arg("transports"))
.def("send_memory_on_setup", &Communicator::sendMemoryOnSetup, nb::arg("memory"), nb::arg("remoteRank"),
nb::arg("tag"))
.def("recv_memory_on_setup", &Communicator::recvMemoryOnSetup, nb::arg("remoteRank"), nb::arg("tag"))
.def("connect_on_setup", &Communicator::connectOnSetup, nb::arg("remoteRank"), nb::arg("tag"),
nb::arg("localConfig"))
.def("send_memory", &Communicator::sendMemory, nb::arg("memory"), nb::arg("remoteRank"), nb::arg("tag"))
.def("recv_memory", &Communicator::recvMemory, nb::arg("remoteRank"), nb::arg("tag"))
.def("connect", &Communicator::connect, nb::arg("remoteRank"), nb::arg("tag"), nb::arg("localConfig"))
.def("remote_rank_of", &Communicator::remoteRankOf)
.def("tag_of", &Communicator::tagOf)
.def("setup", &Communicator::setup);
.def("tag_of", &Communicator::tagOf);
}

NB_MODULE(_mscclpp, m) {
Expand Down
Loading

0 comments on commit cc08b23

Please sign in to comment.