Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 26 additions & 8 deletions src/modules/transport/libfabric/libfabric.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,16 @@ int get_next_ep(nvshmemt_libfabric_state_t *state, int qp_index) {
}
}

int get_next_remote_domain(nvshmemt_libfabric_state_t *state, int qp_index, int my_pe, int pe) {
if (qp_index == NVSHMEMX_QP_HOST) {
return 0;
} else {
int base = (my_pe + pe) % state->num_proxy_domains;
int rr = (state->remote_ep_cntr++) % state->num_proxy_domains;
return ((base + rr) % state->num_proxy_domains) + state->num_host_domains;
}
}

nvshmemt_libfabric_imm_cq_data_hdr_t get_write_with_imm_hdr(uint64_t imm_data) {
return (nvshmemt_libfabric_imm_cq_data_hdr_t)((uint32_t)imm_data >>
NVSHMEM_STAGED_AMO_PUT_SIGNAL_SEQ_CNTR_BIT_SHIFT);
Expand Down Expand Up @@ -776,7 +786,7 @@ static int nvshmemt_libfabric_show_info(struct nvshmem_transport *transport, int
static int nvshmemt_libfabric_rma_impl(struct nvshmem_transport *tcurr, int pe, rma_verb_t verb,
rma_memdesc_t *remote, rma_memdesc_t *local,
rma_bytesdesc_t bytesdesc, int qp_index, uint32_t *imm_data,
nvshmemt_libfabric_endpoint_t &ep) {
nvshmemt_libfabric_endpoint_t &ep, int remote_domain_idx) {
nvshmemt_libfabric_mem_handle_ep_t *remote_handle, *local_handle = NULL;
void *local_mr_desc = NULL;
nvshmemt_libfabric_state_t *libfabric_state = (nvshmemt_libfabric_state_t *)tcurr->state;
Expand All @@ -795,7 +805,8 @@ static int nvshmemt_libfabric_rma_impl(struct nvshmem_transport *tcurr, int pe,

ep_idx = ep.ep_index;
domain_idx = ep.domain_index;
target_ep = pe * libfabric_state->eps.size() + ep_idx;
/* Use remote_domain_idx for target EP addressing and remote MR key selection */
target_ep = pe * libfabric_state->eps.size() + remote_domain_idx;

if (libfabric_state->provider == NVSHMEMT_LIBFABRIC_PROVIDER_EFA) {
nvshmemt_libfabric_gdr_op_ctx_t *gdr_ctx;
Expand All @@ -814,7 +825,8 @@ static int nvshmemt_libfabric_rma_impl(struct nvshmem_transport *tcurr, int pe,
}
}

remote_handle = &((nvshmemt_libfabric_mem_handle_t *)remote->handle)->hdls[domain_idx];
/* Use remote_domain_idx to select remote MR key, distributing RX across remote NICs */
remote_handle = &((nvshmemt_libfabric_mem_handle_t *)remote->handle)->hdls[remote_domain_idx];
op_size = bytesdesc.elembytes * bytesdesc.nelems;

if (verb.desc == NVSHMEMI_OP_P) {
Expand Down Expand Up @@ -906,9 +918,10 @@ static int nvshmemt_libfabric_rma(struct nvshmem_transport *tcurr, int pe, rma_v
rma_bytesdesc_t bytesdesc, int qp_index) {
nvshmemt_libfabric_state_t *libfabric_state = (nvshmemt_libfabric_state_t *)tcurr->state;
int ep_idx = get_next_ep(libfabric_state, qp_index);
int remote_domain_idx = get_next_remote_domain(libfabric_state, qp_index, tcurr->my_pe, pe);
nvshmemt_libfabric_endpoint_t &ep = *(libfabric_state->eps[ep_idx]);
return nvshmemt_libfabric_rma_impl(tcurr, pe, verb, remote, local, bytesdesc, qp_index, NULL,
ep);
ep, remote_domain_idx);
}

static int nvshmemt_libfabric_gdr_amo(struct nvshmem_transport *transport, int pe, void *curetptr,
Expand Down Expand Up @@ -1114,7 +1127,8 @@ static int nvshmemt_libfabric_gdr_signal(struct nvshmem_transport *transport, in
void *curetptr, amo_verb_t verb, amo_memdesc_t *remote,
amo_bytesdesc_t bytesdesc, int qp_index,
uint32_t sequence_count, uint16_t num_writes,
nvshmemt_libfabric_endpoint_t &ep) {
nvshmemt_libfabric_endpoint_t &ep,
int remote_domain_idx) {
nvshmemt_libfabric_state_t *libfabric_state = (nvshmemt_libfabric_state_t *)transport->state;
nvshmemt_libfabric_gdr_op_ctx_t *context;
nvshmemt_libfabric_gdr_signal_op_t *signal;
Expand All @@ -1124,7 +1138,7 @@ static int nvshmemt_libfabric_gdr_signal(struct nvshmem_transport *transport, in
int status = 0;

domain_idx = ep.domain_index;
target_ep = pe * libfabric_state->eps.size() + ep.ep_index;
target_ep = pe * libfabric_state->eps.size() + remote_domain_idx;

static_assert(sizeof(nvshmemt_libfabric_gdr_op_ctx) >=
sizeof(nvshmemt_libfabric_gdr_signal_op_t));
Expand Down Expand Up @@ -1175,6 +1189,7 @@ static int nvshmemt_libfabric_put_signal_unordered(struct nvshmem_transport *tcu
int status = 0;

int ep_idx = get_next_ep(libfabric_state, qp_index);
int remote_domain_idx = get_next_remote_domain(libfabric_state, qp_index, tcurr->my_pe, pe);
nvshmemt_libfabric_endpoint_t &ep = *(libfabric_state->eps[ep_idx]);

/* Get sequence number for this put-signal, with retry */
Expand All @@ -1200,7 +1215,8 @@ static int nvshmemt_libfabric_put_signal_unordered(struct nvshmem_transport *tcu
for (size_t i = 0; i < write_remote.size(); i++) {
status =
nvshmemt_libfabric_rma_impl(tcurr, pe, write_verb, &write_remote[i], &write_local[i],
write_bytes_desc[i], qp_index, &sequence_count, ep);
write_bytes_desc[i], qp_index, &sequence_count, ep,
remote_domain_idx);
if (unlikely(status)) {
NVSHMEMI_ERROR_PRINT(
"Error in nvshmemt_put_signal_unordered, could not submit write #%lu\n", i);
Expand All @@ -1211,7 +1227,8 @@ static int nvshmemt_libfabric_put_signal_unordered(struct nvshmem_transport *tcu
assert(use_staged_atomics == true);
status =
nvshmemt_libfabric_gdr_signal(tcurr, pe, NULL, sig_verb, sig_target, sig_bytes_desc,
qp_index, sequence_count, (uint16_t)write_remote.size(), ep);
qp_index, sequence_count, (uint16_t)write_remote.size(), ep,
remote_domain_idx);
out:
if (status) {
NVSHMEMI_ERROR_PRINT(
Expand Down Expand Up @@ -1624,6 +1641,7 @@ static int nvshmemt_libfabric_connect_endpoints(nvshmem_transport_t t, int *sele
/* One-time initializations */
t->max_op_len = UINT64_MAX;
state->proxy_ep_cntr = 0;
state->remote_ep_cntr = 0;

memset(&cq_attr, 0, sizeof(struct fi_cq_attr));
if (state->provider == NVSHMEMT_LIBFABRIC_PROVIDER_SLINGSHOT) {
Expand Down
1 change: 1 addition & 0 deletions src/modules/transport/libfabric/libfabric.h
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,7 @@ typedef struct {
int num_selected_devs;
int max_nic_per_pe;
std::atomic<uint32_t> proxy_ep_cntr;
std::atomic<uint32_t> remote_ep_cntr;

/* Required for staged_amo */
std::vector<std::unique_ptr<threadSafeOpQueue>> op_queue;
Expand Down