Skip to content

Commit c4093ba

Browse files
committed
[None][fix] Fix review findings: docstring, encoder token guard, share-last-block logic
- Update addSequenceBatch docstring to reflect support for both block-reuse and non-reuse paths via buildClaimResultMetadata. - Guard encoder unique token access in claimMatchingBlocks and onboardAndAllocateBlocks with hasUniqueTokens check, matching buildClaimResultMetadata and WindowBlockManager::addSequence (PR #10437) for cross-KV requests without encoder tokens (e.g., Whisper). - Align shareLastContextBlockAmongBeams in claimMatchingBlocks with the unified formula from loadOrAllocateBlocks (PR #10437): isShareLastContextBlock = kCROSS || inputLength % tokensPerBlock == 0. Signed-off-by: Jin Li <59594262+liji-nv@users.noreply.github.com>
1 parent 847c922 commit c4093ba

3 files changed

Lines changed: 60 additions & 35 deletions

File tree

cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1888,11 +1888,13 @@ class BaseKVCacheManager
18881888
OptionalRef<LlmRequest> llmRequest = std::nullopt)
18891889
= 0;
18901890

1891-
//! \brief Batch add sequences with two-phase claim-then-onboard to prevent host offloading eviction.
1892-
//! \details For each attention window, Phase 1 claims all matching blocks across all requests
1893-
//! (protecting them from eviction), then Phase 2 onboards host blocks and allocates
1894-
//! non-matching blocks. Supports variable sliding window attention (VSWA) by iterating
1895-
//! over all window sizes. Requires block reuse to be enabled.
1891+
//! \brief Batch add sequences with two-phase claim-then-onboard strategy.
1892+
//! \details For each attention window, when block reuse is enabled, Phase 1 claims all matching
1893+
//! blocks across all requests (protecting them from eviction via PartialClaimTracker),
1894+
//! then Phase 2 onboards host blocks and allocates non-matching blocks. When block reuse
1895+
//! is disabled, buildClaimResultMetadata() prepares ClaimResult metadata without radix
1896+
//! tree traversal, and Phase 2 performs fresh allocation only. Supports variable sliding
1897+
//! window attention (VSWA) by iterating over all window sizes.
18961898
virtual void addSequenceBatch(
18971899
std::vector<std::tuple<LlmRequest::RequestIdType, SizeType32, SizeType32>> const& requestInfos,
18981900
std::vector<std::reference_wrapper<LlmRequest>> const& llmRequests)

cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp

Lines changed: 51 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1409,19 +1409,26 @@ WindowBlockManager::ClaimResult WindowBlockManager::claimMatchingBlocks(Generati
14091409
auto const [seqIt, emplaceDone] = mAllocatedBlocksPerSeq.emplace(requestId, std::vector<BlockPtr>{});
14101410
TLLM_CHECK(emplaceDone);
14111411

1412-
// Prepare block keys — same logic as WindowBlockManager::addSequence lines 1437-1465
1412+
// Prepare block keys — guard for cross-KV without encoder tokens (e.g., Whisper).
14131413
auto constexpr beamIdx = 0;
1414-
auto const& uniqueTokens = (mCacheType == CacheType::kSELF || mCacheType == CacheType::kSELFKONLY)
1415-
? llmRequest.getUniqueTokens(beamIdx)
1416-
: *(llmRequest.getEncoderUniqueTokens().value());
1414+
bool const isSelfCache = mCacheType == CacheType::kSELF || mCacheType == CacheType::kSELFKONLY;
1415+
bool const hasUniqueTokens = isSelfCache
1416+
|| (llmRequest.getEncoderUniqueTokens().has_value() && llmRequest.getEncoderUniqueTokens().value());
14171417

1418-
auto blockedUniqueTokens = chopVectorIntoBlocks<UniqueToken>(uniqueTokens, inputLength - 1, mTokensPerBlock, true);
1419-
if (inputLength % mTokensPerBlock == 1)
1418+
if (hasUniqueTokens)
14201419
{
1421-
blockedUniqueTokens.emplace_back();
1422-
}
1420+
auto const& uniqueTokens
1421+
= isSelfCache ? llmRequest.getUniqueTokens(beamIdx) : *(llmRequest.getEncoderUniqueTokens().value());
14231422

1424-
result.blockKeys = buildBlockKeys(blockedUniqueTokens, llmRequest);
1423+
auto blockedUniqueTokens
1424+
= chopVectorIntoBlocks<UniqueToken>(uniqueTokens, inputLength - 1, mTokensPerBlock, true);
1425+
if (inputLength % mTokensPerBlock == 1)
1426+
{
1427+
blockedUniqueTokens.emplace_back();
1428+
}
1429+
1430+
result.blockKeys = buildBlockKeys(blockedUniqueTokens, llmRequest);
1431+
}
14251432

14261433
auto config = llmRequest.getKvCacheRetentionConfig();
14271434
result.perBlockRetentions = config.value_or(executor::KvCacheRetentionConfig())
@@ -1441,14 +1448,12 @@ WindowBlockManager::ClaimResult WindowBlockManager::claimMatchingBlocks(Generati
14411448
// Phase 1: Walk radix tree, claim matching blocks — no onboard, no getFreeBlock
14421449
// NOTE: Caller must hold mCachedBlocksRootMutex.
14431450

1444-
// Compute shareLastContextBlockAmongBeams — same logic as WindowBlockManager::addSequence
1445-
result.shareLastContextBlockAmongBeams = sequence.getBeamWidth() == 1;
1446-
if (isRecurrentState())
1447-
{
1448-
result.shareLastContextBlockAmongBeams |= inputLength % mTokensPerBlock == 0;
1449-
}
1450-
1451-
result.numSharedContextBlocks = result.shareLastContextBlockAmongBeams ? numContextBlocks : numContextBlocks - 1;
1451+
// Compute shareLastContextBlockAmongBeams — aligned with loadOrAllocateBlocks (PR #10437).
1452+
auto const beamWidth = sequence.getBeamWidth();
1453+
bool const isShareLastContextBlock = mCacheType == CacheType::kCROSS || inputLength % mTokensPerBlock == 0;
1454+
result.numSharedContextBlocks
1455+
= (beamWidth > 1 && !isShareLastContextBlock) ? numContextBlocks - 1 : numContextBlocks;
1456+
result.shareLastContextBlockAmongBeams = result.numSharedContextBlocks == numContextBlocks;
14521457
auto searchRoot = mCachedBlocksRoot;
14531458
auto blockItr = result.blockKeys.begin();
14541459

@@ -1507,19 +1512,28 @@ WindowBlockManager::ClaimResult WindowBlockManager::claimMatchingBlocks(Generati
15071512
auto tIt = tracker.map.find(blockId);
15081513
if (tIt != tracker.map.end())
15091514
{
1510-
// Previous copier no longer responsible for release.
1511-
claimResults[tIt->second.requestIdx]
1512-
.claimedBlocks[tIt->second.claimedIdx]
1513-
.shouldReleaseCopySource
1514-
= false;
1515+
if (tIt->second.fullyMatched)
1516+
{
1517+
// A full match holds this block — do not release.
1518+
claimed.shouldReleaseCopySource = false;
1519+
}
1520+
else
1521+
{
1522+
// Previous copier no longer responsible for release.
1523+
claimResults[tIt->second.requestIdx]
1524+
.claimedBlocks[tIt->second.claimedIdx]
1525+
.shouldReleaseCopySource
1526+
= false;
1527+
claimed.shouldReleaseCopySource = true;
1528+
}
15151529
tIt->second.requestIdx = requestIdx;
15161530
tIt->second.claimedIdx = result.claimedBlocks.size();
15171531
}
15181532
else
15191533
{
15201534
tracker.map[blockId] = {requestIdx, result.claimedBlocks.size(), /*fullyMatched=*/false};
1535+
claimed.shouldReleaseCopySource = true;
15211536
}
1522-
claimed.shouldReleaseCopySource = true;
15231537
}
15241538
}
15251539
else
@@ -1760,11 +1774,20 @@ SizeType32 WindowBlockManager::onboardAndAllocateBlocks(
17601774

17611775
// Update stats and return prepopulated length
17621776
mReusedTokens += static_cast<double>(numMatchedTokens);
1763-
auto constexpr beamIdx = 0;
1764-
auto const& uniqueTokens = (mCacheType == CacheType::kSELF || mCacheType == CacheType::kSELFKONLY)
1765-
? llmRequest.getUniqueTokens(beamIdx)
1766-
: *(llmRequest.getEncoderUniqueTokens().value());
1767-
mTotalInputTokens += static_cast<double>(uniqueTokens.size());
1777+
bool const isSelfCache = mCacheType == CacheType::kSELF || mCacheType == CacheType::kSELFKONLY;
1778+
bool const hasUniqueTokens = isSelfCache
1779+
|| (llmRequest.getEncoderUniqueTokens().has_value() && llmRequest.getEncoderUniqueTokens().value());
1780+
if (hasUniqueTokens)
1781+
{
1782+
auto constexpr beamIdx = 0;
1783+
auto const& uniqueTokens
1784+
= isSelfCache ? llmRequest.getUniqueTokens(beamIdx) : *(llmRequest.getEncoderUniqueTokens().value());
1785+
mTotalInputTokens += static_cast<double>(uniqueTokens.size());
1786+
}
1787+
else
1788+
{
1789+
mTotalInputTokens += static_cast<double>(claimResult.numContextBlocks * mTokensPerBlock);
1790+
}
17681791

17691792
SizeType32 numConnectorMatchedTokens = 0;
17701793
if (mKvCacheConnectorManager && !llmRequest.isDummyRequest())

cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
* SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
33
* SPDX-License-Identifier: Apache-2.0
44
*
55
* Licensed under the Apache License, Version 2.0 (the "License");
@@ -68,7 +68,7 @@ std::optional<tensorrt_llm::runtime::ITensor::UniquePtr> from_torch(std::optiona
6868
class PyKvCacheManager : public tbk::BaseKVCacheManager
6969
{
7070
public:
71-
NB_TRAMPOLINE(tbk::BaseKVCacheManager, 37);
71+
NB_TRAMPOLINE(tbk::BaseKVCacheManager, 39);
7272

7373
// using BaseKVCacheManager::BaseKVCacheManager; // Inherit constructors
7474
void allocatePools(bool useUvm = false) override

0 commit comments

Comments
 (0)