Skip to content

Commit a9a7f73

Browse files
committed
[None][fix] Batch addSequence with pre-claim to fix host offloading MNT overflow
When host offloading is enabled, onboarding a host block to GPU during addSequence can trigger eviction of other reusable host blocks from the radix tree. This causes actual KV cache reuse to be less than the scheduler estimated, leading to max_num_tokens (MNT) overflow assertions. Add a new addSequenceBatch API that processes all first-chunk context requests in two phases: - Phase 1: Walk the radix tree and claimBlock() for all matching blocks across all requests. No onboarding, no allocation. This protects reusable blocks from eviction. - Phase 2: Onboard host blocks and allocate non-matching blocks. Since all reusable blocks are already claimed, evictions during onboarding cannot touch them. On the Python side, replace the TOCTOU-prone revalidation loop (count_reusable_blocks + budget check) with a single batch call. Signed-off-by: Jin Li <59594262+liji-nv@users.noreply.github.com>
1 parent 3e942cc commit a9a7f73

4 files changed

Lines changed: 554 additions & 13 deletions

File tree

cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -735,6 +735,55 @@ class WindowBlockManager
735735
//! \brief Assign blocks for new sequence. Does not try to reuse blocks.
736736
void addSequence(GenerationRequest& sequence, SizeType32 numContextBlocks, bool isShareLastContextBlock);
737737

738+
//! \brief Per-request block allocation statistics from batch addSequence.
739+
struct BatchSeqStats
740+
{
741+
SizeType32 prepopulatedLen{0};
742+
SizeType32 allocTotalDelta{0};
743+
SizeType32 allocNewDelta{0};
744+
SizeType32 reusedDelta{0};
745+
SizeType32 missedDelta{0};
746+
};
747+
748+
//! \brief Result of Phase 1 (claim-only) of batch addSequence.
749+
//! \details Holds matched blocks and prepared data so Phase 2 can proceed without
750+
//! re-traversing the radix tree.
751+
struct ClaimResult
752+
{
753+
struct ClaimedBlock
754+
{
755+
BlockPtr block;
756+
SizeType32 numMatchedTokens; //!< tokens matched in this block
757+
bool isPartialMatch;
758+
bool needsCopy; //!< partial match on block with refs or non-leaf (needs getFreeBlock + copy in Phase 2)
759+
bool isPlaceholder; //!< placeholder block (linear attention recurrent states)
760+
};
761+
762+
std::vector<ClaimedBlock> claimedBlocks;
763+
SizeType32 totalMatchedTokens{0};
764+
SizeType32 latestMatchingNonPlaceholderBlockIdx{-1};
765+
SizeType32 numSharedContextBlocks{0};
766+
SizeType32 numContextBlocks{0};
767+
bool shareLastContextBlockAmongBeams{true};
768+
std::vector<BlockKey> blockKeys;
769+
std::vector<executor::RetentionPriorityAndDuration> perBlockRetentions;
770+
executor::KvCacheTransferMode mode{executor::KvCacheTransferMode::DRAM};
771+
std::string directory;
772+
};
773+
774+
//! \brief Batch add sequences with two-phase claim-then-onboard under a single lock.
775+
//! \details Phase 1 claims all matching blocks across all requests (protecting from eviction).
776+
//! Phase 2 onboards host blocks and allocates non-matching blocks.
777+
//! The mCachedBlocksRootMutex is held for the entire operation.
778+
//! \param sequences Per-request GenerationRequest references (parallel with other vectors).
779+
//! \param inputLengths Per-request effective input length.
780+
//! \param numContextBlocksVec Per-request number of context blocks.
781+
//! \param llmRequests Per-request LlmRequest references.
782+
//! \return Per-request prepopulatedPromptLen.
783+
[[nodiscard]] std::vector<BatchSeqStats> addSequenceBatch(std::vector<GenerationRequest*> const& sequences,
784+
std::vector<SizeType32> const& inputLengths, std::vector<SizeType32> const& numContextBlocksVec,
785+
std::vector<std::reference_wrapper<LlmRequest>> const& llmRequests);
786+
738787
//! \brief Allocate new block for each beam of the sequence.
739788
//! \details Might free cached blocks if no free blocks are available.
740789
void allocateBlock(GenerationRequest& sequence, bool shareAmongBeams);
@@ -1048,6 +1097,16 @@ class WindowBlockManager
10481097
bool shareLastContextBlockAmongBeams, executor::KvCacheTransferMode mode = executor::KvCacheTransferMode::DRAM,
10491098
std::string const& directory = "");
10501099

