diff --git a/src/ucp/Makefile.am b/src/ucp/Makefile.am index 0d570535234..f4074d3bc28 100644 --- a/src/ucp/Makefile.am +++ b/src/ucp/Makefile.am @@ -69,6 +69,7 @@ noinst_HEADERS = \ proto/proto.h \ rma/rma.h \ rma/rma.inl \ + rma/rma_rndv.h \ rndv/proto_rndv.h \ rndv/proto_rndv.inl \ rndv/rndv_mtype.inl \ @@ -143,6 +144,7 @@ libucp_la_SOURCES = \ rma/get_offload.c \ rma/put_am.c \ rma/put_offload.c \ + rma/rma_rndv.c \ rma/rma_send.c \ rma/rma_sw.c \ rma/flush.c \ diff --git a/src/ucp/core/ucp_ep.c b/src/ucp/core/ucp_ep.c index 1a0c09bd808..ab9c0e42d49 100644 --- a/src/ucp/core/ucp_ep.c +++ b/src/ucp/core/ucp_ep.c @@ -3696,6 +3696,12 @@ void ucp_ep_req_purge(ucp_ep_h ucp_ep, ucp_request_t *req, } ucp_request_put(req); + } else if (req->flags & UCP_REQUEST_FLAG_RNDV_SEND_INTERNAL) { + ucs_assert(req->send.ep == ucp_ep); + + ucp_datatype_iter_cleanup(&req->send.state.dt_iter, 1, + UCP_DT_MASK_ALL); + ucp_request_complete_send(req, status); } else if (req->send.uct.func == ucp_amo_sw_proto.progress_fetch) { /* Currently we don't support UCP EP request purging for proto mode */ ucs_assert(!ucp_ep->worker->context->config.ext.proto_enable); diff --git a/src/ucp/core/ucp_request.c b/src/ucp/core/ucp_request.c index dc3f793e6da..5363f3cb4b5 100644 --- a/src/ucp/core/ucp_request.c +++ b/src/ucp/core/ucp_request.c @@ -44,12 +44,17 @@ static const char *ucp_request_flag_names[] = { [ucs_ilog2(UCP_REQUEST_FLAG_RECV_TAG)] = "rcv_tag", [ucs_ilog2(UCP_REQUEST_FLAG_RKEY_INUSE)] = "rk_use", [ucs_ilog2(UCP_REQUEST_FLAG_USER_HEADER_COPIED)] = "hdr_copy", + [ucs_ilog2(UCP_REQUEST_FLAG_RNDV_RECV_INTERNAL)] = "rndv_rcv_int", #if UCS_ENABLE_ASSERT [ucs_ilog2(UCP_REQUEST_FLAG_STREAM_RECV)] = "strm_rcv", [ucs_ilog2(UCP_REQUEST_DEBUG_FLAG_EXTERNAL)] = "extrn", [ucs_ilog2(UCP_REQUEST_FLAG_SUPER_VALID)] = "spr_vld", #endif + [ucs_ilog2(UCP_REQUEST_FLAG_RNDV_SEND_INTERNAL)] = "rndv_snd_int", + [ucs_ilog2(UCP_REQUEST_FLAG_RNDV_RTR_REQ)] = "rndv_rtr_req", + [ucs_ilog2(UCP_REQUEST_FLAG_RNDV_FLUSH)] = "rndv_flush", + [ucs_ilog2(UCP_REQUEST_FLAG_RNDV_START_FLUSH)] = "rndv_start_flush", }; static ucs_memory_type_t ucp_request_get_mem_type(ucp_request_t *req) @@ -59,7 +64,8 @@ static ucs_memory_type_t ucp_request_get_mem_type(ucp_request_t *req) } else if (req->flags & (UCP_REQUEST_FLAG_SEND_AM | UCP_REQUEST_FLAG_SEND_TAG)) { return req->send.mem_type; } else if (req->flags & - (UCP_REQUEST_FLAG_RECV_AM | UCP_REQUEST_FLAG_RECV_TAG)) { + (UCP_REQUEST_FLAG_RECV_AM | UCP_REQUEST_FLAG_RECV_TAG | + UCP_REQUEST_FLAG_RNDV_RECV_INTERNAL)) { return req->recv.dt_iter.mem_info.type; } else { return UCS_MEMORY_TYPE_UNKNOWN; diff --git a/src/ucp/core/ucp_request.h b/src/ucp/core/ucp_request.h index 3e6c0ad2555..b5dc1ab0090 100644 --- a/src/ucp/core/ucp_request.h +++ b/src/ucp/core/ucp_request.h @@ -56,15 +56,20 @@ enum { UCP_REQUEST_FLAG_USER_HEADER_COPIED = UCS_BIT(19), UCP_REQUEST_FLAG_USAGE_TRACKED = UCS_BIT(20), UCP_REQUEST_FLAG_FENCE_REQUIRED = UCS_BIT(21), + UCP_REQUEST_FLAG_RNDV_RECV_INTERNAL = UCS_BIT(22), #if UCS_ENABLE_ASSERT - UCP_REQUEST_FLAG_STREAM_RECV = UCS_BIT(22), - UCP_REQUEST_DEBUG_FLAG_EXTERNAL = UCS_BIT(23), - UCP_REQUEST_FLAG_SUPER_VALID = UCS_BIT(24), + UCP_REQUEST_FLAG_STREAM_RECV = UCS_BIT(23), + UCP_REQUEST_DEBUG_FLAG_EXTERNAL = UCS_BIT(24), + UCP_REQUEST_FLAG_SUPER_VALID = UCS_BIT(25), #else UCP_REQUEST_FLAG_STREAM_RECV = 0, UCP_REQUEST_DEBUG_FLAG_EXTERNAL = 0, - UCP_REQUEST_FLAG_SUPER_VALID = 0 + UCP_REQUEST_FLAG_SUPER_VALID = 0, #endif + UCP_REQUEST_FLAG_RNDV_SEND_INTERNAL = UCS_BIT(26), + UCP_REQUEST_FLAG_RNDV_RTR_REQ = UCS_BIT(27), + UCP_REQUEST_FLAG_RNDV_FLUSH = UCS_BIT(28), + UCP_REQUEST_FLAG_RNDV_START_FLUSH = UCS_BIT(29) }; @@ -261,6 +266,9 @@ struct ucp_request { /* Remote buffer address for get/put operation */ uint64_t remote_address; + /* Remote buffer memory info for RTR_REQ */ + ucp_memory_info_t remote_mem_info; + /* Key for remote buffer operation */ ucp_rkey_h rkey; @@ -477,6 +485,13 @@ struct ucp_request { size_t length; /* Completion info to fill */ } stream; + struct { + /* Remote endpoint ID used to send internal completions */ + uint64_t ep_id; + /* Completion callback for internal RNDV receives */ + ucp_request_callback_t complete_cb; + } rndv; + struct { ucp_am_recv_data_nbx_callback_t cb; /* Completion callback */ ucp_recv_desc_t *desc; /* Receive desc */ diff --git a/src/ucp/core/ucp_request.inl b/src/ucp/core/ucp_request.inl index 007ccd828b9..33ec7cfe549 100644 --- a/src/ucp/core/ucp_request.inl +++ b/src/ucp/core/ucp_request.inl @@ -266,6 +266,17 @@ ucp_request_put(ucp_request_t *req) ucs_mpool_put_inline(req); } +static UCS_F_ALWAYS_INLINE void +ucp_request_rndv_flush_complete(ucp_request_t *req) +{ + /* Complete the extra flush op held by a RNDV wrapper until the RNDV data + * path completes. */ + if (ucs_unlikely(req->flags & UCP_REQUEST_FLAG_RNDV_FLUSH)) { + req->flags &= ~UCP_REQUEST_FLAG_RNDV_FLUSH; + ucp_worker_flush_ops_count_add(req->send.ep->worker, -1); + } +} + static UCS_F_ALWAYS_INLINE void ucp_request_complete_send(ucp_request_t *req, ucs_status_t status) { diff --git a/src/ucp/proto/proto.c b/src/ucp/proto/proto.c index 4eb5af6cd9e..d5f3c8f93ce 100644 --- a/src/ucp/proto/proto.c +++ b/src/ucp/proto/proto.c @@ -25,10 +25,12 @@ _macro(ucp_get_am_bcopy_proto) \ _macro(ucp_get_offload_bcopy_proto) \ _macro(ucp_get_offload_zcopy_proto) \ + _macro(ucp_get_rndv_proto) \ _macro(ucp_put_am_bcopy_proto) \ _macro(ucp_put_offload_short_proto) \ _macro(ucp_put_offload_bcopy_proto) \ _macro(ucp_put_offload_zcopy_proto) \ + _macro(ucp_put_rndv_proto) \ _macro(ucp_put_sgl_offload_proto) \ _macro(ucp_eager_bcopy_multi_proto) \ _macro(ucp_eager_sync_bcopy_multi_proto) \ diff --git a/src/ucp/proto/proto.h b/src/ucp/proto/proto.h index ae7227b1998..26bc88f6cc2 100644 --- a/src/ucp/proto/proto.h +++ b/src/ucp/proto/proto.h @@ -23,7 +23,7 @@ /* Maximal number of protocols in total */ -#define UCP_PROTO_MAX_COUNT 64 +#define UCP_PROTO_MAX_COUNT 65 /* Special value for non-existent protocol */ diff --git a/src/ucp/proto/proto_common.inl b/src/ucp/proto/proto_common.inl index 6def997b254..3b8fd7658a2 100644 --- a/src/ucp/proto/proto_common.inl +++ b/src/ucp/proto/proto_common.inl @@ -94,7 +94,8 @@ ucp_proto_request_zcopy_clean(ucp_request_t *req, unsigned dt_mask) } static UCS_F_ALWAYS_INLINE void -ucp_proto_request_zcopy_complete(ucp_request_t *req, ucs_status_t status) +ucp_proto_request_zcopy_complete_cb(ucp_request_t *req, ucs_status_t status, + ucp_request_callback_t complete_cb) { ucp_datatype_iter_cleanup(&req->send.state.dt_iter, 1, UCP_DT_MASK_CONTIG_IOV | @@ -109,10 +110,20 @@ ucp_proto_request_zcopy_complete(ucp_request_t *req, ucs_status_t status) !(req->send.ep->flags & UCP_EP_FLAG_FAILED)) { ucp_proto_request_restart(req); } else { + if (complete_cb != NULL) { + complete_cb(req); + } + ucp_request_complete_send(req, status); } } +static UCS_F_ALWAYS_INLINE void +ucp_proto_request_zcopy_complete(ucp_request_t *req, ucs_status_t status) +{ + ucp_proto_request_zcopy_complete_cb(req, status, NULL); +} + static UCS_F_ALWAYS_INLINE ucs_status_t ucp_proto_request_zcopy_complete_success(ucp_request_t *req) { diff --git a/src/ucp/proto/proto_debug.c b/src/ucp/proto/proto_debug.c index 570d8580bc7..babdabe260c 100644 --- a/src/ucp/proto/proto_debug.c +++ b/src/ucp/proto/proto_debug.c @@ -400,7 +400,8 @@ void ucp_proto_select_param_str(const ucp_proto_select_param_t *select_param, [ucs_ilog2(UCP_OP_ATTR_FLAG_MULTI_SEND)] = "multi", }; static const char *rndv_flag_names[] = { - [ucs_ilog2(UCP_PROTO_SELECT_OP_FLAG_PPLN_FRAG)] = "frag" + [ucs_ilog2(UCP_PROTO_SELECT_OP_FLAG_PPLN_FRAG)] = "frag", + [ucs_ilog2(UCP_PROTO_SELECT_OP_FLAG_RNDV_PUSH)] = "push" }; static const char *am_flag_names[] = { [ucs_ilog2(UCP_PROTO_SELECT_OP_FLAG_AM_EAGER)] = "egr", diff --git a/src/ucp/proto/proto_select.c b/src/ucp/proto/proto_select.c index b8c74c691d6..deba66e2649 100644 --- a/src/ucp/proto/proto_select.c +++ b/src/ucp/proto/proto_select.c @@ -558,9 +558,18 @@ ucp_proto_select_lookup_slow(ucp_worker_h worker, return NULL; } - /* add to hash after initializing the temp element, since calling - * ucp_proto_select_elem_init() can recursively modify the hash + /* Add to hash after initializing the temp element, since calling + * ucp_proto_select_elem_init() can recursively modify the hash. + * Re-check the key because recursive lookup may have initialized this + * exact selection already. */ + khiter = kh_get(ucp_proto_select_hash, proto_select->hash, key.u64); + if (khiter != kh_end(proto_select->hash)) { + ucp_proto_select_elem_cleanup(&tmp_select_elem); + select_elem = &kh_value(proto_select->hash, khiter); + goto out; + } + khiter = kh_put(ucp_proto_select_hash, proto_select->hash, key.u64, &khret); ucs_assert_always(khret == UCS_KH_PUT_BUCKET_EMPTY); diff --git a/src/ucp/proto/proto_select.h b/src/ucp/proto/proto_select.h index dddce54f951..a988a52e9a7 100644 --- a/src/ucp/proto/proto_select.h +++ b/src/ucp/proto/proto_select.h @@ -32,6 +32,9 @@ * Relevant for UCP_OP_ID_RNDV_SEND and UCP_OP_ID_RNDV_RECV. */ #define UCP_PROTO_SELECT_OP_FLAG_PPLN_FRAG (UCP_PROTO_SELECT_OP_FLAGS_BASE << 1) +/* Select only push-based rendezvous receive protocols. */ +#define UCP_PROTO_SELECT_OP_FLAG_RNDV_PUSH (UCP_PROTO_SELECT_OP_FLAGS_BASE << 3) + /* Select eager/rendezvous protocol for Active Message sends. * Relevant for UCP_OP_ID_AM_SEND and UCP_OP_ID_AM_SEND_REPLY. */ diff --git a/src/ucp/rma/rma_rndv.c b/src/ucp/rma/rma_rndv.c new file mode 100644 index 00000000000..27221406d1d --- /dev/null +++ b/src/ucp/rma/rma_rndv.c @@ -0,0 +1,598 @@ +/** + * Copyright (c) NVIDIA CORPORATION & AFFILIATES, 2026. ALL RIGHTS RESERVED. + * + * See file LICENSE for terms. + */ + +#ifdef HAVE_CONFIG_H +# include "config.h" +#endif + +#include "rma.h" +#include "rma.inl" +#include "rma_rndv.h" + +#include +#include +#include +#include +#include +#include +#include + + +#define UCP_PROTO_RMA_RNDV_RTS_NAME "RMA_RTS" + + +static void +ucp_rma_rndv_dt_iter_init(ucp_datatype_iter_t *dt_iter, uint64_t address, + size_t length, ucs_memory_type_t mem_type, + ucs_sys_device_t sys_dev) +{ + dt_iter->dt_class = UCP_DATATYPE_CONTIG; + dt_iter->mem_info.type = mem_type; + dt_iter->mem_info.sys_dev = sys_dev; + dt_iter->length = length; + dt_iter->offset = 0; + dt_iter->type.contig.buffer = (void*)(uintptr_t)address; + dt_iter->type.contig.memh = NULL; +} + +static int +ucp_proto_rma_rndv_probe_check(const ucp_proto_init_params_t *init_params, + ucp_operation_id_t op_id) +{ + const ucp_proto_select_param_t *sel_param = init_params->select_param; + + if (!ucp_proto_init_check_op(init_params, UCS_BIT(op_id)) || + ucp_proto_rndv_init_params_is_ppln_frag(init_params) || + (sel_param->dt_class != UCP_DATATYPE_CONTIG) || + (init_params->rkey_config_key == NULL)) { + return 0; + } + + return !UCP_MEM_IS_HOST(sel_param->mem_type) || + !UCP_MEM_IS_HOST(init_params->rkey_config_key->mem_type); +} + + +static size_t ucp_proto_put_rndv_rts_pack(void *dest, void *arg) +{ + ucp_request_t *req = arg; + ucp_rma_rndv_rts_hdr_t *rts = dest; + ucp_rkey_config_t *rkey_config; + + rkey_config = ucp_rkey_config(req->send.ep->worker, req->send.rma.rkey); + + rts->super.hdr = 0; + rts->super.opcode = UCP_RNDV_RTS_RMA; + rts->address = req->send.rma.remote_addr; + rts->sys_dev = rkey_config->key.sys_dev; + rts->mem_type = req->send.rma.rkey->mem_type; + + return ucp_proto_rndv_rts_pack(req, &rts->super, sizeof(*rts)); +} + +static ucs_status_t ucp_proto_put_rndv_init(ucp_request_t *req) +{ + const ucp_proto_rndv_ctrl_priv_t *rpriv = req->send.proto_config->priv; + int was_initialized; + ucs_status_t status; + + was_initialized = req->flags & UCP_REQUEST_FLAG_PROTO_INITIALIZED; + status = ucp_proto_rndv_rts_request_init(req); + if ((status != UCS_OK) || was_initialized) { + return status; + } + + /* Nested RNDV data protocols are not RMA protocols, so the wrapper handles + * RMA fence ordering before exposing the operation to the peer. */ + return ucp_ep_rma_handle_fence(req->send.ep, req, UCS_BIT(rpriv->lane)); +} + +static ucs_status_t ucp_proto_put_rndv_progress(uct_pending_req_t *self) +{ + ucp_request_t *req = ucs_container_of(self, ucp_request_t, send.uct); + const ucp_proto_rndv_ctrl_priv_t *rpriv; + size_t max_rts_size; + ucs_status_t status; + ucp_ep_h ep; + + status = ucp_proto_put_rndv_init(req); + if (status != UCS_OK) { + ucp_proto_request_abort(req, status); + return UCS_OK; + } + + ep = req->send.ep; + rpriv = req->send.proto_config->priv; + max_rts_size = sizeof(ucp_rma_rndv_rts_hdr_t) + rpriv->packed_rkey_size; + + req->flags |= UCP_REQUEST_FLAG_RNDV_FLUSH; + ucp_worker_flush_ops_count_add(ep->worker, +2); + status = ucp_proto_am_bcopy_single_send(req, UCP_AM_ID_RNDV_RTS, + rpriv->lane, + ucp_proto_put_rndv_rts_pack, req, + max_rts_size, 0); + if (status != UCS_OK) { + if (status == UCS_ERR_NO_RESOURCE) { + req->send.lane = rpriv->lane; + } + goto err_flush_count; + } + + ucp_ep_rma_remote_request_sent(ep); + return UCS_OK; + +err_flush_count: + req->flags &= ~UCP_REQUEST_FLAG_RNDV_FLUSH; + ucp_worker_flush_ops_count_add(ep->worker, -2); + if (status == UCS_ERR_NO_RESOURCE) { + return status; + } + + ucp_proto_request_abort(req, status); + return UCS_OK; +} + +static void +ucp_proto_put_rndv_probe(const ucp_proto_init_params_t *init_params) +{ + ucp_context_h context = init_params->worker->context; + ucp_proto_rndv_ctrl_init_params_t params = { + .super.super = *init_params, + .super.latency = 0, + .super.overhead = context->config.ext.proto_overhead_rndv_rts, + .super.cfg_thresh = context->config.ext.zcopy_thresh, + .super.cfg_priority = 5, + .super.min_length = 0, + .super.max_length = SIZE_MAX, + .super.min_iov = 1, + .super.min_frag_offs = UCP_PROTO_COMMON_OFFSET_INVALID, + .super.max_frag_offs = ucs_offsetof(uct_iface_attr_t, cap.am.max_bcopy), + .super.max_iov_offs = UCP_PROTO_COMMON_OFFSET_INVALID, + .super.hdr_size = sizeof(ucp_rma_rndv_rts_hdr_t), + .super.send_op = UCT_EP_OP_AM_BCOPY, + .super.memtype_op = UCT_EP_OP_LAST, + .super.flags = UCP_PROTO_COMMON_INIT_FLAG_ERR_HANDLING, + .super.exclude_map = 0, + .super.reg_mem_info = ucp_proto_common_select_param_mem_info( + init_params->select_param), + /* For performance modeling, this control protocol is followed on the + * peer by a regular RNDV receive flow over the final RMA address. */ + .remote_op_id = UCP_OP_ID_RNDV_RECV, + .lane = ucp_proto_rndv_find_ctrl_lane(init_params), + .unpack_perf = NULL, + .perf_bias = 0, + .ctrl_msg_name = UCP_PROTO_RMA_RNDV_RTS_NAME, + .md_map = 0 + }; + ucp_proto_rndv_ctrl_priv_t rpriv = {0}; + + if (!ucp_proto_rma_rndv_probe_check(init_params, UCP_OP_ID_PUT)) { + return; + } + + ucp_proto_rndv_ctrl_probe(¶ms, &rpriv, sizeof(rpriv)); +} + +static void +ucp_rma_rndv_send_ats_err(ucp_ep_h ep, ucs_ptr_map_key_t remote_req_id, + ucs_status_t status) +{ + ucp_request_t *req; + + req = ucp_request_get(ep->worker); + if (req == NULL) { + ucs_error("failed to allocate RMA RNDV error ATS"); + return; + } + + ucp_proto_request_send_init(req, ep, 0); + ucp_rndv_req_send_ack(req, 0, remote_req_id, status, UCP_AM_ID_RNDV_ATS, + "send_ats_err"); +} + +static void ucp_proto_rma_rndv_query(const ucp_proto_query_params_t *params, + ucp_proto_query_attr_t *attr, + const char *desc) +{ + const ucp_proto_rndv_ctrl_priv_t *rpriv = params->priv; + ucp_proto_query_attr_t remote_attr; + + ucp_proto_config_query(params->worker, &rpriv->remote_proto_config, + params->msg_length, &remote_attr); + + attr->is_estimation = 1; + attr->max_msg_length = remote_attr.max_msg_length; + attr->lane_map = UCS_BIT(rpriv->lane) | remote_attr.lane_map; + + ucs_snprintf_safe(attr->desc, sizeof(attr->desc), "%s using %s", desc, + remote_attr.desc); + ucs_snprintf_safe(attr->config, sizeof(attr->config), "%s", + remote_attr.config); +} + +static void +ucp_proto_put_rndv_query(const ucp_proto_query_params_t *params, + ucp_proto_query_attr_t *attr) +{ + ucp_proto_rma_rndv_query(params, attr, UCP_PROTO_RNDV_DESC); +} + +static void +ucp_proto_get_rndv_query(const ucp_proto_query_params_t *params, + ucp_proto_query_attr_t *attr) +{ + ucp_proto_rma_rndv_query(params, attr, UCP_PROTO_RNDV_DESC); +} + +static void +ucp_proto_get_rndv_add_variant( + const ucp_proto_init_params_t *init_params, + const ucp_proto_select_param_t *select_param, + ucp_worker_cfg_index_t rkey_cfg_index, ucp_lane_index_t lane, + ucp_proto_init_elem_t *proto, const void *proto_priv) +{ + ucp_context_h context = init_params->worker->context; + const ucp_proto_perf_t *perf_elems[1]; + ucp_proto_rndv_ctrl_priv_t rpriv = {0}; + ucp_proto_init_params_t variant_params; + UCS_STRING_BUFFER_ONSTACK(perf_name, 128); + ucp_proto_perf_t *perf; + size_t cfg_thresh; + ucs_status_t status; + + perf_elems[0] = proto->perf; + ucs_string_buffer_appendf(&perf_name, "%s" UCP_PROTO_PERF_NODE_NEW_LINE + "%s", UCP_PROTO_RNDV_RTR_REQ_NAME, + ucp_proto_perf_name(proto->perf)); + status = ucp_proto_perf_aggregate(ucs_string_buffer_cstr(&perf_name), + perf_elems, 1, &perf); + if (status != UCS_OK) { + return; + } else if (ucp_proto_perf_is_empty(perf)) { + ucp_proto_perf_destroy(perf); + return; + } + + rpriv.lane = lane; + + variant_params = *init_params; + variant_params.rkey_cfg_index = rkey_cfg_index; + ucp_proto_rndv_set_variant_config(&variant_params, proto, select_param, + proto_priv, &rpriv.remote_proto_config); + + cfg_thresh = context->config.ext.zcopy_thresh; + if (proto->cfg_thresh != UCS_MEMUNITS_AUTO) { + cfg_thresh = proto->cfg_thresh; + } + + ucp_proto_select_add_proto(init_params, cfg_thresh, 6, perf, &rpriv, + sizeof(rpriv)); +} + +static void +ucp_proto_get_rndv_probe(const ucp_proto_init_params_t *init_params) +{ + ucp_worker_h worker = init_params->worker; + const ucp_proto_select_init_protocols_t *proto_init; + ucp_proto_select_param_t rndv_sel_param; + ucp_worker_cfg_index_t rkey_cfg_index; + ucp_proto_select_elem_t *select_elem; + ucp_proto_select_t *proto_select; + ucp_proto_init_elem_t *proto; + ucp_memory_info_t mem_info; + ucp_lane_index_t lane; + const void *priv; + + if (!ucp_proto_rma_rndv_probe_check(init_params, UCP_OP_ID_GET)) { + return; + } + + lane = ucp_proto_rndv_find_ctrl_lane(init_params); + if (lane == UCP_NULL_LANE) { + return; + } + + mem_info = ucp_proto_common_select_param_mem_info( + init_params->select_param); + ucp_proto_select_param_init(&rndv_sel_param, UCP_OP_ID_RNDV_RECV, 0, + UCP_PROTO_SELECT_OP_FLAG_RNDV_PUSH, + UCP_DATATYPE_CONTIG, &mem_info, 1); + + proto_select = ucp_proto_select_get(worker, init_params->ep_cfg_index, + init_params->rkey_cfg_index, + &rkey_cfg_index); + if (proto_select == NULL) { + return; + } + + select_elem = ucp_proto_select_lookup_slow(worker, proto_select, 1, + init_params->ep_cfg_index, + rkey_cfg_index, + &rndv_sel_param); + if (select_elem == NULL) { + return; + } + + proto_init = &select_elem->proto_init; + ucs_array_for_each(proto, &proto_init->protocols) { + if (ucp_proto_id_field(proto->proto_id, flags) & + UCP_PROTO_FLAG_INVALID) { + continue; + } + + priv = &ucs_array_elem(&proto_init->priv_buf, proto->priv_offset); + ucp_proto_get_rndv_add_variant(init_params, &rndv_sel_param, + rkey_cfg_index, lane, proto, priv); + } +} + +static void +ucp_proto_get_rndv_abort(ucp_request_t *req, ucs_status_t status) +{ + if (req->id != UCS_PTR_MAP_KEY_INVALID) { + ucp_send_request_id_release(req); + } + + if (req->send.state.dt_iter.dt_class != UCP_DATATYPE_CLASS_MASK) { + ucp_datatype_iter_cleanup(&req->send.state.dt_iter, 0, + UCP_DT_MASK_ALL); + } + + ucp_request_complete_send(req, status); +} + +static ucs_status_t ucp_proto_get_rndv_reset(ucp_request_t *req) +{ + if (req->flags & UCP_REQUEST_FLAG_PROTO_INITIALIZED) { + if (req->id != UCS_PTR_MAP_KEY_INVALID) { + ucp_send_request_id_release(req); + } + + if (req->send.state.dt_iter.dt_class != UCP_DATATYPE_CLASS_MASK) { + ucp_datatype_iter_cleanup(&req->send.state.dt_iter, 0, + UCP_DT_MASK_ALL); + } + } + + req->flags &= ~UCP_REQUEST_FLAG_PROTO_INITIALIZED; + return UCS_OK; +} + +ucp_request_t *ucp_rma_rndv_rtr_flush_open(ucp_request_t *rtr_req) +{ + ucp_request_t *recv_req = rtr_req; + ucp_ep_h ep = rtr_req->send.ep; + + if (!(rtr_req->flags & UCP_REQUEST_FLAG_RNDV_RTR_REQ)) { + return NULL; + } + + while (!(recv_req->flags & UCP_REQUEST_FLAG_RNDV_RECV_INTERNAL)) { + ucs_assert(recv_req->flags & UCP_REQUEST_FLAG_RNDV_RTR_REQ); + recv_req = ucp_request_get_super(recv_req); + } + + if (recv_req->flags & UCP_REQUEST_FLAG_RNDV_START_FLUSH) { + /* Claim before the AM send, since SELF can complete inline. */ + recv_req->flags &= ~UCP_REQUEST_FLAG_RNDV_START_FLUSH; + ucp_worker_flush_ops_count_add(ep->worker, +1); + return recv_req; + } + + return NULL; +} + +void ucp_rma_rndv_rtr_flush_close(ucp_request_t *recv_req, ucp_ep_h ep, + ucs_status_t status) +{ + if (recv_req != NULL) { + if (status == UCS_OK) { + /* recv_req may complete inline, so only touch ep on success. */ + ucp_ep_rma_remote_request_sent(ep); + } else { + ucp_worker_flush_ops_count_add(ep->worker, -1); + recv_req->flags |= UCP_REQUEST_FLAG_RNDV_START_FLUSH; + } + } +} + +static void ucp_rma_rndv_get_recv_complete(ucp_request_t *recv_req) +{ + ucp_request_t *get_req = ucp_request_get_super(recv_req); + ucp_ep_h ep = get_req->send.ep; + int start_flush; + + start_flush = recv_req->flags & UCP_REQUEST_FLAG_RNDV_START_FLUSH; + ucp_request_complete_send(get_req, recv_req->status); + if (!start_flush) { + ucp_ep_rma_remote_request_completed(ep); + } + ucp_request_put(recv_req); +} + +static ucs_status_t ucp_proto_get_rndv_init(ucp_request_t *get_req, + ucp_request_t **rndv_req_p) +{ + const ucp_proto_rndv_ctrl_priv_t *rpriv = get_req->send.proto_config->priv; + ucp_worker_h worker = get_req->send.ep->worker; + ucp_rkey_config_t *rkey_config; + ucp_request_t *recv_req; + ucp_request_t *rndv_req; + uint8_t UCS_V_UNUSED sg_count; + ucp_memory_info_t mem_info; + ucs_status_t status; + uint64_t address; + size_t length; + + status = ucp_ep_rma_handle_fence(get_req->send.ep, get_req, + UCS_BIT(rpriv->lane)); + if (status != UCS_OK) { + return status; + } + + address = get_req->send.rma.remote_addr; + rkey_config = ucp_rkey_config(worker, get_req->send.rma.rkey); + mem_info.type = get_req->send.rma.rkey->mem_type; + mem_info.sys_dev = rkey_config->key.sys_dev; + length = get_req->send.state.dt_iter.length; + get_req->send.buffer = + get_req->send.state.dt_iter.type.contig.buffer; + get_req->send.length = length; + + recv_req = ucp_request_get(worker); + if (recv_req == NULL) { + return UCS_ERR_NO_MEMORY; + } + + rndv_req = ucp_request_get(worker); + if (rndv_req == NULL) { + ucp_request_put(recv_req); + return UCS_ERR_NO_MEMORY; + } + + get_req->flags |= UCP_REQUEST_FLAG_PROTO_INITIALIZED; + recv_req->flags = UCP_REQUEST_FLAG_RNDV_RECV_INTERNAL | + UCP_REQUEST_FLAG_RNDV_START_FLUSH; + recv_req->recv.worker = worker; + recv_req->recv.op_attr = 0; + recv_req->recv.remote_req_id = UCS_PTR_MAP_KEY_INVALID; + recv_req->recv.rndv.complete_cb = ucp_rma_rndv_get_recv_complete; + recv_req->status = UCS_OK; + ucp_request_set_super(recv_req, get_req); + + UCS_PROFILE_CALL_VOID(ucp_datatype_iter_move, &recv_req->recv.dt_iter, + &get_req->send.state.dt_iter, length, &sg_count); + + ucp_proto_request_send_init(rndv_req, get_req->send.ep, + UCP_REQUEST_FLAG_RNDV_RTR_REQ); + ucp_request_set_super(rndv_req, recv_req); + rndv_req->send.rndv.remote_req_id = UCS_PTR_MAP_KEY_INVALID; + rndv_req->send.rndv.remote_address = address; + rndv_req->send.rndv.remote_mem_info = mem_info; + rndv_req->send.rndv.rkey = NULL; + rndv_req->send.rndv.offset = 0; + + UCS_PROFILE_CALL_VOID(ucp_datatype_iter_move, + &rndv_req->send.state.dt_iter, + &recv_req->recv.dt_iter, length, &sg_count); + ucp_proto_request_set_proto(rndv_req, &rpriv->remote_proto_config, length); + + *rndv_req_p = rndv_req; + return UCS_OK; +} + +static ucs_status_t ucp_proto_get_rndv_progress(uct_pending_req_t *self) +{ + ucp_request_t *get_req = ucs_container_of(self, ucp_request_t, send.uct); + const ucp_proto_rndv_ctrl_priv_t *rpriv; + ucp_request_t *rndv_req; + ucs_status_t status; + + if (get_req->flags & UCP_REQUEST_FLAG_PROTO_INITIALIZED) { + return UCS_OK; + } + + rpriv = get_req->send.proto_config->priv; + get_req->send.lane = rpriv->lane; + status = ucp_ep_resolve_remote_id(get_req->send.ep, + rpriv->lane); + if (status == UCS_ERR_NO_RESOURCE) { + return status; + } else if (status != UCS_OK) { + ucp_proto_request_abort(get_req, status); + return UCS_OK; + } + + status = ucp_proto_get_rndv_init(get_req, &rndv_req); + if (status != UCS_OK) { + ucp_proto_request_abort(get_req, status); + return UCS_OK; + } + + ucp_request_send(rndv_req); + return UCS_OK; +} + +static void ucp_rma_rndv_put_recv_complete(ucp_request_t *recv_req) +{ + ucp_worker_h worker = recv_req->recv.worker; + ucp_ep_h ep; + + UCP_WORKER_GET_EP_BY_ID(&ep, worker, recv_req->recv.rndv.ep_id, { + ucp_request_put(recv_req); + return; + }, "RMA RNDV PUT completion"); + + ucp_rma_sw_send_cmpl(ep); + if (recv_req->status != UCS_OK) { + ucp_rma_rndv_send_ats_err(ep, recv_req->recv.remote_req_id, + recv_req->status); + } + + ucp_request_put(recv_req); +} + +ucs_status_t ucp_rma_rndv_process_rts(ucp_worker_h worker, + const ucp_rma_rndv_rts_hdr_t *rts, + size_t length) +{ + const void *rkey_buffer; + ucp_request_t *recv_req; + ucp_ep_h ep; + + if (length < sizeof(*rts)) { + return UCS_ERR_MESSAGE_TRUNCATED; + } + + recv_req = ucp_request_get(worker); + if (recv_req == NULL) { + ucs_error("failed to allocate RMA RNDV PUT receive request"); + UCP_WORKER_GET_EP_BY_ID(&ep, worker, rts->super.sreq.ep_id, + return UCS_OK, "RMA RNDV PUT error"); + ucp_rma_sw_send_cmpl(ep); + ucp_rma_rndv_send_ats_err(ep, rts->super.sreq.req_id, + UCS_ERR_NO_MEMORY); + return UCS_OK; + } + + recv_req->flags = UCP_REQUEST_FLAG_RNDV_RECV_INTERNAL; + recv_req->recv.worker = worker; + recv_req->recv.op_attr = 0; + recv_req->recv.remote_req_id = rts->super.sreq.req_id; + ucp_rma_rndv_dt_iter_init(&recv_req->recv.dt_iter, rts->address, + rts->super.size, rts->mem_type, rts->sys_dev); + recv_req->recv.rndv.ep_id = rts->super.sreq.ep_id; + recv_req->recv.rndv.complete_cb = ucp_rma_rndv_put_recv_complete; + + rkey_buffer = UCS_PTR_BYTE_OFFSET(rts, sizeof(*rts)); + ucp_proto_rndv_receive_start(worker, recv_req, &rts->super, rkey_buffer, + length - sizeof(*rts)); + return UCS_OK; +} + +ucp_proto_t ucp_put_rndv_proto = { + .name = "put/rndv", + .desc = UCP_PROTO_RNDV_DESC, + .flags = 0, + .dt_mask = UCS_BIT(UCP_DATATYPE_CONTIG), + .probe = ucp_proto_put_rndv_probe, + .query = ucp_proto_put_rndv_query, + .progress = {ucp_proto_put_rndv_progress}, + .abort = ucp_proto_rndv_rts_abort, + .reset = ucp_proto_rndv_rts_reset +}; + +ucp_proto_t ucp_get_rndv_proto = { + .name = "get/rndv", + .desc = UCP_PROTO_RNDV_DESC, + .flags = 0, + .dt_mask = UCS_BIT(UCP_DATATYPE_CONTIG), + .probe = ucp_proto_get_rndv_probe, + .query = ucp_proto_get_rndv_query, + .progress = {ucp_proto_get_rndv_progress}, + .abort = ucp_proto_get_rndv_abort, + .reset = ucp_proto_get_rndv_reset +}; diff --git a/src/ucp/rma/rma_rndv.h b/src/ucp/rma/rma_rndv.h new file mode 100644 index 00000000000..906c41bd40b --- /dev/null +++ b/src/ucp/rma/rma_rndv.h @@ -0,0 +1,31 @@ +/** + * Copyright (c) NVIDIA CORPORATION & AFFILIATES, 2026. ALL RIGHTS RESERVED. + * + * See file LICENSE for terms. + */ + +#ifndef UCP_RMA_RNDV_H_ +#define UCP_RMA_RNDV_H_ + +#include +#include + + +typedef struct { + ucp_rndv_rts_hdr_t super; + uint64_t address; + ucs_sys_device_t sys_dev; + ucs_memory_type_t mem_type; +} UCS_S_PACKED ucp_rma_rndv_rts_hdr_t; + + +ucs_status_t ucp_rma_rndv_process_rts(ucp_worker_h worker, + const ucp_rma_rndv_rts_hdr_t *rts, + size_t length); + +ucp_request_t *ucp_rma_rndv_rtr_flush_open(ucp_request_t *rtr_req); + +void ucp_rma_rndv_rtr_flush_close(ucp_request_t *recv_req, ucp_ep_h ep, + ucs_status_t status); + +#endif diff --git a/src/ucp/rndv/proto_rndv.c b/src/ucp/rndv/proto_rndv.c index 3b4f2937ee7..2db3e9e6724 100644 --- a/src/ucp/rndv/proto_rndv.c +++ b/src/ucp/rndv/proto_rndv.c @@ -580,6 +580,7 @@ void ucp_proto_rndv_rts_query(const ucp_proto_query_params_t *params, void ucp_proto_rndv_rts_abort(ucp_request_t *req, ucs_status_t status) { ucp_am_release_user_header(req); + ucp_request_rndv_flush_complete(req); if (ucp_request_memh_invalidate(req, status)) { ucp_proto_rndv_rts_reset(req); @@ -817,6 +818,7 @@ UCS_PROFILE_FUNC_VOID(ucp_proto_rndv_receive_start, ucp_ep_h ep; UCP_WORKER_GET_VALID_EP_BY_ID(&ep, worker, rts->sreq.ep_id, { + ucp_datatype_iter_cleanup(&recv_req->recv.dt_iter, 1, UCP_DT_MASK_ALL); ucp_proto_rndv_recv_req_complete(recv_req, UCS_ERR_CANCELED); return; }, "RTS on non-existing endpoint"); @@ -824,6 +826,8 @@ UCS_PROFILE_FUNC_VOID(ucp_proto_rndv_receive_start, req = ucp_request_get(worker); if (req == NULL) { ucs_error("failed to allocate rendezvous reply"); + ucp_datatype_iter_cleanup(&recv_req->recv.dt_iter, 1, UCP_DT_MASK_ALL); + ucp_proto_rndv_recv_req_complete(recv_req, UCS_ERR_NO_MEMORY); return; } @@ -856,6 +860,7 @@ UCS_PROFILE_FUNC_VOID(ucp_proto_rndv_receive_start, if (status != UCS_OK) { ucp_datatype_iter_cleanup(&req->send.state.dt_iter, 1, UCP_DT_MASK_ALL); ucs_mpool_put(req); + ucp_proto_rndv_recv_req_complete(recv_req, status); return; } @@ -868,16 +873,13 @@ UCS_PROFILE_FUNC_VOID(ucp_proto_rndv_receive_start, } UCS_PROFILE_FUNC(ucs_status_t, ucp_proto_rndv_send_start, - (worker, req, op_attr_mask, rtr, header_length, sg_count), + (worker, req, op_attr_mask, rtr, rkey_buffer, rkey_length, + sg_count), ucp_worker_h worker, ucp_request_t *req, uint32_t op_attr_mask, - const ucp_rndv_rtr_hdr_t *rtr, size_t header_length, - uint8_t sg_count) + const ucp_rndv_rtr_hdr_t *rtr, const void *rkey_buffer, + size_t rkey_length, uint8_t sg_count) { ucs_status_t status; - size_t rkey_length; - - ucs_assert(header_length >= sizeof(*rtr)); - rkey_length = header_length - sizeof(*rtr); ucp_proto_rndv_check_rkey_length(rtr->address, rkey_length, "rtr"); req->send.rndv.remote_address = rtr->address; @@ -886,7 +888,7 @@ UCS_PROFILE_FUNC(ucs_status_t, ucp_proto_rndv_send_start, ucs_assert(rtr->size == req->send.state.dt_iter.length); status = ucp_proto_rndv_send_reply(worker, req, UCP_OP_ID_RNDV_SEND, - op_attr_mask, rtr->size, rtr + 1, + op_attr_mask, rtr->size, rkey_buffer, rkey_length, sg_count); if (status != UCS_OK) { return status; @@ -909,7 +911,101 @@ static void ucp_proto_rndv_send_complete_one(void *request, ucs_status_t status, } ucp_send_request_id_release(req); - ucp_proto_request_zcopy_complete(req, status); + ucp_proto_rndv_request_zcopy_complete(req, status); +} + +static void +ucp_proto_rndv_rtr_req_send_complete(void *request, + ucs_status_t UCS_V_UNUSED status, + void *UCS_V_UNUSED user_data) +{ + ucp_request_t *req = (ucp_request_t*)request - 1; + + ucp_request_put(req); +} + +static void +ucp_proto_rndv_rtr_req_send_atp_err(ucp_ep_h ep, + ucs_ptr_map_key_t remote_req_id, + ucs_status_t status) +{ + ucp_request_t *req; + + req = ucp_request_get(ep->worker); + if (req == NULL) { + ucs_error("failed to allocate RNDV RTR_REQ error ATP"); + return; + } + + ucp_proto_request_send_init(req, ep, 0); + ucp_rndv_req_send_ack(req, 0, remote_req_id, status, UCP_AM_ID_RNDV_ATP, + "send_atp_err"); +} + +static void +ucp_proto_rndv_rtr_req_sreq_init(ucp_ep_h ep, ucp_request_t *req, + const ucp_rndv_rtr_req_hdr_t *rtr_req) +{ + const ucp_rndv_rtr_hdr_t *rtr = &rtr_req->super; + + ucp_proto_request_send_init(req, ep, + UCP_REQUEST_FLAG_RNDV_SEND_INTERNAL); + ucp_request_set_callback(req, send.cb, + ucp_proto_rndv_rtr_req_send_complete); + req->send.buffer = (void*)(uintptr_t)rtr_req->address; + req->send.length = rtr->size; + req->send.mem_type = rtr_req->mem_type; + req->send.rndv.remote_req_id = rtr->rreq_id; + req->send.rndv.rkey = NULL; + req->send.rndv.remote_address = rtr_req->address; + req->send.state.dt_iter.dt_class = UCP_DATATYPE_CONTIG; + req->send.state.dt_iter.mem_info.type = rtr_req->mem_type; + req->send.state.dt_iter.mem_info.sys_dev = rtr_req->sys_dev; + req->send.state.dt_iter.length = rtr->size; + req->send.state.dt_iter.offset = 0; + req->send.state.dt_iter.type.contig.buffer = + (void*)(uintptr_t)rtr_req->address; + req->send.state.dt_iter.type.contig.memh = NULL; +} + +static ucs_status_t +ucp_proto_rndv_handle_rtr_req(ucp_worker_h worker, void *data, size_t length) +{ + const ucp_rndv_rtr_req_hdr_t *rtr_req = data; + const void *rkey_buffer; + ucp_request_t *req; + ucs_status_t status; + ucp_ep_h ep; + + if (length < sizeof(*rtr_req)) { + return UCS_ERR_MESSAGE_TRUNCATED; + } + + UCP_WORKER_GET_EP_BY_ID(&ep, worker, rtr_req->req.ep_id, return UCS_OK, + "RNDV RTR_REQ"); + + req = ucp_request_get(worker); + if (req == NULL) { + ucs_error("failed to allocate RNDV RTR_REQ send request"); + ucp_proto_rndv_rtr_req_send_atp_err(ep, rtr_req->super.rreq_id, + UCS_ERR_NO_MEMORY); + return UCS_OK; + } + + rkey_buffer = rtr_req + 1; + ucp_proto_rndv_rtr_req_sreq_init(ep, req, rtr_req); + status = ucp_proto_rndv_send_start(worker, req, 0, &rtr_req->super, + rkey_buffer, + length - sizeof(*rtr_req), 1); + if (status != UCS_OK) { + ucp_proto_rndv_rtr_req_send_atp_err(ep, rtr_req->super.rreq_id, + status); + ucp_datatype_iter_cleanup(&req->send.state.dt_iter, 1, + UCP_DT_MASK_ALL); + ucp_request_put(req); + } + + return UCS_OK; } ucs_status_t @@ -923,6 +1019,14 @@ ucp_proto_rndv_handle_rtr(void *arg, void *data, size_t length, unsigned flags) ucs_status_t status; uint8_t sg_count; + if (length < sizeof(*rtr)) { + return UCS_ERR_MESSAGE_TRUNCATED; + } + + if (rtr->sreq_id == UCS_PTR_MAP_KEY_INVALID) { + return ucp_proto_rndv_handle_rtr_req(worker, data, length); + } + UCP_SEND_REQUEST_GET_BY_ID(&req, worker, rtr->sreq_id, 0, return UCS_OK, "RTR %p", rtr); @@ -950,7 +1054,8 @@ ucp_proto_rndv_handle_rtr(void *arg, void *data, size_t length, unsigned flags) sg_count = select_param->sg_count; status = ucp_proto_rndv_send_start(worker, req, op_attr_mask, rtr, - length, sg_count); + rtr + 1, length - sizeof(*rtr), + sg_count); if (status != UCS_OK) { goto err_request_fail; } @@ -980,7 +1085,8 @@ ucp_proto_rndv_handle_rtr(void *arg, void *data, size_t length, unsigned flags) status = ucp_proto_rndv_send_start(worker, freq, op_attr_mask | UCP_OP_ATTR_FLAG_MULTI_SEND, - rtr, length, 1); + rtr, rtr + 1, + length - sizeof(*rtr), 1); if (status != UCS_OK) { goto err_put_freq; } diff --git a/src/ucp/rndv/proto_rndv.h b/src/ucp/rndv/proto_rndv.h index 8299b740491..419c233603c 100644 --- a/src/ucp/rndv/proto_rndv.h +++ b/src/ucp/rndv/proto_rndv.h @@ -12,11 +12,13 @@ #include -/* Names of rendezvous control messages */ -#define UCP_PROTO_RNDV_RTS_NAME "RTS" -#define UCP_PROTO_RNDV_RTR_NAME "RTR" -#define UCP_PROTO_RNDV_ATS_NAME "ATS" -#define UCP_PROTO_RNDV_ATP_NAME "ATP" +/* Rendezvous protocol description and control message names */ +#define UCP_PROTO_RNDV_DESC "rndv" +#define UCP_PROTO_RNDV_RTS_NAME "RTS" +#define UCP_PROTO_RNDV_RTR_NAME "RTR" +#define UCP_PROTO_RNDV_RTR_REQ_NAME "RTR_REQ" +#define UCP_PROTO_RNDV_ATS_NAME "ATS" +#define UCP_PROTO_RNDV_ATP_NAME "ATP" /* Mask of rendezvous operations */ @@ -201,6 +203,13 @@ void ucp_proto_rndv_receive_start(ucp_worker_h worker, ucp_request_t *recv_req, ucs_status_t ucp_proto_rndv_handle_rtr(void *arg, void *data, size_t length, unsigned flags); +ucs_status_t ucp_proto_rndv_send_start(ucp_worker_h worker, + ucp_request_t *req, + uint32_t op_attr_mask, + const ucp_rndv_rtr_hdr_t *rtr, + const void *rkey_buffer, + size_t rkey_length, + uint8_t sg_count); ucs_status_t ucp_proto_rndv_rtr_handle_atp(void *arg, void *data, size_t length, unsigned flags); diff --git a/src/ucp/rndv/proto_rndv.inl b/src/ucp/rndv/proto_rndv.inl index 837e8dc1534..07db2f0de5c 100644 --- a/src/ucp/rndv/proto_rndv.inl +++ b/src/ucp/rndv/proto_rndv.inl @@ -79,7 +79,7 @@ ucp_proto_rndv_ats_handler(void *arg, void *data, size_t length, unsigned flags) ucp_tag_offload_cancel_rndv(req); } - if (length >= sizeof(*ats)) { + if ((status == UCS_OK) && (length >= sizeof(*ats))) { /* ATS message carries a size field */ ats = ucs_derived_of(rephdr, ucp_rndv_ack_hdr_t); if (!ucp_proto_common_frag_complete(req, ats->size, "rndv_ats")) { @@ -89,11 +89,35 @@ ucp_proto_rndv_ats_handler(void *arg, void *data, size_t length, unsigned flags) ucp_send_request_id_release(req); ucp_datatype_iter_cleanup(&req->send.state.dt_iter, 1, UCP_DT_MASK_ALL); + ucp_request_rndv_flush_complete(req); ucp_request_complete_send(req, status); return UCS_OK; } +static UCS_F_ALWAYS_INLINE void +ucp_proto_rndv_request_zcopy_complete(ucp_request_t *req, ucs_status_t status) +{ + ucp_proto_request_zcopy_complete_cb(req, status, + ucp_request_rndv_flush_complete); +} + +static UCS_F_ALWAYS_INLINE ucs_status_t +ucp_proto_rndv_request_zcopy_complete_success(ucp_request_t *req) +{ + ucp_proto_rndv_request_zcopy_complete(req, UCS_OK); + return UCS_OK; +} + +static UCS_F_ALWAYS_INLINE void +ucp_proto_rndv_request_zcopy_completion(uct_completion_t *self) +{ + ucp_request_t *req = ucs_container_of(self, ucp_request_t, + send.state.uct_comp); + + ucp_proto_rndv_request_zcopy_complete(req, req->send.state.uct_comp.status); +} + static UCS_F_ALWAYS_INLINE size_t ucp_proto_rndv_rts_pack( ucp_request_t *req, ucp_rndv_rts_hdr_t *rts, size_t hdr_len) { @@ -182,6 +206,10 @@ static UCS_F_ALWAYS_INLINE ucs_status_t ucp_proto_rndv_frag_request_alloc( ucp_proto_request_send_init(freq, req->send.ep, UCP_REQUEST_FLAG_RNDV_FRAG); ucp_request_set_super(freq, req); + if (req->flags & UCP_REQUEST_FLAG_RNDV_RTR_REQ) { + freq->flags |= UCP_REQUEST_FLAG_RNDV_RTR_REQ; + freq->send.rndv.remote_mem_info = req->send.rndv.remote_mem_info; + } *freq_p = freq; return UCS_OK; @@ -346,6 +374,13 @@ ucp_proto_rndv_init_params_is_ppln_frag(const ucp_proto_init_params_t *params) UCP_PROTO_SELECT_OP_FLAG_PPLN_FRAG; } +static UCS_F_ALWAYS_INLINE int +ucp_proto_rndv_init_params_is_push(const ucp_proto_init_params_t *params) +{ + return ucp_proto_select_op_flags(params->select_param) & + UCP_PROTO_SELECT_OP_FLAG_RNDV_PUSH; +} + static UCS_F_ALWAYS_INLINE int ucp_proto_rndv_op_check(const ucp_proto_init_params_t *params, ucp_operation_id_t op_id, int support_ppln) @@ -360,7 +395,10 @@ ucp_proto_rndv_recv_req_complete(ucp_request_t *recv_req, ucs_status_t status) ucp_trace_req(recv_req, "rndv_recv_req_complete status '%s'", ucs_status_string(status)); - if (recv_req->flags & UCP_REQUEST_FLAG_RECV_AM) { + if (recv_req->flags & UCP_REQUEST_FLAG_RNDV_RECV_INTERNAL) { + recv_req->status = status; + recv_req->recv.rndv.complete_cb(recv_req); + } else if (recv_req->flags & UCP_REQUEST_FLAG_RECV_AM) { ucp_request_complete_am_recv(recv_req, status); } else { ucs_assert(recv_req->flags & UCP_REQUEST_FLAG_RECV_TAG); diff --git a/src/ucp/rndv/rndv.c b/src/ucp/rndv/rndv.c index 76cdcc4c570..2680b9613de 100644 --- a/src/ucp/rndv/rndv.c +++ b/src/ucp/rndv/rndv.c @@ -17,6 +17,7 @@ #include #include #include +#include #include #include @@ -422,6 +423,7 @@ static void ucp_rndv_complete_rma_put_zcopy(ucp_request_t *sreq, int is_frag_put } ucp_request_send_buffer_dereg(sreq); + ucp_request_rndv_flush_complete(sreq); ucp_request_complete_send(sreq, status); } @@ -1759,9 +1761,11 @@ UCS_PROFILE_FUNC(ucs_status_t, ucp_rndv_rts_handler, if (ucp_rndv_rts_is_am(rts_hdr)) { return ucp_am_rndv_process_rts(arg, data, length, tl_flags); - } else { - ucs_assert(ucp_rndv_rts_is_tag(rts_hdr)); + } else if (ucp_rndv_rts_is_tag(rts_hdr)) { return ucp_tag_rndv_process_rts(worker, rts_hdr, length, tl_flags); + } else { + ucs_assert(ucp_rndv_rts_is_rma(rts_hdr)); + return ucp_rma_rndv_process_rts(worker, data, length); } } @@ -1888,6 +1892,7 @@ static void ucp_rndv_am_zcopy_send_req_complete(ucp_request_t *req, { ucs_assert(req->send.state.uct_comp.count == 0); ucp_request_send_buffer_dereg(req); + ucp_request_rndv_flush_complete(req); ucp_request_complete_send(req, status); } @@ -2436,7 +2441,9 @@ static void ucp_rndv_dump(ucp_worker_h worker, uct_am_trace_type_t type, { UCS_STRING_BUFFER_FIXED(strb, buffer, max); const ucp_rndv_rts_hdr_t *rndv_rts_hdr = data; + const ucp_rma_rndv_rts_hdr_t *rma_rts = data; const ucp_rndv_rtr_hdr_t *rndv_rtr_hdr = data; + const ucp_rndv_rtr_req_hdr_t *rtr_req = data; const ucp_request_data_hdr_t *rndv_data = data; const ucp_rndv_ack_hdr_t *ack_hdr = data; const ucp_reply_hdr_t *rep_hdr = data; @@ -2449,13 +2456,19 @@ static void ucp_rndv_dump(ucp_worker_h worker, uct_am_trace_type_t type, if (ucp_rndv_rts_is_am(rndv_rts_hdr)) { ucs_string_buffer_appendf(&strb, "am_id %u", ucp_am_hdr_from_rts(rndv_rts_hdr)->am_id); - } else { - ucs_assert(ucp_rndv_rts_is_tag(rndv_rts_hdr)); + rkey_buf = rndv_rts_hdr + 1; + } else if (ucp_rndv_rts_is_tag(rndv_rts_hdr)) { ucs_string_buffer_appendf(&strb, "tag %" PRIx64, ucp_tag_hdr_from_rts(rndv_rts_hdr)->tag); + rkey_buf = rndv_rts_hdr + 1; + } else { + ucs_assert(ucp_rndv_rts_is_rma(rndv_rts_hdr)); + ucs_string_buffer_appendf( + &strb, "RMA_RTS dst 0x%" PRIx64 " %s", rma_rts->address, + ucs_memory_type_names[rma_rts->mem_type]); + rkey_buf = rma_rts + 1; } - rkey_buf = rndv_rts_hdr + 1; ucs_string_buffer_appendf(&strb, " ep_id 0x%" PRIx64 " sreq_id 0x%" PRIx64 " address 0x%" PRIx64 " size %zu", @@ -2477,6 +2490,22 @@ static void ucp_rndv_dump(ucp_worker_h worker, uct_am_trace_type_t type, } break; case UCP_AM_ID_RNDV_RTR: + if ((rndv_rtr_hdr->sreq_id == UCS_PTR_MAP_KEY_INVALID) && + (length >= sizeof(*rtr_req))) { + ucs_string_buffer_appendf( + &strb, "RNDV_RTR_REQ src 0x%" PRIx64 " dst 0x%" PRIx64 + " rreq_id 0x%" PRIx64 " ep_id 0x%" PRIx64 " size %zu" + " offset %zu %s", rtr_req->address, + rtr_req->super.address, rtr_req->super.rreq_id, + rtr_req->req.ep_id, rtr_req->super.size, + rtr_req->super.offset, + ucs_memory_type_names[rtr_req->mem_type]); + if (rtr_req->super.address != 0) { + ucp_rndv_dump_rkey(rtr_req + 1, data_end, &strb); + } + break; + } + ucs_string_buffer_appendf(&strb, "RNDV_RTR sreq_id 0x%" PRIx64 " rreq_id 0x%" PRIx64 " address 0x%" PRIx64 @@ -2509,13 +2538,18 @@ static void ucp_rndv_dump(ucp_worker_h worker, uct_am_trace_type_t type, } } -UCP_DEFINE_AM_WITH_PROXY(UCP_FEATURE_TAG | UCP_FEATURE_AM, UCP_AM_ID_RNDV_RTS, - ucp_rndv_rts_handler, ucp_rndv_dump, 0); -UCP_DEFINE_AM_WITH_PROXY(UCP_FEATURE_TAG | UCP_FEATURE_AM, UCP_AM_ID_RNDV_ATS, +UCP_DEFINE_AM_WITH_PROXY(UCP_FEATURE_TAG | UCP_FEATURE_AM | UCP_FEATURE_RMA, + UCP_AM_ID_RNDV_RTS, ucp_rndv_rts_handler, + ucp_rndv_dump, 0); +UCP_DEFINE_AM_WITH_PROXY(UCP_FEATURE_TAG | UCP_FEATURE_AM | UCP_FEATURE_RMA, + UCP_AM_ID_RNDV_ATS, ucp_rndv_ats_handler, ucp_rndv_dump, 0); -UCP_DEFINE_AM_WITH_PROXY(UCP_FEATURE_TAG | UCP_FEATURE_AM, UCP_AM_ID_RNDV_ATP, +UCP_DEFINE_AM_WITH_PROXY(UCP_FEATURE_TAG | UCP_FEATURE_AM | UCP_FEATURE_RMA, + UCP_AM_ID_RNDV_ATP, ucp_rndv_atp_handler, ucp_rndv_dump, 0); -UCP_DEFINE_AM_WITH_PROXY(UCP_FEATURE_TAG | UCP_FEATURE_AM, UCP_AM_ID_RNDV_RTR, +UCP_DEFINE_AM_WITH_PROXY(UCP_FEATURE_TAG | UCP_FEATURE_AM | UCP_FEATURE_RMA, + UCP_AM_ID_RNDV_RTR, ucp_rndv_rtr_handler, ucp_rndv_dump, 0); -UCP_DEFINE_AM_WITH_PROXY(UCP_FEATURE_TAG | UCP_FEATURE_AM, UCP_AM_ID_RNDV_DATA, +UCP_DEFINE_AM_WITH_PROXY(UCP_FEATURE_TAG | UCP_FEATURE_AM | UCP_FEATURE_RMA, + UCP_AM_ID_RNDV_DATA, ucp_rndv_data_handler, ucp_rndv_dump, 0); diff --git a/src/ucp/rndv/rndv.h b/src/ucp/rndv/rndv.h index 3a7b9d6a1a7..534e320cbaf 100644 --- a/src/ucp/rndv/rndv.h +++ b/src/ucp/rndv/rndv.h @@ -17,7 +17,9 @@ typedef enum { * the previous UCP versions) */ UCP_RNDV_RTS_TAG_OK = UCS_OK, /* RNDV AM operation */ - UCP_RNDV_RTS_AM = 1 + UCP_RNDV_RTS_AM = 1, + /* One-sided RNDV operation */ + UCP_RNDV_RTS_RMA = 2 } UCS_S_PACKED ucp_rndv_rts_opcode_t; @@ -63,6 +65,25 @@ typedef struct { } UCS_S_PACKED ucp_rndv_rtr_hdr_t; +/* + * RTR which requests the peer to create an internal sender. + */ +typedef struct { + /* Base RTR header; sreq_id is UCS_PTR_MAP_KEY_INVALID */ + ucp_rndv_rtr_hdr_t super; + + /* Endpoint on the RTR initiator side */ + ucp_request_hdr_t req; + + /* Address of the source buffer on the peer */ + uint64_t address; + + /* Memory locality of the source buffer on the peer */ + ucs_sys_device_t sys_dev; + ucs_memory_type_t mem_type; +} UCS_S_PACKED ucp_rndv_rtr_req_hdr_t; + + /* * RNDV_ATS/RNDV_ATP with size field */ diff --git a/src/ucp/rndv/rndv.inl b/src/ucp/rndv/rndv.inl index e3f4f6e198f..43dbab85191 100644 --- a/src/ucp/rndv/rndv.inl +++ b/src/ucp/rndv/rndv.inl @@ -23,6 +23,12 @@ ucp_rndv_rts_is_tag(const ucp_rndv_rts_hdr_t *rts_hdr) return rts_hdr->opcode == UCP_RNDV_RTS_TAG_OK; } +static UCS_F_ALWAYS_INLINE int +ucp_rndv_rts_is_rma(const ucp_rndv_rts_hdr_t *rts_hdr) +{ + return rts_hdr->opcode == UCP_RNDV_RTS_RMA; +} + static UCS_F_ALWAYS_INLINE void ucp_rndv_receive_start(ucp_worker_h worker, ucp_request_t *rreq, const ucp_rndv_rts_hdr_t *rndv_rts_hdr, diff --git a/src/ucp/rndv/rndv_am.c b/src/ucp/rndv/rndv_am.c index 245dc788d2e..2629a19f0c1 100644 --- a/src/ucp/rndv/rndv_am.c +++ b/src/ucp/rndv/rndv_am.c @@ -87,6 +87,7 @@ ucp_proto_rndv_am_bcopy_complete(ucp_request_t *req) { ucp_rndv_am_destroy_rkey(req); ucp_datatype_iter_mem_dereg(&req->send.state.dt_iter, UCP_DT_MASK_ALL); + ucp_request_rndv_flush_complete(req); return ucp_proto_request_bcopy_complete_success(req); } @@ -128,7 +129,8 @@ ucp_proto_rndv_am_bcopy_abort(ucp_request_t *req, ucs_status_t status) { ucp_rndv_am_destroy_rkey(req); ucp_datatype_iter_mem_dereg(&req->send.state.dt_iter, UCP_DT_MASK_ALL); - ucp_proto_request_bcopy_abort(req,status); + ucp_request_rndv_flush_complete(req); + ucp_proto_request_bcopy_abort(req, status); } ucp_proto_t ucp_rndv_am_bcopy_proto = { @@ -174,7 +176,7 @@ static ucs_status_t ucp_rndv_am_zcopy_proto_progress(uct_pending_req_t *uct_req) UCP_DT_MASK_CONTIG_IOV, ucp_rndv_am_zcopy_send_func, ucp_rndv_am_zcopy_complete, - ucp_proto_request_zcopy_completion); + ucp_proto_rndv_request_zcopy_completion); } static void ucp_rndv_am_zcopy_probe(const ucp_proto_init_params_t *init_params) diff --git a/src/ucp/rndv/rndv_get.c b/src/ucp/rndv/rndv_get.c index 67c60290428..c85cdb1154c 100644 --- a/src/ucp/rndv/rndv_get.c +++ b/src/ucp/rndv/rndv_get.c @@ -60,7 +60,8 @@ ucp_proto_rndv_get_common_probe(const ucp_proto_init_params_t *init_params, ucp_proto_perf_t *perf; ucs_status_t status; - if (!ucp_proto_rndv_op_check(init_params, UCP_OP_ID_RNDV_RECV, + if (ucp_proto_rndv_init_params_is_push(init_params) || + !ucp_proto_rndv_op_check(init_params, UCP_OP_ID_RNDV_RECV, support_ppln)) { return; } diff --git a/src/ucp/rndv/rndv_ppln.c b/src/ucp/rndv/rndv_ppln.c index 3fdfeeeae03..2f1f47e1fd9 100644 --- a/src/ucp/rndv/rndv_ppln.c +++ b/src/ucp/rndv/rndv_ppln.c @@ -83,7 +83,7 @@ ucp_proto_rndv_ppln_probe(const ucp_proto_init_params_t *init_params) /* Select a protocol for rndv recv */ sel_param = *select_param; - sel_param.op_id_flags = ucp_proto_select_op_id(select_param) | + sel_param.op_id_flags = select_param->op_id_flags | UCP_PROTO_SELECT_OP_FLAG_PPLN_FRAG; sel_param.op_attr = ucp_proto_select_op_attr_pack( UCP_OP_ATTR_FLAG_MULTI_SEND, UCP_PROTO_SELECT_OP_ATTR_MASK); @@ -203,9 +203,8 @@ static void ucp_proto_rndv_ppln_query(const ucp_proto_query_params_t *params, attr->lane_map |= UCS_BIT(rpriv->ack.lane); } -static void +static ucp_request_t * ucp_proto_rndv_ppln_frag_complete(ucp_request_t *freq, int send_ack, int abort, - ucp_proto_complete_cb_t complete_func, const char *title) { ucp_request_t *req = ucp_request_get_super(freq); @@ -217,7 +216,7 @@ ucp_proto_rndv_ppln_frag_complete(ucp_request_t *freq, int send_ack, int abort, /* In case of abort we don't destroy super request until all fragments are * completed */ if (!ucp_proto_rndv_frag_complete(req, freq, title)) { - return; + return NULL; } if (req->send.rndv.rkey != NULL) { @@ -229,24 +228,33 @@ ucp_proto_rndv_ppln_frag_complete(ucp_request_t *freq, int send_ack, int abort, if (!abort && (req->send.rndv.ppln.ack_data_size > 0)) { ucp_proto_request_set_stage(req, UCP_PROTO_RNDV_PPLN_STAGE_ACK); ucp_request_send(req); + return NULL; } else { - complete_func(req); + return req; } } void ucp_proto_rndv_ppln_send_frag_complete(ucp_request_t *freq, int send_ack) { - ucp_proto_rndv_ppln_frag_complete(freq, send_ack, 0, - ucp_proto_request_complete_success, - "ppln_send"); + ucp_request_t *req; + + req = ucp_proto_rndv_ppln_frag_complete(freq, send_ack, 0, "ppln_send"); + if (req != NULL) { + ucp_request_rndv_flush_complete(req); + ucp_proto_request_complete_success(req); + } } void ucp_proto_rndv_ppln_recv_frag_complete(ucp_request_t *freq, int send_ack, int abort) { - ucp_proto_rndv_ppln_frag_complete(freq, send_ack, abort, - ucp_proto_rndv_recv_complete, - "ppln_recv"); + ucp_request_t *req; + + req = ucp_proto_rndv_ppln_frag_complete(freq, send_ack, abort, "ppln_recv"); + if (req != NULL) { + ucp_request_rndv_flush_complete(req); + ucp_proto_rndv_recv_complete(req); + } } static ucs_status_t ucp_proto_rndv_ppln_progress(uct_pending_req_t *uct_req) @@ -336,7 +344,7 @@ ucp_proto_rndv_send_ppln_atp_progress(uct_pending_req_t *uct_req) return ucp_proto_rndv_ack_progress(req, &rpriv->ack, UCP_AM_ID_RNDV_ATP, ucp_proto_rndv_ppln_pack_ack, - ucp_proto_request_zcopy_complete_success); + ucp_proto_rndv_request_zcopy_complete_success); } ucp_proto_t ucp_rndv_send_ppln_proto = { diff --git a/src/ucp/rndv/rndv_put.c b/src/ucp/rndv/rndv_put.c index 95ed67baf42..48eea35db8e 100644 --- a/src/ucp/rndv/rndv_put.c +++ b/src/ucp/rndv/rndv_put.c @@ -47,7 +47,8 @@ ucp_proto_rndv_put_common_complete(ucp_request_t *req) UCS_STATS_UPDATE_COUNTER(req->send.ep->worker->stats, rpriv->stat_counter, +1); ucp_proto_rndv_rkey_destroy(req); - ucp_proto_request_zcopy_complete(req, req->send.state.uct_comp.status); + ucp_proto_rndv_request_zcopy_complete(req, + req->send.state.uct_comp.status); } static void ucp_proto_rndv_put_zcopy_completion(uct_completion_t *uct_comp) diff --git a/src/ucp/rndv/rndv_rkey_ptr.c b/src/ucp/rndv/rndv_rkey_ptr.c index f18e00c8e52..acb37d1600a 100644 --- a/src/ucp/rndv/rndv_rkey_ptr.c +++ b/src/ucp/rndv/rndv_rkey_ptr.c @@ -101,7 +101,8 @@ ucp_proto_rndv_rkey_ptr_probe(const ucp_proto_init_params_t *init_params) ucp_proto_perf_t *perf; ucs_status_t status; - if (!ucp_proto_rndv_op_check(init_params, UCP_OP_ID_RNDV_RECV, 0) || + if (ucp_proto_rndv_init_params_is_push(init_params) || + !ucp_proto_rndv_op_check(init_params, UCP_OP_ID_RNDV_RECV, 0) || !ucp_proto_common_init_check_err_handling(¶ms.super) || (ucp_proto_select_op_flags(params.super.super.select_param) & UCP_PROTO_SELECT_OP_FLAG_RESUME)) { @@ -317,7 +318,7 @@ static ucs_status_t ucp_proto_rndv_rkey_ptr_mtype_completion(ucp_request_t *req) { ucp_trace_req(req, "ucp_proto_rndv_rkey_ptr_mtype_completion"); ucp_proto_rndv_rkey_destroy(req); - ucp_proto_request_zcopy_complete(req, UCS_OK); + ucp_proto_rndv_request_zcopy_complete(req, UCS_OK); return UCS_OK; } diff --git a/src/ucp/rndv/rndv_rtr.c b/src/ucp/rndv/rndv_rtr.c index 5c160c7db94..609ba31ac2d 100644 --- a/src/ucp/rndv/rndv_rtr.c +++ b/src/ucp/rndv/rndv_rtr.c @@ -14,6 +14,9 @@ #include #include #include +#include + +#include /** @@ -38,6 +41,10 @@ typedef struct { ucs_sys_device_t frag_sys_dev; } ucp_proto_rndv_rtr_mtype_priv_t; + +static size_t ucp_proto_rndv_rtr_req_pack(void *dest, void *arg); + + static UCS_F_ALWAYS_INLINE void ucp_proto_rtr_common_request_init(ucp_request_t *req) { @@ -45,19 +52,45 @@ ucp_proto_rtr_common_request_init(ucp_request_t *req) req->send.state.completed_size = 0; } +static size_t ucp_proto_rndv_rtr_max_size(ucp_request_t *req) +{ + const ucp_proto_rndv_rtr_priv_t *rpriv = req->send.proto_config->priv; + + return sizeof(ucp_rndv_rtr_hdr_t) + rpriv->super.packed_rkey_size; +} + +static size_t ucp_proto_rndv_rtr_req_max_size(ucp_request_t *req) +{ + return ucp_proto_rndv_rtr_max_size(req) + + sizeof(ucp_rndv_rtr_req_hdr_t) - sizeof(ucp_rndv_rtr_hdr_t); +} + static ucs_status_t ucp_proto_rndv_rtr_common_send(ucp_request_t *req) { const ucp_proto_rndv_rtr_priv_t *rpriv = req->send.proto_config->priv; ucp_worker_h UCS_V_UNUSED worker = req->send.ep->worker; + ucp_ep_h ep = req->send.ep; + ucp_request_t *recv_req = NULL; + uct_pack_callback_t pack_cb; size_t max_rtr_size; ucs_status_t status; - max_rtr_size = sizeof(ucp_rndv_rtr_hdr_t) + rpriv->super.packed_rkey_size; - status = ucp_proto_am_bcopy_single_progress(req, UCP_AM_ID_RNDV_RTR, - rpriv->super.lane, - rpriv->pack_cb, req, - max_rtr_size, NULL, 0); - return status; + if (req->flags & UCP_REQUEST_FLAG_RNDV_RTR_REQ) { + pack_cb = ucp_proto_rndv_rtr_req_pack; + max_rtr_size = ucp_proto_rndv_rtr_req_max_size(req); + recv_req = ucp_rma_rndv_rtr_flush_open(req); + } else { + pack_cb = rpriv->pack_cb; + max_rtr_size = ucp_proto_rndv_rtr_max_size(req); + } + + status = ucp_proto_am_bcopy_single_send(req, UCP_AM_ID_RNDV_RTR, + rpriv->super.lane, pack_cb, req, + max_rtr_size, 0); + ucp_rma_rndv_rtr_flush_close(recv_req, ep, status); + + return ucp_proto_single_status_handle(req, 0, NULL, rpriv->super.lane, + status); } static UCS_F_ALWAYS_INLINE void @@ -125,6 +158,33 @@ static size_t ucp_proto_rndv_rtr_pack_with_rkey(void *dest, void *arg) return sizeof(*rtr) + rkey_size; } +static size_t ucp_proto_rndv_rtr_req_pack(void *dest, void *arg) +{ + ucp_request_t *req = arg; + ucp_rndv_rtr_req_hdr_t *rtr_req = dest; + const ucp_proto_rndv_rtr_priv_t *rpriv = req->send.proto_config->priv; + void *rkey_src, *rkey_dst; + size_t rkey_size, packed_size; + + packed_size = rpriv->pack_cb(&rtr_req->super, req); + rkey_size = packed_size - sizeof(rtr_req->super); + if (rkey_size > 0) { + rkey_src = UCS_PTR_BYTE_OFFSET(rtr_req, sizeof(rtr_req->super)); + rkey_dst = UCS_PTR_BYTE_OFFSET(rtr_req, sizeof(*rtr_req)); + memmove(rkey_dst, rkey_src, rkey_size); + } + + rtr_req->super.sreq_id = UCS_PTR_MAP_KEY_INVALID; + rtr_req->super.offset = 0; + rtr_req->req.ep_id = ucp_send_request_get_ep_remote_id(req); + rtr_req->req.req_id = UCS_PTR_MAP_KEY_INVALID; + rtr_req->address = req->send.rndv.remote_address; + rtr_req->sys_dev = req->send.rndv.remote_mem_info.sys_dev; + rtr_req->mem_type = req->send.rndv.remote_mem_info.type; + + return sizeof(*rtr_req) + rkey_size; +} + static ucs_status_t ucp_proto_rndv_rtr_progress(uct_pending_req_t *self) { ucp_request_t *req = ucs_container_of(self, ucp_request_t, send.uct); @@ -498,6 +558,11 @@ ucs_status_t ucp_proto_rndv_rtr_handle_atp(void *arg, void *data, size_t length, UCP_SEND_REQUEST_GET_BY_ID(&req, worker, atp->super.req_id, 0, return UCS_OK, "ATP %p", atp); + if (atp->super.status != UCS_OK) { + req->send.proto_config->proto->abort(req, atp->super.status); + return UCS_OK; + } + if (!ucp_proto_common_frag_complete(req, atp->size, "rndv_atp")) { return UCS_OK; } diff --git a/test/gtest/ucp/test_ucp_rma.cc b/test/gtest/ucp/test_ucp_rma.cc index 750ae1d9cbd..f108c32f94f 100755 --- a/test/gtest/ucp/test_ucp_rma.cc +++ b/test/gtest/ucp/test_ucp_rma.cc @@ -203,7 +203,7 @@ class test_ucp_rma : public test_ucp_memheap { return get_variant_value() & USER_MEMH; } -private: +protected: /* Test variants */ enum { FLUSH_EP = UCS_BIT(0), /* If not set, flush worker */ @@ -390,6 +390,82 @@ UCS_TEST_P(test_ucp_rma_dmabuf, put_registration_offset) UCP_INSTANTIATE_TEST_CASE_TLS(test_ucp_rma_dmabuf, ib_cuda, "ib,cuda_copy") +class test_ucp_rma_rndv : public test_ucp_rma { +public: + static constexpr size_t SIZE = 512 * UCS_KBYTE; + + static void get_test_variants(std::vector &variants) + { + add_variant_with_value(variants, UCP_FEATURE_RMA, 0, "flush_worker"); + add_variant_with_value(variants, UCP_FEATURE_RMA, FLUSH_EP, "flush_ep"); + } + + test_ucp_rma_rndv() + { + modify_config("PROTOS", "put/rndv,get/rndv,rndv/*"); + } + + void init() override + { + m_env.push_back( + new ucs::scoped_setenv("UCX_IB_GPU_DIRECT_RDMA", "n")); + test_ucp_rma::init(); + } + +protected: + static bool is_rndv_mem_type_pair(ucs_memory_type_t local_mem_type, + ucs_memory_type_t remote_mem_type) + { + return !UCP_MEM_IS_HOST(local_mem_type) || + !UCP_MEM_IS_HOST(remote_mem_type); + } + + void test_forced_rndv(send_func_t send_func) + { + unsigned num_tested = 0; + + for (const auto &pair : ucs::supported_mem_type_pairs()) { + if (!is_rndv_mem_type_pair(pair[0], pair[1])) { + continue; + } + + test_message_sizes(send_func, 128, default_max_size(), + pair[0], pair[1], 0); + ++num_tested; + if (HasFailure() || (num_errors() > 0)) { + break; + } + } + + if (num_tested == 0) { + UCS_TEST_SKIP_R("no memory type pair supports RMA/RNDV"); + } + } +}; + +UCS_TEST_P(test_ucp_rma_rndv, put_blocking) +{ + test_forced_rndv(static_cast(&test_ucp_rma::put_b)); +} + +UCS_TEST_P(test_ucp_rma_rndv, put_nonblocking) +{ + test_forced_rndv(static_cast(&test_ucp_rma::put_nbi)); +} + +UCS_TEST_P(test_ucp_rma_rndv, get_blocking) +{ + test_forced_rndv(static_cast(&test_ucp_rma::get_b)); +} + +UCS_TEST_P(test_ucp_rma_rndv, get_nonblocking) +{ + test_forced_rndv(static_cast(&test_ucp_rma::get_nbi)); +} + +UCP_INSTANTIATE_TEST_CASE_TLS_GPU_AWARE(test_ucp_rma_rndv, ib, "ib") + + class test_ucp_proto_emulation_enable : public test_ucp_rma { public: static constexpr size_t SMALL_SIZE = 8;