Skip to content

[TRTLLM-11540][feat] Revert EAGLE3 dynamic tree speculative decoding support (#12062)#13006

Merged
litaotju merged 3 commits intoNVIDIA:mainfrom
brb-nv:user/brb/revert-to-unblock-premerge
Apr 14, 2026
Merged

[TRTLLM-11540][feat] Revert EAGLE3 dynamic tree speculative decoding support (#12062)#13006
litaotju merged 3 commits intoNVIDIA:mainfrom
brb-nv:user/brb/revert-to-unblock-premerge

Conversation

@brb-nv
Copy link
Copy Markdown
Collaborator

@brb-nv brb-nv commented Apr 13, 2026

Description

Checking if pre-merge can be stabilized by reverting this.

Test Coverage

Existing tests must suffice.

PR Checklist

Please review the following before submitting your PR:

  • PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.

  • PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.

  • Test cases are provided for new code paths (see test instructions)

  • Any new dependencies have been scanned for license and vulnerabilities

  • CODEOWNERS updated if ownership changes

  • Documentation updated as needed

  • Update tava architecture diagram if there is a significant design change in PR.

  • The reviewers assigned automatically/manually are appropriate for the PR.

  • Please check this after reviewing the above items as appropriate for this PR.

GitHub Bot Help

To see a list of available CI bot commands, please comment /bot help.

Summary by CodeRabbit

  • Features Removed

    • Removed dynamic tree mode support from EAGLE-3 speculative decoding. EAGLE-3 now uses linear draft token generation only.
    • Removed max_batch_size parameter from EAGLE-3 configuration.
  • Documentation

    • Updated documentation to reflect removal of dynamic tree speculative decoding support and simplified EAGLE-3 configuration options.
  • Examples

    • Simplified example scripts by removing dynamic tree and streaming-related parameters.

@brb-nv
Copy link
Copy Markdown
Collaborator Author

brb-nv commented Apr 13, 2026

/bot run --disable-fail-fast

@brb-nv brb-nv changed the title Revert "[TRTLLM-11540][feat] Add EAGLE3 dynamic tree speculative decoding support (#12062)" [TRTLLM-11540][feat] Revert EAGLE3 dynamic tree speculative decoding support (#12062) Apr 13, 2026
@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #43076 [ run ] triggered by Bot. Commit: 43002c5 Link to invocation

…ding support (NVIDIA#12062)"

This reverts commit 4ece13c.

Signed-off-by: Balaram Buddharaju <169953907+brb-nv@users.noreply.github.com>
@brb-nv brb-nv force-pushed the user/brb/revert-to-unblock-premerge branch from 43002c5 to 1e422bf Compare April 13, 2026 18:17
@brb-nv
Copy link
Copy Markdown
Collaborator Author

brb-nv commented Apr 13, 2026

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #43096 [ run ] triggered by Bot. Commit: 1e422bf Link to invocation

Signed-off-by: Yiqing Yan <yiqingy@nvidia.com>
Signed-off-by: Yiqing Yan <yiqingy@nvidia.com>
@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #43096 [ run ] completed with state SUCCESS. Commit: 1e422bf
/LLM/main/L0_MergeRequest_PR pipeline #33734 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@yiqingy0
Copy link
Copy Markdown
Collaborator

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #43166 [ run ] triggered by Bot. Commit: d76cabe Link to invocation

@litaotju litaotju marked this pull request as ready for review April 14, 2026 09:10
@litaotju litaotju requested review from a team as code owners April 14, 2026 09:10
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Apr 14, 2026

📝 Walkthrough

Walkthrough

This PR removes dynamic tree speculative decoding support from the codebase, including CUDA kernels, Torch operators, Python worker implementations, configuration parameters, and related documentation. The linear drafting path for EAGLE3 speculative decoding is retained.

Changes

Cohort / File(s) Summary
Dynamic Tree CUDA Kernels
cpp/tensorrt_llm/kernels/speculativeDecoding/dynamicTreeKernels.cu, cpp/tensorrt_llm/kernels/speculativeDecoding/dynamicTreeKernels.h
Removed entire CUDA implementation of dynamic tree construction and verification kernels, including buildDynamicTreeKernel, verifyDynamicTreeGreedyKernel, and their host-side wrappers.
KV Cache Update (2D Variant)
cpp/tensorrt_llm/kernels/speculativeDecoding/kvCacheUpdateKernels.cu, cpp/tensorrt_llm/kernels/speculativeDecoding/kvCacheUpdateKernels.h, cpp/tensorrt_llm/thop/parallelDecodeKVCacheUpdateOp.cpp
Removed 2D variant of KV cache update operations and kept only the packed/batched update path.
FMHA Configuration and Dispatcher
cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/fmhaRunnerParams.h, cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/prepareCustomMask.cu, cpp/tensorrt_llm/kernels/xqaDispatcher.cpp
Removed mPackedMaskMaxSeqLenQ field and updated mask indexing to use runtime seqLenQ directly instead of pre-computed padded values.
Torch Operators and Build System
cpp/tensorrt_llm/thop/CMakeLists.txt, cpp/tensorrt_llm/thop/dynamicTreeOp.cpp
Removed dynamic tree Torch operator bindings (build_dynamic_tree_op, verify_dynamic_tree_greedy_op, verify_dynamic_tree_greedy_out_op) and removed file from build compilation.
Python Speculative Decoding Infrastructure
tensorrt_llm/_torch/speculative/dynamic_tree_ops.py, tensorrt_llm/_torch/speculative/eagle3_dynamic_tree.py
Completely removed dynamic tree operations converter and dynamic tree worker implementation (1049 lines deleted).
EAGLE3 Configuration and Resource Management
tensorrt_llm/_torch/speculative/eagle3.py, tensorrt_llm/llmapi/llm_args.py
Removed Eagle3OneModelDynamicTreeResourceManager, simplified resource initialization, removed max_batch_size config field, eliminated dynamic-tree-specific branching, and fixed tree mode flags to false.
Attention Backend Updates
tensorrt_llm/_torch/attention_backend/interface.py, tensorrt_llm/_torch/attention_backend/sparse/dsa.py, tensorrt_llm/_torch/attention_backend/trtllm.py
Added spec_decoding_tensor parameter to spec-dec update methods; refactored Blackwell mask allocation strategy and removed position-offset reshape logic.
Speculative Decoding Managers and Wrappers
tensorrt_llm/_torch/speculative/spec_tree_manager.py, tensorrt_llm/_torch/speculative/drafting_loops.py, tensorrt_llm/_torch/speculative/spec_sampler_base.py
Removed dynamic tree buffer management, slot scattering/gathering; renamed StaticTreeDraftingLoopWrapperTreeDraftingLoopWrapper; simplified buffer allocation logic.
Model Engine and Executor
tensorrt_llm/_torch/pyexecutor/model_engine.py, tensorrt_llm/_torch/pyexecutor/py_executor.py, tensorrt_llm/_torch/pyexecutor/py_executor_creator.py, tensorrt_llm/_torch/pyexecutor/resource_manager.py
Removed spec tree data gathering, updated drafting wrapper selection to use is_linear_tree, changed runtime draft-length fallback, removed _get_spec_worker() method.
Speculative Utilities and Core Logic
tensorrt_llm/_torch/modules/attention.py, tensorrt_llm/_torch/speculative/utils.py
Removed _adjust_position_ids_for_spec_dec() helper and dynamic tree import paths; simplified Eagle3 sampler initialization and resource/worker selection.
Examples and Configuration
examples/llm-api/quickstart_advanced.py
Removed EAGLE3 dynamic tree CLI arguments (--max_total_draft_tokens, --max_batch_size); simplified prompt sourcing and removed async/streaming generation with acceptance-rate computation.
Documentation
docs/source/features/feature-combination-matrix.md, docs/source/features/speculative-decoding.md, docs/source/models/supported-models.md
Consolidated EAGLE-3 columns by removing "Linear/Dynamic" variants; removed dynamic tree mode documentation and updated feature descriptions to reflect linear-only path.
Tests
tests/integration/test_lists/waives.txt, tests/unittest/_torch/modeling/test_modeling_llama.py, tests/unittest/_torch/speculative/test_draft_token_prepare_for_generation.py, tests/unittest/_torch/speculative/test_draft_token_tree_sampling.py, tests/unittest/_torch/speculative/test_eagle3.py, tests/unittest/_torch/thop/parallel/test_custom_ops.py, tests/unittest/others/test_kv_cache_update.py
Removed dynamic tree test cases, skip entries, and 2D KV cache tests; updated wrapper class references and simplified spec-decoding tensor usage in unit tests.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~75 minutes

Possibly related PRs

Suggested reviewers

  • syuoni
  • mikeiovine
  • ziyixiong-nv
  • kris1025
  • lfr-0531
🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 26.23% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly indicates this is a revert of EAGLE3 dynamic tree speculative decoding support with the JIRA ticket reference.
Description check ✅ Passed The description is minimal but adequate, stating the purpose is to revert changes for pre-merge stabilization and noting that existing tests suffice.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 10

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (3)
examples/llm-api/quickstart_advanced.py (1)

265-273: ⚠️ Potential issue | 🟠 Major

Stop wiring tree-only knobs into the EAGLE3 example.

This revert removes the PyTorch EAGLE3 tree path, but this helper still forwards eagle_choices, use_dynamic_tree, and dynamic_tree_max_topK into Eagle3DecodingConfig. --spec_decode_algo EAGLE3 can still advertise an unsupported mode from the example.

🐛 Proposed fix
         spec_config = Eagle3DecodingConfig(
             max_draft_len=args.spec_decode_max_draft_len,
             speculative_model=args.draft_model_dir,
             eagle3_one_model=args.use_one_model,
-            eagle_choices=args.eagle_choices,
-            use_dynamic_tree=args.use_dynamic_tree,
-            dynamic_tree_max_topK=args.dynamic_tree_max_topK,
             allow_advanced_sampling=args.allow_advanced_sampling,
             eagle3_model_arch=args.eagle3_model_arch)

The parser flags should be removed or rejected as a follow-up outside this hunk.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/llm-api/quickstart_advanced.py` around lines 265 - 273, The Eagle3
example is still passing tree-only knobs into the Eagle3DecodingConfig; remove
the unsupported parameters eagle_choices, use_dynamic_tree, and
dynamic_tree_max_topK from the spec_config construction (the
Eagle3DecodingConfig call that builds spec_config) so it only forwards supported
args like max_draft_len, speculative_model, eagle3_one_model,
allow_advanced_sampling, and eagle3_model_arch; do not add conditional logic to
enable a tree path here—strip those args from the constructor call and leave
parser flag removal/rejection as a separate follow-up.
tensorrt_llm/llmapi/llm_args.py (1)

1129-1174: ⚠️ Potential issue | 🟠 Major

Keep the linear-tree default out of the dynamic-tree branch.

Line 1130 unconditionally overwrites max_total_draft_tokens with max_draft_len before the dynamic-tree checks. That means Lines 1171-1174 can never catch an omitted value, and a dynamic tree silently gets linear-tree sizing instead of the explicit total-node budget it needs.

🐛 Proposed fix
         self.num_eagle_layers = self.max_draft_len
-        self.max_total_draft_tokens = self.max_draft_len  # If using linear-tree, the max_total_draft_tokens is the same as max_draft_len
+        if not self.use_dynamic_tree:
+            self.max_total_draft_tokens = self.max_draft_len  # Linear tree only

         if self.eagle3_model_arch == "mistral_large3" and self.eagle3_layers_to_capture is None:
             # FIXME find a better way to setup it.
             self.eagle3_layers_to_capture = {-1}
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/llmapi/llm_args.py` around lines 1129 - 1174, The code
unconditionally sets self.max_total_draft_tokens = self.max_draft_len (using
symbols self.max_total_draft_tokens, self.max_draft_len) before handling
dynamic-tree, which causes dynamic trees to inherit linear-tree sizing; remove
that unconditional assignment and instead set max_total_draft_tokens only for
the linear/static-tree path (e.g., guard it with if not self.use_dynamic_tree or
move it into the static-tree block after eagle_choices handling), ensuring the
dynamic-tree checks (use_dynamic_tree, dynamic_tree_max_topK,
max_total_draft_tokens) can validate a provided total-node budget.
tensorrt_llm/_torch/speculative/eagle3.py (1)

75-84: ⚠️ Potential issue | 🟠 Major

This leaves use_dynamic_tree metadata without a SpecTreeManager.

After this change, Eagle3ResourceManager only builds SpecTreeManager when config.eagle_choices is set. But tensorrt_llm/_torch/speculative/utils.py:create_spec_metadata() still marks the EAGLE3 metadata as tree/dynamic-tree whenever spec_config.use_dynamic_tree is true. That leaves the runtime carrying dynamic-tree flags with spec_tree_manager=None, and Eagle3SpecMetadata.prepare() falls back to the linear first-draft path instead of the accepted-path/tree handling.

Please either clear use_dynamic_tree at metadata creation as part of this revert, or keep constructing the manager until those callers are updated too.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/speculative/eagle3.py` around lines 75 - 84, The change
stops constructing SpecTreeManager when config.eagle_choices is None but leaves
spec_config.use_dynamic_tree true, causing metadata to advertise dynamic-tree
behavior without a manager; update Eagle3ResourceManager so that if
config.use_dynamic_tree is true but config.eagle_choices is None you either (A)
clear/disable use_dynamic_tree on the created Eagle3SpecMetadata (set
spec_config.use_dynamic_tree = False) before returning metadata, or (B) continue
constructing SpecTreeManager even when eagle_choices is None so
self.spec_tree_manager is always present when use_dynamic_tree is enabled; touch
the Eagle3ResourceManager construction logic around the SpecTreeManager
instantiation (and ensure create_spec_metadata()/Eagle3SpecMetadata.prepare()
will see a consistent spec_config and self.spec_tree_manager).
🧹 Nitpick comments (2)
tensorrt_llm/_torch/speculative/spec_tree_manager.py (1)

18-18: Avoid the shared mutable class default for top_k_list.

This list is shared at the class level until __init__ rebinds it. Keeping only the annotation here avoids accidental cross-instance state and clears the Ruff RUF012 warning.

♻️ Proposed fix
-    top_k_list = []
+    top_k_list: list[torch.Tensor]
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/speculative/spec_tree_manager.py` at line 18, The
class-level mutable default "top_k_list = []" creates shared state across
instances and triggers RUF012; remove the assignment so only the annotation
remains (e.g., keep "top_k_list: list[int]" or appropriate type) and ensure the
instance attribute is initialized in the class __init__ (set self.top_k_list =
[] there) so each instance gets its own list; reference: top_k_list and __init__
in the same class.
tensorrt_llm/_torch/speculative/drafting_loops.py (1)

229-233: Consider simplifying the spec_tree_manager extraction.

Based on the executor creator logic (context snippet 3), TreeDraftingLoopWrapper is only instantiated when EagleDecodingConfig is used and is_linear_tree is False, guaranteeing Eagle3SpecMetadata will always be passed. The isinstance check and subsequent assert could be combined.

♻️ Optional simplification
-        spec_tree_manager = None
-        if isinstance(spec_metadata, Eagle3SpecMetadata):
-            spec_tree_manager = spec_metadata.eagle3_resource_manager.spec_tree_manager
-
-        assert spec_tree_manager is not None
+        assert isinstance(spec_metadata, Eagle3SpecMetadata), \
+            "TreeDraftingLoopWrapper requires Eagle3SpecMetadata"
+        spec_tree_manager = spec_metadata.eagle3_resource_manager.spec_tree_manager
+        assert spec_tree_manager is not None, "spec_tree_manager must be set"
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/speculative/drafting_loops.py` around lines 229 - 233,
Summary: simplify extraction of spec_tree_manager by combining the type
assertion and assignment. Replace the current two-step pattern (setting
spec_tree_manager=None, then using isinstance check and a separate assert) with
a single assert that spec_metadata is an Eagle3SpecMetadata followed immediately
by assigning spec_tree_manager from
spec_metadata.eagle3_resource_manager.spec_tree_manager (this aligns with the
guarantee from TreeDraftingLoopWrapper/EagleDecodingConfig); update any type
hints or mypy ignores if needed.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/prepareCustomMask.cu`:
- Around line 168-169: Update the comments in prepareCustomMask (in
trtllmGenKernels/fmha/prepareCustomMask.cu) so they consistently describe the
same mask layout: change the function-level input-shape comment to match the
inline comment that the input mask shape is [bs, seqLenQ, ceilDiv(seqLenQ, 32)]
and clarify that the KV dimension corresponds to Q positions (tree mask); ensure
there is a single definitive comment describing the third dimension as
ceilDiv(seqLenQ, 32) and remove any conflicting description elsewhere in the
function.
- Around line 168-174: The packed-mask indexing uses ceilDiv(seqLenQ, 32) as the
row stride when computing qMaskBaseIdx/packedMaskIdx, but the producer allocates
rows with a fixed stride numPackedMasks (divUp(maxDecodingTokens, 32)), causing
under-stride when seqLenQ < maxDecodingTokens; fix by using the actual allocated
stride: replace the per-kernel computed stride ceilDiv(seqLenQ, 32) with the
passed/known allocation stride (e.g., numPackedMasks or a new kernel parameter
maskRowStride) when computing qMaskBaseIdx/packedMaskIdx, or alternatively
change the producer to allocate rows using ceilDiv(seqLenQ, 32) per-batch so the
allocation and kernel agree.

In `@examples/llm-api/quickstart_advanced.py`:
- Around line 404-412: The code indexes optional dicts
sequence.additional_context_outputs and sequence.additional_generation_outputs
for each output_name from args.additional_model_outputs without guarding for
missing keys; update the loop in the printing block to check presence before
indexing (e.g., if sequence.additional_context_outputs and output_name in
sequence.additional_context_outputs) and similarly for
sequence.additional_generation_outputs, only printing the Context and Generation
lines when the corresponding map contains output_name (use .get or membership
checks) so printing [i]{sequence_id_text} lines never raises KeyError/TypeError.

In `@tensorrt_llm/_torch/attention_backend/sparse/dsa.py`:
- Around line 600-608: The type name 'SpecDecodingTensor' used in the
update_spec_dec_param signature is not imported causing F821; add an import for
SpecDecodingTensor inside the existing TYPE_CHECKING block so static
type-checkers see the symbol without affecting runtime. Locate the TYPE_CHECKING
block near the top of tensorrt_llm/_torch/attention_backend/sparse/dsa.py and
add a from ... import SpecDecodingTensor (matching the module where
SpecDecodingTensor is defined) so the function signature referencing
'SpecDecodingTensor' resolves.

In `@tensorrt_llm/_torch/attention_backend/trtllm.py`:
- Around line 1536-1548: The code incorrectly forces
self.is_spec_decoding_enabled False on any SM that reports TRTLLM-gen-kernel
support; change the logic so spec-decoding is only disabled for SM-gen-kernel
machines when a spec-dec tree or dynamic tree is present. Concretely, compute a
local flag via self.is_sm_version_trtllm_gen_kernel(sm=get_sm_version()) and set
self.is_spec_decoding_enabled = is_spec_decoding_enabled unless (is_sm_gen and
(spec_tree_manager is not None or spec_tree_manager.use_dynamic_tree)), keep
self.is_spec_dec_tree and self.is_spec_dec_dynamic_tree assignments and the
existing assertion block (the assertion about spec-dec tree support) as-is, and
ensure self.use_spec_decoding continues to initialize from
self.is_spec_decoding_enabled; refer to attributes/methods
self.is_spec_decoding_enabled, self.is_spec_dec_tree,
self.is_spec_dec_dynamic_tree, self.use_spec_decoding and method
is_sm_version_trtllm_gen_kernel.

In `@tensorrt_llm/_torch/pyexecutor/model_engine.py`:
- Around line 2399-2405: spec_tree_manager.spec_dec_position_offsets[0] is a
device tensor and is being accessed element-wise inside per-request loops;
materialize it once into a Python list (e.g., host_offsets =
spec_tree_manager.spec_dec_position_offsets[0].tolist()) outside the request
loops and then use host_offsets to extend position_ids in both branches instead
of indexing the tensor inside the loop; apply the same change pattern wherever
spec_dec_position_offsets is iterated (also update the analogous places handling
position_ids in the other request-loop blocks).

In `@tensorrt_llm/_torch/speculative/spec_sampler_base.py`:
- Around line 241-253: The current padding uses self.draft_len but must use the
allocated store width from the sampler storage helpers; compute storage_width =
self._get_draft_tokens_storage_size() (or use _get_max_tokens() if appropriate)
and pad o_new_tokens and o_next_new_tokens up to storage_width, and pad
o_next_draft_tokens up to storage_width - 1, so the tensors match the
destination buffers before the index_copy_ calls (look for o_new_tokens,
o_next_draft_tokens, o_next_new_tokens and index_copy_ in SpecSamplerBase).

In `@tensorrt_llm/_torch/speculative/spec_tree_manager.py`:
- Around line 260-266: get_eagle_paths is indexing self.eagle_paths[tree_idx] as
if it's a 3-D tensor ([:, i, :]) but after self.eagle_paths[tree_idx] it's 2-D,
causing an IndexError when use_dynamic_tree is True; fix by constructing or
indexing the correct shape: either ensure self.eagle_paths[tree_idx] is a 3-D
tensor of shape (batch, max_total_draft_tokens+1, path_len) before the loop, or
change the assignment to match the 2-D layout (e.g. assign into
self.eagle_paths[tree_idx][:, i] or reshape the nonzero result to match the
third dimension). Update get_eagle_paths to use self.eagle_paths,
self.spec_dec_mask_matrix, use_dynamic_tree, and max_total_draft_tokens
consistently so the slice operations match tensor ranks (reshape or unsqueeze
nonzero results or preallocate a 3-D tensor) to avoid IndexError.
- Around line 73-75: The type annotation for the eagle_choices parameter in
SpecTreeManager.__init__ is using a list literal ([List[List[int]]]) instead of
a proper type; change the signature to use the modern typing form (e.g.,
eagle_choices: list[list[int]] | None) so type checkers and inspect-based
tooling can understand it, and update any callers or default behavior if you
make it optional (None) to preserve existing semantics.

In
`@tests/unittest/_torch/speculative/test_draft_token_prepare_for_generation.py`:
- Around line 457-458: The file currently calls unittest.main() but defines
module-level pytest-style test_draft_token_static_tree_prepare_for_generation(),
so running the file finds zero tests; either remove the if __name__ ==
"__main__": unittest.main() block to keep this as a pytest-only test, or convert
the module-level function into a unittest.TestCase (e.g., create class
TestDraftTokenPrepareForGeneration(unittest.TestCase) with a method
test_draft_token_static_tree_prepare_for_generation that calls the same
assertions) and keep unittest.main(); update imports accordingly.

---

Outside diff comments:
In `@examples/llm-api/quickstart_advanced.py`:
- Around line 265-273: The Eagle3 example is still passing tree-only knobs into
the Eagle3DecodingConfig; remove the unsupported parameters eagle_choices,
use_dynamic_tree, and dynamic_tree_max_topK from the spec_config construction
(the Eagle3DecodingConfig call that builds spec_config) so it only forwards
supported args like max_draft_len, speculative_model, eagle3_one_model,
allow_advanced_sampling, and eagle3_model_arch; do not add conditional logic to
enable a tree path here—strip those args from the constructor call and leave
parser flag removal/rejection as a separate follow-up.

In `@tensorrt_llm/_torch/speculative/eagle3.py`:
- Around line 75-84: The change stops constructing SpecTreeManager when
config.eagle_choices is None but leaves spec_config.use_dynamic_tree true,
causing metadata to advertise dynamic-tree behavior without a manager; update
Eagle3ResourceManager so that if config.use_dynamic_tree is true but
config.eagle_choices is None you either (A) clear/disable use_dynamic_tree on
the created Eagle3SpecMetadata (set spec_config.use_dynamic_tree = False) before
returning metadata, or (B) continue constructing SpecTreeManager even when
eagle_choices is None so self.spec_tree_manager is always present when
use_dynamic_tree is enabled; touch the Eagle3ResourceManager construction logic
around the SpecTreeManager instantiation (and ensure
create_spec_metadata()/Eagle3SpecMetadata.prepare() will see a consistent
spec_config and self.spec_tree_manager).

In `@tensorrt_llm/llmapi/llm_args.py`:
- Around line 1129-1174: The code unconditionally sets
self.max_total_draft_tokens = self.max_draft_len (using symbols
self.max_total_draft_tokens, self.max_draft_len) before handling dynamic-tree,
which causes dynamic trees to inherit linear-tree sizing; remove that
unconditional assignment and instead set max_total_draft_tokens only for the
linear/static-tree path (e.g., guard it with if not self.use_dynamic_tree or
move it into the static-tree block after eagle_choices handling), ensuring the
dynamic-tree checks (use_dynamic_tree, dynamic_tree_max_topK,
max_total_draft_tokens) can validate a provided total-node budget.

---

Nitpick comments:
In `@tensorrt_llm/_torch/speculative/drafting_loops.py`:
- Around line 229-233: Summary: simplify extraction of spec_tree_manager by
combining the type assertion and assignment. Replace the current two-step
pattern (setting spec_tree_manager=None, then using isinstance check and a
separate assert) with a single assert that spec_metadata is an
Eagle3SpecMetadata followed immediately by assigning spec_tree_manager from
spec_metadata.eagle3_resource_manager.spec_tree_manager (this aligns with the
guarantee from TreeDraftingLoopWrapper/EagleDecodingConfig); update any type
hints or mypy ignores if needed.

In `@tensorrt_llm/_torch/speculative/spec_tree_manager.py`:
- Line 18: The class-level mutable default "top_k_list = []" creates shared
state across instances and triggers RUF012; remove the assignment so only the
annotation remains (e.g., keep "top_k_list: list[int]" or appropriate type) and
ensure the instance attribute is initialized in the class __init__ (set
self.top_k_list = [] there) so each instance gets its own list; reference:
top_k_list and __init__ in the same class.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro Plus

Run ID: 314827d7-8cad-4813-81c8-dcc291c62b30

📥 Commits

Reviewing files that changed from the base of the PR and between f6a80d2 and d76cabe.

📒 Files selected for processing (37)
  • cpp/tensorrt_llm/kernels/speculativeDecoding/dynamicTreeKernels.cu
  • cpp/tensorrt_llm/kernels/speculativeDecoding/dynamicTreeKernels.h
  • cpp/tensorrt_llm/kernels/speculativeDecoding/kvCacheUpdateKernels.cu
  • cpp/tensorrt_llm/kernels/speculativeDecoding/kvCacheUpdateKernels.h
  • cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/fmhaRunnerParams.h
  • cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/prepareCustomMask.cu
  • cpp/tensorrt_llm/kernels/xqaDispatcher.cpp
  • cpp/tensorrt_llm/thop/CMakeLists.txt
  • cpp/tensorrt_llm/thop/dynamicTreeOp.cpp
  • cpp/tensorrt_llm/thop/parallelDecodeKVCacheUpdateOp.cpp
  • docs/source/features/feature-combination-matrix.md
  • docs/source/features/speculative-decoding.md
  • docs/source/models/supported-models.md
  • examples/llm-api/quickstart_advanced.py
  • tensorrt_llm/_torch/attention_backend/interface.py
  • tensorrt_llm/_torch/attention_backend/sparse/dsa.py
  • tensorrt_llm/_torch/attention_backend/trtllm.py
  • tensorrt_llm/_torch/modules/attention.py
  • tensorrt_llm/_torch/pyexecutor/model_engine.py
  • tensorrt_llm/_torch/pyexecutor/py_executor.py
  • tensorrt_llm/_torch/pyexecutor/py_executor_creator.py
  • tensorrt_llm/_torch/pyexecutor/resource_manager.py
  • tensorrt_llm/_torch/speculative/drafting_loops.py
  • tensorrt_llm/_torch/speculative/dynamic_tree_ops.py
  • tensorrt_llm/_torch/speculative/eagle3.py
  • tensorrt_llm/_torch/speculative/eagle3_dynamic_tree.py
  • tensorrt_llm/_torch/speculative/spec_sampler_base.py
  • tensorrt_llm/_torch/speculative/spec_tree_manager.py
  • tensorrt_llm/_torch/speculative/utils.py
  • tensorrt_llm/llmapi/llm_args.py
  • tests/integration/test_lists/waives.txt
  • tests/unittest/_torch/modeling/test_modeling_llama.py
  • tests/unittest/_torch/speculative/test_draft_token_prepare_for_generation.py
  • tests/unittest/_torch/speculative/test_draft_token_tree_sampling.py
  • tests/unittest/_torch/speculative/test_eagle3.py
  • tests/unittest/_torch/thop/parallel/test_custom_ops.py
  • tests/unittest/others/test_kv_cache_update.py
💤 Files with no reviewable changes (16)
  • cpp/tensorrt_llm/thop/CMakeLists.txt
  • tests/unittest/_torch/thop/parallel/test_custom_ops.py
  • cpp/tensorrt_llm/kernels/xqaDispatcher.cpp
  • cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/fmhaRunnerParams.h
  • tests/integration/test_lists/waives.txt
  • tests/unittest/_torch/speculative/test_eagle3.py
  • tensorrt_llm/_torch/modules/attention.py
  • cpp/tensorrt_llm/kernels/speculativeDecoding/dynamicTreeKernels.h
  • cpp/tensorrt_llm/thop/parallelDecodeKVCacheUpdateOp.cpp
  • tests/unittest/others/test_kv_cache_update.py
  • tensorrt_llm/_torch/speculative/dynamic_tree_ops.py
  • cpp/tensorrt_llm/kernels/speculativeDecoding/kvCacheUpdateKernels.cu
  • cpp/tensorrt_llm/kernels/speculativeDecoding/kvCacheUpdateKernels.h
  • cpp/tensorrt_llm/kernels/speculativeDecoding/dynamicTreeKernels.cu
  • cpp/tensorrt_llm/thop/dynamicTreeOp.cpp
  • tensorrt_llm/_torch/speculative/eagle3_dynamic_tree.py

Comment on lines +168 to 169
// Input mask shape: [bs, seqLenQ, ceilDiv(seqLenQ, 32)]
// The KV dimension in the mask corresponds to Q positions (tree mask)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Mask-layout comments are inconsistent in this function.

Line 168 documents [bs, seqLenQ, ceilDiv(seqLenQ, 32)], but the function-level input-shape comment still describes a different third dimension. Please align both comments to one definitive layout contract.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/prepareCustomMask.cu` around
lines 168 - 169, Update the comments in prepareCustomMask (in
trtllmGenKernels/fmha/prepareCustomMask.cu) so they consistently describe the
same mask layout: change the function-level input-shape comment to match the
inline comment that the input mask shape is [bs, seqLenQ, ceilDiv(seqLenQ, 32)]
and clarify that the KV dimension corresponds to Q positions (tree mask); ensure
there is a single definitive comment describing the third dimension as
ceilDiv(seqLenQ, 32) and remove any conflicting description elsewhere in the
function.

Comment on lines +168 to 174
// Input mask shape: [bs, seqLenQ, ceilDiv(seqLenQ, 32)]
// The KV dimension in the mask corresponds to Q positions (tree mask)
int32_t const qPosInTree = tokenIdxKv - firstSparseMaskOffsetKv;
if (qPosInTree < seqLenQ)
{
int32_t const qMaskBaseIdx = (batchIdx * packedMaskSeqLenQ + tokenIdxQ) * packedMaskNumBlocks;
int32_t const qMaskBaseIdx = (batchIdx * seqLenQ + tokenIdxQ) * ceilDiv(seqLenQ, 32);
int32_t const packedMaskIdx = qMaskBaseIdx + (qPosInTree >> 5);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
set -euo pipefail

echo "== Producer/consumer wiring =="
rg -n -C3 'spec_decoding_packed_mask|generalPackedCustoMaskPtr' cpp

echo "== Packed-mask allocation/reshape sites =="
rg -n -C3 'attentionPackedMask|numPackedMasks|makeShape\\(|reshape\\(' cpp/tensorrt_llm/batch_manager

echo "== Remaining packed-stride params in FMHA runner path =="
rg -n -C3 'PackedMask|packedMaskNumBlocks|packedMaskSeqLen|mPackedMaskMaxSeqLenQ' \
  cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha

Repository: NVIDIA/TensorRT-LLM

Length of output: 25113


🏁 Script executed:

# Find where the packed mask tensor is actually allocated and reshaped
rg -n "attentionPackedMask|spec_decoding_packed_mask" cpp/tensorrt_llm/batch_manager --type cpp | head -40

Repository: NVIDIA/TensorRT-LLM

Length of output: 1519


🏁 Script executed:

# Look at medusaBuffers.cpp to find allocation of attentionPackedMaskDevice
sed -n '1,200p' cpp/tensorrt_llm/batch_manager/medusaBuffers.cpp

Repository: NVIDIA/TensorRT-LLM

Length of output: 7527


🏁 Script executed:

# Find reshape/reshape operations and makeShape in the codebase
rg -n "reshape\|makeShape" cpp/tensorrt_llm/batch_manager --type cpp -A 2 | head -60

Repository: NVIDIA/TensorRT-LLM

Length of output: 45


🏁 Script executed:

# Find the definition of getNumPackedMasks
rg -n "getNumPackedMasks|NumPackedMasks" cpp/tensorrt_llm/runtime --type cpp -B 2 -A 5 | head -80

Repository: NVIDIA/TensorRT-LLM

Length of output: 1835


🏁 Script executed:

# Find MedusaModule definition
fd -name "*medusaModule*" cpp/tensorrt_llm/runtime

Repository: NVIDIA/TensorRT-LLM

Length of output: 293


🏁 Script executed:

# Look at the prepareCustomMask kernel more carefully to understand indexing
sed -n '165,180p' cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/prepareCustomMask.cu

Repository: NVIDIA/TensorRT-LLM

Length of output: 726


🏁 Script executed:

# Find MedusaModule header and implementation
fd "MedusaModule" cpp/tensorrt_llm/runtime

Repository: NVIDIA/TensorRT-LLM

Length of output: 45


🏁 Script executed:

# Search for getNumPackedMasks implementation
rg "getNumPackedMasks" cpp/tensorrt_llm/runtime --type cpp -B 3 -A 5

Repository: NVIDIA/TensorRT-LLM

Length of output: 2013


🏁 Script executed:

# Find where MedusaModule is defined
rg "class MedusaModule|class SpeculativeDecodingModule" cpp/tensorrt_llm --type cpp -A 10

Repository: NVIDIA/TensorRT-LLM

Length of output: 45


🏁 Script executed:

# Search in headers
fd "\.h$" cpp/tensorrt_llm/runtime | xargs rg "getNumPackedMasks" | head -20

Repository: NVIDIA/TensorRT-LLM

Length of output: 45


🏁 Script executed:

# Look at the comment in prepareCustomMask.cu more carefully around the input shape
sed -n '90,115p' cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/prepareCustomMask.cu

Repository: NVIDIA/TensorRT-LLM

Length of output: 1415


🏁 Script executed:

# Search for getNumPackedMasks in header files
find cpp/tensorrt_llm -name "*.h" -type f | xargs rg "getNumPackedMasks" -B 2 -A 3

Repository: NVIDIA/TensorRT-LLM

Length of output: 47


🏁 Script executed:

# Search for implementations of getNumPackedMasks more broadly
rg "NumPackedMasks\s*\(" cpp/tensorrt_llm --type cpp -B 3 -A 3

Repository: NVIDIA/TensorRT-LLM

Length of output: 2263


🏁 Script executed:

# Look at what maxDecodingTokens and numPackedMasks relationship might be
rg "maxDecodingTokens|maxMedusaTokens" cpp/tensorrt_llm/runtime --type cpp | grep -i "packed\|num" | head -20

Repository: NVIDIA/TensorRT-LLM

Length of output: 2717


🏁 Script executed:

# Found it in explicitDraftTokensBuffers - check if there's a pattern for numPackedMasks
rg "ceilDiv.*32.*packed|getNumPackedMasks" cpp/tensorrt_llm --type cpp -B 2 -A 2

Repository: NVIDIA/TensorRT-LLM

Length of output: 1622


🏁 Script executed:

# Look at medusaModule or similar to find numPackedMasks implementation
find cpp/tensorrt_llm -name "*.h" -o -name "*.cpp" | xargs rg "class.*Module.*{" | grep -i "medusa\|speculative" | head -10

Repository: NVIDIA/TensorRT-LLM

Length of output: 734


🏁 Script executed:

# Search for where numPackedMasks is actually set/returned
rg "numPackedMasks\s*=" cpp/tensorrt_llm --type cpp -B 2 -A 2

Repository: NVIDIA/TensorRT-LLM

Length of output: 1712


🏁 Script executed:

# Look at where seqLenQ comes from in prepareCustomMask kernel
sed -n '100,180p' cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/prepareCustomMask.cu

Repository: NVIDIA/TensorRT-LLM

Length of output: 3356


🏁 Script executed:

# Check the medusaModule definition to see what getNumPackedMasks returns
find cpp/tensorrt_llm/runtime -name "*.h" | xargs grep -l "MedusaModule\|getNumPackedMasks" | head -5

Repository: NVIDIA/TensorRT-LLM

Length of output: 45


🏁 Script executed:

# Look at test file to understand the relationship
sed -n '1,50p' cpp/tests/unit_tests/kernels/prepareCustomMaskTest.cpp

Repository: NVIDIA/TensorRT-LLM

Length of output: 1806


🏁 Script executed:

# Check if seqLensQPtr is set to a fixed value or varies
rg "seqLensQPtr\[" cpp/tensorrt_llm --type cpp -B 3 -A 3 | head -60

Repository: NVIDIA/TensorRT-LLM

Length of output: 45


🏁 Script executed:

# Look at the test to see what values are used
sed -n '100,200p' cpp/tests/unit_tests/kernels/prepareCustomMaskTest.cpp

Repository: NVIDIA/TensorRT-LLM

Length of output: 4265


🏁 Script executed:

# Continue looking at the test to see actual values used
sed -n '200,320p' cpp/tests/unit_tests/kernels/prepareCustomMaskTest.cpp

Repository: NVIDIA/TensorRT-LLM

Length of output: 6746


🏁 Script executed:

# Check if there's documentation about the relationship between seqLenQ and maxDecodingTokens
rg "seqLenQ|maxDecodingTokens" cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/prepareCustomMask.cu | head -20

Repository: NVIDIA/TensorRT-LLM

Length of output: 740


🏁 Script executed:

# Verify what getNumPackedMasks returns in MedusaModule
rg "getNumPackedMasks" cpp/tensorrt_llm --type cpp -A 3 -B 1

Repository: NVIDIA/TensorRT-LLM

Length of output: 1565


🏁 Script executed:

# Double-check: does the test actually allocate with uniform or variable seqLenQ?
sed -n '210,245p' cpp/tests/unit_tests/kernels/prepareCustomMaskTest.cpp

Repository: NVIDIA/TensorRT-LLM

Length of output: 1701


Confirmed stride mismatch in packed-mask indexing at line 173.

The kernel indexes the packed mask using ceilDiv(seqLenQ, 32) as the row stride, but the producer allocates with a fixed stride of numPackedMasks (computed as divUp(maxDecodingTokens, 32) at startup). Since seqLenQ is per-batch and can be less than maxDecodingTokens, the kernel will under-stride and read incorrect mask bits when seqLenQ < maxDecodingTokens.

Either ensure the allocated stride matches the kernel's stride computation at line 173, or compute and pass the correct stride value per batch to the kernel.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/prepareCustomMask.cu` around
lines 168 - 174, The packed-mask indexing uses ceilDiv(seqLenQ, 32) as the row
stride when computing qMaskBaseIdx/packedMaskIdx, but the producer allocates
rows with a fixed stride numPackedMasks (divUp(maxDecodingTokens, 32)), causing
under-stride when seqLenQ < maxDecodingTokens; fix by using the actual allocated
stride: replace the per-kernel computed stride ceilDiv(seqLenQ, 32) with the
passed/known allocation stride (e.g., numPackedMasks or a new kernel parameter
maskRowStride) when computing qMaskBaseIdx/packedMaskIdx, or alternatively
change the producer to allocate rows using ceilDiv(seqLenQ, 32) per-batch so the
allocation and kernel agree.

Comment on lines +404 to +412
if args.additional_model_outputs:
for output_name in args.additional_model_outputs:
if sequence.additional_context_outputs:
print(
f"[{i}]{sequence_id_text} Generation {output_name}: {sequence.additional_generation_outputs[output_name]}"
f"[{i}]{sequence_id_text} Context {output_name}: {sequence.additional_context_outputs[output_name]}"
)
print(
f"[{i}]{sequence_id_text} Generation {output_name}: {sequence.additional_generation_outputs[output_name]}"
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Guard optional extra-output maps before indexing them.

additional_context_outputs and additional_generation_outputs are optional. Any requested output_name that's absent from one of those dicts currently raises TypeError/KeyError while printing results.

🛡️ Proposed fix
             if args.additional_model_outputs:
                 for output_name in args.additional_model_outputs:
-                    if sequence.additional_context_outputs:
+                    if (
+                        sequence.additional_context_outputs
+                        and output_name in sequence.additional_context_outputs
+                    ):
                         print(
                             f"[{i}]{sequence_id_text} Context {output_name}: {sequence.additional_context_outputs[output_name]}"
                         )
-                    print(
-                        f"[{i}]{sequence_id_text} Generation {output_name}: {sequence.additional_generation_outputs[output_name]}"
-                    )
+                    if (
+                        sequence.additional_generation_outputs
+                        and output_name in sequence.additional_generation_outputs
+                    ):
+                        print(
+                            f"[{i}]{sequence_id_text} Generation {output_name}: {sequence.additional_generation_outputs[output_name]}"
+                        )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if args.additional_model_outputs:
for output_name in args.additional_model_outputs:
if sequence.additional_context_outputs:
print(
f"[{i}]{sequence_id_text} Generation {output_name}: {sequence.additional_generation_outputs[output_name]}"
f"[{i}]{sequence_id_text} Context {output_name}: {sequence.additional_context_outputs[output_name]}"
)
print(
f"[{i}]{sequence_id_text} Generation {output_name}: {sequence.additional_generation_outputs[output_name]}"
)
if args.additional_model_outputs:
for output_name in args.additional_model_outputs:
if (
sequence.additional_context_outputs
and output_name in sequence.additional_context_outputs
):
print(
f"[{i}]{sequence_id_text} Context {output_name}: {sequence.additional_context_outputs[output_name]}"
)
if (
sequence.additional_generation_outputs
and output_name in sequence.additional_generation_outputs
):
print(
f"[{i}]{sequence_id_text} Generation {output_name}: {sequence.additional_generation_outputs[output_name]}"
)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/llm-api/quickstart_advanced.py` around lines 404 - 412, The code
indexes optional dicts sequence.additional_context_outputs and
sequence.additional_generation_outputs for each output_name from
args.additional_model_outputs without guarding for missing keys; update the loop
in the printing block to check presence before indexing (e.g., if
sequence.additional_context_outputs and output_name in
sequence.additional_context_outputs) and similarly for
sequence.additional_generation_outputs, only printing the Context and Generation
lines when the corresponding map contains output_name (use .get or membership
checks) so printing [i]{sequence_id_text} lines never raises KeyError/TypeError.

Comment on lines +600 to +608
spec_decoding_tensor: Optional['SpecDecodingTensor'] = None,
):
"""Update speculative decoding parameters and create expanded buffers."""
super().update_spec_dec_param(batch_size, is_spec_decoding_enabled,
is_spec_dec_tree,
is_spec_dec_dynamic_tree, max_draft_len,
max_total_draft_tokens, model_is_wrapped,
spec_metadata, spec_tree_manager)
spec_metadata, spec_tree_manager,
spec_decoding_tensor)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Missing TYPE_CHECKING import for SpecDecodingTensor.

The type hint 'SpecDecodingTensor' is used but not imported. Static analysis correctly flags this as F821: undefined name 'SpecDecodingTensor'. You need to add the import under TYPE_CHECKING block.

🐛 Proposed fix to add the missing import

Add to the existing TYPE_CHECKING block (around line 38-39):

 if TYPE_CHECKING:
     from tensorrt_llm.llmapi.llm_args import DecodingBaseConfig
+    from tensorrt_llm._torch.speculative.utils import SpecDecodingTensor
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
spec_decoding_tensor: Optional['SpecDecodingTensor'] = None,
):
"""Update speculative decoding parameters and create expanded buffers."""
super().update_spec_dec_param(batch_size, is_spec_decoding_enabled,
is_spec_dec_tree,
is_spec_dec_dynamic_tree, max_draft_len,
max_total_draft_tokens, model_is_wrapped,
spec_metadata, spec_tree_manager)
spec_metadata, spec_tree_manager,
spec_decoding_tensor)
if TYPE_CHECKING:
from tensorrt_llm.llmapi.llm_args import DecodingBaseConfig
from tensorrt_llm._torch.speculative.utils import SpecDecodingTensor
🧰 Tools
🪛 Flake8 (7.3.0)

[error] 600-600: undefined name 'SpecDecodingTensor'

(F821)

🪛 Ruff (0.15.9)

[error] 600-600: Undefined name SpecDecodingTensor

(F821)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/attention_backend/sparse/dsa.py` around lines 600 - 608,
The type name 'SpecDecodingTensor' used in the update_spec_dec_param signature
is not imported causing F821; add an import for SpecDecodingTensor inside the
existing TYPE_CHECKING block so static type-checkers see the symbol without
affecting runtime. Locate the TYPE_CHECKING block near the top of
tensorrt_llm/_torch/attention_backend/sparse/dsa.py and add a from ... import
SpecDecodingTensor (matching the module where SpecDecodingTensor is defined) so
the function signature referencing 'SpecDecodingTensor' resolves.

Comment on lines +1536 to 1548
# spec_dec mode should only be enabled for non-sm100 machines and when there's a spec-dec tree.
self.is_spec_decoding_enabled = is_spec_decoding_enabled and (
not self.is_sm_version_trtllm_gen_kernel(sm=get_sm_version()))

self.is_spec_dec_tree = spec_tree_manager is not None
self.is_spec_dec_dynamic_tree = spec_tree_manager is not None and spec_tree_manager.use_dynamic_tree

if self.is_sm_version_trtllm_gen_kernel(sm=get_sm_version()):
if self.is_spec_dec_tree or self.is_spec_dec_dynamic_tree:
assert not self.is_spec_dec_tree, "Spec-dec tree is not supported on this machine. Please use a pre-Blackwell machine for a spec-dec tree."

# use_spec_decoding is default to true by default, change in runtime by layers / requests
self.use_spec_decoding = self.is_spec_decoding_enabled
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Don't gate off the surviving linear spec-dec path on SM100+.

This now forces is_spec_decoding_enabled=False on every SM that reports TRTLLM-gen-kernel support, not just the removed tree/dynamic-tree modes. Eagle3OneModelWorker.forward() still sets attn_metadata.use_spec_decoding = True for the verification pass, so Blackwell+ ends up in an enabled=False / use=True state for the retained linear EAGLE3 flow.

Suggested direction
-        self.is_spec_decoding_enabled = is_spec_decoding_enabled and (
-            not self.is_sm_version_trtllm_gen_kernel(sm=get_sm_version()))
+        sm = get_sm_version()
+        is_trtllm_gen_sm = self.is_sm_version_trtllm_gen_kernel(sm=sm)
+        is_tree_mode = is_spec_dec_tree or is_spec_dec_dynamic_tree
+        self.is_spec_decoding_enabled = is_spec_decoding_enabled and (
+            not is_trtllm_gen_sm or not is_tree_mode)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/attention_backend/trtllm.py` around lines 1536 - 1548,
The code incorrectly forces self.is_spec_decoding_enabled False on any SM that
reports TRTLLM-gen-kernel support; change the logic so spec-decoding is only
disabled for SM-gen-kernel machines when a spec-dec tree or dynamic tree is
present. Concretely, compute a local flag via
self.is_sm_version_trtllm_gen_kernel(sm=get_sm_version()) and set
self.is_spec_decoding_enabled = is_spec_decoding_enabled unless (is_sm_gen and
(spec_tree_manager is not None or spec_tree_manager.use_dynamic_tree)), keep
self.is_spec_dec_tree and self.is_spec_dec_dynamic_tree assignments and the
existing assertion block (the assertion about spec-dec tree support) as-is, and
ensure self.use_spec_decoding continues to initialize from
self.is_spec_decoding_enabled; refer to attributes/methods
self.is_spec_decoding_enabled, self.is_spec_dec_tree,
self.is_spec_dec_dynamic_tree, self.use_spec_decoding and method
is_sm_version_trtllm_gen_kernel.

Comment on lines +2399 to +2405
spec_resource_manager, spec_tree_manager = None, None
if spec_config is not None:
spec_resource_manager = resource_manager.get_resource_manager(
ResourceManagerType.SPEC_RESOURCE_MANAGER)
if spec_resource_manager is not None and hasattr(
spec_resource_manager, 'spec_tree_manager'):
spec_tree_manager = spec_resource_manager.spec_tree_manager
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Convert tree position offsets to a Python list before these request loops.

spec_tree_manager.spec_dec_position_offsets[0] is a device tensor, so extending position_ids from it here does per-element tensor access in a hot path. Materialize the offsets once with .tolist() outside the loop and reuse that host list in both branches.

♻️ Suggested change
-        spec_resource_manager, spec_tree_manager = None, None
+        spec_resource_manager, spec_tree_manager = None, None
+        spec_position_offsets = None
         if spec_config is not None:
             spec_resource_manager = resource_manager.get_resource_manager(
                 ResourceManagerType.SPEC_RESOURCE_MANAGER)
             if spec_resource_manager is not None and hasattr(
                     spec_resource_manager, 'spec_tree_manager'):
                 spec_tree_manager = spec_resource_manager.spec_tree_manager
+                if (spec_tree_manager is not None and not self.is_draft_model
+                        and not spec_config.is_linear_tree):
+                    spec_position_offsets = (
+                        spec_tree_manager.spec_dec_position_offsets[0].tolist()
+                    )
@@
-                    position_ids.extend(
-                        past_seen_token_num +
-                        spec_tree_manager.spec_dec_position_offsets[
-                            0]  # [max_total_draft_tokens + 1]
-                    )
+                    position_ids.extend(
+                        past_seen_token_num + offset
+                        for offset in spec_position_offsets
+                    )
@@
-                    position_ids.extend(
-                        past_seen_token_num +
-                        spec_tree_manager.spec_dec_position_offsets[
-                            0]  # [max_total_draft_tokens + 1]
-                    )
+                    position_ids.extend(
+                        past_seen_token_num + offset
+                        for offset in spec_position_offsets
+                    )

Based on learnings, in files under tensorrt_llm/_torch/pyexecutor, avoid accessing torch.Tensor objects inside for-loops when iterating over requests. Convert batched tensors to Python lists beforehand using tensor.tolist().

Also applies to: 2445-2458, 2478-2491

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/pyexecutor/model_engine.py` around lines 2399 - 2405,
spec_tree_manager.spec_dec_position_offsets[0] is a device tensor and is being
accessed element-wise inside per-request loops; materialize it once into a
Python list (e.g., host_offsets =
spec_tree_manager.spec_dec_position_offsets[0].tolist()) outside the request
loops and then use host_offsets to extend position_ids in both branches instead
of indexing the tensor inside the loop; apply the same change pattern wherever
spec_dec_position_offsets is iterated (also update the analogous places handling
position_ids in the other request-loop blocks).

Comment on lines +241 to 253
# Pad to match fixed-size store buffers for index_copy_.
if o_new_tokens.shape[1] < (self.draft_len + 1):
o_new_tokens = torch.nn.functional.pad(
o_new_tokens, (0, new_tokens_width - o_new_tokens.shape[1])
o_new_tokens, (0, (self.draft_len + 1) - o_new_tokens.shape[1])
)
elif o_new_tokens.shape[1] > new_tokens_width:
o_new_tokens = o_new_tokens[:, :new_tokens_width]
if o_next_draft_tokens.shape[1] < draft_tokens_width:
if o_next_draft_tokens.shape[1] < self.draft_len:
o_next_draft_tokens = torch.nn.functional.pad(
o_next_draft_tokens, (0, draft_tokens_width - o_next_draft_tokens.shape[1])
o_next_draft_tokens, (0, self.draft_len - o_next_draft_tokens.shape[1])
)
elif o_next_draft_tokens.shape[1] > draft_tokens_width:
o_next_draft_tokens = o_next_draft_tokens[:, :draft_tokens_width]
if o_next_new_tokens.shape[1] < next_new_tokens_width:
if o_next_new_tokens.shape[1] < (self.draft_len + 1):
o_next_new_tokens = torch.nn.functional.pad(
o_next_new_tokens, (0, next_new_tokens_width - o_next_new_tokens.shape[1])
o_next_new_tokens, (0, (self.draft_len + 1) - o_next_new_tokens.shape[1])
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Pad against the allocated store width, not draft_len.

SpecSamplerBase still allows subclasses to allocate wider buffers via _get_max_tokens() / _get_draft_tokens_storage_size(). MTPSampler does that with args.max_total_draft_tokens + 1, so padding only to self.draft_len leaves o_* narrower than the destination and index_copy_ at Lines 256-259 will fail on MTP batches whose tree width exceeds the runtime draft length.

🐛 Proposed fix
+        target_new_tokens = self.store.new_tokens.shape[0]
+        target_draft_tokens = self.store.next_draft_tokens.shape[1]
+
         # Pad to match fixed-size store buffers for index_copy_.
-        if o_new_tokens.shape[1] < (self.draft_len + 1):
+        if o_new_tokens.shape[1] < target_new_tokens:
             o_new_tokens = torch.nn.functional.pad(
-                o_new_tokens, (0, (self.draft_len + 1) - o_new_tokens.shape[1])
+                o_new_tokens, (0, target_new_tokens - o_new_tokens.shape[1])
             )
-        if o_next_draft_tokens.shape[1] < self.draft_len:
+        if o_next_draft_tokens.shape[1] < target_draft_tokens:
             o_next_draft_tokens = torch.nn.functional.pad(
-                o_next_draft_tokens, (0, self.draft_len - o_next_draft_tokens.shape[1])
+                o_next_draft_tokens, (0, target_draft_tokens - o_next_draft_tokens.shape[1])
             )
-        if o_next_new_tokens.shape[1] < (self.draft_len + 1):
+        if o_next_new_tokens.shape[1] < target_new_tokens:
             o_next_new_tokens = torch.nn.functional.pad(
-                o_next_new_tokens, (0, (self.draft_len + 1) - o_next_new_tokens.shape[1])
+                o_next_new_tokens, (0, target_new_tokens - o_next_new_tokens.shape[1])
             )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/speculative/spec_sampler_base.py` around lines 241 - 253,
The current padding uses self.draft_len but must use the allocated store width
from the sampler storage helpers; compute storage_width =
self._get_draft_tokens_storage_size() (or use _get_max_tokens() if appropriate)
and pad o_new_tokens and o_next_new_tokens up to storage_width, and pad
o_next_draft_tokens up to storage_width - 1, so the tensors match the
destination buffers before the index_copy_ calls (look for o_new_tokens,
o_next_draft_tokens, o_next_new_tokens and index_copy_ in SpecSamplerBase).

Comment on lines 73 to +75
def __init__(self, max_num_requests: int, use_dynamic_tree: bool,
max_total_draft_tokens: int, max_draft_len: int,
eagle_choices: List[List[int]] | None,
dynamic_tree_max_topK: int):
eagle_choices: [List[List[int]]], dynamic_tree_max_topK: int):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
python - <<'PY'
from pathlib import Path
import ast

path = Path("tensorrt_llm/_torch/speculative/spec_tree_manager.py")
module = ast.parse(path.read_text())

for node in ast.walk(module):
    if isinstance(node, ast.FunctionDef) and node.name == "__init__":
        arg = next(a for a in node.args.args if a.arg == "eagle_choices")
        print(ast.dump(arg.annotation, indent=2))
        break
PY

Repository: NVIDIA/TensorRT-LLM

Length of output: 296


🏁 Script executed:

#!/bin/bash
# Check the full __init__ signature and see if eagle_choices has a default value
cd tensorrt_llm/_torch/speculative
head -80 spec_tree_manager.py | tail -20

Repository: NVIDIA/TensorRT-LLM

Length of output: 1000


🏁 Script executed:

#!/bin/bash
# Check project Python version requirement
fd -t f "pyproject.toml|setup.py|setup.cfg" -x cat {}

Repository: NVIDIA/TensorRT-LLM

Length of output: 50376


🏁 Script executed:

#!/bin/bash
# Check if eagle_choices is used elsewhere to understand if it can be None
rg "eagle_choices" tensorrt_llm/_torch/speculative/spec_tree_manager.py -A 2 -B 2

Repository: NVIDIA/TensorRT-LLM

Length of output: 1944


Fix the eagle_choices annotation.

[List[List[int]]] is a list literal in Python, not a type annotation, so type checkers and inspect-based tooling can't reason about this parameter correctly. Use list[list[int]] | None instead. Python 3.10+ supports this syntax.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/speculative/spec_tree_manager.py` around lines 73 - 75,
The type annotation for the eagle_choices parameter in SpecTreeManager.__init__
is using a list literal ([List[List[int]]]) instead of a proper type; change the
signature to use the modern typing form (e.g., eagle_choices: list[list[int]] |
None) so type checkers and inspect-based tooling can understand it, and update
any callers or default behavior if you make it optional (None) to preserve
existing semantics.

Comment on lines +260 to +266
if self.use_dynamic_tree:
self.eagle_paths[tree_idx].fill_(-1)
# If dynamic tree, return the eagle_paths according to the mask.
for i in range(self.max_total_draft_tokens + 1):
self.eagle_paths[tree_idx][:, i, :] = self.spec_dec_mask_matrix[
tree_idx][i, :].nonzero()
return self.eagle_paths[tree_idx]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

get_eagle_paths() will raise before it can rebuild dynamic paths.

After self.eagle_paths[tree_idx], the tensor is 2-D, so [:, i, :] on Line 264 is an IndexError. tensorrt_llm/_torch/pyexecutor/sampler.py calls get_eagle_paths(seq_slot) during verification, so any dynamic-tree request will fail here instead of returning paths.

🐛 Proposed fix
         if self.use_dynamic_tree:
             self.eagle_paths[tree_idx].fill_(-1)
             # If dynamic tree, return the eagle_paths according to the mask.
             for i in range(self.max_total_draft_tokens + 1):
-                self.eagle_paths[tree_idx][:, i, :] = self.spec_dec_mask_matrix[
-                    tree_idx][i, :].nonzero()
+                path = self.spec_dec_mask_matrix[tree_idx, i].nonzero(
+                    as_tuple=False
+                ).flatten()
+                self.eagle_paths[tree_idx, i, :path.numel()] = path
             return self.eagle_paths[tree_idx]
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/speculative/spec_tree_manager.py` around lines 260 - 266,
get_eagle_paths is indexing self.eagle_paths[tree_idx] as if it's a 3-D tensor
([:, i, :]) but after self.eagle_paths[tree_idx] it's 2-D, causing an IndexError
when use_dynamic_tree is True; fix by constructing or indexing the correct
shape: either ensure self.eagle_paths[tree_idx] is a 3-D tensor of shape (batch,
max_total_draft_tokens+1, path_len) before the loop, or change the assignment to
match the 2-D layout (e.g. assign into self.eagle_paths[tree_idx][:, i] or
reshape the nonzero result to match the third dimension). Update get_eagle_paths
to use self.eagle_paths, self.spec_dec_mask_matrix, use_dynamic_tree, and
max_total_draft_tokens consistently so the slice operations match tensor ranks
(reshape or unsqueeze nonzero results or preallocate a 3-D tensor) to avoid
IndexError.

Comment on lines +457 to +458
if __name__ == "__main__":
unittest.main()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

🏁 Script executed:

cat tests/unittest/_torch/speculative/test_draft_token_prepare_for_generation.py | head -50

Repository: NVIDIA/TensorRT-LLM

Length of output: 1593


🏁 Script executed:

# Get the test file size first
wc -l tests/unittest/_torch/speculative/test_draft_token_prepare_for_generation.py

Repository: NVIDIA/TensorRT-LLM

Length of output: 143


🏁 Script executed:

# Look at the specific test function definition and the end of file
sed -n '1,30p; 450,460p' tests/unittest/_torch/speculative/test_draft_token_prepare_for_generation.py

Repository: NVIDIA/TensorRT-LLM

Length of output: 1108


🏁 Script executed:

# Search for all test function definitions in the file
rg "^def test_" tests/unittest/_torch/speculative/test_draft_token_prepare_for_generation.py

Repository: NVIDIA/TensorRT-LLM

Length of output: 121


🏁 Script executed:

# Check for unittest.TestCase subclasses
rg "class.*\(.*unittest\.TestCase" tests/unittest/_torch/speculative/test_draft_token_prepare_for_generation.py

Repository: NVIDIA/TensorRT-LLM

Length of output: 45


unittest.main() won't discover the module-level test function.

This file defines test_draft_token_static_tree_prepare_for_generation() as a module-level function (pytest convention), but unittest.main() only discovers unittest.TestCase subclasses. Running this file directly will find zero tests. Either remove the unittest.main() block if this is a pytest test, or convert the test to a unittest.TestCase subclass.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/unittest/_torch/speculative/test_draft_token_prepare_for_generation.py`
around lines 457 - 458, The file currently calls unittest.main() but defines
module-level pytest-style test_draft_token_static_tree_prepare_for_generation(),
so running the file finds zero tests; either remove the if __name__ ==
"__main__": unittest.main() block to keep this as a pytest-only test, or convert
the module-level function into a unittest.TestCase (e.g., create class
TestDraftTokenPrepareForGeneration(unittest.TestCase) with a method
test_draft_token_static_tree_prepare_for_generation that calls the same
assertions) and keep unittest.main(); update imports accordingly.

@litaotju
Copy link
Copy Markdown
Collaborator

I will skip merge to unblock CI.
The failures are known and not related to this PR

WAN failures #13026
Other failures waived by this: #13035 and #13016

@litaotju litaotju merged commit e31d8d7 into NVIDIA:main Apr 14, 2026
6 of 7 checks passed
@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #43166 [ run ] completed with state SUCCESS. Commit: d76cabe
/LLM/main/L0_MergeRequest_PR pipeline #33796 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants