Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CUDA: enable cuda support v1 - EAGER with GDR COPY #20

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
19 changes: 19 additions & 0 deletions src/ucp/core/ucp_ep.c
Original file line number Diff line number Diff line change
Expand Up @@ -885,6 +885,7 @@ void ucp_ep_config_init(ucp_worker_h worker, ucp_ep_config_t *config)
{
ucp_context_h context = worker->context;
ucp_ep_rma_config_t *rma_config;
ucp_ep_addr_domain_config_t *domain_config;
uct_iface_attr_t *iface_attr;
uct_md_attr_t *md_attr;
ucp_rsc_index_t rsc_index;
Expand All @@ -903,6 +904,7 @@ void ucp_ep_config_init(ucp_worker_h worker, ucp_ep_config_t *config)
config->tag.eager.zcopy_auto_thresh = 0;
config->am.zcopy_auto_thresh = 0;
config->p2p_lanes = 0;
config->domain_lanes = 0;
config->bcopy_thresh = context->config.ext.bcopy_thresh;
config->tag.lane = UCP_NULL_LANE;
config->tag.proto = &ucp_tag_eager_proto;
Expand Down Expand Up @@ -990,6 +992,23 @@ void ucp_ep_config_init(ucp_worker_h worker, ucp_ep_config_t *config)
}
}

/* Configuration for memory domains */
for (lane = 0; lane < config->key.num_lanes; ++lane) {
if (config->key.domain_lanes[lane] == UCP_NULL_LANE) {
continue;
}
config->domain_lanes |= UCS_BIT(lane);

domain_config = &config->domain[lane];
rsc_index = config->key.lanes[lane].rsc_index;
iface_attr = &worker->ifaces[rsc_index].attr;

domain_config->tag.eager.max_short = iface_attr->cap.am.max_short;
//TODO: zcopy thrshold should be based on the ep AM lane capability with domain addr(i.e can UCT do zcopy from domain)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thrshold -> threshold

memset(domain_config->tag.eager.zcopy_thresh, 0, UCP_MAX_IOV * sizeof(size_t));

}

