Skip to content

Commit 847c922

Browse files
committed
[None][fix] Release claimed non-leaf copy source after copy in addSequenceBatch
When claimMatchingBlocks claims an unreferenced non-leaf partial-match block to protect it during Phase 2 copies, the block is removed from the free queue. Previously it was only released at the end of the batch (deferred release), which was too late — subsequent getFreeBlock calls for non-matching blocks could not find free blocks in tight pools. Fix: use PartialClaimTracker to assign release responsibility to the last copier in the batch. After the copy completes in Phase 2, the responsible copier releases the source back to the free queue via shouldReleaseCopySource flag on ClaimedBlock. A full match on the same block revokes the release responsibility. Also migrate BlockManagerReuseTest to addSequenceBatch to validate the fix, and update the full-match tracker branch to handle both leaf and non-leaf blocks. Signed-off-by: Jin Li <59594262+liji-nv@users.noreply.github.com>
1 parent 033d8d9 commit 847c922

3 files changed

Lines changed: 74 additions & 24 deletions

File tree

cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -790,6 +790,7 @@ class WindowBlockManager
790790
bool isPartialMatch;
791791
bool needsCopy; //!< partial match on block with refs or non-leaf (needs getFreeBlock + copy in Phase 2)
792792
bool isPlaceholder; //!< placeholder block (linear attention recurrent states)
793+
bool shouldReleaseCopySource{false}; //!< last copier releases the claimed source after copy
793794
};
794795

795796
std::vector<ClaimedBlock> claimedBlocks;

cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1497,11 +1497,29 @@ WindowBlockManager::ClaimResult WindowBlockManager::claimMatchingBlocks(Generati
14971497
claimed.needsCopy = true;
14981498
if (!matchingBlock->hasRefs())
14991499
{
1500-
// Unreferenced non-leaf: could be in the free queue and evictable.
1501-
// Claim to protect during Phase 2 copies; deferred release at batch end.
1500+
// Unreferenced non-leaf: claim to protect from eviction during copies.
1501+
// Use tracker to assign release responsibility to the last copier.
15021502
mEvictionPolicy->claimBlock(matchingBlock, result.perBlockRetentions[bi].retentionPriority,
15031503
result.perBlockRetentions[bi].durationMs);
15041504
result.claimedCopySource = matchingBlock;
1505+
1506+
auto const blockId = matchingBlock->getBlockId();
1507+
auto tIt = tracker.map.find(blockId);
1508+
if (tIt != tracker.map.end())
1509+
{
1510+
// Previous copier no longer responsible for release.
1511+
claimResults[tIt->second.requestIdx]
1512+
.claimedBlocks[tIt->second.claimedIdx]
1513+
.shouldReleaseCopySource
1514+
= false;
1515+
tIt->second.requestIdx = requestIdx;
1516+
tIt->second.claimedIdx = result.claimedBlocks.size();
1517+
}
1518+
else
1519+
{
1520+
tracker.map[blockId] = {requestIdx, result.claimedBlocks.size(), /*fullyMatched=*/false};
1521+
}
1522+
claimed.shouldReleaseCopySource = true;
15051523
}
15061524
}
15071525
else
@@ -1550,15 +1568,18 @@ WindowBlockManager::ClaimResult WindowBlockManager::claimMatchingBlocks(Generati
15501568
mEvictionPolicy->claimBlock(matchingBlock, result.perBlockRetentions[bi].retentionPriority,
15511569
result.perBlockRetentions[bi].durationMs);
15521570

1553-
// If a previous request was going to reuse this block via partial match,
1571+
// If a previous request was going to reuse or release this block via partial match,
15541572
// it must now copy instead — a full match takes priority.
1555-
if (matchingBlock->isLeaf())
15561573
{
15571574
auto const blockId = matchingBlock->getBlockId();
15581575
auto tIt = tracker.map.find(blockId);
15591576
if (tIt != tracker.map.end() && !tIt->second.fullyMatched)
15601577
{
15611578
claimResults[tIt->second.requestIdx].claimedBlocks[tIt->second.claimedIdx].needsCopy = true;
1579+
claimResults[tIt->second.requestIdx]
1580+
.claimedBlocks[tIt->second.claimedIdx]
1581+
.shouldReleaseCopySource
1582+
= false;
15621583
tIt->second.fullyMatched = true;
15631584
}
15641585
else
@@ -1611,6 +1632,14 @@ SizeType32 WindowBlockManager::onboardAndAllocateBlocks(
16111632
*blockItr, blockItr->uniqueTokens.size() == static_cast<size_t>(mTokensPerBlock));
16121633
}
16131634
claimed.block->setHash();
1635+
// Release the claimed non-leaf copy source back to the free queue now that
1636+
// the copy is done. The tracker ensures only the last copier releases.
1637+
if (claimed.shouldReleaseCopySource && claimResult.claimedCopySource
1638+
&& !claimResult.claimedCopySource->hasRefs())
1639+
{
1640+
mEvictionPolicy->releaseBlock(claimResult.claimedCopySource);
1641+
claimResult.claimedCopySource = nullptr;
1642+
}
16141643
TLLM_LOG_DEBUG("%s::onboardAndAllocateBlocks for request %lu - Copied partially filled block %d",
16151644
mLogPrefix.c_str(), sequence.getRequestId(), matchingBlockId);
16161645
}

cpp/tests/unit_tests/batch_manager/kvCacheManagerTest.cpp

Lines changed: 40 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -810,8 +810,10 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseTest)
810810
auto promptLen0 = llmRequest0->getNumTokens(beamIdx);
811811
auto numContextBlocks0 = tc::ceilDiv(promptLen0, blockManager.getTokensPerBlock());
812812
blockManager.holdSequence(seq0.getRequestId());
813-
auto prepopulatedPromptLen0 = blockManager.addSequence(
814-
seq0, promptLen0, numContextBlocks0, *llmRequest0, maxAttentionWindow, /*isEnableBlockReuse=*/true);
813+
auto prepopulatedPromptLen0 = blockManager
814+
.addSequenceBatch({&seq0}, {promptLen0}, {numContextBlocks0},
815+
{std::ref(*llmRequest0)}, maxAttentionWindow, /*isEnableBlockReuse=*/true)[0]
816+
.prepopulatedLen;
815817
llmRequest0->setPrepopulatedPromptLen(prepopulatedPromptLen0, blockManager.getTokensPerBlock());
816818
EXPECT_EQ(llmRequest0->getContextCurrentPosition(), 0);
817819
EXPECT_THAT(seq0.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({0, 1, 2}));
@@ -840,8 +842,10 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseTest)
840842
auto promptLen1 = llmRequest1->getNumTokens(beamIdx);
841843
auto numContextBlocks1 = tc::ceilDiv(promptLen1, blockManager.getTokensPerBlock());
842844
blockManager.holdSequence(seq1.getRequestId());
843-
auto prepopulatedPromptLen1 = blockManager.addSequence(
844-
seq1, promptLen1, numContextBlocks1, *llmRequest1, maxAttentionWindow, /*isEnableBlockReuse=*/true);
845+
auto prepopulatedPromptLen1 = blockManager
846+
.addSequenceBatch({&seq1}, {promptLen1}, {numContextBlocks1},
847+
{std::ref(*llmRequest1)}, maxAttentionWindow, /*isEnableBlockReuse=*/true)[0]
848+
.prepopulatedLen;
845849
llmRequest1->setPrepopulatedPromptLen(prepopulatedPromptLen1, blockManager.getTokensPerBlock());
846850
EXPECT_EQ(llmRequest1->getContextCurrentPosition(), 2 * tokensPerBlock);
847851
EXPECT_THAT(seq1.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({0, 1, 3}));
@@ -870,8 +874,10 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseTest)
870874
promptLen0 = llmRequest0->getNumTokens(beamIdx);
871875
numContextBlocks0 = tc::ceilDiv(promptLen0, blockManager.getTokensPerBlock());
872876
blockManager.holdSequence(seq0_dup.getRequestId());
873-
prepopulatedPromptLen0 = blockManager.addSequence(
874-
seq0_dup, promptLen0, numContextBlocks0, *llmRequest0, maxAttentionWindow, /*isEnableBlockReuse=*/true);
877+
prepopulatedPromptLen0 = blockManager
878+
.addSequenceBatch({&seq0_dup}, {promptLen0}, {numContextBlocks0},
879+
{std::ref(*llmRequest0)}, maxAttentionWindow, /*isEnableBlockReuse=*/true)[0]
880+
.prepopulatedLen;
875881
llmRequest0->setPrepopulatedPromptLen(prepopulatedPromptLen0, blockManager.getTokensPerBlock());
876882
EXPECT_EQ(llmRequest0->getContextCurrentPosition(), promptLen0 - 1);
877883
EXPECT_THAT(seq0_dup.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({0, 1, 2}));
@@ -888,8 +894,10 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseTest)
888894
promptLen1 = llmRequest1->getNumTokens(beamIdx);
889895
numContextBlocks1 = tc::ceilDiv(promptLen1, blockManager.getTokensPerBlock());
890896
blockManager.holdSequence(seq1_dup.getRequestId());
891-
prepopulatedPromptLen1 = blockManager.addSequence(
892-
seq1_dup, promptLen1, numContextBlocks1, *llmRequest1, maxAttentionWindow, /*isEnableBlockReuse=*/true);
897+
prepopulatedPromptLen1 = blockManager
898+
.addSequenceBatch({&seq1_dup}, {promptLen1}, {numContextBlocks1},
899+
{std::ref(*llmRequest1)}, maxAttentionWindow, /*isEnableBlockReuse=*/true)[0]
900+
.prepopulatedLen;
893901
llmRequest1->setPrepopulatedPromptLen(prepopulatedPromptLen1, blockManager.getTokensPerBlock());
894902
EXPECT_EQ(llmRequest1->getContextCurrentPosition(), 2 * tokensPerBlock);
895903
EXPECT_THAT(seq1_dup.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({0, 1, 4}));
@@ -925,8 +933,10 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseTest)
925933
auto promptLen2 = llmRequest2->getNumTokens(beamIdx);
926934
auto numContextBlocks2 = tc::ceilDiv(promptLen2, blockManager.getTokensPerBlock());
927935
blockManager.holdSequence(seq2.getRequestId());
928-
auto prepopulatedPromptLen2 = blockManager.addSequence(
929-
seq2, promptLen2, numContextBlocks2, *llmRequest2, maxAttentionWindow, /*isEnableBlockReuse=*/true);
936+
auto prepopulatedPromptLen2 = blockManager
937+
.addSequenceBatch({&seq2}, {promptLen2}, {numContextBlocks2},
938+
{std::ref(*llmRequest2)}, maxAttentionWindow, /*isEnableBlockReuse=*/true)[0]
939+
.prepopulatedLen;
930940
llmRequest2->setPrepopulatedPromptLen(prepopulatedPromptLen2, blockManager.getTokensPerBlock());
931941
EXPECT_EQ(llmRequest2->getContextCurrentPosition(), tokensPerBlock);
932942
EXPECT_THAT(seq2.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({0, 5}));
@@ -949,8 +959,10 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseTest)
949959
auto promptLen3 = llmRequest3->getNumTokens(beamIdx);
950960
auto numContextBlocks3 = tc::ceilDiv(promptLen3, blockManager.getTokensPerBlock());
951961
blockManager.holdSequence(seq3.getRequestId());
952-
auto prepopulatedPromptLen3 = blockManager.addSequence(
953-
seq3, promptLen3, numContextBlocks3, *llmRequest3, maxAttentionWindow, /*isEnableBlockReuse=*/true);
962+
auto prepopulatedPromptLen3 = blockManager
963+
.addSequenceBatch({&seq3}, {promptLen3}, {numContextBlocks3},
964+
{std::ref(*llmRequest3)}, maxAttentionWindow, /*isEnableBlockReuse=*/true)[0]
965+
.prepopulatedLen;
954966
llmRequest3->setPrepopulatedPromptLen(prepopulatedPromptLen3, blockManager.getTokensPerBlock());
955967
EXPECT_EQ(llmRequest3->getContextCurrentPosition(), numTokens - 1);
956968
EXPECT_THAT(seq3.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({0, 1, 4}));
@@ -986,8 +998,10 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseTest)
986998
auto promptLen4 = llmRequest4->getNumTokens(beamIdx);
987999
auto numContextBlocks4 = tc::ceilDiv(promptLen4, blockManager.getTokensPerBlock());
9881000
blockManager.holdSequence(seq4.getRequestId());
989-
auto prepopulatedPromptLen4 = blockManager.addSequence(
990-
seq4, promptLen4, numContextBlocks4, *llmRequest4, maxAttentionWindow, /*isEnableBlockReuse=*/true);
1001+
auto prepopulatedPromptLen4 = blockManager
1002+
.addSequenceBatch({&seq4}, {promptLen4}, {numContextBlocks4},
1003+
{std::ref(*llmRequest4)}, maxAttentionWindow, /*isEnableBlockReuse=*/true)[0]
1004+
.prepopulatedLen;
9911005
llmRequest4->setPrepopulatedPromptLen(prepopulatedPromptLen4, blockManager.getTokensPerBlock());
9921006
EXPECT_EQ(llmRequest4->getContextCurrentPosition(), promptLen4 - 1);
9931007
EXPECT_THAT(seq4.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({0, 1, 4}));
@@ -1021,8 +1035,10 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseTest)
10211035
promptLen4 = llmRequest4->getNumTokens(beamIdx);
10221036
numContextBlocks4 = tc::ceilDiv(promptLen4, blockManager.getTokensPerBlock());
10231037
blockManager.holdSequence(seq4_dup.getRequestId());
1024-
prepopulatedPromptLen4 = blockManager.addSequence(
1025-
seq4_dup, promptLen4, numContextBlocks4, *llmRequest4, maxAttentionWindow, /*isEnableBlockReuse=*/true);
1038+
prepopulatedPromptLen4 = blockManager
1039+
.addSequenceBatch({&seq4_dup}, {promptLen4}, {numContextBlocks4},
1040+
{std::ref(*llmRequest4)}, maxAttentionWindow, /*isEnableBlockReuse=*/true)[0]
1041+
.prepopulatedLen;
10261042
llmRequest4->setPrepopulatedPromptLen(prepopulatedPromptLen4, blockManager.getTokensPerBlock());
10271043
EXPECT_EQ(llmRequest4->getContextCurrentPosition(), promptLen4 - 2);
10281044
EXPECT_THAT(seq4_dup.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({0, 1, 2}));
@@ -1051,8 +1067,10 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseTest)
10511067
auto promptLen5 = llmRequest5->getNumTokens(beamIdx);
10521068
auto numContextBlocks5 = tc::ceilDiv(promptLen5, blockManager.getTokensPerBlock());
10531069
blockManager.holdSequence(seq5.getRequestId());
1054-
auto prepopulatedPromptLen5 = blockManager.addSequence(
1055-
seq5, promptLen5, numContextBlocks5, *llmRequest5, maxAttentionWindow, /*isEnableBlockReuse=*/true);
1070+
auto prepopulatedPromptLen5 = blockManager
1071+
.addSequenceBatch({&seq5}, {promptLen5}, {numContextBlocks5},
1072+
{std::ref(*llmRequest5)}, maxAttentionWindow, /*isEnableBlockReuse=*/true)[0]
1073+
.prepopulatedLen;
10561074
llmRequest5->setPrepopulatedPromptLen(prepopulatedPromptLen5, blockManager.getTokensPerBlock());
10571075
llmRequest5->addNewToken(0, beamIdx);
10581076
EXPECT_EQ(llmRequest5->getContextCurrentPosition(), 1); // incidental reuse
@@ -1079,8 +1097,10 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseTest)
10791097
auto promptLen6 = llmRequest6->getNumTokens(beamIdx);
10801098
auto numContextBlocks6 = tc::ceilDiv(promptLen6, blockManager.getTokensPerBlock());
10811099
blockManager.holdSequence(seq6.getRequestId());
1082-
auto prepopulatedPromptLen6 = blockManager.addSequence(
1083-
seq6, promptLen6, numContextBlocks6, *llmRequest6, maxAttentionWindow, /*isEnableBlockReuse=*/true);
1100+
auto prepopulatedPromptLen6 = blockManager
1101+
.addSequenceBatch({&seq6}, {promptLen6}, {numContextBlocks6},
1102+
{std::ref(*llmRequest6)}, maxAttentionWindow, /*isEnableBlockReuse=*/true)[0]
1103+
.prepopulatedLen;
10841104
llmRequest6->setPrepopulatedPromptLen(prepopulatedPromptLen6, blockManager.getTokensPerBlock());
10851105
llmRequest6->addNewToken(0, beamIdx);
10861106
// no reuse occurs because we are unable to reuse last input token and inputLength6 == 1.

0 commit comments

Comments
 (0)