Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion tensorrt_llm/_torch/models/modeling_multimodal_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,13 @@ def get_multimodal_embeddings(

# TODO: support multiple multimodal modalities per request
if len(encoder_embeddings) > 1:
logger.warning("Multiple modalities caching is not supported yet.")
logger.warning(
f"Multiple modalities caching is not supported yet. "
f"encoder returned {len(encoder_embeddings)} embeddings "
f"(types: {[type(e).__name__ for e in encoder_embeddings]}, "
f"shapes: {[e.shape if hasattr(e, 'shape') else 'N/A' for e in encoder_embeddings]}) "
f"for {len(uncached_multimodal_params)} uncached params. "
f"encoder_forward_fn={encoder_forward_fn}")
return encoder_embeddings

# Validate that multimodal_runtime has required attributes for caching
Expand Down
40 changes: 30 additions & 10 deletions tensorrt_llm/_torch/models/modeling_nemotron_nano.py
Original file line number Diff line number Diff line change
Expand Up @@ -1788,24 +1788,37 @@ def _encode_audio(self, param: MultimodalParams) -> torch.Tensor:
result = torch.cat(truncated, dim=0) # [total_tokens, llm_hidden_size]
return result

def _encode_multimodal(
self, multimodal_params: List[MultimodalParams]
) -> Tuple[List[torch.Tensor], List[Optional[List[int]]]]:
"""Dispatch multimodal encoding to the appropriate encoder."""
def _encode_multimodal(self, multimodal_params: List[MultimodalParams]) -> List[torch.Tensor]:
"""Dispatch multimodal encoding to the appropriate encoder.

Returns a single-element `List[torch.Tensor]` (all per-request
embeddings concatenated) to conform to the contract expected by
`get_multimodal_embeddings`, which enables chunked-prefill
caching. Per-request `num_tokens_in_video` (needed by EVS) is
stashed in each param's `multimodal_data` dict as a
side-channel.
"""
mm_embeddings = []
mm_num_tokens = []
for param in multimodal_params:
modality_type = param.multimodal_data["modality_type"]
if modality_type in ("image", "video"):
embs, num_tokens = self.vision_encoder([param])
mm_embeddings.append(embs[0])
mm_num_tokens.append(num_tokens[0] if num_tokens is not None else None)

# Stash per-request token counts for later EVS adjustment.
if num_tokens is not None:
param.multimodal_data["num_tokens_in_video"] = num_tokens[0]
elif modality_type == "audio":
mm_embeddings.append(self._encode_audio(param))
mm_num_tokens.append(None)
else:
raise ValueError(f"Unknown modality: {modality_type}")
return mm_embeddings, mm_num_tokens

# Concatenate per-request embeddings into a single tensor.
# `get_multimodal_embeddings` expects a single-element list containing one tensor (all
# items' embeddings concatenated).
if mm_embeddings:
return [torch.cat(mm_embeddings, dim=0)]
return []

@torch.inference_mode()
def forward(
Expand All @@ -1832,7 +1845,7 @@ def forward(
mm_embedding = []
if len(multimodal_params) > 0:
if not _is_disagg():
mm_embedding, num_tokens_in_videos = get_multimodal_embeddings(
mm_embedding = get_multimodal_embeddings(
encoder_forward_fn=self._encode_multimodal,
multimodal_params=multimodal_params[:num_context_requests],
)
Expand All @@ -1843,9 +1856,16 @@ def forward(
)
# Adjust input_ids in videos if EVS is applied.
if self.video_pruning_rate > 0:
# Retrieve per-video count stashed by `_encode_multimodal`.
ctx_params = multimodal_params[:num_context_requests]
num_tokens_in_videos = [
param.multimodal_data.get("num_tokens_in_video")
for param in ctx_params
if param.has_content()
]
input_ids = self.merge_evs_mm_embeds(
num_tokens_in_videos,
multimodal_params=multimodal_params[:num_context_requests],
multimodal_params=ctx_params,
input_ids=input_ids,
)

Expand Down
249 changes: 243 additions & 6 deletions tests/unittest/_torch/modeling/test_modeling_nemotron_nano_v2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from test_modeling_nemotron_h import extract_decode_logprobs

from tensorrt_llm import LLM
from tensorrt_llm._torch.models.modeling_multimodal_utils import get_multimodal_embeddings
from tensorrt_llm._torch.models.modeling_nemotron_nano import (
NanoV2VLVisionEncoder,
NemotronH_Nano_VL_V2,
Expand All @@ -20,6 +21,7 @@
default_multimodal_input_loader,
prompt_inputs,
)
from tensorrt_llm.inputs.multimodal import MultimodalParams, MultimodalRuntimeData
from tensorrt_llm.llmapi import KvCacheConfig
from tensorrt_llm.llmapi.llm_args import CudaGraphConfig
from tensorrt_llm.sampling_params import SamplingParams
Expand Down Expand Up @@ -220,6 +222,15 @@ def _make_mock_model(self):
model.sound_encoder = mock.MagicMock(spec=ProjectedParakeet)
return model

@staticmethod
def _assert_compatible_with_chunked_prefill(multimodal_embeddings):
# NOTE: `multimodal_embeddings` is expected to be the output of `_encode_multimodal`.
# The below checks help verify that we can make use of `get_multimodal_embeddings` and its
# caching feature. Otherwise, we would be re-encoding the items each chunk during chunked
# prefill.
assert len(multimodal_embeddings) == 1
assert isinstance(multimodal_embeddings[0], torch.Tensor)

def test_encode_multimodal_dispatches_audio(self):
model = self._make_mock_model()
fake_audio_embeds = torch.randn(10, 128)
Expand All @@ -229,16 +240,16 @@ def test_encode_multimodal_dispatches_audio(self):
mm_param.multimodal_data = {"modality_type": "audio", "audio": {}}

# Call the real method on our mock
result_embeds, result_nones = NemotronH_Nano_VL_V2._encode_multimodal(model, [mm_param])
result = NemotronH_Nano_VL_V2._encode_multimodal(model, [mm_param])

model._encode_audio.assert_called_once_with(mm_param)
model.vision_encoder.assert_not_called()
assert len(result_embeds) == 1
assert torch.equal(result_embeds[0], fake_audio_embeds)
self._assert_compatible_with_chunked_prefill(result)
assert torch.equal(result[0], fake_audio_embeds)

def test_encode_multimodal_dispatches_image(self):
model = self._make_mock_model()
fake_image_embeds = torch.randn(1, 10, 128)
fake_image_embeds = torch.randn(10, 128)
model.vision_encoder.return_value = ([fake_image_embeds], [None])

mm_param = mock.MagicMock()
Expand All @@ -247,10 +258,10 @@ def test_encode_multimodal_dispatches_image(self):
"image": {"pixel_values": torch.randn(1, 100, 768)},
}

result_embeds, result_nones = NemotronH_Nano_VL_V2._encode_multimodal(model, [mm_param])
result = NemotronH_Nano_VL_V2._encode_multimodal(model, [mm_param])

model.vision_encoder.assert_called_once_with([mm_param])
assert len(result_embeds) == 1
self._assert_compatible_with_chunked_prefill(result)

def test_encode_multimodal_unknown_modality_raises(self):
"""Unknown modality raises ValueError."""
Expand All @@ -260,3 +271,229 @@ def test_encode_multimodal_unknown_modality_raises(self):

with pytest.raises(ValueError, match="Unknown modality"):
NemotronH_Nano_VL_V2._encode_multimodal(model, [mm_param])


class TestEncodeMultimodalContract:
"""Verify `_encode_multimodal` conforms to the contract expected by `get_multimodal_embeddings`.

The key assumption is that the `encoder_forward_fn` passed to it returns something whose length
is 1, and can be indexed by `[0]` to return a single `torch.Tensor`.
"""

HIDDEN = 128

def _make_mock_model(self):
model = mock.MagicMock(spec=NemotronH_Nano_VL_V2)
model.vision_encoder = mock.MagicMock(spec=NanoV2VLVisionEncoder)
model.sound_encoder = mock.MagicMock(spec=ProjectedParakeet)
return model

def _make_mm_param(self, modality_type, **extra):
param = mock.MagicMock()
param.multimodal_data = {"modality_type": modality_type, **extra}
return param

def test_returns_list_with_a_single_element(self):
model = self._make_mock_model()
model.vision_encoder.return_value = ([torch.randn(5, self.HIDDEN)], [None])

param = self._make_mm_param("image")
result = NemotronH_Nano_VL_V2._encode_multimodal(model, [param])

assert isinstance(result, list)
assert len(result) == 1

def test_single_concatenated_tensor_for_multiple_multimodal_items(self):
"""Multiple multimodal items must be concatenated into a single tensor.

`get_multimodal_embeddings` requires `len(embeddings) == 1` and splits by per-request token
counts in order to cache the embeddings.
"""
model = self._make_mock_model()
emb_a = torch.randn(5, self.HIDDEN)
emb_b = torch.randn(3, self.HIDDEN)
model.vision_encoder.side_effect = [
([emb_a], [None]),
([emb_b], [None]),
]

params = [self._make_mm_param("image"), self._make_mm_param("image")]
result = NemotronH_Nano_VL_V2._encode_multimodal(model, params)

assert len(result) == 1
assert result[0].shape == (8, self.HIDDEN)
# Verify concatenation order is preserved.
assert torch.equal(result[0][:5], emb_a)
assert torch.equal(result[0][5:], emb_b)

def test_mixed_modalities_still_single_tensor(self):
"""Image + audio requests produce one concatenated tensor."""
model = self._make_mock_model()
img_emb = torch.randn(5, self.HIDDEN)
audio_emb = torch.randn(3, self.HIDDEN)
model.vision_encoder.return_value = ([img_emb], [None])
model._encode_audio = mock.MagicMock(return_value=audio_emb)

params = [
self._make_mm_param("image"),
self._make_mm_param("audio", audio={}),
]
result = NemotronH_Nano_VL_V2._encode_multimodal(model, params)

assert len(result) == 1
assert result[0].shape == (8, self.HIDDEN)

def test_empty_params_returns_empty_list(self):
model = self._make_mock_model()
result = NemotronH_Nano_VL_V2._encode_multimodal(model, [])
assert result == []


class TestChunkedPrefillCaching:
"""Verify that `_encode_multimodal` output is compatible with `get_multimodal_embeddings`.

Specifically, we want to test that the caching functionality is exercised and not skipped due
to the return type not being compatible.

The test structure for each modality:

1. Build `MultimodalParams` with a real `MultimodalRuntimeData` (past_seen_token_num=0,
simulating the first chunk).
2. Call `get_multimodal_embeddings` with the real _encode_multimodal wired to mock sub-encoders
-> encoder MUST be invoked.
3. Verify embeddings were cached in `multimodal_data`.
4. Call `get_multimodal_embeddings` again with the SAME params (simulating a second chunk) ->
encoder must NOT be invoked.
"""

HIDDEN = 128
NUM_TOKENS = 10

def _make_mock_model(self):
model = mock.MagicMock(spec=NemotronH_Nano_VL_V2)
model.vision_encoder = mock.MagicMock(spec=NanoV2VLVisionEncoder)
model.sound_encoder = mock.MagicMock(spec=ProjectedParakeet)
return model

def _make_param_with_runtime(self, modality_type, num_tokens, **extra):
"""Build a real MultimodalParams with runtime data for caching."""
runtime = MultimodalRuntimeData(
past_seen_token_num=0,
mm_token_lengths=[num_tokens],
mm_token_positions=[0],
chunk_end_pos=num_tokens,
special_token_offsets=[],
)
return MultimodalParams(
multimodal_data={"modality_type": modality_type, **extra},
multimodal_runtime=runtime,
)

def _make_encoder_fn(self, model):
"""Wrap `_encode_multimodal` as a callable for get_multimodal_embeddings."""

def encoder_fn(params):
return NemotronH_Nano_VL_V2._encode_multimodal(model, params)

return encoder_fn

@pytest.mark.parametrize("modality", ["image", "video"])
def test_vision_encoder_not_called_on_second_chunk(self, modality):
model = self._make_mock_model()
fake_emb = torch.randn(self.NUM_TOKENS, self.HIDDEN)
model.vision_encoder.return_value = ([fake_emb], [None])

param = self._make_param_with_runtime(modality, self.NUM_TOKENS)
encoder_fn = self._make_encoder_fn(model)

# First call: encoder must run and cache the result.
result = get_multimodal_embeddings(
encoder_forward_fn=encoder_fn,
multimodal_params=[param],
)
assert len(result) == 1
assert result[0].shape == (self.NUM_TOKENS, self.HIDDEN)
assert model.vision_encoder.call_count == 1

# Embedding is now cached in multimodal_data.
assert "multimodal_embedding" in param.multimodal_data

# Second call: encoder must NOT run - embeddings come from cache.
result2 = get_multimodal_embeddings(
encoder_forward_fn=encoder_fn,
multimodal_params=[param],
)
assert model.vision_encoder.call_count == 1, (
"`vision_encoder` was called again on the second chunk. "
"Caching is broken - `_encode_multimodal` likely violates the "
"`get_multimodal_embeddings` return type contract."
)
assert len(result2) == 1
assert torch.equal(result2[0], result[0])

def test_audio_encoder_not_called_on_second_chunk(self):
model = self._make_mock_model()
fake_emb = torch.randn(self.NUM_TOKENS, self.HIDDEN)
model._encode_audio = mock.MagicMock(return_value=fake_emb)

param = self._make_param_with_runtime("audio", self.NUM_TOKENS, audio={})
encoder_fn = self._make_encoder_fn(model)

# First call.
result = get_multimodal_embeddings(
encoder_forward_fn=encoder_fn,
multimodal_params=[param],
)
assert len(result) == 1
assert model._encode_audio.call_count == 1

# Second call - should use cache.
result2 = get_multimodal_embeddings(
encoder_forward_fn=encoder_fn,
multimodal_params=[param],
)
assert model._encode_audio.call_count == 1, (
"`_encode_audio` was called again on the second chunk. "
"Caching is broken - `_encode_multimodal` likely violates the "
"`get_multimodal_embeddings` return type contract."
)
assert torch.equal(result2[0], result[0])

def test_multi_request_batch_caching(self):
"""Two image requests in one batch: both cached after one call."""
model = self._make_mock_model()
emb_a = torch.randn(5, self.HIDDEN)
emb_b = torch.randn(3, self.HIDDEN)
model.vision_encoder.side_effect = [
([emb_a], [None]),
([emb_b], [None]),
]

param_a = self._make_param_with_runtime("image", 5)
param_b = self._make_param_with_runtime("image", 3)
encoder_fn = self._make_encoder_fn(model)

# First call: encoder runs for both.
result = get_multimodal_embeddings(
encoder_forward_fn=encoder_fn,
multimodal_params=[param_a, param_b],
)
assert len(result) == 1
assert result[0].shape == (8, self.HIDDEN)
assert model.vision_encoder.call_count == 2 # once per param

# Both should be cached.
assert "multimodal_embedding" in param_a.multimodal_data
assert "multimodal_embedding" in param_b.multimodal_data

# Second call: encoder must not run again.
result2 = get_multimodal_embeddings(
encoder_forward_fn=encoder_fn,
multimodal_params=[param_a, param_b],
)
assert model.vision_encoder.call_count == 2, (
"`vision_encoder` was called again on the second chunk. "
"Caching is broken - `_encode_multimodal` likely violates the "
"`get_multimodal_embeddings` return type contract."
)
assert torch.equal(result2[0], result[0])
Loading