/* Configuration for remote memory access */
for (lane = 0; lane < config->key.num_lanes; ++lane) {
if (ucp_ep_config_get_rma_prio(config->key.rma_lanes, lane) == -1) {
Expand Down
27 changes: 26 additions & 1 deletion src/ucp/core/ucp_ep.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,9 @@ typedef struct ucp_ep_config_key {
/* Lanes for atomic operations, sorted by priority, highest first */
ucp_lane_index_t amo_lanes[UCP_MAX_LANES];

/* Lanes for domain operations, sorted by priority, highest first */
ucp_lane_index_t domain_lanes[UCP_MAX_LANES];

/* Bitmap of remote mds which are reachable from this endpoint (with any set
* of transports which could be selected in the future).
*/
Expand All @@ -106,6 +109,17 @@ typedef struct ucp_ep_rma_config {
} ucp_ep_rma_config_t;


#define UCP_IS_DEFAULT_ADDR_DOMAIN(_addr_dn_h) (_addr_dn_h == &ucp_addr_dn_dummy_handle)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(_addr_dn_h == &ucp_addr_dn_dummy_handle) -> ((_addr_dn_h) == &ucp_addr_dn_dummy_handle)


typedef struct ucp_ep_addr_domain_config {
struct {
struct {
ssize_t max_short;
size_t zcopy_thresh[UCP_MAX_IOV];
} eager;
} tag;
} ucp_ep_addr_domain_config_t;

/*
* Configuration for AM and tag offload protocols
*/
Expand Down Expand Up @@ -136,6 +150,10 @@ typedef struct ucp_ep_config {
*/
ucp_lane_map_t p2p_lanes;

/* Bitmap of which lanes are domain lanes
*/
ucp_lane_map_t domain_lanes;

/* Configuration for each lane that provides RMA */
ucp_ep_rma_config_t rma[UCP_MAX_LANES];
/* Threshold for switching from put_short to put_bcopy */
Expand Down Expand Up @@ -179,8 +197,11 @@ typedef struct ucp_ep_config {
* (currently it's only AM based). */
const ucp_proto_t *proto;
} stream;
} ucp_ep_config_t;

/* Configuration of all domains */
ucp_ep_addr_domain_config_t domain[UCP_MAX_LANES];

} ucp_ep_config_t;

/**
* Remote protocol layer endpoint
Expand Down Expand Up @@ -245,4 +266,8 @@ size_t ucp_ep_config_get_zcopy_auto_thresh(size_t iovcnt,
const ucp_context_h context,
double bandwidth);

ucp_lane_index_t ucp_config_find_domain_lane(const ucp_ep_config_t *config,
const ucp_lane_index_t *lanes,
ucp_md_map_t dn_md_map);
ucs_status_t ucp_ep_set_domain_lanes(ucp_ep_h ep, ucp_addr_dn_h addr_dn_h);
#endif
121 changes: 121 additions & 0 deletions src/ucp/dt/dt.c
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
*/

#include "dt.h"
#include <ucp/core/ucp_request.inl>


size_t ucp_dt_pack(ucp_datatype_t datatype, void *dest, const void *src,
Expand Down Expand Up @@ -44,3 +45,123 @@ size_t ucp_dt_pack(ucp_datatype_t datatype, void *dest, const void *src,
state->offset += result_len;
return result_len;
}

static UCS_F_ALWAYS_INLINE ucs_status_t ucp_dn_dt_unpack(ucp_request_t *req, void *buffer, size_t buffer_size,
const void *recv_data, size_t recv_length)
{
ucs_status_t status;
ucp_worker_h worker = req->recv.worker;
ucp_context_h context = worker->context;
ucp_ep_h ep = ucp_worker_ep_find(worker, worker->uuid);
ucp_ep_config_t *config = ucp_ep_config(ep);
ucp_md_map_t dn_md_map = req->addr_dn_h->md_map;
ucp_lane_index_t dn_lane;
ucp_rsc_index_t rsc_index;
uct_iface_attr_t *iface_attr;
unsigned md_index;
uct_mem_h memh;
uct_iov_t iov;

if (recv_length == 0) {
return UCS_OK;
}

while(1) {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

space after while:
while (1) {

dn_lane = ucp_config_find_domain_lane(config, config->key.domain_lanes, dn_md_map);
if (dn_lane == UCP_NULL_LANE) {
ucs_error("Not find address domain lane.");
return UCS_ERR_IO_ERROR;
}
rsc_index = ucp_ep_get_rsc_index(ep, dn_lane);
iface_attr = &worker->ifaces[rsc_index].attr;
md_index = config->key.lanes[dn_lane].dst_md_index;
if (!(iface_attr->cap.flags & UCT_IFACE_FLAG_PUT_ZCOPY)) {
dn_md_map |= ~UCS_BIT(md_index);
continue;
}
break;
}


status = uct_md_mem_reg(context->tl_mds[md_index].md, buffer, buffer_size,
UCT_MD_MEM_ACCESS_REMOTE_PUT, &memh);
if (status != UCS_OK) {
ucs_error("Failed to reg address %p with md %s", buffer,
context->tl_mds[md_index].rsc.md_name);
return status;
}

ucs_assert(buffer_size >= recv_length);
iov.buffer = (void *)recv_data;
iov.length = recv_length;
iov.count = 1;
iov.memh = UCT_MEM_HANDLE_NULL;


status = uct_ep_put_zcopy(ep->uct_eps[dn_lane], &iov, 1, (uint64_t)buffer,
(uct_rkey_t )memh, NULL);
if (status != UCS_OK) {
uct_md_mem_dereg(context->tl_mds[md_index].md, memh);
ucs_error("Failed to perform uct_ep_put_zcopy to address %p", recv_data);
return status;
}

status = uct_md_mem_dereg(context->tl_mds[md_index].md, memh);
if (status != UCS_OK) {
ucs_error("Failed to dereg address %p with md %s", buffer,
context->tl_mds[md_index].rsc.md_name);
return status;
}

return UCS_OK;
}


ucs_status_t ucp_dt_unpack(ucp_request_t *req, ucp_datatype_t datatype, void *buffer, size_t buffer_size,
ucp_dt_state_t *state, const void *recv_data, size_t recv_length, int last)
{
ucp_dt_generic_t *dt_gen;
size_t offset = state->offset;
ucs_status_t status;

if (ucs_unlikely((recv_length + offset) > buffer_size)) {
ucs_trace_req("message truncated: recv_length %zu offset %zu buffer_size %zu",
recv_length, offset, buffer_size);
if (UCP_DT_IS_GENERIC(datatype) && last) {
ucp_dt_generic(datatype)->ops.finish(state->dt.generic.state);
}
return UCS_ERR_MESSAGE_TRUNCATED;
}

switch (datatype & UCP_DATATYPE_CLASS_MASK) {
case UCP_DATATYPE_CONTIG:
if (ucs_likely(UCP_IS_DEFAULT_ADDR_DOMAIN(req->addr_dn_h))) {
UCS_PROFILE_NAMED_CALL("memcpy_recv", memcpy, buffer + offset,
recv_data, recv_length);
return UCS_OK;
} else {
return ucp_dn_dt_unpack(req, buffer, buffer_size, recv_data, recv_length);
}

case UCP_DATATYPE_IOV:
UCS_PROFILE_CALL(ucp_dt_iov_scatter, buffer, state->dt.iov.iovcnt,
recv_data, recv_length, &state->dt.iov.iov_offset,
&state->dt.iov.iovcnt_offset);
return UCS_OK;

case UCP_DATATYPE_GENERIC:
dt_gen = ucp_dt_generic(datatype);
status = UCS_PROFILE_NAMED_CALL("dt_unpack", dt_gen->ops.unpack,
state->dt.generic.state, offset,
recv_data, recv_length);
if (last) {
UCS_PROFILE_NAMED_CALL_VOID("dt_finish", dt_gen->ops.finish,
state->dt.generic.state);
}
return status;

default:
ucs_error("unexpected datatype=%lx", datatype);
return UCS_ERR_INVALID_PARAM;
}
}
50 changes: 4 additions & 46 deletions src/ucp/dt/dt.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include <uct/api/uct.h>
#include <ucs/debug/profile.h>
#include <string.h>
#include <ucp/core/ucp_types.h>


/**
Expand Down Expand Up @@ -72,51 +73,8 @@ size_t ucp_dt_length(ucp_datatype_t datatype, size_t count,
size_t ucp_dt_pack(ucp_datatype_t datatype, void *dest, const void *src,
ucp_dt_state_t *state, size_t length);

static UCS_F_ALWAYS_INLINE ucs_status_t
ucp_dt_unpack(ucp_datatype_t datatype, void *buffer, size_t buffer_size,
ucp_dt_state_t *state, const void *recv_data,
size_t recv_length, int last)
{
ucp_dt_generic_t *dt_gen;
size_t offset = state->offset;
ucs_status_t status;

if (ucs_unlikely((recv_length + offset) > buffer_size)) {
ucs_trace_req("message truncated: recv_length %zu offset %zu buffer_size %zu",
recv_length, offset, buffer_size);
if (UCP_DT_IS_GENERIC(datatype) && last) {
ucp_dt_generic(datatype)->ops.finish(state->dt.generic.state);
}
return UCS_ERR_MESSAGE_TRUNCATED;
}

switch (datatype & UCP_DATATYPE_CLASS_MASK) {
case UCP_DATATYPE_CONTIG:
UCS_PROFILE_NAMED_CALL("memcpy_recv", memcpy, buffer + offset,
recv_data, recv_length);
return UCS_OK;

case UCP_DATATYPE_IOV:
UCS_PROFILE_CALL(ucp_dt_iov_scatter, buffer, state->dt.iov.iovcnt,
recv_data, recv_length, &state->dt.iov.iov_offset,
&state->dt.iov.iovcnt_offset);
return UCS_OK;

case UCP_DATATYPE_GENERIC:
dt_gen = ucp_dt_generic(datatype);
status = UCS_PROFILE_NAMED_CALL("dt_unpack", dt_gen->ops.unpack,
state->dt.generic.state, offset,
recv_data, recv_length);
if (last) {
UCS_PROFILE_NAMED_CALL_VOID("dt_finish", dt_gen->ops.finish,
state->dt.generic.state);
}
return status;

default:
ucs_error("unexpected datatype=%lx", datatype);
return UCS_ERR_INVALID_PARAM;
}
}
ucs_status_t ucp_dt_unpack(ucp_request_t *req, ucp_datatype_t datatype,
void *buffer, size_t buffer_size, ucp_dt_state_t *state,
const void *recv_data, size_t recv_length, int last);

#endif
4 changes: 2 additions & 2 deletions src/ucp/tag/eager.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ static UCS_F_ALWAYS_INLINE ucs_status_t
ucp_eager_unexp_match(ucp_worker_h worker, ucp_recv_desc_t *rdesc, ucp_tag_t tag,
unsigned flags, void *buffer, size_t count,
ucp_datatype_t datatype, ucp_dt_state_t *state,
ucp_tag_recv_info_t *info)
ucp_request_t *req, ucp_tag_recv_info_t *info)
{
size_t recv_len, hdr_len;
ucs_status_t status;
Expand All @@ -110,7 +110,7 @@ ucp_eager_unexp_match(ucp_worker_h worker, ucp_recv_desc_t *rdesc, ucp_tag_t tag
UCP_WORKER_STAT_EAGER_CHUNK(worker, UNEXP);
hdr_len = rdesc->hdr_len;
recv_len = rdesc->length - hdr_len;
status = ucp_dt_unpack(datatype, buffer, count, state, data + hdr_len,
status = ucp_dt_unpack(req, datatype, buffer, count, state, data + hdr_len,
recv_len, flags & UCP_RECV_DESC_FLAG_LAST);
state->offset += recv_len;

Expand Down
2 changes: 1 addition & 1 deletion src/ucp/tag/eager_rcv.c
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ ucp_eager_handler(void *arg, void *data, size_t length, unsigned am_flags,
if (req != NULL) {
UCS_PROFILE_REQUEST_EVENT(req, "eager_recv", recv_len);

status = ucp_dt_unpack(req->recv.datatype, req->recv.buffer,
status = ucp_dt_unpack(req, req->recv.datatype, req->recv.buffer,
req->recv.length, &req->recv.state,
data + hdr_len, recv_len,
flags & UCP_RECV_DESC_FLAG_LAST);
Expand Down
2 changes: 1 addition & 1 deletion src/ucp/tag/offload.c
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ void ucp_tag_offload_completed(uct_tag_context_t *self, uct_tag_t stag,
}

if (req->recv.rdesc != NULL) {
status = ucp_dt_unpack(req->recv.datatype, req->recv.buffer, req->recv.length,
status = ucp_dt_unpack(req, req->recv.datatype, req->recv.buffer, req->recv.length,
&req->recv.state, req->recv.rdesc + 1, length, 1);
ucs_mpool_put_inline(req->recv.rdesc);
} else {
Expand Down
8 changes: 4 additions & 4 deletions src/ucp/tag/rndv.c
Original file line number Diff line number Diff line change
Expand Up @@ -731,7 +731,7 @@ UCS_PROFILE_FUNC(ucs_status_t, ucp_rndv_data_handler,
}

UCS_PROFILE_REQUEST_EVENT(rreq, "rndv_data_recv", recv_len);
status = ucp_dt_unpack(rreq->recv.datatype, rreq->recv.buffer,
status = ucp_dt_unpack(rreq, rreq->recv.datatype, rreq->recv.buffer,
rreq->recv.length, &rreq->recv.state,
data + hdr_len, recv_len, 0);
if ((status == UCS_OK) || (status == UCS_INPROGRESS)) {
Expand Down Expand Up @@ -764,9 +764,9 @@ UCS_PROFILE_FUNC(ucs_status_t, ucp_rndv_data_last_handler,
/* Check that total received length matches RTS->length */
ucs_assert(rreq->recv.info.length == rreq->recv.state.offset + recv_len);
UCS_PROFILE_REQUEST_EVENT(rreq, "rndv_data_last_recv", recv_len);
status = ucp_dt_unpack(rreq->recv.datatype, rreq->recv.buffer,
rreq->recv.length, &rreq->recv.state,
data + hdr_len, recv_len, 1);
status = ucp_dt_unpack(rreq, rreq->recv.datatype, rreq->recv.buffer,
rreq->recv.length, &rreq->recv.state,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

allignment

data + hdr_len, recv_len, 1);
} else {
ucs_trace_data("drop last segment for rreq %p, length %zu, status %s",
rreq, recv_len, ucs_status_string(rreq->status));
Expand Down
2 changes: 1 addition & 1 deletion src/ucp/tag/tag_recv.c
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ ucp_tag_search_unexp(ucp_worker_h worker, void *buffer, size_t buffer_size,
UCS_PROFILE_REQUEST_EVENT(req, "eager_match", 0);
status = ucp_eager_unexp_match(worker, rdesc, recv_tag, flags,
buffer, buffer_size, datatype,
&req->recv.state, info);
&req->recv.state, req, info);
ucs_trace_req("release receive descriptor %p", rdesc);
if (status != UCS_INPROGRESS) {
goto out_release_desc;
Expand Down
Loading