Skip to content
108 changes: 108 additions & 0 deletions cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h
Original file line number Diff line number Diff line change
Expand Up @@ -768,6 +768,72 @@ class WindowBlockManager
[[nodiscard]] SizeType32 addSequence(GenerationRequest& sequence, SizeType32 inputLength,
SizeType32 numContextBlocks, LlmRequest& llmRequest, bool isEnableBlockReuse);

//! \brief Per-request block allocation statistics from batch addSequence.
struct BatchSeqStats
{
SizeType32 prepopulatedLen{0};
SizeType32 allocTotalDelta{0};
SizeType32 allocNewDelta{0};
SizeType32 reusedDelta{0};
SizeType32 missedDelta{0};
};

//! \brief Result of Phase 1 (claim-only) of batch addSequence.
//! \details Holds matched blocks and prepared data so Phase 2 can proceed without
//! re-traversing the radix tree.
struct ClaimResult
{
struct ClaimedBlock
{
BlockPtr block;
SizeType32 numMatchedTokens; //!< tokens matched in this block
bool isPartialMatch;
bool needsCopy; //!< partial match on block with refs or non-leaf (needs getFreeBlock + copy in Phase 2)
bool isPlaceholder; //!< placeholder block (linear attention recurrent states)
bool shouldReleaseCopySource{false}; //!< last copier releases the claimed source after copy
};

std::vector<ClaimedBlock> claimedBlocks;
SizeType32 totalMatchedTokens{0};
SizeType32 latestMatchingNonPlaceholderBlockIdx{-1};
SizeType32 numSharedContextBlocks{0};
SizeType32 numContextBlocks{0};
bool shareLastContextBlockAmongBeams{true};
std::vector<BlockKey> blockKeys;
std::vector<executor::RetentionPriorityAndDuration> perBlockRetentions;
executor::KvCacheTransferMode mode{executor::KvCacheTransferMode::DRAM};
std::string directory;
};

//! \brief Tracks which request currently "owns" a partially-matched leaf block across
//! the batch Phase 1 loop, so that at most one request reuses the block in-place
//! while all others copy.
struct PartialClaimTracker
{
struct Entry
{
size_t requestIdx; //!< index of the request that currently owns the reuse
size_t claimedIdx; //!< index into that request's claimedBlocks vector
bool fullyMatched; //!< true once any request fully matches this block
};

//! Keyed by block ID.
std::unordered_map<KVCacheBlock::IdType, Entry> map;
};

//! \brief Batch add sequences with two-phase claim-then-onboard under a single lock.
//! \details Phase 1 claims all matching blocks across all requests (protecting from eviction).
//! Phase 2 onboards host blocks and allocates non-matching blocks.
//! The mCachedBlocksRootMutex is held for the entire operation.
//! \param sequences Per-request GenerationRequest references (parallel with other vectors).
//! \param inputLengths Per-request effective input length.
//! \param numContextBlocksVec Per-request number of context blocks.
//! \param llmRequests Per-request LlmRequest references.
//! \return Per-request prepopulatedPromptLen.
[[nodiscard]] std::vector<BatchSeqStats> addSequenceBatch(std::vector<GenerationRequest*> const& sequences,
std::vector<SizeType32> const& inputLengths, std::vector<SizeType32> const& numContextBlocksVec,
std::vector<std::reference_wrapper<LlmRequest>> const& llmRequests, bool isEnableBlockReuse);

//! \brief Allocate new block for each beam of the sequence.
//! \details Might free cached blocks if no free blocks are available.
void allocateBlock(GenerationRequest& sequence, bool shareAmongBeams);
Expand Down Expand Up @@ -1087,6 +1153,25 @@ class WindowBlockManager
executor::KvCacheTransferMode mode = executor::KvCacheTransferMode::DRAM, std::string const& directory = "",
bool isEnableBlockReuse = false);

//! \brief Phase 1: Walk radix tree and claim matching blocks.
//! \details Caller must hold mCachedBlocksRootMutex.
//! Uses \p tracker to coordinate partial-match ownership across requests in
//! the same batch. \p claimResults is the full vector so that a previous
//! request's ClaimedBlock can be retroactively marked needsCopy.
[[nodiscard]] ClaimResult claimMatchingBlocks(GenerationRequest& sequence, SizeType32 inputLength,
SizeType32 numContextBlocks, LlmRequest& llmRequest, size_t requestIdx, PartialClaimTracker& tracker,
std::vector<ClaimResult>& claimResults);

//! \brief Build ClaimResult metadata without walking the radix tree.
//! \details Used for non-reuse path where all blocks are freshly allocated.
[[nodiscard]] ClaimResult buildClaimResultMetadata(
GenerationRequest& sequence, SizeType32 inputLength, SizeType32 numContextBlocks, LlmRequest& llmRequest);

//! \brief Phase 2: Onboard claimed host blocks and allocate non-matching blocks.
//! \details Caller must hold mCachedBlocksRootMutex.
[[nodiscard]] SizeType32 onboardAndAllocateBlocks(
GenerationRequest& sequence, LlmRequest& llmRequest, ClaimResult& claimResult, bool isEnableBlockReuse);

//! \brief Free block and all it's descendants. This makes block a claimed leaf block.
void freeChildren(BlockPtr const& block);

Expand Down Expand Up @@ -1286,6 +1371,13 @@ class BlockManager
[[nodiscard]] SizeType32 addSequence(GenerationRequest& sequence, SizeType32 inputLength,
SizeType32 numContextBlocks, LlmRequest& llmRequest, SizeType32 windowSize, bool isEnableBlockReuse);

//! \brief Batch add sequences forwarding to WindowBlockManager::addSequenceBatch.
[[nodiscard]] std::vector<WindowBlockManager::BatchSeqStats> addSequenceBatch(
std::vector<GenerationRequest*> const& sequences, std::vector<SizeType32> const& inputLengths,
std::vector<SizeType32> const& numContextBlocksVec,
std::vector<std::reference_wrapper<LlmRequest>> const& llmRequests, SizeType32 windowSize,
bool isEnableBlockReuse);

void allocateBlock(GenerationRequest& sequence, SizeType32 windowSize);

//! \brief According to request's current position, copy data from the last full block to the next block (ignoring
Expand Down Expand Up @@ -1793,6 +1885,18 @@ class BaseKVCacheManager
OptionalRef<LlmRequest> llmRequest = std::nullopt)
= 0;

//! \brief Batch add sequences with two-phase claim-then-onboard strategy.
//! \details For each attention window, when block reuse is enabled, Phase 1 claims all matching
//! blocks across all requests (protecting them from eviction via PartialClaimTracker),
//! then Phase 2 onboards host blocks and allocates non-matching blocks. When block reuse
//! is disabled, buildClaimResultMetadata() prepares ClaimResult metadata without radix
//! tree traversal, and Phase 2 performs fresh allocation only. Supports variable sliding
//! window attention (VSWA) by iterating over all window sizes.
virtual void addSequenceBatch(
std::vector<std::tuple<LlmRequest::RequestIdType, SizeType32, SizeType32>> const& requestInfos,
std::vector<std::reference_wrapper<LlmRequest>> const& llmRequests)
= 0;

[[nodiscard]] virtual std::optional<KVCacheBlock::IdType> removeSequence(LlmRequest::RequestIdType requestId,
OptionalRef<LlmRequest const> llmRequest = std::nullopt, bool pinOnRelease = false)
= 0;
Expand Down Expand Up @@ -2168,6 +2272,10 @@ class KVCacheManager : public BaseKVCacheManager
void addSequence(LlmRequest::RequestIdType requestId, SizeType32 inputLength, SizeType32 beamWidth,
OptionalRef<LlmRequest> llmRequest = std::nullopt) override;

void addSequenceBatch(
std::vector<std::tuple<LlmRequest::RequestIdType, SizeType32, SizeType32>> const& requestInfos,
std::vector<std::reference_wrapper<LlmRequest>> const& llmRequests) override;

[[nodiscard]] std::optional<KVCacheBlock::IdType> removeSequence(LlmRequest::RequestIdType requestId,
OptionalRef<LlmRequest const> llmRequest = std::nullopt, bool pinOnRelease = false) override;

Expand Down
Loading
Loading