From 43cdf0dad2287734502e34c67385d8a1b57f37d1 Mon Sep 17 00:00:00 2001 From: quic-xiyushi Date: Mon, 18 May 2026 10:28:45 -0700 Subject: [PATCH 1/5] Add user_vision_size in VLM get_specializations for chunked embedding in vLLM v1 Signed-off-by: quic-xiyushi --- .../models/gemma3/modeling_gemma3.py | 21 ++++++++++++------- .../models/internvl/modeling_internvl.py | 7 ++++++- .../models/llama4/modeling_llama4.py | 7 ++++++- .../models/llava/modeling_llava.py | 7 ++++++- .../models/llava_next/modeling_llava_next.py | 7 ++++++- .../models/mistral3/modeling_mistral3.py | 11 +++++++--- .../models/qwen2_5_vl/modeling_qwen2_5_vl.py | 15 ++++++------- .../models/qwen3_vl/modeling_qwen3_vl.py | 15 ++++++------- .../qwen3_vl_moe/modeling_qwen3_vl_moe.py | 15 ++++++------- 9 files changed, 69 insertions(+), 36 deletions(-) diff --git a/QEfficient/transformers/models/gemma3/modeling_gemma3.py b/QEfficient/transformers/models/gemma3/modeling_gemma3.py index 8fb8cdbdda..1b28c38fc3 100644 --- a/QEfficient/transformers/models/gemma3/modeling_gemma3.py +++ b/QEfficient/transformers/models/gemma3/modeling_gemma3.py @@ -731,7 +731,12 @@ def get_specializations( elif img_size is None: img_size = 896 # FIXME based on gemma3 Image size logger.warning("Setting img_size to be 336, as it was neither passed nor found in vision_config") - mm_tokens_per_image = getattr(self.config, "mm_tokens_per_image", 256) + user_vision_size = compiler_options.pop("vision_size", None) + if user_vision_size: + assert user_vision_size < ctx_len, "vision_size must be less than ctx_len" + vision_size = user_vision_size + else: + vision_size = getattr(self.config, "mm_tokens_per_image", 256) vision = [ { @@ -752,7 +757,7 @@ def get_specializations( "comp_ctx_lengths": comp_ctx_lengths_prefill[i], "sliding_window": self.language_model.config.sliding_window, "img_size": img_size, - "mm_tokens_per_image": mm_tokens_per_image, + "vision_size": vision_size, "vision_batch_size": batch_size, } if continuous_batching: @@ -771,7 +776,7 @@ def get_specializations( "comp_ctx_lengths": comp_ctx_lengths_decode[i], "sliding_window": self.language_model.config.sliding_window, "img_size": img_size, - "mm_tokens_per_image": mm_tokens_per_image, + "vision_size": vision_size, "vision_batch_size": batch_size, } if continuous_batching: @@ -787,7 +792,7 @@ def get_specializations( "ctx_len": ctx_len, "sliding_window": self.language_model.config.sliding_window, "img_size": img_size, - "mm_tokens_per_image": mm_tokens_per_image, + "vision_size": vision_size, "vision_batch_size": batch_size, } if continuous_batching: @@ -803,7 +808,7 @@ def get_specializations( "ctx_len": ctx_len, "sliding_window": self.language_model.config.sliding_window, "img_size": img_size, - "mm_tokens_per_image": mm_tokens_per_image, + "vision_size": vision_size, "vision_batch_size": batch_size, } if continuous_batching: @@ -829,7 +834,7 @@ def get_onnx_dynamic_axes( lang_dynamic_axes = {} lang_dynamic_axes["input_ids"] = {0: "batch_size", 1: "seq_len"} lang_dynamic_axes["position_ids"] = {0: "batch_size", 1: "seq_len"} - lang_dynamic_axes["vision_embeds"] = {0: "vision_batch_size", 1: "mm_tokens_per_image"} + lang_dynamic_axes["vision_embeds"] = {0: "vision_batch_size", 1: "vision_size"} if continuous_batching: lang_dynamic_axes["batch_index"] = {0: "batch_size"} vision_dynamic_axes["pixel_values"] = {0: "batch_size", 2: "img_size", 3: "img_size"} @@ -911,13 +916,13 @@ def get_dummy_inputs( else: img_size = 896 - mm_tokens_per_image = getattr(self.config, "mm_tokens_per_image", 256) + vision_size = getattr(self.config, "mm_tokens_per_image", 256) # Define shapes inputs_shapes = {} inputs_shapes["input_ids"] = (constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN) inputs_shapes["vision_embeds"] = ( 1, # constants.INTERN_NUM_PATCHES, - mm_tokens_per_image, # constants.INTERN_FEATURE_SIZE, + vision_size, # constants.INTERN_FEATURE_SIZE, self.language_model.config.hidden_size, # 5120 ) inputs_shapes["position_ids"] = ( diff --git a/QEfficient/transformers/models/internvl/modeling_internvl.py b/QEfficient/transformers/models/internvl/modeling_internvl.py index e389e6a840..da4ec5758c 100644 --- a/QEfficient/transformers/models/internvl/modeling_internvl.py +++ b/QEfficient/transformers/models/internvl/modeling_internvl.py @@ -134,7 +134,12 @@ def get_specializations( raise NotImplementedError("Image Size other than 448 is not supported for Intern models yet.") per_patch_embed_size = (img_size // self.config.vision_config.patch_size * self.config.downsample_ratio) ** 2 - vision_size = int(batch_size * num_patches * per_patch_embed_size) + user_vision_size = compiler_options.pop("vision_size", None) + if user_vision_size: + assert user_vision_size < ctx_len, "vision_size must be less than ctx_len" + vision_size = user_vision_size + else: + vision_size = int(batch_size * num_patches * per_patch_embed_size) vision = [ { "batch_size": batch_size, diff --git a/QEfficient/transformers/models/llama4/modeling_llama4.py b/QEfficient/transformers/models/llama4/modeling_llama4.py index a49f9a24be..7f6c160d19 100644 --- a/QEfficient/transformers/models/llama4/modeling_llama4.py +++ b/QEfficient/transformers/models/llama4/modeling_llama4.py @@ -1001,7 +1001,12 @@ def get_specializations( * (img_size // self.config.vision_config.patch_size) // downsample_ratio ) - vision_size = num_features_per_tile * max_num_tiles + user_vision_size = compiler_options.pop("vision_size", None) + if user_vision_size: + assert user_vision_size < ctx_len, "vision_size must be less than ctx_len" + vision_size = user_vision_size + else: + vision_size = num_features_per_tile * max_num_tiles vision = [ { diff --git a/QEfficient/transformers/models/llava/modeling_llava.py b/QEfficient/transformers/models/llava/modeling_llava.py index 48b002a31a..4a873bb0bd 100644 --- a/QEfficient/transformers/models/llava/modeling_llava.py +++ b/QEfficient/transformers/models/llava/modeling_llava.py @@ -237,7 +237,12 @@ def get_specializations( logger.warning("Setting img_size to be 336, as it was neither passed nor found in vision_config") if img_size != 336 and kv_offload: raise NotImplementedError("Image Size other than 336 is not supported for Llava models yet.") - vision_size = (img_size // self.config.vision_config.patch_size) ** 2 + user_vision_size = compiler_options.pop("vision_size", None) + if user_vision_size: + assert user_vision_size < ctx_len, "vision_size must be less than ctx_len" + vision_size = user_vision_size + else: + vision_size = (img_size // self.config.vision_config.patch_size) ** 2 vision = [ { "batch_size": batch_size, diff --git a/QEfficient/transformers/models/llava_next/modeling_llava_next.py b/QEfficient/transformers/models/llava_next/modeling_llava_next.py index 59d5cad229..fec5ad8253 100755 --- a/QEfficient/transformers/models/llava_next/modeling_llava_next.py +++ b/QEfficient/transformers/models/llava_next/modeling_llava_next.py @@ -326,7 +326,12 @@ def get_specializations( logger.warning("Setting img_size to be 384, as it was neither passed nor found in vision_config") if img_size != constants.GRANITEVISION_IMG_SIZE and kv_offload: logger.warning("Image Size other than 384 is not supported for LlavaNext models yet.") - vision_size = constants.GRANITEVISION_FEATURE_SIZE + user_vision_size = compiler_options.pop("vision_size", None) + if user_vision_size: + assert user_vision_size < ctx_len, "vision_size must be less than ctx_len" + vision_size = user_vision_size + else: + vision_size = constants.GRANITEVISION_FEATURE_SIZE vision = [ { "batch_size": batch_size, diff --git a/QEfficient/transformers/models/mistral3/modeling_mistral3.py b/QEfficient/transformers/models/mistral3/modeling_mistral3.py index a8fb34bafe..f5658025f3 100644 --- a/QEfficient/transformers/models/mistral3/modeling_mistral3.py +++ b/QEfficient/transformers/models/mistral3/modeling_mistral3.py @@ -370,9 +370,14 @@ def get_specializations( ctx_len = ctx_len if ctx_len else constants.INTERN_CTX_LEN patch_size = self.config.vision_config.patch_size kernel_size = self.config.spatial_merge_size - vision_size = ( - ((img_size // patch_size) * (img_size // patch_size)) * (batch_size) // (kernel_size * kernel_size) - ) + user_vision_size = compiler_options.pop("vision_size", None) + if user_vision_size: + assert user_vision_size < ctx_len, "vision_size must be less than ctx_len" + vision_size = user_vision_size + else: + vision_size = ( + ((img_size // patch_size) * (img_size // patch_size)) * (batch_size) // (kernel_size * kernel_size) + ) vision = [ { diff --git a/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py b/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py index 45c6616018..dd0a23ccdd 100644 --- a/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +++ b/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py @@ -1091,13 +1091,14 @@ def get_specializations( "resolution." ) else: - assert vision_size * f <= user_vision_size, ( - f"Computed vision_size of {vision_size * f} tokens " - f"(vision_size={vision_size}, num_frames={f}) for image resolution " - f"(width={w}, height={h}) cannot exceed the provided " - f"vision_size={user_vision_size}. Please adjust the image resolution or " - "increase the vision_size." - ) + if vision_size * f >= user_vision_size: + logger.warning_once( + f"Computed vision_size of {vision_size * f} tokens " + f"(vision_size={vision_size}, num_frames={f}) for image resolution " + f"(width={w}, height={h}) exceed the provided " + f"vision_size={user_vision_size}. " + f"Vision embedding need to be chunked during prefill." + ) vision.append( { diff --git a/QEfficient/transformers/models/qwen3_vl/modeling_qwen3_vl.py b/QEfficient/transformers/models/qwen3_vl/modeling_qwen3_vl.py index 2d834423f6..39ff9e0138 100644 --- a/QEfficient/transformers/models/qwen3_vl/modeling_qwen3_vl.py +++ b/QEfficient/transformers/models/qwen3_vl/modeling_qwen3_vl.py @@ -972,13 +972,14 @@ def get_specializations( "resolution." ) else: - assert vision_size * f <= user_vision_size, ( - f"Computed vision_size of {vision_size * f} tokens " - f"(vision_size={vision_size}, num_frames={f}) for image resolution " - f"(width={w}, height={h}) cannot exceed the provided " - f"vision_size={user_vision_size}. Please adjust the image resolution or " - "increase the vision_size." - ) + if vision_size * f >= user_vision_size: + logger.warning_once( + f"Computed vision_size of {vision_size * f} tokens " + f"(vision_size={vision_size}, num_frames={f}) for image resolution " + f"(width={w}, height={h}) exceed the provided " + f"vision_size={user_vision_size}. " + f"Vision embedding need to be chunked during prefill." + ) vision.append( { diff --git a/QEfficient/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py b/QEfficient/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py index 078cb4afb7..2f5cd86ee2 100644 --- a/QEfficient/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +++ b/QEfficient/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py @@ -1001,13 +1001,14 @@ def get_specializations( "resolution." ) else: - assert vision_size * f <= user_vision_size, ( - f"Computed vision_size of {vision_size * f} tokens " - f"(vision_size={vision_size}, num_frames={f}) for image resolution " - f"(width={w}, height={h}) cannot exceed the provided " - f"vision_size={user_vision_size}. Please adjust the image resolution or " - "increase the vision_size." - ) + if vision_size * f >= user_vision_size: + logger.warning_once( + f"Computed vision_size of {vision_size * f} tokens " + f"(vision_size={vision_size}, num_frames={f}) for image resolution " + f"(width={w}, height={h}) exceed the provided " + f"vision_size={user_vision_size}. " + f"Vision embedding need to be chunked during prefill." + ) vision.append( { From 2293951e9515599388be5d2e13633c84dd03b265 Mon Sep 17 00:00:00 2001 From: sanising Date: Mon, 1 Jun 2026 14:41:42 -0700 Subject: [PATCH 2/5] Fix bug Signed-off-by: sanising --- .../transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py | 2 +- QEfficient/transformers/models/qwen3_vl/modeling_qwen3_vl.py | 2 +- .../transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py b/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py index dd0a23ccdd..961d7412d7 100644 --- a/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +++ b/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py @@ -1091,7 +1091,7 @@ def get_specializations( "resolution." ) else: - if vision_size * f >= user_vision_size: + if vision_size * f > user_vision_size: logger.warning_once( f"Computed vision_size of {vision_size * f} tokens " f"(vision_size={vision_size}, num_frames={f}) for image resolution " diff --git a/QEfficient/transformers/models/qwen3_vl/modeling_qwen3_vl.py b/QEfficient/transformers/models/qwen3_vl/modeling_qwen3_vl.py index 39ff9e0138..b0e4a4fea4 100644 --- a/QEfficient/transformers/models/qwen3_vl/modeling_qwen3_vl.py +++ b/QEfficient/transformers/models/qwen3_vl/modeling_qwen3_vl.py @@ -972,7 +972,7 @@ def get_specializations( "resolution." ) else: - if vision_size * f >= user_vision_size: + if vision_size * f > user_vision_size: logger.warning_once( f"Computed vision_size of {vision_size * f} tokens " f"(vision_size={vision_size}, num_frames={f}) for image resolution " diff --git a/QEfficient/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py b/QEfficient/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py index 2f5cd86ee2..65d483b1ae 100644 --- a/QEfficient/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +++ b/QEfficient/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py @@ -1001,7 +1001,7 @@ def get_specializations( "resolution." ) else: - if vision_size * f >= user_vision_size: + if vision_size * f > user_vision_size: logger.warning_once( f"Computed vision_size of {vision_size * f} tokens " f"(vision_size={vision_size}, num_frames={f}) for image resolution " From 16967aa76f7fa54e5affdf88abd82d38e7224614 Mon Sep 17 00:00:00 2001 From: sanising Date: Mon, 1 Jun 2026 14:52:33 -0700 Subject: [PATCH 3/5] Change assert to exception Signed-off-by: sanising --- QEfficient/transformers/models/gemma3/modeling_gemma3.py | 3 ++- QEfficient/transformers/models/internvl/modeling_internvl.py | 3 ++- QEfficient/transformers/models/llama4/modeling_llama4.py | 3 ++- QEfficient/transformers/models/llava/modeling_llava.py | 3 ++- .../transformers/models/llava_next/modeling_llava_next.py | 3 ++- QEfficient/transformers/models/mistral3/modeling_mistral3.py | 3 ++- 6 files changed, 12 insertions(+), 6 deletions(-) diff --git a/QEfficient/transformers/models/gemma3/modeling_gemma3.py b/QEfficient/transformers/models/gemma3/modeling_gemma3.py index 1b28c38fc3..29c023d4f7 100644 --- a/QEfficient/transformers/models/gemma3/modeling_gemma3.py +++ b/QEfficient/transformers/models/gemma3/modeling_gemma3.py @@ -733,7 +733,8 @@ def get_specializations( logger.warning("Setting img_size to be 336, as it was neither passed nor found in vision_config") user_vision_size = compiler_options.pop("vision_size", None) if user_vision_size: - assert user_vision_size < ctx_len, "vision_size must be less than ctx_len" + if user_vision_size >= ctx_len: + raise ValueError("vision_size must be less than ctx_len") vision_size = user_vision_size else: vision_size = getattr(self.config, "mm_tokens_per_image", 256) diff --git a/QEfficient/transformers/models/internvl/modeling_internvl.py b/QEfficient/transformers/models/internvl/modeling_internvl.py index da4ec5758c..f95a657843 100644 --- a/QEfficient/transformers/models/internvl/modeling_internvl.py +++ b/QEfficient/transformers/models/internvl/modeling_internvl.py @@ -136,7 +136,8 @@ def get_specializations( per_patch_embed_size = (img_size // self.config.vision_config.patch_size * self.config.downsample_ratio) ** 2 user_vision_size = compiler_options.pop("vision_size", None) if user_vision_size: - assert user_vision_size < ctx_len, "vision_size must be less than ctx_len" + if user_vision_size >= ctx_len: + raise ValueError("vision_size must be less than ctx_len") vision_size = user_vision_size else: vision_size = int(batch_size * num_patches * per_patch_embed_size) diff --git a/QEfficient/transformers/models/llama4/modeling_llama4.py b/QEfficient/transformers/models/llama4/modeling_llama4.py index 7f6c160d19..71e0f6000f 100644 --- a/QEfficient/transformers/models/llama4/modeling_llama4.py +++ b/QEfficient/transformers/models/llama4/modeling_llama4.py @@ -1003,7 +1003,8 @@ def get_specializations( ) user_vision_size = compiler_options.pop("vision_size", None) if user_vision_size: - assert user_vision_size < ctx_len, "vision_size must be less than ctx_len" + if user_vision_size >= ctx_len: + raise ValueError("vision_size must be less than ctx_len") vision_size = user_vision_size else: vision_size = num_features_per_tile * max_num_tiles diff --git a/QEfficient/transformers/models/llava/modeling_llava.py b/QEfficient/transformers/models/llava/modeling_llava.py index 4a873bb0bd..e6d19d2782 100644 --- a/QEfficient/transformers/models/llava/modeling_llava.py +++ b/QEfficient/transformers/models/llava/modeling_llava.py @@ -239,7 +239,8 @@ def get_specializations( raise NotImplementedError("Image Size other than 336 is not supported for Llava models yet.") user_vision_size = compiler_options.pop("vision_size", None) if user_vision_size: - assert user_vision_size < ctx_len, "vision_size must be less than ctx_len" + if user_vision_size >= ctx_len: + raise ValueError("vision_size must be less than ctx_len") vision_size = user_vision_size else: vision_size = (img_size // self.config.vision_config.patch_size) ** 2 diff --git a/QEfficient/transformers/models/llava_next/modeling_llava_next.py b/QEfficient/transformers/models/llava_next/modeling_llava_next.py index fec5ad8253..9d1943dc62 100755 --- a/QEfficient/transformers/models/llava_next/modeling_llava_next.py +++ b/QEfficient/transformers/models/llava_next/modeling_llava_next.py @@ -328,7 +328,8 @@ def get_specializations( logger.warning("Image Size other than 384 is not supported for LlavaNext models yet.") user_vision_size = compiler_options.pop("vision_size", None) if user_vision_size: - assert user_vision_size < ctx_len, "vision_size must be less than ctx_len" + if user_vision_size >= ctx_len: + raise ValueError("vision_size must be less than ctx_len") vision_size = user_vision_size else: vision_size = constants.GRANITEVISION_FEATURE_SIZE diff --git a/QEfficient/transformers/models/mistral3/modeling_mistral3.py b/QEfficient/transformers/models/mistral3/modeling_mistral3.py index f5658025f3..4dbe76d993 100644 --- a/QEfficient/transformers/models/mistral3/modeling_mistral3.py +++ b/QEfficient/transformers/models/mistral3/modeling_mistral3.py @@ -372,7 +372,8 @@ def get_specializations( kernel_size = self.config.spatial_merge_size user_vision_size = compiler_options.pop("vision_size", None) if user_vision_size: - assert user_vision_size < ctx_len, "vision_size must be less than ctx_len" + if user_vision_size >= ctx_len: + raise ValueError("vision_size must be less than ctx_len") vision_size = user_vision_size else: vision_size = ( From 4f65a7762dc12199619c060f0b1e7ad600d5b525 Mon Sep 17 00:00:00 2001 From: sanising Date: Mon, 1 Jun 2026 14:57:06 -0700 Subject: [PATCH 4/5] Add deprecation warning for mm_tokens_per_image Signed-off-by: sanising --- QEfficient/transformers/models/gemma3/modeling_gemma3.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/QEfficient/transformers/models/gemma3/modeling_gemma3.py b/QEfficient/transformers/models/gemma3/modeling_gemma3.py index 29c023d4f7..b7673891a3 100644 --- a/QEfficient/transformers/models/gemma3/modeling_gemma3.py +++ b/QEfficient/transformers/models/gemma3/modeling_gemma3.py @@ -737,7 +737,12 @@ def get_specializations( raise ValueError("vision_size must be less than ctx_len") vision_size = user_vision_size else: - vision_size = getattr(self.config, "mm_tokens_per_image", 256) + _mm_tokens_per_image = getattr(self.config, "mm_tokens_per_image", None) + if _mm_tokens_per_image: + logger.warning_once("mm_tokens_per_image is deprecated and will be removed in the next release.") + vision_size = _mm_tokens_per_image + else: + vision_size = 256 vision = [ { From d308a404e5521b00e424af421316c645cd5d596c Mon Sep 17 00:00:00 2001 From: sanising Date: Mon, 1 Jun 2026 15:39:23 -0700 Subject: [PATCH 5/5] Add deprecation warning for mm_tokens_per_image Signed-off-by: sanising --- QEfficient/transformers/models/gemma3/modeling_gemma3.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/QEfficient/transformers/models/gemma3/modeling_gemma3.py b/QEfficient/transformers/models/gemma3/modeling_gemma3.py index b7673891a3..6927529b76 100644 --- a/QEfficient/transformers/models/gemma3/modeling_gemma3.py +++ b/QEfficient/transformers/models/gemma3/modeling_gemma3.py @@ -922,7 +922,13 @@ def get_dummy_inputs( else: img_size = 896 - vision_size = getattr(self.config, "mm_tokens_per_image", 256) + _mm_tokens_per_image = getattr(self.config, "mm_tokens_per_image", None) + if _mm_tokens_per_image: + logger.warning_once("mm_tokens_per_image is deprecated and will be removed in the next release.") + vision_size = _mm_tokens_per_image + else: + vision_size = 256 + # Define shapes inputs_shapes = {} inputs_shapes["input_ids"] = (constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN)