Skip to content

Commit e7868dd

Browse files
committed
[None][fix] Fix chunked prefill API contract for nemotron nano VL
* Why? In order to opt into the caching functionality for chunked prefix, there are certain assumptions on the return type of the encoder's forward function. These assumptions did not hold for nemotron nano VL prior to this commit. * What? This commit fixes this issue, and adds tests to catch regressions. Signed-off-by: William Zhang <133824995+2ez4bz@users.noreply.github.com>
1 parent 89eae92 commit e7868dd

3 files changed

Lines changed: 280 additions & 17 deletions

File tree

tensorrt_llm/_torch/models/modeling_multimodal_utils.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,13 @@ def get_multimodal_embeddings(
151151

152152
# TODO: support multiple multimodal modalities per request
153153
if len(encoder_embeddings) > 1:
154-
logger.warning("Multiple modalities caching is not supported yet.")
154+
logger.warning(
155+
f"Multiple modalities caching is not supported yet. "
156+
f"encoder returned {len(encoder_embeddings)} embeddings "
157+
f"(types: {[type(e).__name__ for e in encoder_embeddings]}, "
158+
f"shapes: {[e.shape if hasattr(e, 'shape') else 'N/A' for e in encoder_embeddings]}) "
159+
f"for {len(uncached_multimodal_params)} uncached params. "
160+
f"encoder_forward_fn={encoder_forward_fn}")
155161
return encoder_embeddings
156162

157163
# Validate that multimodal_runtime has required attributes for caching

tensorrt_llm/_torch/models/modeling_nemotron_nano.py

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1788,24 +1788,37 @@ def _encode_audio(self, param: MultimodalParams) -> torch.Tensor:
17881788
result = torch.cat(truncated, dim=0) # [total_tokens, llm_hidden_size]
17891789
return result
17901790

1791-
def _encode_multimodal(
1792-
self, multimodal_params: List[MultimodalParams]
1793-
) -> Tuple[List[torch.Tensor], List[Optional[List[int]]]]:
1794-
"""Dispatch multimodal encoding to the appropriate encoder."""
1791+
def _encode_multimodal(self, multimodal_params: List[MultimodalParams]) -> List[torch.Tensor]:
1792+
"""Dispatch multimodal encoding to the appropriate encoder.
1793+
1794+
Returns a single-element `List[torch.Tensor]` (all per-request
1795+
embeddings concatenated) to conform to the contract expected by
1796+
`get_multimodal_embeddings`, which enables chunked-prefill
1797+
caching. Per-request `num_tokens_in_video` (needed by EVS) is
1798+
stashed in each param's `multimodal_data` dict as a
1799+
side-channel.
1800+
"""
17951801
mm_embeddings = []
1796-
mm_num_tokens = []
17971802
for param in multimodal_params:
17981803
modality_type = param.multimodal_data["modality_type"]
17991804
if modality_type in ("image", "video"):
18001805
embs, num_tokens = self.vision_encoder([param])
18011806
mm_embeddings.append(embs[0])
1802-
mm_num_tokens.append(num_tokens[0] if num_tokens is not None else None)
1807+
1808+
# Stash per-request token counts for later EVS adjustment.
1809+
if num_tokens is not None:
1810+
param.multimodal_data["num_tokens_in_video"] = num_tokens[0]
18031811
elif modality_type == "audio":
18041812
mm_embeddings.append(self._encode_audio(param))
1805-
mm_num_tokens.append(None)
18061813
else:
18071814
raise ValueError(f"Unknown modality: {modality_type}")
1808-
return mm_embeddings, mm_num_tokens
1815+
1816+
# Concatenate per-request embeddings into a single tensor.
1817+
# `get_multimodal_embeddings` expects a single-element list containing one tensor (all
1818+
# items' embeddings concatenated).
1819+
if mm_embeddings:
1820+
return [torch.cat(mm_embeddings, dim=0)]
1821+
return []
18091822

18101823
@torch.inference_mode()
18111824
def forward(
@@ -1832,7 +1845,7 @@ def forward(
18321845
mm_embedding = []
18331846
if len(multimodal_params) > 0:
18341847
if not _is_disagg():
1835-
mm_embedding, num_tokens_in_videos = get_multimodal_embeddings(
1848+
mm_embedding = get_multimodal_embeddings(
18361849
encoder_forward_fn=self._encode_multimodal,
18371850
multimodal_params=multimodal_params[:num_context_requests],
18381851
)
@@ -1843,9 +1856,16 @@ def forward(
18431856
)
18441857
# Adjust input_ids in videos if EVS is applied.
18451858
if self.video_pruning_rate > 0:
1859+
# Retrieve per-video count stashed by `_encode_multimodal`.
1860+
ctx_params = multimodal_params[:num_context_requests]
1861+
num_tokens_in_videos = [
1862+
param.multimodal_data.get("num_tokens_in_video")
1863+
for param in ctx_params
1864+
if param.has_content()
1865+
]
18461866
input_ids = self.merge_evs_mm_embeds(
18471867
num_tokens_in_videos,
1848-
multimodal_params=multimodal_params[:num_context_requests],
1868+
multimodal_params=ctx_params,
18491869
input_ids=input_ids,
18501870
)
18511871

tests/unittest/_torch/modeling/test_modeling_nemotron_nano_v2_vl.py

Lines changed: 243 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from test_modeling_nemotron_h import extract_decode_logprobs
1010

1111
from tensorrt_llm import LLM
12+
from tensorrt_llm._torch.models.modeling_multimodal_utils import get_multimodal_embeddings
1213
from tensorrt_llm._torch.models.modeling_nemotron_nano import (
1314
NanoV2VLVisionEncoder,
1415
NemotronH_Nano_VL_V2,
@@ -20,6 +21,7 @@
2021
default_multimodal_input_loader,
2122
prompt_inputs,
2223
)
24+
from tensorrt_llm.inputs.multimodal import MultimodalParams, MultimodalRuntimeData
2325
from tensorrt_llm.llmapi import KvCacheConfig
2426
from tensorrt_llm.llmapi.llm_args import CudaGraphConfig
2527
from tensorrt_llm.sampling_params import SamplingParams
@@ -220,6 +222,15 @@ def _make_mock_model(self):
220222
model.sound_encoder = mock.MagicMock(spec=ProjectedParakeet)
221223
return model
222224

225+
@staticmethod
226+
def _assert_compatible_with_chunked_prefill(multimodal_embeddings):
227+
# NOTE: `multimodal_embeddings` is expected to be the output of `_encode_multimodal`.
228+
# The below checks help verify that we can make use of `get_multimodal_embeddings` and its
229+
# caching feature. Otherwise, we would be re-encoding the items each chunk during chunked
230+
# prefill.
231+
assert len(multimodal_embeddings) == 1
232+
assert isinstance(multimodal_embeddings[0], torch.Tensor)
233+
223234
def test_encode_multimodal_dispatches_audio(self):
224235
model = self._make_mock_model()
225236
fake_audio_embeds = torch.randn(10, 128)
@@ -229,16 +240,16 @@ def test_encode_multimodal_dispatches_audio(self):
229240
mm_param.multimodal_data = {"modality_type": "audio", "audio": {}}
230241

231242
# Call the real method on our mock
232-
result_embeds, result_nones = NemotronH_Nano_VL_V2._encode_multimodal(model, [mm_param])
243+
result = NemotronH_Nano_VL_V2._encode_multimodal(model, [mm_param])
233244

234245
model._encode_audio.assert_called_once_with(mm_param)
235246
model.vision_encoder.assert_not_called()
236-
assert len(result_embeds) == 1
237-
assert torch.equal(result_embeds[0], fake_audio_embeds)
247+
self._assert_compatible_with_chunked_prefill(result)
248+
assert torch.equal(result[0], fake_audio_embeds)
238249

239250
def test_encode_multimodal_dispatches_image(self):
240251
model = self._make_mock_model()
241-
fake_image_embeds = torch.randn(1, 10, 128)
252+
fake_image_embeds = torch.randn(10, 128)
242253
model.vision_encoder.return_value = ([fake_image_embeds], [None])
243254

244255
mm_param = mock.MagicMock()
@@ -247,10 +258,10 @@ def test_encode_multimodal_dispatches_image(self):
247258
"image": {"pixel_values": torch.randn(1, 100, 768)},
248259
}
249260

250-
result_embeds, result_nones = NemotronH_Nano_VL_V2._encode_multimodal(model, [mm_param])
261+
result = NemotronH_Nano_VL_V2._encode_multimodal(model, [mm_param])
251262

252263
model.vision_encoder.assert_called_once_with([mm_param])
253-
assert len(result_embeds) == 1
264+
self._assert_compatible_with_chunked_prefill(result)
254265

255266
def test_encode_multimodal_unknown_modality_raises(self):
256267
"""Unknown modality raises ValueError."""
@@ -260,3 +271,229 @@ def test_encode_multimodal_unknown_modality_raises(self):
260271

261272
with pytest.raises(ValueError, match="Unknown modality"):
262273
NemotronH_Nano_VL_V2._encode_multimodal(model, [mm_param])
274+
275+
276+
class TestEncodeMultimodalContract:
277+
"""Verify `_encode_multimodal` conforms to the contract expected by `get_multimodal_embeddings`.
278+
279+
The key assumption is that the `encoder_forward_fn` passed to it returns something whose length
280+
is 1, and can be indexed by `[0]` to return a single `torch.Tensor`.
281+
"""
282+
283+
HIDDEN = 128
284+
285+
def _make_mock_model(self):
286+
model = mock.MagicMock(spec=NemotronH_Nano_VL_V2)
287+
model.vision_encoder = mock.MagicMock(spec=NanoV2VLVisionEncoder)
288+
model.sound_encoder = mock.MagicMock(spec=ProjectedParakeet)
289+
return model
290+
291+
def _make_mm_param(self, modality_type, **extra):
292+
param = mock.MagicMock()
293+
param.multimodal_data = {"modality_type": modality_type, **extra}
294+
return param
295+
296+
def test_returns_list_with_a_single_element(self):
297+
model = self._make_mock_model()
298+
model.vision_encoder.return_value = ([torch.randn(5, self.HIDDEN)], [None])
299+
300+
param = self._make_mm_param("image")
301+
result = NemotronH_Nano_VL_V2._encode_multimodal(model, [param])
302+
303+
assert isinstance(result, list)
304+
assert len(result) == 1
305+
306+
def test_single_concatenated_tensor_for_multiple_multimodal_items(self):
307+
"""Multiple multimodal items must be concatenated into a single tensor.
308+
309+
`get_multimodal_embeddings` requires `len(embeddings) == 1` and splits by per-request token
310+
counts in order to cache the embeddings.
311+
"""
312+
model = self._make_mock_model()
313+
emb_a = torch.randn(5, self.HIDDEN)
314+
emb_b = torch.randn(3, self.HIDDEN)
315+
model.vision_encoder.side_effect = [
316+
([emb_a], [None]),
317+
([emb_b], [None]),
318+
]
319+
320+
params = [self._make_mm_param("image"), self._make_mm_param("image")]
321+
result = NemotronH_Nano_VL_V2._encode_multimodal(model, params)
322+
323+
assert len(result) == 1
324+
assert result[0].shape == (8, self.HIDDEN)
325+
# Verify concatenation order is preserved.
326+
assert torch.equal(result[0][:5], emb_a)
327+
assert torch.equal(result[0][5:], emb_b)
328+
329+
def test_mixed_modalities_still_single_tensor(self):
330+
"""Image + audio requests produce one concatenated tensor."""
331+
model = self._make_mock_model()
332+
img_emb = torch.randn(5, self.HIDDEN)
333+
audio_emb = torch.randn(3, self.HIDDEN)
334+
model.vision_encoder.return_value = ([img_emb], [None])
335+
model._encode_audio = mock.MagicMock(return_value=audio_emb)
336+
337+
params = [
338+
self._make_mm_param("image"),
339+
self._make_mm_param("audio", audio={}),
340+
]
341+
result = NemotronH_Nano_VL_V2._encode_multimodal(model, params)
342+
343+
assert len(result) == 1
344+
assert result[0].shape == (8, self.HIDDEN)
345+
346+
def test_empty_params_returns_empty_list(self):
347+
model = self._make_mock_model()
348+
result = NemotronH_Nano_VL_V2._encode_multimodal(model, [])
349+
assert result == []
350+
351+
352+
class TestChunkedPrefillCaching:
353+
"""Verify that `_encode_multimodal` output is compatible with `get_multimodal_embeddings`.
354+
355+
Specifically, we want to test that the caching functionality is exercised and not skipped due
356+
to the return type not being compatible.
357+
358+
The test structure for each modality:
359+
360+
1. Build `MultimodalParams` with a real `MultimodalRuntimeData` (past_seen_token_num=0,
361+
simulating the first chunk).
362+
2. Call `get_multimodal_embeddings` with the real _encode_multimodal wired to mock sub-encoders
363+
-> encoder MUST be invoked.
364+
3. Verify embeddings were cached in `multimodal_data`.
365+
4. Call `get_multimodal_embeddings` again with the SAME params (simulating a second chunk) ->
366+
encoder must NOT be invoked.
367+
"""
368+
369+
HIDDEN = 128
370+
NUM_TOKENS = 10
371+
372+
def _make_mock_model(self):
373+
model = mock.MagicMock(spec=NemotronH_Nano_VL_V2)
374+
model.vision_encoder = mock.MagicMock(spec=NanoV2VLVisionEncoder)
375+
model.sound_encoder = mock.MagicMock(spec=ProjectedParakeet)
376+
return model
377+
378+
def _make_param_with_runtime(self, modality_type, num_tokens, **extra):
379+
"""Build a real MultimodalParams with runtime data for caching."""
380+
runtime = MultimodalRuntimeData(
381+
past_seen_token_num=0,
382+
mm_token_lengths=[num_tokens],
383+
mm_token_positions=[0],
384+
chunk_end_pos=num_tokens,
385+
special_token_offsets=[],
386+
)
387+
return MultimodalParams(
388+
multimodal_data={"modality_type": modality_type, **extra},
389+
multimodal_runtime=runtime,
390+
)
391+
392+
def _make_encoder_fn(self, model):
393+
"""Wrap `_encode_multimodal` as a callable for get_multimodal_embeddings."""
394+
395+
def encoder_fn(params):
396+
return NemotronH_Nano_VL_V2._encode_multimodal(model, params)
397+
398+
return encoder_fn
399+
400+
@pytest.mark.parametrize("modality", ["image", "video"])
401+
def test_vision_encoder_not_called_on_second_chunk(self, modality):
402+
model = self._make_mock_model()
403+
fake_emb = torch.randn(self.NUM_TOKENS, self.HIDDEN)
404+
model.vision_encoder.return_value = ([fake_emb], [None])
405+
406+
param = self._make_param_with_runtime(modality, self.NUM_TOKENS)
407+
encoder_fn = self._make_encoder_fn(model)
408+
409+
# First call: encoder must run and cache the result.
410+
result = get_multimodal_embeddings(
411+
encoder_forward_fn=encoder_fn,
412+
multimodal_params=[param],
413+
)
414+
assert len(result) == 1
415+
assert result[0].shape == (self.NUM_TOKENS, self.HIDDEN)
416+
assert model.vision_encoder.call_count == 1
417+
418+
# Embedding is now cached in multimodal_data.
419+
assert "multimodal_embedding" in param.multimodal_data
420+
421+
# Second call: encoder must NOT run - embeddings come from cache.
422+
result2 = get_multimodal_embeddings(
423+
encoder_forward_fn=encoder_fn,
424+
multimodal_params=[param],
425+
)
426+
assert model.vision_encoder.call_count == 1, (
427+
"`vision_encoder` was called again on the second chunk. "
428+
"Caching is broken - `_encode_multimodal` likely violates the "
429+
"`get_multimodal_embeddings` return type contract."
430+
)
431+
assert len(result2) == 1
432+
assert torch.equal(result2[0], result[0])
433+
434+
def test_audio_encoder_not_called_on_second_chunk(self):
435+
model = self._make_mock_model()
436+
fake_emb = torch.randn(self.NUM_TOKENS, self.HIDDEN)
437+
model._encode_audio = mock.MagicMock(return_value=fake_emb)
438+
439+
param = self._make_param_with_runtime("audio", self.NUM_TOKENS, audio={})
440+
encoder_fn = self._make_encoder_fn(model)
441+
442+
# First call.
443+
result = get_multimodal_embeddings(
444+
encoder_forward_fn=encoder_fn,
445+
multimodal_params=[param],
446+
)
447+
assert len(result) == 1
448+
assert model._encode_audio.call_count == 1
449+
450+
# Second call - should use cache.
451+
result2 = get_multimodal_embeddings(
452+
encoder_forward_fn=encoder_fn,
453+
multimodal_params=[param],
454+
)
455+
assert model._encode_audio.call_count == 1, (
456+
"`_encode_audio` was called again on the second chunk. "
457+
"Caching is broken - `_encode_multimodal` likely violates the "
458+
"`get_multimodal_embeddings` return type contract."
459+
)
460+
assert torch.equal(result2[0], result[0])
461+
462+
def test_multi_request_batch_caching(self):
463+
"""Two image requests in one batch: both cached after one call."""
464+
model = self._make_mock_model()
465+
emb_a = torch.randn(5, self.HIDDEN)
466+
emb_b = torch.randn(3, self.HIDDEN)
467+
model.vision_encoder.side_effect = [
468+
([emb_a], [None]),
469+
([emb_b], [None]),
470+
]
471+
472+
param_a = self._make_param_with_runtime("image", 5)
473+
param_b = self._make_param_with_runtime("image", 3)
474+
encoder_fn = self._make_encoder_fn(model)
475+
476+
# First call: encoder runs for both.
477+
result = get_multimodal_embeddings(
478+
encoder_forward_fn=encoder_fn,
479+
multimodal_params=[param_a, param_b],
480+
)
481+
assert len(result) == 1
482+
assert result[0].shape == (8, self.HIDDEN)
483+
assert model.vision_encoder.call_count == 2 # once per param
484+
485+
# Both should be cached.
486+
assert "multimodal_embedding" in param_a.multimodal_data
487+
assert "multimodal_embedding" in param_b.multimodal_data
488+
489+
# Second call: encoder must not run again.
490+
result2 = get_multimodal_embeddings(
491+
encoder_forward_fn=encoder_fn,
492+
multimodal_params=[param_a, param_b],
493+
)
494+
assert model.vision_encoder.call_count == 2, (
495+
"`vision_encoder` was called again on the second chunk. "
496+
"Caching is broken - `_encode_multimodal` likely violates the "
497+
"`get_multimodal_embeddings` return type contract."
498+
)
499+
assert torch.equal(result2[0], result[0])

0 commit comments

Comments
 (0)