From 91935c554cb8265c4b7bdd20065f9a0e961c05d8 Mon Sep 17 00:00:00 2001 From: Eddie-Wang1120 Date: Tue, 27 Jan 2026 20:17:25 +0800 Subject: [PATCH 1/2] fix deepep init / finalize --- csrc/deep_ep.cpp | 7 ++++++- csrc/deep_ep.hpp | 21 +++++++++++++++++++++ csrc/kernels/runtime.cu | 5 ++++- 3 files changed, 31 insertions(+), 2 deletions(-) diff --git a/csrc/deep_ep.cpp b/csrc/deep_ep.cpp index 8a8b121e..72bb7fd2 100644 --- a/csrc/deep_ep.cpp +++ b/csrc/deep_ep.cpp @@ -322,7 +322,12 @@ void Buffer::destroy() { internode::free(mask_buffer_ptr); internode::free(sync_buffer_ptr); } - internode::finalize(); + // internode::finalize(); + + GlobalState::instance().counter++; + if (GlobalState::instance().counter > 1) { + internode::finalize(); + } } #endif diff --git a/csrc/deep_ep.hpp b/csrc/deep_ep.hpp index cd5be4a3..8bf9d465 100644 --- a/csrc/deep_ep.hpp +++ b/csrc/deep_ep.hpp @@ -25,6 +25,27 @@ #define TORCH_EXTENSION_NAME deep_ep_cpp #endif +#ifndef GLOBALS_H +#define GLOBALS_H + +class GlobalState { +public: + static GlobalState& instance() { + static GlobalState inst; + return inst; + } + + int64_t counter; + +private: + GlobalState() : counter(0) {} + GlobalState(const GlobalState&) = delete; + GlobalState& operator=(const GlobalState&) = delete; +}; + +#endif + + namespace shared_memory { union MemHandleInner { diff --git a/csrc/kernels/runtime.cu b/csrc/kernels/runtime.cu index c4fbb8ed..2a823f8f 100644 --- a/csrc/kernels/runtime.cu +++ b/csrc/kernels/runtime.cu @@ -53,6 +53,9 @@ int init(const std::vector& root_unique_id_val, int rank, int num_ranks nvshmemx_set_attr_uniqueid_args(rank, num_ranks, &root_unique_id, &attr); nvshmemx_init_attr(NVSHMEMX_INIT_WITH_UNIQUEID, &attr); + // Initialize before nvshmem_team_split_strided + nvshmem_barrier_all(); + // Create sub-RDMA teams // NOTES: if `num_ranks <= NUM_MAX_NVL_PEERS` then only low-latency kernels are used if (low_latency_mode and num_ranks > NUM_MAX_NVL_PEERS) { @@ -68,7 +71,7 @@ int init(const std::vector& root_unique_id_val, int rank, int num_ranks EP_HOST_ASSERT(cpu_rdma_team != NVSHMEM_TEAM_INVALID); } - nvshmem_barrier_all(); + // nvshmem_barrier_all(); return nvshmem_my_pe(); } From f33d8a36964d3e43045923ae00f713d88b672b14 Mon Sep 17 00:00:00 2001 From: Eddie-Wang1120 Date: Mon, 2 Feb 2026 10:46:54 +0800 Subject: [PATCH 2/2] fix --- csrc/deep_ep.cpp | 7 +------ csrc/deep_ep.hpp | 21 --------------------- csrc/kernels/runtime.cu | 27 ++++++++++++++++++++++++++- 3 files changed, 27 insertions(+), 28 deletions(-) diff --git a/csrc/deep_ep.cpp b/csrc/deep_ep.cpp index 72bb7fd2..8a8b121e 100644 --- a/csrc/deep_ep.cpp +++ b/csrc/deep_ep.cpp @@ -322,12 +322,7 @@ void Buffer::destroy() { internode::free(mask_buffer_ptr); internode::free(sync_buffer_ptr); } - // internode::finalize(); - - GlobalState::instance().counter++; - if (GlobalState::instance().counter > 1) { - internode::finalize(); - } + internode::finalize(); } #endif diff --git a/csrc/deep_ep.hpp b/csrc/deep_ep.hpp index 8bf9d465..cd5be4a3 100644 --- a/csrc/deep_ep.hpp +++ b/csrc/deep_ep.hpp @@ -25,27 +25,6 @@ #define TORCH_EXTENSION_NAME deep_ep_cpp #endif -#ifndef GLOBALS_H -#define GLOBALS_H - -class GlobalState { -public: - static GlobalState& instance() { - static GlobalState inst; - return inst; - } - - int64_t counter; - -private: - GlobalState() : counter(0) {} - GlobalState(const GlobalState&) = delete; - GlobalState& operator=(const GlobalState&) = delete; -}; - -#endif - - namespace shared_memory { union MemHandleInner { diff --git a/csrc/kernels/runtime.cu b/csrc/kernels/runtime.cu index 2a823f8f..284638b9 100644 --- a/csrc/kernels/runtime.cu +++ b/csrc/kernels/runtime.cu @@ -38,6 +38,26 @@ namespace internode { nvshmem_team_t cpu_rdma_team = NVSHMEM_TEAM_INVALID; nvshmem_team_config_t cpu_rdma_team_config; +#ifndef GLOBALS_H +#define GLOBALS_H + +class GlobalState { +public: + static GlobalState& instance() { + static GlobalState inst; + return inst; + } + + int64_t counter; + +private: + GlobalState() : counter(0) {} + GlobalState(const GlobalState&) = delete; + GlobalState& operator=(const GlobalState&) = delete; +}; + +#endif + std::vector get_unique_id() { nvshmemx_uniqueid_t unique_id; nvshmemx_get_uniqueid(&unique_id); @@ -92,7 +112,12 @@ void finalize() { nvshmem_team_destroy(cpu_rdma_team); cpu_rdma_team = NVSHMEM_TEAM_INVALID; } - nvshmem_finalize(); + + GlobalState::instance().counter++; + if (GlobalState::instance().counter > 1) { + nvshmem_finalize(); + } + // nvshmem_finalize(); } #endif