Skip to content

Commit

Permalink
Clean up endpoint configuration
Browse files Browse the repository at this point in the history
Add EndpointConfig struct to remove duplication of constants and get
configuration options out of the function calls.
  • Loading branch information
olsaarik committed Aug 31, 2023
1 parent 3d17ac8 commit 7350de0
Show file tree
Hide file tree
Showing 6 changed files with 48 additions and 28 deletions.
30 changes: 25 additions & 5 deletions include/mscclpp/core.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,28 @@ class Connection {
static std::shared_ptr<Endpoint::Impl> getImpl(Endpoint& memory);
};

/// Used to configure an endpoint.
struct EndpointConfig {
const int DefaultMaxCqSize = 1024;
const int DefaultMaxCqPollNum = 1;
const int DefaultMaxSendWr = 8192;
const int DefaultMaxWrPerSend = 64;

Transport transport;
int ibMaxCqSize = DefaultMaxCqSize;
int ibMaxCqPollNum = DefaultMaxCqPollNum;
int ibMaxSendWr = DefaultMaxSendWr;
int ibMaxWrPerSend = DefaultMaxWrPerSend;

/// Default constructor. Sets transport to Transport::Unknown.
EndpointConfig() : transport(Transport::Unknown) {}

/// Constructor that takes a transport and sets the other fields to their default values.
///
/// @param transport The transport to use.
EndpointConfig(Transport transport) : transport(transport) {}
};

