Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions QEfficient/blocking/attention_blocking.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ def past_key_value_update(
position_ids: Optional[torch.LongTensor] = None,
sliding_window: Optional[int] = None,
):
cache_kwargs = {}
if past_key_value is not None:
cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids}
if sliding_window is not None:
Expand Down
22 changes: 16 additions & 6 deletions QEfficient/transformers/models/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -1379,17 +1379,12 @@ def export(
List[str]
A list containing the paths to the generated ONNX graph files for both components.
"""
dummy_inputs_kwargs = {}
if prefill_seq_len is not None:
dummy_inputs_kwargs["prefill_seq_len"] = int(prefill_seq_len)

# TODO This is a temporary change as continous batching is enabled only for few models. Once support is added for all the models this exception handing can be removed.
try:
inputs = self.model.get_dummy_inputs(
kv_offload=True,
continuous_batching=self.continuous_batching,
comp_ctx_lengths=self.comp_ctx_lengths_decode,
**dummy_inputs_kwargs,
)
dynamic_axes = self.model.get_onnx_dynamic_axes(
kv_offload=True,
Expand Down Expand Up @@ -1733,6 +1728,10 @@ def filter_custom_io_lang(custom_io_lang, onnx_path):
elif prefill_seq_len == 1:
specializations = specializations["lang"][-1:]
qpc_key = "lang_decode_qpc_path"
elif prefill_seq_len is not None and ctx_len is not None and prefill_seq_len == ctx_len:
# Single-shot mode (e.g. reranker): no decode steps, only prefill kernel needed.
specializations = specializations["lang"][:1]
qpc_key = "lang_qpc_path"
else:
specializations = specializations["lang"]
qpc_key = "lang_qpc_path"
Expand Down Expand Up @@ -2426,6 +2425,11 @@ def compile(
**compiler_options,
)

# Single-shot mode (reranker/embedding): no decode steps, only prefill kernel needed.
single_shot = prefill_seq_len is not None and ctx_len is not None and prefill_seq_len == ctx_len
if single_shot:
specializations = specializations[:1]

if hasattr(self.model, "get_npi_file") and "node_precision_info" not in compiler_options:
compiler_options["node_precision_info"] = self.model.get_npi_file(self.model.name_or_path)

Expand All @@ -2446,14 +2450,20 @@ def compile(
CUSTOM_IO_DTYPE_MAP[target_dtype] if "pixel_values" in output_name else kv_cache_dtype
)

# Single-shot mode has no retained state; pixel_values is a direct input so
# its dtype must still be set explicitly (float16 for hardware).
if single_shot:
custom_io["pixel_values"] = CUSTOM_IO_DTYPE_MAP[target_dtype]

# TODO this hould be removed once the continous batching is supported for all the models.
compiler_options.pop("continuous_batching", None)
compiler_options.pop("kv_cache_batch_size", None)
compiler_options.pop("full_batch_size", None)
self._compile(
onnx_path=onnx_path,
compile_dir=compile_dir,
retained_state=True,
# Single-shot (reranker/embedding): no decode, no need for retained-state enforcement.
retained_state=not single_shot,
specializations=specializations,
convert_to_fp16=(CUSTOM_IO_DTYPE_MAP[target_dtype] == "float16"),
mxfp6_matmul=mxfp6_matmul,
Expand Down
131 changes: 99 additions & 32 deletions QEfficient/transformers/models/qwen3_vl/_embedding_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,9 +257,7 @@ def _collect_contexts(self, inputs: List[Dict[str, Any]]):

return contexts, max_prompt_len, max_grid_h, max_grid_w

def get_compile_specs(
self, inputs: List[Dict[str, Any]], ctx_len: int, prefill_seq_len: int = None
) -> Dict[str, int]:
def get_compile_specs(self, inputs: List[Dict[str, Any]], prefill_seq_len: int = None) -> Dict[str, int]:
"""Compute compile-time spec values for the current input batch."""
_, max_prompt_len, max_grid_h, max_grid_w = self._collect_contexts(inputs)
if max_prompt_len == 0:
Expand All @@ -275,9 +273,10 @@ def get_compile_specs(
height = max_grid_h * patch_size
width = max_grid_w * patch_size

# ctx_len == prefill_seq_len always: embedding is single-shot prefill, no decode steps.
return {
"prefill_seq_len": target_prefill_seq_len,
"ctx_len": int(ctx_len),
"ctx_len": target_prefill_seq_len,
"img_size": max(height, width),
"height": height,
"width": width,
Expand Down Expand Up @@ -352,17 +351,71 @@ def _run_ai100_prefill(
embedding_output = embedding_output.reshape(embedding_output.shape[0], -1)
return embedding_output

@staticmethod
def _run_ai100_single_qpc_prefill(
prepared_inputs: Dict[str, torch.Tensor],
qpc_path,
) -> np.ndarray:
"""Execute single-QPC (vision+language fused) prefill and return the embedding row."""
prefill_len = prepared_inputs["position_ids"].shape[-1]
input_ids = prepared_inputs["input_ids"]
if input_ids.shape[1] < prefill_len:
pad = torch.full(
(input_ids.shape[0], prefill_len - input_ids.shape[1]),
1,
dtype=input_ids.dtype,
device=input_ids.device,
)
input_ids = torch.cat([input_ids, pad], dim=1)
else:
input_ids = input_ids[:, :prefill_len]

position_ids = prepared_inputs["position_ids"][..., :prefill_len]

session = QAICInferenceSession(str(qpc_path))

run_inputs = {
"input_ids": input_ids.detach().cpu().numpy().astype(np.int64),
"position_ids": position_ids.detach().cpu().numpy().astype(np.int64),
"image_idx": np.zeros((1, 1), dtype=np.int64),
}

if "pixel_values" in prepared_inputs:
run_inputs["pixel_values"] = prepared_inputs["pixel_values"].detach().cpu().numpy().astype(np.float16)
else:
pv_idx = session.binding_index_map["pixel_values"]
run_inputs["pixel_values"] = np.zeros(session.bindings[pv_idx].dims, dtype=np.float16)

for name in session.input_names:
if name.startswith("past_"):
idx = session.binding_index_map[name]
run_inputs[name] = np.zeros(session.bindings[idx].dims, dtype=np.float16)

outputs = session.run(run_inputs)
session.deactivate()

if "embedding_output" not in outputs:
raise KeyError(
"Missing 'embedding_output' in single-QPC outputs. Ensure export_embedding is enabled in qaic_config."
)

embedding_output = outputs["embedding_output"]
if embedding_output.ndim > 2:
embedding_output = embedding_output.reshape(embedding_output.shape[0], -1)
return embedding_output

def process(
self,
inputs: List[Dict[str, Any]],
qpc_paths: Dict[str, str],
qpc_paths,
prefill_seq_len: int,
normalize: bool = True,
) -> torch.Tensor:
"""Run AI100 embedding generation for all inputs and return stacked rows."""
if "vision_qpc_path" not in qpc_paths or "lang_qpc_path" not in qpc_paths:
raise ValueError("qpc_paths must contain 'vision_qpc_path' and 'lang_qpc_path'.")
"""Run AI100 embedding generation for all inputs and return stacked rows.

Supports both dual-QPC (qpc_paths is a dict with 'vision_qpc_path' and
'lang_qpc_path') and single-QPC (qpc_paths is a str/Path to the combined QPC).
"""
contexts, max_prompt_len, _, _ = self._collect_contexts(inputs)
if max_prompt_len == 0:
return torch.empty((0, 0), dtype=torch.float32)
Expand All @@ -374,7 +427,6 @@ def process(
)

prepared_contexts = []
vision_template = None
for ctx in contexts:
prepared_inputs, _ = self._prepare_qeff_inputs(
qeff_model=self.model,
Expand All @@ -383,32 +435,47 @@ def process(
)
prepared_contexts.append({"prepared_inputs": prepared_inputs})

if vision_template is None and "pixel_values" in prepared_inputs and "image_grid_thw" in prepared_inputs:
vision_template = self._run_ai100_vision(
vision_qpc_path=qpc_paths["vision_qpc_path"],
prepared_inputs=prepared_inputs,
)

if vision_template is None:
raise ValueError("At least one input with an image is required to initialize AI100 vision buffers.")

is_dual_qpc = isinstance(qpc_paths, dict)
embedding_rows = []
for ctx in prepared_contexts:
prepared_inputs = ctx["prepared_inputs"]
if "pixel_values" in prepared_inputs and "image_grid_thw" in prepared_inputs:
vision_outputs = self._run_ai100_vision(
vision_qpc_path=qpc_paths["vision_qpc_path"],

if is_dual_qpc:
if "vision_qpc_path" not in qpc_paths or "lang_qpc_path" not in qpc_paths:
raise ValueError("qpc_paths must contain 'vision_qpc_path' and 'lang_qpc_path'.")

vision_template = None
for ctx in prepared_contexts:
if vision_template is None and "pixel_values" in ctx["prepared_inputs"]:
vision_template = self._run_ai100_vision(
vision_qpc_path=qpc_paths["vision_qpc_path"],
prepared_inputs=ctx["prepared_inputs"],
)

if vision_template is None:
raise ValueError("At least one input with an image is required to initialize AI100 vision buffers.")

for ctx in prepared_contexts:
prepared_inputs = ctx["prepared_inputs"]
if "pixel_values" in prepared_inputs and "image_grid_thw" in prepared_inputs:
vision_outputs = self._run_ai100_vision(
vision_qpc_path=qpc_paths["vision_qpc_path"],
prepared_inputs=prepared_inputs,
)
else:
vision_outputs = self._zero_vision_outputs(vision_template)

embedding_output = self._run_ai100_prefill(
prepared_inputs=prepared_inputs,
vision_outputs=vision_outputs,
lang_qpc_path=qpc_paths["lang_qpc_path"],
)
else:
vision_outputs = self._zero_vision_outputs(vision_template)

embedding_output = self._run_ai100_prefill(
prepared_inputs=prepared_inputs,
vision_outputs=vision_outputs,
lang_qpc_path=qpc_paths["lang_qpc_path"],
)
embedding_rows.append(torch.from_numpy(embedding_output).to(torch.float32))
embedding_rows.append(torch.from_numpy(embedding_output).to(torch.float32))
else:
# Single QPC: vision + language fused in one compiled binary.
for ctx in prepared_contexts:
embedding_output = self._run_ai100_single_qpc_prefill(
prepared_inputs=ctx["prepared_inputs"], qpc_path=qpc_paths
)
embedding_rows.append(torch.from_numpy(embedding_output).to(torch.float32))

embeddings = torch.cat(embedding_rows, dim=0)
if normalize:
Expand Down
Loading
Loading