Skip to content

Commit 0a62078

Browse files
committed
fixup! [test] Enable LoRA in PRAD speculative decoding
Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>
1 parent 80dc58f commit 0a62078

2 files changed

Lines changed: 11 additions & 3 deletions

File tree

tensorrt_llm/_torch/peft/lora/cuda_graph_lora_manager.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,8 +106,14 @@ def get_layer_idx(
106106
return None
107107

108108
# Ignore LoRA layers without at least one of the target modules.
109+
# Skip LoRA layers that belong to draft model subtrees (e.g., PARD
110+
# embeds a full HF model as a submodule whose layers share the same
111+
# layer_idx values as the target model, causing key collisions).
109112
for name, module in model.named_modules():
110113
if isinstance(module, LoraLayer):
114+
if name.startswith("draft_model."):
115+
logger.debug(f"Skipping draft model LoRA module {name}")
116+
continue
111117
layer_idx = get_layer_idx(model, module, name)
112118
# if target_modules_ids is None, by default enable all modules
113119
if self.target_modules_ids and not any(

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -537,9 +537,11 @@ def _init_cuda_graph_lora_manager(self, lora_config: LoraConfig):
537537
max_lora_size = lora_config.max_loras or 8 # Default fallback
538538
max_batch_size = self.batch_size # Use engine's max batch size
539539

540-
# For spec decode, each generation request contributes
541-
# max_draft_len + 1 tokens per forward pass.
542-
max_tokens_per_seq = (self.original_max_draft_len +
540+
# For spec decode, each generation request can contribute up to
541+
# tokens_per_gen_step tokens per forward pass. This is larger than
542+
# max_draft_len + 1 for modes like PARD, which use extra mask
543+
# tokens in the same generation step.
544+
max_tokens_per_seq = (self.original_max_total_draft_tokens +
543545
1) if self.is_spec_decode else 1
544546
self.cuda_graph_lora_manager = CudaGraphLoraManager(
545547
max_lora_size=max_lora_size,

0 commit comments

Comments
 (0)