class Context {
public:
/// Create a context.
Expand All @@ -457,8 +479,7 @@ class Context {
/// @param ibMaxSendWr The maximum number of outstanding send work requests for IB. Unused if transport is not IB.
/// @param ibMaxWrPerSend The maximum number of work requests per send for IB. Unused if transport is not IB.
/// @return The newly created endpoint.
Endpoint createEndpoint(Transport transport, int ibMaxCqSize = 1024, int ibMaxCqPollNum = 1, int ibMaxSendWr = 8192,
int ibMaxWrPerSend = 64);
Endpoint createEndpoint(EndpointConfig config);

/// Establish a connection between two endpoints.
///
Expand Down Expand Up @@ -545,6 +566,7 @@ class Communicator {
/// Initializes the communicator with a given bootstrap implementation.
///
/// @param bootstrap An implementation of the Bootstrap that the communicator will use.
/// @param context An optional context to use for the communicator. If not provided, a new context will be created.
Communicator(std::shared_ptr<Bootstrap> bootstrap, std::shared_ptr<Context> context = nullptr);

/// Destroy the communicator.
Expand Down Expand Up @@ -606,9 +628,7 @@ class Communicator {
/// @param ibMaxWrPerSend The maximum number of work requests per send for IB. Unused if transport is not IB.
/// @return NonblockingFuture<NonblockingFuture<std::shared_ptr<Connection>>> A non-blocking future of shared pointer
/// to the connection.
NonblockingFuture<std::shared_ptr<Connection>> connectOnSetup(int remoteRank, int tag, Transport transport,
int ibMaxCqSize = 1024, int ibMaxCqPollNum = 1,
int ibMaxSendWr = 8192, int ibMaxWrPerSend = 64);
NonblockingFuture<std::shared_ptr<Connection>> connectOnSetup(int remoteRank, int tag, EndpointConfig localConfig);

/// Get the remote rank a connection is connected to.
///
Expand Down
15 changes: 11 additions & 4 deletions python/core_py.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,15 @@ void register_core(nb::module_& m) {
.def("serialize", &Endpoint::serialize)
.def_static("deserialize", &Endpoint::deserialize, nb::arg("data"));

nb::class_<EndpointConfig>(m, "EndpointConfig")
.def(nb::init<>())
.def(nb::init_implicit<Transport>(), nb::arg("transport"))
.def_rw("transport", &EndpointConfig::transport)
.def_rw("ib_max_cq_size", &EndpointConfig::ibMaxCqSize)
.def_rw("ib_max_cq_poll_num", &EndpointConfig::ibMaxCqPollNum)
.def_rw("ib_max_send_wr", &EndpointConfig::ibMaxSendWr)
.def_rw("ib_max_wr_per_send", &EndpointConfig::ibMaxWrPerSend);

nb::class_<Context>(m, "Context")
.def(nb::init<>())
.def(
Expand All @@ -135,8 +144,7 @@ void register_core(nb::module_& m) {
return self->registerMemory((void*)ptr, size, transports);
},
nb::arg("ptr"), nb::arg("size"), nb::arg("transports"))
.def("create_endpoint", &Context::createEndpoint, nb::arg("transport"), nb::arg("ibMaxCqSize") = 1024,
nb::arg("ibMaxCqPollNum") = 1, nb::arg("ibMaxSendWr") = 8192, nb::arg("ibMaxWrPerSend") = 64)
.def("create_endpoint", &Context::createEndpoint, nb::arg("config"))
.def("connect", &Context::connect, nb::arg("local_endpoint"), nb::arg("remote_endpoint"));

def_nonblocking_future<RegisteredMemory>(m, "RegisteredMemory");
Expand All @@ -157,8 +165,7 @@ void register_core(nb::module_& m) {
nb::arg("tag"))
.def("recv_memory_on_setup", &Communicator::recvMemoryOnSetup, nb::arg("remoteRank"), nb::arg("tag"))
.def("connect_on_setup", &Communicator::connectOnSetup, nb::arg("remoteRank"), nb::arg("tag"),
nb::arg("transport"), nb::arg("ibMaxCqSize") = 1024, nb::arg("ibMaxCqPollNum") = 1,
nb::arg("ibMaxSendWr") = 8192, nb::arg("ibMaxWrPerSend") = 64)
nb::arg("localConfig"))
.def("remote_rank_of", &Communicator::remoteRankOf)
.def("tag_of", &Communicator::tagOf)
.def("setup", &Communicator::setup);
Expand Down
12 changes: 4 additions & 8 deletions src/communicator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,14 +67,12 @@ MSCCLPP_API_CPP NonblockingFuture<RegisteredMemory> Communicator::recvMemoryOnSe
}

struct Communicator::Impl::Connector : public Setuppable {
Connector(Communicator& comm, Communicator::Impl& commImpl_, int remoteRank, int tag, Transport transport,
int ibMaxCqSize, int ibMaxCqPollNum, int ibMaxSendWr, int ibMaxWrPerSend)
Connector(Communicator& comm, Communicator::Impl& commImpl_, int remoteRank, int tag, EndpointConfig localConfig)
: comm_(comm),
commImpl_(commImpl_),
remoteRank_(remoteRank),
tag_(tag),
localEndpoint_(
comm.context()->createEndpoint(transport, ibMaxCqSize, ibMaxCqPollNum, ibMaxSendWr, ibMaxWrPerSend)) {}
localEndpoint_(comm.context()->createEndpoint(localConfig)) {}

void beginSetup(std::shared_ptr<Bootstrap> bootstrap) override {
bootstrap->send(localEndpoint_.serialize(), remoteRank_, tag_);
Expand All @@ -98,10 +96,8 @@ struct Communicator::Impl::Connector : public Setuppable {
};

MSCCLPP_API_CPP NonblockingFuture<std::shared_ptr<Connection>> Communicator::connectOnSetup(
int remoteRank, int tag, Transport transport, int ibMaxCqSize, int ibMaxCqPollNum, int ibMaxSendWr,
int ibMaxWrPerSend) {
auto connector = std::make_shared<Communicator::Impl::Connector>(
*this, *pimpl_, remoteRank, tag, transport, ibMaxCqSize, ibMaxCqPollNum, ibMaxSendWr, ibMaxWrPerSend);
int remoteRank, int tag, EndpointConfig localConfig) {
auto connector = std::make_shared<Communicator::Impl::Connector>(*this, *pimpl_, remoteRank, tag, localConfig);
onSetup(connector);
return NonblockingFuture<std::shared_ptr<Connection>>(connector->connectionPromise_.get_future());
}
Expand Down
6 changes: 2 additions & 4 deletions src/context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,8 @@ MSCCLPP_API_CPP RegisteredMemory Context::registerMemory(void* ptr, size_t size,
return RegisteredMemory(std::make_shared<RegisteredMemory::Impl>(ptr, size, transports, *pimpl_));
}

MSCCLPP_API_CPP Endpoint Context::createEndpoint(Transport transport, int ibMaxCqSize, int ibMaxCqPollNum,
int ibMaxSendWr, int ibMaxWrPerSend) {
return Endpoint(
std::make_shared<Endpoint::Impl>(transport, ibMaxCqSize, ibMaxCqPollNum, ibMaxSendWr, ibMaxWrPerSend, *pimpl_));
MSCCLPP_API_CPP Endpoint Context::createEndpoint(EndpointConfig config) {
return Endpoint(std::make_shared<Endpoint::Impl>(config, *pimpl_));
}

MSCCLPP_API_CPP std::shared_ptr<Connection> Context::connect(Endpoint localEndpoint, Endpoint remoteEndpoint) {
Expand Down
10 changes: 5 additions & 5 deletions src/endpoint.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@

namespace mscclpp {

Endpoint::Impl::Impl(Transport transport, int ibMaxCqSize, int ibMaxCqPollNum, int ibMaxSendWr, int ibMaxWrPerSend,
Context::Impl& contextImpl)
: transport_(transport), hostHash_(contextImpl.hostHash_) {
if (AllIBTransports.has(transport)) {
Endpoint::Impl::Impl(EndpointConfig config, Context::Impl& contextImpl)
: transport_(config.transport), hostHash_(contextImpl.hostHash_) {
if (AllIBTransports.has(transport_)) {
ibLocal_ = true;
ibQp_ = contextImpl.getIbContext(transport)->createQp(ibMaxCqSize, ibMaxCqPollNum, ibMaxSendWr, 0, ibMaxWrPerSend);
ibQp_ = contextImpl.getIbContext(transport_)
->createQp(config.ibMaxCqSize, config.ibMaxCqPollNum, config.ibMaxSendWr, 0, config.ibMaxWrPerSend);
ibQpInfo_ = ibQp_->getInfo();
}
}
Expand Down
3 changes: 1 addition & 2 deletions src/include/endpoint.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@
namespace mscclpp {

struct Endpoint::Impl {
Impl(Transport transport, int ibMaxCqSize, int ibMaxCqPollNum, int ibMaxSendWr, int ibMaxWrPerSend,
Context::Impl& contextImpl);
Impl(EndpointConfig config, Context::Impl& contextImpl);
Impl(const std::vector<char>& serialization);

Transport transport_;
Expand Down

0 comments on commit 7350de0

Please sign in to comment.