Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 8 additions & 25 deletions QEfficient/transformers/models/gemma3/modeling_gemma3.py
Original file line number Diff line number Diff line change
Expand Up @@ -731,18 +731,7 @@ def get_specializations(
elif img_size is None:
img_size = 896 # FIXME based on gemma3 Image size
logger.warning("Setting img_size to be 336, as it was neither passed nor found in vision_config")
user_vision_size = compiler_options.pop("vision_size", None)
if user_vision_size:
if user_vision_size >= ctx_len:
raise ValueError("vision_size must be less than ctx_len")
vision_size = user_vision_size
else:
_mm_tokens_per_image = getattr(self.config, "mm_tokens_per_image", None)
if _mm_tokens_per_image:
logger.warning_once("mm_tokens_per_image is deprecated and will be removed in the next release.")
vision_size = _mm_tokens_per_image
else:
vision_size = 256
mm_tokens_per_image = getattr(self.config, "mm_tokens_per_image", 256)

vision = [
{
Expand All @@ -763,7 +752,7 @@ def get_specializations(
"comp_ctx_lengths": comp_ctx_lengths_prefill[i],
"sliding_window": self.language_model.config.sliding_window,
"img_size": img_size,
"vision_size": vision_size,
"mm_tokens_per_image": mm_tokens_per_image,
"vision_batch_size": batch_size,
}
if continuous_batching:
Expand All @@ -782,7 +771,7 @@ def get_specializations(
"comp_ctx_lengths": comp_ctx_lengths_decode[i],
"sliding_window": self.language_model.config.sliding_window,
"img_size": img_size,
"vision_size": vision_size,
"mm_tokens_per_image": mm_tokens_per_image,
"vision_batch_size": batch_size,
}
if continuous_batching:
Expand All @@ -798,7 +787,7 @@ def get_specializations(
"ctx_len": ctx_len,
"sliding_window": self.language_model.config.sliding_window,
"img_size": img_size,
"vision_size": vision_size,
"mm_tokens_per_image": mm_tokens_per_image,
"vision_batch_size": batch_size,
}
if continuous_batching:
Expand All @@ -814,7 +803,7 @@ def get_specializations(
"ctx_len": ctx_len,
"sliding_window": self.language_model.config.sliding_window,
"img_size": img_size,
"vision_size": vision_size,
"mm_tokens_per_image": mm_tokens_per_image,
"vision_batch_size": batch_size,
}
if continuous_batching:
Expand All @@ -840,7 +829,7 @@ def get_onnx_dynamic_axes(
lang_dynamic_axes = {}
lang_dynamic_axes["input_ids"] = {0: "batch_size", 1: "seq_len"}
lang_dynamic_axes["position_ids"] = {0: "batch_size", 1: "seq_len"}
lang_dynamic_axes["vision_embeds"] = {0: "vision_batch_size", 1: "vision_size"}
lang_dynamic_axes["vision_embeds"] = {0: "vision_batch_size", 1: "mm_tokens_per_image"}
if continuous_batching:
lang_dynamic_axes["batch_index"] = {0: "batch_size"}
vision_dynamic_axes["pixel_values"] = {0: "batch_size", 2: "img_size", 3: "img_size"}
Expand Down Expand Up @@ -922,19 +911,13 @@ def get_dummy_inputs(
else:
img_size = 896

_mm_tokens_per_image = getattr(self.config, "mm_tokens_per_image", None)
if _mm_tokens_per_image:
logger.warning_once("mm_tokens_per_image is deprecated and will be removed in the next release.")
vision_size = _mm_tokens_per_image
else:
vision_size = 256

mm_tokens_per_image = getattr(self.config, "mm_tokens_per_image", 256)
# Define shapes
inputs_shapes = {}
inputs_shapes["input_ids"] = (constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN)
inputs_shapes["vision_embeds"] = (
1, # constants.INTERN_NUM_PATCHES,
vision_size, # constants.INTERN_FEATURE_SIZE,
mm_tokens_per_image, # constants.INTERN_FEATURE_SIZE,
self.language_model.config.hidden_size, # 5120
)
inputs_shapes["position_ids"] = (
Expand Down
8 changes: 1 addition & 7 deletions QEfficient/transformers/models/internvl/modeling_internvl.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,13 +134,7 @@ def get_specializations(
raise NotImplementedError("Image Size other than 448 is not supported for Intern models yet.")

per_patch_embed_size = (img_size // self.config.vision_config.patch_size * self.config.downsample_ratio) ** 2
user_vision_size = compiler_options.pop("vision_size", None)
if user_vision_size:
if user_vision_size >= ctx_len:
raise ValueError("vision_size must be less than ctx_len")
vision_size = user_vision_size
else:
vision_size = int(batch_size * num_patches * per_patch_embed_size)
vision_size = int(batch_size * num_patches * per_patch_embed_size)
vision = [
{
"batch_size": batch_size,
Expand Down
8 changes: 1 addition & 7 deletions QEfficient/transformers/models/llama4/modeling_llama4.py
Original file line number Diff line number Diff line change
Expand Up @@ -1001,13 +1001,7 @@ def get_specializations(
* (img_size // self.config.vision_config.patch_size)
// downsample_ratio
)
user_vision_size = compiler_options.pop("vision_size", None)
if user_vision_size:
if user_vision_size >= ctx_len:
raise ValueError("vision_size must be less than ctx_len")
vision_size = user_vision_size
else:
vision_size = num_features_per_tile * max_num_tiles
vision_size = num_features_per_tile * max_num_tiles

vision = [
{
Expand Down
8 changes: 1 addition & 7 deletions QEfficient/transformers/models/llava/modeling_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,13 +237,7 @@ def get_specializations(
logger.warning("Setting img_size to be 336, as it was neither passed nor found in vision_config")
if img_size != 336 and kv_offload:
raise NotImplementedError("Image Size other than 336 is not supported for Llava models yet.")
user_vision_size = compiler_options.pop("vision_size", None)
if user_vision_size:
if user_vision_size >= ctx_len:
raise ValueError("vision_size must be less than ctx_len")
vision_size = user_vision_size
else:
vision_size = (img_size // self.config.vision_config.patch_size) ** 2
vision_size = (img_size // self.config.vision_config.patch_size) ** 2
vision = [
{
"batch_size": batch_size,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -326,13 +326,7 @@ def get_specializations(
logger.warning("Setting img_size to be 384, as it was neither passed nor found in vision_config")
if img_size != constants.GRANITEVISION_IMG_SIZE and kv_offload:
logger.warning("Image Size other than 384 is not supported for LlavaNext models yet.")
user_vision_size = compiler_options.pop("vision_size", None)
if user_vision_size:
if user_vision_size >= ctx_len:
raise ValueError("vision_size must be less than ctx_len")
vision_size = user_vision_size
else:
vision_size = constants.GRANITEVISION_FEATURE_SIZE
vision_size = constants.GRANITEVISION_FEATURE_SIZE
vision = [
{
"batch_size": batch_size,
Expand Down
12 changes: 3 additions & 9 deletions QEfficient/transformers/models/mistral3/modeling_mistral3.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,15 +370,9 @@ def get_specializations(
ctx_len = ctx_len if ctx_len else constants.INTERN_CTX_LEN
patch_size = self.config.vision_config.patch_size
kernel_size = self.config.spatial_merge_size
user_vision_size = compiler_options.pop("vision_size", None)
if user_vision_size:
if user_vision_size >= ctx_len:
raise ValueError("vision_size must be less than ctx_len")
vision_size = user_vision_size
else:
vision_size = (
((img_size // patch_size) * (img_size // patch_size)) * (batch_size) // (kernel_size * kernel_size)
)
vision_size = (
((img_size // patch_size) * (img_size // patch_size)) * (batch_size) // (kernel_size * kernel_size)
)

vision = [
{
Expand Down
15 changes: 7 additions & 8 deletions QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1091,14 +1091,13 @@ def get_specializations(
"resolution."
)
else:
if vision_size * f > user_vision_size:
logger.warning_once(
f"Computed vision_size of {vision_size * f} tokens "
f"(vision_size={vision_size}, num_frames={f}) for image resolution "
f"(width={w}, height={h}) exceed the provided "
f"vision_size={user_vision_size}. "
f"Vision embedding need to be chunked during prefill."
)
assert vision_size * f <= user_vision_size, (
f"Computed vision_size of {vision_size * f} tokens "
f"(vision_size={vision_size}, num_frames={f}) for image resolution "
f"(width={w}, height={h}) cannot exceed the provided "
f"vision_size={user_vision_size}. Please adjust the image resolution or "
"increase the vision_size."
)

vision.append(
{
Expand Down
15 changes: 7 additions & 8 deletions QEfficient/transformers/models/qwen3_vl/modeling_qwen3_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -972,14 +972,13 @@ def get_specializations(
"resolution."
)
else:
if vision_size * f > user_vision_size:
logger.warning_once(
f"Computed vision_size of {vision_size * f} tokens "
f"(vision_size={vision_size}, num_frames={f}) for image resolution "
f"(width={w}, height={h}) exceed the provided "
f"vision_size={user_vision_size}. "
f"Vision embedding need to be chunked during prefill."
)
assert vision_size * f <= user_vision_size, (
f"Computed vision_size of {vision_size * f} tokens "
f"(vision_size={vision_size}, num_frames={f}) for image resolution "
f"(width={w}, height={h}) cannot exceed the provided "
f"vision_size={user_vision_size}. Please adjust the image resolution or "
"increase the vision_size."
)

vision.append(
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1001,14 +1001,13 @@ def get_specializations(
"resolution."
)
else:
if vision_size * f > user_vision_size:
logger.warning_once(
f"Computed vision_size of {vision_size * f} tokens "
f"(vision_size={vision_size}, num_frames={f}) for image resolution "
f"(width={w}, height={h}) exceed the provided "
f"vision_size={user_vision_size}. "
f"Vision embedding need to be chunked during prefill."
)
assert vision_size * f <= user_vision_size, (
f"Computed vision_size of {vision_size * f} tokens "
f"(vision_size={vision_size}, num_frames={f}) for image resolution "
f"(width={w}, height={h}) cannot exceed the provided "
f"vision_size={user_vision_size}. Please adjust the image resolution or "
"increase the vision_size."
)

vision.append(
{
Expand Down
Loading