diff --git a/QEfficient/blocking/attention_blocking.py b/QEfficient/blocking/attention_blocking.py index b75342013..cae884081 100644 --- a/QEfficient/blocking/attention_blocking.py +++ b/QEfficient/blocking/attention_blocking.py @@ -81,6 +81,7 @@ def past_key_value_update( position_ids: Optional[torch.LongTensor] = None, sliding_window: Optional[int] = None, ): + cache_kwargs = {} if past_key_value is not None: cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids} if sliding_window is not None: diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 65b89d274..68a125c08 100755 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -1379,17 +1379,12 @@ def export( List[str] A list containing the paths to the generated ONNX graph files for both components. """ - dummy_inputs_kwargs = {} - if prefill_seq_len is not None: - dummy_inputs_kwargs["prefill_seq_len"] = int(prefill_seq_len) - # TODO This is a temporary change as continous batching is enabled only for few models. Once support is added for all the models this exception handing can be removed. try: inputs = self.model.get_dummy_inputs( kv_offload=True, continuous_batching=self.continuous_batching, comp_ctx_lengths=self.comp_ctx_lengths_decode, - **dummy_inputs_kwargs, ) dynamic_axes = self.model.get_onnx_dynamic_axes( kv_offload=True, @@ -1733,6 +1728,10 @@ def filter_custom_io_lang(custom_io_lang, onnx_path): elif prefill_seq_len == 1: specializations = specializations["lang"][-1:] qpc_key = "lang_decode_qpc_path" + elif prefill_seq_len is not None and ctx_len is not None and prefill_seq_len == ctx_len: + # Single-shot mode (e.g. reranker): no decode steps, only prefill kernel needed. + specializations = specializations["lang"][:1] + qpc_key = "lang_qpc_path" else: specializations = specializations["lang"] qpc_key = "lang_qpc_path" @@ -2426,6 +2425,11 @@ def compile( **compiler_options, ) + # Single-shot mode (reranker/embedding): no decode steps, only prefill kernel needed. + single_shot = prefill_seq_len is not None and ctx_len is not None and prefill_seq_len == ctx_len + if single_shot: + specializations = specializations[:1] + if hasattr(self.model, "get_npi_file") and "node_precision_info" not in compiler_options: compiler_options["node_precision_info"] = self.model.get_npi_file(self.model.name_or_path) @@ -2446,6 +2450,11 @@ def compile( CUSTOM_IO_DTYPE_MAP[target_dtype] if "pixel_values" in output_name else kv_cache_dtype ) + # Single-shot mode has no retained state; pixel_values is a direct input so + # its dtype must still be set explicitly (float16 for hardware). + if single_shot: + custom_io["pixel_values"] = CUSTOM_IO_DTYPE_MAP[target_dtype] + # TODO this hould be removed once the continous batching is supported for all the models. compiler_options.pop("continuous_batching", None) compiler_options.pop("kv_cache_batch_size", None) @@ -2453,7 +2462,8 @@ def compile( self._compile( onnx_path=onnx_path, compile_dir=compile_dir, - retained_state=True, + # Single-shot (reranker/embedding): no decode, no need for retained-state enforcement. + retained_state=not single_shot, specializations=specializations, convert_to_fp16=(CUSTOM_IO_DTYPE_MAP[target_dtype] == "float16"), mxfp6_matmul=mxfp6_matmul, diff --git a/QEfficient/transformers/models/qwen3_vl/_embedding_utils.py b/QEfficient/transformers/models/qwen3_vl/_embedding_utils.py index ca0316371..5c99a867a 100644 --- a/QEfficient/transformers/models/qwen3_vl/_embedding_utils.py +++ b/QEfficient/transformers/models/qwen3_vl/_embedding_utils.py @@ -257,9 +257,7 @@ def _collect_contexts(self, inputs: List[Dict[str, Any]]): return contexts, max_prompt_len, max_grid_h, max_grid_w - def get_compile_specs( - self, inputs: List[Dict[str, Any]], ctx_len: int, prefill_seq_len: int = None - ) -> Dict[str, int]: + def get_compile_specs(self, inputs: List[Dict[str, Any]], prefill_seq_len: int = None) -> Dict[str, int]: """Compute compile-time spec values for the current input batch.""" _, max_prompt_len, max_grid_h, max_grid_w = self._collect_contexts(inputs) if max_prompt_len == 0: @@ -275,9 +273,10 @@ def get_compile_specs( height = max_grid_h * patch_size width = max_grid_w * patch_size + # ctx_len == prefill_seq_len always: embedding is single-shot prefill, no decode steps. return { "prefill_seq_len": target_prefill_seq_len, - "ctx_len": int(ctx_len), + "ctx_len": target_prefill_seq_len, "img_size": max(height, width), "height": height, "width": width, @@ -352,17 +351,71 @@ def _run_ai100_prefill( embedding_output = embedding_output.reshape(embedding_output.shape[0], -1) return embedding_output + @staticmethod + def _run_ai100_single_qpc_prefill( + prepared_inputs: Dict[str, torch.Tensor], + qpc_path, + ) -> np.ndarray: + """Execute single-QPC (vision+language fused) prefill and return the embedding row.""" + prefill_len = prepared_inputs["position_ids"].shape[-1] + input_ids = prepared_inputs["input_ids"] + if input_ids.shape[1] < prefill_len: + pad = torch.full( + (input_ids.shape[0], prefill_len - input_ids.shape[1]), + 1, + dtype=input_ids.dtype, + device=input_ids.device, + ) + input_ids = torch.cat([input_ids, pad], dim=1) + else: + input_ids = input_ids[:, :prefill_len] + + position_ids = prepared_inputs["position_ids"][..., :prefill_len] + + session = QAICInferenceSession(str(qpc_path)) + + run_inputs = { + "input_ids": input_ids.detach().cpu().numpy().astype(np.int64), + "position_ids": position_ids.detach().cpu().numpy().astype(np.int64), + "image_idx": np.zeros((1, 1), dtype=np.int64), + } + + if "pixel_values" in prepared_inputs: + run_inputs["pixel_values"] = prepared_inputs["pixel_values"].detach().cpu().numpy().astype(np.float16) + else: + pv_idx = session.binding_index_map["pixel_values"] + run_inputs["pixel_values"] = np.zeros(session.bindings[pv_idx].dims, dtype=np.float16) + + for name in session.input_names: + if name.startswith("past_"): + idx = session.binding_index_map[name] + run_inputs[name] = np.zeros(session.bindings[idx].dims, dtype=np.float16) + + outputs = session.run(run_inputs) + session.deactivate() + + if "embedding_output" not in outputs: + raise KeyError( + "Missing 'embedding_output' in single-QPC outputs. Ensure export_embedding is enabled in qaic_config." + ) + + embedding_output = outputs["embedding_output"] + if embedding_output.ndim > 2: + embedding_output = embedding_output.reshape(embedding_output.shape[0], -1) + return embedding_output + def process( self, inputs: List[Dict[str, Any]], - qpc_paths: Dict[str, str], + qpc_paths, prefill_seq_len: int, normalize: bool = True, ) -> torch.Tensor: - """Run AI100 embedding generation for all inputs and return stacked rows.""" - if "vision_qpc_path" not in qpc_paths or "lang_qpc_path" not in qpc_paths: - raise ValueError("qpc_paths must contain 'vision_qpc_path' and 'lang_qpc_path'.") + """Run AI100 embedding generation for all inputs and return stacked rows. + Supports both dual-QPC (qpc_paths is a dict with 'vision_qpc_path' and + 'lang_qpc_path') and single-QPC (qpc_paths is a str/Path to the combined QPC). + """ contexts, max_prompt_len, _, _ = self._collect_contexts(inputs) if max_prompt_len == 0: return torch.empty((0, 0), dtype=torch.float32) @@ -374,7 +427,6 @@ def process( ) prepared_contexts = [] - vision_template = None for ctx in contexts: prepared_inputs, _ = self._prepare_qeff_inputs( qeff_model=self.model, @@ -383,32 +435,47 @@ def process( ) prepared_contexts.append({"prepared_inputs": prepared_inputs}) - if vision_template is None and "pixel_values" in prepared_inputs and "image_grid_thw" in prepared_inputs: - vision_template = self._run_ai100_vision( - vision_qpc_path=qpc_paths["vision_qpc_path"], - prepared_inputs=prepared_inputs, - ) - - if vision_template is None: - raise ValueError("At least one input with an image is required to initialize AI100 vision buffers.") - + is_dual_qpc = isinstance(qpc_paths, dict) embedding_rows = [] - for ctx in prepared_contexts: - prepared_inputs = ctx["prepared_inputs"] - if "pixel_values" in prepared_inputs and "image_grid_thw" in prepared_inputs: - vision_outputs = self._run_ai100_vision( - vision_qpc_path=qpc_paths["vision_qpc_path"], + + if is_dual_qpc: + if "vision_qpc_path" not in qpc_paths or "lang_qpc_path" not in qpc_paths: + raise ValueError("qpc_paths must contain 'vision_qpc_path' and 'lang_qpc_path'.") + + vision_template = None + for ctx in prepared_contexts: + if vision_template is None and "pixel_values" in ctx["prepared_inputs"]: + vision_template = self._run_ai100_vision( + vision_qpc_path=qpc_paths["vision_qpc_path"], + prepared_inputs=ctx["prepared_inputs"], + ) + + if vision_template is None: + raise ValueError("At least one input with an image is required to initialize AI100 vision buffers.") + + for ctx in prepared_contexts: + prepared_inputs = ctx["prepared_inputs"] + if "pixel_values" in prepared_inputs and "image_grid_thw" in prepared_inputs: + vision_outputs = self._run_ai100_vision( + vision_qpc_path=qpc_paths["vision_qpc_path"], + prepared_inputs=prepared_inputs, + ) + else: + vision_outputs = self._zero_vision_outputs(vision_template) + + embedding_output = self._run_ai100_prefill( prepared_inputs=prepared_inputs, + vision_outputs=vision_outputs, + lang_qpc_path=qpc_paths["lang_qpc_path"], ) - else: - vision_outputs = self._zero_vision_outputs(vision_template) - - embedding_output = self._run_ai100_prefill( - prepared_inputs=prepared_inputs, - vision_outputs=vision_outputs, - lang_qpc_path=qpc_paths["lang_qpc_path"], - ) - embedding_rows.append(torch.from_numpy(embedding_output).to(torch.float32)) + embedding_rows.append(torch.from_numpy(embedding_output).to(torch.float32)) + else: + # Single QPC: vision + language fused in one compiled binary. + for ctx in prepared_contexts: + embedding_output = self._run_ai100_single_qpc_prefill( + prepared_inputs=ctx["prepared_inputs"], qpc_path=qpc_paths + ) + embedding_rows.append(torch.from_numpy(embedding_output).to(torch.float32)) embeddings = torch.cat(embedding_rows, dim=0) if normalize: diff --git a/QEfficient/transformers/models/qwen3_vl/modeling_qwen3_vl.py b/QEfficient/transformers/models/qwen3_vl/modeling_qwen3_vl.py index 0f6ab210d..39faee754 100644 --- a/QEfficient/transformers/models/qwen3_vl/modeling_qwen3_vl.py +++ b/QEfficient/transformers/models/qwen3_vl/modeling_qwen3_vl.py @@ -57,6 +57,18 @@ def _should_export_embedding_output(module) -> bool: return False +def _is_single_shot_mode(module) -> bool: + """True when model is single-shot prefill only (reranker/embedding) — no KV cache needed.""" + for holder in (module, getattr(module, "model", None)): + if holder is None: + continue + qaic_config = getattr(holder, "qaic_config", None) + if isinstance(qaic_config, dict): + if qaic_config.get("no_kv_cache", False) or qaic_config.get("export_embedding", False): + return True + return False + + def qeff_apply_interleaved_mrope(freqs, mrope_section): """Apply interleaved MRoPE to 3D rotary embeddings. Reorganizes frequency layout from chunked [TTT...HHH...WWW] to @@ -549,7 +561,9 @@ def forward( ) -> Union[Tuple, BaseModelOutputWithPast]: if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") - if self.config.use_cache and not isinstance(past_key_values, Cache): + return_legacy_cache = False + effective_use_cache = use_cache if use_cache is not None else self.config.use_cache + if effective_use_cache and not isinstance(past_key_values, Cache): return_legacy_cache = True past_key_values = QEffDynamicCache.from_legacy_cache(past_key_values) @@ -567,7 +581,11 @@ def forward( elif position_ids.dim() == 2: position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1) - target_length = attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else past_seen_tokens + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else (past_seen_tokens if past_seen_tokens > 0 else inputs_embeds.shape[1]) + ) causal_mask = _create_causal_mask( position_ids=position_ids[0], target_length=target_length, sliding_window=None ) @@ -805,7 +823,7 @@ def forward( self, input_ids, position_ids, - past_key_values, + past_key_values=None, pixel_values: Optional[torch.FloatTensor] = None, image_idx: Optional[torch.LongTensor] = None, comp_ctx_lengths: Optional[List[int]] = None, @@ -819,23 +837,29 @@ def forward( selected = input_ids == self.model.config.image_token_id indices1 = selected.to(torch.int64).cumsum(1) - 1 indices1 = torch.where(indices1 != -1, indices1 + image_idx, indices1) - indices0 = torch.arange(selected.unsqueeze(0).shape[0]).view(-1, 1) + indices0 = torch.arange(selected.shape[0], device=selected.device).view(-1, 1) image_features_expanded = image_embeds.reshape(-1, C).unsqueeze(0)[indices0, indices1] # TODO: deepstack_features are not processed for single QPC setup yet. Will do if required. image_input_embeds = torch.where(selected.unsqueeze(-1), image_features_expanded, inputs_embeds) inputs_embeds = torch.where(input_ids.shape[1] == torch.tensor(1), inputs_embeds, image_input_embeds) - outputs = self.language_model( + + single_shot = _is_single_shot_mode(self) + outputs = self.model.language_model( inputs_embeds=inputs_embeds, position_ids=position_ids, - past_key_values=past_key_values, + past_key_values=None if single_shot else past_key_values, comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, - use_cache=True, + use_cache=not single_shot, ) logit_index = position_ids[0].to(torch.int32).argmax(1, keepdim=True) hidden_states = outputs.last_hidden_state[torch.arange(position_ids[0].shape[0]).view(-1, 1), logit_index] logits = self.lm_head(hidden_states) image_idx = (indices1.max() + 1).unsqueeze(0).unsqueeze(0) + if single_shot: + if _should_export_embedding_output(self): + return logits, image_embeds, image_idx, hidden_states + return logits, image_embeds, image_idx if _should_export_embedding_output(self): return logits, image_embeds, image_idx, hidden_states, outputs.past_key_values return logits, image_embeds, image_idx, outputs.past_key_values @@ -847,13 +871,8 @@ def get_dummy_inputs( continuous_batching: bool = False, **kwargs, ): - prefill_seq_len = kwargs.get("prefill_seq_len", constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN) - if prefill_seq_len is None: - prefill_seq_len = constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN - prefill_seq_len = int(prefill_seq_len) - inputs_shapes = {} - inputs_shapes["input_ids"] = (constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, prefill_seq_len) + inputs_shapes["input_ids"] = (constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN) # vision_size = 1024 vision_size = 187 inputs_shapes["vision_embeds"] = ( @@ -865,7 +884,7 @@ def get_dummy_inputs( inputs_shapes["position_ids"] = ( 3, constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, - prefill_seq_len, + constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN, ) inputs_shapes["pixel_values"] = (748, 1536) inputs_shapes["image_idx"] = (1, 1) @@ -889,8 +908,8 @@ def get_dummy_inputs( ) lang_inputs["position_ids"] = ( ( - torch.arange(prefill_seq_len, dtype=torch.int64) - .view(1, prefill_seq_len) + torch.arange(constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN, dtype=torch.int64) + .view(1, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN) .repeat(constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, 1) ) .unsqueeze(0) @@ -908,7 +927,7 @@ def get_dummy_inputs( kv_cache_shape = get_padding_shape_from_config( config=self.model.config.text_config, batch_size=fbs if continuous_batching else bs, - seq_len=prefill_seq_len, + seq_len=constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN, ) lang_inputs["past_key_values"] = [[] for _ in range(self.model.config.text_config.num_hidden_layers)] @@ -929,6 +948,8 @@ def get_dummy_inputs( inputs["lang"] = lang_inputs else: lang_inputs.pop("vision_embeds") + if _is_single_shot_mode(self): + lang_inputs.pop("past_key_values") inputs = {**vision_inputs, **lang_inputs} return inputs @@ -1113,8 +1134,15 @@ def smart_resize( specializations["lang"] = lang return specializations, compiler_options else: - lang[0].pop("vision_size") - lang[1].pop("vision_size") + # Single QPC: pixel_values and image_grid_thw are direct inputs, + # so the compiler needs the vision spatial symbols in every spec. + for lang_spec in lang: + lang_spec.pop("vision_size") + lang_spec["grid_height"] = grid_height + lang_spec["grid_width"] = grid_width + lang_spec["grid_h"] = grid_h + lang_spec["grid_w"] = grid_w + lang_spec["time"] = time return lang, compiler_options def get_onnx_dynamic_axes( @@ -1158,6 +1186,12 @@ def get_onnx_dynamic_axes( dynamic_axes["lang"] = lang_dynamic_axes else: lang_dynamic_axes.pop("vision_embeds") + # deepstack_features are computed internally by vision encoder in single QPC — not a direct input + vision_dynamic_axes.pop("deepstack_features") + if _is_single_shot_mode(self): + for i in range(num_layers): + lang_dynamic_axes.pop(f"past_key.{i}", None) + lang_dynamic_axes.pop(f"past_value.{i}", None) dynamic_axes = {**vision_dynamic_axes, **lang_dynamic_axes} return dynamic_axes @@ -1179,6 +1213,13 @@ def get_output_names(self, kv_offload: bool = False): output_names["vision"] = vision_output_names output_names["lang"] = lang_output_names else: + if _is_single_shot_mode(self): + # Single-shot forward returns: (logits, image_embeds, image_idx) + # embedding adds hidden_states: (logits, image_embeds, image_idx, hidden_states) + single_shot_outputs = ["logits", "image_embeds", "image_idx_output"] + if _should_export_embedding_output(self): + single_shot_outputs.append("embedding_output") + return single_shot_outputs lang_output_names.insert(1, "pixel_values_RetainedState") lang_output_names.insert(2, "image_idx_output") if _should_export_embedding_output(self): diff --git a/QEfficient/transformers/models/whisper/modeling_whisper.py b/QEfficient/transformers/models/whisper/modeling_whisper.py index 1bdcd07ad..4c3016628 100644 --- a/QEfficient/transformers/models/whisper/modeling_whisper.py +++ b/QEfficient/transformers/models/whisper/modeling_whisper.py @@ -792,10 +792,9 @@ def forward( def get_dummy_inputs( self, - **kwargs, ): bs = 1 - seq_len = int(kwargs.get("prefill_seq_len", ONNX_EXPORT_EXAMPLE_SEQ_LEN)) + seq_len = ONNX_EXPORT_EXAMPLE_SEQ_LEN encoder_seq_len = self.config.max_source_positions encoder_feature_count = self.config.num_mel_bins num_key_value_heads = self.config.decoder_attention_heads diff --git a/examples/embeddings/qwen3vl/README.md b/examples/embeddings/qwen3vl/README.md index cff14908c..6f89fade0 100644 --- a/examples/embeddings/qwen3vl/README.md +++ b/examples/embeddings/qwen3vl/README.md @@ -40,7 +40,6 @@ With compile parameters: ```bash python examples/embeddings/qwen3vl/qwen3_vl_embedding.py \ --model-name Qwen/Qwen3-VL-Embedding-8B \ - --ctx-len 2048 \ --num-cores 16 \ --num-devices 1 \ --compile-prefill-seq-len 4096 \ diff --git a/examples/embeddings/qwen3vl/qwen3_vl_embedding.py b/examples/embeddings/qwen3vl/qwen3_vl_embedding.py index bd707ffb0..076d3f058 100644 --- a/examples/embeddings/qwen3vl/qwen3_vl_embedding.py +++ b/examples/embeddings/qwen3vl/qwen3_vl_embedding.py @@ -24,7 +24,6 @@ from QEfficient.transformers.models.qwen3_vl._embedding_utils import configure_embedding_model_config DEFAULT_MODEL_NAME = "Qwen/Qwen3-VL-Embedding-8B" -DEFAULT_CTX_LEN = 2048 DEFAULT_NUM_CORES = 16 DEFAULT_NUM_DEVICES = 1 DEFAULT_NUM_HIDDEN_LAYERS = 36 @@ -36,7 +35,6 @@ def parse_args() -> argparse.Namespace: """Parse command-line arguments for AI100 compile/inference knobs.""" parser = argparse.ArgumentParser(description="Qwen3-VL embedding example.") parser.add_argument("--model-name", type=str, default=DEFAULT_MODEL_NAME) - parser.add_argument("--ctx-len", type=int, default=DEFAULT_CTX_LEN, help="Context length used at compile time.") parser.add_argument("--num-cores", type=int, default=DEFAULT_NUM_CORES, help="Number of AI100 cores.") parser.add_argument("--num-devices", type=int, default=DEFAULT_NUM_DEVICES, help="Number of AI100 devices.") parser.add_argument( @@ -107,7 +105,7 @@ def main() -> None: processor = AutoProcessor.from_pretrained(model_source, trust_remote_code=True, padding=True) model = QEFFAutoModelForImageTextToText.from_pretrained( model_source, - kv_offload=True, + kv_offload=False, trust_remote_code=True, config=config, qaic_config={"export_embedding": True}, @@ -121,7 +119,6 @@ def main() -> None: # 3) Derive compile requirements from current payload. compile_specs = embedder.get_compile_specs( inputs=model_inputs, - ctx_len=args.ctx_len, prefill_seq_len=args.compile_prefill_seq_len, ) diff --git a/examples/image_text_to_text/models/qwen3vl/qwen3_vl.py b/examples/image_text_to_text/models/qwen3vl/qwen3_vl.py index 6b86ea874..6ba5cfba3 100644 --- a/examples/image_text_to_text/models/qwen3vl/qwen3_vl.py +++ b/examples/image_text_to_text/models/qwen3vl/qwen3_vl.py @@ -16,9 +16,9 @@ model_id = "Qwen/Qwen3-VL-32B-Instruct" config = AutoConfig.from_pretrained(model_id) -# config.vision_config.depth = 9 -# config.text_config.num_hidden_layers = 1 -# config.vision_config.deepstack_visual_indexes = [8] +config.vision_config.depth = 9 +config.text_config.num_hidden_layers = 1 +config.vision_config.deepstack_visual_indexes = [8] qeff_model = QEFFAutoModelForImageTextToText.from_pretrained( model_id, attn_implementation="eager", kv_offload=True, config=config diff --git a/examples/reranker/qwen3vl/README.md b/examples/reranker/qwen3vl/README.md index d9d96645a..74bc9d4a2 100644 --- a/examples/reranker/qwen3vl/README.md +++ b/examples/reranker/qwen3vl/README.md @@ -49,7 +49,6 @@ With compile parameters: ```bash python examples/reranker/qwen3vl/qwen3_vl_reranker.py \ --model-name Qwen/Qwen3-VL-Reranker-2B \ - --ctx-len 2048 \ --num-cores 16 \ --num-devices 1 \ --compile-prefill-seq-len 4096 \ diff --git a/examples/reranker/qwen3vl/qwen3_vl_reranker.py b/examples/reranker/qwen3vl/qwen3_vl_reranker.py index 01884d0d0..d6e0c5e63 100644 --- a/examples/reranker/qwen3vl/qwen3_vl_reranker.py +++ b/examples/reranker/qwen3vl/qwen3_vl_reranker.py @@ -30,7 +30,6 @@ def parse_args() -> argparse.Namespace: """Parse command-line arguments for AI100 compile/inference knobs.""" parser = argparse.ArgumentParser(description="Qwen3-VL reranker example.") parser.add_argument("--model-name", type=str, default="Qwen/Qwen3-VL-Reranker-2B") - parser.add_argument("--ctx-len", type=int, default=2048, help="Context length used at compile time.") parser.add_argument("--num-cores", type=int, default=16, help="Number of AI100 cores.") parser.add_argument("--num-devices", type=int, default=1, help="Number of AI100 devices.") parser.add_argument( @@ -94,9 +93,10 @@ def main() -> None: processor = AutoProcessor.from_pretrained(model_source, trust_remote_code=True) model = QEFFAutoModelForImageTextToText.from_pretrained( model_source, - kv_offload=True, + kv_offload=False, trust_remote_code=True, config=config, + qaic_config={"no_kv_cache": True}, ) # 2) Build reranker helper and reference payload. @@ -106,7 +106,6 @@ def main() -> None: # 3) Derive compile requirements from current payload. compile_specs = reranker.get_compile_specs( inputs=inputs, - ctx_len=args.ctx_len, prefill_seq_len=args.compile_prefill_seq_len, ) diff --git a/examples/reranker/qwen3vl/reranker_model.py b/examples/reranker/qwen3vl/reranker_model.py index 33e73b05f..b82143dd1 100644 --- a/examples/reranker/qwen3vl/reranker_model.py +++ b/examples/reranker/qwen3vl/reranker_model.py @@ -173,7 +173,7 @@ def _collect_contexts(self, inputs: Dict): return prepared_contexts, max_prompt_len, max_grid_h, max_grid_w - def get_compile_specs(self, inputs: Dict, ctx_len: int, prefill_seq_len: int = None) -> Dict[str, int]: + def get_compile_specs(self, inputs: Dict, prefill_seq_len: int = None) -> Dict[str, int]: """Return compile parameters required for this input batch.""" _, max_prompt_len, max_grid_h, max_grid_w = self._collect_contexts(inputs) if max_prompt_len == 0: @@ -189,9 +189,10 @@ def get_compile_specs(self, inputs: Dict, ctx_len: int, prefill_seq_len: int = N height = max_grid_h * patch_size width = max_grid_w * patch_size + # ctx_len == prefill_seq_len always: reranker is single-shot prefill, no decode steps. return { "prefill_seq_len": target_prefill_seq_len, - "ctx_len": int(ctx_len), + "ctx_len": target_prefill_seq_len, "img_size": max(height, width), "height": height, "width": width, @@ -264,8 +265,51 @@ def _run_ai100_prefill( lang_session.deactivate() return outputs["logits"] - def process(self, inputs: Dict, qpc_paths: Dict[str, str], prefill_seq_len: int) -> List[float]: - """Score all documents for one query on AI100 using precompiled QPCs.""" + @staticmethod + def _run_ai100_single_qpc_prefill(prepared_inputs: Dict, qpc_path: str) -> np.ndarray: + """Run single-QPC (vision+language fused) prefill and return logits.""" + prefill_len = prepared_inputs["position_ids"].shape[-1] + input_ids = prepared_inputs["input_ids"] + if input_ids.shape[1] < prefill_len: + pad = torch.full( + (input_ids.shape[0], prefill_len - input_ids.shape[1]), + 1, + dtype=input_ids.dtype, + device=input_ids.device, + ) + input_ids = torch.cat([input_ids, pad], dim=1) + else: + input_ids = input_ids[:, :prefill_len] + + position_ids = prepared_inputs["position_ids"][..., :prefill_len] + + session = QAICInferenceSession(str(qpc_path)) + + run_inputs = { + "input_ids": input_ids.detach().cpu().numpy().astype(np.int64), + "position_ids": position_ids.detach().cpu().numpy().astype(np.int64), + "image_idx": np.zeros((1, 1), dtype=np.int64), + } + + # image_grid_thw is baked as a constant during single-QPC ONNX tracing; only + # pixel_values remains as a dynamic input for the vision encoder. + if "pixel_values" in prepared_inputs: + run_inputs["pixel_values"] = prepared_inputs["pixel_values"].detach().cpu().numpy().astype(np.float16) + else: + # Text-only: pass zeros with the shape fixed at compile time. + pv_idx = session.binding_index_map["pixel_values"] + run_inputs["pixel_values"] = np.zeros(session.bindings[pv_idx].dims, dtype=np.float16) + + outputs = session.run(run_inputs) + session.deactivate() + return outputs["logits"] + + def process(self, inputs: Dict, qpc_paths, prefill_seq_len: int) -> List[float]: + """Score all documents for one query on AI100 using precompiled QPCs. + + Supports both dual-QPC (qpc_paths is a dict with 'vision_qpc_path' and + 'lang_qpc_path') and single-QPC (qpc_paths is a str/Path to the combined QPC). + """ prepared_contexts, max_prompt_len, _, _ = self._collect_contexts(inputs) if max_prompt_len == 0: return [] @@ -276,30 +320,40 @@ def process(self, inputs: Dict, qpc_paths: Dict[str, str], prefill_seq_len: int) f"prefill_seq_len ({target_prefill_seq_len}) must be >= max runtime prompt length ({max_prompt_len})." ) - if "vision_qpc_path" not in qpc_paths or "lang_qpc_path" not in qpc_paths: - raise ValueError("qpc_paths must contain 'vision_qpc_path' and 'lang_qpc_path'.") - prepared_contexts_with_prefill = [] - vision_template = None for ctx in prepared_contexts: prepared_inputs = self._prepare_inputs(ctx["tokenized"], prefill_seq_len=target_prefill_seq_len) prepared_contexts_with_prefill.append({"prepared_inputs": prepared_inputs}) - if vision_template is None and "pixel_values" in prepared_inputs and "image_grid_thw" in prepared_inputs: - vision_template = self._run_ai100_vision(prepared_inputs, vision_qpc_path=qpc_paths["vision_qpc_path"]) - - if vision_template is None: - raise ValueError("At least one image document is required to initialize AI100 vision buffers.") - + is_dual_qpc = isinstance(qpc_paths, dict) scores = [] - for ctx in prepared_contexts_with_prefill: - logits = self._run_ai100_prefill( - ctx["prepared_inputs"], - vision_template=vision_template, - lang_qpc_path=qpc_paths["lang_qpc_path"], - vision_qpc_path=qpc_paths["vision_qpc_path"], - ) - score = self._score_from_logits(logits, self.yes_token_id, self.no_token_id) - scores.append(score) + + if is_dual_qpc: + if "vision_qpc_path" not in qpc_paths or "lang_qpc_path" not in qpc_paths: + raise ValueError("qpc_paths must contain 'vision_qpc_path' and 'lang_qpc_path'.") + + vision_template = None + for ctx in prepared_contexts_with_prefill: + if vision_template is None and "pixel_values" in ctx["prepared_inputs"]: + vision_template = self._run_ai100_vision( + ctx["prepared_inputs"], vision_qpc_path=qpc_paths["vision_qpc_path"] + ) + + if vision_template is None: + raise ValueError("At least one image document is required to initialize AI100 vision buffers.") + + for ctx in prepared_contexts_with_prefill: + logits = self._run_ai100_prefill( + ctx["prepared_inputs"], + vision_template=vision_template, + lang_qpc_path=qpc_paths["lang_qpc_path"], + vision_qpc_path=qpc_paths["vision_qpc_path"], + ) + scores.append(self._score_from_logits(logits, self.yes_token_id, self.no_token_id)) + else: + # Single QPC: vision + language fused in one compiled binary. + for ctx in prepared_contexts_with_prefill: + logits = self._run_ai100_single_qpc_prefill(ctx["prepared_inputs"], qpc_path=qpc_paths) + scores.append(self._score_from_logits(logits, self.yes_token_id, self.no_token_id)) return scores diff --git a/tests/configs/image_text_model_configs.json b/tests/configs/image_text_model_configs.json index 85df55997..d98d0e08a 100644 --- a/tests/configs/image_text_model_configs.json +++ b/tests/configs/image_text_model_configs.json @@ -724,42 +724,11 @@ } } ], - "image_text_reranker_models": [ - { - "model_name": "Qwen/Qwen3-VL-Reranker-2B", - "model_type": "qwen3_vl", - "batch_size": 1, - "prompt_len": 128, - "ctx_len": 1024, - "img_size": 1540, - "img_url": "https://picsum.photos/id/237/536/354", - "instruction": "Retrieve candidates relevant to the query.", - "query_text": "A woman playing with her dog on a beach at sunset.", - "document_text": "A woman and her dog spend time together on a beach during sunset.", - "num_layers": 1, - "additional_params": {} - }, - { - "model_name": "Qwen/Qwen3-VL-Reranker-8B", - "model_type": "qwen3_vl", - "batch_size": 1, - "prompt_len": 128, - "ctx_len": 1024, - "img_size": 1540, - "img_url": "https://picsum.photos/id/237/536/354", - "instruction": "Retrieve candidates relevant to the query.", - "query_text": "A woman playing with her dog on a beach at sunset.", - "document_text": "A woman and her dog spend time together on a beach during sunset.", - "num_layers": 1, - "additional_params": {} - } - ], "image_text_embedding_models": [ { "model_name": "Qwen/Qwen3-VL-Embedding-8B", "model_type": "qwen3_vl", "batch_size": 1, - "ctx_len": 2048, "num_layers": 1, "vision_depth": 9, "deepstack_index": 8, diff --git a/tests/configs/reranker_model_configs.json b/tests/configs/reranker_model_configs.json new file mode 100644 index 000000000..4427b9da0 --- /dev/null +++ b/tests/configs/reranker_model_configs.json @@ -0,0 +1,28 @@ +[ + { + "model_name": "Qwen/Qwen3-VL-Reranker-2B", + "model_type": "qwen3_vl", + "batch_size": 1, + "prompt_len": 128, + "img_size": 1540, + "img_url": "https://picsum.photos/id/237/536/354", + "instruction": "Retrieve candidates relevant to the query.", + "query_text": "A woman playing with her dog on a beach at sunset.", + "document_text": "A woman and her dog spend time together on a beach during sunset.", + "num_layers": 1, + "additional_params": {} + }, + { + "model_name": "Qwen/Qwen3-VL-Reranker-8B", + "model_type": "qwen3_vl", + "batch_size": 1, + "prompt_len": 128, + "img_size": 1540, + "img_url": "https://picsum.photos/id/237/536/354", + "instruction": "Retrieve candidates relevant to the query.", + "query_text": "A woman playing with her dog on a beach at sunset.", + "document_text": "A woman and her dog spend time together on a beach during sunset.", + "num_layers": 1, + "additional_params": {} + } +] diff --git a/tests/transformers/models/embedding_models/test_qwen3vl_embedding_mad.py b/tests/transformers/models/embedding_models/test_qwen3vl_embedding_mad.py index d540593b8..885372355 100644 --- a/tests/transformers/models/embedding_models/test_qwen3vl_embedding_mad.py +++ b/tests/transformers/models/embedding_models/test_qwen3vl_embedding_mad.py @@ -108,7 +108,6 @@ def test_qwen3_vl_embedding_cpu_vs_ai100_mad_parity(model_name): model_inputs = EXAMPLE_QUERIES + EXAMPLE_DOCUMENTS compile_specs = embedder.get_compile_specs( inputs=model_inputs, - ctx_len=model_cfg["ctx_len"], prefill_seq_len=model_cfg.get("compile_prefill_seq_len", None), ) qpc_paths = qeff_model.compile( diff --git a/tests/transformers/models/reranker/test_reranker_mad.py b/tests/transformers/models/reranker/test_reranker_mad.py index 148935c5a..4677f9693 100644 --- a/tests/transformers/models/reranker/test_reranker_mad.py +++ b/tests/transformers/models/reranker/test_reranker_mad.py @@ -39,7 +39,7 @@ ) from QEfficient.utils.test_utils import load_vlm_model, set_num_layers_vlm -CONFIG_PATH = os.path.join(os.path.dirname(__file__), "../../../configs/image_text_model_configs.json") +CONFIG_PATH = os.path.join(os.path.dirname(__file__), "../../../configs/reranker_model_configs.json") PT_AI100_MAD_MAX = 5e-3 MAX_LENGTH = 8192 @@ -60,8 +60,7 @@ } with open(CONFIG_PATH, "r") as f: - config_data = json.load(f) - reranker_models = config_data["image_text_reranker_models"] + reranker_models = json.load(f) test_reranker_models = [model_config["model_name"] for model_config in reranker_models] reranker_model_config_dict = {model["model_name"]: model for model in reranker_models} @@ -298,7 +297,7 @@ def test_qwen3_vl_reranker_mad_parity(model_name): height=compile_height, width=compile_width, prefill_seq_len=max_prompt_len, - ctx_len=model_cfg["ctx_len"], + ctx_len=max_prompt_len, num_devices=1, num_cores=16, mxfp6_matmul=False, diff --git a/tests/unit_test/models/embedding/test_qwen3vl_embedding_unit.py b/tests/unit_test/models/embedding/test_qwen3vl_embedding_unit.py index ae7c88e83..a7e94eac8 100644 --- a/tests/unit_test/models/embedding/test_qwen3vl_embedding_unit.py +++ b/tests/unit_test/models/embedding/test_qwen3vl_embedding_unit.py @@ -118,8 +118,8 @@ def _fake_run_ai100_prefill(prepared_inputs, vision_outputs, lang_qpc_path): monkeypatch.setattr(QEffQwen3VLEmbedder, "_run_ai100_vision", staticmethod(_fake_run_ai100_vision)) monkeypatch.setattr(QEffQwen3VLEmbedder, "_run_ai100_prefill", staticmethod(_fake_run_ai100_prefill)) - compile_specs = embedder.get_compile_specs(inputs=[{}, {}], ctx_len=64, prefill_seq_len=12) - assert compile_specs == {"prefill_seq_len": 12, "ctx_len": 64, "img_size": 160, "height": 96, "width": 160} + compile_specs = embedder.get_compile_specs(inputs=[{}, {}], prefill_seq_len=12) + assert compile_specs == {"prefill_seq_len": 12, "ctx_len": 12, "img_size": 160, "height": 96, "width": 160} embeddings = embedder.process( inputs=[{}, {}], @@ -130,3 +130,48 @@ def _fake_run_ai100_prefill(prepared_inputs, vision_outputs, lang_qpc_path): assert tuple(embeddings.shape) == (2, 4) norms = torch.linalg.norm(embeddings, dim=-1) assert torch.allclose(norms, torch.ones_like(norms), atol=1e-6) + + +@pytest.mark.embedding +def test_qwen3_vl_embedder_single_qpc_dispatch(monkeypatch): + """process() with a non-dict qpc_paths uses the single-QPC path.""" + from pathlib import Path + + embedder = QEffQwen3VLEmbedder(processor=None, model=_DummyQEffModel()) + + contexts = [{"tokenized": {"kind": "image"}}, {"tokenized": {"kind": "text"}}] + + def _fake_collect_contexts(_inputs): + return contexts, 8, 6, 10 + + def _fake_prepare_qeff_inputs(qeff_model, tokenized_inputs, prefill_seq_len): + del qeff_model + prepared = { + "input_ids": torch.arange(8, dtype=torch.int64).unsqueeze(0), + "position_ids": torch.arange(prefill_seq_len, dtype=torch.int64).reshape(1, 1, prefill_seq_len), + } + if tokenized_inputs.get("kind") == "image": + prepared["pixel_values"] = torch.ones((1, 3, 2, 2), dtype=torch.float32) + return prepared, 8 + + def _fake_single_qpc_prefill(prepared_inputs, qpc_path): + del qpc_path + if "pixel_values" in prepared_inputs: + return np.array([[1.0, 2.0, 3.0, 4.0]], dtype=np.float32) + return np.array([[2.0, 1.0, 0.5, 1.0]], dtype=np.float32) + + monkeypatch.setattr(embedder, "_collect_contexts", _fake_collect_contexts) + monkeypatch.setattr(QEffQwen3VLEmbedder, "_prepare_qeff_inputs", staticmethod(_fake_prepare_qeff_inputs)) + monkeypatch.setattr(QEffQwen3VLEmbedder, "_run_ai100_single_qpc_prefill", staticmethod(_fake_single_qpc_prefill)) + + # Test with string path + embeddings = embedder.process(inputs=[{}, {}], qpc_paths="/path/to/single.qpc", prefill_seq_len=12, normalize=True) + assert tuple(embeddings.shape) == (2, 4) + norms = torch.linalg.norm(embeddings, dim=-1) + assert torch.allclose(norms, torch.ones_like(norms), atol=1e-6) + + # Test with Path object + embeddings = embedder.process( + inputs=[{}, {}], qpc_paths=Path("/tmp/model.qpc"), prefill_seq_len=12, normalize=False + ) + assert tuple(embeddings.shape) == (2, 4) diff --git a/tests/unit_test/models/reranker/test_reranker_models_unit.py b/tests/unit_test/models/reranker/test_reranker_models_unit.py index f3036502e..800801f2d 100644 --- a/tests/unit_test/models/reranker/test_reranker_models_unit.py +++ b/tests/unit_test/models/reranker/test_reranker_models_unit.py @@ -8,27 +8,29 @@ Generic unit coverage for image-text reranker model entries. This test is intentionally model-list driven: - - Add/remove reranker models only in tests/configs/image_text_model_configs.json + - Add/remove reranker models only in tests/configs/reranker_model_configs.json - The same unit checks run for every configured reranker model """ import copy import json -import os +from pathlib import Path from typing import Dict, List +from unittest.mock import MagicMock +import numpy as np import pytest +import torch from transformers import AutoConfig from QEfficient.utils.test_utils import set_num_layers_vlm -CONFIG_PATH = os.path.join(os.path.dirname(__file__), "../../../configs/image_text_model_configs.json") +CONFIG_PATH = "tests/configs/reranker_model_configs.json" def _load_reranker_model_configs() -> List[Dict]: with open(CONFIG_PATH, "r", encoding="utf-8") as file: - config_data = json.load(file) - return config_data.get("image_text_reranker_models", []) + return json.load(file) RERANKER_MODEL_CONFIGS = _load_reranker_model_configs() @@ -51,7 +53,7 @@ def _vision_num_layers(config) -> int: def test_reranker_model_list_is_present(): assert RERANKER_MODEL_CONFIGS, ( - "image_text_reranker_models is empty. Add reranker entries in tests/configs/image_text_model_configs.json." + "reranker_model_configs.json is empty. Add reranker entries in tests/configs/reranker_model_configs.json." ) @@ -81,3 +83,111 @@ def test_reranker_config_reduction_keeps_valid_deepstack(model_cfg: Dict): assert max(deepstack_idxs) < _vision_num_layers(reduced_cfg), ( f"{model_name}: deepstack indexes must be in [0, vision_num_layers)" ) + + +# --------------------------------------------------------------------------- +# Tests: kv_offload=False (single QPC) runtime dispatch in QEffQwen3VLReranker +# --------------------------------------------------------------------------- + + +def _make_dummy_reranker(): + """Build a minimal QEffQwen3VLReranker with mocked internals.""" + # Import the reranker class from the examples directory via importlib + import importlib.util + import sys + + spec = importlib.util.spec_from_file_location( + "reranker_model", + Path(__file__).parents[4] / "examples" / "reranker" / "qwen3vl" / "reranker_model.py", + ) + mod = importlib.util.module_from_spec(spec) + # Stub heavy dependencies so the module loads without hardware + sys.modules.setdefault("QEfficient.generation.cloud_infer", MagicMock()) + sys.modules.setdefault("QEfficient.transformers.models.qwen3_vl._reranker_utils", MagicMock()) + spec.loader.exec_module(mod) + return mod.QEffQwen3VLReranker + + +@pytest.fixture() +def reranker_cls(): + return _make_dummy_reranker() + + +def _fake_prepared_inputs(has_image: bool, prefill_len: int = 8): + inputs = { + "input_ids": torch.ones((1, prefill_len), dtype=torch.int64), + "position_ids": torch.arange(prefill_len).reshape(1, 1, prefill_len).expand(4, 1, -1), + } + if has_image: + inputs["pixel_values"] = torch.zeros((748, 1536), dtype=torch.float32) + inputs["image_grid_thw"] = torch.zeros((1, 1, 22, 34), dtype=torch.int64) + return inputs + + +def test_reranker_process_dispatches_to_dual_qpc(reranker_cls, monkeypatch): + """process() with dict qpc_paths uses the dual-QPC path.""" + reranker = object.__new__(reranker_cls) + reranker.yes_token_id = 0 + reranker.no_token_id = 1 + + fake_logits = np.zeros((1, 1, 10), dtype=np.float32) + fake_logits[0, 0, 0] = 2.0 # yes logit > no logit → score > 0.5 + + monkeypatch.setattr(reranker, "_collect_contexts", lambda _: ([{"tokenized": {}}], 4, 22, 34)) + monkeypatch.setattr(reranker, "_prepare_inputs", lambda tok, prefill_seq_len: _fake_prepared_inputs(True)) + monkeypatch.setattr( + reranker_cls, "_run_ai100_vision", staticmethod(lambda pi, vision_qpc_path: {"v": np.zeros((1,))}) + ) + monkeypatch.setattr( + reranker_cls, + "_run_ai100_prefill", + staticmethod(lambda pi, vision_template, lang_qpc_path, vision_qpc_path: fake_logits), + ) + monkeypatch.setattr(reranker_cls, "_score_from_logits", staticmethod(lambda logits, y, n: 0.88)) + + scores = reranker.process( + inputs={}, + qpc_paths={"vision_qpc_path": "v.qpc", "lang_qpc_path": "l.qpc"}, + prefill_seq_len=8, + ) + assert scores == [0.88] + + +def test_reranker_process_dispatches_to_single_qpc(reranker_cls, monkeypatch): + """process() with a non-dict qpc_paths uses the single-QPC path.""" + reranker = object.__new__(reranker_cls) + reranker.yes_token_id = 0 + reranker.no_token_id = 1 + + fake_logits = np.zeros((1, 1, 10), dtype=np.float32) + + monkeypatch.setattr(reranker, "_collect_contexts", lambda _: ([{"tokenized": {}}], 4, 22, 34)) + monkeypatch.setattr(reranker, "_prepare_inputs", lambda tok, prefill_seq_len: _fake_prepared_inputs(False)) + monkeypatch.setattr( + reranker_cls, + "_run_ai100_single_qpc_prefill", + staticmethod(lambda pi, qpc_path: fake_logits), + ) + monkeypatch.setattr(reranker_cls, "_score_from_logits", staticmethod(lambda logits, y, n: 0.72)) + + scores = reranker.process(inputs={}, qpc_paths="/path/to/single.qpc", prefill_seq_len=8) + assert scores == [0.72] + + +def test_reranker_process_single_qpc_with_pathlib(reranker_cls, monkeypatch): + """Single QPC path also accepts a pathlib.Path object.""" + reranker = object.__new__(reranker_cls) + reranker.yes_token_id = 0 + reranker.no_token_id = 1 + + monkeypatch.setattr(reranker, "_collect_contexts", lambda _: ([{"tokenized": {}}], 4, 22, 34)) + monkeypatch.setattr(reranker, "_prepare_inputs", lambda tok, prefill_seq_len: _fake_prepared_inputs(True)) + monkeypatch.setattr( + reranker_cls, + "_run_ai100_single_qpc_prefill", + staticmethod(lambda pi, qpc_path: np.zeros((1, 1, 10), dtype=np.float32)), + ) + monkeypatch.setattr(reranker_cls, "_score_from_logits", staticmethod(lambda logits, y, n: 0.5)) + + scores = reranker.process(inputs={}, qpc_paths=Path("/tmp/model.qpc"), prefill_seq_len=8) + assert scores == [0.5]