Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions tensorrt_llm/_torch/peft/lora/cuda_graph_lora_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,14 @@ def get_layer_idx(
return None

# Ignore LoRA layers without at least one of the target modules.
# Skip LoRA layers that belong to draft model subtrees (e.g., PARD
# embeds a full HF model as a submodule whose layers share the same
# layer_idx values as the target model, causing key collisions).
for name, module in model.named_modules():
if isinstance(module, LoraLayer):
if name.startswith("draft_model."):
logger.debug(f"Skipping draft model LoRA module {name}")
continue
layer_idx = get_layer_idx(model, module, name)
# if target_modules_ids is None, by default enable all modules
if self.target_modules_ids and not any(
Expand Down
8 changes: 5 additions & 3 deletions tensorrt_llm/_torch/pyexecutor/model_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,9 +538,11 @@ def _init_cuda_graph_lora_manager(self, lora_config: LoraConfig):
max_lora_size = lora_config.max_loras or 8 # Default fallback
max_batch_size = self.batch_size # Use engine's max batch size

# For spec decode, each generation request contributes
# max_draft_len + 1 tokens per forward pass.
max_tokens_per_seq = (self.original_max_draft_len +
# For spec decode, each generation request can contribute up to
# tokens_per_gen_step tokens per forward pass. This is larger than
# max_draft_len + 1 for modes like PARD, which use extra mask
# tokens in the same generation step.
max_tokens_per_seq = (self.original_max_total_draft_tokens +
1) if self.is_spec_decode else 1
self.cuda_graph_lora_manager = CudaGraphLoraManager(
max_lora_size=max_lora_size,
Expand Down
62 changes: 62 additions & 0 deletions tests/unittest/_torch/speculative/test_pard.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
from utils.llm_data import llm_models_root

from tensorrt_llm import LLM, SamplingParams
from tensorrt_llm.executor.request import LoRARequest
from tensorrt_llm.llmapi import CudaGraphConfig, KvCacheConfig, PARDDecodingConfig
from tensorrt_llm.lora_helper import LoraConfig

sys.path.append(os.path.join(os.path.dirname(__file__), ".."))

Expand Down Expand Up @@ -76,5 +78,65 @@ def test_pard(disable_overlap_scheduler: bool):
llm_spec.shutdown()


@pytest.mark.parametrize("use_cuda_graph", [True, False])
def test_pard_lora(use_cuda_graph: bool):
"""Test PARD speculative decoding with LoRA support.

This test verifies that PARD (Parallel Draft) speculative decoding works
correctly with LoRA modules.
"""
attn_backend = "TRTLLM"
enable_block_reuse = False
enable_chunked_prefill = False

total_mem_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
if total_mem_gb < 35:
pytest.skip("Not enough memory to load target + draft model")

models_path = llm_models_root()
pard_model_dir = f"{models_path}/PARD-Llama-3.2-1B"
target_model_dir = f"{models_path}/llama-3.1-model/Llama-3.1-8B-Instruct"
hf_lora_dir = f"{models_path}/llama-models/luotuo-lora-7b-0.1"

# Test with 3 requests and max_batch_size=4 to trigger padding
max_batch_size = 4
max_draft_len = 4
kv_cache_config = KvCacheConfig(enable_block_reuse=enable_block_reuse, max_tokens=2048)
cuda_graph_config = (
CudaGraphConfig(batch_sizes=[1, 2, 4], enable_padding=True) if use_cuda_graph else None
)
lora_config = LoraConfig(max_lora_rank=64, max_loras=2, max_cpu_loras=2)

llm_common_config = dict(
model=target_model_dir,
attn_backend=attn_backend,
cuda_graph_config=cuda_graph_config,
max_batch_size=max_batch_size,
kv_cache_config=kv_cache_config,
max_seq_len=2048,
enable_chunked_prefill=enable_chunked_prefill,
lora_config=lora_config,
)

spec_config = PARDDecodingConfig(
max_draft_len=max_draft_len,
speculative_model=pard_model_dir,
)

# Create the LLM instance
llm_spec = LLM(**llm_common_config, speculative_config=spec_config)

prompts = [
"The capital of France is",
"The president of the United States is",
"The future of AI is",
]
lora_requests = [LoRARequest("luotuo", 1, hf_lora_dir)] * len(prompts)

sampling_params = SamplingParams(max_tokens=1024, temperature=0, add_special_tokens=False)
llm_spec.generate(prompts, sampling_params, lora_request=lora_requests)
llm_spec.shutdown()


if __name__ == "__main__":
unittest.main()
Loading