[TRTLLM-11540][feat] Revert EAGLE3 dynamic tree speculative decoding support (#12062)#13006
Conversation
|
/bot run --disable-fail-fast |
|
PR_Github #43076 [ run ] triggered by Bot. Commit: |
…ding support (NVIDIA#12062)" This reverts commit 4ece13c. Signed-off-by: Balaram Buddharaju <169953907+brb-nv@users.noreply.github.com>
43002c5 to
1e422bf
Compare
|
/bot run --disable-fail-fast |
|
PR_Github #43096 [ run ] triggered by Bot. Commit: |
Signed-off-by: Yiqing Yan <yiqingy@nvidia.com>
Signed-off-by: Yiqing Yan <yiqingy@nvidia.com>
|
PR_Github #43096 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #43166 [ run ] triggered by Bot. Commit: |
📝 WalkthroughWalkthroughThis PR removes dynamic tree speculative decoding support from the codebase, including CUDA kernels, Torch operators, Python worker implementations, configuration parameters, and related documentation. The linear drafting path for EAGLE3 speculative decoding is retained. Changes
Estimated code review effort🎯 4 (Complex) | ⏱️ ~75 minutes Possibly related PRs
Suggested reviewers
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Actionable comments posted: 10
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (3)
examples/llm-api/quickstart_advanced.py (1)
265-273:⚠️ Potential issue | 🟠 MajorStop wiring tree-only knobs into the EAGLE3 example.
This revert removes the PyTorch EAGLE3 tree path, but this helper still forwards
eagle_choices,use_dynamic_tree, anddynamic_tree_max_topKintoEagle3DecodingConfig.--spec_decode_algo EAGLE3can still advertise an unsupported mode from the example.🐛 Proposed fix
spec_config = Eagle3DecodingConfig( max_draft_len=args.spec_decode_max_draft_len, speculative_model=args.draft_model_dir, eagle3_one_model=args.use_one_model, - eagle_choices=args.eagle_choices, - use_dynamic_tree=args.use_dynamic_tree, - dynamic_tree_max_topK=args.dynamic_tree_max_topK, allow_advanced_sampling=args.allow_advanced_sampling, eagle3_model_arch=args.eagle3_model_arch)The parser flags should be removed or rejected as a follow-up outside this hunk.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/llm-api/quickstart_advanced.py` around lines 265 - 273, The Eagle3 example is still passing tree-only knobs into the Eagle3DecodingConfig; remove the unsupported parameters eagle_choices, use_dynamic_tree, and dynamic_tree_max_topK from the spec_config construction (the Eagle3DecodingConfig call that builds spec_config) so it only forwards supported args like max_draft_len, speculative_model, eagle3_one_model, allow_advanced_sampling, and eagle3_model_arch; do not add conditional logic to enable a tree path here—strip those args from the constructor call and leave parser flag removal/rejection as a separate follow-up.tensorrt_llm/llmapi/llm_args.py (1)
1129-1174:⚠️ Potential issue | 🟠 MajorKeep the linear-tree default out of the dynamic-tree branch.
Line 1130 unconditionally overwrites
max_total_draft_tokenswithmax_draft_lenbefore the dynamic-tree checks. That means Lines 1171-1174 can never catch an omitted value, and a dynamic tree silently gets linear-tree sizing instead of the explicit total-node budget it needs.🐛 Proposed fix
self.num_eagle_layers = self.max_draft_len - self.max_total_draft_tokens = self.max_draft_len # If using linear-tree, the max_total_draft_tokens is the same as max_draft_len + if not self.use_dynamic_tree: + self.max_total_draft_tokens = self.max_draft_len # Linear tree only if self.eagle3_model_arch == "mistral_large3" and self.eagle3_layers_to_capture is None: # FIXME find a better way to setup it. self.eagle3_layers_to_capture = {-1}🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/llmapi/llm_args.py` around lines 1129 - 1174, The code unconditionally sets self.max_total_draft_tokens = self.max_draft_len (using symbols self.max_total_draft_tokens, self.max_draft_len) before handling dynamic-tree, which causes dynamic trees to inherit linear-tree sizing; remove that unconditional assignment and instead set max_total_draft_tokens only for the linear/static-tree path (e.g., guard it with if not self.use_dynamic_tree or move it into the static-tree block after eagle_choices handling), ensuring the dynamic-tree checks (use_dynamic_tree, dynamic_tree_max_topK, max_total_draft_tokens) can validate a provided total-node budget.tensorrt_llm/_torch/speculative/eagle3.py (1)
75-84:⚠️ Potential issue | 🟠 MajorThis leaves
use_dynamic_treemetadata without aSpecTreeManager.After this change,
Eagle3ResourceManageronly buildsSpecTreeManagerwhenconfig.eagle_choicesis set. Buttensorrt_llm/_torch/speculative/utils.py:create_spec_metadata()still marks the EAGLE3 metadata as tree/dynamic-tree wheneverspec_config.use_dynamic_treeis true. That leaves the runtime carrying dynamic-tree flags withspec_tree_manager=None, andEagle3SpecMetadata.prepare()falls back to the linear first-draft path instead of the accepted-path/tree handling.Please either clear
use_dynamic_treeat metadata creation as part of this revert, or keep constructing the manager until those callers are updated too.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/_torch/speculative/eagle3.py` around lines 75 - 84, The change stops constructing SpecTreeManager when config.eagle_choices is None but leaves spec_config.use_dynamic_tree true, causing metadata to advertise dynamic-tree behavior without a manager; update Eagle3ResourceManager so that if config.use_dynamic_tree is true but config.eagle_choices is None you either (A) clear/disable use_dynamic_tree on the created Eagle3SpecMetadata (set spec_config.use_dynamic_tree = False) before returning metadata, or (B) continue constructing SpecTreeManager even when eagle_choices is None so self.spec_tree_manager is always present when use_dynamic_tree is enabled; touch the Eagle3ResourceManager construction logic around the SpecTreeManager instantiation (and ensure create_spec_metadata()/Eagle3SpecMetadata.prepare() will see a consistent spec_config and self.spec_tree_manager).
🧹 Nitpick comments (2)
tensorrt_llm/_torch/speculative/spec_tree_manager.py (1)
18-18: Avoid the shared mutable class default fortop_k_list.This list is shared at the class level until
__init__rebinds it. Keeping only the annotation here avoids accidental cross-instance state and clears the RuffRUF012warning.♻️ Proposed fix
- top_k_list = [] + top_k_list: list[torch.Tensor]🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/_torch/speculative/spec_tree_manager.py` at line 18, The class-level mutable default "top_k_list = []" creates shared state across instances and triggers RUF012; remove the assignment so only the annotation remains (e.g., keep "top_k_list: list[int]" or appropriate type) and ensure the instance attribute is initialized in the class __init__ (set self.top_k_list = [] there) so each instance gets its own list; reference: top_k_list and __init__ in the same class.tensorrt_llm/_torch/speculative/drafting_loops.py (1)
229-233: Consider simplifying the spec_tree_manager extraction.Based on the executor creator logic (context snippet 3),
TreeDraftingLoopWrapperis only instantiated whenEagleDecodingConfigis used andis_linear_treeis False, guaranteeingEagle3SpecMetadatawill always be passed. Theisinstancecheck and subsequentassertcould be combined.♻️ Optional simplification
- spec_tree_manager = None - if isinstance(spec_metadata, Eagle3SpecMetadata): - spec_tree_manager = spec_metadata.eagle3_resource_manager.spec_tree_manager - - assert spec_tree_manager is not None + assert isinstance(spec_metadata, Eagle3SpecMetadata), \ + "TreeDraftingLoopWrapper requires Eagle3SpecMetadata" + spec_tree_manager = spec_metadata.eagle3_resource_manager.spec_tree_manager + assert spec_tree_manager is not None, "spec_tree_manager must be set"🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/_torch/speculative/drafting_loops.py` around lines 229 - 233, Summary: simplify extraction of spec_tree_manager by combining the type assertion and assignment. Replace the current two-step pattern (setting spec_tree_manager=None, then using isinstance check and a separate assert) with a single assert that spec_metadata is an Eagle3SpecMetadata followed immediately by assigning spec_tree_manager from spec_metadata.eagle3_resource_manager.spec_tree_manager (this aligns with the guarantee from TreeDraftingLoopWrapper/EagleDecodingConfig); update any type hints or mypy ignores if needed.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/prepareCustomMask.cu`:
- Around line 168-169: Update the comments in prepareCustomMask (in
trtllmGenKernels/fmha/prepareCustomMask.cu) so they consistently describe the
same mask layout: change the function-level input-shape comment to match the
inline comment that the input mask shape is [bs, seqLenQ, ceilDiv(seqLenQ, 32)]
and clarify that the KV dimension corresponds to Q positions (tree mask); ensure
there is a single definitive comment describing the third dimension as
ceilDiv(seqLenQ, 32) and remove any conflicting description elsewhere in the
function.
- Around line 168-174: The packed-mask indexing uses ceilDiv(seqLenQ, 32) as the
row stride when computing qMaskBaseIdx/packedMaskIdx, but the producer allocates
rows with a fixed stride numPackedMasks (divUp(maxDecodingTokens, 32)), causing
under-stride when seqLenQ < maxDecodingTokens; fix by using the actual allocated
stride: replace the per-kernel computed stride ceilDiv(seqLenQ, 32) with the
passed/known allocation stride (e.g., numPackedMasks or a new kernel parameter
maskRowStride) when computing qMaskBaseIdx/packedMaskIdx, or alternatively
change the producer to allocate rows using ceilDiv(seqLenQ, 32) per-batch so the
allocation and kernel agree.
In `@examples/llm-api/quickstart_advanced.py`:
- Around line 404-412: The code indexes optional dicts
sequence.additional_context_outputs and sequence.additional_generation_outputs
for each output_name from args.additional_model_outputs without guarding for
missing keys; update the loop in the printing block to check presence before
indexing (e.g., if sequence.additional_context_outputs and output_name in
sequence.additional_context_outputs) and similarly for
sequence.additional_generation_outputs, only printing the Context and Generation
lines when the corresponding map contains output_name (use .get or membership
checks) so printing [i]{sequence_id_text} lines never raises KeyError/TypeError.
In `@tensorrt_llm/_torch/attention_backend/sparse/dsa.py`:
- Around line 600-608: The type name 'SpecDecodingTensor' used in the
update_spec_dec_param signature is not imported causing F821; add an import for
SpecDecodingTensor inside the existing TYPE_CHECKING block so static
type-checkers see the symbol without affecting runtime. Locate the TYPE_CHECKING
block near the top of tensorrt_llm/_torch/attention_backend/sparse/dsa.py and
add a from ... import SpecDecodingTensor (matching the module where
SpecDecodingTensor is defined) so the function signature referencing
'SpecDecodingTensor' resolves.
In `@tensorrt_llm/_torch/attention_backend/trtllm.py`:
- Around line 1536-1548: The code incorrectly forces
self.is_spec_decoding_enabled False on any SM that reports TRTLLM-gen-kernel
support; change the logic so spec-decoding is only disabled for SM-gen-kernel
machines when a spec-dec tree or dynamic tree is present. Concretely, compute a
local flag via self.is_sm_version_trtllm_gen_kernel(sm=get_sm_version()) and set
self.is_spec_decoding_enabled = is_spec_decoding_enabled unless (is_sm_gen and
(spec_tree_manager is not None or spec_tree_manager.use_dynamic_tree)), keep
self.is_spec_dec_tree and self.is_spec_dec_dynamic_tree assignments and the
existing assertion block (the assertion about spec-dec tree support) as-is, and
ensure self.use_spec_decoding continues to initialize from
self.is_spec_decoding_enabled; refer to attributes/methods
self.is_spec_decoding_enabled, self.is_spec_dec_tree,
self.is_spec_dec_dynamic_tree, self.use_spec_decoding and method
is_sm_version_trtllm_gen_kernel.
In `@tensorrt_llm/_torch/pyexecutor/model_engine.py`:
- Around line 2399-2405: spec_tree_manager.spec_dec_position_offsets[0] is a
device tensor and is being accessed element-wise inside per-request loops;
materialize it once into a Python list (e.g., host_offsets =
spec_tree_manager.spec_dec_position_offsets[0].tolist()) outside the request
loops and then use host_offsets to extend position_ids in both branches instead
of indexing the tensor inside the loop; apply the same change pattern wherever
spec_dec_position_offsets is iterated (also update the analogous places handling
position_ids in the other request-loop blocks).
In `@tensorrt_llm/_torch/speculative/spec_sampler_base.py`:
- Around line 241-253: The current padding uses self.draft_len but must use the
allocated store width from the sampler storage helpers; compute storage_width =
self._get_draft_tokens_storage_size() (or use _get_max_tokens() if appropriate)
and pad o_new_tokens and o_next_new_tokens up to storage_width, and pad
o_next_draft_tokens up to storage_width - 1, so the tensors match the
destination buffers before the index_copy_ calls (look for o_new_tokens,
o_next_draft_tokens, o_next_new_tokens and index_copy_ in SpecSamplerBase).
In `@tensorrt_llm/_torch/speculative/spec_tree_manager.py`:
- Around line 260-266: get_eagle_paths is indexing self.eagle_paths[tree_idx] as
if it's a 3-D tensor ([:, i, :]) but after self.eagle_paths[tree_idx] it's 2-D,
causing an IndexError when use_dynamic_tree is True; fix by constructing or
indexing the correct shape: either ensure self.eagle_paths[tree_idx] is a 3-D
tensor of shape (batch, max_total_draft_tokens+1, path_len) before the loop, or
change the assignment to match the 2-D layout (e.g. assign into
self.eagle_paths[tree_idx][:, i] or reshape the nonzero result to match the
third dimension). Update get_eagle_paths to use self.eagle_paths,
self.spec_dec_mask_matrix, use_dynamic_tree, and max_total_draft_tokens
consistently so the slice operations match tensor ranks (reshape or unsqueeze
nonzero results or preallocate a 3-D tensor) to avoid IndexError.
- Around line 73-75: The type annotation for the eagle_choices parameter in
SpecTreeManager.__init__ is using a list literal ([List[List[int]]]) instead of
a proper type; change the signature to use the modern typing form (e.g.,
eagle_choices: list[list[int]] | None) so type checkers and inspect-based
tooling can understand it, and update any callers or default behavior if you
make it optional (None) to preserve existing semantics.
In
`@tests/unittest/_torch/speculative/test_draft_token_prepare_for_generation.py`:
- Around line 457-458: The file currently calls unittest.main() but defines
module-level pytest-style test_draft_token_static_tree_prepare_for_generation(),
so running the file finds zero tests; either remove the if __name__ ==
"__main__": unittest.main() block to keep this as a pytest-only test, or convert
the module-level function into a unittest.TestCase (e.g., create class
TestDraftTokenPrepareForGeneration(unittest.TestCase) with a method
test_draft_token_static_tree_prepare_for_generation that calls the same
assertions) and keep unittest.main(); update imports accordingly.
---
Outside diff comments:
In `@examples/llm-api/quickstart_advanced.py`:
- Around line 265-273: The Eagle3 example is still passing tree-only knobs into
the Eagle3DecodingConfig; remove the unsupported parameters eagle_choices,
use_dynamic_tree, and dynamic_tree_max_topK from the spec_config construction
(the Eagle3DecodingConfig call that builds spec_config) so it only forwards
supported args like max_draft_len, speculative_model, eagle3_one_model,
allow_advanced_sampling, and eagle3_model_arch; do not add conditional logic to
enable a tree path here—strip those args from the constructor call and leave
parser flag removal/rejection as a separate follow-up.
In `@tensorrt_llm/_torch/speculative/eagle3.py`:
- Around line 75-84: The change stops constructing SpecTreeManager when
config.eagle_choices is None but leaves spec_config.use_dynamic_tree true,
causing metadata to advertise dynamic-tree behavior without a manager; update
Eagle3ResourceManager so that if config.use_dynamic_tree is true but
config.eagle_choices is None you either (A) clear/disable use_dynamic_tree on
the created Eagle3SpecMetadata (set spec_config.use_dynamic_tree = False) before
returning metadata, or (B) continue constructing SpecTreeManager even when
eagle_choices is None so self.spec_tree_manager is always present when
use_dynamic_tree is enabled; touch the Eagle3ResourceManager construction logic
around the SpecTreeManager instantiation (and ensure
create_spec_metadata()/Eagle3SpecMetadata.prepare() will see a consistent
spec_config and self.spec_tree_manager).
In `@tensorrt_llm/llmapi/llm_args.py`:
- Around line 1129-1174: The code unconditionally sets
self.max_total_draft_tokens = self.max_draft_len (using symbols
self.max_total_draft_tokens, self.max_draft_len) before handling dynamic-tree,
which causes dynamic trees to inherit linear-tree sizing; remove that
unconditional assignment and instead set max_total_draft_tokens only for the
linear/static-tree path (e.g., guard it with if not self.use_dynamic_tree or
move it into the static-tree block after eagle_choices handling), ensuring the
dynamic-tree checks (use_dynamic_tree, dynamic_tree_max_topK,
max_total_draft_tokens) can validate a provided total-node budget.
---
Nitpick comments:
In `@tensorrt_llm/_torch/speculative/drafting_loops.py`:
- Around line 229-233: Summary: simplify extraction of spec_tree_manager by
combining the type assertion and assignment. Replace the current two-step
pattern (setting spec_tree_manager=None, then using isinstance check and a
separate assert) with a single assert that spec_metadata is an
Eagle3SpecMetadata followed immediately by assigning spec_tree_manager from
spec_metadata.eagle3_resource_manager.spec_tree_manager (this aligns with the
guarantee from TreeDraftingLoopWrapper/EagleDecodingConfig); update any type
hints or mypy ignores if needed.
In `@tensorrt_llm/_torch/speculative/spec_tree_manager.py`:
- Line 18: The class-level mutable default "top_k_list = []" creates shared
state across instances and triggers RUF012; remove the assignment so only the
annotation remains (e.g., keep "top_k_list: list[int]" or appropriate type) and
ensure the instance attribute is initialized in the class __init__ (set
self.top_k_list = [] there) so each instance gets its own list; reference:
top_k_list and __init__ in the same class.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro Plus
Run ID: 314827d7-8cad-4813-81c8-dcc291c62b30
📒 Files selected for processing (37)
cpp/tensorrt_llm/kernels/speculativeDecoding/dynamicTreeKernels.cucpp/tensorrt_llm/kernels/speculativeDecoding/dynamicTreeKernels.hcpp/tensorrt_llm/kernels/speculativeDecoding/kvCacheUpdateKernels.cucpp/tensorrt_llm/kernels/speculativeDecoding/kvCacheUpdateKernels.hcpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/fmhaRunnerParams.hcpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/prepareCustomMask.cucpp/tensorrt_llm/kernels/xqaDispatcher.cppcpp/tensorrt_llm/thop/CMakeLists.txtcpp/tensorrt_llm/thop/dynamicTreeOp.cppcpp/tensorrt_llm/thop/parallelDecodeKVCacheUpdateOp.cppdocs/source/features/feature-combination-matrix.mddocs/source/features/speculative-decoding.mddocs/source/models/supported-models.mdexamples/llm-api/quickstart_advanced.pytensorrt_llm/_torch/attention_backend/interface.pytensorrt_llm/_torch/attention_backend/sparse/dsa.pytensorrt_llm/_torch/attention_backend/trtllm.pytensorrt_llm/_torch/modules/attention.pytensorrt_llm/_torch/pyexecutor/model_engine.pytensorrt_llm/_torch/pyexecutor/py_executor.pytensorrt_llm/_torch/pyexecutor/py_executor_creator.pytensorrt_llm/_torch/pyexecutor/resource_manager.pytensorrt_llm/_torch/speculative/drafting_loops.pytensorrt_llm/_torch/speculative/dynamic_tree_ops.pytensorrt_llm/_torch/speculative/eagle3.pytensorrt_llm/_torch/speculative/eagle3_dynamic_tree.pytensorrt_llm/_torch/speculative/spec_sampler_base.pytensorrt_llm/_torch/speculative/spec_tree_manager.pytensorrt_llm/_torch/speculative/utils.pytensorrt_llm/llmapi/llm_args.pytests/integration/test_lists/waives.txttests/unittest/_torch/modeling/test_modeling_llama.pytests/unittest/_torch/speculative/test_draft_token_prepare_for_generation.pytests/unittest/_torch/speculative/test_draft_token_tree_sampling.pytests/unittest/_torch/speculative/test_eagle3.pytests/unittest/_torch/thop/parallel/test_custom_ops.pytests/unittest/others/test_kv_cache_update.py
💤 Files with no reviewable changes (16)
- cpp/tensorrt_llm/thop/CMakeLists.txt
- tests/unittest/_torch/thop/parallel/test_custom_ops.py
- cpp/tensorrt_llm/kernels/xqaDispatcher.cpp
- cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/fmhaRunnerParams.h
- tests/integration/test_lists/waives.txt
- tests/unittest/_torch/speculative/test_eagle3.py
- tensorrt_llm/_torch/modules/attention.py
- cpp/tensorrt_llm/kernels/speculativeDecoding/dynamicTreeKernels.h
- cpp/tensorrt_llm/thop/parallelDecodeKVCacheUpdateOp.cpp
- tests/unittest/others/test_kv_cache_update.py
- tensorrt_llm/_torch/speculative/dynamic_tree_ops.py
- cpp/tensorrt_llm/kernels/speculativeDecoding/kvCacheUpdateKernels.cu
- cpp/tensorrt_llm/kernels/speculativeDecoding/kvCacheUpdateKernels.h
- cpp/tensorrt_llm/kernels/speculativeDecoding/dynamicTreeKernels.cu
- cpp/tensorrt_llm/thop/dynamicTreeOp.cpp
- tensorrt_llm/_torch/speculative/eagle3_dynamic_tree.py
| // Input mask shape: [bs, seqLenQ, ceilDiv(seqLenQ, 32)] | ||
| // The KV dimension in the mask corresponds to Q positions (tree mask) |
There was a problem hiding this comment.
Mask-layout comments are inconsistent in this function.
Line 168 documents [bs, seqLenQ, ceilDiv(seqLenQ, 32)], but the function-level input-shape comment still describes a different third dimension. Please align both comments to one definitive layout contract.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/prepareCustomMask.cu` around
lines 168 - 169, Update the comments in prepareCustomMask (in
trtllmGenKernels/fmha/prepareCustomMask.cu) so they consistently describe the
same mask layout: change the function-level input-shape comment to match the
inline comment that the input mask shape is [bs, seqLenQ, ceilDiv(seqLenQ, 32)]
and clarify that the KV dimension corresponds to Q positions (tree mask); ensure
there is a single definitive comment describing the third dimension as
ceilDiv(seqLenQ, 32) and remove any conflicting description elsewhere in the
function.
| // Input mask shape: [bs, seqLenQ, ceilDiv(seqLenQ, 32)] | ||
| // The KV dimension in the mask corresponds to Q positions (tree mask) | ||
| int32_t const qPosInTree = tokenIdxKv - firstSparseMaskOffsetKv; | ||
| if (qPosInTree < seqLenQ) | ||
| { | ||
| int32_t const qMaskBaseIdx = (batchIdx * packedMaskSeqLenQ + tokenIdxQ) * packedMaskNumBlocks; | ||
| int32_t const qMaskBaseIdx = (batchIdx * seqLenQ + tokenIdxQ) * ceilDiv(seqLenQ, 32); | ||
| int32_t const packedMaskIdx = qMaskBaseIdx + (qPosInTree >> 5); |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
set -euo pipefail
echo "== Producer/consumer wiring =="
rg -n -C3 'spec_decoding_packed_mask|generalPackedCustoMaskPtr' cpp
echo "== Packed-mask allocation/reshape sites =="
rg -n -C3 'attentionPackedMask|numPackedMasks|makeShape\\(|reshape\\(' cpp/tensorrt_llm/batch_manager
echo "== Remaining packed-stride params in FMHA runner path =="
rg -n -C3 'PackedMask|packedMaskNumBlocks|packedMaskSeqLen|mPackedMaskMaxSeqLenQ' \
cpp/tensorrt_llm/kernels/trtllmGenKernels/fmhaRepository: NVIDIA/TensorRT-LLM
Length of output: 25113
🏁 Script executed:
# Find where the packed mask tensor is actually allocated and reshaped
rg -n "attentionPackedMask|spec_decoding_packed_mask" cpp/tensorrt_llm/batch_manager --type cpp | head -40Repository: NVIDIA/TensorRT-LLM
Length of output: 1519
🏁 Script executed:
# Look at medusaBuffers.cpp to find allocation of attentionPackedMaskDevice
sed -n '1,200p' cpp/tensorrt_llm/batch_manager/medusaBuffers.cppRepository: NVIDIA/TensorRT-LLM
Length of output: 7527
🏁 Script executed:
# Find reshape/reshape operations and makeShape in the codebase
rg -n "reshape\|makeShape" cpp/tensorrt_llm/batch_manager --type cpp -A 2 | head -60Repository: NVIDIA/TensorRT-LLM
Length of output: 45
🏁 Script executed:
# Find the definition of getNumPackedMasks
rg -n "getNumPackedMasks|NumPackedMasks" cpp/tensorrt_llm/runtime --type cpp -B 2 -A 5 | head -80Repository: NVIDIA/TensorRT-LLM
Length of output: 1835
🏁 Script executed:
# Find MedusaModule definition
fd -name "*medusaModule*" cpp/tensorrt_llm/runtimeRepository: NVIDIA/TensorRT-LLM
Length of output: 293
🏁 Script executed:
# Look at the prepareCustomMask kernel more carefully to understand indexing
sed -n '165,180p' cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/prepareCustomMask.cuRepository: NVIDIA/TensorRT-LLM
Length of output: 726
🏁 Script executed:
# Find MedusaModule header and implementation
fd "MedusaModule" cpp/tensorrt_llm/runtimeRepository: NVIDIA/TensorRT-LLM
Length of output: 45
🏁 Script executed:
# Search for getNumPackedMasks implementation
rg "getNumPackedMasks" cpp/tensorrt_llm/runtime --type cpp -B 3 -A 5Repository: NVIDIA/TensorRT-LLM
Length of output: 2013
🏁 Script executed:
# Find where MedusaModule is defined
rg "class MedusaModule|class SpeculativeDecodingModule" cpp/tensorrt_llm --type cpp -A 10Repository: NVIDIA/TensorRT-LLM
Length of output: 45
🏁 Script executed:
# Search in headers
fd "\.h$" cpp/tensorrt_llm/runtime | xargs rg "getNumPackedMasks" | head -20Repository: NVIDIA/TensorRT-LLM
Length of output: 45
🏁 Script executed:
# Look at the comment in prepareCustomMask.cu more carefully around the input shape
sed -n '90,115p' cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/prepareCustomMask.cuRepository: NVIDIA/TensorRT-LLM
Length of output: 1415
🏁 Script executed:
# Search for getNumPackedMasks in header files
find cpp/tensorrt_llm -name "*.h" -type f | xargs rg "getNumPackedMasks" -B 2 -A 3Repository: NVIDIA/TensorRT-LLM
Length of output: 47
🏁 Script executed:
# Search for implementations of getNumPackedMasks more broadly
rg "NumPackedMasks\s*\(" cpp/tensorrt_llm --type cpp -B 3 -A 3Repository: NVIDIA/TensorRT-LLM
Length of output: 2263
🏁 Script executed:
# Look at what maxDecodingTokens and numPackedMasks relationship might be
rg "maxDecodingTokens|maxMedusaTokens" cpp/tensorrt_llm/runtime --type cpp | grep -i "packed\|num" | head -20Repository: NVIDIA/TensorRT-LLM
Length of output: 2717
🏁 Script executed:
# Found it in explicitDraftTokensBuffers - check if there's a pattern for numPackedMasks
rg "ceilDiv.*32.*packed|getNumPackedMasks" cpp/tensorrt_llm --type cpp -B 2 -A 2Repository: NVIDIA/TensorRT-LLM
Length of output: 1622
🏁 Script executed:
# Look at medusaModule or similar to find numPackedMasks implementation
find cpp/tensorrt_llm -name "*.h" -o -name "*.cpp" | xargs rg "class.*Module.*{" | grep -i "medusa\|speculative" | head -10Repository: NVIDIA/TensorRT-LLM
Length of output: 734
🏁 Script executed:
# Search for where numPackedMasks is actually set/returned
rg "numPackedMasks\s*=" cpp/tensorrt_llm --type cpp -B 2 -A 2Repository: NVIDIA/TensorRT-LLM
Length of output: 1712
🏁 Script executed:
# Look at where seqLenQ comes from in prepareCustomMask kernel
sed -n '100,180p' cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/prepareCustomMask.cuRepository: NVIDIA/TensorRT-LLM
Length of output: 3356
🏁 Script executed:
# Check the medusaModule definition to see what getNumPackedMasks returns
find cpp/tensorrt_llm/runtime -name "*.h" | xargs grep -l "MedusaModule\|getNumPackedMasks" | head -5Repository: NVIDIA/TensorRT-LLM
Length of output: 45
🏁 Script executed:
# Look at test file to understand the relationship
sed -n '1,50p' cpp/tests/unit_tests/kernels/prepareCustomMaskTest.cppRepository: NVIDIA/TensorRT-LLM
Length of output: 1806
🏁 Script executed:
# Check if seqLensQPtr is set to a fixed value or varies
rg "seqLensQPtr\[" cpp/tensorrt_llm --type cpp -B 3 -A 3 | head -60Repository: NVIDIA/TensorRT-LLM
Length of output: 45
🏁 Script executed:
# Look at the test to see what values are used
sed -n '100,200p' cpp/tests/unit_tests/kernels/prepareCustomMaskTest.cppRepository: NVIDIA/TensorRT-LLM
Length of output: 4265
🏁 Script executed:
# Continue looking at the test to see actual values used
sed -n '200,320p' cpp/tests/unit_tests/kernels/prepareCustomMaskTest.cppRepository: NVIDIA/TensorRT-LLM
Length of output: 6746
🏁 Script executed:
# Check if there's documentation about the relationship between seqLenQ and maxDecodingTokens
rg "seqLenQ|maxDecodingTokens" cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/prepareCustomMask.cu | head -20Repository: NVIDIA/TensorRT-LLM
Length of output: 740
🏁 Script executed:
# Verify what getNumPackedMasks returns in MedusaModule
rg "getNumPackedMasks" cpp/tensorrt_llm --type cpp -A 3 -B 1Repository: NVIDIA/TensorRT-LLM
Length of output: 1565
🏁 Script executed:
# Double-check: does the test actually allocate with uniform or variable seqLenQ?
sed -n '210,245p' cpp/tests/unit_tests/kernels/prepareCustomMaskTest.cppRepository: NVIDIA/TensorRT-LLM
Length of output: 1701
Confirmed stride mismatch in packed-mask indexing at line 173.
The kernel indexes the packed mask using ceilDiv(seqLenQ, 32) as the row stride, but the producer allocates with a fixed stride of numPackedMasks (computed as divUp(maxDecodingTokens, 32) at startup). Since seqLenQ is per-batch and can be less than maxDecodingTokens, the kernel will under-stride and read incorrect mask bits when seqLenQ < maxDecodingTokens.
Either ensure the allocated stride matches the kernel's stride computation at line 173, or compute and pass the correct stride value per batch to the kernel.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/prepareCustomMask.cu` around
lines 168 - 174, The packed-mask indexing uses ceilDiv(seqLenQ, 32) as the row
stride when computing qMaskBaseIdx/packedMaskIdx, but the producer allocates
rows with a fixed stride numPackedMasks (divUp(maxDecodingTokens, 32)), causing
under-stride when seqLenQ < maxDecodingTokens; fix by using the actual allocated
stride: replace the per-kernel computed stride ceilDiv(seqLenQ, 32) with the
passed/known allocation stride (e.g., numPackedMasks or a new kernel parameter
maskRowStride) when computing qMaskBaseIdx/packedMaskIdx, or alternatively
change the producer to allocate rows using ceilDiv(seqLenQ, 32) per-batch so the
allocation and kernel agree.
| if args.additional_model_outputs: | ||
| for output_name in args.additional_model_outputs: | ||
| if sequence.additional_context_outputs: | ||
| print( | ||
| f"[{i}]{sequence_id_text} Generation {output_name}: {sequence.additional_generation_outputs[output_name]}" | ||
| f"[{i}]{sequence_id_text} Context {output_name}: {sequence.additional_context_outputs[output_name]}" | ||
| ) | ||
| print( | ||
| f"[{i}]{sequence_id_text} Generation {output_name}: {sequence.additional_generation_outputs[output_name]}" | ||
| ) |
There was a problem hiding this comment.
Guard optional extra-output maps before indexing them.
additional_context_outputs and additional_generation_outputs are optional. Any requested output_name that's absent from one of those dicts currently raises TypeError/KeyError while printing results.
🛡️ Proposed fix
if args.additional_model_outputs:
for output_name in args.additional_model_outputs:
- if sequence.additional_context_outputs:
+ if (
+ sequence.additional_context_outputs
+ and output_name in sequence.additional_context_outputs
+ ):
print(
f"[{i}]{sequence_id_text} Context {output_name}: {sequence.additional_context_outputs[output_name]}"
)
- print(
- f"[{i}]{sequence_id_text} Generation {output_name}: {sequence.additional_generation_outputs[output_name]}"
- )
+ if (
+ sequence.additional_generation_outputs
+ and output_name in sequence.additional_generation_outputs
+ ):
+ print(
+ f"[{i}]{sequence_id_text} Generation {output_name}: {sequence.additional_generation_outputs[output_name]}"
+ )📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| if args.additional_model_outputs: | |
| for output_name in args.additional_model_outputs: | |
| if sequence.additional_context_outputs: | |
| print( | |
| f"[{i}]{sequence_id_text} Generation {output_name}: {sequence.additional_generation_outputs[output_name]}" | |
| f"[{i}]{sequence_id_text} Context {output_name}: {sequence.additional_context_outputs[output_name]}" | |
| ) | |
| print( | |
| f"[{i}]{sequence_id_text} Generation {output_name}: {sequence.additional_generation_outputs[output_name]}" | |
| ) | |
| if args.additional_model_outputs: | |
| for output_name in args.additional_model_outputs: | |
| if ( | |
| sequence.additional_context_outputs | |
| and output_name in sequence.additional_context_outputs | |
| ): | |
| print( | |
| f"[{i}]{sequence_id_text} Context {output_name}: {sequence.additional_context_outputs[output_name]}" | |
| ) | |
| if ( | |
| sequence.additional_generation_outputs | |
| and output_name in sequence.additional_generation_outputs | |
| ): | |
| print( | |
| f"[{i}]{sequence_id_text} Generation {output_name}: {sequence.additional_generation_outputs[output_name]}" | |
| ) |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/llm-api/quickstart_advanced.py` around lines 404 - 412, The code
indexes optional dicts sequence.additional_context_outputs and
sequence.additional_generation_outputs for each output_name from
args.additional_model_outputs without guarding for missing keys; update the loop
in the printing block to check presence before indexing (e.g., if
sequence.additional_context_outputs and output_name in
sequence.additional_context_outputs) and similarly for
sequence.additional_generation_outputs, only printing the Context and Generation
lines when the corresponding map contains output_name (use .get or membership
checks) so printing [i]{sequence_id_text} lines never raises KeyError/TypeError.
| spec_decoding_tensor: Optional['SpecDecodingTensor'] = None, | ||
| ): | ||
| """Update speculative decoding parameters and create expanded buffers.""" | ||
| super().update_spec_dec_param(batch_size, is_spec_decoding_enabled, | ||
| is_spec_dec_tree, | ||
| is_spec_dec_dynamic_tree, max_draft_len, | ||
| max_total_draft_tokens, model_is_wrapped, | ||
| spec_metadata, spec_tree_manager) | ||
| spec_metadata, spec_tree_manager, | ||
| spec_decoding_tensor) |
There was a problem hiding this comment.
Missing TYPE_CHECKING import for SpecDecodingTensor.
The type hint 'SpecDecodingTensor' is used but not imported. Static analysis correctly flags this as F821: undefined name 'SpecDecodingTensor'. You need to add the import under TYPE_CHECKING block.
🐛 Proposed fix to add the missing import
Add to the existing TYPE_CHECKING block (around line 38-39):
if TYPE_CHECKING:
from tensorrt_llm.llmapi.llm_args import DecodingBaseConfig
+ from tensorrt_llm._torch.speculative.utils import SpecDecodingTensor📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| spec_decoding_tensor: Optional['SpecDecodingTensor'] = None, | |
| ): | |
| """Update speculative decoding parameters and create expanded buffers.""" | |
| super().update_spec_dec_param(batch_size, is_spec_decoding_enabled, | |
| is_spec_dec_tree, | |
| is_spec_dec_dynamic_tree, max_draft_len, | |
| max_total_draft_tokens, model_is_wrapped, | |
| spec_metadata, spec_tree_manager) | |
| spec_metadata, spec_tree_manager, | |
| spec_decoding_tensor) | |
| if TYPE_CHECKING: | |
| from tensorrt_llm.llmapi.llm_args import DecodingBaseConfig | |
| from tensorrt_llm._torch.speculative.utils import SpecDecodingTensor |
🧰 Tools
🪛 Flake8 (7.3.0)
[error] 600-600: undefined name 'SpecDecodingTensor'
(F821)
🪛 Ruff (0.15.9)
[error] 600-600: Undefined name SpecDecodingTensor
(F821)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tensorrt_llm/_torch/attention_backend/sparse/dsa.py` around lines 600 - 608,
The type name 'SpecDecodingTensor' used in the update_spec_dec_param signature
is not imported causing F821; add an import for SpecDecodingTensor inside the
existing TYPE_CHECKING block so static type-checkers see the symbol without
affecting runtime. Locate the TYPE_CHECKING block near the top of
tensorrt_llm/_torch/attention_backend/sparse/dsa.py and add a from ... import
SpecDecodingTensor (matching the module where SpecDecodingTensor is defined) so
the function signature referencing 'SpecDecodingTensor' resolves.
| # spec_dec mode should only be enabled for non-sm100 machines and when there's a spec-dec tree. | ||
| self.is_spec_decoding_enabled = is_spec_decoding_enabled and ( | ||
| not self.is_sm_version_trtllm_gen_kernel(sm=get_sm_version())) | ||
|
|
||
| self.is_spec_dec_tree = spec_tree_manager is not None | ||
| self.is_spec_dec_dynamic_tree = spec_tree_manager is not None and spec_tree_manager.use_dynamic_tree | ||
|
|
||
| if self.is_sm_version_trtllm_gen_kernel(sm=get_sm_version()): | ||
| if self.is_spec_dec_tree or self.is_spec_dec_dynamic_tree: | ||
| assert not self.is_spec_dec_tree, "Spec-dec tree is not supported on this machine. Please use a pre-Blackwell machine for a spec-dec tree." | ||
|
|
||
| # use_spec_decoding is default to true by default, change in runtime by layers / requests | ||
| self.use_spec_decoding = self.is_spec_decoding_enabled |
There was a problem hiding this comment.
Don't gate off the surviving linear spec-dec path on SM100+.
This now forces is_spec_decoding_enabled=False on every SM that reports TRTLLM-gen-kernel support, not just the removed tree/dynamic-tree modes. Eagle3OneModelWorker.forward() still sets attn_metadata.use_spec_decoding = True for the verification pass, so Blackwell+ ends up in an enabled=False / use=True state for the retained linear EAGLE3 flow.
Suggested direction
- self.is_spec_decoding_enabled = is_spec_decoding_enabled and (
- not self.is_sm_version_trtllm_gen_kernel(sm=get_sm_version()))
+ sm = get_sm_version()
+ is_trtllm_gen_sm = self.is_sm_version_trtllm_gen_kernel(sm=sm)
+ is_tree_mode = is_spec_dec_tree or is_spec_dec_dynamic_tree
+ self.is_spec_decoding_enabled = is_spec_decoding_enabled and (
+ not is_trtllm_gen_sm or not is_tree_mode)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tensorrt_llm/_torch/attention_backend/trtllm.py` around lines 1536 - 1548,
The code incorrectly forces self.is_spec_decoding_enabled False on any SM that
reports TRTLLM-gen-kernel support; change the logic so spec-decoding is only
disabled for SM-gen-kernel machines when a spec-dec tree or dynamic tree is
present. Concretely, compute a local flag via
self.is_sm_version_trtllm_gen_kernel(sm=get_sm_version()) and set
self.is_spec_decoding_enabled = is_spec_decoding_enabled unless (is_sm_gen and
(spec_tree_manager is not None or spec_tree_manager.use_dynamic_tree)), keep
self.is_spec_dec_tree and self.is_spec_dec_dynamic_tree assignments and the
existing assertion block (the assertion about spec-dec tree support) as-is, and
ensure self.use_spec_decoding continues to initialize from
self.is_spec_decoding_enabled; refer to attributes/methods
self.is_spec_decoding_enabled, self.is_spec_dec_tree,
self.is_spec_dec_dynamic_tree, self.use_spec_decoding and method
is_sm_version_trtllm_gen_kernel.
| spec_resource_manager, spec_tree_manager = None, None | ||
| if spec_config is not None: | ||
| spec_resource_manager = resource_manager.get_resource_manager( | ||
| ResourceManagerType.SPEC_RESOURCE_MANAGER) | ||
| if spec_resource_manager is not None and hasattr( | ||
| spec_resource_manager, 'spec_tree_manager'): | ||
| spec_tree_manager = spec_resource_manager.spec_tree_manager |
There was a problem hiding this comment.
Convert tree position offsets to a Python list before these request loops.
spec_tree_manager.spec_dec_position_offsets[0] is a device tensor, so extending position_ids from it here does per-element tensor access in a hot path. Materialize the offsets once with .tolist() outside the loop and reuse that host list in both branches.
♻️ Suggested change
- spec_resource_manager, spec_tree_manager = None, None
+ spec_resource_manager, spec_tree_manager = None, None
+ spec_position_offsets = None
if spec_config is not None:
spec_resource_manager = resource_manager.get_resource_manager(
ResourceManagerType.SPEC_RESOURCE_MANAGER)
if spec_resource_manager is not None and hasattr(
spec_resource_manager, 'spec_tree_manager'):
spec_tree_manager = spec_resource_manager.spec_tree_manager
+ if (spec_tree_manager is not None and not self.is_draft_model
+ and not spec_config.is_linear_tree):
+ spec_position_offsets = (
+ spec_tree_manager.spec_dec_position_offsets[0].tolist()
+ )
@@
- position_ids.extend(
- past_seen_token_num +
- spec_tree_manager.spec_dec_position_offsets[
- 0] # [max_total_draft_tokens + 1]
- )
+ position_ids.extend(
+ past_seen_token_num + offset
+ for offset in spec_position_offsets
+ )
@@
- position_ids.extend(
- past_seen_token_num +
- spec_tree_manager.spec_dec_position_offsets[
- 0] # [max_total_draft_tokens + 1]
- )
+ position_ids.extend(
+ past_seen_token_num + offset
+ for offset in spec_position_offsets
+ )Based on learnings, in files under tensorrt_llm/_torch/pyexecutor, avoid accessing torch.Tensor objects inside for-loops when iterating over requests. Convert batched tensors to Python lists beforehand using tensor.tolist().
Also applies to: 2445-2458, 2478-2491
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tensorrt_llm/_torch/pyexecutor/model_engine.py` around lines 2399 - 2405,
spec_tree_manager.spec_dec_position_offsets[0] is a device tensor and is being
accessed element-wise inside per-request loops; materialize it once into a
Python list (e.g., host_offsets =
spec_tree_manager.spec_dec_position_offsets[0].tolist()) outside the request
loops and then use host_offsets to extend position_ids in both branches instead
of indexing the tensor inside the loop; apply the same change pattern wherever
spec_dec_position_offsets is iterated (also update the analogous places handling
position_ids in the other request-loop blocks).
| # Pad to match fixed-size store buffers for index_copy_. | ||
| if o_new_tokens.shape[1] < (self.draft_len + 1): | ||
| o_new_tokens = torch.nn.functional.pad( | ||
| o_new_tokens, (0, new_tokens_width - o_new_tokens.shape[1]) | ||
| o_new_tokens, (0, (self.draft_len + 1) - o_new_tokens.shape[1]) | ||
| ) | ||
| elif o_new_tokens.shape[1] > new_tokens_width: | ||
| o_new_tokens = o_new_tokens[:, :new_tokens_width] | ||
| if o_next_draft_tokens.shape[1] < draft_tokens_width: | ||
| if o_next_draft_tokens.shape[1] < self.draft_len: | ||
| o_next_draft_tokens = torch.nn.functional.pad( | ||
| o_next_draft_tokens, (0, draft_tokens_width - o_next_draft_tokens.shape[1]) | ||
| o_next_draft_tokens, (0, self.draft_len - o_next_draft_tokens.shape[1]) | ||
| ) | ||
| elif o_next_draft_tokens.shape[1] > draft_tokens_width: | ||
| o_next_draft_tokens = o_next_draft_tokens[:, :draft_tokens_width] | ||
| if o_next_new_tokens.shape[1] < next_new_tokens_width: | ||
| if o_next_new_tokens.shape[1] < (self.draft_len + 1): | ||
| o_next_new_tokens = torch.nn.functional.pad( | ||
| o_next_new_tokens, (0, next_new_tokens_width - o_next_new_tokens.shape[1]) | ||
| o_next_new_tokens, (0, (self.draft_len + 1) - o_next_new_tokens.shape[1]) | ||
| ) |
There was a problem hiding this comment.
Pad against the allocated store width, not draft_len.
SpecSamplerBase still allows subclasses to allocate wider buffers via _get_max_tokens() / _get_draft_tokens_storage_size(). MTPSampler does that with args.max_total_draft_tokens + 1, so padding only to self.draft_len leaves o_* narrower than the destination and index_copy_ at Lines 256-259 will fail on MTP batches whose tree width exceeds the runtime draft length.
🐛 Proposed fix
+ target_new_tokens = self.store.new_tokens.shape[0]
+ target_draft_tokens = self.store.next_draft_tokens.shape[1]
+
# Pad to match fixed-size store buffers for index_copy_.
- if o_new_tokens.shape[1] < (self.draft_len + 1):
+ if o_new_tokens.shape[1] < target_new_tokens:
o_new_tokens = torch.nn.functional.pad(
- o_new_tokens, (0, (self.draft_len + 1) - o_new_tokens.shape[1])
+ o_new_tokens, (0, target_new_tokens - o_new_tokens.shape[1])
)
- if o_next_draft_tokens.shape[1] < self.draft_len:
+ if o_next_draft_tokens.shape[1] < target_draft_tokens:
o_next_draft_tokens = torch.nn.functional.pad(
- o_next_draft_tokens, (0, self.draft_len - o_next_draft_tokens.shape[1])
+ o_next_draft_tokens, (0, target_draft_tokens - o_next_draft_tokens.shape[1])
)
- if o_next_new_tokens.shape[1] < (self.draft_len + 1):
+ if o_next_new_tokens.shape[1] < target_new_tokens:
o_next_new_tokens = torch.nn.functional.pad(
- o_next_new_tokens, (0, (self.draft_len + 1) - o_next_new_tokens.shape[1])
+ o_next_new_tokens, (0, target_new_tokens - o_next_new_tokens.shape[1])
)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tensorrt_llm/_torch/speculative/spec_sampler_base.py` around lines 241 - 253,
The current padding uses self.draft_len but must use the allocated store width
from the sampler storage helpers; compute storage_width =
self._get_draft_tokens_storage_size() (or use _get_max_tokens() if appropriate)
and pad o_new_tokens and o_next_new_tokens up to storage_width, and pad
o_next_draft_tokens up to storage_width - 1, so the tensors match the
destination buffers before the index_copy_ calls (look for o_new_tokens,
o_next_draft_tokens, o_next_new_tokens and index_copy_ in SpecSamplerBase).
| def __init__(self, max_num_requests: int, use_dynamic_tree: bool, | ||
| max_total_draft_tokens: int, max_draft_len: int, | ||
| eagle_choices: List[List[int]] | None, | ||
| dynamic_tree_max_topK: int): | ||
| eagle_choices: [List[List[int]]], dynamic_tree_max_topK: int): |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
python - <<'PY'
from pathlib import Path
import ast
path = Path("tensorrt_llm/_torch/speculative/spec_tree_manager.py")
module = ast.parse(path.read_text())
for node in ast.walk(module):
if isinstance(node, ast.FunctionDef) and node.name == "__init__":
arg = next(a for a in node.args.args if a.arg == "eagle_choices")
print(ast.dump(arg.annotation, indent=2))
break
PYRepository: NVIDIA/TensorRT-LLM
Length of output: 296
🏁 Script executed:
#!/bin/bash
# Check the full __init__ signature and see if eagle_choices has a default value
cd tensorrt_llm/_torch/speculative
head -80 spec_tree_manager.py | tail -20Repository: NVIDIA/TensorRT-LLM
Length of output: 1000
🏁 Script executed:
#!/bin/bash
# Check project Python version requirement
fd -t f "pyproject.toml|setup.py|setup.cfg" -x cat {}Repository: NVIDIA/TensorRT-LLM
Length of output: 50376
🏁 Script executed:
#!/bin/bash
# Check if eagle_choices is used elsewhere to understand if it can be None
rg "eagle_choices" tensorrt_llm/_torch/speculative/spec_tree_manager.py -A 2 -B 2Repository: NVIDIA/TensorRT-LLM
Length of output: 1944
Fix the eagle_choices annotation.
[List[List[int]]] is a list literal in Python, not a type annotation, so type checkers and inspect-based tooling can't reason about this parameter correctly. Use list[list[int]] | None instead. Python 3.10+ supports this syntax.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tensorrt_llm/_torch/speculative/spec_tree_manager.py` around lines 73 - 75,
The type annotation for the eagle_choices parameter in SpecTreeManager.__init__
is using a list literal ([List[List[int]]]) instead of a proper type; change the
signature to use the modern typing form (e.g., eagle_choices: list[list[int]] |
None) so type checkers and inspect-based tooling can understand it, and update
any callers or default behavior if you make it optional (None) to preserve
existing semantics.
| if self.use_dynamic_tree: | ||
| self.eagle_paths[tree_idx].fill_(-1) | ||
| # If dynamic tree, return the eagle_paths according to the mask. | ||
| for i in range(self.max_total_draft_tokens + 1): | ||
| self.eagle_paths[tree_idx][:, i, :] = self.spec_dec_mask_matrix[ | ||
| tree_idx][i, :].nonzero() | ||
| return self.eagle_paths[tree_idx] |
There was a problem hiding this comment.
get_eagle_paths() will raise before it can rebuild dynamic paths.
After self.eagle_paths[tree_idx], the tensor is 2-D, so [:, i, :] on Line 264 is an IndexError. tensorrt_llm/_torch/pyexecutor/sampler.py calls get_eagle_paths(seq_slot) during verification, so any dynamic-tree request will fail here instead of returning paths.
🐛 Proposed fix
if self.use_dynamic_tree:
self.eagle_paths[tree_idx].fill_(-1)
# If dynamic tree, return the eagle_paths according to the mask.
for i in range(self.max_total_draft_tokens + 1):
- self.eagle_paths[tree_idx][:, i, :] = self.spec_dec_mask_matrix[
- tree_idx][i, :].nonzero()
+ path = self.spec_dec_mask_matrix[tree_idx, i].nonzero(
+ as_tuple=False
+ ).flatten()
+ self.eagle_paths[tree_idx, i, :path.numel()] = path
return self.eagle_paths[tree_idx]🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tensorrt_llm/_torch/speculative/spec_tree_manager.py` around lines 260 - 266,
get_eagle_paths is indexing self.eagle_paths[tree_idx] as if it's a 3-D tensor
([:, i, :]) but after self.eagle_paths[tree_idx] it's 2-D, causing an IndexError
when use_dynamic_tree is True; fix by constructing or indexing the correct
shape: either ensure self.eagle_paths[tree_idx] is a 3-D tensor of shape (batch,
max_total_draft_tokens+1, path_len) before the loop, or change the assignment to
match the 2-D layout (e.g. assign into self.eagle_paths[tree_idx][:, i] or
reshape the nonzero result to match the third dimension). Update get_eagle_paths
to use self.eagle_paths, self.spec_dec_mask_matrix, use_dynamic_tree, and
max_total_draft_tokens consistently so the slice operations match tensor ranks
(reshape or unsqueeze nonzero results or preallocate a 3-D tensor) to avoid
IndexError.
| if __name__ == "__main__": | ||
| unittest.main() |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
cat tests/unittest/_torch/speculative/test_draft_token_prepare_for_generation.py | head -50Repository: NVIDIA/TensorRT-LLM
Length of output: 1593
🏁 Script executed:
# Get the test file size first
wc -l tests/unittest/_torch/speculative/test_draft_token_prepare_for_generation.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 143
🏁 Script executed:
# Look at the specific test function definition and the end of file
sed -n '1,30p; 450,460p' tests/unittest/_torch/speculative/test_draft_token_prepare_for_generation.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 1108
🏁 Script executed:
# Search for all test function definitions in the file
rg "^def test_" tests/unittest/_torch/speculative/test_draft_token_prepare_for_generation.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 121
🏁 Script executed:
# Check for unittest.TestCase subclasses
rg "class.*\(.*unittest\.TestCase" tests/unittest/_torch/speculative/test_draft_token_prepare_for_generation.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 45
unittest.main() won't discover the module-level test function.
This file defines test_draft_token_static_tree_prepare_for_generation() as a module-level function (pytest convention), but unittest.main() only discovers unittest.TestCase subclasses. Running this file directly will find zero tests. Either remove the unittest.main() block if this is a pytest test, or convert the test to a unittest.TestCase subclass.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tests/unittest/_torch/speculative/test_draft_token_prepare_for_generation.py`
around lines 457 - 458, The file currently calls unittest.main() but defines
module-level pytest-style test_draft_token_static_tree_prepare_for_generation(),
so running the file finds zero tests; either remove the if __name__ ==
"__main__": unittest.main() block to keep this as a pytest-only test, or convert
the module-level function into a unittest.TestCase (e.g., create class
TestDraftTokenPrepareForGeneration(unittest.TestCase) with a method
test_draft_token_static_tree_prepare_for_generation that calls the same
assertions) and keep unittest.main(); update imports accordingly.
|
PR_Github #43166 [ run ] completed with state
|
Description
Checking if pre-merge can be stabilized by reverting this.
Test Coverage
Existing tests must suffice.
PR Checklist
Please review the following before submitting your PR:
PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.
PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.
Test cases are provided for new code paths (see test instructions)
Any new dependencies have been scanned for license and vulnerabilities
CODEOWNERS updated if ownership changes
Documentation updated as needed
Update tava architecture diagram if there is a significant design change in PR.
The reviewers assigned automatically/manually are appropriate for the PR.
Please check this after reviewing the above items as appropriate for this PR.
GitHub Bot Help
To see a list of available CI bot commands, please comment
/bot help.Summary by CodeRabbit
Features Removed
max_batch_sizeparameter from EAGLE-3 configuration.Documentation
Examples