Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions cpp/include/tensorrt_llm/batch_manager/cacheTransceiver.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,12 @@
#include "tensorrt_llm/executor/dataTransceiverState.h"
#include "tensorrt_llm/runtime/utils/mpiUtils.h"
#include "tensorrt_llm/runtime/utils/pgUtils.h"
#include <fstream>
#include <future>
#include <memory>
#include <mutex>
#include <optional>
#include <string>
#include <pybind11/pybind11.h>
#include <torch/csrc/jit/python/pybind_utils.h>
#include <torch/custom_class.h>
Expand Down Expand Up @@ -294,6 +296,13 @@ class CacheTransceiver : public BaseCacheTransceiver
// TODO(shreyasm): update this to use same container as kv by using base trans buffers instead
std::unique_ptr<rnn_state_manager::RnnCacheTransBufferManager> mRnnCacheTransBufferManager{nullptr};

// Unique instance identifier for CSV file naming (avoids collisions across gen instances)
std::string mInstanceId;

// Gen-side transfer summary CSV (written after timing sync)
std::ofstream mGenTransferSummaryFile;
std::mutex mGenTransferSummaryMutex;

// library handle to the communicator related features,
// this is used to defer dependency resolution until needed.
static std::mutex mDllMutex;
Expand Down
7 changes: 7 additions & 0 deletions cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,11 @@

namespace kvc = tensorrt_llm::executor::kv_cache;

namespace tensorrt_llm::batch_manager::kv_cache_manager
{
class FabricMemory;
} // namespace tensorrt_llm::batch_manager::kv_cache_manager

