diff --git a/QEfficient/blocking/attention_blocking.py b/QEfficient/blocking/attention_blocking.py index b75342013..cae884081 100644 --- a/QEfficient/blocking/attention_blocking.py +++ b/QEfficient/blocking/attention_blocking.py @@ -81,6 +81,7 @@ def past_key_value_update( position_ids: Optional[torch.LongTensor] = None, sliding_window: Optional[int] = None, ): + cache_kwargs = {} if past_key_value is not None: cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids} if sliding_window is not None: diff --git a/QEfficient/transformers/models/gemma3/modeling_gemma3.py b/QEfficient/transformers/models/gemma3/modeling_gemma3.py index a3e9257a7..35d9c07cf 100644 --- a/QEfficient/transformers/models/gemma3/modeling_gemma3.py +++ b/QEfficient/transformers/models/gemma3/modeling_gemma3.py @@ -969,16 +969,8 @@ def get_dummy_pkv_cache(self, config, batch_size, seq_len, dtype=None): return past_key_values def get_dummy_inputs( - self, - comp_ctx_lengths: Optional[List[int]] = None, - kv_offload: bool = False, - continuous_batching: bool = False, - **kwargs, + self, comp_ctx_lengths: Optional[List[int]] = None, kv_offload: bool = False, continuous_batching: bool = False ): - prefill_seq_len = kwargs.get("prefill_seq_len") - if prefill_seq_len is None: - prefill_seq_len = constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN - prefill_seq_len = int(prefill_seq_len) if vis_cfg := getattr(self.config, "vision_config", None): img_size = getattr(vis_cfg, "image_size", 896) else: @@ -987,7 +979,7 @@ def get_dummy_inputs( mm_tokens_per_image = getattr(self.config, "mm_tokens_per_image", 256) # Define shapes inputs_shapes = {} - inputs_shapes["input_ids"] = (constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, prefill_seq_len) + inputs_shapes["input_ids"] = (constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN) inputs_shapes["vision_embeds"] = ( 1, # constants.INTERN_NUM_PATCHES, mm_tokens_per_image, # constants.INTERN_FEATURE_SIZE, @@ -995,7 +987,7 @@ def get_dummy_inputs( ) inputs_shapes["position_ids"] = ( constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, - prefill_seq_len, + constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN, ) inputs_shapes["pixel_values"] = ( constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, @@ -1012,8 +1004,8 @@ def get_dummy_inputs( lang_inputs["input_ids"] = torch.zeros((inputs_shapes["input_ids"]), dtype=torch.int64) lang_inputs["vision_embeds"] = torch.zeros((inputs_shapes["vision_embeds"]), dtype=self.config.torch_dtype) lang_inputs["position_ids"] = ( - torch.arange(prefill_seq_len, dtype=torch.int64) - .view(1, prefill_seq_len) + torch.arange(constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN, dtype=torch.int64) + .view(1, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN) .repeat(constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, 1) ) lang_inputs["image_idx"] = torch.zeros((inputs_shapes["image_idx"]), dtype=torch.int64) @@ -1025,7 +1017,7 @@ def get_dummy_inputs( lang_inputs["past_key_values"] = self.get_dummy_pkv_cache( config=self.language_model.config, batch_size=fbs if continuous_batching else bs, - seq_len=prefill_seq_len, + seq_len=constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN, ) if comp_ctx_lengths is not None: diff --git a/QEfficient/transformers/models/internvl/modeling_internvl.py b/QEfficient/transformers/models/internvl/modeling_internvl.py index 563c42e25..821381ac0 100644 --- a/QEfficient/transformers/models/internvl/modeling_internvl.py +++ b/QEfficient/transformers/models/internvl/modeling_internvl.py @@ -273,16 +273,8 @@ def get_output_names(self, kv_offload: bool = False): return output_names def get_dummy_inputs( - self, - comp_ctx_lengths: Optional[List[int]] = None, - kv_offload: bool = False, - continuous_batching: bool = False, - **kwargs, + self, comp_ctx_lengths: Optional[List[int]] = None, kv_offload: bool = False, continuous_batching: bool = False ): - prefill_seq_len = kwargs.get("prefill_seq_len") - if prefill_seq_len is None: - prefill_seq_len = constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN - prefill_seq_len = int(prefill_seq_len) if vis_cfg := getattr(self.config, "vision_config", None): img_size = getattr(vis_cfg, "image_size", constants.INTERN_IMG_SIZE) else: @@ -301,7 +293,7 @@ def get_dummy_inputs( # Define shapes inputs_shapes = {} - inputs_shapes["input_ids"] = (constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, prefill_seq_len) + inputs_shapes["input_ids"] = (constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN) inputs_shapes["vision_embeds"] = ( 1, computed_feature_size * constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, @@ -309,7 +301,7 @@ def get_dummy_inputs( ) inputs_shapes["position_ids"] = ( constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, - prefill_seq_len, + constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN, ) inputs_shapes["pixel_values"] = ( constants.INTERN_NUM_PATCHES * constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, @@ -329,8 +321,8 @@ def get_dummy_inputs( (inputs_shapes["vision_embeds"]), dtype=self.config.vision_config.torch_dtype ) lang_inputs["position_ids"] = ( - torch.arange(prefill_seq_len, dtype=torch.int64) - .view(1, prefill_seq_len) + torch.arange(constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN, dtype=torch.int64) + .view(1, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN) .repeat(constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, 1) ) lang_inputs["image_idx"] = torch.zeros((1, 1), dtype=torch.int64) @@ -342,7 +334,7 @@ def get_dummy_inputs( kv_cache_shape = get_padding_shape_from_config( config=self.language_model.config, batch_size=fbs if continuous_batching else bs, - seq_len=prefill_seq_len, + seq_len=constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN, ) lang_inputs["past_key_values"] = [[] for _ in range(self.language_model.config.num_hidden_layers)] diff --git a/QEfficient/transformers/models/llama4/modeling_llama4.py b/QEfficient/transformers/models/llama4/modeling_llama4.py index 2cf5dbb2e..7f90262be 100644 --- a/QEfficient/transformers/models/llama4/modeling_llama4.py +++ b/QEfficient/transformers/models/llama4/modeling_llama4.py @@ -1185,16 +1185,8 @@ def get_dummy_pkv_cache(self, config, batch_size, seq_len): return past_key_values def get_dummy_inputs( - self, - comp_ctx_lengths: Optional[List[int]] = None, - kv_offload: bool = False, - continuous_batching: bool = False, - **kwargs, + self, comp_ctx_lengths: Optional[List[int]] = None, kv_offload: bool = False, continuous_batching: bool = False ): - prefill_seq_len = kwargs.get("prefill_seq_len") - if prefill_seq_len is None: - prefill_seq_len = constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN - prefill_seq_len = int(prefill_seq_len) if vis_cfg := getattr(self.config, "vision_config", None): img_size = getattr(vis_cfg, "image_size", 336) else: @@ -1202,7 +1194,7 @@ def get_dummy_inputs( # Define shapes inputs_shapes = {} - inputs_shapes["input_ids"] = (constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, prefill_seq_len) + inputs_shapes["input_ids"] = (constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN) max_num_tiles = 17 downsample_ratio = int(round(1.0 / (self.config.vision_config.pixel_shuffle_ratio**2))) num_features_per_tile = int( @@ -1218,7 +1210,7 @@ def get_dummy_inputs( ) inputs_shapes["position_ids"] = ( constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, - prefill_seq_len, + constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN, ) inputs_shapes["pixel_values"] = ( max_num_tiles, # constants.INTERN_NUM_PATCHES, @@ -1234,8 +1226,8 @@ def get_dummy_inputs( lang_inputs["input_ids"] = torch.zeros((inputs_shapes["input_ids"]), dtype=torch.int64) lang_inputs["vision_embeds"] = torch.zeros((inputs_shapes["vision_embeds"]), dtype=self.config.torch_dtype) lang_inputs["position_ids"] = ( - torch.arange(prefill_seq_len, dtype=torch.int64) - .view(1, prefill_seq_len) + torch.arange(constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN, dtype=torch.int64) + .view(1, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN) .repeat(constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, 1) ) lang_inputs["image_idx"] = torch.zeros((inputs_shapes["image_idx"]), dtype=torch.int64) @@ -1247,7 +1239,7 @@ def get_dummy_inputs( past_key_values = self.get_dummy_pkv_cache( config=self.language_model.config, batch_size=fbs if continuous_batching else bs, - seq_len=prefill_seq_len, + seq_len=constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN, ) lang_inputs["past_key_values"] = [[] for _ in range(self.language_model.config.num_hidden_layers)] diff --git a/QEfficient/transformers/models/llava/modeling_llava.py b/QEfficient/transformers/models/llava/modeling_llava.py index 88bb5e102..3fdfd11b9 100644 --- a/QEfficient/transformers/models/llava/modeling_llava.py +++ b/QEfficient/transformers/models/llava/modeling_llava.py @@ -168,10 +168,6 @@ def get_dummy_inputs( continuous_batching: bool = False, **kwargs, ): - prefill_seq_len = kwargs.get("prefill_seq_len") - if prefill_seq_len is None: - prefill_seq_len = SEQ_LEN - prefill_seq_len = int(prefill_seq_len) num_layers = self.config.text_config.num_hidden_layers num_key_value_heads = self.config.text_config.num_key_value_heads head_dim = self.config.text_config.hidden_size // self.config.text_config.num_attention_heads @@ -186,11 +182,11 @@ def get_dummy_inputs( "pixel_values": torch.zeros((BS, NUM_CHANNEL, img_size, img_size), dtype=self.config.torch_dtype), } lang_inputs = { - "input_ids": torch.ones((BS, prefill_seq_len), dtype=torch.int64), + "input_ids": torch.ones((BS, SEQ_LEN), dtype=torch.int64), "vision_embeds": torch.ones( (BS, vision_size, self.model.language_model.config.hidden_size), dtype=self.config.torch_dtype ), - "attention_mask": torch.ones((BS, prefill_seq_len), dtype=torch.int64), + "attention_mask": torch.ones((BS, SEQ_LEN), dtype=torch.int64), "image_idx": torch.zeros((1, 1), dtype=torch.int64), } lang_inputs["position_ids"] = lang_inputs.pop("attention_mask").cumsum(1) diff --git a/QEfficient/transformers/models/llava_next/modeling_llava_next.py b/QEfficient/transformers/models/llava_next/modeling_llava_next.py index 342269ce5..c2a913700 100755 --- a/QEfficient/transformers/models/llava_next/modeling_llava_next.py +++ b/QEfficient/transformers/models/llava_next/modeling_llava_next.py @@ -195,10 +195,6 @@ def get_dummy_inputs( continuous_batching: bool = False, **kwargs, ): - prefill_seq_len = kwargs.get("prefill_seq_len") - if prefill_seq_len is None: - prefill_seq_len = constants.GRANITEVISION_SEQ_LEN - prefill_seq_len = int(prefill_seq_len) num_layers = self.config.text_config.num_hidden_layers num_key_value_heads = self.config.text_config.num_key_value_heads head_dim = self.config.text_config.hidden_size // self.config.text_config.num_attention_heads @@ -225,9 +221,11 @@ def get_dummy_inputs( ), } lang_inputs = { - "input_ids": torch.ones((constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, prefill_seq_len), dtype=torch.int64), + "input_ids": torch.ones( + (constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, constants.GRANITEVISION_SEQ_LEN), dtype=torch.int64 + ), "attention_mask": torch.ones( - (constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, prefill_seq_len), dtype=torch.int64 + (constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, constants.GRANITEVISION_SEQ_LEN), dtype=torch.int64 ), "vision_embeds": torch.ones( ( diff --git a/QEfficient/transformers/models/mistral3/modeling_mistral3.py b/QEfficient/transformers/models/mistral3/modeling_mistral3.py index 628d1dee2..9c3735332 100644 --- a/QEfficient/transformers/models/mistral3/modeling_mistral3.py +++ b/QEfficient/transformers/models/mistral3/modeling_mistral3.py @@ -346,12 +346,8 @@ def get_dummy_inputs( continuous_batching: bool = False, **kwargs, ): - prefill_seq_len = kwargs.get("prefill_seq_len") - if prefill_seq_len is None: - prefill_seq_len = constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN - prefill_seq_len = int(prefill_seq_len) inputs_shapes = {} - inputs_shapes["input_ids"] = (constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, prefill_seq_len) + inputs_shapes["input_ids"] = (constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN) height = self.config.vision_config.image_size width = self.config.vision_config.image_size patch_size = self.config.vision_config.patch_size @@ -367,7 +363,7 @@ def get_dummy_inputs( ) inputs_shapes["position_ids"] = ( constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, - prefill_seq_len, + constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN, ) inputs_shapes["pixel_values"] = ( constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, @@ -384,8 +380,8 @@ def get_dummy_inputs( lang_inputs["input_ids"] = torch.zeros((inputs_shapes["input_ids"]), dtype=torch.int64) lang_inputs["vision_embeds"] = torch.zeros((inputs_shapes["vision_embeds"]), dtype=self.config.torch_dtype) lang_inputs["position_ids"] = ( - torch.arange(prefill_seq_len, dtype=torch.int64) - .view(1, prefill_seq_len) + torch.arange(constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN, dtype=torch.int64) + .view(1, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN) .repeat(constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, 1) ) lang_inputs["image_idx"] = torch.zeros((inputs_shapes["image_idx"]), dtype=torch.int64) @@ -397,7 +393,7 @@ def get_dummy_inputs( kv_cache_shape = get_padding_shape_from_config( config=self.model.config.text_config, batch_size=fbs if continuous_batching else bs, - seq_len=prefill_seq_len, + seq_len=constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN, ) lang_inputs["past_key_values"] = [[] for _ in range(self.model.language_model.config.num_hidden_layers)] diff --git a/QEfficient/transformers/models/mllama/modeling_mllama.py b/QEfficient/transformers/models/mllama/modeling_mllama.py index 45649662a..d9310c02e 100644 --- a/QEfficient/transformers/models/mllama/modeling_mllama.py +++ b/QEfficient/transformers/models/mllama/modeling_mllama.py @@ -924,12 +924,9 @@ def forward( logits = self.lm_head(hidden_states).float() return logits, image_idx, outputs.past_key_values, pixel_values - def get_dummy_inputs(self, comp_ctx_lengths: Optional[List[int]] = None, kv_offload: bool = False, **kwargs): + def get_dummy_inputs(self, comp_ctx_lengths: Optional[List[int]] = None, kv_offload: bool = False): BS = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE - seq_len = kwargs.get("prefill_seq_len") - if seq_len is None: - seq_len = constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN - SEQ_LEN = int(seq_len) + SEQ_LEN = constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN CTX_LEN = constants.ONNX_EXPORT_CTX_LEN txt_cfg = self.config.get_text_config() diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 0b1e3702b..4f5ad61d3 100755 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -1268,6 +1268,9 @@ def __init__( ) self.model = model self.config = model.config + # Propagate qaic_config to the full model so helpers like _is_single_shot_mode + # can detect the mode when get_output_names/get_dummy_inputs are called on it. + model.qaic_config = qaic_config self.vision_model = QEffVisionEncoderForTextImageToTextModel(model, **kwargs) self.lang_model = QEffCausalLMForTextImageToTextModel(model, qaic_config=qaic_config, **kwargs) @@ -1367,17 +1370,12 @@ def export( List[str] A list containing the paths to the generated ONNX graph files for both components. """ - dummy_inputs_kwargs = {} - if prefill_seq_len is not None: - dummy_inputs_kwargs["prefill_seq_len"] = int(prefill_seq_len) - # TODO This is a temporary change as continous batching is enabled only for few models. Once support is added for all the models this exception handing can be removed. try: inputs = self.model.get_dummy_inputs( kv_offload=True, continuous_batching=self.continuous_batching, comp_ctx_lengths=self.comp_ctx_lengths_decode, - **dummy_inputs_kwargs, ) dynamic_axes = self.model.get_onnx_dynamic_axes( kv_offload=True, @@ -1678,6 +1676,10 @@ def compile( elif prefill_seq_len == 1: specializations = specializations["lang"][-1:] qpc_key = "lang_decode_qpc_path" + elif prefill_seq_len is not None and ctx_len is not None and prefill_seq_len == ctx_len: + # Single-shot mode (e.g. reranker): no decode steps, only prefill kernel needed. + specializations = specializations["lang"][:1] + qpc_key = "lang_qpc_path" else: specializations = specializations["lang"] qpc_key = "lang_qpc_path" diff --git a/QEfficient/transformers/models/molmo/modeling_molmo.py b/QEfficient/transformers/models/molmo/modeling_molmo.py index d59ca4e01..3eefba47f 100644 --- a/QEfficient/transformers/models/molmo/modeling_molmo.py +++ b/QEfficient/transformers/models/molmo/modeling_molmo.py @@ -931,13 +931,9 @@ def get_dummy_inputs( continuous_batching: bool = False, **kwargs, ): - prefill_seq_len = kwargs.get("prefill_seq_len") - if prefill_seq_len is None: - prefill_seq_len = constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN - prefill_seq_len = int(prefill_seq_len) inputs_shapes = {} inputs_shapes_lang = {} - inputs_shapes["input_ids"] = (constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, prefill_seq_len) + inputs_shapes["input_ids"] = (constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN) inputs_shapes["vision_embeds"] = ( constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, @@ -946,7 +942,7 @@ def get_dummy_inputs( ) inputs_shapes["position_ids"] = ( constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, - prefill_seq_len, + constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN, ) inputs_shapes["pixel_values"] = ( constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, @@ -980,8 +976,8 @@ def get_dummy_inputs( lang_inputs["input_ids"] = torch.zeros((inputs_shapes["input_ids"]), dtype=torch.int64) lang_inputs["vision_embeds"] = torch.zeros((inputs_shapes["vision_embeds"]), dtype=self.config.torch_dtype) lang_inputs["position_ids"] = ( - torch.arange(prefill_seq_len, dtype=torch.int64) - .view(1, prefill_seq_len) + torch.arange(constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN, dtype=torch.int64) + .view(1, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN) .repeat(constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, 1) ) lang_inputs["image_idx"] = torch.zeros((inputs_shapes["image_idx"]), dtype=torch.int64) @@ -993,7 +989,7 @@ def get_dummy_inputs( kv_cache_shape = get_padding_shape_from_config( config=self.config, batch_size=fbs if continuous_batching else bs, - seq_len=prefill_seq_len, + seq_len=constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN, ) lang_inputs["past_key_values"] = [[] for _ in range(self.model.config.n_layers)] diff --git a/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py b/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py index 357c4af16..dd70a31c9 100644 --- a/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +++ b/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py @@ -831,12 +831,8 @@ def get_dummy_inputs( continuous_batching: bool = False, **kwargs, ): - prefill_seq_len = kwargs.get("prefill_seq_len", constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN) - if prefill_seq_len is None: - prefill_seq_len = constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN - prefill_seq_len = int(prefill_seq_len) inputs_shapes = {} - inputs_shapes["input_ids"] = (constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, prefill_seq_len) + inputs_shapes["input_ids"] = (constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN) vision_size = 3577 inputs_shapes["vision_embeds"] = ( @@ -848,7 +844,7 @@ def get_dummy_inputs( inputs_shapes["position_ids"] = ( 3, constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, - prefill_seq_len, + constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN, ) inputs_shapes["pixel_values"] = (14308, 1176) inputs_shapes["image_idx"] = (1, 1) @@ -862,8 +858,8 @@ def get_dummy_inputs( lang_inputs["vision_embeds"] = torch.zeros((inputs_shapes["vision_embeds"]), dtype=self.config.torch_dtype) lang_inputs["position_ids"] = ( ( - torch.arange(prefill_seq_len, dtype=torch.int64) - .view(1, prefill_seq_len) + torch.arange(constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN, dtype=torch.int64) + .view(1, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN) .repeat(constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, 1) ) .unsqueeze(0) @@ -878,7 +874,7 @@ def get_dummy_inputs( kv_cache_shape = get_padding_shape_from_config( config=self.model.config.text_config, batch_size=fbs if continuous_batching else bs, - seq_len=prefill_seq_len, + seq_len=constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN, ) lang_inputs["past_key_values"] = [[] for _ in range(self.model.config.text_config.num_hidden_layers)] diff --git a/QEfficient/transformers/models/qwen3_vl/_embedding_utils.py b/QEfficient/transformers/models/qwen3_vl/_embedding_utils.py index ca0316371..bce751db9 100644 --- a/QEfficient/transformers/models/qwen3_vl/_embedding_utils.py +++ b/QEfficient/transformers/models/qwen3_vl/_embedding_utils.py @@ -257,9 +257,7 @@ def _collect_contexts(self, inputs: List[Dict[str, Any]]): return contexts, max_prompt_len, max_grid_h, max_grid_w - def get_compile_specs( - self, inputs: List[Dict[str, Any]], ctx_len: int, prefill_seq_len: int = None - ) -> Dict[str, int]: + def get_compile_specs(self, inputs: List[Dict[str, Any]], prefill_seq_len: int = None) -> Dict[str, int]: """Compute compile-time spec values for the current input batch.""" _, max_prompt_len, max_grid_h, max_grid_w = self._collect_contexts(inputs) if max_prompt_len == 0: @@ -275,9 +273,10 @@ def get_compile_specs( height = max_grid_h * patch_size width = max_grid_w * patch_size + # ctx_len == prefill_seq_len always: embedding is single-shot prefill, no decode steps. return { "prefill_seq_len": target_prefill_seq_len, - "ctx_len": int(ctx_len), + "ctx_len": target_prefill_seq_len, "img_size": max(height, width), "height": height, "width": width, diff --git a/QEfficient/transformers/models/qwen3_vl/modeling_qwen3_vl.py b/QEfficient/transformers/models/qwen3_vl/modeling_qwen3_vl.py index 0f6ab210d..cd39a98f0 100644 --- a/QEfficient/transformers/models/qwen3_vl/modeling_qwen3_vl.py +++ b/QEfficient/transformers/models/qwen3_vl/modeling_qwen3_vl.py @@ -57,6 +57,18 @@ def _should_export_embedding_output(module) -> bool: return False +def _is_single_shot_mode(module) -> bool: + """True when model is single-shot prefill only (reranker/embedding) — no KV cache needed.""" + for holder in (module, getattr(module, "model", None)): + if holder is None: + continue + qaic_config = getattr(holder, "qaic_config", None) + if isinstance(qaic_config, dict): + if qaic_config.get("no_kv_cache", False) or qaic_config.get("export_embedding", False): + return True + return False + + def qeff_apply_interleaved_mrope(freqs, mrope_section): """Apply interleaved MRoPE to 3D rotary embeddings. Reorganizes frequency layout from chunked [TTT...HHH...WWW] to @@ -549,7 +561,9 @@ def forward( ) -> Union[Tuple, BaseModelOutputWithPast]: if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") - if self.config.use_cache and not isinstance(past_key_values, Cache): + return_legacy_cache = False + effective_use_cache = use_cache if use_cache is not None else self.config.use_cache + if effective_use_cache and not isinstance(past_key_values, Cache): return_legacy_cache = True past_key_values = QEffDynamicCache.from_legacy_cache(past_key_values) @@ -567,7 +581,11 @@ def forward( elif position_ids.dim() == 2: position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1) - target_length = attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else past_seen_tokens + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else (past_seen_tokens if past_seen_tokens > 0 else inputs_embeds.shape[1]) + ) causal_mask = _create_causal_mask( position_ids=position_ids[0], target_length=target_length, sliding_window=None ) @@ -696,7 +714,7 @@ def forward( deepstack_features, position_ids, image_idx, - past_key_values, + past_key_values=None, batch_index: Optional[torch.LongTensor] = None, comp_ctx_lengths: Optional[List[int]] = None, ): @@ -705,7 +723,7 @@ def forward( selected = input_ids == self.model.config.image_token_id indices1 = selected.to(torch.int64).cumsum(1) - 1 indices1 = torch.where(indices1 != -1, indices1 + image_idx, indices1) - indices0 = torch.arange(selected.unsqueeze(0).shape[0]).view(-1, 1) + indices0 = torch.arange(selected.shape[0], device=selected.device).view(-1, 1) image_features_expanded = vision_embeds.reshape(-1, C).unsqueeze(0)[indices0, indices1] num_features, bs, split_size, C = deepstack_features.shape @@ -723,13 +741,14 @@ def forward( visual_pos_masks = image_mask deepstack_visual_embeds = deepstack_features_expanded + single_shot = _is_single_shot_mode(self) outputs = self.language_model( inputs_embeds=inputs_embeds, position_ids=position_ids, - past_key_values=past_key_values, + past_key_values=None if single_shot else past_key_values, comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, - use_cache=True, + use_cache=not single_shot, visual_pos_masks=visual_pos_masks, deepstack_visual_embeds=deepstack_visual_embeds, ) @@ -737,6 +756,10 @@ def forward( hidden_states = outputs.last_hidden_state[torch.arange(position_ids[0].shape[0]).view(-1, 1), logit_index] logits = self.model.lm_head(hidden_states) image_idx = (indices1.max() + 1).unsqueeze(0).unsqueeze(0) + if single_shot: + if _should_export_embedding_output(self): + return logits, vision_embeds, deepstack_features, image_idx, hidden_states + return logits, vision_embeds, deepstack_features, image_idx if _should_export_embedding_output(self): return logits, vision_embeds, deepstack_features, image_idx, hidden_states, outputs.past_key_values return logits, vision_embeds, deepstack_features, image_idx, outputs.past_key_values @@ -847,13 +870,8 @@ def get_dummy_inputs( continuous_batching: bool = False, **kwargs, ): - prefill_seq_len = kwargs.get("prefill_seq_len", constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN) - if prefill_seq_len is None: - prefill_seq_len = constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN - prefill_seq_len = int(prefill_seq_len) - inputs_shapes = {} - inputs_shapes["input_ids"] = (constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, prefill_seq_len) + inputs_shapes["input_ids"] = (constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN) # vision_size = 1024 vision_size = 187 inputs_shapes["vision_embeds"] = ( @@ -865,7 +883,7 @@ def get_dummy_inputs( inputs_shapes["position_ids"] = ( 3, constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, - prefill_seq_len, + constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN, ) inputs_shapes["pixel_values"] = (748, 1536) inputs_shapes["image_idx"] = (1, 1) @@ -889,8 +907,8 @@ def get_dummy_inputs( ) lang_inputs["position_ids"] = ( ( - torch.arange(prefill_seq_len, dtype=torch.int64) - .view(1, prefill_seq_len) + torch.arange(constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN, dtype=torch.int64) + .view(1, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN) .repeat(constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, 1) ) .unsqueeze(0) @@ -908,7 +926,7 @@ def get_dummy_inputs( kv_cache_shape = get_padding_shape_from_config( config=self.model.config.text_config, batch_size=fbs if continuous_batching else bs, - seq_len=prefill_seq_len, + seq_len=constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN, ) lang_inputs["past_key_values"] = [[] for _ in range(self.model.config.text_config.num_hidden_layers)] @@ -925,6 +943,8 @@ def get_dummy_inputs( lang_inputs["comp_ctx_lengths"] = torch.randint(0, 100, (40,), dtype=torch.int64) inputs = {} if kv_offload: + if _is_single_shot_mode(self): + lang_inputs.pop("past_key_values") inputs["vision"] = vision_inputs inputs["lang"] = lang_inputs else: @@ -1106,6 +1126,11 @@ def smart_resize( lang = [lang_prefill, lang_decode] + # Single-shot (reranker/embedding): no KV cache → ctx_len not referenced in ONNX + if _is_single_shot_mode(self): + for spec in lang: + spec.pop("ctx_len", None) + specializations = {} if kv_offload: @@ -1154,6 +1179,10 @@ def get_onnx_dynamic_axes( dynamic_axes = {} if kv_offload: + if _is_single_shot_mode(self): + for i in range(num_layers): + lang_dynamic_axes.pop(f"past_key.{i}", None) + lang_dynamic_axes.pop(f"past_value.{i}", None) dynamic_axes["vision"] = vision_dynamic_axes dynamic_axes["lang"] = lang_dynamic_axes else: @@ -1171,11 +1200,25 @@ def get_output_names(self, kv_offload: bool = False): output_names = {} if kv_offload: - lang_output_names.insert(1, "vision_embeds_RetainedState") - lang_output_names.insert(2, "image_idx_output") - lang_output_names.insert(2, "deepstack_features_RetainedState") - if _should_export_embedding_output(self): - lang_output_names.insert(4, "embedding_output") + if _is_single_shot_mode(self): + # Single-shot: keep vision/deepstack retained states, drop KV retained states. + # Order matches QEffQwen3VLDecoderWrapper.forward single-shot return: + # reranker: (logits, vision_embeds, deepstack_features, image_idx) + # embedding: (logits, vision_embeds, deepstack_features, image_idx, hidden_states) + lang_output_names = [ + "logits", + "vision_embeds_RetainedState", + "deepstack_features_RetainedState", + "image_idx_output", + ] + if _should_export_embedding_output(self): + lang_output_names.append("embedding_output") # hidden_states is output[4] + else: + lang_output_names.insert(1, "vision_embeds_RetainedState") + lang_output_names.insert(2, "image_idx_output") + lang_output_names.insert(2, "deepstack_features_RetainedState") + if _should_export_embedding_output(self): + lang_output_names.insert(4, "embedding_output") output_names["vision"] = vision_output_names output_names["lang"] = lang_output_names else: diff --git a/QEfficient/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py b/QEfficient/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py index dc741969a..317c5ee26 100644 --- a/QEfficient/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +++ b/QEfficient/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py @@ -867,12 +867,8 @@ def get_dummy_inputs( continuous_batching: bool = False, **kwargs, ): - prefill_seq_len = kwargs.get("prefill_seq_len") - if prefill_seq_len is None: - prefill_seq_len = constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN - prefill_seq_len = int(prefill_seq_len) inputs_shapes = {} - inputs_shapes["input_ids"] = (constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, prefill_seq_len) + inputs_shapes["input_ids"] = (constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN) # vision_size = 1024 vision_size = 187 inputs_shapes["vision_embeds"] = ( @@ -884,7 +880,7 @@ def get_dummy_inputs( inputs_shapes["position_ids"] = ( 3, constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, - prefill_seq_len, + constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN, ) inputs_shapes["pixel_values"] = (748, 1536) inputs_shapes["image_idx"] = (1, 1) @@ -908,8 +904,8 @@ def get_dummy_inputs( ) lang_inputs["position_ids"] = ( ( - torch.arange(prefill_seq_len, dtype=torch.int64) - .view(1, prefill_seq_len) + torch.arange(constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN, dtype=torch.int64) + .view(1, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN) .repeat(constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, 1) ) .unsqueeze(0) @@ -927,7 +923,7 @@ def get_dummy_inputs( kv_cache_shape = get_padding_shape_from_config( config=self.model.config.text_config, batch_size=fbs if continuous_batching else bs, - seq_len=prefill_seq_len, + seq_len=constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN, ) lang_inputs["past_key_values"] = [[] for _ in range(self.model.config.text_config.num_hidden_layers)] diff --git a/QEfficient/transformers/models/whisper/modeling_whisper.py b/QEfficient/transformers/models/whisper/modeling_whisper.py index 1bdcd07ad..4c3016628 100644 --- a/QEfficient/transformers/models/whisper/modeling_whisper.py +++ b/QEfficient/transformers/models/whisper/modeling_whisper.py @@ -792,10 +792,9 @@ def forward( def get_dummy_inputs( self, - **kwargs, ): bs = 1 - seq_len = int(kwargs.get("prefill_seq_len", ONNX_EXPORT_EXAMPLE_SEQ_LEN)) + seq_len = ONNX_EXPORT_EXAMPLE_SEQ_LEN encoder_seq_len = self.config.max_source_positions encoder_feature_count = self.config.num_mel_bins num_key_value_heads = self.config.decoder_attention_heads diff --git a/examples/embeddings/qwen3vl/README.md b/examples/embeddings/qwen3vl/README.md index cff14908c..6f89fade0 100644 --- a/examples/embeddings/qwen3vl/README.md +++ b/examples/embeddings/qwen3vl/README.md @@ -40,7 +40,6 @@ With compile parameters: ```bash python examples/embeddings/qwen3vl/qwen3_vl_embedding.py \ --model-name Qwen/Qwen3-VL-Embedding-8B \ - --ctx-len 2048 \ --num-cores 16 \ --num-devices 1 \ --compile-prefill-seq-len 4096 \ diff --git a/examples/embeddings/qwen3vl/qwen3_vl_embedding.py b/examples/embeddings/qwen3vl/qwen3_vl_embedding.py index bd707ffb0..b3124352a 100644 --- a/examples/embeddings/qwen3vl/qwen3_vl_embedding.py +++ b/examples/embeddings/qwen3vl/qwen3_vl_embedding.py @@ -24,7 +24,6 @@ from QEfficient.transformers.models.qwen3_vl._embedding_utils import configure_embedding_model_config DEFAULT_MODEL_NAME = "Qwen/Qwen3-VL-Embedding-8B" -DEFAULT_CTX_LEN = 2048 DEFAULT_NUM_CORES = 16 DEFAULT_NUM_DEVICES = 1 DEFAULT_NUM_HIDDEN_LAYERS = 36 @@ -36,7 +35,6 @@ def parse_args() -> argparse.Namespace: """Parse command-line arguments for AI100 compile/inference knobs.""" parser = argparse.ArgumentParser(description="Qwen3-VL embedding example.") parser.add_argument("--model-name", type=str, default=DEFAULT_MODEL_NAME) - parser.add_argument("--ctx-len", type=int, default=DEFAULT_CTX_LEN, help="Context length used at compile time.") parser.add_argument("--num-cores", type=int, default=DEFAULT_NUM_CORES, help="Number of AI100 cores.") parser.add_argument("--num-devices", type=int, default=DEFAULT_NUM_DEVICES, help="Number of AI100 devices.") parser.add_argument( @@ -121,7 +119,6 @@ def main() -> None: # 3) Derive compile requirements from current payload. compile_specs = embedder.get_compile_specs( inputs=model_inputs, - ctx_len=args.ctx_len, prefill_seq_len=args.compile_prefill_seq_len, ) diff --git a/examples/reranker/qwen3vl/README.md b/examples/reranker/qwen3vl/README.md index d9d96645a..74bc9d4a2 100644 --- a/examples/reranker/qwen3vl/README.md +++ b/examples/reranker/qwen3vl/README.md @@ -49,7 +49,6 @@ With compile parameters: ```bash python examples/reranker/qwen3vl/qwen3_vl_reranker.py \ --model-name Qwen/Qwen3-VL-Reranker-2B \ - --ctx-len 2048 \ --num-cores 16 \ --num-devices 1 \ --compile-prefill-seq-len 4096 \ diff --git a/examples/reranker/qwen3vl/qwen3_vl_reranker.py b/examples/reranker/qwen3vl/qwen3_vl_reranker.py index 01884d0d0..a3d05c3d2 100644 --- a/examples/reranker/qwen3vl/qwen3_vl_reranker.py +++ b/examples/reranker/qwen3vl/qwen3_vl_reranker.py @@ -30,7 +30,6 @@ def parse_args() -> argparse.Namespace: """Parse command-line arguments for AI100 compile/inference knobs.""" parser = argparse.ArgumentParser(description="Qwen3-VL reranker example.") parser.add_argument("--model-name", type=str, default="Qwen/Qwen3-VL-Reranker-2B") - parser.add_argument("--ctx-len", type=int, default=2048, help="Context length used at compile time.") parser.add_argument("--num-cores", type=int, default=16, help="Number of AI100 cores.") parser.add_argument("--num-devices", type=int, default=1, help="Number of AI100 devices.") parser.add_argument( @@ -97,6 +96,7 @@ def main() -> None: kv_offload=True, trust_remote_code=True, config=config, + qaic_config={"no_kv_cache": True}, ) # 2) Build reranker helper and reference payload. @@ -106,7 +106,6 @@ def main() -> None: # 3) Derive compile requirements from current payload. compile_specs = reranker.get_compile_specs( inputs=inputs, - ctx_len=args.ctx_len, prefill_seq_len=args.compile_prefill_seq_len, ) diff --git a/examples/reranker/qwen3vl/reranker_model.py b/examples/reranker/qwen3vl/reranker_model.py index 33e73b05f..32c4e65ea 100644 --- a/examples/reranker/qwen3vl/reranker_model.py +++ b/examples/reranker/qwen3vl/reranker_model.py @@ -173,7 +173,7 @@ def _collect_contexts(self, inputs: Dict): return prepared_contexts, max_prompt_len, max_grid_h, max_grid_w - def get_compile_specs(self, inputs: Dict, ctx_len: int, prefill_seq_len: int = None) -> Dict[str, int]: + def get_compile_specs(self, inputs: Dict, prefill_seq_len: int = None) -> Dict[str, int]: """Return compile parameters required for this input batch.""" _, max_prompt_len, max_grid_h, max_grid_w = self._collect_contexts(inputs) if max_prompt_len == 0: @@ -189,9 +189,10 @@ def get_compile_specs(self, inputs: Dict, ctx_len: int, prefill_seq_len: int = N height = max_grid_h * patch_size width = max_grid_w * patch_size + # ctx_len == prefill_seq_len always: reranker is single-shot prefill, no decode steps. return { "prefill_seq_len": target_prefill_seq_len, - "ctx_len": int(ctx_len), + "ctx_len": target_prefill_seq_len, "img_size": max(height, width), "height": height, "width": width, diff --git a/tests/configs/image_text_model_configs.json b/tests/configs/image_text_model_configs.json index 85df55997..d98d0e08a 100644 --- a/tests/configs/image_text_model_configs.json +++ b/tests/configs/image_text_model_configs.json @@ -724,42 +724,11 @@ } } ], - "image_text_reranker_models": [ - { - "model_name": "Qwen/Qwen3-VL-Reranker-2B", - "model_type": "qwen3_vl", - "batch_size": 1, - "prompt_len": 128, - "ctx_len": 1024, - "img_size": 1540, - "img_url": "https://picsum.photos/id/237/536/354", - "instruction": "Retrieve candidates relevant to the query.", - "query_text": "A woman playing with her dog on a beach at sunset.", - "document_text": "A woman and her dog spend time together on a beach during sunset.", - "num_layers": 1, - "additional_params": {} - }, - { - "model_name": "Qwen/Qwen3-VL-Reranker-8B", - "model_type": "qwen3_vl", - "batch_size": 1, - "prompt_len": 128, - "ctx_len": 1024, - "img_size": 1540, - "img_url": "https://picsum.photos/id/237/536/354", - "instruction": "Retrieve candidates relevant to the query.", - "query_text": "A woman playing with her dog on a beach at sunset.", - "document_text": "A woman and her dog spend time together on a beach during sunset.", - "num_layers": 1, - "additional_params": {} - } - ], "image_text_embedding_models": [ { "model_name": "Qwen/Qwen3-VL-Embedding-8B", "model_type": "qwen3_vl", "batch_size": 1, - "ctx_len": 2048, "num_layers": 1, "vision_depth": 9, "deepstack_index": 8, diff --git a/tests/configs/reranker_model_configs.json b/tests/configs/reranker_model_configs.json new file mode 100644 index 000000000..4427b9da0 --- /dev/null +++ b/tests/configs/reranker_model_configs.json @@ -0,0 +1,28 @@ +[ + { + "model_name": "Qwen/Qwen3-VL-Reranker-2B", + "model_type": "qwen3_vl", + "batch_size": 1, + "prompt_len": 128, + "img_size": 1540, + "img_url": "https://picsum.photos/id/237/536/354", + "instruction": "Retrieve candidates relevant to the query.", + "query_text": "A woman playing with her dog on a beach at sunset.", + "document_text": "A woman and her dog spend time together on a beach during sunset.", + "num_layers": 1, + "additional_params": {} + }, + { + "model_name": "Qwen/Qwen3-VL-Reranker-8B", + "model_type": "qwen3_vl", + "batch_size": 1, + "prompt_len": 128, + "img_size": 1540, + "img_url": "https://picsum.photos/id/237/536/354", + "instruction": "Retrieve candidates relevant to the query.", + "query_text": "A woman playing with her dog on a beach at sunset.", + "document_text": "A woman and her dog spend time together on a beach during sunset.", + "num_layers": 1, + "additional_params": {} + } +] diff --git a/tests/transformers/models/embedding_models/test_qwen3vl_embedding_mad.py b/tests/transformers/models/embedding_models/test_qwen3vl_embedding_mad.py index d540593b8..885372355 100644 --- a/tests/transformers/models/embedding_models/test_qwen3vl_embedding_mad.py +++ b/tests/transformers/models/embedding_models/test_qwen3vl_embedding_mad.py @@ -108,7 +108,6 @@ def test_qwen3_vl_embedding_cpu_vs_ai100_mad_parity(model_name): model_inputs = EXAMPLE_QUERIES + EXAMPLE_DOCUMENTS compile_specs = embedder.get_compile_specs( inputs=model_inputs, - ctx_len=model_cfg["ctx_len"], prefill_seq_len=model_cfg.get("compile_prefill_seq_len", None), ) qpc_paths = qeff_model.compile( diff --git a/tests/transformers/models/reranker/test_reranker_mad.py b/tests/transformers/models/reranker/test_reranker_mad.py index 148935c5a..4677f9693 100644 --- a/tests/transformers/models/reranker/test_reranker_mad.py +++ b/tests/transformers/models/reranker/test_reranker_mad.py @@ -39,7 +39,7 @@ ) from QEfficient.utils.test_utils import load_vlm_model, set_num_layers_vlm -CONFIG_PATH = os.path.join(os.path.dirname(__file__), "../../../configs/image_text_model_configs.json") +CONFIG_PATH = os.path.join(os.path.dirname(__file__), "../../../configs/reranker_model_configs.json") PT_AI100_MAD_MAX = 5e-3 MAX_LENGTH = 8192 @@ -60,8 +60,7 @@ } with open(CONFIG_PATH, "r") as f: - config_data = json.load(f) - reranker_models = config_data["image_text_reranker_models"] + reranker_models = json.load(f) test_reranker_models = [model_config["model_name"] for model_config in reranker_models] reranker_model_config_dict = {model["model_name"]: model for model in reranker_models} @@ -298,7 +297,7 @@ def test_qwen3_vl_reranker_mad_parity(model_name): height=compile_height, width=compile_width, prefill_seq_len=max_prompt_len, - ctx_len=model_cfg["ctx_len"], + ctx_len=max_prompt_len, num_devices=1, num_cores=16, mxfp6_matmul=False, diff --git a/tests/unit_test/models/embedding/test_qwen3vl_embedding_unit.py b/tests/unit_test/models/embedding/test_qwen3vl_embedding_unit.py index ae7c88e83..a602a0f7d 100644 --- a/tests/unit_test/models/embedding/test_qwen3vl_embedding_unit.py +++ b/tests/unit_test/models/embedding/test_qwen3vl_embedding_unit.py @@ -118,8 +118,8 @@ def _fake_run_ai100_prefill(prepared_inputs, vision_outputs, lang_qpc_path): monkeypatch.setattr(QEffQwen3VLEmbedder, "_run_ai100_vision", staticmethod(_fake_run_ai100_vision)) monkeypatch.setattr(QEffQwen3VLEmbedder, "_run_ai100_prefill", staticmethod(_fake_run_ai100_prefill)) - compile_specs = embedder.get_compile_specs(inputs=[{}, {}], ctx_len=64, prefill_seq_len=12) - assert compile_specs == {"prefill_seq_len": 12, "ctx_len": 64, "img_size": 160, "height": 96, "width": 160} + compile_specs = embedder.get_compile_specs(inputs=[{}, {}], prefill_seq_len=12) + assert compile_specs == {"prefill_seq_len": 12, "ctx_len": 12, "img_size": 160, "height": 96, "width": 160} embeddings = embedder.process( inputs=[{}, {}], diff --git a/tests/unit_test/models/reranker/test_reranker_models_unit.py b/tests/unit_test/models/reranker/test_reranker_models_unit.py index f3036502e..7d1321a98 100644 --- a/tests/unit_test/models/reranker/test_reranker_models_unit.py +++ b/tests/unit_test/models/reranker/test_reranker_models_unit.py @@ -8,13 +8,12 @@ Generic unit coverage for image-text reranker model entries. This test is intentionally model-list driven: - - Add/remove reranker models only in tests/configs/image_text_model_configs.json + - Add/remove reranker models only in tests/configs/reranker_model_configs.json - The same unit checks run for every configured reranker model """ import copy import json -import os from typing import Dict, List import pytest @@ -22,13 +21,12 @@ from QEfficient.utils.test_utils import set_num_layers_vlm -CONFIG_PATH = os.path.join(os.path.dirname(__file__), "../../../configs/image_text_model_configs.json") +CONFIG_PATH = "tests/configs/reranker_model_configs.json" def _load_reranker_model_configs() -> List[Dict]: with open(CONFIG_PATH, "r", encoding="utf-8") as file: - config_data = json.load(file) - return config_data.get("image_text_reranker_models", []) + return json.load(file) RERANKER_MODEL_CONFIGS = _load_reranker_model_configs() @@ -51,7 +49,7 @@ def _vision_num_layers(config) -> int: def test_reranker_model_list_is_present(): assert RERANKER_MODEL_CONFIGS, ( - "image_text_reranker_models is empty. Add reranker entries in tests/configs/image_text_model_configs.json." + "reranker_model_configs.json is empty. Add reranker entries in tests/configs/reranker_model_configs.json." )