99from test_modeling_nemotron_h import extract_decode_logprobs
1010
1111from tensorrt_llm import LLM
12+ from tensorrt_llm ._torch .models .modeling_multimodal_utils import get_multimodal_embeddings
1213from tensorrt_llm ._torch .models .modeling_nemotron_nano import (
1314 NanoV2VLVisionEncoder ,
1415 NemotronH_Nano_VL_V2 ,
2021 default_multimodal_input_loader ,
2122 prompt_inputs ,
2223)
24+ from tensorrt_llm .inputs .multimodal import MultimodalParams , MultimodalRuntimeData
2325from tensorrt_llm .llmapi import KvCacheConfig
2426from tensorrt_llm .llmapi .llm_args import CudaGraphConfig
2527from tensorrt_llm .sampling_params import SamplingParams
@@ -220,6 +222,15 @@ def _make_mock_model(self):
220222 model .sound_encoder = mock .MagicMock (spec = ProjectedParakeet )
221223 return model
222224
225+ @staticmethod
226+ def _assert_compatible_with_chunked_prefill (multimodal_embeddings ):
227+ # NOTE: `multimodal_embeddings` is expected to be the output of `_encode_multimodal`.
228+ # The below checks help verify that we can make use of `get_multimodal_embeddings` and its
229+ # caching feature. Otherwise, we would be re-encoding the items each chunk during chunked
230+ # prefill.
231+ assert len (multimodal_embeddings ) == 1
232+ assert isinstance (multimodal_embeddings [0 ], torch .Tensor )
233+
223234 def test_encode_multimodal_dispatches_audio (self ):
224235 model = self ._make_mock_model ()
225236 fake_audio_embeds = torch .randn (10 , 128 )
@@ -229,16 +240,16 @@ def test_encode_multimodal_dispatches_audio(self):
229240 mm_param .multimodal_data = {"modality_type" : "audio" , "audio" : {}}
230241
231242 # Call the real method on our mock
232- result_embeds , result_nones = NemotronH_Nano_VL_V2 ._encode_multimodal (model , [mm_param ])
243+ result = NemotronH_Nano_VL_V2 ._encode_multimodal (model , [mm_param ])
233244
234245 model ._encode_audio .assert_called_once_with (mm_param )
235246 model .vision_encoder .assert_not_called ()
236- assert len ( result_embeds ) == 1
237- assert torch .equal (result_embeds [0 ], fake_audio_embeds )
247+ self . _assert_compatible_with_chunked_prefill ( result )
248+ assert torch .equal (result [0 ], fake_audio_embeds )
238249
239250 def test_encode_multimodal_dispatches_image (self ):
240251 model = self ._make_mock_model ()
241- fake_image_embeds = torch .randn (1 , 10 , 128 )
252+ fake_image_embeds = torch .randn (10 , 128 )
242253 model .vision_encoder .return_value = ([fake_image_embeds ], [None ])
243254
244255 mm_param = mock .MagicMock ()
@@ -247,10 +258,10 @@ def test_encode_multimodal_dispatches_image(self):
247258 "image" : {"pixel_values" : torch .randn (1 , 100 , 768 )},
248259 }
249260
250- result_embeds , result_nones = NemotronH_Nano_VL_V2 ._encode_multimodal (model , [mm_param ])
261+ result = NemotronH_Nano_VL_V2 ._encode_multimodal (model , [mm_param ])
251262
252263 model .vision_encoder .assert_called_once_with ([mm_param ])
253- assert len ( result_embeds ) == 1
264+ self . _assert_compatible_with_chunked_prefill ( result )
254265
255266 def test_encode_multimodal_unknown_modality_raises (self ):
256267 """Unknown modality raises ValueError."""
@@ -260,3 +271,229 @@ def test_encode_multimodal_unknown_modality_raises(self):
260271
261272 with pytest .raises (ValueError , match = "Unknown modality" ):
262273 NemotronH_Nano_VL_V2 ._encode_multimodal (model , [mm_param ])
274+
275+
276+ class TestEncodeMultimodalContract :
277+ """Verify `_encode_multimodal` conforms to the contract expected by `get_multimodal_embeddings`.
278+
279+ The key assumption is that the `encoder_forward_fn` passed to it returns something whose length
280+ is 1, and can be indexed by `[0]` to return a single `torch.Tensor`.
281+ """
282+
283+ HIDDEN = 128
284+
285+ def _make_mock_model (self ):
286+ model = mock .MagicMock (spec = NemotronH_Nano_VL_V2 )
287+ model .vision_encoder = mock .MagicMock (spec = NanoV2VLVisionEncoder )
288+ model .sound_encoder = mock .MagicMock (spec = ProjectedParakeet )
289+ return model
290+
291+ def _make_mm_param (self , modality_type , ** extra ):
292+ param = mock .MagicMock ()
293+ param .multimodal_data = {"modality_type" : modality_type , ** extra }
294+ return param
295+
296+ def test_returns_list_with_a_single_element (self ):
297+ model = self ._make_mock_model ()
298+ model .vision_encoder .return_value = ([torch .randn (5 , self .HIDDEN )], [None ])
299+
300+ param = self ._make_mm_param ("image" )
301+ result = NemotronH_Nano_VL_V2 ._encode_multimodal (model , [param ])
302+
303+ assert isinstance (result , list )
304+ assert len (result ) == 1
305+
306+ def test_single_concatenated_tensor_for_multiple_multimodal_items (self ):
307+ """Multiple multimodal items must be concatenated into a single tensor.
308+
309+ `get_multimodal_embeddings` requires `len(embeddings) == 1` and splits by per-request token
310+ counts in order to cache the embeddings.
311+ """
312+ model = self ._make_mock_model ()
313+ emb_a = torch .randn (5 , self .HIDDEN )
314+ emb_b = torch .randn (3 , self .HIDDEN )
315+ model .vision_encoder .side_effect = [
316+ ([emb_a ], [None ]),
317+ ([emb_b ], [None ]),
318+ ]
319+
320+ params = [self ._make_mm_param ("image" ), self ._make_mm_param ("image" )]
321+ result = NemotronH_Nano_VL_V2 ._encode_multimodal (model , params )
322+
323+ assert len (result ) == 1
324+ assert result [0 ].shape == (8 , self .HIDDEN )
325+ # Verify concatenation order is preserved.
326+ assert torch .equal (result [0 ][:5 ], emb_a )
327+ assert torch .equal (result [0 ][5 :], emb_b )
328+
329+ def test_mixed_modalities_still_single_tensor (self ):
330+ """Image + audio requests produce one concatenated tensor."""
331+ model = self ._make_mock_model ()
332+ img_emb = torch .randn (5 , self .HIDDEN )
333+ audio_emb = torch .randn (3 , self .HIDDEN )
334+ model .vision_encoder .return_value = ([img_emb ], [None ])
335+ model ._encode_audio = mock .MagicMock (return_value = audio_emb )
336+
337+ params = [
338+ self ._make_mm_param ("image" ),
339+ self ._make_mm_param ("audio" , audio = {}),
340+ ]
341+ result = NemotronH_Nano_VL_V2 ._encode_multimodal (model , params )
342+
343+ assert len (result ) == 1
344+ assert result [0 ].shape == (8 , self .HIDDEN )
345+
346+ def test_empty_params_returns_empty_list (self ):
347+ model = self ._make_mock_model ()
348+ result = NemotronH_Nano_VL_V2 ._encode_multimodal (model , [])
349+ assert result == []
350+
351+
352+ class TestChunkedPrefillCaching :
353+ """Verify that `_encode_multimodal` output is compatible with `get_multimodal_embeddings`.
354+
355+ Specifically, we want to test that the caching functionality is exercised and not skipped due
356+ to the return type not being compatible.
357+
358+ The test structure for each modality:
359+
360+ 1. Build `MultimodalParams` with a real `MultimodalRuntimeData` (past_seen_token_num=0,
361+ simulating the first chunk).
362+ 2. Call `get_multimodal_embeddings` with the real _encode_multimodal wired to mock sub-encoders
363+ -> encoder MUST be invoked.
364+ 3. Verify embeddings were cached in `multimodal_data`.
365+ 4. Call `get_multimodal_embeddings` again with the SAME params (simulating a second chunk) ->
366+ encoder must NOT be invoked.
367+ """
368+
369+ HIDDEN = 128
370+ NUM_TOKENS = 10
371+
372+ def _make_mock_model (self ):
373+ model = mock .MagicMock (spec = NemotronH_Nano_VL_V2 )
374+ model .vision_encoder = mock .MagicMock (spec = NanoV2VLVisionEncoder )
375+ model .sound_encoder = mock .MagicMock (spec = ProjectedParakeet )
376+ return model
377+
378+ def _make_param_with_runtime (self , modality_type , num_tokens , ** extra ):
379+ """Build a real MultimodalParams with runtime data for caching."""
380+ runtime = MultimodalRuntimeData (
381+ past_seen_token_num = 0 ,
382+ mm_token_lengths = [num_tokens ],
383+ mm_token_positions = [0 ],
384+ chunk_end_pos = num_tokens ,
385+ special_token_offsets = [],
386+ )
387+ return MultimodalParams (
388+ multimodal_data = {"modality_type" : modality_type , ** extra },
389+ multimodal_runtime = runtime ,
390+ )
391+
392+ def _make_encoder_fn (self , model ):
393+ """Wrap `_encode_multimodal` as a callable for get_multimodal_embeddings."""
394+
395+ def encoder_fn (params ):
396+ return NemotronH_Nano_VL_V2 ._encode_multimodal (model , params )
397+
398+ return encoder_fn
399+
400+ @pytest .mark .parametrize ("modality" , ["image" , "video" ])
401+ def test_vision_encoder_not_called_on_second_chunk (self , modality ):
402+ model = self ._make_mock_model ()
403+ fake_emb = torch .randn (self .NUM_TOKENS , self .HIDDEN )
404+ model .vision_encoder .return_value = ([fake_emb ], [None ])
405+
406+ param = self ._make_param_with_runtime (modality , self .NUM_TOKENS )
407+ encoder_fn = self ._make_encoder_fn (model )
408+
409+ # First call: encoder must run and cache the result.
410+ result = get_multimodal_embeddings (
411+ encoder_forward_fn = encoder_fn ,
412+ multimodal_params = [param ],
413+ )
414+ assert len (result ) == 1
415+ assert result [0 ].shape == (self .NUM_TOKENS , self .HIDDEN )
416+ assert model .vision_encoder .call_count == 1
417+
418+ # Embedding is now cached in multimodal_data.
419+ assert "multimodal_embedding" in param .multimodal_data
420+
421+ # Second call: encoder must NOT run - embeddings come from cache.
422+ result2 = get_multimodal_embeddings (
423+ encoder_forward_fn = encoder_fn ,
424+ multimodal_params = [param ],
425+ )
426+ assert model .vision_encoder .call_count == 1 , (
427+ "`vision_encoder` was called again on the second chunk. "
428+ "Caching is broken - `_encode_multimodal` likely violates the "
429+ "`get_multimodal_embeddings` return type contract."
430+ )
431+ assert len (result2 ) == 1
432+ assert torch .equal (result2 [0 ], result [0 ])
433+
434+ def test_audio_encoder_not_called_on_second_chunk (self ):
435+ model = self ._make_mock_model ()
436+ fake_emb = torch .randn (self .NUM_TOKENS , self .HIDDEN )
437+ model ._encode_audio = mock .MagicMock (return_value = fake_emb )
438+
439+ param = self ._make_param_with_runtime ("audio" , self .NUM_TOKENS , audio = {})
440+ encoder_fn = self ._make_encoder_fn (model )
441+
442+ # First call.
443+ result = get_multimodal_embeddings (
444+ encoder_forward_fn = encoder_fn ,
445+ multimodal_params = [param ],
446+ )
447+ assert len (result ) == 1
448+ assert model ._encode_audio .call_count == 1
449+
450+ # Second call - should use cache.
451+ result2 = get_multimodal_embeddings (
452+ encoder_forward_fn = encoder_fn ,
453+ multimodal_params = [param ],
454+ )
455+ assert model ._encode_audio .call_count == 1 , (
456+ "`_encode_audio` was called again on the second chunk. "
457+ "Caching is broken - `_encode_multimodal` likely violates the "
458+ "`get_multimodal_embeddings` return type contract."
459+ )
460+ assert torch .equal (result2 [0 ], result [0 ])
461+
462+ def test_multi_request_batch_caching (self ):
463+ """Two image requests in one batch: both cached after one call."""
464+ model = self ._make_mock_model ()
465+ emb_a = torch .randn (5 , self .HIDDEN )
466+ emb_b = torch .randn (3 , self .HIDDEN )
467+ model .vision_encoder .side_effect = [
468+ ([emb_a ], [None ]),
469+ ([emb_b ], [None ]),
470+ ]
471+
472+ param_a = self ._make_param_with_runtime ("image" , 5 )
473+ param_b = self ._make_param_with_runtime ("image" , 3 )
474+ encoder_fn = self ._make_encoder_fn (model )
475+
476+ # First call: encoder runs for both.
477+ result = get_multimodal_embeddings (
478+ encoder_forward_fn = encoder_fn ,
479+ multimodal_params = [param_a , param_b ],
480+ )
481+ assert len (result ) == 1
482+ assert result [0 ].shape == (8 , self .HIDDEN )
483+ assert model .vision_encoder .call_count == 2 # once per param
484+
485+ # Both should be cached.
486+ assert "multimodal_embedding" in param_a .multimodal_data
487+ assert "multimodal_embedding" in param_b .multimodal_data
488+
489+ # Second call: encoder must not run again.
490+ result2 = get_multimodal_embeddings (
491+ encoder_forward_fn = encoder_fn ,
492+ multimodal_params = [param_a , param_b ],
493+ )
494+ assert model .vision_encoder .call_count == 2 , (
495+ "`vision_encoder` was called again on the second chunk. "
496+ "Caching is broken - `_encode_multimodal` likely violates the "
497+ "`get_multimodal_embeddings` return type contract."
498+ )
499+ assert torch .equal (result2 [0 ], result [0 ])
0 commit comments