Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions QEfficient/blocking/attention_blocking.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ def past_key_value_update(
position_ids: Optional[torch.LongTensor] = None,
sliding_window: Optional[int] = None,
):
cache_kwargs = {}
if past_key_value is not None:
cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids}
if sliding_window is not None:
Expand Down
20 changes: 6 additions & 14 deletions QEfficient/transformers/models/gemma3/modeling_gemma3.py
Original file line number Diff line number Diff line change
Expand Up @@ -969,16 +969,8 @@ def get_dummy_pkv_cache(self, config, batch_size, seq_len, dtype=None):
return past_key_values

def get_dummy_inputs(
self,
comp_ctx_lengths: Optional[List[int]] = None,
kv_offload: bool = False,
continuous_batching: bool = False,
**kwargs,
self, comp_ctx_lengths: Optional[List[int]] = None, kv_offload: bool = False, continuous_batching: bool = False
):
prefill_seq_len = kwargs.get("prefill_seq_len")
if prefill_seq_len is None:
prefill_seq_len = constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN
prefill_seq_len = int(prefill_seq_len)
if vis_cfg := getattr(self.config, "vision_config", None):
img_size = getattr(vis_cfg, "image_size", 896)
else:
Expand All @@ -987,15 +979,15 @@ def get_dummy_inputs(
mm_tokens_per_image = getattr(self.config, "mm_tokens_per_image", 256)
# Define shapes
inputs_shapes = {}
inputs_shapes["input_ids"] = (constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, prefill_seq_len)
inputs_shapes["input_ids"] = (constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN)
inputs_shapes["vision_embeds"] = (
1, # constants.INTERN_NUM_PATCHES,
mm_tokens_per_image, # constants.INTERN_FEATURE_SIZE,
self.language_model.config.hidden_size, # 5120
)
inputs_shapes["position_ids"] = (
constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE,
prefill_seq_len,
constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN,
)
inputs_shapes["pixel_values"] = (
constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE,
Expand All @@ -1012,8 +1004,8 @@ def get_dummy_inputs(
lang_inputs["input_ids"] = torch.zeros((inputs_shapes["input_ids"]), dtype=torch.int64)
lang_inputs["vision_embeds"] = torch.zeros((inputs_shapes["vision_embeds"]), dtype=self.config.torch_dtype)
lang_inputs["position_ids"] = (
torch.arange(prefill_seq_len, dtype=torch.int64)
.view(1, prefill_seq_len)
torch.arange(constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN, dtype=torch.int64)
.view(1, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN)
.repeat(constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, 1)
)
lang_inputs["image_idx"] = torch.zeros((inputs_shapes["image_idx"]), dtype=torch.int64)
Expand All @@ -1025,7 +1017,7 @@ def get_dummy_inputs(
lang_inputs["past_key_values"] = self.get_dummy_pkv_cache(
config=self.language_model.config,
batch_size=fbs if continuous_batching else bs,
seq_len=prefill_seq_len,
seq_len=constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN,
)

if comp_ctx_lengths is not None:
Expand Down
20 changes: 6 additions & 14 deletions QEfficient/transformers/models/internvl/modeling_internvl.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,16 +273,8 @@ def get_output_names(self, kv_offload: bool = False):
return output_names

def get_dummy_inputs(
self,
comp_ctx_lengths: Optional[List[int]] = None,
kv_offload: bool = False,
continuous_batching: bool = False,
**kwargs,
self, comp_ctx_lengths: Optional[List[int]] = None, kv_offload: bool = False, continuous_batching: bool = False
):
prefill_seq_len = kwargs.get("prefill_seq_len")
if prefill_seq_len is None:
prefill_seq_len = constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN
prefill_seq_len = int(prefill_seq_len)
if vis_cfg := getattr(self.config, "vision_config", None):
img_size = getattr(vis_cfg, "image_size", constants.INTERN_IMG_SIZE)
else:
Expand All @@ -301,15 +293,15 @@ def get_dummy_inputs(

# Define shapes
inputs_shapes = {}
inputs_shapes["input_ids"] = (constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, prefill_seq_len)
inputs_shapes["input_ids"] = (constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN)
inputs_shapes["vision_embeds"] = (
1,
computed_feature_size * constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE,
self.language_model.config.hidden_size,
)
inputs_shapes["position_ids"] = (
constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE,
prefill_seq_len,
constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN,
)
inputs_shapes["pixel_values"] = (
constants.INTERN_NUM_PATCHES * constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE,
Expand All @@ -329,8 +321,8 @@ def get_dummy_inputs(
(inputs_shapes["vision_embeds"]), dtype=self.config.vision_config.torch_dtype
)
lang_inputs["position_ids"] = (
torch.arange(prefill_seq_len, dtype=torch.int64)
.view(1, prefill_seq_len)
torch.arange(constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN, dtype=torch.int64)
.view(1, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN)
.repeat(constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, 1)
)
lang_inputs["image_idx"] = torch.zeros((1, 1), dtype=torch.int64)
Expand All @@ -342,7 +334,7 @@ def get_dummy_inputs(
kv_cache_shape = get_padding_shape_from_config(
config=self.language_model.config,
batch_size=fbs if continuous_batching else bs,
seq_len=prefill_seq_len,
seq_len=constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN,
)

lang_inputs["past_key_values"] = [[] for _ in range(self.language_model.config.num_hidden_layers)]
Expand Down
20 changes: 6 additions & 14 deletions QEfficient/transformers/models/llama4/modeling_llama4.py
Original file line number Diff line number Diff line change
Expand Up @@ -1185,24 +1185,16 @@ def get_dummy_pkv_cache(self, config, batch_size, seq_len):
return past_key_values

def get_dummy_inputs(
self,
comp_ctx_lengths: Optional[List[int]] = None,
kv_offload: bool = False,
continuous_batching: bool = False,
**kwargs,
self, comp_ctx_lengths: Optional[List[int]] = None, kv_offload: bool = False, continuous_batching: bool = False
):
prefill_seq_len = kwargs.get("prefill_seq_len")
if prefill_seq_len is None:
prefill_seq_len = constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN
prefill_seq_len = int(prefill_seq_len)
if vis_cfg := getattr(self.config, "vision_config", None):
img_size = getattr(vis_cfg, "image_size", 336)
else:
img_size = 336

# Define shapes
inputs_shapes = {}
inputs_shapes["input_ids"] = (constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, prefill_seq_len)
inputs_shapes["input_ids"] = (constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN)
max_num_tiles = 17
downsample_ratio = int(round(1.0 / (self.config.vision_config.pixel_shuffle_ratio**2)))
num_features_per_tile = int(
Expand All @@ -1218,7 +1210,7 @@ def get_dummy_inputs(
)
inputs_shapes["position_ids"] = (
constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE,
prefill_seq_len,
constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN,
)
inputs_shapes["pixel_values"] = (
max_num_tiles, # constants.INTERN_NUM_PATCHES,
Expand All @@ -1234,8 +1226,8 @@ def get_dummy_inputs(
lang_inputs["input_ids"] = torch.zeros((inputs_shapes["input_ids"]), dtype=torch.int64)
lang_inputs["vision_embeds"] = torch.zeros((inputs_shapes["vision_embeds"]), dtype=self.config.torch_dtype)
lang_inputs["position_ids"] = (
torch.arange(prefill_seq_len, dtype=torch.int64)
.view(1, prefill_seq_len)
torch.arange(constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN, dtype=torch.int64)
.view(1, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN)
.repeat(constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, 1)
)
lang_inputs["image_idx"] = torch.zeros((inputs_shapes["image_idx"]), dtype=torch.int64)
Expand All @@ -1247,7 +1239,7 @@ def get_dummy_inputs(
past_key_values = self.get_dummy_pkv_cache(
config=self.language_model.config,
batch_size=fbs if continuous_batching else bs,
seq_len=prefill_seq_len,
seq_len=constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN,
)

lang_inputs["past_key_values"] = [[] for _ in range(self.language_model.config.num_hidden_layers)]
Expand Down
8 changes: 2 additions & 6 deletions QEfficient/transformers/models/llava/modeling_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,10 +168,6 @@ def get_dummy_inputs(
continuous_batching: bool = False,
**kwargs,
):
prefill_seq_len = kwargs.get("prefill_seq_len")
if prefill_seq_len is None:
prefill_seq_len = SEQ_LEN
prefill_seq_len = int(prefill_seq_len)
num_layers = self.config.text_config.num_hidden_layers
num_key_value_heads = self.config.text_config.num_key_value_heads
head_dim = self.config.text_config.hidden_size // self.config.text_config.num_attention_heads
Expand All @@ -186,11 +182,11 @@ def get_dummy_inputs(
"pixel_values": torch.zeros((BS, NUM_CHANNEL, img_size, img_size), dtype=self.config.torch_dtype),
}
lang_inputs = {
"input_ids": torch.ones((BS, prefill_seq_len), dtype=torch.int64),
"input_ids": torch.ones((BS, SEQ_LEN), dtype=torch.int64),
"vision_embeds": torch.ones(
(BS, vision_size, self.model.language_model.config.hidden_size), dtype=self.config.torch_dtype
),
"attention_mask": torch.ones((BS, prefill_seq_len), dtype=torch.int64),
"attention_mask": torch.ones((BS, SEQ_LEN), dtype=torch.int64),
"image_idx": torch.zeros((1, 1), dtype=torch.int64),
}
lang_inputs["position_ids"] = lang_inputs.pop("attention_mask").cumsum(1)
Expand Down
10 changes: 4 additions & 6 deletions QEfficient/transformers/models/llava_next/modeling_llava_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,10 +195,6 @@ def get_dummy_inputs(
continuous_batching: bool = False,
**kwargs,
):
prefill_seq_len = kwargs.get("prefill_seq_len")
if prefill_seq_len is None:
prefill_seq_len = constants.GRANITEVISION_SEQ_LEN
prefill_seq_len = int(prefill_seq_len)
num_layers = self.config.text_config.num_hidden_layers
num_key_value_heads = self.config.text_config.num_key_value_heads
head_dim = self.config.text_config.hidden_size // self.config.text_config.num_attention_heads
Expand All @@ -225,9 +221,11 @@ def get_dummy_inputs(
),
}
lang_inputs = {
"input_ids": torch.ones((constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, prefill_seq_len), dtype=torch.int64),
"input_ids": torch.ones(
(constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, constants.GRANITEVISION_SEQ_LEN), dtype=torch.int64
),
"attention_mask": torch.ones(
(constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, prefill_seq_len), dtype=torch.int64
(constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, constants.GRANITEVISION_SEQ_LEN), dtype=torch.int64
),
"vision_embeds": torch.ones(
(
Expand Down
14 changes: 5 additions & 9 deletions QEfficient/transformers/models/mistral3/modeling_mistral3.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,12 +346,8 @@ def get_dummy_inputs(
continuous_batching: bool = False,
**kwargs,
):
prefill_seq_len = kwargs.get("prefill_seq_len")
if prefill_seq_len is None:
prefill_seq_len = constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN
prefill_seq_len = int(prefill_seq_len)
inputs_shapes = {}
inputs_shapes["input_ids"] = (constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, prefill_seq_len)
inputs_shapes["input_ids"] = (constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN)
height = self.config.vision_config.image_size
width = self.config.vision_config.image_size
patch_size = self.config.vision_config.patch_size
Expand All @@ -367,7 +363,7 @@ def get_dummy_inputs(
)
inputs_shapes["position_ids"] = (
constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE,
prefill_seq_len,
constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN,
)
inputs_shapes["pixel_values"] = (
constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE,
Expand All @@ -384,8 +380,8 @@ def get_dummy_inputs(
lang_inputs["input_ids"] = torch.zeros((inputs_shapes["input_ids"]), dtype=torch.int64)
lang_inputs["vision_embeds"] = torch.zeros((inputs_shapes["vision_embeds"]), dtype=self.config.torch_dtype)
lang_inputs["position_ids"] = (
torch.arange(prefill_seq_len, dtype=torch.int64)
.view(1, prefill_seq_len)
torch.arange(constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN, dtype=torch.int64)
.view(1, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN)
.repeat(constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, 1)
)
lang_inputs["image_idx"] = torch.zeros((inputs_shapes["image_idx"]), dtype=torch.int64)
Expand All @@ -397,7 +393,7 @@ def get_dummy_inputs(
kv_cache_shape = get_padding_shape_from_config(
config=self.model.config.text_config,
batch_size=fbs if continuous_batching else bs,
seq_len=prefill_seq_len,
seq_len=constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN,
)

lang_inputs["past_key_values"] = [[] for _ in range(self.model.language_model.config.num_hidden_layers)]
Expand Down
7 changes: 2 additions & 5 deletions QEfficient/transformers/models/mllama/modeling_mllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -924,12 +924,9 @@ def forward(
logits = self.lm_head(hidden_states).float()
return logits, image_idx, outputs.past_key_values, pixel_values

def get_dummy_inputs(self, comp_ctx_lengths: Optional[List[int]] = None, kv_offload: bool = False, **kwargs):
def get_dummy_inputs(self, comp_ctx_lengths: Optional[List[int]] = None, kv_offload: bool = False):
BS = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE
seq_len = kwargs.get("prefill_seq_len")
if seq_len is None:
seq_len = constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN
SEQ_LEN = int(seq_len)
SEQ_LEN = constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN
CTX_LEN = constants.ONNX_EXPORT_CTX_LEN

txt_cfg = self.config.get_text_config()
Expand Down
12 changes: 7 additions & 5 deletions QEfficient/transformers/models/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -1268,6 +1268,9 @@ def __init__(
)
self.model = model
self.config = model.config
# Propagate qaic_config to the full model so helpers like _is_single_shot_mode
# can detect the mode when get_output_names/get_dummy_inputs are called on it.
model.qaic_config = qaic_config

self.vision_model = QEffVisionEncoderForTextImageToTextModel(model, **kwargs)
self.lang_model = QEffCausalLMForTextImageToTextModel(model, qaic_config=qaic_config, **kwargs)
Expand Down Expand Up @@ -1367,17 +1370,12 @@ def export(
List[str]
A list containing the paths to the generated ONNX graph files for both components.
"""
dummy_inputs_kwargs = {}
if prefill_seq_len is not None:
dummy_inputs_kwargs["prefill_seq_len"] = int(prefill_seq_len)

# TODO This is a temporary change as continous batching is enabled only for few models. Once support is added for all the models this exception handing can be removed.
try:
inputs = self.model.get_dummy_inputs(
kv_offload=True,
continuous_batching=self.continuous_batching,
comp_ctx_lengths=self.comp_ctx_lengths_decode,
**dummy_inputs_kwargs,
)
dynamic_axes = self.model.get_onnx_dynamic_axes(
kv_offload=True,
Expand Down Expand Up @@ -1678,6 +1676,10 @@ def compile(
elif prefill_seq_len == 1:
specializations = specializations["lang"][-1:]
qpc_key = "lang_decode_qpc_path"
elif prefill_seq_len is not None and ctx_len is not None and prefill_seq_len == ctx_len:
# Single-shot mode (e.g. reranker): no decode steps, only prefill kernel needed.
specializations = specializations["lang"][:1]
qpc_key = "lang_qpc_path"
else:
specializations = specializations["lang"]
qpc_key = "lang_qpc_path"
Expand Down
14 changes: 5 additions & 9 deletions QEfficient/transformers/models/molmo/modeling_molmo.py
Original file line number Diff line number Diff line change
Expand Up @@ -931,13 +931,9 @@ def get_dummy_inputs(
continuous_batching: bool = False,
**kwargs,
):
prefill_seq_len = kwargs.get("prefill_seq_len")
if prefill_seq_len is None:
prefill_seq_len = constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN
prefill_seq_len = int(prefill_seq_len)
inputs_shapes = {}
inputs_shapes_lang = {}
inputs_shapes["input_ids"] = (constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, prefill_seq_len)
inputs_shapes["input_ids"] = (constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN)

inputs_shapes["vision_embeds"] = (
constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE,
Expand All @@ -946,7 +942,7 @@ def get_dummy_inputs(
)
inputs_shapes["position_ids"] = (
constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE,
prefill_seq_len,
constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN,
)
inputs_shapes["pixel_values"] = (
constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE,
Expand Down Expand Up @@ -980,8 +976,8 @@ def get_dummy_inputs(
lang_inputs["input_ids"] = torch.zeros((inputs_shapes["input_ids"]), dtype=torch.int64)
lang_inputs["vision_embeds"] = torch.zeros((inputs_shapes["vision_embeds"]), dtype=self.config.torch_dtype)
lang_inputs["position_ids"] = (
torch.arange(prefill_seq_len, dtype=torch.int64)
.view(1, prefill_seq_len)
torch.arange(constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN, dtype=torch.int64)
.view(1, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN)
.repeat(constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, 1)
)
lang_inputs["image_idx"] = torch.zeros((inputs_shapes["image_idx"]), dtype=torch.int64)
Expand All @@ -993,7 +989,7 @@ def get_dummy_inputs(
kv_cache_shape = get_padding_shape_from_config(
config=self.config,
batch_size=fbs if continuous_batching else bs,
seq_len=prefill_seq_len,
seq_len=constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN,
)

lang_inputs["past_key_values"] = [[] for _ in range(self.model.config.n_layers)]
Expand Down
Loading
Loading