namespace tensorrt_llm::batch_manager::eviction_policy
{
class BaseEvictionPolicy;
Expand Down Expand Up @@ -1097,6 +1102,8 @@ class WindowBlockManager
bool mOnboardBlocks;
// Buffer manager
runtime::BufferManager mBufferManager;
// Fabric memory backing for primary pools (MNNVL-capable allocation)
std::vector<std::unique_ptr<kv_cache_manager::FabricMemory>> mFabricMemoryPools;

// Used to keep track of number of free blocks during scheduling
SizeType32 mSchedulingNumFreeBlocks;
Expand Down
246 changes: 205 additions & 41 deletions cpp/tensorrt_llm/batch_manager/cacheTransceiver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,38 @@
#include "tensorrt_llm/runtime/utils/pgUtils.h"
#include <algorithm>
#include <cstddef>
#include <filesystem>
#include <fstream>
#include <iomanip>
#include <numeric>
#include <random>
#include <sstream>
#include <unordered_set>

namespace tensorrt_llm::batch_manager
{

namespace
{

/// Generate a UUID-like hex string (e.g. "a1b2c3d4-e5f6-7890-abcd-ef1234567890")
/// to uniquely identify a CacheTransceiver instance across gen instances.
std::string generateInstanceId()
{
std::random_device rd;
std::mt19937_64 gen(rd());
std::uniform_int_distribution<uint64_t> dis;
uint64_t a = dis(gen);
uint64_t b = dis(gen);
std::ostringstream oss;
oss << std::hex << std::setfill('0') << std::setw(8) << (a >> 32) << "-" << std::setw(4)
<< ((a >> 16) & 0xFFFF) << "-" << std::setw(4) << (a & 0xFFFF) << "-" << std::setw(4) << (b >> 48) << "-"
<< std::setw(12) << (b & 0xFFFFFFFFFFFF);
return oss.str();
}

} // anonymous namespace

std::mutex CacheTransceiver::mDllMutex;

std::unique_ptr<BaseCacheTransceiver> CacheTransceiverFactory::createCacheTransceiver(
Expand Down Expand Up @@ -137,6 +163,74 @@ CacheTransceiver::CacheTransceiver(kv_cache_manager::BaseKVCacheManager* cacheMa
mGroupComm = std::make_shared<CacheTransceiverComm>(tensorrt_llm::pg_utils::get_world_pg());
}

// Generate instance ID on rank 0 and broadcast to all ranks in the session
// so every rank in the same gen/ctx instance shares the same ID.
{
if (mGroupComm->getRank() == 0)
{
mInstanceId = generateInstanceId();
}
if (useMPI())
{
int len = static_cast<int>(mInstanceId.size());
tensorrt_llm::mpi::MpiComm::session().bcast(&len, 1, mpi::MpiType::kINT32, 0);
mInstanceId.resize(len);
tensorrt_llm::mpi::MpiComm::session().bcast(mInstanceId.data(), len, mpi::MpiType::kCHAR, 0);
}
else
{
// PG path: rank 0 sends via allgather, others receive.
constexpr int kUuidLen = 36;
std::vector<char> sendBuf(kUuidLen, '\0');
if (mGroupComm->getRank() == 0)
{
std::copy_n(mInstanceId.begin(), std::min<size_t>(mInstanceId.size(), kUuidLen), sendBuf.begin());
}
std::vector<char> recvBuf(kUuidLen * mGroupComm->getSize(), '\0');
mGroupComm->allgather(std::ref(sendBuf), std::ref(recvBuf), {});
// Take rank 0's segment.
mInstanceId = std::string(recvBuf.begin(), recvBuf.begin() + kUuidLen);
}
}

// Calibrate steady_clock across ranks so that cross-node allgather
// in batchUpdateKVCacheTransferBW can compare time points.
// Python's _set_global_steady_clock_offset writes to the nanobind
// module's copy of sGlobalSteadyClockOffset, which is invisible to
// libtensorrt_llm.so (separate inline-static instances across .so
// boundaries). We redo the calibration here in the C++ library.
if (!LlmRequest::sGlobalSteadyClockOffset.has_value())
{
using Duration = LlmRequest::Duration;
using TimePoint = LlmRequest::TimePoint;
// Barrier + take local timestamp
if (useMPI())
{
tensorrt_llm::mpi::MpiComm::session().barrier();
}
auto localNow = std::chrono::steady_clock::now();
auto localNs = std::chrono::duration_cast<std::chrono::nanoseconds>(localNow.time_since_epoch()).count();

// Allgather timestamps from all ranks
std::vector<int64_t> allNs(mGroupComm->getSize(), 0);
if (useMPI())
{
tensorrt_llm::mpi::MpiComm::session().allgather(&localNs, allNs.data(), 1, mpi::MpiType::kINT64);
}
else
{
mGroupComm->allgather(localNs, std::ref(allNs), {});
}

// Offset = rank0's timestamp - my timestamp (same formula as Python)
auto offsetNs = allNs[0] - localNs;
LlmRequest::sGlobalSteadyClockOffset = Duration(offsetNs);

TLLM_LOG_INFO(mGroupComm->getRank(),
"CacheTransceiver: set sGlobalSteadyClockOffset = %.6f sec for rank %d",
static_cast<double>(offsetNs) / 1e9, mGroupComm->getRank());
}

if (worldConfig.isTensorParallel() || worldConfig.isContextParallel())
{
mGroupTensorParaComm = std::make_shared<CacheTransceiverComm>(
Expand Down Expand Up @@ -285,8 +379,10 @@ CacheTransceiver::CacheTransceiver(kv_cache_manager::BaseKVCacheManager* cacheMa
auto makeCacheTransferLayer
= [&]() { return CacheTransferLayer(*mCacheState, makeFormatter(), makeRnnFormatter()); };

mCacheSender = std::make_unique<CacheSender>(mManager.get(), worldConfig.getRank(), makeCacheTransferLayer());
mCacheReceiver = std::make_unique<CacheReceiver>(mManager.get(), worldConfig.getRank(), makeCacheTransferLayer());
mCacheSender
= std::make_unique<CacheSender>(mManager.get(), worldConfig.getRank(), makeCacheTransferLayer(), mInstanceId);
mCacheReceiver
= std::make_unique<CacheReceiver>(mManager.get(), worldConfig.getRank(), makeCacheTransferLayer(), mInstanceId);

initializeCommState();
}
Expand Down Expand Up @@ -416,67 +512,108 @@ std::vector<LlmRequest::RequestIdType> gatherRequestIds(
return retData;
}

void updateKVCacheTransferBW(std::shared_ptr<CacheTransceiverComm> const& mComm, LlmRequest* request)
void batchUpdateKVCacheTransferBW(
std::shared_ptr<CacheTransceiverComm> const& comm, std::vector<LlmRequest*> const& requests)
{
// Key-based merge: each rank serializes (requestId, start, end, size)
// tuples and we use allgatherv so ranks may have different request counts.
// The merge matches by requestId, not by position — this tolerates
// ordering differences and count mismatches across ranks.

namespace su = executor::serialize_utils;
int worldSize = mComm->getSize();
int const worldSize = comm->getSize();

// --- Serialize local entries keyed by requestId ---
std::size_t const numReqs = requests.size();

std::ostringstream oStream;
su::serialize(request->getKvCacheTransferStart(), oStream);
su::serialize(request->getKvCacheTransferEnd(), oStream);
su::serialize(numReqs, oStream);
for (auto* req : requests)
{
su::serialize(req->getContextPhaseParams().value().getReqId(), oStream);
su::serialize(req->getKvCacheTransferStart(), oStream);
su::serialize(req->getKvCacheTransferEnd(), oStream);
su::serialize(req->getKvCacheSize(), oStream);
}

auto str = oStream.str();
std::vector<char> sendBuffer(str.begin(), str.end());
auto sendBufferSize = sendBuffer.size();
auto recvBufferSize = sendBufferSize * worldSize;
std::vector<char> recvBuffer(recvBufferSize);
int const sendSize = static_cast<int>(sendBuffer.size());

// --- Step 1: allgather per-rank buffer sizes ---
std::vector<int> recvCounts(worldSize, 0);
if (useMPI())
{
mComm->allgather(sendBuffer.data(), recvBuffer.data(), sendBufferSize, mpi::MpiType::kCHAR);
comm->allgather(&sendSize, recvCounts.data(), 1, mpi::MpiType::kINT32);
}
else
{
mComm->allgather(std::ref(sendBuffer), std::ref(recvBuffer), {});
comm->allgather(sendSize, std::ref(recvCounts), {});
}

su::VectorWrapBuf<char> strbuf(recvBuffer);
std::istream is(&strbuf);

auto minStartTime = executor::RequestPerfMetrics::TimePoint::max();
auto maxEndTime = executor::RequestPerfMetrics::TimePoint::min();

for (int rank = 0; rank < worldSize; rank++)
// --- Step 2: allgatherv the serialized data ---
std::vector<int> displs(worldSize, 0);
int totalRecvSize = 0;
for (int r = 0; r < worldSize; ++r)
{
minStartTime = std::min(su::deserialize<executor::RequestPerfMetrics::TimePoint>(is), minStartTime);
maxEndTime = std::max(su::deserialize<executor::RequestPerfMetrics::TimePoint>(is), maxEndTime);
displs[r] = totalRecvSize;
totalRecvSize += recvCounts[r];
}

// Handle KV cache size separately - gather all sizes to the leader rank
std::size_t localKVCacheSize = request->getKvCacheSize();
std::vector<std::size_t> allKVCacheSizes(worldSize, 0);
std::vector<char> recvBuffer(totalRecvSize, 0);

if (useMPI())
{
mComm->allgather(&localKVCacheSize, allKVCacheSizes.data(), 1, mpi::MpiType::kUINT64);
comm->allgatherv(
sendBuffer.data(), sendSize, mpi::MpiType::kCHAR, recvBuffer.data(), recvCounts, displs, mpi::MpiType::kCHAR);
}
else
{
mComm->allgather(&localKVCacheSize, std::ref(allKVCacheSizes), {});
comm->allgatherv(std::ref(sendBuffer), std::ref(recvBuffer), recvCounts, {});
}

std::size_t totalKVCacheSize = 0;
for (int rank = 0; rank < worldSize; rank++)
// --- Step 3: Deserialize and merge by requestId ---
using TimePoint = executor::RequestPerfMetrics::TimePoint;
using ReqIdType = LlmRequest::RequestIdType;

struct MergedEntry
{
TimePoint minStart = TimePoint::max();
TimePoint maxEnd = TimePoint::min();
std::size_t totalSize = 0;
};
std::unordered_map<ReqIdType, MergedEntry> merged;

su::VectorWrapBuf<char> strbuf(recvBuffer);
std::istream is(&strbuf);

for (int rank = 0; rank < worldSize; ++rank)
{
totalKVCacheSize += allKVCacheSizes[rank];
auto rankNumReqs = su::deserialize<std::size_t>(is);
for (std::size_t i = 0; i < rankNumReqs; ++i)
{
auto rid = su::deserialize<ReqIdType>(is);
auto start = su::deserialize<TimePoint>(is);
auto end = su::deserialize<TimePoint>(is);
auto size = su::deserialize<std::size_t>(is);

auto& entry = merged[rid];
entry.minStart = std::min(entry.minStart, start);
entry.maxEnd = std::max(entry.maxEnd, end);
entry.totalSize += size;
}
}

// Update the latest KV cache transfer time for leader rank
if (mComm->getRank() == 0)
// --- Step 4: Update local requests ---
for (auto* req : requests)
{
request->setKvCacheTransferStart(minStartTime);
request->setKvCacheTransferEnd(maxEndTime);
request->setKvCacheSize(totalKVCacheSize);
auto reqId = req->getContextPhaseParams().value().getReqId();
auto it = merged.find(reqId);
if (it != merged.end())
{
req->setKvCacheTransferStart(it->second.minStart);
req->setKvCacheTransferEnd(it->second.maxEnd);
req->setKvCacheSize(it->second.totalSize);
}
}
}

Expand Down Expand Up @@ -709,6 +846,8 @@ void CacheTransceiver::checkGenTransferStatus(std::optional<int> const& atLeastR
" checkGenTransferStatus toCompleteIdSet size: %zu, atLeastRequestNum: %d ", toCompleteIdSet.size(),
atLeastRequestNum.value_or(0));
}
// Phase 1: Wait on futures and collect completed requests.
std::vector<LlmRequest*> completedRequests;
for (auto it = mRequesterFutures.begin(); it != mRequesterFutures.end();)
{
if (blockAll || toCompleteIdSet.find(it->first->mRequestId) != toCompleteIdSet.end())
Expand All @@ -717,13 +856,7 @@ void CacheTransceiver::checkGenTransferStatus(std::optional<int> const& atLeastR
{
it->second.get();
it->first->setState(LlmRequestState::kDISAGG_GENERATION_TRANS_COMPLETE);

// Gather the kv cache transfer time from all workers and update to leader rank
if (!common::getEnvKVCacheTimeOutputPath().empty())
{
auto syncComm = mCacheState->getParallelConfig().mEnableAttentionDP ? mGroupDataComm : mGroupComm;
updateKVCacheTransferBW(syncComm, it->first);
}
completedRequests.push_back(it->first);
}
catch (std::exception const& e)
{
Expand All @@ -750,6 +883,37 @@ void CacheTransceiver::checkGenTransferStatus(std::optional<int> const& atLeastR
++it;
}
}

// Phase 2: Batch-sync timing across ranks in one allgather (instead of per-request).
if (!completedRequests.empty() && !common::getEnvKVCacheTimeOutputPath().empty())
{
auto bwSyncComm = mCacheState->getParallelConfig().mEnableAttentionDP ? mGroupDataComm : mGroupComm;
batchUpdateKVCacheTransferBW(bwSyncComm, completedRequests);

// Write gen-side transfer summary CSV
{
std::lock_guard<std::mutex> lock(mGenTransferSummaryMutex);
if (!mGenTransferSummaryFile.is_open())
{
namespace fs = std::filesystem;
auto outputPath = fs::path(common::getEnvKVCacheTimeOutputPath());
fs::create_directories(outputPath);
int rank
= useMPI() ? mpi::MpiComm::world().getRank() : tensorrt_llm::pg_utils::get_world_pg()->getRank();
auto filePath
= outputPath / (mInstanceId + "_" + std::to_string(rank) + "_gen_transfer_summary.csv");
mGenTransferSummaryFile.open(filePath);
mGenTransferSummaryFile << "RequestID,gen_side_transfer_time(ms),kv_cache_size" << '\n';
}
for (auto* req : completedRequests)
{
auto reqId = req->getContextPhaseParams().value().getReqId();
mGenTransferSummaryFile << reqId << "," << req->getKvCacheTransferTimeMS() << ","
<< req->getKvCacheSize() << '\n';
}
mGenTransferSummaryFile << std::flush;
}
}
}

bool CacheTransceiver::checkGenTransferComplete() const
Expand Down
Loading
Loading