diff --git a/QEfficient/base/modeling_qeff.py b/QEfficient/base/modeling_qeff.py index 3bb05c7b9..2285ceae1 100755 --- a/QEfficient/base/modeling_qeff.py +++ b/QEfficient/base/modeling_qeff.py @@ -36,6 +36,7 @@ ReplicateKVHeadTransform, ) from QEfficient.utils import ( + align_kv_input_names_to_retained_outputs, constants, create_json, create_model_params, @@ -372,6 +373,16 @@ def _resolve_pkv_layers(pkv_obj): else: input_names.append(param) + # When retained-state outputs carry an injected KV-cache prefix + # (past_key.0__RetainedState), rename the matching KV inputs (past_key.0 -> + # past_key.0_) so the compiler pairs and retains them, and carry the dynamic axes over + # to the renamed inputs. No-op without a prefix. + aligned_input_names = align_kv_input_names_to_retained_outputs(input_names, output_names) + if aligned_input_names != input_names: + rename_map = {old: new for old, new in zip(input_names, aligned_input_names) if old != new} + dynamic_axes = {rename_map.get(k, k): v for k, v in dynamic_axes.items()} + input_names = aligned_input_names + try: torch.onnx.export( self.model, @@ -431,6 +442,7 @@ def get_onnx_path( retain_full_kv: Optional[bool] = False, qaic_config: Optional[dict] = None, moe_prefill_packed_chunk_size: Optional[int] = None, + kv_cache_prefix: Optional[str] = None, **compiler_options, ): kwargs = { @@ -438,6 +450,8 @@ def get_onnx_path( "use_onnx_subfunctions": use_onnx_subfunctions, "retain_full_kv": retain_full_kv, } + if kv_cache_prefix: + kwargs["kv_cache_prefix"] = kv_cache_prefix if prefill_only: kwargs.update( @@ -709,6 +723,7 @@ def _compile( retain_full_kv: Optional[bool] = None, qaic_config: Optional[dict] = None, specialization_module_name: Optional[str] = None, + kv_cache_prefix: Optional[str] = None, **compiler_options, ) -> str: """ @@ -754,6 +769,7 @@ def _compile( num_devices=mdp_ts_num_devices, qaic_config=qaic_config, moe_prefill_packed_chunk_size=moe_prefill_packed_chunk_size, + kv_cache_prefix=kv_cache_prefix, **compiler_options, ) onnx_path = Path(onnx_path) diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 65b89d274..aa7d99b8b 100755 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -73,8 +73,10 @@ Mxfp4GptOssExpertDequantizeTransform, ) from QEfficient.utils import ( + apply_kv_cache_prefix, constants, get_padding_shape_from_config, + validate_kv_cache_prefix, ) from QEfficient.utils.check_ccl_specializations import process_ccl_specializations from QEfficient.utils.logging_utils import logger @@ -1357,6 +1359,7 @@ def export( prefill_seq_len: Optional[int] = None, prefill_only: bool = False, enable_chunking: bool = False, + kv_cache_prefix: Optional[str] = None, **kwargs, ) -> str: """ @@ -1402,6 +1405,8 @@ def export( kv_offload=True, comp_ctx_lengths=self.comp_ctx_lengths_decode ) output_names = self.model.get_output_names(kv_offload=True) + # Prefix only the language-side KV-cache retained buffers (vision buffers are untouched). + output_names = apply_kv_cache_prefix(output_names, validate_kv_cache_prefix(kv_cache_prefix)) if self.lang_model.model.qaic_config is not None and self.lang_model.model.qaic_config.get( "include_sampler", False ): @@ -1506,6 +1511,7 @@ def compile( prefill_only=None, enable_chunking=False, qaic_config: Optional[dict] = None, + kv_cache_prefix: Optional[str] = None, **compiler_options, ) -> str: """ @@ -1577,7 +1583,11 @@ def compile( # Infer kv_cache_batch_size if not provided kv_cache_batch_size = kv_cache_batch_size or full_batch_size or batch_size + kv_cache_prefix = validate_kv_cache_prefix(kv_cache_prefix) output_names = self.model.get_output_names(kv_offload=True) + # Prefix only the language-side KV-cache retained buffers (vision buffers are untouched) so the + # derived custom_io_lang keys match the prefixed names written into the exported graph. + output_names = apply_kv_cache_prefix(output_names, kv_cache_prefix) # if ccl_enabled is True read Compute-Context-Length lists if self.ccl_enabled: @@ -1643,6 +1653,7 @@ def compile( prefill_only=prefill_only, enable_chunking=enable_chunking, prefill_seq_len=prefill_seq_len, + kv_cache_prefix=kv_cache_prefix, ) if hasattr(self.model, "generate_npi_file") and "node_precision_info" in compiler_options: @@ -2270,6 +2281,7 @@ def export( prefill_seq_len: Optional[int] = None, prefill_only: bool = False, enable_chunking: bool = False, + kv_cache_prefix: Optional[str] = None, **kwargs, ) -> str: """ @@ -2302,6 +2314,8 @@ def export( inputs = self.model.get_dummy_inputs(comp_ctx_lengths=self.comp_ctx_lengths_decode) dynamic_axes = self.model.get_onnx_dynamic_axes(comp_ctx_lengths=self.comp_ctx_lengths_decode) output_names = self.model.get_output_names() + # Prefix only the LLM KV-cache retained buffers (vision/multimodal buffers untouched). + output_names = apply_kv_cache_prefix(output_names, validate_kv_cache_prefix(kv_cache_prefix)) return self._export( inputs, output_names=output_names, @@ -2330,6 +2344,7 @@ def compile( num_speculative_tokens: Optional[int] = None, use_onnx_subfunctions: bool = False, qaic_config: Optional[dict] = None, + kv_cache_prefix: Optional[str] = None, **compiler_options, ) -> str: """ @@ -2388,7 +2403,11 @@ def compile( # Infer kv_cache_batch_size if not provided kv_cache_batch_size = kv_cache_batch_size or full_batch_size or batch_size + kv_cache_prefix = validate_kv_cache_prefix(kv_cache_prefix) output_names = self.model.get_output_names() + # Prefix only the LLM KV-cache retained buffers so the derived custom_io (and the names baked + # into the exported graph) stay consistent; vision/multimodal buffers are untouched. + output_names = apply_kv_cache_prefix(output_names, kv_cache_prefix) # if ccl_enabled is True read Compute-Context-Length lists if self.ccl_enabled: @@ -2462,6 +2481,7 @@ def compile( aic_num_cores=num_cores, mxint8_kv_cache=mxint8_kv_cache, use_onnx_subfunctions=use_onnx_subfunctions, + kv_cache_prefix=kv_cache_prefix, **compiler_options, ) return self.qpc_path @@ -3241,6 +3261,7 @@ def export( prefill_seq_len: Optional[int] = None, num_cores: int = constants.DEFAULT_AIC_NUM_CORES, moe_prefill_packed_chunk_size: int = constants.MOE_PREFILL_PACKED_CHUNK_SIZE, + kv_cache_prefix: Optional[str] = None, **kwargs, ) -> str: """ @@ -3506,6 +3527,14 @@ def _legacyify_cache(obj): self.model.forward = _qeff_patched_forward self.model._qeff_export_gemma3_cache_patch = True + # Optionally inject a user-provided infix token into the LLM KV-cache retained-state names + # (e.g. past_key.0_RetainedState -> past_key.0__RetainedState) so downstream consumers + # (vLLM disaggregated serving) can regex-select only the LLM KV buffers for transfer. + kv_cache_prefix = validate_kv_cache_prefix(kv_cache_prefix) + if kv_cache_prefix: + output_names = apply_kv_cache_prefix(output_names, kv_cache_prefix) + self.hash_params["kv_cache_prefix"] = kv_cache_prefix + if os.environ.get("LAYERWISE_EXPORT", "False") == "True": return self._export_layerwise( example_inputs, @@ -3678,6 +3707,7 @@ def compile( enable_chunking: Optional[bool] = False, moe_prefill_packed_chunk_size: int = constants.MOE_PREFILL_PACKED_CHUNK_SIZE, retain_full_kv: Optional[bool] = None, + kv_cache_prefix: Optional[str] = None, **compiler_options, ) -> str: """ @@ -3956,17 +3986,22 @@ def compile( target_dtype = getattr(self.model.config, "torch_dtype", torch.float32) kv_cache_dtype = "mxint8" if mxint8_kv_cache else CUSTOM_IO_DTYPE_MAP[target_dtype] # --- Compilation --- + # When a KV-cache prefix is requested, the exported buffers are named + # past_key.{i}_ (input) and past_key.{i}__RetainedState (output); the custom_io + # keys must match those names so the compiler pairs and retains them correctly. + kv_cache_prefix = validate_kv_cache_prefix(kv_cache_prefix) + kv_infix = f"_{kv_cache_prefix}" if kv_cache_prefix else "" custom_io = {} if not cache_compressed: for suffix in ["", "_RetainedState"]: for i in range(self.num_layers): for kv in ["key", "value"]: - custom_io[f"past_{kv}.{i}{suffix}"] = kv_cache_dtype + custom_io[f"past_{kv}.{i}{kv_infix}{suffix}"] = kv_cache_dtype else: for suffix in ["", "_RetainedState"]: for i in range(self.num_layers): - custom_io[f"compressed_kv.{i}{suffix}"] = kv_cache_dtype - custom_io[f"k_pe.{i}{suffix}"] = kv_cache_dtype + custom_io[f"compressed_kv.{i}{kv_infix}{suffix}"] = kv_cache_dtype + custom_io[f"k_pe.{i}{kv_infix}{suffix}"] = kv_cache_dtype def filter_custom_io(custom_io_lang, onnx_path): # Extract filename @@ -4017,6 +4052,7 @@ def filter_custom_io(custom_io_lang, onnx_path): offload_pt_weights=offload_pt_weights, enable_chunking=enable_chunking, retain_full_kv=retain_full_kv, + kv_cache_prefix=kv_cache_prefix, **compiler_options, ) diff --git a/QEfficient/utils/__init__.py b/QEfficient/utils/__init__.py index c1c8fd777..278076fe6 100755 --- a/QEfficient/utils/__init__.py +++ b/QEfficient/utils/__init__.py @@ -11,6 +11,8 @@ ) from QEfficient.utils._utils import ( # noqa: F401 LRUCache, + align_kv_input_names_to_retained_outputs, + apply_kv_cache_prefix, check_and_assign_cache_dir, create_json, create_model_params, @@ -37,6 +39,7 @@ qpc_exists, require_value, to_named_specializations, + validate_kv_cache_prefix, ) from QEfficient.utils.compile_layerwise import ( # noqa: F401 run_compile_layerwise, diff --git a/QEfficient/utils/_utils.py b/QEfficient/utils/_utils.py index 24ab88aa0..5d4a6ada7 100755 --- a/QEfficient/utils/_utils.py +++ b/QEfficient/utils/_utils.py @@ -30,6 +30,102 @@ from QEfficient.utils.hash_utils import json_serializable from QEfficient.utils.logging_utils import logger +# Retained-state buffer name stems that correspond to the LLM KV cache. Only these are eligible for +# the optional vLLM KV-cache prefix; vision/multimodal retained buffers (vision_embeds, pixel_values, +# image_idx, deepstack_features, ...) are intentionally excluded. +_KV_RETAINED_STEMS = ("past_key.", "past_value.", "compressed_kv.", "k_pe.") +_RETAINED_STATE_SUFFIX = "_RetainedState" + + +def validate_kv_cache_prefix(kv_cache_prefix: Optional[str]) -> Optional[str]: + """ + Validate the optional KV-cache buffer-name prefix. + + The prefix is injected as an infix token into KV retained-state names + (``past_key.0_RetainedState`` -> ``past_key.0__RetainedState``), so it must be a plain + alphanumeric token. Disallowing ``.`` and ``_`` keeps the ``past_key.{layer}_{prefix}`` structure + unambiguous for downstream regex matching. + + Returns the prefix unchanged when valid, or ``None`` when not provided. + """ + if kv_cache_prefix is None: + return None + if not isinstance(kv_cache_prefix, str) or not kv_cache_prefix.isalnum(): + raise ValueError( + "kv_cache_prefix must be a non-empty alphanumeric string (no '.', '_' or whitespace); " + f"got {kv_cache_prefix!r}" + ) + return kv_cache_prefix + + +def _infix_kv_prefix(name: str, kv_cache_prefix: str) -> str: + """Insert ``_`` before the ``_RetainedState`` suffix for LLM KV-cache buffers only.""" + if not name.endswith(_RETAINED_STATE_SUFFIX): + return name + stem = name[: -len(_RETAINED_STATE_SUFFIX)] + if not any(stem.startswith(kv_stem) for kv_stem in _KV_RETAINED_STEMS): + return name + return f"{stem}_{kv_cache_prefix}{_RETAINED_STATE_SUFFIX}" + + +def apply_kv_cache_prefix(output_names, kv_cache_prefix: Optional[str]): + """ + Insert an infix token into LLM KV-cache retained-state output names. + + ``past_key.0_RetainedState`` -> ``past_key.0__RetainedState`` (and likewise for + ``past_value`` / ``compressed_kv`` / ``k_pe``). The matching device input buffer is named by the + compiler by stripping ``_RetainedState`` (``past_key.0_``), so KV retention pairing is + preserved. Vision/multimodal retained buffers are left untouched. + + Accepts either a flat ``List[str]`` (CausalLM / single-QPC VLM) or the + ``{"vision": [...], "lang": [...]}`` dict (dual-QPC VLM); for the dict form only the ``lang`` list + is rewritten. No-op when ``kv_cache_prefix`` is falsy. The input is not mutated in place. + """ + if not kv_cache_prefix: + return output_names + validate_kv_cache_prefix(kv_cache_prefix) + + if isinstance(output_names, dict): + result = dict(output_names) + if result.get("lang") is not None: + result["lang"] = [_infix_kv_prefix(name, kv_cache_prefix) for name in result["lang"]] + return result + return [_infix_kv_prefix(name, kv_cache_prefix) for name in output_names] + + +def align_kv_input_names_to_retained_outputs(input_names, output_names): + """ + Rename KV-cache *input* buffers so each pairs with its retained-state *output*. + + The AIC compiler retains a KV buffer by matching an output ``X_RetainedState`` to the input named + ``X`` (suffix stripped). When the retained outputs carry an injected prefix + (``past_key.0__RetainedState``), the corresponding input must be renamed from + ``past_key.0`` to ``past_key.0_`` for the pairing to hold. + + This derives the rename purely from ``output_names`` (which already carry any prefix), so callers + that build prefixed outputs do not need to thread the prefix separately. It is a no-op for inputs + that already match a retained output exactly, and for non-KV inputs. ``input_names`` is not mutated. + """ + # Stripped target names from retained KV outputs, e.g. {"past_key.0_VLLM", "past_value.0_VLLM"}. + retained_targets = [] + for name in output_names: + if not name.endswith(_RETAINED_STATE_SUFFIX): + continue + stem = name[: -len(_RETAINED_STATE_SUFFIX)] + if any(stem.startswith(kv_stem) for kv_stem in _KV_RETAINED_STEMS): + retained_targets.append(stem) + retained_set = set(retained_targets) + + aligned = [] + for name in input_names: + if not any(name.startswith(stem) for stem in _KV_RETAINED_STEMS) or name in retained_set: + aligned.append(name) + continue + # Find a retained target that is this input with an extra "_" infix. + match = next((t for t in retained_targets if t == name or t.startswith(name + "_")), None) + aligned.append(match if match is not None else name) + return aligned + class LRUCache: """Simple LRU cache with size limit for vision outputs""" diff --git a/tests/unit_test/models/test_model_quickcheck.py b/tests/unit_test/models/test_model_quickcheck.py index 4b7ed6f17..a0594019d 100644 --- a/tests/unit_test/models/test_model_quickcheck.py +++ b/tests/unit_test/models/test_model_quickcheck.py @@ -1282,3 +1282,249 @@ def test_no_tag_falls_back_to_lm_rules(self): result = to_named_specializations(flat) assert result[0]["name"] == "Prefill" assert result[1]["name"] == "Decode" + + +# --------------------------------------------------------------------------- +# Tests for the optional KV-cache buffer-name prefix (vLLM disaggregated KV transfer) +# --------------------------------------------------------------------------- + + +def _retained_state_outputs(onnx_path: Path) -> Set[str]: + onnx_model = onnx.load(onnx_path, load_external_data=False) + return {out.name for out in onnx_model.graph.output if out.name.endswith("_RetainedState")} + + +def _kv_input_names(onnx_path: Path) -> Set[str]: + onnx_model = onnx.load(onnx_path, load_external_data=False) + return { + inp.name + for inp in onnx_model.graph.input + if inp.name.startswith(("past_key.", "past_value.", "compressed_kv.", "k_pe.")) + } + + +class TestApplyKvCachePrefixHelper: + """Unit tests for the apply_kv_cache_prefix / validate_kv_cache_prefix helpers.""" + + def test_flat_list_kv_only(self): + from QEfficient.utils import apply_kv_cache_prefix + + names = ["logits", "past_key.0_RetainedState", "past_value.0_RetainedState", "vision_embeds_RetainedState"] + result = apply_kv_cache_prefix(names, "VLLM") + assert result == [ + "logits", + "past_key.0_VLLM_RetainedState", + "past_value.0_VLLM_RetainedState", + "vision_embeds_RetainedState", # vision buffer untouched + ] + + def test_compressed_kv_and_k_pe(self): + from QEfficient.utils import apply_kv_cache_prefix + + names = ["compressed_kv.0_RetainedState", "k_pe.0_RetainedState"] + assert apply_kv_cache_prefix(names, "P") == ["compressed_kv.0_P_RetainedState", "k_pe.0_P_RetainedState"] + + def test_dict_form_lang_only(self): + from QEfficient.utils import apply_kv_cache_prefix + + names = { + "vision": ["vision_embeds", "past_key.0_RetainedState"], # vision side never rewritten + "lang": ["logits", "vision_embeds_RetainedState", "past_key.0_RetainedState"], + } + result = apply_kv_cache_prefix(names, "VLLM") + assert result["vision"] == ["vision_embeds", "past_key.0_RetainedState"] + assert result["lang"] == ["logits", "vision_embeds_RetainedState", "past_key.0_VLLM_RetainedState"] + + def test_noop_when_prefix_absent(self): + from QEfficient.utils import apply_kv_cache_prefix + + names = ["logits", "past_key.0_RetainedState"] + assert apply_kv_cache_prefix(names, None) == names + assert apply_kv_cache_prefix(names, "") == names + + def test_align_inputs_to_retained_outputs(self): + from QEfficient.utils import align_kv_input_names_to_retained_outputs + + input_names = ["input_ids", "past_key.0", "past_value.0"] + output_names = ["logits", "past_key.0_VLLM_RetainedState", "past_value.0_VLLM_RetainedState"] + assert align_kv_input_names_to_retained_outputs(input_names, output_names) == [ + "input_ids", + "past_key.0_VLLM", + "past_value.0_VLLM", + ] + + def test_align_inputs_noop_without_prefix(self): + from QEfficient.utils import align_kv_input_names_to_retained_outputs + + input_names = ["input_ids", "past_key.0", "past_value.0"] + output_names = ["logits", "past_key.0_RetainedState", "past_value.0_RetainedState"] + assert align_kv_input_names_to_retained_outputs(input_names, output_names) == input_names + + @pytest.mark.parametrize("bad", ["", "a_b", "a.b", "a b", 123, "past-key"]) + def test_validation_rejects_bad_prefix(self, bad): + from QEfficient.utils import validate_kv_cache_prefix + + with pytest.raises(ValueError): + validate_kv_cache_prefix(bad) + + def test_validation_accepts_alnum_and_none(self): + from QEfficient.utils import validate_kv_cache_prefix + + assert validate_kv_cache_prefix("VLLM") == "VLLM" + assert validate_kv_cache_prefix("vllm0") == "vllm0" + assert validate_kv_cache_prefix(None) is None + + +@pytest.mark.llm_model +def test_causal_export_with_kv_cache_prefix(tmp_path): + model_id = CAUSAL_RUNTIME_MODEL_IDS["gpt2"] + model_hf = AutoModelForCausalLM.from_pretrained( + model_id, **MODEL_KWARGS, low_cpu_mem_usage=False, torch_dtype=torch.float32 + ) + model_hf.eval() + qeff_model = QEFFAutoModelForCausalLM(model_hf) + + onnx_path = _exported_onnx_path(qeff_model.export(tmp_path / "prefixed", kv_cache_prefix="VLLM")) + + retained = _retained_state_outputs(onnx_path) + assert retained, "expected retained-state outputs" + # Every KV retained output carries the infix; the suffix is preserved. + kv_retained = {name for name in retained if name.startswith(("past_key.", "past_value."))} + assert kv_retained + assert all(name.endswith("_VLLM_RetainedState") for name in kv_retained) + + # The matching device input buffer exists (output minus _RetainedState). + kv_inputs = _kv_input_names(onnx_path) + for out_name in kv_retained: + stripped = out_name[: -len("_RetainedState")] + assert stripped in kv_inputs, f"missing paired input buffer for {out_name}" + assert stripped.endswith("_VLLM") + + +@pytest.mark.llm_model +def test_causal_export_default_names_unchanged(tmp_path): + """Without the flag, retained-state names must remain byte-for-byte identical to today.""" + model_id = CAUSAL_RUNTIME_MODEL_IDS["gpt2"] + model_hf = AutoModelForCausalLM.from_pretrained( + model_id, **MODEL_KWARGS, low_cpu_mem_usage=False, torch_dtype=torch.float32 + ) + model_hf.eval() + qeff_model = QEFFAutoModelForCausalLM(model_hf) + + onnx_path = _exported_onnx_path(qeff_model.export(tmp_path / "default")) + retained = _retained_state_outputs(onnx_path) + kv_retained = {name for name in retained if name.startswith(("past_key.", "past_value."))} + assert kv_retained + assert all(name.endswith("_RetainedState") and "_VLLM_" not in name for name in kv_retained) + # Inputs use the plain names. + assert "past_key.0" in _kv_input_names(onnx_path) + + +@pytest.mark.llm_model +def test_causal_export_prefix_changes_hash_dir(tmp_path): + """Prefixed and unprefixed exports must land in distinct hashed dirs (no cache collision).""" + model_id = CAUSAL_RUNTIME_MODEL_IDS["gpt2"] + model_hf = AutoModelForCausalLM.from_pretrained( + model_id, **MODEL_KWARGS, low_cpu_mem_usage=False, torch_dtype=torch.float32 + ) + model_hf.eval() + + plain_path = _exported_onnx_path(QEFFAutoModelForCausalLM(model_hf).export(tmp_path / "p")) + + model_hf2 = AutoModelForCausalLM.from_pretrained( + model_id, **MODEL_KWARGS, low_cpu_mem_usage=False, torch_dtype=torch.float32 + ) + model_hf2.eval() + prefixed_path = _exported_onnx_path( + QEFFAutoModelForCausalLM(model_hf2).export(tmp_path / "p", kv_cache_prefix="VLLM") + ) + + assert plain_path.parent != prefixed_path.parent + + +@pytest.mark.llm_model +def test_causal_compile_custom_io_carries_prefix(tmp_path, monkeypatch): + """The compile custom_io must pair prefixed input/output KV buffers.""" + try: + qeff_model = QEFFAutoModelForCausalLM.from_pretrained( + CAUSAL_RUNTIME_MODEL_IDS["llama"], + num_hidden_layers=2, + **MODEL_KWARGS, + ) + except Exception as exc: + _skip_on_model_fetch_error(exc, CAUSAL_RUNTIME_MODEL_IDS["llama"]) + + captured = {} + + def fake_compile(**kwargs): + captured.update(kwargs) + return tmp_path / "qpc" + + monkeypatch.setattr(qeff_model, "_compile", fake_compile) + qeff_model.compile(prefill_seq_len=8, ctx_len=32, compile_dir=str(tmp_path), kv_cache_prefix="VLLM") + + custom_io = captured["custom_io"] + assert custom_io, "expected non-empty custom_io" + assert "past_key.0_VLLM" in custom_io # input buffer + assert "past_key.0_VLLM_RetainedState" in custom_io # paired retained output + assert all("_VLLM" in k for k in custom_io if k.startswith(("past_key.", "past_value."))) + assert captured.get("kv_cache_prefix") == "VLLM" + + # Critical safety check: kv_cache_prefix must NOT appear in compiler_options. + # The _compile signature has kv_cache_prefix as an explicit named param so Python never places + # it in **compiler_options — if it did, the compiler would see "-kv-cache-prefix=VLLM" and fail. + # We verify by reconstructing the known explicit params and confirming the remainder (what would + # become **compiler_options in the real _compile) does not contain kv_cache_prefix. + _known_explicit_params = { + "onnx_path", + "compile_dir", + "mxint8_kv_cache", + "specializations", + "custom_io", + "mdp_ts_num_devices", + "num_speculative_tokens", + "enable_qnn", + "qnn_config", + "use_onnx_subfunctions", + "prefill_only", + "offload_pt_weights", + "enable_chunking", + "retain_full_kv", + "qaic_config", + "specialization_module_name", + "kv_cache_prefix", + "retained_state", + "convert_to_fp16", + "mxfp6_matmul", + # compile-time args added by causal compile(): + "aic_num_cores", + "moe_prefill_packed_chunk_size", + } + implicit_compiler_options = {k: v for k, v in captured.items() if k not in _known_explicit_params} + assert "kv_cache_prefix" not in implicit_compiler_options, ( + "kv_cache_prefix leaked into compiler_options — would produce an invalid compiler flag" + ) + + +@pytest.mark.llm_model +def test_vlm_export_prefix_lang_only(tmp_path): + """VLM export with prefix: lang KV buffers prefixed, vision retained buffers untouched.""" + try: + vlm_model = QEFFAutoModelForImageTextToText.from_pretrained( + VLM_TEXT_RUNTIME_MODEL_ID, trust_remote_code=True, kv_offload=True + ) + except Exception as exc: + _skip_on_model_fetch_error(exc, VLM_TEXT_RUNTIME_MODEL_ID) + + vlm_model.export(tmp_path / "vlm-prefixed", kv_cache_prefix="VLLM") + lang_onnx = Path(vlm_model.lang_model.onnx_path) + retained = _retained_state_outputs(lang_onnx) + + kv_retained = {name for name in retained if name.startswith(("past_key.", "past_value."))} + assert kv_retained + assert all(name.endswith("_VLLM_RetainedState") for name in kv_retained) + + # Vision/multimodal retained buffers on the lang graph must NOT be prefixed. + for name in retained: + if name.startswith(("vision_embeds", "pixel_values", "deepstack_features")): + assert "_VLLM_" not in name