diff --git a/src/modules/transport/libfabric/libfabric.cpp b/src/modules/transport/libfabric/libfabric.cpp index d76b85fb..7b2f7f3b 100644 --- a/src/modules/transport/libfabric/libfabric.cpp +++ b/src/modules/transport/libfabric/libfabric.cpp @@ -112,6 +112,16 @@ int get_next_ep(nvshmemt_libfabric_state_t *state, int qp_index) { } } +int get_next_remote_domain(nvshmemt_libfabric_state_t *state, int qp_index, int my_pe, int pe) { + if (qp_index == NVSHMEMX_QP_HOST) { + return 0; + } else { + int base = (my_pe + pe) % state->num_proxy_domains; + int rr = (state->remote_ep_cntr++) % state->num_proxy_domains; + return ((base + rr) % state->num_proxy_domains) + state->num_host_domains; + } +} + nvshmemt_libfabric_imm_cq_data_hdr_t get_write_with_imm_hdr(uint64_t imm_data) { return (nvshmemt_libfabric_imm_cq_data_hdr_t)((uint32_t)imm_data >> NVSHMEM_STAGED_AMO_PUT_SIGNAL_SEQ_CNTR_BIT_SHIFT); @@ -776,7 +786,7 @@ static int nvshmemt_libfabric_show_info(struct nvshmem_transport *transport, int static int nvshmemt_libfabric_rma_impl(struct nvshmem_transport *tcurr, int pe, rma_verb_t verb, rma_memdesc_t *remote, rma_memdesc_t *local, rma_bytesdesc_t bytesdesc, int qp_index, uint32_t *imm_data, - nvshmemt_libfabric_endpoint_t &ep) { + nvshmemt_libfabric_endpoint_t &ep, int remote_domain_idx) { nvshmemt_libfabric_mem_handle_ep_t *remote_handle, *local_handle = NULL; void *local_mr_desc = NULL; nvshmemt_libfabric_state_t *libfabric_state = (nvshmemt_libfabric_state_t *)tcurr->state; @@ -795,7 +805,8 @@ static int nvshmemt_libfabric_rma_impl(struct nvshmem_transport *tcurr, int pe, ep_idx = ep.ep_index; domain_idx = ep.domain_index; - target_ep = pe * libfabric_state->eps.size() + ep_idx; + /* Use remote_domain_idx for target EP addressing and remote MR key selection */ + target_ep = pe * libfabric_state->eps.size() + remote_domain_idx; if (libfabric_state->provider == NVSHMEMT_LIBFABRIC_PROVIDER_EFA) { nvshmemt_libfabric_gdr_op_ctx_t *gdr_ctx; @@ -814,7 +825,8 @@ static int nvshmemt_libfabric_rma_impl(struct nvshmem_transport *tcurr, int pe, } } - remote_handle = &((nvshmemt_libfabric_mem_handle_t *)remote->handle)->hdls[domain_idx]; + /* Use remote_domain_idx to select remote MR key, distributing RX across remote NICs */ + remote_handle = &((nvshmemt_libfabric_mem_handle_t *)remote->handle)->hdls[remote_domain_idx]; op_size = bytesdesc.elembytes * bytesdesc.nelems; if (verb.desc == NVSHMEMI_OP_P) { @@ -906,9 +918,10 @@ static int nvshmemt_libfabric_rma(struct nvshmem_transport *tcurr, int pe, rma_v rma_bytesdesc_t bytesdesc, int qp_index) { nvshmemt_libfabric_state_t *libfabric_state = (nvshmemt_libfabric_state_t *)tcurr->state; int ep_idx = get_next_ep(libfabric_state, qp_index); + int remote_domain_idx = get_next_remote_domain(libfabric_state, qp_index, tcurr->my_pe, pe); nvshmemt_libfabric_endpoint_t &ep = *(libfabric_state->eps[ep_idx]); return nvshmemt_libfabric_rma_impl(tcurr, pe, verb, remote, local, bytesdesc, qp_index, NULL, - ep); + ep, remote_domain_idx); } static int nvshmemt_libfabric_gdr_amo(struct nvshmem_transport *transport, int pe, void *curetptr, @@ -1114,7 +1127,8 @@ static int nvshmemt_libfabric_gdr_signal(struct nvshmem_transport *transport, in void *curetptr, amo_verb_t verb, amo_memdesc_t *remote, amo_bytesdesc_t bytesdesc, int qp_index, uint32_t sequence_count, uint16_t num_writes, - nvshmemt_libfabric_endpoint_t &ep) { + nvshmemt_libfabric_endpoint_t &ep, + int remote_domain_idx) { nvshmemt_libfabric_state_t *libfabric_state = (nvshmemt_libfabric_state_t *)transport->state; nvshmemt_libfabric_gdr_op_ctx_t *context; nvshmemt_libfabric_gdr_signal_op_t *signal; @@ -1124,7 +1138,7 @@ static int nvshmemt_libfabric_gdr_signal(struct nvshmem_transport *transport, in int status = 0; domain_idx = ep.domain_index; - target_ep = pe * libfabric_state->eps.size() + ep.ep_index; + target_ep = pe * libfabric_state->eps.size() + remote_domain_idx; static_assert(sizeof(nvshmemt_libfabric_gdr_op_ctx) >= sizeof(nvshmemt_libfabric_gdr_signal_op_t)); @@ -1175,6 +1189,7 @@ static int nvshmemt_libfabric_put_signal_unordered(struct nvshmem_transport *tcu int status = 0; int ep_idx = get_next_ep(libfabric_state, qp_index); + int remote_domain_idx = get_next_remote_domain(libfabric_state, qp_index, tcurr->my_pe, pe); nvshmemt_libfabric_endpoint_t &ep = *(libfabric_state->eps[ep_idx]); /* Get sequence number for this put-signal, with retry */ @@ -1200,7 +1215,8 @@ static int nvshmemt_libfabric_put_signal_unordered(struct nvshmem_transport *tcu for (size_t i = 0; i < write_remote.size(); i++) { status = nvshmemt_libfabric_rma_impl(tcurr, pe, write_verb, &write_remote[i], &write_local[i], - write_bytes_desc[i], qp_index, &sequence_count, ep); + write_bytes_desc[i], qp_index, &sequence_count, ep, + remote_domain_idx); if (unlikely(status)) { NVSHMEMI_ERROR_PRINT( "Error in nvshmemt_put_signal_unordered, could not submit write #%lu\n", i); @@ -1211,7 +1227,8 @@ static int nvshmemt_libfabric_put_signal_unordered(struct nvshmem_transport *tcu assert(use_staged_atomics == true); status = nvshmemt_libfabric_gdr_signal(tcurr, pe, NULL, sig_verb, sig_target, sig_bytes_desc, - qp_index, sequence_count, (uint16_t)write_remote.size(), ep); + qp_index, sequence_count, (uint16_t)write_remote.size(), ep, + remote_domain_idx); out: if (status) { NVSHMEMI_ERROR_PRINT( @@ -1624,6 +1641,7 @@ static int nvshmemt_libfabric_connect_endpoints(nvshmem_transport_t t, int *sele /* One-time initializations */ t->max_op_len = UINT64_MAX; state->proxy_ep_cntr = 0; + state->remote_ep_cntr = 0; memset(&cq_attr, 0, sizeof(struct fi_cq_attr)); if (state->provider == NVSHMEMT_LIBFABRIC_PROVIDER_SLINGSHOT) { diff --git a/src/modules/transport/libfabric/libfabric.h b/src/modules/transport/libfabric/libfabric.h index fb6bb1bb..1d2b50a1 100644 --- a/src/modules/transport/libfabric/libfabric.h +++ b/src/modules/transport/libfabric/libfabric.h @@ -417,6 +417,7 @@ typedef struct { int num_selected_devs; int max_nic_per_pe; std::atomic proxy_ep_cntr; + std::atomic remote_ep_cntr; /* Required for staged_amo */ std::vector> op_queue;