From af636147878f65353893cb5c75564d484164a4a7 Mon Sep 17 00:00:00 2001 From: Thomas Vegas Date: Thu, 21 May 2026 18:41:40 +0200 Subject: [PATCH 01/15] UCP/RMA: Add RNDV PUT protocol --- src/ucp/Makefile.am | 1 + src/ucp/core/ucp_context.c | 3 +- src/ucp/core/ucp_request.c | 4 +- src/ucp/core/ucp_request.h | 14 +- src/ucp/core/ucp_types.h | 2 + src/ucp/proto/proto.c | 1 + src/ucp/rma/rma_rndv.c | 339 ++++++++++++++++++++++++++++++++++++ src/ucp/rndv/proto_rndv.c | 4 + src/ucp/rndv/proto_rndv.inl | 7 +- src/ucp/rndv/rndv.c | 12 +- 10 files changed, 376 insertions(+), 11 deletions(-) create mode 100644 src/ucp/rma/rma_rndv.c diff --git a/src/ucp/Makefile.am b/src/ucp/Makefile.am index b4f55c6fdaf..8280a58f5d8 100644 --- a/src/ucp/Makefile.am +++ b/src/ucp/Makefile.am @@ -141,6 +141,7 @@ libucp_la_SOURCES = \ rma/get_offload.c \ rma/put_am.c \ rma/put_offload.c \ + rma/rma_rndv.c \ rma/rma_send.c \ rma/rma_sw.c \ rma/flush.c \ diff --git a/src/ucp/core/ucp_context.c b/src/ucp/core/ucp_context.c index 4ae9573c7ba..e13771e76b8 100644 --- a/src/ucp/core/ucp_context.c +++ b/src/ucp/core/ucp_context.c @@ -61,7 +61,8 @@ _macro(UCP_AM_ID_AM_MIDDLE) \ _macro(UCP_AM_ID_AM_SINGLE_REPLY) \ _macro(UCP_AM_ID_AM_FIRST_PSN) \ - _macro(UCP_AM_ID_AM_MIDDLE_PSN) + _macro(UCP_AM_ID_AM_MIDDLE_PSN) \ + _macro(UCP_AM_ID_RMA_RNDV) #define UCP_AM_HANDLER_DECL(_id) extern ucp_am_handler_t ucp_am_handler_##_id; diff --git a/src/ucp/core/ucp_request.c b/src/ucp/core/ucp_request.c index 336ce95b248..d0247c875eb 100644 --- a/src/ucp/core/ucp_request.c +++ b/src/ucp/core/ucp_request.c @@ -44,6 +44,7 @@ static const char *ucp_request_flag_names[] = { [ucs_ilog2(UCP_REQUEST_FLAG_RECV_TAG)] = "rcv_tag", [ucs_ilog2(UCP_REQUEST_FLAG_RKEY_INUSE)] = "rk_use", [ucs_ilog2(UCP_REQUEST_FLAG_USER_HEADER_COPIED)] = "hdr_copy", + [ucs_ilog2(UCP_REQUEST_FLAG_RNDV_RECV_INTERNAL)] = "rndv_rcv_int", #if UCS_ENABLE_ASSERT [ucs_ilog2(UCP_REQUEST_FLAG_STREAM_RECV)] = "strm_rcv", @@ -59,7 +60,8 @@ static ucs_memory_type_t ucp_request_get_mem_type(ucp_request_t *req) } else if (req->flags & (UCP_REQUEST_FLAG_SEND_AM | UCP_REQUEST_FLAG_SEND_TAG)) { return req->send.mem_type; } else if (req->flags & - (UCP_REQUEST_FLAG_RECV_AM | UCP_REQUEST_FLAG_RECV_TAG)) { + (UCP_REQUEST_FLAG_RECV_AM | UCP_REQUEST_FLAG_RECV_TAG | + UCP_REQUEST_FLAG_RNDV_RECV_INTERNAL)) { return req->recv.dt_iter.mem_info.type; } else { return UCS_MEMORY_TYPE_UNKNOWN; diff --git a/src/ucp/core/ucp_request.h b/src/ucp/core/ucp_request.h index 3e6c0ad2555..53aa059d325 100644 --- a/src/ucp/core/ucp_request.h +++ b/src/ucp/core/ucp_request.h @@ -56,10 +56,11 @@ enum { UCP_REQUEST_FLAG_USER_HEADER_COPIED = UCS_BIT(19), UCP_REQUEST_FLAG_USAGE_TRACKED = UCS_BIT(20), UCP_REQUEST_FLAG_FENCE_REQUIRED = UCS_BIT(21), + UCP_REQUEST_FLAG_RNDV_RECV_INTERNAL = UCS_BIT(22), #if UCS_ENABLE_ASSERT - UCP_REQUEST_FLAG_STREAM_RECV = UCS_BIT(22), - UCP_REQUEST_DEBUG_FLAG_EXTERNAL = UCS_BIT(23), - UCP_REQUEST_FLAG_SUPER_VALID = UCS_BIT(24), + UCP_REQUEST_FLAG_STREAM_RECV = UCS_BIT(23), + UCP_REQUEST_DEBUG_FLAG_EXTERNAL = UCS_BIT(24), + UCP_REQUEST_FLAG_SUPER_VALID = UCS_BIT(25), #else UCP_REQUEST_FLAG_STREAM_RECV = 0, UCP_REQUEST_DEBUG_FLAG_EXTERNAL = 0, @@ -477,6 +478,13 @@ struct ucp_request { size_t length; /* Completion info to fill */ } stream; + struct { + /* Remote endpoint ID used to send internal completions */ + uint64_t ep_id; + /* Completion callback for internal RNDV receives */ + ucp_request_callback_t complete_cb; + } rndv; + struct { ucp_am_recv_data_nbx_callback_t cb; /* Completion callback */ ucp_recv_desc_t *desc; /* Receive desc */ diff --git a/src/ucp/core/ucp_types.h b/src/ucp/core/ucp_types.h index e78506403ad..abba8deed2b 100644 --- a/src/ucp/core/ucp_types.h +++ b/src/ucp/core/ucp_types.h @@ -204,6 +204,8 @@ typedef enum { carrying remote ep and PSN for tracking */ UCP_AM_ID_AM_MIDDLE_PSN = 28, + + UCP_AM_ID_RMA_RNDV = 29, /* RMA rendezvous control */ UCP_AM_ID_LAST } ucp_am_id_t; diff --git a/src/ucp/proto/proto.c b/src/ucp/proto/proto.c index 7cb232873c0..c2edc894b22 100644 --- a/src/ucp/proto/proto.c +++ b/src/ucp/proto/proto.c @@ -29,6 +29,7 @@ _macro(ucp_put_offload_short_proto) \ _macro(ucp_put_offload_bcopy_proto) \ _macro(ucp_put_offload_zcopy_proto) \ + _macro(ucp_put_rndv_proto) \ _macro(ucp_eager_bcopy_multi_proto) \ _macro(ucp_eager_sync_bcopy_multi_proto) \ _macro(ucp_eager_zcopy_multi_proto) \ diff --git a/src/ucp/rma/rma_rndv.c b/src/ucp/rma/rma_rndv.c new file mode 100644 index 00000000000..018ff89b12a --- /dev/null +++ b/src/ucp/rma/rma_rndv.c @@ -0,0 +1,339 @@ +/** + * Copyright (c) NVIDIA CORPORATION & AFFILIATES, 2026. ALL RIGHTS RESERVED. + * + * See file LICENSE for terms. + */ + +#ifdef HAVE_CONFIG_H +# include "config.h" +#endif + +#include "rma.h" +#include "rma.inl" + +#include +#include +#include +#include +#include +#include + + +#define UCP_PROTO_RMA_RNDV_PUT_RTS_NAME "RMA PUT RTS" +#define UCP_PROTO_RMA_RNDV_PUT_DESC "RMA PUT rendezvous" + + +enum { + UCP_RMA_RNDV_AM_PUT_RTS +}; + + +typedef struct { + ucp_rndv_rts_hdr_t super; + uint64_t address; + ucs_sys_device_t sys_dev; + ucs_memory_type_t mem_type; +} UCS_S_PACKED ucp_rma_rndv_put_rts_hdr_t; + + + +static void +ucp_rma_rndv_dt_iter_init(ucp_datatype_iter_t *dt_iter, uint64_t address, + size_t length, ucs_memory_type_t mem_type, + ucs_sys_device_t sys_dev) +{ + dt_iter->dt_class = UCP_DATATYPE_CONTIG; + dt_iter->mem_info.type = mem_type; + dt_iter->mem_info.sys_dev = sys_dev; + dt_iter->length = length; + dt_iter->offset = 0; + dt_iter->type.contig.buffer = (void*)(uintptr_t)address; + dt_iter->type.contig.memh = NULL; +} + +static size_t ucp_proto_put_rndv_rts_pack(void *dest, void *arg) +{ + ucp_request_t *req = arg; + ucp_rma_rndv_put_rts_hdr_t *rts = dest; + ucp_rkey_config_t *rkey_config; + + rkey_config = ucp_rkey_config(req->send.ep->worker, req->send.rma.rkey); + + rts->super.hdr = UCP_RMA_RNDV_AM_PUT_RTS; + rts->super.opcode = UCP_RNDV_RTS_TAG_OK; + rts->address = req->send.rma.remote_addr; + rts->sys_dev = rkey_config->key.sys_dev; + rts->mem_type = req->send.rma.rkey->mem_type; + + return ucp_proto_rndv_rts_pack(req, &rts->super, sizeof(*rts)); +} + +static ucs_status_t ucp_proto_put_rndv_init(ucp_request_t *req) +{ + const ucp_proto_rndv_ctrl_priv_t *rpriv = req->send.proto_config->priv; + int was_initialized; + ucs_status_t status; + + was_initialized = req->flags & UCP_REQUEST_FLAG_PROTO_INITIALIZED; + status = ucp_proto_rndv_rts_request_init(req); + if ((status != UCS_OK) || was_initialized) { + return status; + } + + /* Nested RNDV data protocols are not RMA protocols, so the wrapper handles + * RMA fence ordering before exposing the operation to the peer. */ + return ucp_ep_rma_handle_fence(req->send.ep, req, UCS_BIT(rpriv->lane)); +} + +static ucs_status_t ucp_proto_put_rndv_progress(uct_pending_req_t *self) +{ + ucp_request_t *req = ucs_container_of(self, ucp_request_t, send.uct); + const ucp_proto_rndv_ctrl_priv_t *rpriv; + size_t max_rts_size; + ucs_status_t status; + ucp_ep_h ep; + + status = ucp_proto_put_rndv_init(req); + if (status != UCS_OK) { + ucp_proto_request_abort(req, status); + return UCS_OK; + } + + ep = req->send.ep; + rpriv = req->send.proto_config->priv; + max_rts_size = sizeof(ucp_rma_rndv_put_rts_hdr_t) + + rpriv->packed_rkey_size; + + ucp_worker_flush_ops_count_add(ep->worker, +1); + status = ucp_proto_am_bcopy_single_send(req, UCP_AM_ID_RMA_RNDV, + rpriv->lane, + ucp_proto_put_rndv_rts_pack, req, + max_rts_size, 0); + if (status == UCS_ERR_NO_RESOURCE) { + ucp_worker_flush_ops_count_add(ep->worker, -1); + req->send.lane = rpriv->lane; + return status; + } else if (status != UCS_OK) { + ucp_worker_flush_ops_count_add(ep->worker, -1); + ucp_proto_request_abort(req, status); + return UCS_OK; + } + + ucp_ep_rma_remote_request_sent(ep); + return UCS_OK; +} + +static void +ucp_proto_put_rndv_probe(const ucp_proto_init_params_t *init_params) +{ + ucp_context_h context = init_params->worker->context; + const ucp_proto_select_param_t *sel_param = init_params->select_param; + ucp_proto_rndv_ctrl_init_params_t params = { + .super.super = *init_params, + .super.latency = 0, + .super.overhead = context->config.ext.proto_overhead_rndv_rts, + .super.cfg_thresh = context->config.ext.zcopy_thresh, + .super.cfg_priority = 5, + .super.min_length = 1, + .super.max_length = SIZE_MAX, + .super.min_iov = 1, + .super.min_frag_offs = UCP_PROTO_COMMON_OFFSET_INVALID, + .super.max_frag_offs = ucs_offsetof(uct_iface_attr_t, cap.am.max_bcopy), + .super.max_iov_offs = UCP_PROTO_COMMON_OFFSET_INVALID, + .super.hdr_size = sizeof(ucp_rma_rndv_put_rts_hdr_t), + .super.send_op = UCT_EP_OP_AM_BCOPY, + .super.memtype_op = UCT_EP_OP_LAST, + .super.flags = UCP_PROTO_COMMON_INIT_FLAG_ERR_HANDLING, + .super.exclude_map = 0, + .super.reg_mem_info = ucp_proto_common_select_param_mem_info( + init_params->select_param), + /* For performance modeling, this control protocol is followed on the + * peer by a regular RNDV receive flow over the final RMA address. */ + .remote_op_id = UCP_OP_ID_RNDV_RECV, + .lane = ucp_proto_rndv_find_ctrl_lane(init_params), + .unpack_perf = NULL, + .perf_bias = 0, + .ctrl_msg_name = UCP_PROTO_RMA_RNDV_PUT_RTS_NAME, + .md_map = 0 + }; + ucp_proto_rndv_ctrl_priv_t rpriv; + + if (!ucp_proto_init_check_op(init_params, UCS_BIT(UCP_OP_ID_PUT)) || + (sel_param->dt_class != UCP_DATATYPE_CONTIG) || + (init_params->rkey_config_key == NULL)) { + return; + } + + if (UCP_MEM_IS_HOST(sel_param->mem_type) && + UCP_MEM_IS_HOST(init_params->rkey_config_key->mem_type)) { + return; + } + + ucp_proto_rndv_ctrl_probe(¶ms, &rpriv, sizeof(rpriv)); +} + +static void +ucp_rma_rndv_send_ats_err(ucp_ep_h ep, ucs_ptr_map_key_t remote_req_id, + ucs_status_t status) +{ + ucp_request_t *req; + + req = ucp_request_get(ep->worker); + if (req == NULL) { + ucs_error("failed to allocate RMA RNDV error ATS"); + return; + } + + ucp_proto_request_send_init(req, ep, 0); + ucp_rndv_req_send_ack(req, 0, remote_req_id, status, UCP_AM_ID_RNDV_ATS, + "send_ats_err"); +} + +static void ucp_proto_rma_rndv_query(const ucp_proto_query_params_t *params, + ucp_proto_query_attr_t *attr, + const char *desc) +{ + const ucp_proto_rndv_ctrl_priv_t *rpriv = params->priv; + ucp_proto_query_attr_t remote_attr; + + ucp_proto_config_query(params->worker, &rpriv->remote_proto_config, + params->msg_length, &remote_attr); + + attr->is_estimation = 1; + attr->max_msg_length = remote_attr.max_msg_length; + attr->lane_map = UCS_BIT(rpriv->lane); + + ucs_snprintf_safe(attr->desc, sizeof(attr->desc), "%s using %s", desc, + remote_attr.desc); + ucs_snprintf_safe(attr->config, sizeof(attr->config), "ctrl lane %u, %s", + rpriv->lane, remote_attr.config); +} + +static void ucp_proto_put_rndv_query(const ucp_proto_query_params_t *params, + ucp_proto_query_attr_t *attr) +{ + ucp_proto_rma_rndv_query(params, attr, UCP_PROTO_RMA_RNDV_PUT_DESC); +} + + +static void ucp_rma_rndv_put_recv_complete(ucp_request_t *recv_req) +{ + ucp_worker_h worker = recv_req->recv.worker; + ucp_ep_h ep; + + UCP_WORKER_GET_EP_BY_ID(&ep, worker, recv_req->recv.rndv.ep_id, { + ucp_request_put(recv_req); + return; + }, "RMA RNDV PUT completion"); + + ucp_rma_sw_send_cmpl(ep); + if (recv_req->status != UCS_OK) { + ucp_rma_rndv_send_ats_err(ep, recv_req->recv.remote_req_id, + recv_req->status); + } + + ucp_request_put(recv_req); +} + +static ucs_status_t +ucp_rma_rndv_handle_put_rts(ucp_worker_h worker, void *data, size_t length) +{ + const ucp_rma_rndv_put_rts_hdr_t *rts = data; + const void *rkey_buffer; + ucp_request_t *recv_req; + ucp_ep_h ep; + + if (length < sizeof(*rts)) { + return UCS_ERR_MESSAGE_TRUNCATED; + } + + recv_req = ucp_request_get(worker); + if (recv_req == NULL) { + ucs_error("failed to allocate RMA RNDV PUT receive request"); + UCP_WORKER_GET_EP_BY_ID(&ep, worker, rts->super.sreq.ep_id, + return UCS_OK, "RMA RNDV PUT error"); + ucp_rma_sw_send_cmpl(ep); + ucp_rma_rndv_send_ats_err(ep, rts->super.sreq.req_id, + UCS_ERR_NO_MEMORY); + return UCS_OK; + } + + recv_req->flags = UCP_REQUEST_FLAG_RNDV_RECV_INTERNAL; + recv_req->recv.worker = worker; + recv_req->recv.op_attr = 0; + recv_req->recv.remote_req_id = rts->super.sreq.req_id; + ucp_rma_rndv_dt_iter_init(&recv_req->recv.dt_iter, rts->address, + rts->super.size, rts->mem_type, rts->sys_dev); + recv_req->recv.rndv.ep_id = rts->super.sreq.ep_id; + recv_req->recv.rndv.complete_cb = ucp_rma_rndv_put_recv_complete; + + rkey_buffer = UCS_PTR_BYTE_OFFSET(rts, sizeof(*rts)); + ucp_proto_rndv_receive_start(worker, recv_req, &rts->super, rkey_buffer, + length - sizeof(*rts)); + return UCS_OK; +} + +UCS_PROFILE_FUNC(ucs_status_t, ucp_rma_rndv_handler, + (arg, data, length, am_flags), void *arg, void *data, + size_t length, unsigned am_flags) +{ + const ucp_rndv_rts_hdr_t *hdr = data; + ucp_worker_h worker = arg; + + if (length < sizeof(*hdr)) { + return UCS_ERR_MESSAGE_TRUNCATED; + } + + switch (hdr->hdr) { + case UCP_RMA_RNDV_AM_PUT_RTS: + return ucp_rma_rndv_handle_put_rts(worker, data, length); + default: + ucs_debug("unexpected RMA RNDV AM sub-id %" PRIu64, hdr->hdr); + return UCS_ERR_UNSUPPORTED; + } +} + +static void +ucp_rma_rndv_dump_packet(ucp_worker_h worker, uct_am_trace_type_t type, + uint8_t id, const void *data, size_t length, + char *buffer, size_t max) +{ + const ucp_rma_rndv_put_rts_hdr_t *put_rts = data; + const ucp_rndv_rts_hdr_t *hdr = data; + + if (length < sizeof(*hdr)) { + return; + } + + switch (hdr->hdr) { + case UCP_RMA_RNDV_AM_PUT_RTS: + if (length < sizeof(*put_rts)) { + return; + } + + snprintf(buffer, max, "RMA_PUT_RTS [src 0x%" PRIx64 + " dst 0x%" PRIx64 " len %zu req_id 0x%" PRIx64 + " ep_id 0x%" PRIx64 " %s]", put_rts->super.address, + put_rts->address, put_rts->super.size, + put_rts->super.sreq.req_id, put_rts->super.sreq.ep_id, + ucs_memory_type_names[put_rts->mem_type]); + break; + default: + snprintf(buffer, max, "RMA_RNDV [sub-id %" PRIu64 "]", hdr->hdr); + break; + } +} + +UCP_DEFINE_AM_WITH_PROXY(UCP_FEATURE_RMA, UCP_AM_ID_RMA_RNDV, + ucp_rma_rndv_handler, ucp_rma_rndv_dump_packet, 0); + +ucp_proto_t ucp_put_rndv_proto = { + .name = "put/rndv", + .desc = UCP_PROTO_RMA_RNDV_PUT_DESC, + .flags = 0, + .probe = ucp_proto_put_rndv_probe, + .query = ucp_proto_put_rndv_query, + .progress = {ucp_proto_put_rndv_progress}, + .abort = ucp_proto_rndv_rts_abort, + .reset = ucp_proto_rndv_rts_reset +}; diff --git a/src/ucp/rndv/proto_rndv.c b/src/ucp/rndv/proto_rndv.c index d68bba2a514..2f15f13cfdd 100644 --- a/src/ucp/rndv/proto_rndv.c +++ b/src/ucp/rndv/proto_rndv.c @@ -817,6 +817,7 @@ UCS_PROFILE_FUNC_VOID(ucp_proto_rndv_receive_start, ucp_ep_h ep; UCP_WORKER_GET_VALID_EP_BY_ID(&ep, worker, rts->sreq.ep_id, { + ucp_datatype_iter_cleanup(&recv_req->recv.dt_iter, 1, UCP_DT_MASK_ALL); ucp_proto_rndv_recv_req_complete(recv_req, UCS_ERR_CANCELED); return; }, "RTS on non-existing endpoint"); @@ -824,6 +825,8 @@ UCS_PROFILE_FUNC_VOID(ucp_proto_rndv_receive_start, req = ucp_request_get(worker); if (req == NULL) { ucs_error("failed to allocate rendezvous reply"); + ucp_datatype_iter_cleanup(&recv_req->recv.dt_iter, 1, UCP_DT_MASK_ALL); + ucp_proto_rndv_recv_req_complete(recv_req, UCS_ERR_NO_MEMORY); return; } @@ -856,6 +859,7 @@ UCS_PROFILE_FUNC_VOID(ucp_proto_rndv_receive_start, if (status != UCS_OK) { ucp_datatype_iter_cleanup(&req->send.state.dt_iter, 1, UCP_DT_MASK_ALL); ucs_mpool_put(req); + ucp_proto_rndv_recv_req_complete(recv_req, status); return; } diff --git a/src/ucp/rndv/proto_rndv.inl b/src/ucp/rndv/proto_rndv.inl index 837e8dc1534..14fedcc19c3 100644 --- a/src/ucp/rndv/proto_rndv.inl +++ b/src/ucp/rndv/proto_rndv.inl @@ -79,7 +79,7 @@ ucp_proto_rndv_ats_handler(void *arg, void *data, size_t length, unsigned flags) ucp_tag_offload_cancel_rndv(req); } - if (length >= sizeof(*ats)) { + if ((status == UCS_OK) && (length >= sizeof(*ats))) { /* ATS message carries a size field */ ats = ucs_derived_of(rephdr, ucp_rndv_ack_hdr_t); if (!ucp_proto_common_frag_complete(req, ats->size, "rndv_ats")) { @@ -360,7 +360,10 @@ ucp_proto_rndv_recv_req_complete(ucp_request_t *recv_req, ucs_status_t status) ucp_trace_req(recv_req, "rndv_recv_req_complete status '%s'", ucs_status_string(status)); - if (recv_req->flags & UCP_REQUEST_FLAG_RECV_AM) { + if (recv_req->flags & UCP_REQUEST_FLAG_RNDV_RECV_INTERNAL) { + recv_req->status = status; + recv_req->recv.rndv.complete_cb(recv_req); + } else if (recv_req->flags & UCP_REQUEST_FLAG_RECV_AM) { ucp_request_complete_am_recv(recv_req, status); } else { ucs_assert(recv_req->flags & UCP_REQUEST_FLAG_RECV_TAG); diff --git a/src/ucp/rndv/rndv.c b/src/ucp/rndv/rndv.c index 76cdcc4c570..24ce90266e1 100644 --- a/src/ucp/rndv/rndv.c +++ b/src/ucp/rndv/rndv.c @@ -2511,11 +2511,15 @@ static void ucp_rndv_dump(ucp_worker_h worker, uct_am_trace_type_t type, UCP_DEFINE_AM_WITH_PROXY(UCP_FEATURE_TAG | UCP_FEATURE_AM, UCP_AM_ID_RNDV_RTS, ucp_rndv_rts_handler, ucp_rndv_dump, 0); -UCP_DEFINE_AM_WITH_PROXY(UCP_FEATURE_TAG | UCP_FEATURE_AM, UCP_AM_ID_RNDV_ATS, +UCP_DEFINE_AM_WITH_PROXY(UCP_FEATURE_TAG | UCP_FEATURE_AM | UCP_FEATURE_RMA, + UCP_AM_ID_RNDV_ATS, ucp_rndv_ats_handler, ucp_rndv_dump, 0); -UCP_DEFINE_AM_WITH_PROXY(UCP_FEATURE_TAG | UCP_FEATURE_AM, UCP_AM_ID_RNDV_ATP, +UCP_DEFINE_AM_WITH_PROXY(UCP_FEATURE_TAG | UCP_FEATURE_AM | UCP_FEATURE_RMA, + UCP_AM_ID_RNDV_ATP, ucp_rndv_atp_handler, ucp_rndv_dump, 0); -UCP_DEFINE_AM_WITH_PROXY(UCP_FEATURE_TAG | UCP_FEATURE_AM, UCP_AM_ID_RNDV_RTR, +UCP_DEFINE_AM_WITH_PROXY(UCP_FEATURE_TAG | UCP_FEATURE_AM | UCP_FEATURE_RMA, + UCP_AM_ID_RNDV_RTR, ucp_rndv_rtr_handler, ucp_rndv_dump, 0); -UCP_DEFINE_AM_WITH_PROXY(UCP_FEATURE_TAG | UCP_FEATURE_AM, UCP_AM_ID_RNDV_DATA, +UCP_DEFINE_AM_WITH_PROXY(UCP_FEATURE_TAG | UCP_FEATURE_AM | UCP_FEATURE_RMA, + UCP_AM_ID_RNDV_DATA, ucp_rndv_data_handler, ucp_rndv_dump, 0); From 004806302a273ce08b417d1e26394918612b6dfe Mon Sep 17 00:00:00 2001 From: Thomas Vegas Date: Fri, 22 May 2026 10:33:27 +0200 Subject: [PATCH 02/15] UCP/RMA: Add RNDV GET protocol --- src/ucp/core/ucp_ep.c | 6 + src/ucp/core/ucp_request.c | 1 + src/ucp/core/ucp_request.h | 3 +- src/ucp/proto/proto.c | 1 + src/ucp/rma/rma_rndv.c | 457 ++++++++++++++++++++++++++++++++++++- 5 files changed, 459 insertions(+), 9 deletions(-) diff --git a/src/ucp/core/ucp_ep.c b/src/ucp/core/ucp_ep.c index 1a0c09bd808..ab9c0e42d49 100644 --- a/src/ucp/core/ucp_ep.c +++ b/src/ucp/core/ucp_ep.c @@ -3696,6 +3696,12 @@ void ucp_ep_req_purge(ucp_ep_h ucp_ep, ucp_request_t *req, } ucp_request_put(req); + } else if (req->flags & UCP_REQUEST_FLAG_RNDV_SEND_INTERNAL) { + ucs_assert(req->send.ep == ucp_ep); + + ucp_datatype_iter_cleanup(&req->send.state.dt_iter, 1, + UCP_DT_MASK_ALL); + ucp_request_complete_send(req, status); } else if (req->send.uct.func == ucp_amo_sw_proto.progress_fetch) { /* Currently we don't support UCP EP request purging for proto mode */ ucs_assert(!ucp_ep->worker->context->config.ext.proto_enable); diff --git a/src/ucp/core/ucp_request.c b/src/ucp/core/ucp_request.c index d0247c875eb..2c07b1e4e9d 100644 --- a/src/ucp/core/ucp_request.c +++ b/src/ucp/core/ucp_request.c @@ -51,6 +51,7 @@ static const char *ucp_request_flag_names[] = { [ucs_ilog2(UCP_REQUEST_DEBUG_FLAG_EXTERNAL)] = "extrn", [ucs_ilog2(UCP_REQUEST_FLAG_SUPER_VALID)] = "spr_vld", #endif + [ucs_ilog2(UCP_REQUEST_FLAG_RNDV_SEND_INTERNAL)] = "rndv_snd_int", }; static ucs_memory_type_t ucp_request_get_mem_type(ucp_request_t *req) diff --git a/src/ucp/core/ucp_request.h b/src/ucp/core/ucp_request.h index 53aa059d325..81a6c373f4f 100644 --- a/src/ucp/core/ucp_request.h +++ b/src/ucp/core/ucp_request.h @@ -64,8 +64,9 @@ enum { #else UCP_REQUEST_FLAG_STREAM_RECV = 0, UCP_REQUEST_DEBUG_FLAG_EXTERNAL = 0, - UCP_REQUEST_FLAG_SUPER_VALID = 0 + UCP_REQUEST_FLAG_SUPER_VALID = 0, #endif + UCP_REQUEST_FLAG_RNDV_SEND_INTERNAL = UCS_BIT(26) }; diff --git a/src/ucp/proto/proto.c b/src/ucp/proto/proto.c index c2edc894b22..fae0b9bd1f1 100644 --- a/src/ucp/proto/proto.c +++ b/src/ucp/proto/proto.c @@ -25,6 +25,7 @@ _macro(ucp_get_am_bcopy_proto) \ _macro(ucp_get_offload_bcopy_proto) \ _macro(ucp_get_offload_zcopy_proto) \ + _macro(ucp_get_rndv_proto) \ _macro(ucp_put_am_bcopy_proto) \ _macro(ucp_put_offload_short_proto) \ _macro(ucp_put_offload_bcopy_proto) \ diff --git a/src/ucp/rma/rma_rndv.c b/src/ucp/rma/rma_rndv.c index 018ff89b12a..a278994fe2a 100644 --- a/src/ucp/rma/rma_rndv.c +++ b/src/ucp/rma/rma_rndv.c @@ -20,11 +20,15 @@ #define UCP_PROTO_RMA_RNDV_PUT_RTS_NAME "RMA PUT RTS" +#define UCP_PROTO_RMA_RNDV_GET_REQ_NAME "RMA GET REQ" #define UCP_PROTO_RMA_RNDV_PUT_DESC "RMA PUT rendezvous" +#define UCP_PROTO_RMA_RNDV_GET_DESC "RMA GET rendezvous" enum { - UCP_RMA_RNDV_AM_PUT_RTS + UCP_RMA_RNDV_AM_PUT_RTS, + UCP_RMA_RNDV_AM_GET_REQ, + UCP_RMA_RNDV_AM_GET_RTS }; @@ -36,6 +40,21 @@ typedef struct { } UCS_S_PACKED ucp_rma_rndv_put_rts_hdr_t; +typedef struct { + uint64_t hdr; + ucp_request_hdr_t req; + uint64_t address; + size_t size; + ucs_sys_device_t sys_dev; + ucs_memory_type_t mem_type; +} UCS_S_PACKED ucp_rma_rndv_get_req_hdr_t; + + +typedef struct { + ucp_rndv_rts_hdr_t super; + ucs_ptr_map_key_t get_req_id; +} UCS_S_PACKED ucp_rma_rndv_get_rts_hdr_t; + static void ucp_rma_rndv_dt_iter_init(ucp_datatype_iter_t *dt_iter, uint64_t address, @@ -51,6 +70,7 @@ ucp_rma_rndv_dt_iter_init(ucp_datatype_iter_t *dt_iter, uint64_t address, dt_iter->type.contig.memh = NULL; } + static size_t ucp_proto_put_rndv_rts_pack(void *dest, void *arg) { ucp_request_t *req = arg; @@ -215,6 +235,388 @@ static void ucp_proto_put_rndv_query(const ucp_proto_query_params_t *params, ucp_proto_rma_rndv_query(params, attr, UCP_PROTO_RMA_RNDV_PUT_DESC); } +static void ucp_proto_get_rndv_query(const ucp_proto_query_params_t *params, + ucp_proto_query_attr_t *attr) +{ + ucp_proto_rma_rndv_query(params, attr, UCP_PROTO_RMA_RNDV_GET_DESC); +} + +static size_t ucp_proto_get_rndv_req_pack(void *dest, void *arg) +{ + ucp_request_t *req = arg; + ucp_rma_rndv_get_req_hdr_t *hdr = dest; + ucp_rkey_config_t *rkey_config; + + rkey_config = ucp_rkey_config(req->send.ep->worker, req->send.rma.rkey); + + hdr->hdr = UCP_RMA_RNDV_AM_GET_REQ; + hdr->req.ep_id = ucp_send_request_get_ep_remote_id(req); + hdr->req.req_id = ucp_send_request_get_id(req); + hdr->address = req->send.rma.remote_addr; + hdr->size = req->send.state.dt_iter.length; + hdr->sys_dev = rkey_config->key.sys_dev; + hdr->mem_type = req->send.rma.rkey->mem_type; + + return sizeof(*hdr); +} + +static ucs_status_t ucp_proto_get_rndv_init(ucp_request_t *req) +{ + const ucp_proto_rndv_ctrl_priv_t *rpriv = req->send.proto_config->priv; + ucs_status_t status; + + if (req->flags & UCP_REQUEST_FLAG_PROTO_INITIALIZED) { + return UCS_OK; + } + + status = ucp_ep_resolve_remote_id(req->send.ep, rpriv->lane); + if (status != UCS_OK) { + return status; + } + + req->send.buffer = req->send.state.dt_iter.type.contig.buffer; + req->send.length = req->send.state.dt_iter.length; + ucp_send_request_id_alloc(req); + req->flags |= UCP_REQUEST_FLAG_PROTO_INITIALIZED; + + /* The nested RNDV receive starts only after GET_RTS; this wrapper still + * has to respect RMA fence ordering before the target can expose data. */ + return ucp_ep_rma_handle_fence(req->send.ep, req, UCS_BIT(rpriv->lane)); +} + +static ucs_status_t ucp_proto_get_rndv_progress(uct_pending_req_t *self) +{ + ucp_request_t *req = ucs_container_of(self, ucp_request_t, send.uct); + const ucp_proto_rndv_ctrl_priv_t *rpriv; + ucs_status_t status; + + status = ucp_proto_get_rndv_init(req); + if (status != UCS_OK) { + ucp_proto_request_abort(req, status); + return UCS_OK; + } + + rpriv = req->send.proto_config->priv; + status = ucp_proto_am_bcopy_single_send(req, UCP_AM_ID_RMA_RNDV, + rpriv->lane, + ucp_proto_get_rndv_req_pack, req, + sizeof(ucp_rma_rndv_get_req_hdr_t), + 0); + if (status == UCS_ERR_NO_RESOURCE) { + req->send.lane = rpriv->lane; + return status; + } else if (status != UCS_OK) { + ucp_proto_request_abort(req, status); + return UCS_OK; + } + + return UCS_OK; +} + +static void +ucp_proto_get_rndv_probe(const ucp_proto_init_params_t *init_params) +{ + ucp_context_h context = init_params->worker->context; + const ucp_proto_select_param_t *sel_param = init_params->select_param; + ucp_proto_rndv_ctrl_init_params_t params = { + .super.super = *init_params, + .super.latency = 0, + .super.overhead = context->config.ext.proto_overhead_rndv_rts, + .super.cfg_thresh = context->config.ext.zcopy_thresh, + .super.cfg_priority = 5, + .super.min_length = 1, + .super.max_length = SIZE_MAX, + .super.min_iov = 1, + .super.min_frag_offs = UCP_PROTO_COMMON_OFFSET_INVALID, + .super.max_frag_offs = ucs_offsetof(uct_iface_attr_t, cap.am.max_bcopy), + .super.max_iov_offs = UCP_PROTO_COMMON_OFFSET_INVALID, + .super.hdr_size = sizeof(ucp_rma_rndv_get_req_hdr_t), + .super.send_op = UCT_EP_OP_AM_BCOPY, + .super.memtype_op = UCT_EP_OP_LAST, + .super.flags = UCP_PROTO_COMMON_INIT_FLAG_ERR_HANDLING, + .super.exclude_map = 0, + .super.reg_mem_info = ucp_proto_common_select_param_mem_info( + init_params->select_param), + /* The peer turns GET_REQ into a synthetic RNDV sender. */ + .remote_op_id = UCP_OP_ID_RNDV_SEND, + .lane = ucp_proto_rndv_find_ctrl_lane(init_params), + .unpack_perf = NULL, + .perf_bias = 0, + .ctrl_msg_name = UCP_PROTO_RMA_RNDV_GET_REQ_NAME, + .md_map = 0 + }; + ucp_proto_rndv_ctrl_priv_t rpriv; + + if (!ucp_proto_init_check_op(init_params, UCS_BIT(UCP_OP_ID_GET)) || + (sel_param->dt_class != UCP_DATATYPE_CONTIG) || + (init_params->rkey_config_key == NULL)) { + return; + } + + if (UCP_MEM_IS_HOST(sel_param->mem_type) && + UCP_MEM_IS_HOST(init_params->rkey_config_key->mem_type)) { + return; + } + + ucp_proto_rndv_ctrl_probe(¶ms, &rpriv, sizeof(rpriv)); +} + +static void ucp_proto_get_rndv_abort(ucp_request_t *req, ucs_status_t status) +{ + if (req->id != UCS_PTR_MAP_KEY_INVALID) { + ucp_send_request_id_release(req); + } + + if (req->send.state.dt_iter.dt_class != UCP_DATATYPE_CLASS_MASK) { + ucp_datatype_iter_cleanup(&req->send.state.dt_iter, 0, + UCP_DT_MASK_ALL); + } + + ucp_request_complete_send(req, status); +} + +static ucs_status_t ucp_proto_get_rndv_reset(ucp_request_t *req) +{ + if (req->flags & UCP_REQUEST_FLAG_PROTO_INITIALIZED) { + if (req->id != UCS_PTR_MAP_KEY_INVALID) { + ucp_send_request_id_release(req); + } + + if (req->send.state.dt_iter.dt_class != UCP_DATATYPE_CLASS_MASK) { + ucp_datatype_iter_cleanup(&req->send.state.dt_iter, 0, + UCP_DT_MASK_ALL); + } + } + + req->flags &= ~UCP_REQUEST_FLAG_PROTO_INITIALIZED; + return UCS_OK; +} + +static void +ucp_rma_rndv_get_send_complete(void *request, + ucs_status_t UCS_V_UNUSED status, + void *UCS_V_UNUSED user_data) +{ + ucp_request_t *req = (ucp_request_t*)request - 1; + + ucp_request_put(req); +} + +static void +ucp_rma_rndv_get_send_abort(ucp_request_t *req, ucs_status_t status) +{ + if (req->send.rndv.remote_req_id != UCS_PTR_MAP_KEY_INVALID) { + ucp_rma_rndv_send_ats_err(req->send.ep, req->send.rndv.remote_req_id, + status); + req->send.rndv.remote_req_id = UCS_PTR_MAP_KEY_INVALID; + } + + if (req->id != UCS_PTR_MAP_KEY_INVALID) { + ucp_send_request_id_release(req); + } + + ucp_datatype_iter_cleanup(&req->send.state.dt_iter, 1, UCP_DT_MASK_ALL); + ucp_request_complete_send(req, status); +} + +static size_t ucp_rma_rndv_get_rts_pack(void *dest, void *arg) +{ + ucp_request_t *req = arg; + ucp_rma_rndv_get_rts_hdr_t *rts = dest; + ucp_datatype_iter_t *dt_iter = &req->send.state.dt_iter; + void *rkey_buffer = UCS_PTR_BYTE_OFFSET(rts, sizeof(*rts)); + size_t rkey_size = 0; + + rts->super.hdr = UCP_RMA_RNDV_AM_GET_RTS; + rts->super.opcode = UCP_RNDV_RTS_TAG_OK; + rts->super.sreq.req_id = ucp_send_request_get_id(req); + rts->super.sreq.ep_id = ucp_send_request_get_ep_remote_id(req); + rts->super.size = dt_iter->length; + rts->super.address = 0; + rts->get_req_id = req->send.rndv.remote_req_id; + + if ((dt_iter->length > 0) && (req->send.rndv.md_map != 0)) { + rkey_size = UCS_PROFILE_CALL(ucp_proto_request_pack_rkey, req, + req->send.rndv.md_map, 0, NULL, + rkey_buffer); + if (rkey_size > 0) { + rts->super.address = (uintptr_t)dt_iter->type.contig.buffer; + } + } + + return sizeof(*rts) + rkey_size; +} + +static ucs_status_t ucp_rma_rndv_get_rts_progress(uct_pending_req_t *self) +{ + ucp_request_t *req = ucs_container_of(self, ucp_request_t, send.uct); + ucp_lane_index_t lane; + ucs_status_t status; + + lane = ucp_ep_get_am_lane(req->send.ep); + req->send.lane = lane; + status = ucp_proto_am_bcopy_single_send( + req, UCP_AM_ID_RMA_RNDV, lane, ucp_rma_rndv_get_rts_pack, req, + sizeof(ucp_rma_rndv_get_rts_hdr_t) + + ucp_ep_config(req->send.ep)->rndv.rkey_size, 0); + if (status == UCS_ERR_NO_RESOURCE) { + return status; + } else if (status != UCS_OK) { + ucp_rma_rndv_get_send_abort(req, status); + } + + return UCS_OK; +} + +static ucs_status_t +ucp_rma_rndv_get_sreq_init(ucp_ep_h ep, ucp_request_t *req, + const ucp_rma_rndv_get_req_hdr_t *get_req) +{ + ucp_proto_select_param_t sel_param; + ucp_md_map_t md_map; + ucs_status_t status; + + ucp_proto_request_send_init(req, ep, + UCP_REQUEST_FLAG_RNDV_SEND_INTERNAL); + ucp_request_set_callback(req, send.cb, ucp_rma_rndv_get_send_complete); + req->send.buffer = (void*)(uintptr_t)get_req->address; + req->send.length = get_req->size; + req->send.mem_type = get_req->mem_type; + req->send.rndv.remote_req_id = get_req->req.req_id; + req->send.rndv.rkey = NULL; + req->send.rndv.remote_address = get_req->address; + req->send.rndv.md_map = 0; + ucp_rma_rndv_dt_iter_init(&req->send.state.dt_iter, get_req->address, + get_req->size, get_req->mem_type, + get_req->sys_dev); + + status = ucp_ep_resolve_remote_id(ep, ucp_ep_get_am_lane(ep)); + if (status != UCS_OK) { + goto err_cleanup; + } + + md_map = ucp_ep_config(ep)->key.rma_bw_md_map; + if ((get_req->size > 0) && (md_map != 0)) { + status = ucp_datatype_iter_mem_reg(ep->worker->context, + &req->send.state.dt_iter, md_map, + UCT_MD_MEM_ACCESS_RMA | + UCT_MD_MEM_FLAG_HIDE_ERRORS, + UCP_DT_MASK_ALL); + if (status != UCS_OK) { + goto err_cleanup; + } + + req->send.rndv.md_map = + req->send.state.dt_iter.type.contig.memh->md_map & md_map; + } + + ucp_proto_select_param_init(&sel_param, UCP_OP_ID_RNDV_SEND, 0, 0, + UCP_DATATYPE_CONTIG, + &req->send.state.dt_iter.mem_info, 1); + status = UCS_PROFILE_CALL(ucp_proto_request_lookup_proto, ep->worker, ep, + req, &ucp_ep_config(ep)->proto_select, + UCP_WORKER_CFG_INDEX_NULL, &sel_param, + get_req->size); + if (status != UCS_OK) { + goto err_cleanup; + } + + ucp_send_request_id_alloc(req); + req->flags |= UCP_REQUEST_FLAG_PROTO_INITIALIZED; + req->send.uct.func = ucp_rma_rndv_get_rts_progress; + return UCS_OK; + +err_cleanup: + ucp_datatype_iter_cleanup(&req->send.state.dt_iter, 1, UCP_DT_MASK_ALL); + return status; +} + +static void ucp_rma_rndv_get_recv_complete(ucp_request_t *recv_req) +{ + ucp_request_t *get_req = ucp_request_get_super(recv_req); + + ucp_request_complete_send(get_req, recv_req->status); + ucp_request_put(recv_req); +} + +static ucs_status_t +ucp_rma_rndv_handle_get_req(ucp_worker_h worker, void *data, size_t length) +{ + const ucp_rma_rndv_get_req_hdr_t *get_req = data; + ucp_request_t *req; + ucs_status_t status; + ucp_ep_h ep; + + if (length < sizeof(*get_req)) { + return UCS_ERR_MESSAGE_TRUNCATED; + } + + UCP_WORKER_GET_EP_BY_ID(&ep, worker, get_req->req.ep_id, return UCS_OK, + "RMA RNDV GET request"); + + req = ucp_request_get(worker); + if (req == NULL) { + ucs_error("failed to allocate RMA RNDV GET send request"); + ucp_rma_rndv_send_ats_err(ep, get_req->req.req_id, UCS_ERR_NO_MEMORY); + return UCS_OK; + } + + status = ucp_rma_rndv_get_sreq_init(ep, req, get_req); + if (status != UCS_OK) { + ucp_rma_rndv_send_ats_err(ep, get_req->req.req_id, status); + ucp_request_put(req); + return UCS_OK; + } + + ucp_request_send(req); + return UCS_OK; +} + +static ucs_status_t +ucp_rma_rndv_handle_get_rts(ucp_worker_h worker, void *data, size_t length) +{ + const ucp_rma_rndv_get_rts_hdr_t *rts = data; + ucp_request_t *get_req, *recv_req; + uint8_t UCS_V_UNUSED sg_count; + const void *rkey_buffer; + ucp_ep_h ep; + + if (length < sizeof(*rts)) { + return UCS_ERR_MESSAGE_TRUNCATED; + } + + UCP_SEND_REQUEST_GET_BY_ID(&get_req, worker, rts->get_req_id, 0, + return UCS_OK, "RMA RNDV GET_RTS %p", rts); + + recv_req = ucp_request_get(worker); + if (recv_req == NULL) { + ucs_error("failed to allocate RMA RNDV GET receive request"); + UCP_WORKER_GET_EP_BY_ID(&ep, worker, rts->super.sreq.ep_id, + return UCS_OK, "RMA RNDV GET_RTS error"); + ucp_rma_rndv_send_ats_err(ep, rts->super.sreq.req_id, + UCS_ERR_NO_MEMORY); + ucp_proto_get_rndv_abort(get_req, UCS_ERR_NO_MEMORY); + return UCS_OK; + } + + ucp_send_request_id_release(get_req); + recv_req->flags = UCP_REQUEST_FLAG_RNDV_RECV_INTERNAL; + recv_req->recv.worker = worker; + recv_req->recv.op_attr = 0; + recv_req->recv.remote_req_id = rts->super.sreq.req_id; + recv_req->recv.rndv.ep_id = rts->super.sreq.ep_id; + recv_req->recv.rndv.complete_cb = ucp_rma_rndv_get_recv_complete; + ucp_request_set_super(recv_req, get_req); + + UCS_PROFILE_CALL_VOID(ucp_datatype_iter_move, &recv_req->recv.dt_iter, + &get_req->send.state.dt_iter, + get_req->send.state.dt_iter.length, &sg_count); + + rkey_buffer = UCS_PTR_BYTE_OFFSET(rts, sizeof(*rts)); + ucp_proto_rndv_receive_start(worker, recv_req, &rts->super, rkey_buffer, + length - sizeof(*rts)); + return UCS_OK; +} static void ucp_rma_rndv_put_recv_complete(ucp_request_t *recv_req) { @@ -277,18 +679,22 @@ UCS_PROFILE_FUNC(ucs_status_t, ucp_rma_rndv_handler, (arg, data, length, am_flags), void *arg, void *data, size_t length, unsigned am_flags) { - const ucp_rndv_rts_hdr_t *hdr = data; - ucp_worker_h worker = arg; + const uint64_t *hdr = data; + ucp_worker_h worker = arg; if (length < sizeof(*hdr)) { return UCS_ERR_MESSAGE_TRUNCATED; } - switch (hdr->hdr) { + switch (*hdr) { case UCP_RMA_RNDV_AM_PUT_RTS: return ucp_rma_rndv_handle_put_rts(worker, data, length); + case UCP_RMA_RNDV_AM_GET_REQ: + return ucp_rma_rndv_handle_get_req(worker, data, length); + case UCP_RMA_RNDV_AM_GET_RTS: + return ucp_rma_rndv_handle_get_rts(worker, data, length); default: - ucs_debug("unexpected RMA RNDV AM sub-id %" PRIu64, hdr->hdr); + ucs_debug("unexpected RMA RNDV AM sub-id %" PRIu64, *hdr); return UCS_ERR_UNSUPPORTED; } } @@ -299,13 +705,15 @@ ucp_rma_rndv_dump_packet(ucp_worker_h worker, uct_am_trace_type_t type, char *buffer, size_t max) { const ucp_rma_rndv_put_rts_hdr_t *put_rts = data; - const ucp_rndv_rts_hdr_t *hdr = data; + const ucp_rma_rndv_get_req_hdr_t *get_req = data; + const ucp_rma_rndv_get_rts_hdr_t *get_rts = data; + const uint64_t *hdr = data; if (length < sizeof(*hdr)) { return; } - switch (hdr->hdr) { + switch (*hdr) { case UCP_RMA_RNDV_AM_PUT_RTS: if (length < sizeof(*put_rts)) { return; @@ -318,8 +726,30 @@ ucp_rma_rndv_dump_packet(ucp_worker_h worker, uct_am_trace_type_t type, put_rts->super.sreq.req_id, put_rts->super.sreq.ep_id, ucs_memory_type_names[put_rts->mem_type]); break; + case UCP_RMA_RNDV_AM_GET_REQ: + if (length < sizeof(*get_req)) { + return; + } + + snprintf(buffer, max, "RMA_GET_REQ [src 0x%" PRIx64 + " len %zu req_id 0x%" PRIx64 " ep_id 0x%" PRIx64 + " %s]", get_req->address, get_req->size, + get_req->req.req_id, get_req->req.ep_id, + ucs_memory_type_names[get_req->mem_type]); + break; + case UCP_RMA_RNDV_AM_GET_RTS: + if (length < sizeof(*get_rts)) { + return; + } + + snprintf(buffer, max, "RMA_GET_RTS [src 0x%" PRIx64 + " len %zu sreq_id 0x%" PRIx64 " ep_id 0x%" PRIx64 + " get_req_id 0x%" PRIx64 "]", get_rts->super.address, + get_rts->super.size, get_rts->super.sreq.req_id, + get_rts->super.sreq.ep_id, get_rts->get_req_id); + break; default: - snprintf(buffer, max, "RMA_RNDV [sub-id %" PRIu64 "]", hdr->hdr); + snprintf(buffer, max, "RMA_RNDV [sub-id %" PRIu64 "]", *hdr); break; } } @@ -337,3 +767,14 @@ ucp_proto_t ucp_put_rndv_proto = { .abort = ucp_proto_rndv_rts_abort, .reset = ucp_proto_rndv_rts_reset }; + +ucp_proto_t ucp_get_rndv_proto = { + .name = "get/rndv", + .desc = UCP_PROTO_RMA_RNDV_GET_DESC, + .flags = 0, + .probe = ucp_proto_get_rndv_probe, + .query = ucp_proto_get_rndv_query, + .progress = {ucp_proto_get_rndv_progress}, + .abort = ucp_proto_get_rndv_abort, + .reset = ucp_proto_get_rndv_reset +}; From cc950cb417581c8e9ae1cc04db0465bd45163726 Mon Sep 17 00:00:00 2001 From: Thomas Vegas Date: Fri, 22 May 2026 12:52:01 +0200 Subject: [PATCH 03/15] UCP/RMA: Add RNDV GET push protocol --- src/ucp/core/ucp_request.c | 1 + src/ucp/core/ucp_request.h | 6 +- src/ucp/proto/proto.c | 1 + src/ucp/proto/proto_debug.c | 3 +- src/ucp/proto/proto_select.h | 3 + src/ucp/rma/rma_rndv.c | 242 ++++++++++++++++++++++++++++++++++- src/ucp/rndv/proto_rndv.c | 121 ++++++++++++++++-- src/ucp/rndv/proto_rndv.h | 16 ++- src/ucp/rndv/proto_rndv.inl | 11 ++ src/ucp/rndv/rndv.c | 17 +++ src/ucp/rndv/rndv.h | 19 +++ src/ucp/rndv/rndv_get.c | 1 + src/ucp/rndv/rndv_ppln.c | 2 +- src/ucp/rndv/rndv_rkey_ptr.c | 3 +- src/ucp/rndv/rndv_rtr.c | 63 ++++++++- 15 files changed, 479 insertions(+), 30 deletions(-) diff --git a/src/ucp/core/ucp_request.c b/src/ucp/core/ucp_request.c index 2c07b1e4e9d..1e2b40d67f9 100644 --- a/src/ucp/core/ucp_request.c +++ b/src/ucp/core/ucp_request.c @@ -52,6 +52,7 @@ static const char *ucp_request_flag_names[] = { [ucs_ilog2(UCP_REQUEST_FLAG_SUPER_VALID)] = "spr_vld", #endif [ucs_ilog2(UCP_REQUEST_FLAG_RNDV_SEND_INTERNAL)] = "rndv_snd_int", + [ucs_ilog2(UCP_REQUEST_FLAG_RNDV_RTR_REQ)] = "rndv_rtr_req", }; static ucs_memory_type_t ucp_request_get_mem_type(ucp_request_t *req) diff --git a/src/ucp/core/ucp_request.h b/src/ucp/core/ucp_request.h index 81a6c373f4f..5a50904c110 100644 --- a/src/ucp/core/ucp_request.h +++ b/src/ucp/core/ucp_request.h @@ -66,7 +66,8 @@ enum { UCP_REQUEST_DEBUG_FLAG_EXTERNAL = 0, UCP_REQUEST_FLAG_SUPER_VALID = 0, #endif - UCP_REQUEST_FLAG_RNDV_SEND_INTERNAL = UCS_BIT(26) + UCP_REQUEST_FLAG_RNDV_SEND_INTERNAL = UCS_BIT(26), + UCP_REQUEST_FLAG_RNDV_RTR_REQ = UCS_BIT(27) }; @@ -263,6 +264,9 @@ struct ucp_request { /* Remote buffer address for get/put operation */ uint64_t remote_address; + /* Remote buffer memory info for RTR_REQ */ + ucp_memory_info_t remote_mem_info; + /* Key for remote buffer operation */ ucp_rkey_h rkey; diff --git a/src/ucp/proto/proto.c b/src/ucp/proto/proto.c index fae0b9bd1f1..988fc0446c0 100644 --- a/src/ucp/proto/proto.c +++ b/src/ucp/proto/proto.c @@ -25,6 +25,7 @@ _macro(ucp_get_am_bcopy_proto) \ _macro(ucp_get_offload_bcopy_proto) \ _macro(ucp_get_offload_zcopy_proto) \ + _macro(ucp_get_rndv_push_proto) \ _macro(ucp_get_rndv_proto) \ _macro(ucp_put_am_bcopy_proto) \ _macro(ucp_put_offload_short_proto) \ diff --git a/src/ucp/proto/proto_debug.c b/src/ucp/proto/proto_debug.c index 570d8580bc7..babdabe260c 100644 --- a/src/ucp/proto/proto_debug.c +++ b/src/ucp/proto/proto_debug.c @@ -400,7 +400,8 @@ void ucp_proto_select_param_str(const ucp_proto_select_param_t *select_param, [ucs_ilog2(UCP_OP_ATTR_FLAG_MULTI_SEND)] = "multi", }; static const char *rndv_flag_names[] = { - [ucs_ilog2(UCP_PROTO_SELECT_OP_FLAG_PPLN_FRAG)] = "frag" + [ucs_ilog2(UCP_PROTO_SELECT_OP_FLAG_PPLN_FRAG)] = "frag", + [ucs_ilog2(UCP_PROTO_SELECT_OP_FLAG_RNDV_PUSH)] = "push" }; static const char *am_flag_names[] = { [ucs_ilog2(UCP_PROTO_SELECT_OP_FLAG_AM_EAGER)] = "egr", diff --git a/src/ucp/proto/proto_select.h b/src/ucp/proto/proto_select.h index dddce54f951..a988a52e9a7 100644 --- a/src/ucp/proto/proto_select.h +++ b/src/ucp/proto/proto_select.h @@ -32,6 +32,9 @@ * Relevant for UCP_OP_ID_RNDV_SEND and UCP_OP_ID_RNDV_RECV. */ #define UCP_PROTO_SELECT_OP_FLAG_PPLN_FRAG (UCP_PROTO_SELECT_OP_FLAGS_BASE << 1) +/* Select only push-based rendezvous receive protocols. */ +#define UCP_PROTO_SELECT_OP_FLAG_RNDV_PUSH (UCP_PROTO_SELECT_OP_FLAGS_BASE << 3) + /* Select eager/rendezvous protocol for Active Message sends. * Relevant for UCP_OP_ID_AM_SEND and UCP_OP_ID_AM_SEND_REPLY. */ diff --git a/src/ucp/rma/rma_rndv.c b/src/ucp/rma/rma_rndv.c index a278994fe2a..76c2dfddae1 100644 --- a/src/ucp/rma/rma_rndv.c +++ b/src/ucp/rma/rma_rndv.c @@ -14,15 +14,17 @@ #include #include #include +#include #include #include #include -#define UCP_PROTO_RMA_RNDV_PUT_RTS_NAME "RMA PUT RTS" -#define UCP_PROTO_RMA_RNDV_GET_REQ_NAME "RMA GET REQ" -#define UCP_PROTO_RMA_RNDV_PUT_DESC "RMA PUT rendezvous" -#define UCP_PROTO_RMA_RNDV_GET_DESC "RMA GET rendezvous" +#define UCP_PROTO_RMA_RNDV_PUT_RTS_NAME "RMA PUT RTS" +#define UCP_PROTO_RMA_RNDV_GET_REQ_NAME "RMA GET REQ" +#define UCP_PROTO_RMA_RNDV_PUT_DESC "RMA PUT rendezvous" +#define UCP_PROTO_RMA_RNDV_GET_DESC "RMA GET rendezvous" +#define UCP_PROTO_RMA_RNDV_GET_PUSH_DESC "RMA GET rendezvous push" enum { @@ -176,9 +178,10 @@ ucp_proto_put_rndv_probe(const ucp_proto_init_params_t *init_params) .ctrl_msg_name = UCP_PROTO_RMA_RNDV_PUT_RTS_NAME, .md_map = 0 }; - ucp_proto_rndv_ctrl_priv_t rpriv; + ucp_proto_rndv_ctrl_priv_t rpriv = {0}; if (!ucp_proto_init_check_op(init_params, UCS_BIT(UCP_OP_ID_PUT)) || + ucp_proto_rndv_init_params_is_ppln_frag(init_params) || (sel_param->dt_class != UCP_DATATYPE_CONTIG) || (init_params->rkey_config_key == NULL)) { return; @@ -221,7 +224,7 @@ static void ucp_proto_rma_rndv_query(const ucp_proto_query_params_t *params, attr->is_estimation = 1; attr->max_msg_length = remote_attr.max_msg_length; - attr->lane_map = UCS_BIT(rpriv->lane); + attr->lane_map = UCS_BIT(rpriv->lane) | remote_attr.lane_map; ucs_snprintf_safe(attr->desc, sizeof(attr->desc), "%s using %s", desc, remote_attr.desc); @@ -241,6 +244,14 @@ static void ucp_proto_get_rndv_query(const ucp_proto_query_params_t *params, ucp_proto_rma_rndv_query(params, attr, UCP_PROTO_RMA_RNDV_GET_DESC); } +static void +ucp_proto_get_rndv_push_query(const ucp_proto_query_params_t *params, + ucp_proto_query_attr_t *attr) +{ + ucp_proto_rma_rndv_query(params, attr, + UCP_PROTO_RMA_RNDV_GET_PUSH_DESC); +} + static size_t ucp_proto_get_rndv_req_pack(void *dest, void *arg) { ucp_request_t *req = arg; @@ -345,9 +356,10 @@ ucp_proto_get_rndv_probe(const ucp_proto_init_params_t *init_params) .ctrl_msg_name = UCP_PROTO_RMA_RNDV_GET_REQ_NAME, .md_map = 0 }; - ucp_proto_rndv_ctrl_priv_t rpriv; + ucp_proto_rndv_ctrl_priv_t rpriv = {0}; if (!ucp_proto_init_check_op(init_params, UCS_BIT(UCP_OP_ID_GET)) || + ucp_proto_rndv_init_params_is_ppln_frag(init_params) || (sel_param->dt_class != UCP_DATATYPE_CONTIG) || (init_params->rkey_config_key == NULL)) { return; @@ -361,6 +373,116 @@ ucp_proto_get_rndv_probe(const ucp_proto_init_params_t *init_params) ucp_proto_rndv_ctrl_probe(¶ms, &rpriv, sizeof(rpriv)); } +static void +ucp_proto_get_rndv_push_add_variant( + const ucp_proto_init_params_t *init_params, + const ucp_proto_select_param_t *select_param, + ucp_worker_cfg_index_t rkey_cfg_index, ucp_lane_index_t lane, + ucp_proto_init_elem_t *proto, const void *proto_priv) +{ + ucp_context_h context = init_params->worker->context; + const ucp_proto_perf_t *perf_elems[1]; + ucp_proto_rndv_ctrl_priv_t rpriv = {0}; + ucp_proto_init_params_t variant_params; + UCS_STRING_BUFFER_ONSTACK(perf_name, 128); + ucp_proto_perf_t *perf; + size_t cfg_thresh; + ucs_status_t status; + + perf_elems[0] = proto->perf; + ucs_string_buffer_appendf(&perf_name, "%s" UCP_PROTO_PERF_NODE_NEW_LINE + "%s", UCP_PROTO_RNDV_RTR_REQ_NAME, + ucp_proto_perf_name(proto->perf)); + status = ucp_proto_perf_aggregate(ucs_string_buffer_cstr(&perf_name), + perf_elems, 1, &perf); + if (status != UCS_OK) { + return; + } else if (ucp_proto_perf_is_empty(perf)) { + ucp_proto_perf_destroy(perf); + return; + } + + rpriv.lane = lane; + + variant_params = *init_params; + variant_params.rkey_cfg_index = rkey_cfg_index; + ucp_proto_rndv_set_variant_config(&variant_params, proto, select_param, + proto_priv, &rpriv.remote_proto_config); + + cfg_thresh = context->config.ext.zcopy_thresh; + if (proto->cfg_thresh != UCS_MEMUNITS_AUTO) { + cfg_thresh = proto->cfg_thresh; + } + + ucp_proto_select_add_proto(init_params, cfg_thresh, 6, perf, &rpriv, + sizeof(rpriv)); +} + +static void +ucp_proto_get_rndv_push_probe(const ucp_proto_init_params_t *init_params) +{ + ucp_worker_h worker = init_params->worker; + const ucp_proto_select_param_t *sel_param = init_params->select_param; + const ucp_proto_select_init_protocols_t *proto_init; + ucp_proto_select_param_t rndv_sel_param; + ucp_worker_cfg_index_t rkey_cfg_index; + ucp_proto_select_elem_t *select_elem; + ucp_proto_select_t *proto_select; + ucp_proto_init_elem_t *proto; + ucp_memory_info_t mem_info; + ucp_lane_index_t lane; + const void *priv; + + if (!ucp_proto_init_check_op(init_params, UCS_BIT(UCP_OP_ID_GET)) || + ucp_proto_rndv_init_params_is_ppln_frag(init_params) || + (sel_param->dt_class != UCP_DATATYPE_CONTIG) || + (init_params->rkey_config_key == NULL)) { + return; + } + + if (UCP_MEM_IS_HOST(sel_param->mem_type) && + UCP_MEM_IS_HOST(init_params->rkey_config_key->mem_type)) { + return; + } + + lane = ucp_proto_rndv_find_ctrl_lane(init_params); + if (lane == UCP_NULL_LANE) { + return; + } + + mem_info = ucp_proto_common_select_param_mem_info(sel_param); + ucp_proto_select_param_init(&rndv_sel_param, UCP_OP_ID_RNDV_RECV, 0, + UCP_PROTO_SELECT_OP_FLAG_RNDV_PUSH, + UCP_DATATYPE_CONTIG, &mem_info, 1); + + proto_select = ucp_proto_select_get(worker, init_params->ep_cfg_index, + init_params->rkey_cfg_index, + &rkey_cfg_index); + if (proto_select == NULL) { + return; + } + + select_elem = ucp_proto_select_lookup_slow(worker, proto_select, 1, + init_params->ep_cfg_index, + rkey_cfg_index, + &rndv_sel_param); + if (select_elem == NULL) { + return; + } + + proto_init = &select_elem->proto_init; + ucs_array_for_each(proto, &proto_init->protocols) { + if (ucp_proto_id_field(proto->proto_id, flags) & + UCP_PROTO_FLAG_INVALID) { + continue; + } + + priv = &ucs_array_elem(&proto_init->priv_buf, proto->priv_offset); + ucp_proto_get_rndv_push_add_variant(init_params, &rndv_sel_param, + rkey_cfg_index, lane, proto, priv); + } +} + static void ucp_proto_get_rndv_abort(ucp_request_t *req, ucs_status_t status) { if (req->id != UCS_PTR_MAP_KEY_INVALID) { @@ -539,6 +661,101 @@ static void ucp_rma_rndv_get_recv_complete(ucp_request_t *recv_req) ucp_request_put(recv_req); } +static ucs_status_t ucp_proto_get_rndv_push_init(ucp_request_t *get_req, + ucp_request_t **rndv_req_p) +{ + const ucp_proto_rndv_ctrl_priv_t *rpriv = get_req->send.proto_config->priv; + ucp_worker_h worker = get_req->send.ep->worker; + ucp_rkey_config_t *rkey_config; + ucp_request_t *recv_req; + ucp_request_t *rndv_req; + uint8_t UCS_V_UNUSED sg_count; + ucp_memory_info_t mem_info; + ucs_status_t status; + uint64_t address; + size_t length; + + status = ucp_ep_resolve_remote_id(get_req->send.ep, rpriv->lane); + if (status != UCS_OK) { + return status; + } + + status = ucp_ep_rma_handle_fence(get_req->send.ep, get_req, + UCS_BIT(rpriv->lane)); + if (status != UCS_OK) { + return status; + } + + address = get_req->send.rma.remote_addr; + rkey_config = ucp_rkey_config(worker, get_req->send.rma.rkey); + mem_info.type = get_req->send.rma.rkey->mem_type; + mem_info.sys_dev = rkey_config->key.sys_dev; + length = get_req->send.state.dt_iter.length; + get_req->send.buffer = + get_req->send.state.dt_iter.type.contig.buffer; + get_req->send.length = length; + + recv_req = ucp_request_get(worker); + if (recv_req == NULL) { + return UCS_ERR_NO_MEMORY; + } + + rndv_req = ucp_request_get(worker); + if (rndv_req == NULL) { + ucp_request_put(recv_req); + return UCS_ERR_NO_MEMORY; + } + + get_req->flags |= UCP_REQUEST_FLAG_PROTO_INITIALIZED; + recv_req->flags = UCP_REQUEST_FLAG_RNDV_RECV_INTERNAL; + recv_req->recv.worker = worker; + recv_req->recv.op_attr = 0; + recv_req->recv.remote_req_id = UCS_PTR_MAP_KEY_INVALID; + recv_req->recv.rndv.complete_cb = ucp_rma_rndv_get_recv_complete; + recv_req->status = UCS_OK; + ucp_request_set_super(recv_req, get_req); + + UCS_PROFILE_CALL_VOID(ucp_datatype_iter_move, &recv_req->recv.dt_iter, + &get_req->send.state.dt_iter, length, &sg_count); + + ucp_proto_request_send_init(rndv_req, get_req->send.ep, + UCP_REQUEST_FLAG_RNDV_RTR_REQ); + ucp_request_set_super(rndv_req, recv_req); + rndv_req->send.rndv.remote_req_id = UCS_PTR_MAP_KEY_INVALID; + rndv_req->send.rndv.remote_address = address; + rndv_req->send.rndv.remote_mem_info = mem_info; + rndv_req->send.rndv.rkey = NULL; + rndv_req->send.rndv.offset = 0; + + UCS_PROFILE_CALL_VOID(ucp_datatype_iter_move, + &rndv_req->send.state.dt_iter, + &recv_req->recv.dt_iter, length, &sg_count); + ucp_proto_request_set_proto(rndv_req, &rpriv->remote_proto_config, length); + + *rndv_req_p = rndv_req; + return UCS_OK; +} + +static ucs_status_t ucp_proto_get_rndv_push_progress(uct_pending_req_t *self) +{ + ucp_request_t *get_req = ucs_container_of(self, ucp_request_t, send.uct); + ucp_request_t *rndv_req; + ucs_status_t status; + + if (get_req->flags & UCP_REQUEST_FLAG_PROTO_INITIALIZED) { + return UCS_OK; + } + + status = ucp_proto_get_rndv_push_init(get_req, &rndv_req); + if (status != UCS_OK) { + ucp_proto_request_abort(get_req, status); + return UCS_OK; + } + + ucp_request_send(rndv_req); + return UCS_OK; +} + static ucs_status_t ucp_rma_rndv_handle_get_req(ucp_worker_h worker, void *data, size_t length) { @@ -778,3 +995,14 @@ ucp_proto_t ucp_get_rndv_proto = { .abort = ucp_proto_get_rndv_abort, .reset = ucp_proto_get_rndv_reset }; + +ucp_proto_t ucp_get_rndv_push_proto = { + .name = "get/rndv/push", + .desc = UCP_PROTO_RMA_RNDV_GET_PUSH_DESC, + .flags = 0, + .probe = ucp_proto_get_rndv_push_probe, + .query = ucp_proto_get_rndv_push_query, + .progress = {ucp_proto_get_rndv_push_progress}, + .abort = ucp_proto_get_rndv_abort, + .reset = ucp_proto_get_rndv_reset +}; diff --git a/src/ucp/rndv/proto_rndv.c b/src/ucp/rndv/proto_rndv.c index 2f15f13cfdd..fecdea52893 100644 --- a/src/ucp/rndv/proto_rndv.c +++ b/src/ucp/rndv/proto_rndv.c @@ -872,16 +872,13 @@ UCS_PROFILE_FUNC_VOID(ucp_proto_rndv_receive_start, } UCS_PROFILE_FUNC(ucs_status_t, ucp_proto_rndv_send_start, - (worker, req, op_attr_mask, rtr, header_length, sg_count), + (worker, req, op_attr_mask, rtr, rkey_buffer, rkey_length, + sg_count), ucp_worker_h worker, ucp_request_t *req, uint32_t op_attr_mask, - const ucp_rndv_rtr_hdr_t *rtr, size_t header_length, - uint8_t sg_count) + const ucp_rndv_rtr_hdr_t *rtr, const void *rkey_buffer, + size_t rkey_length, uint8_t sg_count) { ucs_status_t status; - size_t rkey_length; - - ucs_assert(header_length >= sizeof(*rtr)); - rkey_length = header_length - sizeof(*rtr); ucp_proto_rndv_check_rkey_length(rtr->address, rkey_length, "rtr"); req->send.rndv.remote_address = rtr->address; @@ -890,7 +887,7 @@ UCS_PROFILE_FUNC(ucs_status_t, ucp_proto_rndv_send_start, ucs_assert(rtr->size == req->send.state.dt_iter.length); status = ucp_proto_rndv_send_reply(worker, req, UCP_OP_ID_RNDV_SEND, - op_attr_mask, rtr->size, rtr + 1, + op_attr_mask, rtr->size, rkey_buffer, rkey_length, sg_count); if (status != UCS_OK) { return status; @@ -916,6 +913,100 @@ static void ucp_proto_rndv_send_complete_one(void *request, ucs_status_t status, ucp_proto_request_zcopy_complete(req, status); } +static void +ucp_proto_rndv_rtr_req_send_complete(void *request, + ucs_status_t UCS_V_UNUSED status, + void *UCS_V_UNUSED user_data) +{ + ucp_request_t *req = (ucp_request_t*)request - 1; + + ucp_request_put(req); +} + +static void +ucp_proto_rndv_rtr_req_send_atp_err(ucp_ep_h ep, + ucs_ptr_map_key_t remote_req_id, + ucs_status_t status) +{ + ucp_request_t *req; + + req = ucp_request_get(ep->worker); + if (req == NULL) { + ucs_error("failed to allocate RNDV RTR_REQ error ATP"); + return; + } + + ucp_proto_request_send_init(req, ep, 0); + ucp_rndv_req_send_ack(req, 0, remote_req_id, status, UCP_AM_ID_RNDV_ATP, + "send_atp_err"); +} + +static void +ucp_proto_rndv_rtr_req_sreq_init(ucp_ep_h ep, ucp_request_t *req, + const ucp_rndv_rtr_req_hdr_t *rtr_req) +{ + const ucp_rndv_rtr_hdr_t *rtr = &rtr_req->super; + + ucp_proto_request_send_init(req, ep, + UCP_REQUEST_FLAG_RNDV_SEND_INTERNAL); + ucp_request_set_callback(req, send.cb, + ucp_proto_rndv_rtr_req_send_complete); + req->send.buffer = (void*)(uintptr_t)rtr_req->address; + req->send.length = rtr->size; + req->send.mem_type = rtr_req->mem_type; + req->send.rndv.remote_req_id = rtr->rreq_id; + req->send.rndv.rkey = NULL; + req->send.rndv.remote_address = rtr_req->address; + req->send.state.dt_iter.dt_class = UCP_DATATYPE_CONTIG; + req->send.state.dt_iter.mem_info.type = rtr_req->mem_type; + req->send.state.dt_iter.mem_info.sys_dev = rtr_req->sys_dev; + req->send.state.dt_iter.length = rtr->size; + req->send.state.dt_iter.offset = 0; + req->send.state.dt_iter.type.contig.buffer = + (void*)(uintptr_t)rtr_req->address; + req->send.state.dt_iter.type.contig.memh = NULL; +} + +static ucs_status_t +ucp_proto_rndv_handle_rtr_req(ucp_worker_h worker, void *data, size_t length) +{ + const ucp_rndv_rtr_req_hdr_t *rtr_req = data; + const void *rkey_buffer; + ucp_request_t *req; + ucs_status_t status; + ucp_ep_h ep; + + if (length < sizeof(*rtr_req)) { + return UCS_ERR_MESSAGE_TRUNCATED; + } + + UCP_WORKER_GET_EP_BY_ID(&ep, worker, rtr_req->req.ep_id, return UCS_OK, + "RNDV RTR_REQ"); + + req = ucp_request_get(worker); + if (req == NULL) { + ucs_error("failed to allocate RNDV RTR_REQ send request"); + ucp_proto_rndv_rtr_req_send_atp_err(ep, rtr_req->super.rreq_id, + UCS_ERR_NO_MEMORY); + return UCS_OK; + } + + rkey_buffer = rtr_req + 1; + ucp_proto_rndv_rtr_req_sreq_init(ep, req, rtr_req); + status = ucp_proto_rndv_send_start(worker, req, 0, &rtr_req->super, + rkey_buffer, + length - sizeof(*rtr_req), 1); + if (status != UCS_OK) { + ucp_proto_rndv_rtr_req_send_atp_err(ep, rtr_req->super.rreq_id, + status); + ucp_datatype_iter_cleanup(&req->send.state.dt_iter, 1, + UCP_DT_MASK_ALL); + ucp_request_put(req); + } + + return UCS_OK; +} + ucs_status_t ucp_proto_rndv_handle_rtr(void *arg, void *data, size_t length, unsigned flags) { @@ -927,6 +1018,14 @@ ucp_proto_rndv_handle_rtr(void *arg, void *data, size_t length, unsigned flags) ucs_status_t status; uint8_t sg_count; + if (length < sizeof(*rtr)) { + return UCS_ERR_MESSAGE_TRUNCATED; + } + + if (rtr->sreq_id == UCS_PTR_MAP_KEY_INVALID) { + return ucp_proto_rndv_handle_rtr_req(worker, data, length); + } + UCP_SEND_REQUEST_GET_BY_ID(&req, worker, rtr->sreq_id, 0, return UCS_OK, "RTR %p", rtr); @@ -954,7 +1053,8 @@ ucp_proto_rndv_handle_rtr(void *arg, void *data, size_t length, unsigned flags) sg_count = select_param->sg_count; status = ucp_proto_rndv_send_start(worker, req, op_attr_mask, rtr, - length, sg_count); + rtr + 1, length - sizeof(*rtr), + sg_count); if (status != UCS_OK) { goto err_request_fail; } @@ -984,7 +1084,8 @@ ucp_proto_rndv_handle_rtr(void *arg, void *data, size_t length, unsigned flags) status = ucp_proto_rndv_send_start(worker, freq, op_attr_mask | UCP_OP_ATTR_FLAG_MULTI_SEND, - rtr, length, 1); + rtr, rtr + 1, + length - sizeof(*rtr), 1); if (status != UCS_OK) { goto err_put_freq; } diff --git a/src/ucp/rndv/proto_rndv.h b/src/ucp/rndv/proto_rndv.h index 8299b740491..66cc69f30dd 100644 --- a/src/ucp/rndv/proto_rndv.h +++ b/src/ucp/rndv/proto_rndv.h @@ -13,10 +13,11 @@ /* Names of rendezvous control messages */ -#define UCP_PROTO_RNDV_RTS_NAME "RTS" -#define UCP_PROTO_RNDV_RTR_NAME "RTR" -#define UCP_PROTO_RNDV_ATS_NAME "ATS" -#define UCP_PROTO_RNDV_ATP_NAME "ATP" +#define UCP_PROTO_RNDV_RTS_NAME "RTS" +#define UCP_PROTO_RNDV_RTR_NAME "RTR" +#define UCP_PROTO_RNDV_RTR_REQ_NAME "RTR_REQ" +#define UCP_PROTO_RNDV_ATS_NAME "ATS" +#define UCP_PROTO_RNDV_ATP_NAME "ATP" /* Mask of rendezvous operations */ @@ -201,6 +202,13 @@ void ucp_proto_rndv_receive_start(ucp_worker_h worker, ucp_request_t *recv_req, ucs_status_t ucp_proto_rndv_handle_rtr(void *arg, void *data, size_t length, unsigned flags); +ucs_status_t ucp_proto_rndv_send_start(ucp_worker_h worker, + ucp_request_t *req, + uint32_t op_attr_mask, + const ucp_rndv_rtr_hdr_t *rtr, + const void *rkey_buffer, + size_t rkey_length, + uint8_t sg_count); ucs_status_t ucp_proto_rndv_rtr_handle_atp(void *arg, void *data, size_t length, unsigned flags); diff --git a/src/ucp/rndv/proto_rndv.inl b/src/ucp/rndv/proto_rndv.inl index 14fedcc19c3..9970e421ffe 100644 --- a/src/ucp/rndv/proto_rndv.inl +++ b/src/ucp/rndv/proto_rndv.inl @@ -182,6 +182,10 @@ static UCS_F_ALWAYS_INLINE ucs_status_t ucp_proto_rndv_frag_request_alloc( ucp_proto_request_send_init(freq, req->send.ep, UCP_REQUEST_FLAG_RNDV_FRAG); ucp_request_set_super(freq, req); + if (req->flags & UCP_REQUEST_FLAG_RNDV_RTR_REQ) { + freq->flags |= UCP_REQUEST_FLAG_RNDV_RTR_REQ; + freq->send.rndv.remote_mem_info = req->send.rndv.remote_mem_info; + } *freq_p = freq; return UCS_OK; @@ -346,6 +350,13 @@ ucp_proto_rndv_init_params_is_ppln_frag(const ucp_proto_init_params_t *params) UCP_PROTO_SELECT_OP_FLAG_PPLN_FRAG; } +static UCS_F_ALWAYS_INLINE int +ucp_proto_rndv_init_params_is_push(const ucp_proto_init_params_t *params) +{ + return ucp_proto_select_op_flags(params->select_param) & + UCP_PROTO_SELECT_OP_FLAG_RNDV_PUSH; +} + static UCS_F_ALWAYS_INLINE int ucp_proto_rndv_op_check(const ucp_proto_init_params_t *params, ucp_operation_id_t op_id, int support_ppln) diff --git a/src/ucp/rndv/rndv.c b/src/ucp/rndv/rndv.c index 24ce90266e1..e00a9aedaaa 100644 --- a/src/ucp/rndv/rndv.c +++ b/src/ucp/rndv/rndv.c @@ -2437,6 +2437,7 @@ static void ucp_rndv_dump(ucp_worker_h worker, uct_am_trace_type_t type, UCS_STRING_BUFFER_FIXED(strb, buffer, max); const ucp_rndv_rts_hdr_t *rndv_rts_hdr = data; const ucp_rndv_rtr_hdr_t *rndv_rtr_hdr = data; + const ucp_rndv_rtr_req_hdr_t *rtr_req = data; const ucp_request_data_hdr_t *rndv_data = data; const ucp_rndv_ack_hdr_t *ack_hdr = data; const ucp_reply_hdr_t *rep_hdr = data; @@ -2477,6 +2478,22 @@ static void ucp_rndv_dump(ucp_worker_h worker, uct_am_trace_type_t type, } break; case UCP_AM_ID_RNDV_RTR: + if ((rndv_rtr_hdr->sreq_id == UCS_PTR_MAP_KEY_INVALID) && + (length >= sizeof(*rtr_req))) { + ucs_string_buffer_appendf( + &strb, "RNDV_RTR_REQ src 0x%" PRIx64 " dst 0x%" PRIx64 + " rreq_id 0x%" PRIx64 " ep_id 0x%" PRIx64 " size %zu" + " offset %zu %s", rtr_req->address, + rtr_req->super.address, rtr_req->super.rreq_id, + rtr_req->req.ep_id, rtr_req->super.size, + rtr_req->super.offset, + ucs_memory_type_names[rtr_req->mem_type]); + if (rtr_req->super.address != 0) { + ucp_rndv_dump_rkey(rtr_req + 1, data_end, &strb); + } + break; + } + ucs_string_buffer_appendf(&strb, "RNDV_RTR sreq_id 0x%" PRIx64 " rreq_id 0x%" PRIx64 " address 0x%" PRIx64 diff --git a/src/ucp/rndv/rndv.h b/src/ucp/rndv/rndv.h index 3a7b9d6a1a7..3ff6ca604a1 100644 --- a/src/ucp/rndv/rndv.h +++ b/src/ucp/rndv/rndv.h @@ -63,6 +63,25 @@ typedef struct { } UCS_S_PACKED ucp_rndv_rtr_hdr_t; +/* + * RTR which requests the peer to create an internal sender. + */ +typedef struct { + /* Base RTR header; sreq_id is UCS_PTR_MAP_KEY_INVALID */ + ucp_rndv_rtr_hdr_t super; + + /* Endpoint on the RTR initiator side */ + ucp_request_hdr_t req; + + /* Address of the source buffer on the peer */ + uint64_t address; + + /* Memory locality of the source buffer on the peer */ + ucs_sys_device_t sys_dev; + ucs_memory_type_t mem_type; +} UCS_S_PACKED ucp_rndv_rtr_req_hdr_t; + + /* * RNDV_ATS/RNDV_ATP with size field */ diff --git a/src/ucp/rndv/rndv_get.c b/src/ucp/rndv/rndv_get.c index 9a8fb8c28f4..f9acf40c5f5 100644 --- a/src/ucp/rndv/rndv_get.c +++ b/src/ucp/rndv/rndv_get.c @@ -61,6 +61,7 @@ ucp_proto_rndv_get_common_probe(const ucp_proto_init_params_t *init_params, ucs_status_t status; if ((init_params->select_param->dt_class != UCP_DATATYPE_CONTIG) || + ucp_proto_rndv_init_params_is_push(init_params) || !ucp_proto_rndv_op_check(init_params, UCP_OP_ID_RNDV_RECV, support_ppln)) { return; diff --git a/src/ucp/rndv/rndv_ppln.c b/src/ucp/rndv/rndv_ppln.c index 17fc9e8b16b..9b407c4a669 100644 --- a/src/ucp/rndv/rndv_ppln.c +++ b/src/ucp/rndv/rndv_ppln.c @@ -84,7 +84,7 @@ ucp_proto_rndv_ppln_probe(const ucp_proto_init_params_t *init_params) /* Select a protocol for rndv recv */ sel_param = *select_param; - sel_param.op_id_flags = ucp_proto_select_op_id(select_param) | + sel_param.op_id_flags = select_param->op_id_flags | UCP_PROTO_SELECT_OP_FLAG_PPLN_FRAG; sel_param.op_attr = ucp_proto_select_op_attr_pack( UCP_OP_ATTR_FLAG_MULTI_SEND, UCP_PROTO_SELECT_OP_ATTR_MASK); diff --git a/src/ucp/rndv/rndv_rkey_ptr.c b/src/ucp/rndv/rndv_rkey_ptr.c index 7878e44234c..999c2e04ae3 100644 --- a/src/ucp/rndv/rndv_rkey_ptr.c +++ b/src/ucp/rndv/rndv_rkey_ptr.c @@ -101,7 +101,8 @@ ucp_proto_rndv_rkey_ptr_probe(const ucp_proto_init_params_t *init_params) ucp_proto_perf_t *perf; ucs_status_t status; - if (!ucp_proto_rndv_op_check(init_params, UCP_OP_ID_RNDV_RECV, 0) || + if (ucp_proto_rndv_init_params_is_push(init_params) || + !ucp_proto_rndv_op_check(init_params, UCP_OP_ID_RNDV_RECV, 0) || !ucp_proto_common_init_check_err_handling(¶ms.super) || (ucp_proto_select_op_flags(params.super.super.select_param) & UCP_PROTO_SELECT_OP_FLAG_RESUME)) { diff --git a/src/ucp/rndv/rndv_rtr.c b/src/ucp/rndv/rndv_rtr.c index 11493a511c9..bfe08c0100e 100644 --- a/src/ucp/rndv/rndv_rtr.c +++ b/src/ucp/rndv/rndv_rtr.c @@ -15,6 +15,8 @@ #include #include +#include + /** * RTR protocol callback, which is called when all incoming data is filled to @@ -38,6 +40,10 @@ typedef struct { ucs_sys_device_t frag_sys_dev; } ucp_proto_rndv_rtr_mtype_priv_t; + +static size_t ucp_proto_rndv_rtr_req_pack(void *dest, void *arg); + + static UCS_F_ALWAYS_INLINE void ucp_proto_rtr_common_request_init(ucp_request_t *req) { @@ -45,18 +51,37 @@ ucp_proto_rtr_common_request_init(ucp_request_t *req) req->send.state.completed_size = 0; } +static size_t ucp_proto_rndv_rtr_max_size(ucp_request_t *req) +{ + const ucp_proto_rndv_rtr_priv_t *rpriv = req->send.proto_config->priv; + + return sizeof(ucp_rndv_rtr_hdr_t) + rpriv->super.packed_rkey_size; +} + +static size_t ucp_proto_rndv_rtr_req_max_size(ucp_request_t *req) +{ + return ucp_proto_rndv_rtr_max_size(req) + + sizeof(ucp_rndv_rtr_req_hdr_t) - sizeof(ucp_rndv_rtr_hdr_t); +} + static ucs_status_t ucp_proto_rndv_rtr_common_send(ucp_request_t *req) { const ucp_proto_rndv_rtr_priv_t *rpriv = req->send.proto_config->priv; ucp_worker_h UCS_V_UNUSED worker = req->send.ep->worker; + uct_pack_callback_t pack_cb = rpriv->pack_cb; size_t max_rtr_size; ucs_status_t status; - max_rtr_size = sizeof(ucp_rndv_rtr_hdr_t) + rpriv->super.packed_rkey_size; - status = ucp_proto_am_bcopy_single_progress(req, UCP_AM_ID_RNDV_RTR, - rpriv->super.lane, - rpriv->pack_cb, req, - max_rtr_size, NULL, 0); + if (req->flags & UCP_REQUEST_FLAG_RNDV_RTR_REQ) { + pack_cb = ucp_proto_rndv_rtr_req_pack; + max_rtr_size = ucp_proto_rndv_rtr_req_max_size(req); + } else { + max_rtr_size = ucp_proto_rndv_rtr_max_size(req); + } + + status = ucp_proto_am_bcopy_single_progress(req, UCP_AM_ID_RNDV_RTR, + rpriv->super.lane, pack_cb, + req, max_rtr_size, NULL, 0); return status; } @@ -125,6 +150,29 @@ static size_t ucp_proto_rndv_rtr_pack_with_rkey(void *dest, void *arg) return sizeof(*rtr) + rkey_size; } +static size_t ucp_proto_rndv_rtr_req_pack(void *dest, void *arg) +{ + ucp_request_t *req = arg; + ucp_rndv_rtr_req_hdr_t *rtr_req = dest; + const ucp_proto_rndv_rtr_priv_t *rpriv = req->send.proto_config->priv; + size_t rkey_size, packed_size; + + packed_size = rpriv->pack_cb(&rtr_req->super, req); + rkey_size = packed_size - sizeof(rtr_req->super); + if (rkey_size > 0) { + memmove(rtr_req + 1, &rtr_req->super + 1, rkey_size); + } + + rtr_req->super.sreq_id = UCS_PTR_MAP_KEY_INVALID; + rtr_req->req.ep_id = ucp_send_request_get_ep_remote_id(req); + rtr_req->req.req_id = UCS_PTR_MAP_KEY_INVALID; + rtr_req->address = req->send.rndv.remote_address; + rtr_req->sys_dev = req->send.rndv.remote_mem_info.sys_dev; + rtr_req->mem_type = req->send.rndv.remote_mem_info.type; + + return sizeof(*rtr_req) + rkey_size; +} + static ucs_status_t ucp_proto_rndv_rtr_progress(uct_pending_req_t *self) { ucp_request_t *req = ucs_container_of(self, ucp_request_t, send.uct); @@ -496,6 +544,11 @@ ucs_status_t ucp_proto_rndv_rtr_handle_atp(void *arg, void *data, size_t length, UCP_SEND_REQUEST_GET_BY_ID(&req, worker, atp->super.req_id, 0, return UCS_OK, "ATP %p", atp); + if (atp->super.status != UCS_OK) { + req->send.proto_config->proto->abort(req, atp->super.status); + return UCS_OK; + } + if (!ucp_proto_common_frag_complete(req, atp->size, "rndv_atp")) { return UCS_OK; } From 0e3b87b3e4e161b717107f2f2a33d8fa871d8397 Mon Sep 17 00:00:00 2001 From: Thomas Vegas Date: Fri, 22 May 2026 16:36:00 +0300 Subject: [PATCH 04/15] UCP/RMA: Cleanup code --- src/ucp/proto/proto.h | 2 +- src/ucp/proto/proto_select.c | 13 +++++++++++-- src/ucp/rma/rma_rndv.c | 25 ++++++++++++------------- 3 files changed, 24 insertions(+), 16 deletions(-) diff --git a/src/ucp/proto/proto.h b/src/ucp/proto/proto.h index 2600a07e665..c17a23bbfe0 100644 --- a/src/ucp/proto/proto.h +++ b/src/ucp/proto/proto.h @@ -23,7 +23,7 @@ /* Maximal number of protocols in total */ -#define UCP_PROTO_MAX_COUNT 64 +#define UCP_PROTO_MAX_COUNT 65 /* Special value for non-existent protocol */ diff --git a/src/ucp/proto/proto_select.c b/src/ucp/proto/proto_select.c index ec66179280c..e1aaa0ae86f 100644 --- a/src/ucp/proto/proto_select.c +++ b/src/ucp/proto/proto_select.c @@ -550,9 +550,18 @@ ucp_proto_select_lookup_slow(ucp_worker_h worker, return NULL; } - /* add to hash after initializing the temp element, since calling - * ucp_proto_select_elem_init() can recursively modify the hash + /* Add to hash after initializing the temp element, since calling + * ucp_proto_select_elem_init() can recursively modify the hash. + * Re-check the key because recursive lookup may have initialized this + * exact selection already. */ + khiter = kh_get(ucp_proto_select_hash, proto_select->hash, key.u64); + if (khiter != kh_end(proto_select->hash)) { + ucp_proto_select_elem_cleanup(&tmp_select_elem); + select_elem = &kh_value(proto_select->hash, khiter); + goto out; + } + khiter = kh_put(ucp_proto_select_hash, proto_select->hash, key.u64, &khret); ucs_assert_always(khret == UCS_KH_PUT_BUCKET_EMPTY); diff --git a/src/ucp/rma/rma_rndv.c b/src/ucp/rma/rma_rndv.c index 76c2dfddae1..ff79f77e077 100644 --- a/src/ucp/rma/rma_rndv.c +++ b/src/ucp/rma/rma_rndv.c @@ -20,11 +20,10 @@ #include -#define UCP_PROTO_RMA_RNDV_PUT_RTS_NAME "RMA PUT RTS" -#define UCP_PROTO_RMA_RNDV_GET_REQ_NAME "RMA GET REQ" -#define UCP_PROTO_RMA_RNDV_PUT_DESC "RMA PUT rendezvous" -#define UCP_PROTO_RMA_RNDV_GET_DESC "RMA GET rendezvous" -#define UCP_PROTO_RMA_RNDV_GET_PUSH_DESC "RMA GET rendezvous push" +#define UCP_PROTO_RMA_RNDV_PUT_RTS_NAME "RMA PUT RTS" +#define UCP_PROTO_RMA_RNDV_GET_REQ_NAME "RMA GET REQ" +#define UCP_PROTO_RMA_RNDV_DESC "rendezvous" +#define UCP_PROTO_RMA_RNDV_PUSH_DESC "rendezvous push" enum { @@ -228,20 +227,20 @@ static void ucp_proto_rma_rndv_query(const ucp_proto_query_params_t *params, ucs_snprintf_safe(attr->desc, sizeof(attr->desc), "%s using %s", desc, remote_attr.desc); - ucs_snprintf_safe(attr->config, sizeof(attr->config), "ctrl lane %u, %s", - rpriv->lane, remote_attr.config); + ucs_snprintf_safe(attr->config, sizeof(attr->config), "%s", + remote_attr.config); } static void ucp_proto_put_rndv_query(const ucp_proto_query_params_t *params, ucp_proto_query_attr_t *attr) { - ucp_proto_rma_rndv_query(params, attr, UCP_PROTO_RMA_RNDV_PUT_DESC); + ucp_proto_rma_rndv_query(params, attr, UCP_PROTO_RMA_RNDV_DESC); } static void ucp_proto_get_rndv_query(const ucp_proto_query_params_t *params, ucp_proto_query_attr_t *attr) { - ucp_proto_rma_rndv_query(params, attr, UCP_PROTO_RMA_RNDV_GET_DESC); + ucp_proto_rma_rndv_query(params, attr, UCP_PROTO_RMA_RNDV_DESC); } static void @@ -249,7 +248,7 @@ ucp_proto_get_rndv_push_query(const ucp_proto_query_params_t *params, ucp_proto_query_attr_t *attr) { ucp_proto_rma_rndv_query(params, attr, - UCP_PROTO_RMA_RNDV_GET_PUSH_DESC); + UCP_PROTO_RMA_RNDV_PUSH_DESC); } static size_t ucp_proto_get_rndv_req_pack(void *dest, void *arg) @@ -976,7 +975,7 @@ UCP_DEFINE_AM_WITH_PROXY(UCP_FEATURE_RMA, UCP_AM_ID_RMA_RNDV, ucp_proto_t ucp_put_rndv_proto = { .name = "put/rndv", - .desc = UCP_PROTO_RMA_RNDV_PUT_DESC, + .desc = UCP_PROTO_RMA_RNDV_DESC, .flags = 0, .probe = ucp_proto_put_rndv_probe, .query = ucp_proto_put_rndv_query, @@ -987,7 +986,7 @@ ucp_proto_t ucp_put_rndv_proto = { ucp_proto_t ucp_get_rndv_proto = { .name = "get/rndv", - .desc = UCP_PROTO_RMA_RNDV_GET_DESC, + .desc = UCP_PROTO_RMA_RNDV_DESC, .flags = 0, .probe = ucp_proto_get_rndv_probe, .query = ucp_proto_get_rndv_query, @@ -998,7 +997,7 @@ ucp_proto_t ucp_get_rndv_proto = { ucp_proto_t ucp_get_rndv_push_proto = { .name = "get/rndv/push", - .desc = UCP_PROTO_RMA_RNDV_GET_PUSH_DESC, + .desc = UCP_PROTO_RMA_RNDV_PUSH_DESC, .flags = 0, .probe = ucp_proto_get_rndv_push_probe, .query = ucp_proto_get_rndv_push_query, From 316b334817f64673eec2fd929276724be8c9efb4 Mon Sep 17 00:00:00 2001 From: Thomas Vegas Date: Fri, 22 May 2026 15:49:28 +0200 Subject: [PATCH 05/15] UCP/RMA: Merge common code --- src/ucp/rma/rma_rndv.c | 67 ++++++++++++++++-------------------------- 1 file changed, 26 insertions(+), 41 deletions(-) diff --git a/src/ucp/rma/rma_rndv.c b/src/ucp/rma/rma_rndv.c index ff79f77e077..44a03c17bf0 100644 --- a/src/ucp/rma/rma_rndv.c +++ b/src/ucp/rma/rma_rndv.c @@ -71,6 +71,23 @@ ucp_rma_rndv_dt_iter_init(ucp_datatype_iter_t *dt_iter, uint64_t address, dt_iter->type.contig.memh = NULL; } +static int +ucp_proto_rma_rndv_probe_check(const ucp_proto_init_params_t *init_params, + ucp_operation_id_t op_id) +{ + const ucp_proto_select_param_t *sel_param = init_params->select_param; + + if (!ucp_proto_init_check_op(init_params, UCS_BIT(op_id)) || + ucp_proto_rndv_init_params_is_ppln_frag(init_params) || + (sel_param->dt_class != UCP_DATATYPE_CONTIG) || + (init_params->rkey_config_key == NULL)) { + return 0; + } + + return !UCP_MEM_IS_HOST(sel_param->mem_type) || + !UCP_MEM_IS_HOST(init_params->rkey_config_key->mem_type); +} + static size_t ucp_proto_put_rndv_rts_pack(void *dest, void *arg) { @@ -148,7 +165,6 @@ static void ucp_proto_put_rndv_probe(const ucp_proto_init_params_t *init_params) { ucp_context_h context = init_params->worker->context; - const ucp_proto_select_param_t *sel_param = init_params->select_param; ucp_proto_rndv_ctrl_init_params_t params = { .super.super = *init_params, .super.latency = 0, @@ -179,15 +195,7 @@ ucp_proto_put_rndv_probe(const ucp_proto_init_params_t *init_params) }; ucp_proto_rndv_ctrl_priv_t rpriv = {0}; - if (!ucp_proto_init_check_op(init_params, UCS_BIT(UCP_OP_ID_PUT)) || - ucp_proto_rndv_init_params_is_ppln_frag(init_params) || - (sel_param->dt_class != UCP_DATATYPE_CONTIG) || - (init_params->rkey_config_key == NULL)) { - return; - } - - if (UCP_MEM_IS_HOST(sel_param->mem_type) && - UCP_MEM_IS_HOST(init_params->rkey_config_key->mem_type)) { + if (!ucp_proto_rma_rndv_probe_check(init_params, UCP_OP_ID_PUT)) { return; } @@ -231,14 +239,9 @@ static void ucp_proto_rma_rndv_query(const ucp_proto_query_params_t *params, remote_attr.config); } -static void ucp_proto_put_rndv_query(const ucp_proto_query_params_t *params, - ucp_proto_query_attr_t *attr) -{ - ucp_proto_rma_rndv_query(params, attr, UCP_PROTO_RMA_RNDV_DESC); -} - -static void ucp_proto_get_rndv_query(const ucp_proto_query_params_t *params, - ucp_proto_query_attr_t *attr) +static void +ucp_proto_rma_rndv_default_query(const ucp_proto_query_params_t *params, + ucp_proto_query_attr_t *attr) { ucp_proto_rma_rndv_query(params, attr, UCP_PROTO_RMA_RNDV_DESC); } @@ -327,7 +330,6 @@ static void ucp_proto_get_rndv_probe(const ucp_proto_init_params_t *init_params) { ucp_context_h context = init_params->worker->context; - const ucp_proto_select_param_t *sel_param = init_params->select_param; ucp_proto_rndv_ctrl_init_params_t params = { .super.super = *init_params, .super.latency = 0, @@ -357,15 +359,7 @@ ucp_proto_get_rndv_probe(const ucp_proto_init_params_t *init_params) }; ucp_proto_rndv_ctrl_priv_t rpriv = {0}; - if (!ucp_proto_init_check_op(init_params, UCS_BIT(UCP_OP_ID_GET)) || - ucp_proto_rndv_init_params_is_ppln_frag(init_params) || - (sel_param->dt_class != UCP_DATATYPE_CONTIG) || - (init_params->rkey_config_key == NULL)) { - return; - } - - if (UCP_MEM_IS_HOST(sel_param->mem_type) && - UCP_MEM_IS_HOST(init_params->rkey_config_key->mem_type)) { + if (!ucp_proto_rma_rndv_probe_check(init_params, UCP_OP_ID_GET)) { return; } @@ -421,7 +415,6 @@ static void ucp_proto_get_rndv_push_probe(const ucp_proto_init_params_t *init_params) { ucp_worker_h worker = init_params->worker; - const ucp_proto_select_param_t *sel_param = init_params->select_param; const ucp_proto_select_init_protocols_t *proto_init; ucp_proto_select_param_t rndv_sel_param; ucp_worker_cfg_index_t rkey_cfg_index; @@ -432,15 +425,7 @@ ucp_proto_get_rndv_push_probe(const ucp_proto_init_params_t *init_params) ucp_lane_index_t lane; const void *priv; - if (!ucp_proto_init_check_op(init_params, UCS_BIT(UCP_OP_ID_GET)) || - ucp_proto_rndv_init_params_is_ppln_frag(init_params) || - (sel_param->dt_class != UCP_DATATYPE_CONTIG) || - (init_params->rkey_config_key == NULL)) { - return; - } - - if (UCP_MEM_IS_HOST(sel_param->mem_type) && - UCP_MEM_IS_HOST(init_params->rkey_config_key->mem_type)) { + if (!ucp_proto_rma_rndv_probe_check(init_params, UCP_OP_ID_GET)) { return; } @@ -449,7 +434,7 @@ ucp_proto_get_rndv_push_probe(const ucp_proto_init_params_t *init_params) return; } - mem_info = ucp_proto_common_select_param_mem_info(sel_param); + mem_info = ucp_proto_common_select_param_mem_info(init_params->select_param); ucp_proto_select_param_init(&rndv_sel_param, UCP_OP_ID_RNDV_RECV, 0, UCP_PROTO_SELECT_OP_FLAG_RNDV_PUSH, UCP_DATATYPE_CONTIG, &mem_info, 1); @@ -978,7 +963,7 @@ ucp_proto_t ucp_put_rndv_proto = { .desc = UCP_PROTO_RMA_RNDV_DESC, .flags = 0, .probe = ucp_proto_put_rndv_probe, - .query = ucp_proto_put_rndv_query, + .query = ucp_proto_rma_rndv_default_query, .progress = {ucp_proto_put_rndv_progress}, .abort = ucp_proto_rndv_rts_abort, .reset = ucp_proto_rndv_rts_reset @@ -989,7 +974,7 @@ ucp_proto_t ucp_get_rndv_proto = { .desc = UCP_PROTO_RMA_RNDV_DESC, .flags = 0, .probe = ucp_proto_get_rndv_probe, - .query = ucp_proto_get_rndv_query, + .query = ucp_proto_rma_rndv_default_query, .progress = {ucp_proto_get_rndv_progress}, .abort = ucp_proto_get_rndv_abort, .reset = ucp_proto_get_rndv_reset From 309dcdd1ebd28f51fda72f1f93e00186b24008ce Mon Sep 17 00:00:00 2001 From: Thomas Vegas Date: Fri, 22 May 2026 16:56:07 +0200 Subject: [PATCH 06/15] UCP/RMA: Keep only push-based RNDV GET --- src/ucp/proto/proto.c | 1 - src/ucp/proto/proto.h | 2 +- src/ucp/rma/rma_rndv.c | 439 +++----------------------------------- src/ucp/rndv/proto_rndv.h | 3 +- 4 files changed, 28 insertions(+), 417 deletions(-) diff --git a/src/ucp/proto/proto.c b/src/ucp/proto/proto.c index 988fc0446c0..fae0b9bd1f1 100644 --- a/src/ucp/proto/proto.c +++ b/src/ucp/proto/proto.c @@ -25,7 +25,6 @@ _macro(ucp_get_am_bcopy_proto) \ _macro(ucp_get_offload_bcopy_proto) \ _macro(ucp_get_offload_zcopy_proto) \ - _macro(ucp_get_rndv_push_proto) \ _macro(ucp_get_rndv_proto) \ _macro(ucp_put_am_bcopy_proto) \ _macro(ucp_put_offload_short_proto) \ diff --git a/src/ucp/proto/proto.h b/src/ucp/proto/proto.h index c17a23bbfe0..2600a07e665 100644 --- a/src/ucp/proto/proto.h +++ b/src/ucp/proto/proto.h @@ -23,7 +23,7 @@ /* Maximal number of protocols in total */ -#define UCP_PROTO_MAX_COUNT 65 +#define UCP_PROTO_MAX_COUNT 64 /* Special value for non-existent protocol */ diff --git a/src/ucp/rma/rma_rndv.c b/src/ucp/rma/rma_rndv.c index 44a03c17bf0..87f941a6793 100644 --- a/src/ucp/rma/rma_rndv.c +++ b/src/ucp/rma/rma_rndv.c @@ -20,16 +20,11 @@ #include -#define UCP_PROTO_RMA_RNDV_PUT_RTS_NAME "RMA PUT RTS" -#define UCP_PROTO_RMA_RNDV_GET_REQ_NAME "RMA GET REQ" -#define UCP_PROTO_RMA_RNDV_DESC "rendezvous" -#define UCP_PROTO_RMA_RNDV_PUSH_DESC "rendezvous push" +#define UCP_PROTO_RMA_RNDV_RTS_NAME "RMA_RTS" enum { - UCP_RMA_RNDV_AM_PUT_RTS, - UCP_RMA_RNDV_AM_GET_REQ, - UCP_RMA_RNDV_AM_GET_RTS + UCP_RMA_RNDV_AM_PUT_RTS }; @@ -41,22 +36,6 @@ typedef struct { } UCS_S_PACKED ucp_rma_rndv_put_rts_hdr_t; -typedef struct { - uint64_t hdr; - ucp_request_hdr_t req; - uint64_t address; - size_t size; - ucs_sys_device_t sys_dev; - ucs_memory_type_t mem_type; -} UCS_S_PACKED ucp_rma_rndv_get_req_hdr_t; - - -typedef struct { - ucp_rndv_rts_hdr_t super; - ucs_ptr_map_key_t get_req_id; -} UCS_S_PACKED ucp_rma_rndv_get_rts_hdr_t; - - static void ucp_rma_rndv_dt_iter_init(ucp_datatype_iter_t *dt_iter, uint64_t address, size_t length, ucs_memory_type_t mem_type, @@ -190,7 +169,7 @@ ucp_proto_put_rndv_probe(const ucp_proto_init_params_t *init_params) .lane = ucp_proto_rndv_find_ctrl_lane(init_params), .unpack_perf = NULL, .perf_bias = 0, - .ctrl_msg_name = UCP_PROTO_RMA_RNDV_PUT_RTS_NAME, + .ctrl_msg_name = UCP_PROTO_RMA_RNDV_RTS_NAME, .md_map = 0 }; ucp_proto_rndv_ctrl_priv_t rpriv = {0}; @@ -240,134 +219,21 @@ static void ucp_proto_rma_rndv_query(const ucp_proto_query_params_t *params, } static void -ucp_proto_rma_rndv_default_query(const ucp_proto_query_params_t *params, - ucp_proto_query_attr_t *attr) +ucp_proto_put_rndv_query(const ucp_proto_query_params_t *params, + ucp_proto_query_attr_t *attr) { - ucp_proto_rma_rndv_query(params, attr, UCP_PROTO_RMA_RNDV_DESC); + ucp_proto_rma_rndv_query(params, attr, UCP_PROTO_RNDV_DESC); } static void -ucp_proto_get_rndv_push_query(const ucp_proto_query_params_t *params, - ucp_proto_query_attr_t *attr) +ucp_proto_get_rndv_query(const ucp_proto_query_params_t *params, + ucp_proto_query_attr_t *attr) { - ucp_proto_rma_rndv_query(params, attr, - UCP_PROTO_RMA_RNDV_PUSH_DESC); -} - -static size_t ucp_proto_get_rndv_req_pack(void *dest, void *arg) -{ - ucp_request_t *req = arg; - ucp_rma_rndv_get_req_hdr_t *hdr = dest; - ucp_rkey_config_t *rkey_config; - - rkey_config = ucp_rkey_config(req->send.ep->worker, req->send.rma.rkey); - - hdr->hdr = UCP_RMA_RNDV_AM_GET_REQ; - hdr->req.ep_id = ucp_send_request_get_ep_remote_id(req); - hdr->req.req_id = ucp_send_request_get_id(req); - hdr->address = req->send.rma.remote_addr; - hdr->size = req->send.state.dt_iter.length; - hdr->sys_dev = rkey_config->key.sys_dev; - hdr->mem_type = req->send.rma.rkey->mem_type; - - return sizeof(*hdr); -} - -static ucs_status_t ucp_proto_get_rndv_init(ucp_request_t *req) -{ - const ucp_proto_rndv_ctrl_priv_t *rpriv = req->send.proto_config->priv; - ucs_status_t status; - - if (req->flags & UCP_REQUEST_FLAG_PROTO_INITIALIZED) { - return UCS_OK; - } - - status = ucp_ep_resolve_remote_id(req->send.ep, rpriv->lane); - if (status != UCS_OK) { - return status; - } - - req->send.buffer = req->send.state.dt_iter.type.contig.buffer; - req->send.length = req->send.state.dt_iter.length; - ucp_send_request_id_alloc(req); - req->flags |= UCP_REQUEST_FLAG_PROTO_INITIALIZED; - - /* The nested RNDV receive starts only after GET_RTS; this wrapper still - * has to respect RMA fence ordering before the target can expose data. */ - return ucp_ep_rma_handle_fence(req->send.ep, req, UCS_BIT(rpriv->lane)); -} - -static ucs_status_t ucp_proto_get_rndv_progress(uct_pending_req_t *self) -{ - ucp_request_t *req = ucs_container_of(self, ucp_request_t, send.uct); - const ucp_proto_rndv_ctrl_priv_t *rpriv; - ucs_status_t status; - - status = ucp_proto_get_rndv_init(req); - if (status != UCS_OK) { - ucp_proto_request_abort(req, status); - return UCS_OK; - } - - rpriv = req->send.proto_config->priv; - status = ucp_proto_am_bcopy_single_send(req, UCP_AM_ID_RMA_RNDV, - rpriv->lane, - ucp_proto_get_rndv_req_pack, req, - sizeof(ucp_rma_rndv_get_req_hdr_t), - 0); - if (status == UCS_ERR_NO_RESOURCE) { - req->send.lane = rpriv->lane; - return status; - } else if (status != UCS_OK) { - ucp_proto_request_abort(req, status); - return UCS_OK; - } - - return UCS_OK; -} - -static void -ucp_proto_get_rndv_probe(const ucp_proto_init_params_t *init_params) -{ - ucp_context_h context = init_params->worker->context; - ucp_proto_rndv_ctrl_init_params_t params = { - .super.super = *init_params, - .super.latency = 0, - .super.overhead = context->config.ext.proto_overhead_rndv_rts, - .super.cfg_thresh = context->config.ext.zcopy_thresh, - .super.cfg_priority = 5, - .super.min_length = 1, - .super.max_length = SIZE_MAX, - .super.min_iov = 1, - .super.min_frag_offs = UCP_PROTO_COMMON_OFFSET_INVALID, - .super.max_frag_offs = ucs_offsetof(uct_iface_attr_t, cap.am.max_bcopy), - .super.max_iov_offs = UCP_PROTO_COMMON_OFFSET_INVALID, - .super.hdr_size = sizeof(ucp_rma_rndv_get_req_hdr_t), - .super.send_op = UCT_EP_OP_AM_BCOPY, - .super.memtype_op = UCT_EP_OP_LAST, - .super.flags = UCP_PROTO_COMMON_INIT_FLAG_ERR_HANDLING, - .super.exclude_map = 0, - .super.reg_mem_info = ucp_proto_common_select_param_mem_info( - init_params->select_param), - /* The peer turns GET_REQ into a synthetic RNDV sender. */ - .remote_op_id = UCP_OP_ID_RNDV_SEND, - .lane = ucp_proto_rndv_find_ctrl_lane(init_params), - .unpack_perf = NULL, - .perf_bias = 0, - .ctrl_msg_name = UCP_PROTO_RMA_RNDV_GET_REQ_NAME, - .md_map = 0 - }; - ucp_proto_rndv_ctrl_priv_t rpriv = {0}; - - if (!ucp_proto_rma_rndv_probe_check(init_params, UCP_OP_ID_GET)) { - return; - } - - ucp_proto_rndv_ctrl_probe(¶ms, &rpriv, sizeof(rpriv)); + ucp_proto_rma_rndv_query(params, attr, UCP_PROTO_RNDV_DESC); } static void -ucp_proto_get_rndv_push_add_variant( +ucp_proto_get_rndv_add_variant( const ucp_proto_init_params_t *init_params, const ucp_proto_select_param_t *select_param, ucp_worker_cfg_index_t rkey_cfg_index, ucp_lane_index_t lane, @@ -412,7 +278,7 @@ ucp_proto_get_rndv_push_add_variant( } static void -ucp_proto_get_rndv_push_probe(const ucp_proto_init_params_t *init_params) +ucp_proto_get_rndv_probe(const ucp_proto_init_params_t *init_params) { ucp_worker_h worker = init_params->worker; const ucp_proto_select_init_protocols_t *proto_init; @@ -434,7 +300,8 @@ ucp_proto_get_rndv_push_probe(const ucp_proto_init_params_t *init_params) return; } - mem_info = ucp_proto_common_select_param_mem_info(init_params->select_param); + mem_info = ucp_proto_common_select_param_mem_info( + init_params->select_param); ucp_proto_select_param_init(&rndv_sel_param, UCP_OP_ID_RNDV_RECV, 0, UCP_PROTO_SELECT_OP_FLAG_RNDV_PUSH, UCP_DATATYPE_CONTIG, &mem_info, 1); @@ -462,12 +329,13 @@ ucp_proto_get_rndv_push_probe(const ucp_proto_init_params_t *init_params) } priv = &ucs_array_elem(&proto_init->priv_buf, proto->priv_offset); - ucp_proto_get_rndv_push_add_variant(init_params, &rndv_sel_param, - rkey_cfg_index, lane, proto, priv); + ucp_proto_get_rndv_add_variant(init_params, &rndv_sel_param, + rkey_cfg_index, lane, proto, priv); } } -static void ucp_proto_get_rndv_abort(ucp_request_t *req, ucs_status_t status) +static void +ucp_proto_get_rndv_abort(ucp_request_t *req, ucs_status_t status) { if (req->id != UCS_PTR_MAP_KEY_INVALID) { ucp_send_request_id_release(req); @@ -498,145 +366,6 @@ static ucs_status_t ucp_proto_get_rndv_reset(ucp_request_t *req) return UCS_OK; } -static void -ucp_rma_rndv_get_send_complete(void *request, - ucs_status_t UCS_V_UNUSED status, - void *UCS_V_UNUSED user_data) -{ - ucp_request_t *req = (ucp_request_t*)request - 1; - - ucp_request_put(req); -} - -static void -ucp_rma_rndv_get_send_abort(ucp_request_t *req, ucs_status_t status) -{ - if (req->send.rndv.remote_req_id != UCS_PTR_MAP_KEY_INVALID) { - ucp_rma_rndv_send_ats_err(req->send.ep, req->send.rndv.remote_req_id, - status); - req->send.rndv.remote_req_id = UCS_PTR_MAP_KEY_INVALID; - } - - if (req->id != UCS_PTR_MAP_KEY_INVALID) { - ucp_send_request_id_release(req); - } - - ucp_datatype_iter_cleanup(&req->send.state.dt_iter, 1, UCP_DT_MASK_ALL); - ucp_request_complete_send(req, status); -} - -static size_t ucp_rma_rndv_get_rts_pack(void *dest, void *arg) -{ - ucp_request_t *req = arg; - ucp_rma_rndv_get_rts_hdr_t *rts = dest; - ucp_datatype_iter_t *dt_iter = &req->send.state.dt_iter; - void *rkey_buffer = UCS_PTR_BYTE_OFFSET(rts, sizeof(*rts)); - size_t rkey_size = 0; - - rts->super.hdr = UCP_RMA_RNDV_AM_GET_RTS; - rts->super.opcode = UCP_RNDV_RTS_TAG_OK; - rts->super.sreq.req_id = ucp_send_request_get_id(req); - rts->super.sreq.ep_id = ucp_send_request_get_ep_remote_id(req); - rts->super.size = dt_iter->length; - rts->super.address = 0; - rts->get_req_id = req->send.rndv.remote_req_id; - - if ((dt_iter->length > 0) && (req->send.rndv.md_map != 0)) { - rkey_size = UCS_PROFILE_CALL(ucp_proto_request_pack_rkey, req, - req->send.rndv.md_map, 0, NULL, - rkey_buffer); - if (rkey_size > 0) { - rts->super.address = (uintptr_t)dt_iter->type.contig.buffer; - } - } - - return sizeof(*rts) + rkey_size; -} - -static ucs_status_t ucp_rma_rndv_get_rts_progress(uct_pending_req_t *self) -{ - ucp_request_t *req = ucs_container_of(self, ucp_request_t, send.uct); - ucp_lane_index_t lane; - ucs_status_t status; - - lane = ucp_ep_get_am_lane(req->send.ep); - req->send.lane = lane; - status = ucp_proto_am_bcopy_single_send( - req, UCP_AM_ID_RMA_RNDV, lane, ucp_rma_rndv_get_rts_pack, req, - sizeof(ucp_rma_rndv_get_rts_hdr_t) + - ucp_ep_config(req->send.ep)->rndv.rkey_size, 0); - if (status == UCS_ERR_NO_RESOURCE) { - return status; - } else if (status != UCS_OK) { - ucp_rma_rndv_get_send_abort(req, status); - } - - return UCS_OK; -} - -static ucs_status_t -ucp_rma_rndv_get_sreq_init(ucp_ep_h ep, ucp_request_t *req, - const ucp_rma_rndv_get_req_hdr_t *get_req) -{ - ucp_proto_select_param_t sel_param; - ucp_md_map_t md_map; - ucs_status_t status; - - ucp_proto_request_send_init(req, ep, - UCP_REQUEST_FLAG_RNDV_SEND_INTERNAL); - ucp_request_set_callback(req, send.cb, ucp_rma_rndv_get_send_complete); - req->send.buffer = (void*)(uintptr_t)get_req->address; - req->send.length = get_req->size; - req->send.mem_type = get_req->mem_type; - req->send.rndv.remote_req_id = get_req->req.req_id; - req->send.rndv.rkey = NULL; - req->send.rndv.remote_address = get_req->address; - req->send.rndv.md_map = 0; - ucp_rma_rndv_dt_iter_init(&req->send.state.dt_iter, get_req->address, - get_req->size, get_req->mem_type, - get_req->sys_dev); - - status = ucp_ep_resolve_remote_id(ep, ucp_ep_get_am_lane(ep)); - if (status != UCS_OK) { - goto err_cleanup; - } - - md_map = ucp_ep_config(ep)->key.rma_bw_md_map; - if ((get_req->size > 0) && (md_map != 0)) { - status = ucp_datatype_iter_mem_reg(ep->worker->context, - &req->send.state.dt_iter, md_map, - UCT_MD_MEM_ACCESS_RMA | - UCT_MD_MEM_FLAG_HIDE_ERRORS, - UCP_DT_MASK_ALL); - if (status != UCS_OK) { - goto err_cleanup; - } - - req->send.rndv.md_map = - req->send.state.dt_iter.type.contig.memh->md_map & md_map; - } - - ucp_proto_select_param_init(&sel_param, UCP_OP_ID_RNDV_SEND, 0, 0, - UCP_DATATYPE_CONTIG, - &req->send.state.dt_iter.mem_info, 1); - status = UCS_PROFILE_CALL(ucp_proto_request_lookup_proto, ep->worker, ep, - req, &ucp_ep_config(ep)->proto_select, - UCP_WORKER_CFG_INDEX_NULL, &sel_param, - get_req->size); - if (status != UCS_OK) { - goto err_cleanup; - } - - ucp_send_request_id_alloc(req); - req->flags |= UCP_REQUEST_FLAG_PROTO_INITIALIZED; - req->send.uct.func = ucp_rma_rndv_get_rts_progress; - return UCS_OK; - -err_cleanup: - ucp_datatype_iter_cleanup(&req->send.state.dt_iter, 1, UCP_DT_MASK_ALL); - return status; -} - static void ucp_rma_rndv_get_recv_complete(ucp_request_t *recv_req) { ucp_request_t *get_req = ucp_request_get_super(recv_req); @@ -645,8 +374,8 @@ static void ucp_rma_rndv_get_recv_complete(ucp_request_t *recv_req) ucp_request_put(recv_req); } -static ucs_status_t ucp_proto_get_rndv_push_init(ucp_request_t *get_req, - ucp_request_t **rndv_req_p) +static ucs_status_t ucp_proto_get_rndv_init(ucp_request_t *get_req, + ucp_request_t **rndv_req_p) { const ucp_proto_rndv_ctrl_priv_t *rpriv = get_req->send.proto_config->priv; ucp_worker_h worker = get_req->send.ep->worker; @@ -720,7 +449,7 @@ static ucs_status_t ucp_proto_get_rndv_push_init(ucp_request_t *get_req, return UCS_OK; } -static ucs_status_t ucp_proto_get_rndv_push_progress(uct_pending_req_t *self) +static ucs_status_t ucp_proto_get_rndv_progress(uct_pending_req_t *self) { ucp_request_t *get_req = ucs_container_of(self, ucp_request_t, send.uct); ucp_request_t *rndv_req; @@ -730,7 +459,7 @@ static ucs_status_t ucp_proto_get_rndv_push_progress(uct_pending_req_t *self) return UCS_OK; } - status = ucp_proto_get_rndv_push_init(get_req, &rndv_req); + status = ucp_proto_get_rndv_init(get_req, &rndv_req); if (status != UCS_OK) { ucp_proto_request_abort(get_req, status); return UCS_OK; @@ -740,85 +469,6 @@ static ucs_status_t ucp_proto_get_rndv_push_progress(uct_pending_req_t *self) return UCS_OK; } -static ucs_status_t -ucp_rma_rndv_handle_get_req(ucp_worker_h worker, void *data, size_t length) -{ - const ucp_rma_rndv_get_req_hdr_t *get_req = data; - ucp_request_t *req; - ucs_status_t status; - ucp_ep_h ep; - - if (length < sizeof(*get_req)) { - return UCS_ERR_MESSAGE_TRUNCATED; - } - - UCP_WORKER_GET_EP_BY_ID(&ep, worker, get_req->req.ep_id, return UCS_OK, - "RMA RNDV GET request"); - - req = ucp_request_get(worker); - if (req == NULL) { - ucs_error("failed to allocate RMA RNDV GET send request"); - ucp_rma_rndv_send_ats_err(ep, get_req->req.req_id, UCS_ERR_NO_MEMORY); - return UCS_OK; - } - - status = ucp_rma_rndv_get_sreq_init(ep, req, get_req); - if (status != UCS_OK) { - ucp_rma_rndv_send_ats_err(ep, get_req->req.req_id, status); - ucp_request_put(req); - return UCS_OK; - } - - ucp_request_send(req); - return UCS_OK; -} - -static ucs_status_t -ucp_rma_rndv_handle_get_rts(ucp_worker_h worker, void *data, size_t length) -{ - const ucp_rma_rndv_get_rts_hdr_t *rts = data; - ucp_request_t *get_req, *recv_req; - uint8_t UCS_V_UNUSED sg_count; - const void *rkey_buffer; - ucp_ep_h ep; - - if (length < sizeof(*rts)) { - return UCS_ERR_MESSAGE_TRUNCATED; - } - - UCP_SEND_REQUEST_GET_BY_ID(&get_req, worker, rts->get_req_id, 0, - return UCS_OK, "RMA RNDV GET_RTS %p", rts); - - recv_req = ucp_request_get(worker); - if (recv_req == NULL) { - ucs_error("failed to allocate RMA RNDV GET receive request"); - UCP_WORKER_GET_EP_BY_ID(&ep, worker, rts->super.sreq.ep_id, - return UCS_OK, "RMA RNDV GET_RTS error"); - ucp_rma_rndv_send_ats_err(ep, rts->super.sreq.req_id, - UCS_ERR_NO_MEMORY); - ucp_proto_get_rndv_abort(get_req, UCS_ERR_NO_MEMORY); - return UCS_OK; - } - - ucp_send_request_id_release(get_req); - recv_req->flags = UCP_REQUEST_FLAG_RNDV_RECV_INTERNAL; - recv_req->recv.worker = worker; - recv_req->recv.op_attr = 0; - recv_req->recv.remote_req_id = rts->super.sreq.req_id; - recv_req->recv.rndv.ep_id = rts->super.sreq.ep_id; - recv_req->recv.rndv.complete_cb = ucp_rma_rndv_get_recv_complete; - ucp_request_set_super(recv_req, get_req); - - UCS_PROFILE_CALL_VOID(ucp_datatype_iter_move, &recv_req->recv.dt_iter, - &get_req->send.state.dt_iter, - get_req->send.state.dt_iter.length, &sg_count); - - rkey_buffer = UCS_PTR_BYTE_OFFSET(rts, sizeof(*rts)); - ucp_proto_rndv_receive_start(worker, recv_req, &rts->super, rkey_buffer, - length - sizeof(*rts)); - return UCS_OK; -} - static void ucp_rma_rndv_put_recv_complete(ucp_request_t *recv_req) { ucp_worker_h worker = recv_req->recv.worker; @@ -890,10 +540,6 @@ UCS_PROFILE_FUNC(ucs_status_t, ucp_rma_rndv_handler, switch (*hdr) { case UCP_RMA_RNDV_AM_PUT_RTS: return ucp_rma_rndv_handle_put_rts(worker, data, length); - case UCP_RMA_RNDV_AM_GET_REQ: - return ucp_rma_rndv_handle_get_req(worker, data, length); - case UCP_RMA_RNDV_AM_GET_RTS: - return ucp_rma_rndv_handle_get_rts(worker, data, length); default: ucs_debug("unexpected RMA RNDV AM sub-id %" PRIu64, *hdr); return UCS_ERR_UNSUPPORTED; @@ -906,8 +552,6 @@ ucp_rma_rndv_dump_packet(ucp_worker_h worker, uct_am_trace_type_t type, char *buffer, size_t max) { const ucp_rma_rndv_put_rts_hdr_t *put_rts = data; - const ucp_rma_rndv_get_req_hdr_t *get_req = data; - const ucp_rma_rndv_get_rts_hdr_t *get_rts = data; const uint64_t *hdr = data; if (length < sizeof(*hdr)) { @@ -927,28 +571,6 @@ ucp_rma_rndv_dump_packet(ucp_worker_h worker, uct_am_trace_type_t type, put_rts->super.sreq.req_id, put_rts->super.sreq.ep_id, ucs_memory_type_names[put_rts->mem_type]); break; - case UCP_RMA_RNDV_AM_GET_REQ: - if (length < sizeof(*get_req)) { - return; - } - - snprintf(buffer, max, "RMA_GET_REQ [src 0x%" PRIx64 - " len %zu req_id 0x%" PRIx64 " ep_id 0x%" PRIx64 - " %s]", get_req->address, get_req->size, - get_req->req.req_id, get_req->req.ep_id, - ucs_memory_type_names[get_req->mem_type]); - break; - case UCP_RMA_RNDV_AM_GET_RTS: - if (length < sizeof(*get_rts)) { - return; - } - - snprintf(buffer, max, "RMA_GET_RTS [src 0x%" PRIx64 - " len %zu sreq_id 0x%" PRIx64 " ep_id 0x%" PRIx64 - " get_req_id 0x%" PRIx64 "]", get_rts->super.address, - get_rts->super.size, get_rts->super.sreq.req_id, - get_rts->super.sreq.ep_id, get_rts->get_req_id); - break; default: snprintf(buffer, max, "RMA_RNDV [sub-id %" PRIu64 "]", *hdr); break; @@ -960,10 +582,10 @@ UCP_DEFINE_AM_WITH_PROXY(UCP_FEATURE_RMA, UCP_AM_ID_RMA_RNDV, ucp_proto_t ucp_put_rndv_proto = { .name = "put/rndv", - .desc = UCP_PROTO_RMA_RNDV_DESC, + .desc = UCP_PROTO_RNDV_DESC, .flags = 0, .probe = ucp_proto_put_rndv_probe, - .query = ucp_proto_rma_rndv_default_query, + .query = ucp_proto_put_rndv_query, .progress = {ucp_proto_put_rndv_progress}, .abort = ucp_proto_rndv_rts_abort, .reset = ucp_proto_rndv_rts_reset @@ -971,22 +593,11 @@ ucp_proto_t ucp_put_rndv_proto = { ucp_proto_t ucp_get_rndv_proto = { .name = "get/rndv", - .desc = UCP_PROTO_RMA_RNDV_DESC, + .desc = UCP_PROTO_RNDV_DESC, .flags = 0, .probe = ucp_proto_get_rndv_probe, - .query = ucp_proto_rma_rndv_default_query, + .query = ucp_proto_get_rndv_query, .progress = {ucp_proto_get_rndv_progress}, .abort = ucp_proto_get_rndv_abort, .reset = ucp_proto_get_rndv_reset }; - -ucp_proto_t ucp_get_rndv_push_proto = { - .name = "get/rndv/push", - .desc = UCP_PROTO_RMA_RNDV_PUSH_DESC, - .flags = 0, - .probe = ucp_proto_get_rndv_push_probe, - .query = ucp_proto_get_rndv_push_query, - .progress = {ucp_proto_get_rndv_push_progress}, - .abort = ucp_proto_get_rndv_abort, - .reset = ucp_proto_get_rndv_reset -}; diff --git a/src/ucp/rndv/proto_rndv.h b/src/ucp/rndv/proto_rndv.h index 66cc69f30dd..419c233603c 100644 --- a/src/ucp/rndv/proto_rndv.h +++ b/src/ucp/rndv/proto_rndv.h @@ -12,7 +12,8 @@ #include -/* Names of rendezvous control messages */ +/* Rendezvous protocol description and control message names */ +#define UCP_PROTO_RNDV_DESC "rndv" #define UCP_PROTO_RNDV_RTS_NAME "RTS" #define UCP_PROTO_RNDV_RTR_NAME "RTR" #define UCP_PROTO_RNDV_RTR_REQ_NAME "RTR_REQ" From 7d8a0fecf7891b405a18c71a3c8f21ff733e9525 Mon Sep 17 00:00:00 2001 From: Thomas Vegas Date: Fri, 22 May 2026 17:48:50 +0200 Subject: [PATCH 07/15] UCP/RMA: Reuse RTS for RMA case --- src/ucp/Makefile.am | 1 + src/ucp/core/ucp_context.c | 3 +- src/ucp/core/ucp_types.h | 2 - src/ucp/rma/rma_rndv.c | 89 +++++--------------------------------- src/ucp/rma/rma_rndv.h | 26 +++++++++++ src/ucp/rndv/rndv.c | 25 ++++++++--- src/ucp/rndv/rndv.h | 4 +- src/ucp/rndv/rndv.inl | 6 +++ 8 files changed, 66 insertions(+), 90 deletions(-) create mode 100644 src/ucp/rma/rma_rndv.h diff --git a/src/ucp/Makefile.am b/src/ucp/Makefile.am index 8280a58f5d8..50bca31c02d 100644 --- a/src/ucp/Makefile.am +++ b/src/ucp/Makefile.am @@ -68,6 +68,7 @@ noinst_HEADERS = \ proto/proto.h \ rma/rma.h \ rma/rma.inl \ + rma/rma_rndv.h \ rndv/proto_rndv.h \ rndv/proto_rndv.inl \ rndv/rndv_mtype.inl \ diff --git a/src/ucp/core/ucp_context.c b/src/ucp/core/ucp_context.c index e13771e76b8..4ae9573c7ba 100644 --- a/src/ucp/core/ucp_context.c +++ b/src/ucp/core/ucp_context.c @@ -61,8 +61,7 @@ _macro(UCP_AM_ID_AM_MIDDLE) \ _macro(UCP_AM_ID_AM_SINGLE_REPLY) \ _macro(UCP_AM_ID_AM_FIRST_PSN) \ - _macro(UCP_AM_ID_AM_MIDDLE_PSN) \ - _macro(UCP_AM_ID_RMA_RNDV) + _macro(UCP_AM_ID_AM_MIDDLE_PSN) #define UCP_AM_HANDLER_DECL(_id) extern ucp_am_handler_t ucp_am_handler_##_id; diff --git a/src/ucp/core/ucp_types.h b/src/ucp/core/ucp_types.h index abba8deed2b..e78506403ad 100644 --- a/src/ucp/core/ucp_types.h +++ b/src/ucp/core/ucp_types.h @@ -204,8 +204,6 @@ typedef enum { carrying remote ep and PSN for tracking */ UCP_AM_ID_AM_MIDDLE_PSN = 28, - - UCP_AM_ID_RMA_RNDV = 29, /* RMA rendezvous control */ UCP_AM_ID_LAST } ucp_am_id_t; diff --git a/src/ucp/rma/rma_rndv.c b/src/ucp/rma/rma_rndv.c index 87f941a6793..639cf4d42bb 100644 --- a/src/ucp/rma/rma_rndv.c +++ b/src/ucp/rma/rma_rndv.c @@ -10,6 +10,7 @@ #include "rma.h" #include "rma.inl" +#include "rma_rndv.h" #include #include @@ -23,19 +24,6 @@ #define UCP_PROTO_RMA_RNDV_RTS_NAME "RMA_RTS" -enum { - UCP_RMA_RNDV_AM_PUT_RTS -}; - - -typedef struct { - ucp_rndv_rts_hdr_t super; - uint64_t address; - ucs_sys_device_t sys_dev; - ucs_memory_type_t mem_type; -} UCS_S_PACKED ucp_rma_rndv_put_rts_hdr_t; - - static void ucp_rma_rndv_dt_iter_init(ucp_datatype_iter_t *dt_iter, uint64_t address, size_t length, ucs_memory_type_t mem_type, @@ -70,14 +58,14 @@ ucp_proto_rma_rndv_probe_check(const ucp_proto_init_params_t *init_params, static size_t ucp_proto_put_rndv_rts_pack(void *dest, void *arg) { - ucp_request_t *req = arg; - ucp_rma_rndv_put_rts_hdr_t *rts = dest; + ucp_request_t *req = arg; + ucp_rma_rndv_rts_hdr_t *rts = dest; ucp_rkey_config_t *rkey_config; rkey_config = ucp_rkey_config(req->send.ep->worker, req->send.rma.rkey); - rts->super.hdr = UCP_RMA_RNDV_AM_PUT_RTS; - rts->super.opcode = UCP_RNDV_RTS_TAG_OK; + rts->super.hdr = 0; + rts->super.opcode = UCP_RNDV_RTS_RMA; rts->address = req->send.rma.remote_addr; rts->sys_dev = rkey_config->key.sys_dev; rts->mem_type = req->send.rma.rkey->mem_type; @@ -118,11 +106,10 @@ static ucs_status_t ucp_proto_put_rndv_progress(uct_pending_req_t *self) ep = req->send.ep; rpriv = req->send.proto_config->priv; - max_rts_size = sizeof(ucp_rma_rndv_put_rts_hdr_t) + - rpriv->packed_rkey_size; + max_rts_size = sizeof(ucp_rma_rndv_rts_hdr_t) + rpriv->packed_rkey_size; ucp_worker_flush_ops_count_add(ep->worker, +1); - status = ucp_proto_am_bcopy_single_send(req, UCP_AM_ID_RMA_RNDV, + status = ucp_proto_am_bcopy_single_send(req, UCP_AM_ID_RNDV_RTS, rpriv->lane, ucp_proto_put_rndv_rts_pack, req, max_rts_size, 0); @@ -156,7 +143,7 @@ ucp_proto_put_rndv_probe(const ucp_proto_init_params_t *init_params) .super.min_frag_offs = UCP_PROTO_COMMON_OFFSET_INVALID, .super.max_frag_offs = ucs_offsetof(uct_iface_attr_t, cap.am.max_bcopy), .super.max_iov_offs = UCP_PROTO_COMMON_OFFSET_INVALID, - .super.hdr_size = sizeof(ucp_rma_rndv_put_rts_hdr_t), + .super.hdr_size = sizeof(ucp_rma_rndv_rts_hdr_t), .super.send_op = UCT_EP_OP_AM_BCOPY, .super.memtype_op = UCT_EP_OP_LAST, .super.flags = UCP_PROTO_COMMON_INIT_FLAG_ERR_HANDLING, @@ -488,10 +475,10 @@ static void ucp_rma_rndv_put_recv_complete(ucp_request_t *recv_req) ucp_request_put(recv_req); } -static ucs_status_t -ucp_rma_rndv_handle_put_rts(ucp_worker_h worker, void *data, size_t length) +ucs_status_t ucp_rma_rndv_process_rts(ucp_worker_h worker, + const ucp_rma_rndv_rts_hdr_t *rts, + size_t length) { - const ucp_rma_rndv_put_rts_hdr_t *rts = data; const void *rkey_buffer; ucp_request_t *recv_req; ucp_ep_h ep; @@ -526,60 +513,6 @@ ucp_rma_rndv_handle_put_rts(ucp_worker_h worker, void *data, size_t length) return UCS_OK; } -UCS_PROFILE_FUNC(ucs_status_t, ucp_rma_rndv_handler, - (arg, data, length, am_flags), void *arg, void *data, - size_t length, unsigned am_flags) -{ - const uint64_t *hdr = data; - ucp_worker_h worker = arg; - - if (length < sizeof(*hdr)) { - return UCS_ERR_MESSAGE_TRUNCATED; - } - - switch (*hdr) { - case UCP_RMA_RNDV_AM_PUT_RTS: - return ucp_rma_rndv_handle_put_rts(worker, data, length); - default: - ucs_debug("unexpected RMA RNDV AM sub-id %" PRIu64, *hdr); - return UCS_ERR_UNSUPPORTED; - } -} - -static void -ucp_rma_rndv_dump_packet(ucp_worker_h worker, uct_am_trace_type_t type, - uint8_t id, const void *data, size_t length, - char *buffer, size_t max) -{ - const ucp_rma_rndv_put_rts_hdr_t *put_rts = data; - const uint64_t *hdr = data; - - if (length < sizeof(*hdr)) { - return; - } - - switch (*hdr) { - case UCP_RMA_RNDV_AM_PUT_RTS: - if (length < sizeof(*put_rts)) { - return; - } - - snprintf(buffer, max, "RMA_PUT_RTS [src 0x%" PRIx64 - " dst 0x%" PRIx64 " len %zu req_id 0x%" PRIx64 - " ep_id 0x%" PRIx64 " %s]", put_rts->super.address, - put_rts->address, put_rts->super.size, - put_rts->super.sreq.req_id, put_rts->super.sreq.ep_id, - ucs_memory_type_names[put_rts->mem_type]); - break; - default: - snprintf(buffer, max, "RMA_RNDV [sub-id %" PRIu64 "]", *hdr); - break; - } -} - -UCP_DEFINE_AM_WITH_PROXY(UCP_FEATURE_RMA, UCP_AM_ID_RMA_RNDV, - ucp_rma_rndv_handler, ucp_rma_rndv_dump_packet, 0); - ucp_proto_t ucp_put_rndv_proto = { .name = "put/rndv", .desc = UCP_PROTO_RNDV_DESC, diff --git a/src/ucp/rma/rma_rndv.h b/src/ucp/rma/rma_rndv.h new file mode 100644 index 00000000000..4f49a1bde98 --- /dev/null +++ b/src/ucp/rma/rma_rndv.h @@ -0,0 +1,26 @@ +/** + * Copyright (c) NVIDIA CORPORATION & AFFILIATES, 2026. ALL RIGHTS RESERVED. + * + * See file LICENSE for terms. + */ + +#ifndef UCP_RMA_RNDV_H_ +#define UCP_RMA_RNDV_H_ + +#include +#include + + +typedef struct { + ucp_rndv_rts_hdr_t super; + uint64_t address; + ucs_sys_device_t sys_dev; + ucs_memory_type_t mem_type; +} UCS_S_PACKED ucp_rma_rndv_rts_hdr_t; + + +ucs_status_t ucp_rma_rndv_process_rts(ucp_worker_h worker, + const ucp_rma_rndv_rts_hdr_t *rts, + size_t length); + +#endif diff --git a/src/ucp/rndv/rndv.c b/src/ucp/rndv/rndv.c index e00a9aedaaa..9415ee1fc9e 100644 --- a/src/ucp/rndv/rndv.c +++ b/src/ucp/rndv/rndv.c @@ -17,6 +17,7 @@ #include #include #include +#include #include #include @@ -1759,9 +1760,11 @@ UCS_PROFILE_FUNC(ucs_status_t, ucp_rndv_rts_handler, if (ucp_rndv_rts_is_am(rts_hdr)) { return ucp_am_rndv_process_rts(arg, data, length, tl_flags); - } else { - ucs_assert(ucp_rndv_rts_is_tag(rts_hdr)); + } else if (ucp_rndv_rts_is_tag(rts_hdr)) { return ucp_tag_rndv_process_rts(worker, rts_hdr, length, tl_flags); + } else { + ucs_assert(ucp_rndv_rts_is_rma(rts_hdr)); + return ucp_rma_rndv_process_rts(worker, data, length); } } @@ -2436,6 +2439,7 @@ static void ucp_rndv_dump(ucp_worker_h worker, uct_am_trace_type_t type, { UCS_STRING_BUFFER_FIXED(strb, buffer, max); const ucp_rndv_rts_hdr_t *rndv_rts_hdr = data; + const ucp_rma_rndv_rts_hdr_t *rma_rts = data; const ucp_rndv_rtr_hdr_t *rndv_rtr_hdr = data; const ucp_rndv_rtr_req_hdr_t *rtr_req = data; const ucp_request_data_hdr_t *rndv_data = data; @@ -2450,13 +2454,19 @@ static void ucp_rndv_dump(ucp_worker_h worker, uct_am_trace_type_t type, if (ucp_rndv_rts_is_am(rndv_rts_hdr)) { ucs_string_buffer_appendf(&strb, "am_id %u", ucp_am_hdr_from_rts(rndv_rts_hdr)->am_id); - } else { - ucs_assert(ucp_rndv_rts_is_tag(rndv_rts_hdr)); + rkey_buf = rndv_rts_hdr + 1; + } else if (ucp_rndv_rts_is_tag(rndv_rts_hdr)) { ucs_string_buffer_appendf(&strb, "tag %" PRIx64, ucp_tag_hdr_from_rts(rndv_rts_hdr)->tag); + rkey_buf = rndv_rts_hdr + 1; + } else { + ucs_assert(ucp_rndv_rts_is_rma(rndv_rts_hdr)); + ucs_string_buffer_appendf( + &strb, "RMA_RTS dst 0x%" PRIx64 " %s", rma_rts->address, + ucs_memory_type_names[rma_rts->mem_type]); + rkey_buf = rma_rts + 1; } - rkey_buf = rndv_rts_hdr + 1; ucs_string_buffer_appendf(&strb, " ep_id 0x%" PRIx64 " sreq_id 0x%" PRIx64 " address 0x%" PRIx64 " size %zu", @@ -2526,8 +2536,9 @@ static void ucp_rndv_dump(ucp_worker_h worker, uct_am_trace_type_t type, } } -UCP_DEFINE_AM_WITH_PROXY(UCP_FEATURE_TAG | UCP_FEATURE_AM, UCP_AM_ID_RNDV_RTS, - ucp_rndv_rts_handler, ucp_rndv_dump, 0); +UCP_DEFINE_AM_WITH_PROXY(UCP_FEATURE_TAG | UCP_FEATURE_AM | UCP_FEATURE_RMA, + UCP_AM_ID_RNDV_RTS, ucp_rndv_rts_handler, + ucp_rndv_dump, 0); UCP_DEFINE_AM_WITH_PROXY(UCP_FEATURE_TAG | UCP_FEATURE_AM | UCP_FEATURE_RMA, UCP_AM_ID_RNDV_ATS, ucp_rndv_ats_handler, ucp_rndv_dump, 0); diff --git a/src/ucp/rndv/rndv.h b/src/ucp/rndv/rndv.h index 3ff6ca604a1..534e320cbaf 100644 --- a/src/ucp/rndv/rndv.h +++ b/src/ucp/rndv/rndv.h @@ -17,7 +17,9 @@ typedef enum { * the previous UCP versions) */ UCP_RNDV_RTS_TAG_OK = UCS_OK, /* RNDV AM operation */ - UCP_RNDV_RTS_AM = 1 + UCP_RNDV_RTS_AM = 1, + /* One-sided RNDV operation */ + UCP_RNDV_RTS_RMA = 2 } UCS_S_PACKED ucp_rndv_rts_opcode_t; diff --git a/src/ucp/rndv/rndv.inl b/src/ucp/rndv/rndv.inl index e3f4f6e198f..43dbab85191 100644 --- a/src/ucp/rndv/rndv.inl +++ b/src/ucp/rndv/rndv.inl @@ -23,6 +23,12 @@ ucp_rndv_rts_is_tag(const ucp_rndv_rts_hdr_t *rts_hdr) return rts_hdr->opcode == UCP_RNDV_RTS_TAG_OK; } +static UCS_F_ALWAYS_INLINE int +ucp_rndv_rts_is_rma(const ucp_rndv_rts_hdr_t *rts_hdr) +{ + return rts_hdr->opcode == UCP_RNDV_RTS_RMA; +} + static UCS_F_ALWAYS_INLINE void ucp_rndv_receive_start(ucp_worker_h worker, ucp_request_t *rreq, const ucp_rndv_rts_hdr_t *rndv_rts_hdr, From 2eb3d8e2f8b6e409dad99ff97eb7ef93951d7647 Mon Sep 17 00:00:00 2001 From: Thomas Vegas Date: Fri, 22 May 2026 18:25:58 +0200 Subject: [PATCH 08/15] UCP/RMA: Add GTEST for RMA rendezvous --- src/ucp/rma/rma_rndv.c | 2 +- test/gtest/ucp/test_ucp_rma.cc | 79 +++++++++++++++++++++++++++++++++- 2 files changed, 79 insertions(+), 2 deletions(-) diff --git a/src/ucp/rma/rma_rndv.c b/src/ucp/rma/rma_rndv.c index 639cf4d42bb..3a9ca9cbdd6 100644 --- a/src/ucp/rma/rma_rndv.c +++ b/src/ucp/rma/rma_rndv.c @@ -137,7 +137,7 @@ ucp_proto_put_rndv_probe(const ucp_proto_init_params_t *init_params) .super.overhead = context->config.ext.proto_overhead_rndv_rts, .super.cfg_thresh = context->config.ext.zcopy_thresh, .super.cfg_priority = 5, - .super.min_length = 1, + .super.min_length = 0, .super.max_length = SIZE_MAX, .super.min_iov = 1, .super.min_frag_offs = UCP_PROTO_COMMON_OFFSET_INVALID, diff --git a/test/gtest/ucp/test_ucp_rma.cc b/test/gtest/ucp/test_ucp_rma.cc index e9606bafbc2..a5c93fd131a 100755 --- a/test/gtest/ucp/test_ucp_rma.cc +++ b/test/gtest/ucp/test_ucp_rma.cc @@ -201,7 +201,7 @@ class test_ucp_rma : public test_ucp_memheap { return get_variant_value() & USER_MEMH; } -private: +protected: /* Test variants */ enum { FLUSH_EP = UCS_BIT(0), /* If not set, flush worker */ @@ -388,6 +388,83 @@ UCS_TEST_P(test_ucp_rma_dmabuf, put_registration_offset) UCP_INSTANTIATE_TEST_CASE_TLS(test_ucp_rma_dmabuf, ib_cuda, "ib,cuda_copy") +class test_ucp_rma_rndv : public test_ucp_rma { +public: + static constexpr size_t SIZE = 512 * UCS_KBYTE; + + static void get_test_variants(std::vector &variants) + { + add_variant_with_value(variants, UCP_FEATURE_RMA, 0, "flush_worker"); + add_variant_with_value(variants, UCP_FEATURE_RMA, FLUSH_EP, "flush_ep"); + } + + test_ucp_rma_rndv() + { + modify_config("PROTOS", "put/rndv,get/rndv,rndv/*"); + } + + void init() override + { + m_env.push_back( + new ucs::scoped_setenv("UCX_IB_GPU_DIRECT_RDMA", "n")); + test_ucp_rma::init(); + } + +protected: + static bool is_rndv_mem_type_pair(ucs_memory_type_t local_mem_type, + ucs_memory_type_t remote_mem_type) + { + return !UCP_MEM_IS_HOST(local_mem_type) || + !UCP_MEM_IS_HOST(remote_mem_type); + } + + void test_forced_rndv(send_func_t send_func) + { + unsigned num_tested = 0; + + for (const auto &pair : ucs::supported_mem_type_pairs()) { + if (!is_rndv_mem_type_pair(pair[0], pair[1])) { + continue; + } + + test_message_sizes(send_func, 128, default_max_size(), + pair[0], pair[1], 0); + ++num_tested; + if (HasFailure() || (num_errors() > 0)) { + break; + } + } + + if (num_tested == 0) { + UCS_TEST_SKIP_R("no memory type pair supports RMA/RNDV"); + } + } +}; + +UCS_TEST_P(test_ucp_rma_rndv, put_blocking) +{ + test_forced_rndv(static_cast(&test_ucp_rma::put_b)); +} + +UCS_TEST_P(test_ucp_rma_rndv, put_nonblocking) +{ + test_forced_rndv(static_cast(&test_ucp_rma::put_nbi)); +} + +UCS_TEST_P(test_ucp_rma_rndv, get_blocking) +{ + test_forced_rndv(static_cast(&test_ucp_rma::get_b)); +} + +UCS_TEST_P(test_ucp_rma_rndv, get_nonblocking) +{ + UCS_TEST_SKIP_R("TODO Fix crash"); + test_forced_rndv(static_cast(&test_ucp_rma::get_nbi)); +} + +UCP_INSTANTIATE_TEST_CASE_TLS_GPU_AWARE(test_ucp_rma_rndv, ib, "ib") + + class test_ucp_proto_emulation_enable : public test_ucp_rma { public: static constexpr size_t SMALL_SIZE = 8; From 22a4fc1c90c294cb102565ec50561baf434fa343 Mon Sep 17 00:00:00 2001 From: Thomas Vegas Date: Sat, 23 May 2026 11:53:14 +0300 Subject: [PATCH 09/15] UCP/RMA: Block flush when Get RMA rendezvous is in progress --- src/ucp/rma/rma_rndv.c | 4 ++++ src/ucp/rndv/rndv_rtr.c | 6 +++++- test/gtest/ucp/test_ucp_rma.cc | 1 - 3 files changed, 9 insertions(+), 2 deletions(-) diff --git a/src/ucp/rma/rma_rndv.c b/src/ucp/rma/rma_rndv.c index 3a9ca9cbdd6..e1343ec5d4f 100644 --- a/src/ucp/rma/rma_rndv.c +++ b/src/ucp/rma/rma_rndv.c @@ -356,8 +356,10 @@ static ucs_status_t ucp_proto_get_rndv_reset(ucp_request_t *req) static void ucp_rma_rndv_get_recv_complete(ucp_request_t *recv_req) { ucp_request_t *get_req = ucp_request_get_super(recv_req); + ucp_ep_h ep = get_req->send.ep; ucp_request_complete_send(get_req, recv_req->status); + ucp_ep_rma_remote_request_completed(ep); ucp_request_put(recv_req); } @@ -452,6 +454,8 @@ static ucs_status_t ucp_proto_get_rndv_progress(uct_pending_req_t *self) return UCS_OK; } + ucp_worker_flush_ops_count_add(get_req->send.ep->worker, +1); + ucp_ep_rma_remote_request_sent(get_req->send.ep); ucp_request_send(rndv_req); return UCS_OK; } diff --git a/src/ucp/rndv/rndv_rtr.c b/src/ucp/rndv/rndv_rtr.c index bfe08c0100e..d2668db9ec9 100644 --- a/src/ucp/rndv/rndv_rtr.c +++ b/src/ucp/rndv/rndv_rtr.c @@ -155,15 +155,19 @@ static size_t ucp_proto_rndv_rtr_req_pack(void *dest, void *arg) ucp_request_t *req = arg; ucp_rndv_rtr_req_hdr_t *rtr_req = dest; const ucp_proto_rndv_rtr_priv_t *rpriv = req->send.proto_config->priv; + void *rkey_src, *rkey_dst; size_t rkey_size, packed_size; packed_size = rpriv->pack_cb(&rtr_req->super, req); rkey_size = packed_size - sizeof(rtr_req->super); if (rkey_size > 0) { - memmove(rtr_req + 1, &rtr_req->super + 1, rkey_size); + rkey_src = UCS_PTR_BYTE_OFFSET(rtr_req, sizeof(rtr_req->super)); + rkey_dst = UCS_PTR_BYTE_OFFSET(rtr_req, sizeof(*rtr_req)); + memmove(rkey_dst, rkey_src, rkey_size); } rtr_req->super.sreq_id = UCS_PTR_MAP_KEY_INVALID; + rtr_req->super.offset = 0; rtr_req->req.ep_id = ucp_send_request_get_ep_remote_id(req); rtr_req->req.req_id = UCS_PTR_MAP_KEY_INVALID; rtr_req->address = req->send.rndv.remote_address; diff --git a/test/gtest/ucp/test_ucp_rma.cc b/test/gtest/ucp/test_ucp_rma.cc index a5c93fd131a..60d2cad34b8 100755 --- a/test/gtest/ucp/test_ucp_rma.cc +++ b/test/gtest/ucp/test_ucp_rma.cc @@ -458,7 +458,6 @@ UCS_TEST_P(test_ucp_rma_rndv, get_blocking) UCS_TEST_P(test_ucp_rma_rndv, get_nonblocking) { - UCS_TEST_SKIP_R("TODO Fix crash"); test_forced_rndv(static_cast(&test_ucp_rma::get_nbi)); } From 44c2a73bb3c8ed86eb95527f73eda6ce3251d217 Mon Sep 17 00:00:00 2001 From: Thomas Vegas Date: Sun, 24 May 2026 12:41:34 +0300 Subject: [PATCH 10/15] UCP/RMA: Get RMA rendezvous waits for endpoint --- src/ucp/rma/rma_rndv.c | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/src/ucp/rma/rma_rndv.c b/src/ucp/rma/rma_rndv.c index e1343ec5d4f..dd2369bad3c 100644 --- a/src/ucp/rma/rma_rndv.c +++ b/src/ucp/rma/rma_rndv.c @@ -377,11 +377,6 @@ static ucs_status_t ucp_proto_get_rndv_init(ucp_request_t *get_req, uint64_t address; size_t length; - status = ucp_ep_resolve_remote_id(get_req->send.ep, rpriv->lane); - if (status != UCS_OK) { - return status; - } - status = ucp_ep_rma_handle_fence(get_req->send.ep, get_req, UCS_BIT(rpriv->lane)); if (status != UCS_OK) { @@ -441,6 +436,7 @@ static ucs_status_t ucp_proto_get_rndv_init(ucp_request_t *get_req, static ucs_status_t ucp_proto_get_rndv_progress(uct_pending_req_t *self) { ucp_request_t *get_req = ucs_container_of(self, ucp_request_t, send.uct); + const ucp_proto_rndv_ctrl_priv_t *rpriv; ucp_request_t *rndv_req; ucs_status_t status; @@ -448,6 +444,21 @@ static ucs_status_t ucp_proto_get_rndv_progress(uct_pending_req_t *self) return UCS_OK; } + rpriv = get_req->send.proto_config->priv; + get_req->send.lane = rpriv->lane; + status = ucp_ep_resolve_remote_id(get_req->send.ep, + rpriv->lane); + if (status == UCS_ERR_NO_RESOURCE) { + return status; + } else if (status != UCS_OK) { + ucp_proto_request_abort(get_req, status); + return UCS_OK; + } + + if (!(get_req->send.ep->flags & UCP_EP_FLAG_FLUSH_STATE_VALID)) { + return UCS_ERR_NO_RESOURCE; + } + status = ucp_proto_get_rndv_init(get_req, &rndv_req); if (status != UCS_OK) { ucp_proto_request_abort(get_req, status); From 7bf40d14003d2109710c3b9aeebc7e2a36b219e3 Mon Sep 17 00:00:00 2001 From: Thomas Vegas Date: Sun, 24 May 2026 13:53:15 +0300 Subject: [PATCH 11/15] UCP/RMA: Put rendezvous blocks flush --- src/ucp/core/ucp_request.c | 1 + src/ucp/core/ucp_request.h | 3 ++- src/ucp/core/ucp_request.inl | 11 +++++++++++ src/ucp/proto/proto_common.inl | 13 ++++++++++++- src/ucp/rma/rma_rndv.c | 24 ++++++++++++++++-------- src/ucp/rndv/proto_rndv.c | 3 ++- src/ucp/rndv/proto_rndv.inl | 24 ++++++++++++++++++++++++ src/ucp/rndv/rndv.c | 2 ++ src/ucp/rndv/rndv_am.c | 6 ++++-- src/ucp/rndv/rndv_ppln.c | 3 ++- src/ucp/rndv/rndv_put.c | 3 ++- src/ucp/rndv/rndv_rkey_ptr.c | 2 +- 12 files changed, 79 insertions(+), 16 deletions(-) diff --git a/src/ucp/core/ucp_request.c b/src/ucp/core/ucp_request.c index 1e2b40d67f9..dde6eb62b8c 100644 --- a/src/ucp/core/ucp_request.c +++ b/src/ucp/core/ucp_request.c @@ -53,6 +53,7 @@ static const char *ucp_request_flag_names[] = { #endif [ucs_ilog2(UCP_REQUEST_FLAG_RNDV_SEND_INTERNAL)] = "rndv_snd_int", [ucs_ilog2(UCP_REQUEST_FLAG_RNDV_RTR_REQ)] = "rndv_rtr_req", + [ucs_ilog2(UCP_REQUEST_FLAG_RNDV_FLUSH)] = "rndv_flush", }; static ucs_memory_type_t ucp_request_get_mem_type(ucp_request_t *req) diff --git a/src/ucp/core/ucp_request.h b/src/ucp/core/ucp_request.h index 5a50904c110..3bc8e278edb 100644 --- a/src/ucp/core/ucp_request.h +++ b/src/ucp/core/ucp_request.h @@ -67,7 +67,8 @@ enum { UCP_REQUEST_FLAG_SUPER_VALID = 0, #endif UCP_REQUEST_FLAG_RNDV_SEND_INTERNAL = UCS_BIT(26), - UCP_REQUEST_FLAG_RNDV_RTR_REQ = UCS_BIT(27) + UCP_REQUEST_FLAG_RNDV_RTR_REQ = UCS_BIT(27), + UCP_REQUEST_FLAG_RNDV_FLUSH = UCS_BIT(28) }; diff --git a/src/ucp/core/ucp_request.inl b/src/ucp/core/ucp_request.inl index 5eb90289be5..0c6c5a9f963 100644 --- a/src/ucp/core/ucp_request.inl +++ b/src/ucp/core/ucp_request.inl @@ -240,6 +240,17 @@ ucp_request_put(ucp_request_t *req) ucs_mpool_put_inline(req); } +static UCS_F_ALWAYS_INLINE void +ucp_request_rndv_flush_complete(ucp_request_t *req) +{ + /* Complete the extra flush op held by a RNDV wrapper until the RNDV data + * path completes. */ + if (ucs_unlikely(req->flags & UCP_REQUEST_FLAG_RNDV_FLUSH)) { + req->flags &= ~UCP_REQUEST_FLAG_RNDV_FLUSH; + ucp_worker_flush_ops_count_add(req->send.ep->worker, -1); + } +} + static UCS_F_ALWAYS_INLINE void ucp_request_complete_send(ucp_request_t *req, ucs_status_t status) { diff --git a/src/ucp/proto/proto_common.inl b/src/ucp/proto/proto_common.inl index 8cb0747355e..c702e3672e0 100644 --- a/src/ucp/proto/proto_common.inl +++ b/src/ucp/proto/proto_common.inl @@ -94,7 +94,8 @@ ucp_proto_request_zcopy_clean(ucp_request_t *req, unsigned dt_mask) } static UCS_F_ALWAYS_INLINE void -ucp_proto_request_zcopy_complete(ucp_request_t *req, ucs_status_t status) +ucp_proto_request_zcopy_complete_cb(ucp_request_t *req, ucs_status_t status, + ucp_request_callback_t complete_cb) { ucp_datatype_iter_cleanup(&req->send.state.dt_iter, 1, UCP_DT_MASK_CONTIG_IOV); @@ -108,10 +109,20 @@ ucp_proto_request_zcopy_complete(ucp_request_t *req, ucs_status_t status) !(req->send.ep->flags & UCP_EP_FLAG_FAILED)) { ucp_proto_request_restart(req); } else { + if (complete_cb != NULL) { + complete_cb(req); + } + ucp_request_complete_send(req, status); } } +static UCS_F_ALWAYS_INLINE void +ucp_proto_request_zcopy_complete(ucp_request_t *req, ucs_status_t status) +{ + ucp_proto_request_zcopy_complete_cb(req, status, NULL); +} + static UCS_F_ALWAYS_INLINE ucs_status_t ucp_proto_request_zcopy_complete_success(ucp_request_t *req) { diff --git a/src/ucp/rma/rma_rndv.c b/src/ucp/rma/rma_rndv.c index dd2369bad3c..a0b4f570f9e 100644 --- a/src/ucp/rma/rma_rndv.c +++ b/src/ucp/rma/rma_rndv.c @@ -108,22 +108,30 @@ static ucs_status_t ucp_proto_put_rndv_progress(uct_pending_req_t *self) rpriv = req->send.proto_config->priv; max_rts_size = sizeof(ucp_rma_rndv_rts_hdr_t) + rpriv->packed_rkey_size; - ucp_worker_flush_ops_count_add(ep->worker, +1); + req->flags |= UCP_REQUEST_FLAG_RNDV_FLUSH; + ucp_worker_flush_ops_count_add(ep->worker, +2); status = ucp_proto_am_bcopy_single_send(req, UCP_AM_ID_RNDV_RTS, rpriv->lane, ucp_proto_put_rndv_rts_pack, req, max_rts_size, 0); + if (status != UCS_OK) { + if (status == UCS_ERR_NO_RESOURCE) { + req->send.lane = rpriv->lane; + } + goto err_flush_count; + } + + ucp_ep_rma_remote_request_sent(ep); + return UCS_OK; + +err_flush_count: + req->flags &= ~UCP_REQUEST_FLAG_RNDV_FLUSH; + ucp_worker_flush_ops_count_add(ep->worker, -2); if (status == UCS_ERR_NO_RESOURCE) { - ucp_worker_flush_ops_count_add(ep->worker, -1); - req->send.lane = rpriv->lane; return status; - } else if (status != UCS_OK) { - ucp_worker_flush_ops_count_add(ep->worker, -1); - ucp_proto_request_abort(req, status); - return UCS_OK; } - ucp_ep_rma_remote_request_sent(ep); + ucp_proto_request_abort(req, status); return UCS_OK; } diff --git a/src/ucp/rndv/proto_rndv.c b/src/ucp/rndv/proto_rndv.c index fecdea52893..f3c3524a03d 100644 --- a/src/ucp/rndv/proto_rndv.c +++ b/src/ucp/rndv/proto_rndv.c @@ -580,6 +580,7 @@ void ucp_proto_rndv_rts_query(const ucp_proto_query_params_t *params, void ucp_proto_rndv_rts_abort(ucp_request_t *req, ucs_status_t status) { ucp_am_release_user_header(req); + ucp_request_rndv_flush_complete(req); if (ucp_request_memh_invalidate(req, status)) { ucp_proto_rndv_rts_reset(req); @@ -910,7 +911,7 @@ static void ucp_proto_rndv_send_complete_one(void *request, ucs_status_t status, } ucp_send_request_id_release(req); - ucp_proto_request_zcopy_complete(req, status); + ucp_proto_rndv_request_zcopy_complete(req, status); } static void diff --git a/src/ucp/rndv/proto_rndv.inl b/src/ucp/rndv/proto_rndv.inl index 9970e421ffe..07db2f0de5c 100644 --- a/src/ucp/rndv/proto_rndv.inl +++ b/src/ucp/rndv/proto_rndv.inl @@ -89,11 +89,35 @@ ucp_proto_rndv_ats_handler(void *arg, void *data, size_t length, unsigned flags) ucp_send_request_id_release(req); ucp_datatype_iter_cleanup(&req->send.state.dt_iter, 1, UCP_DT_MASK_ALL); + ucp_request_rndv_flush_complete(req); ucp_request_complete_send(req, status); return UCS_OK; } +static UCS_F_ALWAYS_INLINE void +ucp_proto_rndv_request_zcopy_complete(ucp_request_t *req, ucs_status_t status) +{ + ucp_proto_request_zcopy_complete_cb(req, status, + ucp_request_rndv_flush_complete); +} + +static UCS_F_ALWAYS_INLINE ucs_status_t +ucp_proto_rndv_request_zcopy_complete_success(ucp_request_t *req) +{ + ucp_proto_rndv_request_zcopy_complete(req, UCS_OK); + return UCS_OK; +} + +static UCS_F_ALWAYS_INLINE void +ucp_proto_rndv_request_zcopy_completion(uct_completion_t *self) +{ + ucp_request_t *req = ucs_container_of(self, ucp_request_t, + send.state.uct_comp); + + ucp_proto_rndv_request_zcopy_complete(req, req->send.state.uct_comp.status); +} + static UCS_F_ALWAYS_INLINE size_t ucp_proto_rndv_rts_pack( ucp_request_t *req, ucp_rndv_rts_hdr_t *rts, size_t hdr_len) { diff --git a/src/ucp/rndv/rndv.c b/src/ucp/rndv/rndv.c index 9415ee1fc9e..2680b9613de 100644 --- a/src/ucp/rndv/rndv.c +++ b/src/ucp/rndv/rndv.c @@ -423,6 +423,7 @@ static void ucp_rndv_complete_rma_put_zcopy(ucp_request_t *sreq, int is_frag_put } ucp_request_send_buffer_dereg(sreq); + ucp_request_rndv_flush_complete(sreq); ucp_request_complete_send(sreq, status); } @@ -1891,6 +1892,7 @@ static void ucp_rndv_am_zcopy_send_req_complete(ucp_request_t *req, { ucs_assert(req->send.state.uct_comp.count == 0); ucp_request_send_buffer_dereg(req); + ucp_request_rndv_flush_complete(req); ucp_request_complete_send(req, status); } diff --git a/src/ucp/rndv/rndv_am.c b/src/ucp/rndv/rndv_am.c index adc94ce1c8b..c17f2033080 100644 --- a/src/ucp/rndv/rndv_am.c +++ b/src/ucp/rndv/rndv_am.c @@ -87,6 +87,7 @@ ucp_proto_rndv_am_bcopy_complete(ucp_request_t *req) { ucp_rndv_am_destroy_rkey(req); ucp_datatype_iter_mem_dereg(&req->send.state.dt_iter, UCP_DT_MASK_ALL); + ucp_request_rndv_flush_complete(req); return ucp_proto_request_bcopy_complete_success(req); } @@ -128,7 +129,8 @@ ucp_proto_rndv_am_bcopy_abort(ucp_request_t *req, ucs_status_t status) { ucp_rndv_am_destroy_rkey(req); ucp_datatype_iter_mem_dereg(&req->send.state.dt_iter, UCP_DT_MASK_ALL); - ucp_proto_request_bcopy_abort(req,status); + ucp_request_rndv_flush_complete(req); + ucp_proto_request_bcopy_abort(req, status); } ucp_proto_t ucp_rndv_am_bcopy_proto = { @@ -173,7 +175,7 @@ static ucs_status_t ucp_rndv_am_zcopy_proto_progress(uct_pending_req_t *uct_req) UCP_DT_MASK_CONTIG_IOV, ucp_rndv_am_zcopy_send_func, ucp_rndv_am_zcopy_complete, - ucp_proto_request_zcopy_completion); + ucp_proto_rndv_request_zcopy_completion); } static void ucp_rndv_am_zcopy_probe(const ucp_proto_init_params_t *init_params) diff --git a/src/ucp/rndv/rndv_ppln.c b/src/ucp/rndv/rndv_ppln.c index 9b407c4a669..18584268a7f 100644 --- a/src/ucp/rndv/rndv_ppln.c +++ b/src/ucp/rndv/rndv_ppln.c @@ -231,6 +231,7 @@ ucp_proto_rndv_ppln_frag_complete(ucp_request_t *freq, int send_ack, int abort, ucp_proto_request_set_stage(req, UCP_PROTO_RNDV_PPLN_STAGE_ACK); ucp_request_send(req); } else { + ucp_request_rndv_flush_complete(req); complete_func(req); } } @@ -337,7 +338,7 @@ ucp_proto_rndv_send_ppln_atp_progress(uct_pending_req_t *uct_req) return ucp_proto_rndv_ack_progress(req, &rpriv->ack, UCP_AM_ID_RNDV_ATP, ucp_proto_rndv_ppln_pack_ack, - ucp_proto_request_zcopy_complete_success); + ucp_proto_rndv_request_zcopy_complete_success); } ucp_proto_t ucp_rndv_send_ppln_proto = { diff --git a/src/ucp/rndv/rndv_put.c b/src/ucp/rndv/rndv_put.c index 68989a923ef..225844ae5a2 100644 --- a/src/ucp/rndv/rndv_put.c +++ b/src/ucp/rndv/rndv_put.c @@ -47,7 +47,8 @@ ucp_proto_rndv_put_common_complete(ucp_request_t *req) UCS_STATS_UPDATE_COUNTER(req->send.ep->worker->stats, rpriv->stat_counter, +1); ucp_proto_rndv_rkey_destroy(req); - ucp_proto_request_zcopy_complete(req, req->send.state.uct_comp.status); + ucp_proto_rndv_request_zcopy_complete(req, + req->send.state.uct_comp.status); } static void ucp_proto_rndv_put_zcopy_completion(uct_completion_t *uct_comp) diff --git a/src/ucp/rndv/rndv_rkey_ptr.c b/src/ucp/rndv/rndv_rkey_ptr.c index 999c2e04ae3..2555b123073 100644 --- a/src/ucp/rndv/rndv_rkey_ptr.c +++ b/src/ucp/rndv/rndv_rkey_ptr.c @@ -317,7 +317,7 @@ static ucs_status_t ucp_proto_rndv_rkey_ptr_mtype_completion(ucp_request_t *req) { ucp_trace_req(req, "ucp_proto_rndv_rkey_ptr_mtype_completion"); ucp_proto_rndv_rkey_destroy(req); - ucp_proto_request_zcopy_complete(req, UCS_OK); + ucp_proto_rndv_request_zcopy_complete(req, UCS_OK); return UCS_OK; } From 33859eb29959843462d5b0d0c44b16371378841f Mon Sep 17 00:00:00 2001 From: Thomas Vegas Date: Sun, 24 May 2026 15:21:31 +0300 Subject: [PATCH 12/15] UCP/RMA: Fix build failure --- src/ucp/rndv/rndv_ppln.c | 29 ++++++++++++++++++----------- 1 file changed, 18 insertions(+), 11 deletions(-) diff --git a/src/ucp/rndv/rndv_ppln.c b/src/ucp/rndv/rndv_ppln.c index 18584268a7f..76b88fa97e6 100644 --- a/src/ucp/rndv/rndv_ppln.c +++ b/src/ucp/rndv/rndv_ppln.c @@ -204,9 +204,8 @@ static void ucp_proto_rndv_ppln_query(const ucp_proto_query_params_t *params, attr->lane_map |= UCS_BIT(rpriv->ack.lane); } -static void +static ucp_request_t * ucp_proto_rndv_ppln_frag_complete(ucp_request_t *freq, int send_ack, int abort, - ucp_proto_complete_cb_t complete_func, const char *title) { ucp_request_t *req = ucp_request_get_super(freq); @@ -218,7 +217,7 @@ ucp_proto_rndv_ppln_frag_complete(ucp_request_t *freq, int send_ack, int abort, /* In case of abort we don't destroy super request until all fragments are * completed */ if (!ucp_proto_rndv_frag_complete(req, freq, title)) { - return; + return NULL; } if (req->send.rndv.rkey != NULL) { @@ -230,25 +229,33 @@ ucp_proto_rndv_ppln_frag_complete(ucp_request_t *freq, int send_ack, int abort, if (!abort && (req->send.rndv.ppln.ack_data_size > 0)) { ucp_proto_request_set_stage(req, UCP_PROTO_RNDV_PPLN_STAGE_ACK); ucp_request_send(req); + return NULL; } else { - ucp_request_rndv_flush_complete(req); - complete_func(req); + return req; } } void ucp_proto_rndv_ppln_send_frag_complete(ucp_request_t *freq, int send_ack) { - ucp_proto_rndv_ppln_frag_complete(freq, send_ack, 0, - ucp_proto_request_complete_success, - "ppln_send"); + ucp_request_t *req; + + req = ucp_proto_rndv_ppln_frag_complete(freq, send_ack, 0, "ppln_send"); + if (req != NULL) { + ucp_request_rndv_flush_complete(req); + ucp_proto_request_complete_success(req); + } } void ucp_proto_rndv_ppln_recv_frag_complete(ucp_request_t *freq, int send_ack, int abort) { - ucp_proto_rndv_ppln_frag_complete(freq, send_ack, abort, - ucp_proto_rndv_recv_complete, - "ppln_recv"); + ucp_request_t *req; + + req = ucp_proto_rndv_ppln_frag_complete(freq, send_ack, abort, "ppln_recv"); + if (req != NULL) { + ucp_request_rndv_flush_complete(req); + ucp_proto_rndv_recv_complete(req); + } } static ucs_status_t ucp_proto_rndv_ppln_progress(uct_pending_req_t *uct_req) From 93544817da51b2bcbb4e2856ec9b219e1d594c4b Mon Sep 17 00:00:00 2001 From: Thomas Vegas Date: Fri, 29 May 2026 21:23:37 +0300 Subject: [PATCH 13/15] UCP/RMA/RNDV: Flush in RNDV code --- src/ucp/core/ucp_request.c | 1 + src/ucp/core/ucp_request.h | 3 ++- src/ucp/rma/rma_rndv.c | 53 ++++++++++++++++++++++++++++++++------ src/ucp/rma/rma_rndv.h | 5 ++++ src/ucp/rndv/rndv_rtr.c | 14 ++++++++-- 5 files changed, 65 insertions(+), 11 deletions(-) diff --git a/src/ucp/core/ucp_request.c b/src/ucp/core/ucp_request.c index dde6eb62b8c..b192baa3cb5 100644 --- a/src/ucp/core/ucp_request.c +++ b/src/ucp/core/ucp_request.c @@ -54,6 +54,7 @@ static const char *ucp_request_flag_names[] = { [ucs_ilog2(UCP_REQUEST_FLAG_RNDV_SEND_INTERNAL)] = "rndv_snd_int", [ucs_ilog2(UCP_REQUEST_FLAG_RNDV_RTR_REQ)] = "rndv_rtr_req", [ucs_ilog2(UCP_REQUEST_FLAG_RNDV_FLUSH)] = "rndv_flush", + [ucs_ilog2(UCP_REQUEST_FLAG_RNDV_START_FLUSH)] = "rndv_start_flush", }; static ucs_memory_type_t ucp_request_get_mem_type(ucp_request_t *req) diff --git a/src/ucp/core/ucp_request.h b/src/ucp/core/ucp_request.h index 3bc8e278edb..b5dc1ab0090 100644 --- a/src/ucp/core/ucp_request.h +++ b/src/ucp/core/ucp_request.h @@ -68,7 +68,8 @@ enum { #endif UCP_REQUEST_FLAG_RNDV_SEND_INTERNAL = UCS_BIT(26), UCP_REQUEST_FLAG_RNDV_RTR_REQ = UCS_BIT(27), - UCP_REQUEST_FLAG_RNDV_FLUSH = UCS_BIT(28) + UCP_REQUEST_FLAG_RNDV_FLUSH = UCS_BIT(28), + UCP_REQUEST_FLAG_RNDV_START_FLUSH = UCS_BIT(29) }; diff --git a/src/ucp/rma/rma_rndv.c b/src/ucp/rma/rma_rndv.c index a0b4f570f9e..241117bf21b 100644 --- a/src/ucp/rma/rma_rndv.c +++ b/src/ucp/rma/rma_rndv.c @@ -361,13 +361,55 @@ static ucs_status_t ucp_proto_get_rndv_reset(ucp_request_t *req) return UCS_OK; } +ucp_request_t *ucp_rma_rndv_rtr_flush_open(ucp_request_t *rtr_req) +{ + ucp_request_t *recv_req = rtr_req; + ucp_ep_h ep = rtr_req->send.ep; + + if (!(rtr_req->flags & UCP_REQUEST_FLAG_RNDV_RTR_REQ)) { + return NULL; + } + + while (!(recv_req->flags & UCP_REQUEST_FLAG_RNDV_RECV_INTERNAL)) { + ucs_assert(recv_req->flags & UCP_REQUEST_FLAG_RNDV_RTR_REQ); + recv_req = ucp_request_get_super(recv_req); + } + + if (recv_req->flags & UCP_REQUEST_FLAG_RNDV_START_FLUSH) { + /* Claim before the AM send, since SELF can complete inline. */ + recv_req->flags &= ~UCP_REQUEST_FLAG_RNDV_START_FLUSH; + ucp_worker_flush_ops_count_add(ep->worker, +1); + return recv_req; + } + + return NULL; +} + +void ucp_rma_rndv_rtr_flush_close(ucp_request_t *recv_req, ucp_ep_h ep, + ucs_status_t status) +{ + if (recv_req != NULL) { + if (status == UCS_OK) { + /* recv_req may complete inline, so only touch ep on success. */ + ucp_ep_rma_remote_request_sent(ep); + } else { + ucp_worker_flush_ops_count_add(ep->worker, -1); + recv_req->flags |= UCP_REQUEST_FLAG_RNDV_START_FLUSH; + } + } +} + static void ucp_rma_rndv_get_recv_complete(ucp_request_t *recv_req) { ucp_request_t *get_req = ucp_request_get_super(recv_req); ucp_ep_h ep = get_req->send.ep; + int start_flush; + start_flush = recv_req->flags & UCP_REQUEST_FLAG_RNDV_START_FLUSH; ucp_request_complete_send(get_req, recv_req->status); - ucp_ep_rma_remote_request_completed(ep); + if (!start_flush) { + ucp_ep_rma_remote_request_completed(ep); + } ucp_request_put(recv_req); } @@ -412,7 +454,8 @@ static ucs_status_t ucp_proto_get_rndv_init(ucp_request_t *get_req, } get_req->flags |= UCP_REQUEST_FLAG_PROTO_INITIALIZED; - recv_req->flags = UCP_REQUEST_FLAG_RNDV_RECV_INTERNAL; + recv_req->flags = UCP_REQUEST_FLAG_RNDV_RECV_INTERNAL | + UCP_REQUEST_FLAG_RNDV_START_FLUSH; recv_req->recv.worker = worker; recv_req->recv.op_attr = 0; recv_req->recv.remote_req_id = UCS_PTR_MAP_KEY_INVALID; @@ -463,18 +506,12 @@ static ucs_status_t ucp_proto_get_rndv_progress(uct_pending_req_t *self) return UCS_OK; } - if (!(get_req->send.ep->flags & UCP_EP_FLAG_FLUSH_STATE_VALID)) { - return UCS_ERR_NO_RESOURCE; - } - status = ucp_proto_get_rndv_init(get_req, &rndv_req); if (status != UCS_OK) { ucp_proto_request_abort(get_req, status); return UCS_OK; } - ucp_worker_flush_ops_count_add(get_req->send.ep->worker, +1); - ucp_ep_rma_remote_request_sent(get_req->send.ep); ucp_request_send(rndv_req); return UCS_OK; } diff --git a/src/ucp/rma/rma_rndv.h b/src/ucp/rma/rma_rndv.h index 4f49a1bde98..906c41bd40b 100644 --- a/src/ucp/rma/rma_rndv.h +++ b/src/ucp/rma/rma_rndv.h @@ -23,4 +23,9 @@ ucs_status_t ucp_rma_rndv_process_rts(ucp_worker_h worker, const ucp_rma_rndv_rts_hdr_t *rts, size_t length); +ucp_request_t *ucp_rma_rndv_rtr_flush_open(ucp_request_t *rtr_req); + +void ucp_rma_rndv_rtr_flush_close(ucp_request_t *recv_req, ucp_ep_h ep, + ucs_status_t status); + #endif diff --git a/src/ucp/rndv/rndv_rtr.c b/src/ucp/rndv/rndv_rtr.c index d2668db9ec9..8793790c9af 100644 --- a/src/ucp/rndv/rndv_rtr.c +++ b/src/ucp/rndv/rndv_rtr.c @@ -14,6 +14,7 @@ #include #include #include +#include #include @@ -69,16 +70,25 @@ static ucs_status_t ucp_proto_rndv_rtr_common_send(ucp_request_t *req) const ucp_proto_rndv_rtr_priv_t *rpriv = req->send.proto_config->priv; ucp_worker_h UCS_V_UNUSED worker = req->send.ep->worker; uct_pack_callback_t pack_cb = rpriv->pack_cb; + ucp_request_t *recv_req = NULL; + ucp_ep_h ep = req->send.ep; size_t max_rtr_size; ucs_status_t status; if (req->flags & UCP_REQUEST_FLAG_RNDV_RTR_REQ) { pack_cb = ucp_proto_rndv_rtr_req_pack; max_rtr_size = ucp_proto_rndv_rtr_req_max_size(req); - } else { - max_rtr_size = ucp_proto_rndv_rtr_max_size(req); + recv_req = ucp_rma_rndv_rtr_flush_open(req); + status = ucp_proto_am_bcopy_single_send( + req, UCP_AM_ID_RNDV_RTR, rpriv->super.lane, pack_cb, req, + max_rtr_size, 0); + ucp_rma_rndv_rtr_flush_close(recv_req, ep, status); + + return ucp_proto_single_status_handle(req, 0, NULL, + rpriv->super.lane, status); } + max_rtr_size = ucp_proto_rndv_rtr_max_size(req); status = ucp_proto_am_bcopy_single_progress(req, UCP_AM_ID_RNDV_RTR, rpriv->super.lane, pack_cb, req, max_rtr_size, NULL, 0); From 4c0fe95bbb8e7c8ce1d8f70e2acf283ae958c0cb Mon Sep 17 00:00:00 2001 From: Thomas Vegas Date: Mon, 1 Jun 2026 11:12:00 +0300 Subject: [PATCH 14/15] UCP/RMA/RNDV: Fix static analysis failure --- src/ucp/rndv/rndv_rtr.c | 28 +++++++++++++--------------- 1 file changed, 13 insertions(+), 15 deletions(-) diff --git a/src/ucp/rndv/rndv_rtr.c b/src/ucp/rndv/rndv_rtr.c index 8793790c9af..1bfa7bfbf2d 100644 --- a/src/ucp/rndv/rndv_rtr.c +++ b/src/ucp/rndv/rndv_rtr.c @@ -69,30 +69,28 @@ static ucs_status_t ucp_proto_rndv_rtr_common_send(ucp_request_t *req) { const ucp_proto_rndv_rtr_priv_t *rpriv = req->send.proto_config->priv; ucp_worker_h UCS_V_UNUSED worker = req->send.ep->worker; - uct_pack_callback_t pack_cb = rpriv->pack_cb; - ucp_request_t *recv_req = NULL; ucp_ep_h ep = req->send.ep; + ucp_request_t *recv_req = NULL; + uct_pack_callback_t pack_cb; size_t max_rtr_size; ucs_status_t status; if (req->flags & UCP_REQUEST_FLAG_RNDV_RTR_REQ) { pack_cb = ucp_proto_rndv_rtr_req_pack; max_rtr_size = ucp_proto_rndv_rtr_req_max_size(req); - recv_req = ucp_rma_rndv_rtr_flush_open(req); - status = ucp_proto_am_bcopy_single_send( - req, UCP_AM_ID_RNDV_RTR, rpriv->super.lane, pack_cb, req, - max_rtr_size, 0); - ucp_rma_rndv_rtr_flush_close(recv_req, ep, status); - - return ucp_proto_single_status_handle(req, 0, NULL, - rpriv->super.lane, status); + recv_req = ucp_rma_rndv_rtr_flush_open(req); + } else { + pack_cb = rpriv->pack_cb; + max_rtr_size = ucp_proto_rndv_rtr_max_size(req); } - max_rtr_size = ucp_proto_rndv_rtr_max_size(req); - status = ucp_proto_am_bcopy_single_progress(req, UCP_AM_ID_RNDV_RTR, - rpriv->super.lane, pack_cb, - req, max_rtr_size, NULL, 0); - return status; + status = ucp_proto_am_bcopy_single_send(req, UCP_AM_ID_RNDV_RTR, + rpriv->super.lane, pack_cb, req, + max_rtr_size, 0); + ucp_rma_rndv_rtr_flush_close(recv_req, ep, status); + + return ucp_proto_single_status_handle(req, 0, NULL, rpriv->super.lane, + status); } static UCS_F_ALWAYS_INLINE void From 9be1795353c059f5cb2ac65b31d743a711239ff2 Mon Sep 17 00:00:00 2001 From: Thomas Vegas Date: Mon, 1 Jun 2026 21:26:04 +0300 Subject: [PATCH 15/15] UCP/PROTO: Update maximum number of protocols and DT support --- src/ucp/proto/proto.h | 2 +- src/ucp/rma/rma_rndv.c | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/src/ucp/proto/proto.h b/src/ucp/proto/proto.h index ae7227b1998..26bc88f6cc2 100644 --- a/src/ucp/proto/proto.h +++ b/src/ucp/proto/proto.h @@ -23,7 +23,7 @@ /* Maximal number of protocols in total */ -#define UCP_PROTO_MAX_COUNT 64 +#define UCP_PROTO_MAX_COUNT 65 /* Special value for non-existent protocol */ diff --git a/src/ucp/rma/rma_rndv.c b/src/ucp/rma/rma_rndv.c index 241117bf21b..27221406d1d 100644 --- a/src/ucp/rma/rma_rndv.c +++ b/src/ucp/rma/rma_rndv.c @@ -577,6 +577,7 @@ ucp_proto_t ucp_put_rndv_proto = { .name = "put/rndv", .desc = UCP_PROTO_RNDV_DESC, .flags = 0, + .dt_mask = UCS_BIT(UCP_DATATYPE_CONTIG), .probe = ucp_proto_put_rndv_probe, .query = ucp_proto_put_rndv_query, .progress = {ucp_proto_put_rndv_progress}, @@ -588,6 +589,7 @@ ucp_proto_t ucp_get_rndv_proto = { .name = "get/rndv", .desc = UCP_PROTO_RNDV_DESC, .flags = 0, + .dt_mask = UCS_BIT(UCP_DATATYPE_CONTIG), .probe = ucp_proto_get_rndv_probe, .query = ucp_proto_get_rndv_query, .progress = {ucp_proto_get_rndv_progress},