1100+
//! \brief Phase 1 (lock-free): Walk radix tree and claim matching blocks.
1101+
//! \details Caller must hold mCachedBlocksRootMutex.
1102+
[[nodiscard]] ClaimResult claimMatchingBlocks(
1103+
GenerationRequest& sequence, SizeType32 inputLength, SizeType32 numContextBlocks, LlmRequest& llmRequest);
1104+
1105+
//! \brief Phase 2 (lock-free): Onboard claimed host blocks and allocate non-matching blocks.
1106+
//! \details Caller must hold mCachedBlocksRootMutex.
1107+
[[nodiscard]] SizeType32 onboardAndAllocateBlocks(
1108+
GenerationRequest& sequence, LlmRequest& llmRequest, ClaimResult& claimResult);
1109+
10511110
//! \brief Free block and all it's descendants. This makes block a claimed leaf block.
10521111
void freeChildren(BlockPtr const& block);
10531112

@@ -1242,6 +1301,12 @@ class BlockManager
12421301
void addSequence(
12431302
GenerationRequest& sequence, SizeType32 numContextBlocks, SizeType32 windowSize, bool isShareLastContextBlock);
12441303

1304+
//! \brief Batch add sequences forwarding to WindowBlockManager::addSequenceBatch.
1305+
[[nodiscard]] std::vector<WindowBlockManager::BatchSeqStats> addSequenceBatch(
1306+
std::vector<GenerationRequest*> const& sequences, std::vector<SizeType32> const& inputLengths,
1307+
std::vector<SizeType32> const& numContextBlocksVec,
1308+
std::vector<std::reference_wrapper<LlmRequest>> const& llmRequests, SizeType32 windowSize);
1309+
12451310
void allocateBlock(GenerationRequest& sequence, SizeType32 windowSize);
12461311

12471312
//! \brief According to request's current position, copy data from the last full block to the next block (ignoring
@@ -1732,6 +1797,15 @@ class BaseKVCacheManager
17321797
OptionalRef<LlmRequest> llmRequest = std::nullopt)
17331798
= 0;
17341799

1800+
//! \brief Batch add sequences with two-phase claim-then-onboard to prevent host offloading eviction.
1801+
//! \details Phase 1 claims all matching blocks across all requests (protecting them from eviction).
1802+
//! Phase 2 onboards host blocks and allocates non-matching blocks.
1803+
//! Requires block reuse enabled and single attention window.
1804+
virtual void addSequenceBatch(
1805+
std::vector<std::tuple<LlmRequest::RequestIdType, SizeType32, SizeType32>> const& requestInfos,
1806+
std::vector<std::reference_wrapper<LlmRequest>> const& llmRequests)
1807+
= 0;
1808+
17351809
[[nodiscard]] virtual std::optional<KVCacheBlock::IdType> removeSequence(LlmRequest::RequestIdType requestId,
17361810
OptionalRef<LlmRequest const> llmRequest = std::nullopt, bool pinOnRelease = false)
17371811
= 0;
@@ -2102,6 +2176,10 @@ class KVCacheManager : public BaseKVCacheManager
21022176
void addSequence(LlmRequest::RequestIdType requestId, SizeType32 inputLength, SizeType32 beamWidth,
21032177
OptionalRef<LlmRequest> llmRequest = std::nullopt) override;
21042178

2179+
void addSequenceBatch(
2180+
std::vector<std::tuple<LlmRequest::RequestIdType, SizeType32, SizeType32>> const& requestInfos,
2181+
std::vector<std::reference_wrapper<LlmRequest>> const& llmRequests) override;
2182+
21052183
[[nodiscard]] std::optional<KVCacheBlock::IdType> removeSequence(LlmRequest::RequestIdType requestId,
21062184
OptionalRef<LlmRequest const> llmRequest = std::nullopt, bool pinOnRelease = false) override;
21072185

0 commit comments

Comments
 (0)