Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion tensorrt_llm/_torch/peft/lora/adapter_slot_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,13 @@ def get_or_assign_task(self, task_id: int) -> tuple[int, Optional[int]]:
self.task2slot[task_id] = evicted_slot
return self.task2slot[task_id], evicted_task

def remove_evicted_slots_in_cpp(self, peft_cache_manager: PeftCacheManager):
def remove_evicted_slots_in_cpp(self, peft_cache_manager: Optional[PeftCacheManager]):
"""
Validate slots by removing tasks that are not cached in PeftCacheManager.
"""
if peft_cache_manager is None:
return

for task_id in self.slot2task:
if task_id is not None:
if not peft_cache_manager.is_task_cached_device(task_id):
Expand Down
13 changes: 8 additions & 5 deletions tensorrt_llm/_torch/peft/lora/cuda_graph_lora_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,14 +213,17 @@ def update_sorted_indices(self, slot_ids: List[int], tokens_per_seq: int = 1):
self.sorted_ids[:num_tokens].copy_(sorted_ids_host, non_blocking=True)

def update_weight_pointers(
self, peft_table: Dict[int, List], slot_to_task_mapping: tuple[Optional[int], ...]
self,
peft_table: Optional[Dict[int, List]],
slot_to_task_mapping: tuple[Optional[int], ...],
):
"""
Update weight pointers from PEFT cache manager.

Args:
peft_table: PEFT table from cache manager containing weight pointers, map task id to list of layer
module configs
module configs. Can be None when slot membership changes without any newly prepared PEFT
entries in the current batch.
slot_to_task_mapping: Mapping from slot_id to task_id, tuple of None for empty slots
"""

Expand All @@ -241,9 +244,9 @@ def zero_out_weight_pointers(slot_id: int):
if task_id is None: # empty slot
self.slot_ranks_host[slot_id] = 0
zero_out_weight_pointers(slot_id)
elif (
task_id not in peft_table
): # task has not changed in the slot, retain old rank / weight pointers
elif peft_table is None or task_id not in peft_table:
# No new PEFT entry was prepared for this task in the current batch, so retain
# the existing rank and weight pointers for the occupied slot.
continue
else: # task might have changed in the slot, update its rank
task_configs = peft_table[task_id]
Expand Down
1 change: 1 addition & 0 deletions tests/integration/test_lists/test-db/l0_h100.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ l0_h100:
- unittest/_torch/compilation
- unittest/_torch/debugger
- unittest/_torch/executor
- unittest/_torch/lora
- unittest/_torch/misc
# ------------- modules (non-MoE) ---------------
- unittest/_torch/modules/test_mla_helix.py
Expand Down
38 changes: 38 additions & 0 deletions tests/unittest/_torch/lora/test_lora.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import torch

from tensorrt_llm._torch.peft.lora.adapter_slot_manager import AdapterSlotManager
from tensorrt_llm._torch.peft.lora.cuda_graph_lora_params import CudaGraphLoraParams


def test_cuda_graph_lora_params_handle_missing_peft_table():
layer_key = CudaGraphLoraParams.LoraLayerKey(layer_idx=0, module_ids=(1, 2))
layer_info = {layer_key: CudaGraphLoraParams.LoraLayerInfo(module_num=2, output_sizes=[16, 32])}
params = CudaGraphLoraParams(
max_batch_size=2, max_lora_size=2, max_rank=8, layer_info=layer_info
)
layer_params = params.layer_params[layer_key]

layer_params.h_b_ptrs[:, 0] = torch.tensor([11, 22], dtype=torch.int64)
layer_params.h_b_prime_ptrs[:, 0] = torch.tensor([33, 44], dtype=torch.int64)
layer_params.h_b_ptrs[:, 1] = torch.tensor([55, 66], dtype=torch.int64)
layer_params.h_b_prime_ptrs[:, 1] = torch.tensor([77, 88], dtype=torch.int64)
params.slot_ranks_host[:] = torch.tensor([4, 7], dtype=torch.int32)

params.update_weight_pointers(None, (123, None))

