diff --git a/src/modules/transport/common/env_defs.h b/src/modules/transport/common/env_defs.h index 8ab5aaa3..9bf9922d 100644 --- a/src/modules/transport/common/env_defs.h +++ b/src/modules/transport/common/env_defs.h @@ -152,6 +152,9 @@ NVSHMEMI_ENV_DEF(IBGDA_NIC_HANDLER, string, "auto", NVSHMEMI_ENV_CAT_TRANSPORT, "- gpu: use GPU SMs.\n" "- cpu: use CPU with gdrcopy backend.\n" "- cpu_host_memory: use CPU with CUDA memory.") +NVSHMEMI_ENV_DEF(IBGDA_TIMEOUT, int, 22, NVSHMEMI_ENV_CAT_TRANSPORT, + "IBGDA ibverbs timeout. Values can be 0-31, 0 means infinite. " + "Actual timeout value = 4.096us * 2^timeout .") NVSHMEMI_ENV_DEF(IB_ENABLE_IBGDA, bool, false, NVSHMEMI_ENV_CAT_TRANSPORT, "Set to enable GPU-initiated communication transport.") #endif diff --git a/src/modules/transport/ibgda/ibgda.cpp b/src/modules/transport/ibgda/ibgda.cpp index 3f762f9e..999525c0 100644 --- a/src/modules/transport/ibgda/ibgda.cpp +++ b/src/modules/transport/ibgda/ibgda.cpp @@ -1790,7 +1790,7 @@ static int ibgda_rc_init2rtr(nvshmemt_ibgda_state_t *ibgda_state, struct ibgda_e return status; } -static int ibgda_qp_rtr2rts(struct ibgda_ep *ep, const struct ibgda_device *device, int portid) { +static int ibgda_qp_rtr2rts(nvshmemt_ibgda_state_t *ibgda_state, struct ibgda_ep *ep, const struct ibgda_device *device, int portid) { int status = 0; uint8_t cmd_in[DEVX_ST_SZ_BYTES(rtr2rts_qp_in)] = { @@ -1815,7 +1815,7 @@ static int ibgda_qp_rtr2rts(struct ibgda_ep *ep, const struct ibgda_device *devi DEVX_SET(qpc, qpc, next_send_psn, 0x0); DEVX_SET(qpc, qpc, retry_count, 7); DEVX_SET(qpc, qpc, rnr_retry, 7); - DEVX_SET(qpc, qpc, primary_address_path.ack_timeout, 20); + DEVX_SET(qpc, qpc, primary_address_path.ack_timeout, ibgda_state->option->IBGDA_TIMEOUT); status = mlx5dv_devx_obj_modify(ep->devx_qp, cmd_in, sizeof(cmd_in), cmd_out, sizeof(cmd_out)); NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, @@ -3147,7 +3147,7 @@ static int ibgda_setup_rc_endpoints(nvshmemt_ibgda_state_t *ibgda_state, NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "ibgda_rc_init2rtr failed on RC #%d.", ep_index); - status = ibgda_qp_rtr2rts(device->rc.eps[ep_index], device, portid); + status = ibgda_qp_rtr2rts(ibgda_state, device->rc.eps[ep_index], device, portid); NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "ibgda_qp_rtr2rts failed on RC #%d.", ep_index); }