Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions QEfficient/base/modeling_qeff.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
ReplicateKVHeadTransform,
)
from QEfficient.utils import (
align_kv_input_names_to_retained_outputs,
constants,
create_json,
create_model_params,
Expand Down Expand Up @@ -372,6 +373,16 @@ def _resolve_pkv_layers(pkv_obj):
else:
input_names.append(param)

# When retained-state outputs carry an injected KV-cache prefix
# (past_key.0_<prefix>_RetainedState), rename the matching KV inputs (past_key.0 ->
# past_key.0_<prefix>) so the compiler pairs and retains them, and carry the dynamic axes over
# to the renamed inputs. No-op without a prefix.
aligned_input_names = align_kv_input_names_to_retained_outputs(input_names, output_names)
if aligned_input_names != input_names:
rename_map = {old: new for old, new in zip(input_names, aligned_input_names) if old != new}
dynamic_axes = {rename_map.get(k, k): v for k, v in dynamic_axes.items()}
input_names = aligned_input_names

try:
torch.onnx.export(
self.model,
Expand Down Expand Up @@ -431,13 +442,16 @@ def get_onnx_path(
retain_full_kv: Optional[bool] = False,
qaic_config: Optional[dict] = None,
moe_prefill_packed_chunk_size: Optional[int] = None,
kv_cache_prefix: Optional[str] = None,
**compiler_options,
):
kwargs = {
"offload_pt_weights": offload_pt_weights,
"use_onnx_subfunctions": use_onnx_subfunctions,
"retain_full_kv": retain_full_kv,
}
if kv_cache_prefix:
kwargs["kv_cache_prefix"] = kv_cache_prefix

if prefill_only:
kwargs.update(
Expand Down Expand Up @@ -709,6 +723,7 @@ def _compile(
retain_full_kv: Optional[bool] = None,
qaic_config: Optional[dict] = None,
specialization_module_name: Optional[str] = None,
kv_cache_prefix: Optional[str] = None,
**compiler_options,
) -> str:
"""
Expand Down Expand Up @@ -754,6 +769,7 @@ def _compile(
num_devices=mdp_ts_num_devices,
qaic_config=qaic_config,
moe_prefill_packed_chunk_size=moe_prefill_packed_chunk_size,
kv_cache_prefix=kv_cache_prefix,
**compiler_options,
)
onnx_path = Path(onnx_path)
Expand Down
42 changes: 39 additions & 3 deletions QEfficient/transformers/models/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,10 @@
Mxfp4GptOssExpertDequantizeTransform,
)
from QEfficient.utils import (
apply_kv_cache_prefix,
constants,
get_padding_shape_from_config,
validate_kv_cache_prefix,
)
from QEfficient.utils.check_ccl_specializations import process_ccl_specializations
from QEfficient.utils.logging_utils import logger
Expand Down Expand Up @@ -1357,6 +1359,7 @@ def export(
prefill_seq_len: Optional[int] = None,
prefill_only: bool = False,
enable_chunking: bool = False,
kv_cache_prefix: Optional[str] = None,
**kwargs,
) -> str:
"""
Expand Down Expand Up @@ -1402,6 +1405,8 @@ def export(
kv_offload=True, comp_ctx_lengths=self.comp_ctx_lengths_decode
)
output_names = self.model.get_output_names(kv_offload=True)
# Prefix only the language-side KV-cache retained buffers (vision buffers are untouched).
output_names = apply_kv_cache_prefix(output_names, validate_kv_cache_prefix(kv_cache_prefix))
if self.lang_model.model.qaic_config is not None and self.lang_model.model.qaic_config.get(
"include_sampler", False
):
Expand Down Expand Up @@ -1506,6 +1511,7 @@ def compile(
prefill_only=None,
enable_chunking=False,
qaic_config: Optional[dict] = None,
kv_cache_prefix: Optional[str] = None,
**compiler_options,
) -> str:
"""
Expand Down Expand Up @@ -1577,7 +1583,11 @@ def compile(
# Infer kv_cache_batch_size if not provided
kv_cache_batch_size = kv_cache_batch_size or full_batch_size or batch_size

kv_cache_prefix = validate_kv_cache_prefix(kv_cache_prefix)
output_names = self.model.get_output_names(kv_offload=True)
# Prefix only the language-side KV-cache retained buffers (vision buffers are untouched) so the
# derived custom_io_lang keys match the prefixed names written into the exported graph.
output_names = apply_kv_cache_prefix(output_names, kv_cache_prefix)

# if ccl_enabled is True read Compute-Context-Length lists
if self.ccl_enabled:
Expand Down Expand Up @@ -1643,6 +1653,7 @@ def compile(
prefill_only=prefill_only,
enable_chunking=enable_chunking,
prefill_seq_len=prefill_seq_len,
kv_cache_prefix=kv_cache_prefix,
)

if hasattr(self.model, "generate_npi_file") and "node_precision_info" in compiler_options:
Expand Down Expand Up @@ -2270,6 +2281,7 @@ def export(
prefill_seq_len: Optional[int] = None,
prefill_only: bool = False,
enable_chunking: bool = False,
kv_cache_prefix: Optional[str] = None,
**kwargs,
) -> str:
"""
Expand Down Expand Up @@ -2302,6 +2314,8 @@ def export(
inputs = self.model.get_dummy_inputs(comp_ctx_lengths=self.comp_ctx_lengths_decode)
dynamic_axes = self.model.get_onnx_dynamic_axes(comp_ctx_lengths=self.comp_ctx_lengths_decode)
output_names = self.model.get_output_names()
# Prefix only the LLM KV-cache retained buffers (vision/multimodal buffers untouched).
output_names = apply_kv_cache_prefix(output_names, validate_kv_cache_prefix(kv_cache_prefix))
return self._export(
inputs,
output_names=output_names,
Expand Down Expand Up @@ -2330,6 +2344,7 @@ def compile(
num_speculative_tokens: Optional[int] = None,
use_onnx_subfunctions: bool = False,
qaic_config: Optional[dict] = None,
kv_cache_prefix: Optional[str] = None,
**compiler_options,
) -> str:
"""
Expand Down Expand Up @@ -2388,7 +2403,11 @@ def compile(

# Infer kv_cache_batch_size if not provided
kv_cache_batch_size = kv_cache_batch_size or full_batch_size or batch_size
kv_cache_prefix = validate_kv_cache_prefix(kv_cache_prefix)
output_names = self.model.get_output_names()
# Prefix only the LLM KV-cache retained buffers so the derived custom_io (and the names baked
# into the exported graph) stay consistent; vision/multimodal buffers are untouched.
output_names = apply_kv_cache_prefix(output_names, kv_cache_prefix)

# if ccl_enabled is True read Compute-Context-Length lists
if self.ccl_enabled:
Expand Down Expand Up @@ -2462,6 +2481,7 @@ def compile(
aic_num_cores=num_cores,
mxint8_kv_cache=mxint8_kv_cache,
use_onnx_subfunctions=use_onnx_subfunctions,
kv_cache_prefix=kv_cache_prefix,
**compiler_options,
)
return self.qpc_path
Expand Down Expand Up @@ -3241,6 +3261,7 @@ def export(
prefill_seq_len: Optional[int] = None,
num_cores: int = constants.DEFAULT_AIC_NUM_CORES,
moe_prefill_packed_chunk_size: int = constants.MOE_PREFILL_PACKED_CHUNK_SIZE,
kv_cache_prefix: Optional[str] = None,
**kwargs,
) -> str:
"""
Expand Down Expand Up @@ -3506,6 +3527,14 @@ def _legacyify_cache(obj):
self.model.forward = _qeff_patched_forward
self.model._qeff_export_gemma3_cache_patch = True

# Optionally inject a user-provided infix token into the LLM KV-cache retained-state names
# (e.g. past_key.0_RetainedState -> past_key.0_<prefix>_RetainedState) so downstream consumers
# (vLLM disaggregated serving) can regex-select only the LLM KV buffers for transfer.
kv_cache_prefix = validate_kv_cache_prefix(kv_cache_prefix)
if kv_cache_prefix:
output_names = apply_kv_cache_prefix(output_names, kv_cache_prefix)
self.hash_params["kv_cache_prefix"] = kv_cache_prefix

if os.environ.get("LAYERWISE_EXPORT", "False") == "True":
return self._export_layerwise(
example_inputs,
Expand Down Expand Up @@ -3678,6 +3707,7 @@ def compile(
enable_chunking: Optional[bool] = False,
moe_prefill_packed_chunk_size: int = constants.MOE_PREFILL_PACKED_CHUNK_SIZE,
retain_full_kv: Optional[bool] = None,
kv_cache_prefix: Optional[str] = None,
**compiler_options,
) -> str:
"""
Expand Down Expand Up @@ -3956,17 +3986,22 @@ def compile(
target_dtype = getattr(self.model.config, "torch_dtype", torch.float32)
kv_cache_dtype = "mxint8" if mxint8_kv_cache else CUSTOM_IO_DTYPE_MAP[target_dtype]
# --- Compilation ---
# When a KV-cache prefix is requested, the exported buffers are named
# past_key.{i}_<prefix> (input) and past_key.{i}_<prefix>_RetainedState (output); the custom_io
# keys must match those names so the compiler pairs and retains them correctly.
kv_cache_prefix = validate_kv_cache_prefix(kv_cache_prefix)
kv_infix = f"_{kv_cache_prefix}" if kv_cache_prefix else ""
custom_io = {}
if not cache_compressed:
for suffix in ["", "_RetainedState"]:
for i in range(self.num_layers):
for kv in ["key", "value"]:
custom_io[f"past_{kv}.{i}{suffix}"] = kv_cache_dtype
custom_io[f"past_{kv}.{i}{kv_infix}{suffix}"] = kv_cache_dtype
else:
for suffix in ["", "_RetainedState"]:
for i in range(self.num_layers):
custom_io[f"compressed_kv.{i}{suffix}"] = kv_cache_dtype
custom_io[f"k_pe.{i}{suffix}"] = kv_cache_dtype
custom_io[f"compressed_kv.{i}{kv_infix}{suffix}"] = kv_cache_dtype
custom_io[f"k_pe.{i}{kv_infix}{suffix}"] = kv_cache_dtype

def filter_custom_io(custom_io_lang, onnx_path):
# Extract filename
Expand Down Expand Up @@ -4017,6 +4052,7 @@ def filter_custom_io(custom_io_lang, onnx_path):
offload_pt_weights=offload_pt_weights,
enable_chunking=enable_chunking,
retain_full_kv=retain_full_kv,
kv_cache_prefix=kv_cache_prefix,
**compiler_options,
)

Expand Down
3 changes: 3 additions & 0 deletions QEfficient/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
)
from QEfficient.utils._utils import ( # noqa: F401
LRUCache,
align_kv_input_names_to_retained_outputs,
apply_kv_cache_prefix,
check_and_assign_cache_dir,
create_json,
create_model_params,
Expand All @@ -37,6 +39,7 @@
qpc_exists,
require_value,
to_named_specializations,
validate_kv_cache_prefix,
)
from QEfficient.utils.compile_layerwise import ( # noqa: F401
run_compile_layerwise,
Expand Down
96 changes: 96 additions & 0 deletions QEfficient/utils/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,102 @@
from QEfficient.utils.hash_utils import json_serializable
from QEfficient.utils.logging_utils import logger

# Retained-state buffer name stems that correspond to the LLM KV cache. Only these are eligible for
# the optional vLLM KV-cache prefix; vision/multimodal retained buffers (vision_embeds, pixel_values,
# image_idx, deepstack_features, ...) are intentionally excluded.
_KV_RETAINED_STEMS = ("past_key.", "past_value.", "compressed_kv.", "k_pe.")
_RETAINED_STATE_SUFFIX = "_RetainedState"


def validate_kv_cache_prefix(kv_cache_prefix: Optional[str]) -> Optional[str]:
"""
Validate the optional KV-cache buffer-name prefix.

The prefix is injected as an infix token into KV retained-state names
(``past_key.0_RetainedState`` -> ``past_key.0_<prefix>_RetainedState``), so it must be a plain
alphanumeric token. Disallowing ``.`` and ``_`` keeps the ``past_key.{layer}_{prefix}`` structure
unambiguous for downstream regex matching.

Returns the prefix unchanged when valid, or ``None`` when not provided.
"""
if kv_cache_prefix is None:
return None
if not isinstance(kv_cache_prefix, str) or not kv_cache_prefix.isalnum():
raise ValueError(
"kv_cache_prefix must be a non-empty alphanumeric string (no '.', '_' or whitespace); "
f"got {kv_cache_prefix!r}"
)
return kv_cache_prefix


def _infix_kv_prefix(name: str, kv_cache_prefix: str) -> str:
"""Insert ``_<prefix>`` before the ``_RetainedState`` suffix for LLM KV-cache buffers only."""
if not name.endswith(_RETAINED_STATE_SUFFIX):
return name
stem = name[: -len(_RETAINED_STATE_SUFFIX)]
if not any(stem.startswith(kv_stem) for kv_stem in _KV_RETAINED_STEMS):
return name
return f"{stem}_{kv_cache_prefix}{_RETAINED_STATE_SUFFIX}"


def apply_kv_cache_prefix(output_names, kv_cache_prefix: Optional[str]):
"""
Insert an infix token into LLM KV-cache retained-state output names.

``past_key.0_RetainedState`` -> ``past_key.0_<prefix>_RetainedState`` (and likewise for
``past_value`` / ``compressed_kv`` / ``k_pe``). The matching device input buffer is named by the
compiler by stripping ``_RetainedState`` (``past_key.0_<prefix>``), so KV retention pairing is
preserved. Vision/multimodal retained buffers are left untouched.

Accepts either a flat ``List[str]`` (CausalLM / single-QPC VLM) or the
``{"vision": [...], "lang": [...]}`` dict (dual-QPC VLM); for the dict form only the ``lang`` list
is rewritten. No-op when ``kv_cache_prefix`` is falsy. The input is not mutated in place.
"""
if not kv_cache_prefix:
return output_names
validate_kv_cache_prefix(kv_cache_prefix)

if isinstance(output_names, dict):
result = dict(output_names)
if result.get("lang") is not None:
result["lang"] = [_infix_kv_prefix(name, kv_cache_prefix) for name in result["lang"]]
return result
return [_infix_kv_prefix(name, kv_cache_prefix) for name in output_names]


def align_kv_input_names_to_retained_outputs(input_names, output_names):
"""
Rename KV-cache *input* buffers so each pairs with its retained-state *output*.

The AIC compiler retains a KV buffer by matching an output ``X_RetainedState`` to the input named
``X`` (suffix stripped). When the retained outputs carry an injected prefix
(``past_key.0_<prefix>_RetainedState``), the corresponding input must be renamed from
``past_key.0`` to ``past_key.0_<prefix>`` for the pairing to hold.

This derives the rename purely from ``output_names`` (which already carry any prefix), so callers
that build prefixed outputs do not need to thread the prefix separately. It is a no-op for inputs
that already match a retained output exactly, and for non-KV inputs. ``input_names`` is not mutated.
"""
# Stripped target names from retained KV outputs, e.g. {"past_key.0_VLLM", "past_value.0_VLLM"}.
retained_targets = []
for name in output_names:
if not name.endswith(_RETAINED_STATE_SUFFIX):
continue
stem = name[: -len(_RETAINED_STATE_SUFFIX)]
if any(stem.startswith(kv_stem) for kv_stem in _KV_RETAINED_STEMS):
retained_targets.append(stem)
retained_set = set(retained_targets)

aligned = []
for name in input_names:
if not any(name.startswith(stem) for stem in _KV_RETAINED_STEMS) or name in retained_set:
aligned.append(name)
continue
# Find a retained target that is this input with an extra "_<prefix>" infix.
match = next((t for t in retained_targets if t == name or t.startswith(name + "_")), None)
aligned.append(match if match is not None else name)
return aligned


class LRUCache:
"""Simple LRU cache with size limit for vision outputs"""
Expand Down
Loading
Loading