assert params.slot_ranks_host.tolist() == [4, 0]
assert layer_params.h_b_ptrs[:, 0].tolist() == [11, 22]
assert layer_params.h_b_prime_ptrs[:, 0].tolist() == [33, 44]
assert layer_params.h_b_ptrs[:, 1].tolist() == [0, 0]
assert layer_params.h_b_prime_ptrs[:, 1].tolist() == [0, 0]


def test_adapter_slot_manager_handles_missing_peft_cache_manager():
manager = AdapterSlotManager(max_num_adapters=2)
manager.slot2task[0] = 123
manager.task2slot[123] = 0

manager.remove_evicted_slots_in_cpp(None)

assert manager.get_slot_to_task_mapping() == (123, None)
assert manager.task2slot[123] == 0
70 changes: 68 additions & 2 deletions tests/unittest/_torch/speculative/test_eagle3.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@
from tensorrt_llm import LLM, SamplingParams
from tensorrt_llm._torch.attention_backend.trtllm import TrtllmAttentionMetadata
from tensorrt_llm._torch.metadata import KVCacheParams
from tensorrt_llm.executor.request import LoRARequest
from tensorrt_llm.llmapi import (CudaGraphConfig, Eagle3DecodingConfig,
KvCacheConfig)
from tensorrt_llm.lora_helper import LoraConfig

sys.path.append(os.path.join(os.path.dirname(__file__), '..'))

Expand Down Expand Up @@ -757,8 +759,9 @@ def test_eagle3_cuda_graph_padding(disable_overlap_scheduler: bool):
llm_spec = LLM(**llm_common_config, speculative_config=spec_config)

prompts = [
"The capital of France is", "The president of the United States is",
"The future of AI is"
"The capital of France is",
"The president of the United States is",
"The future of AI is",
]

sampling_params = SamplingParams(max_tokens=2048, temperature=0)
Expand Down Expand Up @@ -906,5 +909,68 @@ def test_llama_eagle3_dynamic_tree(use_cuda_graph: bool,
assert text_spec == text_ref


@pytest.mark.parametrize("use_cuda_graph", [True, False])
def test_eagle3_lora(use_cuda_graph: bool):
"""Test LoRA with 3 requests and max_batch_size=4.

This test verifies that when using LoRA modules,
the system properly applies the LoRA configurations.
"""
attn_backend = "TRTLLM"
enable_block_reuse = False
use_one_model = True
enable_chunked_prefill = False

total_mem_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
if total_mem_gb < 35:
pytest.skip("Not enough memory to load target + draft model")

models_path = llm_models_root()

eagle_model_dir = f"{models_path}/EAGLE3-LLaMA3.1-Instruct-8B"
target_model_dir = f"{models_path}/llama-3.1-model/Llama-3.1-8B-Instruct"
hf_lora_dir = f"{models_path}/llama-models/luotuo-lora-7b-0.1"

# Test with 3 requests and max_batch_size=4 to trigger padding
max_batch_size = 4
max_draft_len = 4
kv_cache_config = KvCacheConfig(enable_block_reuse=enable_block_reuse,
max_tokens=8192)
cuda_graph_config = CudaGraphConfig(
batch_sizes=[1, 2, 4], enable_padding=True) if use_cuda_graph else None
lora_config = LoraConfig(max_lora_rank=64, max_loras=2, max_cpu_loras=2)

llm_common_config = dict(
model=target_model_dir,
attn_backend=attn_backend,
cuda_graph_config=cuda_graph_config,
max_batch_size=max_batch_size,
kv_cache_config=kv_cache_config,
max_seq_len=1024,
enable_chunked_prefill=enable_chunked_prefill,
lora_config=lora_config,
)

spec_config = Eagle3DecodingConfig(
max_draft_len=max_draft_len,
speculative_model=eagle_model_dir,
eagle3_one_model=use_one_model,
)

# Create the LLM instance
llm_spec = LLM(**llm_common_config, speculative_config=spec_config)

prompts = [
"The capital of France is",
"The president of the United States is",
"The future of AI is",
]
lora_requests = [LoRARequest("luotuo", 1, hf_lora_dir)] * len(prompts)

sampling_params = SamplingParams(max_tokens=20, temperature=0)
llm_spec.generate(prompts, sampling_params, lora_request=lora_requests)
llm_spec.shutdown()


if __name__ == "__main__":
unittest.main()
Loading