Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 30 additions & 2 deletions csrc/kernels/runtime.cu
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,26 @@ namespace internode {
nvshmem_team_t cpu_rdma_team = NVSHMEM_TEAM_INVALID;
nvshmem_team_config_t cpu_rdma_team_config;

#ifndef GLOBALS_H
#define GLOBALS_H

class GlobalState {
public:
static GlobalState& instance() {
static GlobalState inst;
return inst;
}

int64_t counter;

private:
GlobalState() : counter(0) {}
GlobalState(const GlobalState&) = delete;
GlobalState& operator=(const GlobalState&) = delete;
};

#endif

std::vector<uint8_t> get_unique_id() {
nvshmemx_uniqueid_t unique_id;
nvshmemx_get_uniqueid(&unique_id);
Expand All @@ -53,6 +73,9 @@ int init(const std::vector<uint8_t>& root_unique_id_val, int rank, int num_ranks
nvshmemx_set_attr_uniqueid_args(rank, num_ranks, &root_unique_id, &attr);
nvshmemx_init_attr(NVSHMEMX_INIT_WITH_UNIQUEID, &attr);

// Initialize before nvshmem_team_split_strided
nvshmem_barrier_all();

// Create sub-RDMA teams
// NOTES: if `num_ranks <= NUM_MAX_NVL_PEERS` then only low-latency kernels are used
if (low_latency_mode and num_ranks > NUM_MAX_NVL_PEERS) {
Expand All @@ -68,7 +91,7 @@ int init(const std::vector<uint8_t>& root_unique_id_val, int rank, int num_ranks
EP_HOST_ASSERT(cpu_rdma_team != NVSHMEM_TEAM_INVALID);
}

nvshmem_barrier_all();
// nvshmem_barrier_all();
return nvshmem_my_pe();
}

Expand All @@ -89,7 +112,12 @@ void finalize() {
nvshmem_team_destroy(cpu_rdma_team);
cpu_rdma_team = NVSHMEM_TEAM_INVALID;
}
nvshmem_finalize();

GlobalState::instance().counter++;
if (GlobalState::instance().counter > 1) {
nvshmem_finalize();
}
// nvshmem_finalize();
}
#endif

Expand Down