diff --git a/csrc/kernels/runtime.cu b/csrc/kernels/runtime.cu index c4fbb8ed..284638b9 100644 --- a/csrc/kernels/runtime.cu +++ b/csrc/kernels/runtime.cu @@ -38,6 +38,26 @@ namespace internode { nvshmem_team_t cpu_rdma_team = NVSHMEM_TEAM_INVALID; nvshmem_team_config_t cpu_rdma_team_config; +#ifndef GLOBALS_H +#define GLOBALS_H + +class GlobalState { +public: + static GlobalState& instance() { + static GlobalState inst; + return inst; + } + + int64_t counter; + +private: + GlobalState() : counter(0) {} + GlobalState(const GlobalState&) = delete; + GlobalState& operator=(const GlobalState&) = delete; +}; + +#endif + std::vector get_unique_id() { nvshmemx_uniqueid_t unique_id; nvshmemx_get_uniqueid(&unique_id); @@ -53,6 +73,9 @@ int init(const std::vector& root_unique_id_val, int rank, int num_ranks nvshmemx_set_attr_uniqueid_args(rank, num_ranks, &root_unique_id, &attr); nvshmemx_init_attr(NVSHMEMX_INIT_WITH_UNIQUEID, &attr); + // Initialize before nvshmem_team_split_strided + nvshmem_barrier_all(); + // Create sub-RDMA teams // NOTES: if `num_ranks <= NUM_MAX_NVL_PEERS` then only low-latency kernels are used if (low_latency_mode and num_ranks > NUM_MAX_NVL_PEERS) { @@ -68,7 +91,7 @@ int init(const std::vector& root_unique_id_val, int rank, int num_ranks EP_HOST_ASSERT(cpu_rdma_team != NVSHMEM_TEAM_INVALID); } - nvshmem_barrier_all(); + // nvshmem_barrier_all(); return nvshmem_my_pe(); } @@ -89,7 +112,12 @@ void finalize() { nvshmem_team_destroy(cpu_rdma_team); cpu_rdma_team = NVSHMEM_TEAM_INVALID; } - nvshmem_finalize(); + + GlobalState::instance().counter++; + if (GlobalState::instance().counter > 1) { + nvshmem_finalize(); + } + // nvshmem_finalize(); } #endif