From dbdc97ffbe0273340ca1b4433f64b3aebf985153 Mon Sep 17 00:00:00 2001 From: Chang-Ning Tsai Date: Tue, 24 Mar 2026 22:41:56 -0700 Subject: [PATCH 1/2] libfabric: fix multi-NIC RX imbalance for EFA transport The multi-NIC round-robin in get_next_ep() only balanced TX by rotating the local sending EP. The remote target EP and MR key selection were coupled to the same local domain index, causing all incoming RDMA writes to land on a single NIC per GPU. Decouple remote NIC selection from local EP by introducing get_next_remote_domain(), which uses (my_pe + target_pe) % num_proxy_domains to distribute RX across all remote NICs. Different senders now target different NICs on the same destination PE. Before: TX distributed across 4 NICs, RX bottlenecked on 1 NIC After: TX and RX both distributed across 4 NICs --- src/modules/transport/libfabric/libfabric.cpp | 31 ++++++++++++++----- 1 file changed, 23 insertions(+), 8 deletions(-) diff --git a/src/modules/transport/libfabric/libfabric.cpp b/src/modules/transport/libfabric/libfabric.cpp index d76b85fb..52bb7312 100644 --- a/src/modules/transport/libfabric/libfabric.cpp +++ b/src/modules/transport/libfabric/libfabric.cpp @@ -112,6 +112,14 @@ int get_next_ep(nvshmemt_libfabric_state_t *state, int qp_index) { } } +int get_next_remote_domain(nvshmemt_libfabric_state_t *state, int qp_index, int my_pe, int pe) { + if (qp_index == NVSHMEMX_QP_HOST) { + return 0; + } else { + return ((my_pe + pe) % state->num_proxy_domains) + state->num_host_domains; + } +} + nvshmemt_libfabric_imm_cq_data_hdr_t get_write_with_imm_hdr(uint64_t imm_data) { return (nvshmemt_libfabric_imm_cq_data_hdr_t)((uint32_t)imm_data >> NVSHMEM_STAGED_AMO_PUT_SIGNAL_SEQ_CNTR_BIT_SHIFT); @@ -776,7 +784,7 @@ static int nvshmemt_libfabric_show_info(struct nvshmem_transport *transport, int static int nvshmemt_libfabric_rma_impl(struct nvshmem_transport *tcurr, int pe, rma_verb_t verb, rma_memdesc_t *remote, rma_memdesc_t *local, rma_bytesdesc_t bytesdesc, int qp_index, uint32_t *imm_data, - nvshmemt_libfabric_endpoint_t &ep) { + nvshmemt_libfabric_endpoint_t &ep, int remote_domain_idx) { nvshmemt_libfabric_mem_handle_ep_t *remote_handle, *local_handle = NULL; void *local_mr_desc = NULL; nvshmemt_libfabric_state_t *libfabric_state = (nvshmemt_libfabric_state_t *)tcurr->state; @@ -795,7 +803,8 @@ static int nvshmemt_libfabric_rma_impl(struct nvshmem_transport *tcurr, int pe, ep_idx = ep.ep_index; domain_idx = ep.domain_index; - target_ep = pe * libfabric_state->eps.size() + ep_idx; + /* Use remote_domain_idx for target EP addressing and remote MR key selection */ + target_ep = pe * libfabric_state->eps.size() + remote_domain_idx; if (libfabric_state->provider == NVSHMEMT_LIBFABRIC_PROVIDER_EFA) { nvshmemt_libfabric_gdr_op_ctx_t *gdr_ctx; @@ -814,7 +823,8 @@ static int nvshmemt_libfabric_rma_impl(struct nvshmem_transport *tcurr, int pe, } } - remote_handle = &((nvshmemt_libfabric_mem_handle_t *)remote->handle)->hdls[domain_idx]; + /* Use remote_domain_idx to select remote MR key, distributing RX across remote NICs */ + remote_handle = &((nvshmemt_libfabric_mem_handle_t *)remote->handle)->hdls[remote_domain_idx]; op_size = bytesdesc.elembytes * bytesdesc.nelems; if (verb.desc == NVSHMEMI_OP_P) { @@ -906,9 +916,10 @@ static int nvshmemt_libfabric_rma(struct nvshmem_transport *tcurr, int pe, rma_v rma_bytesdesc_t bytesdesc, int qp_index) { nvshmemt_libfabric_state_t *libfabric_state = (nvshmemt_libfabric_state_t *)tcurr->state; int ep_idx = get_next_ep(libfabric_state, qp_index); + int remote_domain_idx = get_next_remote_domain(libfabric_state, qp_index, tcurr->my_pe, pe); nvshmemt_libfabric_endpoint_t &ep = *(libfabric_state->eps[ep_idx]); return nvshmemt_libfabric_rma_impl(tcurr, pe, verb, remote, local, bytesdesc, qp_index, NULL, - ep); + ep, remote_domain_idx); } static int nvshmemt_libfabric_gdr_amo(struct nvshmem_transport *transport, int pe, void *curetptr, @@ -1114,7 +1125,8 @@ static int nvshmemt_libfabric_gdr_signal(struct nvshmem_transport *transport, in void *curetptr, amo_verb_t verb, amo_memdesc_t *remote, amo_bytesdesc_t bytesdesc, int qp_index, uint32_t sequence_count, uint16_t num_writes, - nvshmemt_libfabric_endpoint_t &ep) { + nvshmemt_libfabric_endpoint_t &ep, + int remote_domain_idx) { nvshmemt_libfabric_state_t *libfabric_state = (nvshmemt_libfabric_state_t *)transport->state; nvshmemt_libfabric_gdr_op_ctx_t *context; nvshmemt_libfabric_gdr_signal_op_t *signal; @@ -1124,7 +1136,7 @@ static int nvshmemt_libfabric_gdr_signal(struct nvshmem_transport *transport, in int status = 0; domain_idx = ep.domain_index; - target_ep = pe * libfabric_state->eps.size() + ep.ep_index; + target_ep = pe * libfabric_state->eps.size() + remote_domain_idx; static_assert(sizeof(nvshmemt_libfabric_gdr_op_ctx) >= sizeof(nvshmemt_libfabric_gdr_signal_op_t)); @@ -1175,6 +1187,7 @@ static int nvshmemt_libfabric_put_signal_unordered(struct nvshmem_transport *tcu int status = 0; int ep_idx = get_next_ep(libfabric_state, qp_index); + int remote_domain_idx = get_next_remote_domain(libfabric_state, qp_index, tcurr->my_pe, pe); nvshmemt_libfabric_endpoint_t &ep = *(libfabric_state->eps[ep_idx]); /* Get sequence number for this put-signal, with retry */ @@ -1200,7 +1213,8 @@ static int nvshmemt_libfabric_put_signal_unordered(struct nvshmem_transport *tcu for (size_t i = 0; i < write_remote.size(); i++) { status = nvshmemt_libfabric_rma_impl(tcurr, pe, write_verb, &write_remote[i], &write_local[i], - write_bytes_desc[i], qp_index, &sequence_count, ep); + write_bytes_desc[i], qp_index, &sequence_count, ep, + remote_domain_idx); if (unlikely(status)) { NVSHMEMI_ERROR_PRINT( "Error in nvshmemt_put_signal_unordered, could not submit write #%lu\n", i); @@ -1211,7 +1225,8 @@ static int nvshmemt_libfabric_put_signal_unordered(struct nvshmem_transport *tcu assert(use_staged_atomics == true); status = nvshmemt_libfabric_gdr_signal(tcurr, pe, NULL, sig_verb, sig_target, sig_bytes_desc, - qp_index, sequence_count, (uint16_t)write_remote.size(), ep); + qp_index, sequence_count, (uint16_t)write_remote.size(), ep, + remote_domain_idx); out: if (status) { NVSHMEMI_ERROR_PRINT( From 002f1518cb30f49767fd0fa249c0ebd115177a0d Mon Sep 17 00:00:00 2001 From: Chang-Ning Tsai Date: Wed, 25 Mar 2026 13:45:54 -0700 Subject: [PATCH 2/2] libfabric: add round-robin to remote NIC selection for RX balance MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The static remote NIC mapping (my_pe + target_pe) % num_proxy_domains distributes RX across remote NICs when multiple senders target the same PE, but does not balance RX when a single sender repeatedly puts to the same destination — all traffic lands on the same remote NIC. Add a remote_ep_cntr that round-robins the remote domain selection on top of the per-sender base offset. This ensures RX is distributed even in single-sender-to-single-receiver patterns (e.g., two-PE case), while still spreading traffic from different senders across different remote NICs. --- src/modules/transport/libfabric/libfabric.cpp | 5 ++++- src/modules/transport/libfabric/libfabric.h | 1 + 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/src/modules/transport/libfabric/libfabric.cpp b/src/modules/transport/libfabric/libfabric.cpp index 52bb7312..7b2f7f3b 100644 --- a/src/modules/transport/libfabric/libfabric.cpp +++ b/src/modules/transport/libfabric/libfabric.cpp @@ -116,7 +116,9 @@ int get_next_remote_domain(nvshmemt_libfabric_state_t *state, int qp_index, int if (qp_index == NVSHMEMX_QP_HOST) { return 0; } else { - return ((my_pe + pe) % state->num_proxy_domains) + state->num_host_domains; + int base = (my_pe + pe) % state->num_proxy_domains; + int rr = (state->remote_ep_cntr++) % state->num_proxy_domains; + return ((base + rr) % state->num_proxy_domains) + state->num_host_domains; } } @@ -1639,6 +1641,7 @@ static int nvshmemt_libfabric_connect_endpoints(nvshmem_transport_t t, int *sele /* One-time initializations */ t->max_op_len = UINT64_MAX; state->proxy_ep_cntr = 0; + state->remote_ep_cntr = 0; memset(&cq_attr, 0, sizeof(struct fi_cq_attr)); if (state->provider == NVSHMEMT_LIBFABRIC_PROVIDER_SLINGSHOT) { diff --git a/src/modules/transport/libfabric/libfabric.h b/src/modules/transport/libfabric/libfabric.h index fb6bb1bb..1d2b50a1 100644 --- a/src/modules/transport/libfabric/libfabric.h +++ b/src/modules/transport/libfabric/libfabric.h @@ -417,6 +417,7 @@ typedef struct { int num_selected_devs; int max_nic_per_pe; std::atomic proxy_ep_cntr; + std::atomic remote_ep_cntr; /* Required for staged_amo */ std::vector> op_queue;