From f842a30878af2c8c56d548191427e60f56b433af Mon Sep 17 00:00:00 2001 From: vbaddi Date: Fri, 5 Jun 2026 17:38:26 +0530 Subject: [PATCH 1/8] feat(0605): Encapsulate layerwise export behind .compile()/.export() flag Move the layerwise export+stitch+compile orchestration loop into a single internal driver gated by a new layerwise=True kwarg on .compile() and .export(). The flag is opt-in; layerwise=False remains the default and the non-layerwise compile path is unchanged byte-for-byte. The LAYERWISE_EXPORT environment variable is removed entirely; control flows purely through the API via a process-local QEFFBaseModel._layerwise_active flag toggled by an internal context manager. Supported architectures are allowlisted (qwen3_vl_moe, qwen3_5_moe, qwen3_moe); other model types raise NotImplementedError when layerwise=True. Wired on QEFFAutoModelForImageTextToText (dual-QPC) and QEFFAutoModelForCausalLM. Five existing layerwise example scripts collapse from 200-330 lines to ~60 lines each. The encapsulation module is documented as provisional and emits a one-shot DeprecationWarning. test_model_quickcheck.py: 121 -> 127 passed, 3 skipped (unchanged) with five new tests covering the windowing helpers, the supported/unsupported guard, the env-var-not-leaked invariant, and the context manager's class-flag toggle. Signed-off-by: vbaddi --- QEfficient/base/modeling_qeff.py | 3 +- QEfficient/transformers/models/_layerwise.py | 396 ++++++++++++++++++ .../transformers/models/modeling_auto.py | 213 +++++++++- examples/disagg_serving/qwen3moe_layerwise.py | 296 +------------ .../qwen3_5_moe/qwen3_5_moe_layerwise.py | 315 ++------------ .../qwen3_5_moe_layerwise_decode.py | 312 ++------------ .../qwen3_vl_moe/qwen3_vl_moe_layerwise.py | 322 ++------------ .../qwen3_vl_moe_layerwise_decode.py | 322 ++------------ .../unit_test/models/test_model_quickcheck.py | 115 +++++ 9 files changed, 841 insertions(+), 1453 deletions(-) create mode 100644 QEfficient/transformers/models/_layerwise.py diff --git a/QEfficient/base/modeling_qeff.py b/QEfficient/base/modeling_qeff.py index 3bb05c7b9..455b54593 100755 --- a/QEfficient/base/modeling_qeff.py +++ b/QEfficient/base/modeling_qeff.py @@ -65,6 +65,7 @@ class QEFFBaseModel(ABC): _start = 0 _end = 0 _total_layers = None + _layerwise_active = False _pytorch_transforms: List[PytorchTransform] _onnx_transforms = [BaseOnnxTransform] @@ -757,7 +758,7 @@ def _compile( **compiler_options, ) onnx_path = Path(onnx_path) - if os.environ.get("LAYERWISE_EXPORT", "False") == "True": + if QEFFBaseModel._layerwise_active: return onnx_path compile_dir = Path(compile_dir or onnx_path.parent) diff --git a/QEfficient/transformers/models/_layerwise.py b/QEfficient/transformers/models/_layerwise.py new file mode 100644 index 000000000..075a8865e --- /dev/null +++ b/QEfficient/transformers/models/_layerwise.py @@ -0,0 +1,396 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +""" +Layer-wise export/compile orchestration for selected MoE architectures. + +Provisional API. Scheduled for removal once first-class multi-window export lands. +Today only ``qwen3_vl_moe``, ``qwen3_5_moe`` and ``qwen3_moe`` carry the +windowing hooks (``_start``/``_end`` class attributes) that this driver relies +on; calling :func:`run_layerwise` for any other architecture will raise. +""" + +from __future__ import annotations + +import functools +import warnings +from contextlib import contextmanager +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple + +import transformers + +import QEfficient + +# Architectures whose modeling files declare _start/_end class attributes the +# layer-wise driver pokes. Keep this list narrow on purpose - adding a new +# architecture must come with the corresponding modeling-file hooks. +_LAYERWISE_SUPPORTED_MODEL_TYPES = frozenset( + { + "qwen3_vl_moe", + "qwen3_vl_moe_text", + "qwen3_5_moe", + "qwen3_moe", + } +) + +_DEPRECATION_WARNED = False + + +def _maybe_warn_deprecation() -> None: + global _DEPRECATION_WARNED + if _DEPRECATION_WARNED: + return + warnings.warn( + "layerwise=True is a provisional API and will be deprecated once " + "first-class multi-window export lands. Use only for the supported " + f"model types: {sorted(_LAYERWISE_SUPPORTED_MODEL_TYPES)}.", + DeprecationWarning, + stacklevel=3, + ) + _DEPRECATION_WARNED = True + + +def _resolve_model_type(config) -> str: + text_config = getattr(config, "text_config", None) + if text_config is not None and getattr(text_config, "model_type", None): + return text_config.model_type + return getattr(config, "model_type", "") or "" + + +def assert_layerwise_supported(config) -> str: + """Raise a clear error if the model architecture has no layerwise hooks.""" + top_type = getattr(config, "model_type", "") or "" + text_type = _resolve_model_type(config) + if top_type in _LAYERWISE_SUPPORTED_MODEL_TYPES or text_type in _LAYERWISE_SUPPORTED_MODEL_TYPES: + return text_type or top_type + raise NotImplementedError( + "layerwise=True is only supported for model types: " + f"{sorted(_LAYERWISE_SUPPORTED_MODEL_TYPES)}. Got model_type=" + f"'{top_type}' (text_config.model_type='{text_type}'). " + "Run with layerwise=False (the default) for this model." + ) + + +# --------------------------------------------------------------------------- +# Internal helpers (lifted from the legacy example script) +# --------------------------------------------------------------------------- + + +def _ensure_pretrained_window_attrs() -> None: + pt = transformers.modeling_utils.PreTrainedModel + for attr in ("_start", "_end", "_total_layers", "_text_start", "_text_end", "_text_total_layers"): + if not hasattr(pt, attr): + setattr(pt, attr, 0) + + +def _build_layer_windows(total_layers: int, window_size: int) -> List[Tuple[int, int]]: + if total_layers <= 0: + raise ValueError(f"Invalid total_layers={total_layers}; expected > 0.") + if window_size <= 0: + raise ValueError(f"Invalid window_size={window_size}; expected > 0.") + windows: List[Tuple[int, int]] = [] + start = 0 + while start < total_layers: + end = min(total_layers, start + window_size) + windows.append((start, end)) + start = end + return windows + + +def _get_text_layers_container(model): + if ( + hasattr(model, "model") + and hasattr(model.model, "language_model") + and hasattr(model.model.language_model, "layers") + ): + return model.model.language_model.layers + if hasattr(model, "model") and hasattr(model.model, "layers"): + return model.model.layers + if hasattr(model, "language_model") and hasattr(model.language_model, "layers"): + return model.language_model.layers + if hasattr(model, "layers"): + return model.layers + return None + + +def _null_outside_window_layers(model, *, apply_text: bool = True) -> None: + if not apply_text: + return + pt = transformers.modeling_utils.PreTrainedModel + text_start = int(getattr(pt, "_text_start", getattr(pt, "_start", 0))) + text_end = int(getattr(pt, "_text_end", getattr(pt, "_end", 0))) + text_layers = _get_text_layers_container(model) + if text_layers is not None and text_end > text_start: + for idx, _ in enumerate(text_layers): + if idx < text_start or idx >= text_end: + text_layers[idx] = None + + +def _install_window_patch(model_cls) -> None: + if getattr(model_cls, "_window_patch_installed", False): + return + original_init = model_cls.__init__ + + @functools.wraps(original_init) + def patched_init(self, *args, **kwargs): + original_init(self, *args, **kwargs) + _null_outside_window_layers(self, apply_text=True) + + model_cls.__init__ = patched_init + model_cls._window_patch_installed = True + + +def _install_shard_window_patch() -> None: + if getattr(transformers.modeling_utils, "_window_shard_patch_installed", False): + return + + original_get_checkpoint_shard_files = transformers.modeling_utils.get_checkpoint_shard_files + + @functools.wraps(original_get_checkpoint_shard_files) + def patched_get_checkpoint_shard_files(*args, **kwargs): + shard_files, metadata = original_get_checkpoint_shard_files(*args, **kwargs) + weight_map = metadata.get("weight_map") + if not weight_map: + return shard_files, metadata + + pt = transformers.modeling_utils.PreTrainedModel + start = int(getattr(pt, "_start", 0)) + end = int(getattr(pt, "_end", 0)) + text_start = int(getattr(pt, "_text_start", start)) + text_end = int(getattr(pt, "_text_end", end)) + if text_end <= text_start: + return shard_files, metadata + + selected_text_prefixes = tuple( + [f"model.layers.{i}." for i in range(text_start, text_end)] + + [f"model.language_model.layers.{i}." for i in range(text_start, text_end)] + ) + filtered_weight_map: Dict[str, str] = {} + for checkpoint_key, shard_name in weight_map.items(): + if checkpoint_key.startswith("model.layers.") or checkpoint_key.startswith("model.language_model.layers."): + if checkpoint_key.startswith(selected_text_prefixes): + filtered_weight_map[checkpoint_key] = shard_name + continue + filtered_weight_map[checkpoint_key] = shard_name + + if not filtered_weight_map: + return shard_files, metadata + + shard_name_to_path = {path.split("/")[-1]: path for path in shard_files} + filtered_shard_names = sorted(set(filtered_weight_map.values())) + filtered_shard_files = [shard_name_to_path[name] for name in filtered_shard_names if name in shard_name_to_path] + if not filtered_shard_files: + return shard_files, metadata + + metadata["weight_map"] = filtered_weight_map + metadata["all_checkpoint_keys"] = list(filtered_weight_map.keys()) + return filtered_shard_files, metadata + + transformers.modeling_utils.get_checkpoint_shard_files = patched_get_checkpoint_shard_files + transformers.modeling_utils._window_shard_patch_installed = True + + +def _set_layer_windows(text_start: int, text_end: int, text_total_layers: int) -> None: + pt = transformers.modeling_utils.PreTrainedModel + pt._start = text_start + pt._end = text_end + pt._total_layers = text_total_layers + pt._text_start = text_start + pt._text_end = text_end + pt._text_total_layers = text_total_layers + + qeff_vl_mod = getattr(QEfficient.transformers.models, "qwen3_vl_moe", None) + if qeff_vl_mod is not None: + cls = getattr(qeff_vl_mod.modeling_qwen3_vl_moe, "QEffQwen3VLMoeTextModel", None) + if cls is not None: + cls._start = text_start + cls._end = text_end + cls._total_layers = text_total_layers + + qeff_35_mod = getattr(QEfficient.transformers.models, "qwen3_5_moe", None) + if qeff_35_mod is not None: + cls = getattr(qeff_35_mod.modeling_qwen3_5_moe, "QEffQwen3_5MoeTextModel", None) + if cls is not None: + cls._start = text_start + cls._end = text_end + cls._total_layers = text_total_layers + + qeff_3_mod = getattr(QEfficient.transformers.models, "qwen3_moe", None) + if qeff_3_mod is not None: + cls = getattr(qeff_3_mod.modeling_qwen3_moe, "QEffQwen3MoeModel", None) + if cls is not None: + cls._start = text_start + cls._end = text_end + cls._total_layers = text_total_layers + + QEfficient.base.modeling_qeff.QEFFBaseModel._start = text_start + QEfficient.base.modeling_qeff.QEFFBaseModel._end = text_end + QEfficient.base.modeling_qeff.QEFFBaseModel._total_layers = text_total_layers + + +def _reset_layer_windows() -> None: + _set_layer_windows(0, 0, 0) + + +def _resolve_export_root(onnx_path: Path) -> Path: + parts = list(onnx_path.parts) + if "onnx_layerwise_tmp" in parts: + return Path(*parts[: parts.index("onnx_layerwise_tmp")]) + return onnx_path.parent + + +def _stitch_layerwise_if_available(export_root: Path) -> str: + pipeline_fn = getattr(QEfficient.utils, "layerwise_pipeline", None) + if callable(pipeline_fn): + return pipeline_fn(str(export_root)) + return str(export_root / "onnx_layerwise_tmp") + + +def _install_window_patches_for(model_type: str) -> None: + """Install the HF __init__/shard patches needed for the given model_type.""" + _install_shard_window_patch() + if "qwen3_vl_moe" in model_type: + hf_mod = transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe + _install_window_patch(hf_mod.Qwen3VLMoeForConditionalGeneration) + if hasattr(hf_mod, "Qwen3VLMoeForCausalLM"): + _install_window_patch(hf_mod.Qwen3VLMoeForCausalLM) + + +@contextmanager +def _layerwise_export_env(): + """Toggle the layerwise-active flag on QEFFBaseModel for the inner block. + + No environment variables are touched - the flag is a pure in-process class + attribute, which keeps the API self-contained and lets concurrent Python + interpreters (e.g. test workers) operate independently. + """ + base = QEfficient.base.modeling_qeff.QEFFBaseModel + prev = getattr(base, "_layerwise_active", False) + base._layerwise_active = True + try: + yield + finally: + base._layerwise_active = prev + + +def _resolve_text_total_layers(config) -> int: + text_config = getattr(config, "text_config", config) + total = getattr(text_config, "num_hidden_layers", None) + if total is None: + raise ValueError("Could not resolve `num_hidden_layers` from config.text_config / config.") + return int(total) + + +# --------------------------------------------------------------------------- +# Public driver +# --------------------------------------------------------------------------- + + +def run_layerwise( + *, + model_id: str, + config, + qeff_factory, + compile_kwargs: Dict[str, Any], + window_size: int = 1, + final_compile: bool = True, +) -> Any: + """Drive the per-window export loop and (optionally) the final stitched compile. + + Parameters + ---------- + model_id : str + HF id / path passed to ``qeff_factory(model_id, config)`` to (re)build a + QEff model fresh per window. + config : transformers.PretrainedConfig + Already-mutated config (the caller is responsible for any vision tweaks + like ``deepstack_visual_indexes``). + qeff_factory : Callable[[str, PretrainedConfig], QEffModel] + Factory invoked once per window to materialize a fresh QEff model that + only loads the active window's weights. + compile_kwargs : dict + Forwarded verbatim to ``qeff_model.compile(...)`` per window. The driver + injects ``skip_lang`` per-window and ``lang_onnx_path`` for the final + stitched compile. + window_size : int + Number of text-decoder layers per window. ``1`` matches the legacy + example. + final_compile : bool + When True (compile path), do the final QPC compile on the merged ONNX + and return ``qpc_paths``. When False (export path), return the merged + ONNX path only. + + Returns + ------- + Either the qpc_paths dict (final_compile=True) or the merged ONNX path + (final_compile=False). + """ + _maybe_warn_deprecation() + model_type = assert_layerwise_supported(config) + + text_total_layers = _resolve_text_total_layers(config) + text_cfg = getattr(config, "text_config", config) + text_cfg.num_hidden_layers = text_total_layers + + _ensure_pretrained_window_attrs() + _install_window_patches_for(model_type) + + windows = _build_layer_windows(text_total_layers, window_size) + first_onnx_path: Optional[Path] = None + last_qeff_model = None + + with _layerwise_export_env(): + for window_idx, (text_start, text_end) in enumerate(windows): + _set_layer_windows(text_start, text_end, text_total_layers) + + qeff_model = qeff_factory(model_id, config) + last_qeff_model = qeff_model + if hasattr(qeff_model, "model"): + _null_outside_window_layers(qeff_model.model, apply_text=True) + + window_kwargs = dict(compile_kwargs) + # skip_lang is a VLM-only kwarg; only inject when present in caller's kwargs. + if "skip_lang" in window_kwargs: + window_kwargs["skip_lang"] = False + onnx_path = qeff_model.compile(**window_kwargs) + if first_onnx_path is None: + if isinstance(onnx_path, dict): + lang_key = next( + ( + k + for k in ( + "lang_decode_qpc_path", + "lang_prefill_qpc_path", + "lang_qpc_path", + ) + if k in onnx_path + ), + None, + ) + if lang_key is None: + raise RuntimeError(f"Layer-wise window produced no lang_*_qpc_path: keys={list(onnx_path)}") + lang_path = onnx_path[lang_key] + else: + lang_path = onnx_path + first_onnx_path = Path(str(lang_path)) + + if first_onnx_path is None: + raise RuntimeError("Layer-wise export produced no ONNX shards.") + + export_root = _resolve_export_root(first_onnx_path) + final_artifact = _stitch_layerwise_if_available(export_root) + + _reset_layer_windows() + + if not final_compile: + return final_artifact + + final_kwargs = dict(compile_kwargs) + final_kwargs["lang_onnx_path"] = final_artifact + final_kwargs.setdefault("skip_lang", False) + return last_qeff_model.compile(**final_kwargs) diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 65b89d274..108b3e5fd 100755 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -1156,7 +1156,7 @@ def export( self.hash_params["prefill_only"] = False self.__update_prefill_transform(False, retain_full_kv=kwargs.get("retain_full_kv", False)) - if os.environ.get("LAYERWISE_EXPORT", "False") == "True": + if QEfficient.base.modeling_qeff.QEFFBaseModel._layerwise_active: return self._export_layerwise( inputs, output_names=output_names, @@ -1280,6 +1280,7 @@ def __init__( ) self.model = model self.config = model.config + self._pretrained_model_name_or_path = kwargs.get("pretrained_model_name_or_path", None) self.vision_model = QEffVisionEncoderForTextImageToTextModel(model, **kwargs) self.lang_model = QEffCausalLMForTextImageToTextModel(model, qaic_config=qaic_config, **kwargs) @@ -1357,6 +1358,8 @@ def export( prefill_seq_len: Optional[int] = None, prefill_only: bool = False, enable_chunking: bool = False, + layerwise: bool = False, + layerwise_window_size: int = 1, **kwargs, ) -> str: """ @@ -1379,6 +1382,18 @@ def export( List[str] A list containing the paths to the generated ONNX graph files for both components. """ + if layerwise: + return self._run_layerwise_export( + export_dir=export_dir, + use_onnx_subfunctions=use_onnx_subfunctions, + skip_vision=skip_vision, + skip_lang=skip_lang, + prefill_seq_len=prefill_seq_len, + prefill_only=prefill_only, + enable_chunking=enable_chunking, + layerwise_window_size=layerwise_window_size, + **kwargs, + ) dummy_inputs_kwargs = {} if prefill_seq_len is not None: dummy_inputs_kwargs["prefill_seq_len"] = int(prefill_seq_len) @@ -1416,7 +1431,7 @@ def export( qaic_config=self.lang_model.model.qaic_config, ) - layerwise_export = os.environ.get("LAYERWISE_EXPORT", "False") == "True" + layerwise_export = QEFFBaseModel._layerwise_active should_export = not skip_vision and ( not layerwise_export @@ -1482,6 +1497,96 @@ def transform( **compiler_options, ) + def _layerwise_factory_kwargs(self): + """Reproduce the from_pretrained kwargs needed to rebuild this wrapper per window.""" + # Mirror the dual-QPC from_pretrained surface; the layerwise driver passes + # config explicitly per call, so we only carry torch_dtype + attn here. + torch_dtype = getattr(self.config, "torch_dtype", None) + return { + "attn_implementation": "eager", + "kv_offload": True, + "torch_dtype": torch_dtype, + } + + def _build_layerwise_factory(self): + """Return a callable suitable for layerwise driver's qeff_factory hook.""" + from QEfficient import QEFFAutoModelForImageTextToText + + base_kwargs = self._layerwise_factory_kwargs() + + def _factory(model_id, config): + return QEFFAutoModelForImageTextToText.from_pretrained( + model_id, + config=config, + **base_kwargs, + ) + + return _factory + + def _run_layerwise_export( + self, + *, + export_dir, + use_onnx_subfunctions, + skip_vision, + skip_lang, + prefill_seq_len, + prefill_only, + enable_chunking, + layerwise_window_size, + **kwargs, + ): + from QEfficient.transformers.models import _layerwise + + model_id = self._pretrained_model_name_or_path + if model_id is None: + raise RuntimeError( + "layerwise=True requires the QEff model to be built via " + "QEFFAutoModelForImageTextToText.from_pretrained(...). " + "Direct __init__ does not preserve the model id needed for per-window reload." + ) + compile_kwargs = dict( + export_dir=export_dir, + use_onnx_subfunctions=use_onnx_subfunctions, + skip_vision=skip_vision, + prefill_seq_len=prefill_seq_len, + prefill_only=prefill_only, + enable_chunking=enable_chunking, + **kwargs, + ) + return _layerwise.run_layerwise( + model_id=model_id, + config=self.config, + qeff_factory=self._build_layerwise_factory(), + compile_kwargs=compile_kwargs, + window_size=layerwise_window_size, + final_compile=False, + ) + + def _run_layerwise_compile( + self, + *, + layerwise_window_size, + **compile_kwargs, + ): + from QEfficient.transformers.models import _layerwise + + model_id = self._pretrained_model_name_or_path + if model_id is None: + raise RuntimeError( + "layerwise=True requires the QEff model to be built via " + "QEFFAutoModelForImageTextToText.from_pretrained(...). " + "Direct __init__ does not preserve the model id needed for per-window reload." + ) + return _layerwise.run_layerwise( + model_id=model_id, + config=self.config, + qeff_factory=self._build_layerwise_factory(), + compile_kwargs=compile_kwargs, + window_size=layerwise_window_size, + final_compile=True, + ) + def compile( self, img_size: Optional[int] = None, @@ -1506,6 +1611,8 @@ def compile( prefill_only=None, enable_chunking=False, qaic_config: Optional[dict] = None, + layerwise: bool = False, + layerwise_window_size: int = 1, **compiler_options, ) -> str: """ @@ -1565,6 +1672,33 @@ def compile( if skip_lang and skip_vision: raise ValueError("Expected at least one of 'skip_lang' or 'skip_vision' to be False") + if layerwise: + return self._run_layerwise_compile( + img_size=img_size, + vision_onnx_path=vision_onnx_path, + lang_onnx_path=lang_onnx_path, + compile_dir=compile_dir, + prefill_seq_len=prefill_seq_len, + comp_ctx_lengths_prefill=comp_ctx_lengths_prefill, + comp_ctx_lengths_decode=comp_ctx_lengths_decode, + ctx_len=ctx_len, + batch_size=batch_size, + full_batch_size=full_batch_size, + kv_cache_batch_size=kv_cache_batch_size, + num_devices=num_devices, + num_cores=num_cores, + mxfp6_matmul=mxfp6_matmul, + mxint8_kv_cache=mxint8_kv_cache, + skip_vision=skip_vision, + skip_lang=skip_lang, + use_onnx_subfunctions=use_onnx_subfunctions, + prefill_only=prefill_only, + enable_chunking=enable_chunking, + qaic_config=qaic_config, + layerwise_window_size=layerwise_window_size, + **compiler_options, + ) + if self.continuous_batching and full_batch_size is None: raise TypeError("`full_batch_size` is required when `continuous_batching=True`.") @@ -3234,6 +3368,37 @@ def get_seq_len_and_handle_specialized_prefill_model( else constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN ) + def _run_layerwise(self, *, final_compile: bool, layerwise_window_size: int, **forward_kwargs): + """Drive the layer-wise export/compile loop for CausalLM models.""" + from QEfficient.transformers.models import _layerwise + + model_id = getattr(self.model, "pretrained_path", None) + if model_id is None: + raise RuntimeError( + "layerwise=True requires the QEff model to be built via " + "QEFFAutoModelForCausalLM.from_pretrained(...). " + "Direct __init__ does not preserve the model id needed for per-window reload." + ) + config = getattr(self.model, "config", None) + torch_dtype = getattr(config, "torch_dtype", None) + + def _factory(model_id, config): + return QEFFAutoModelForCausalLM.from_pretrained( + model_id, + config=config, + torch_dtype=torch_dtype, + continuous_batching=self.continuous_batching, + ) + + return _layerwise.run_layerwise( + model_id=model_id, + config=config, + qeff_factory=_factory, + compile_kwargs=forward_kwargs, + window_size=layerwise_window_size, + final_compile=final_compile, + ) + def export( self, export_dir: Optional[str] = None, @@ -3241,6 +3406,8 @@ def export( prefill_seq_len: Optional[int] = None, num_cores: int = constants.DEFAULT_AIC_NUM_CORES, moe_prefill_packed_chunk_size: int = constants.MOE_PREFILL_PACKED_CHUNK_SIZE, + layerwise: bool = False, + layerwise_window_size: int = 1, **kwargs, ) -> str: """ @@ -3268,6 +3435,18 @@ def export( "Use the default non-prefill export path for standard CausalLM decode graphs." ) + if layerwise: + return self._run_layerwise( + final_compile=False, + layerwise_window_size=layerwise_window_size, + export_dir=export_dir, + prefill_only=prefill_only, + prefill_seq_len=prefill_seq_len, + num_cores=num_cores, + moe_prefill_packed_chunk_size=moe_prefill_packed_chunk_size, + **kwargs, + ) + bs: int = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE seq_len: int = constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN @@ -3506,7 +3685,7 @@ def _legacyify_cache(obj): self.model.forward = _qeff_patched_forward self.model._qeff_export_gemma3_cache_patch = True - if os.environ.get("LAYERWISE_EXPORT", "False") == "True": + if QEFFBaseModel._layerwise_active: return self._export_layerwise( example_inputs, output_names=output_names, @@ -3678,6 +3857,8 @@ def compile( enable_chunking: Optional[bool] = False, moe_prefill_packed_chunk_size: int = constants.MOE_PREFILL_PACKED_CHUNK_SIZE, retain_full_kv: Optional[bool] = None, + layerwise: bool = False, + layerwise_window_size: int = 1, **compiler_options, ) -> str: """ @@ -3766,6 +3947,32 @@ def compile( If `prefill_seq_len` is less than `num_speculative_tokens + 1` for TLM models. """ + if layerwise: + return self._run_layerwise( + final_compile=True, + layerwise_window_size=layerwise_window_size, + onnx_path=onnx_path, + compile_dir=compile_dir, + prefill_seq_len=prefill_seq_len, + ctx_len=ctx_len, + comp_ctx_lengths_prefill=comp_ctx_lengths_prefill, + comp_ctx_lengths_decode=comp_ctx_lengths_decode, + batch_size=batch_size, + full_batch_size=full_batch_size, + kv_cache_batch_size=kv_cache_batch_size, + num_devices=num_devices, + num_cores=num_cores, + mxfp6_matmul=mxfp6_matmul, + mxint8_kv_cache=mxint8_kv_cache, + num_speculative_tokens=num_speculative_tokens, + prefill_only=prefill_only, + use_onnx_subfunctions=use_onnx_subfunctions, + offload_pt_weights=offload_pt_weights, + enable_chunking=enable_chunking, + moe_prefill_packed_chunk_size=moe_prefill_packed_chunk_size, + retain_full_kv=retain_full_kv, + **compiler_options, + ) if self.model.qaic_config is not None and self.model.qaic_config.get("mla_absorption", None) is not None: mla_absorption = self.model.qaic_config["mla_absorption"] cache_compressed = mla_absorption.get("cache_compressed", False) diff --git a/examples/disagg_serving/qwen3moe_layerwise.py b/examples/disagg_serving/qwen3moe_layerwise.py index ea29e1174..762df044e 100644 --- a/examples/disagg_serving/qwen3moe_layerwise.py +++ b/examples/disagg_serving/qwen3moe_layerwise.py @@ -5,210 +5,31 @@ # # ----------------------------------------------------------------------------- -import functools -import os -import time -from pathlib import Path +"""Layerwise prefill compile example for Qwen3-MoE (disaggregated serving). + +The orchestration loop that previously lived in this script has been moved +behind the ``layerwise=True`` flag on ``.compile()`` / ``.export()``. + +Note: ``layerwise=True`` is a provisional API and is scheduled for deprecation +once first-class multi-window export lands. Supported model types: +``qwen3_vl_moe``, ``qwen3_5_moe``, ``qwen3_moe``. +""" -import transformers from transformers import AutoConfig, AutoTokenizer -import QEfficient from QEfficient import QEFFAutoModelForCausalLM model_id = "Qwen/Qwen3-235B-A22B-Instruct-2507" # weights are not required to convert to fp32 -# model_id = "yujiepan/qwen3-moe-tiny-random" -prompt = """ -Explain quantum computing in simple terms. -""" -config = AutoConfig.from_pretrained(model_id) -tokenizer = AutoTokenizer.from_pretrained(model_id) -config = AutoConfig.from_pretrained(model_id) - -tokenizer = AutoTokenizer.from_pretrained(model_id) PREFILL_SEQ_LEN = 4 CTX_LEN = 128 +config = AutoConfig.from_pretrained(model_id) +tokenizer = AutoTokenizer.from_pretrained(model_id) -def _ensure_pretrained_window_attrs(): - if not hasattr(transformers.modeling_utils.PreTrainedModel, "_start"): - transformers.modeling_utils.PreTrainedModel._start = 0 - if not hasattr(transformers.modeling_utils.PreTrainedModel, "_end"): - transformers.modeling_utils.PreTrainedModel._end = 0 - - -def _build_layer_windows(total_layers: int, window_size: int): - if total_layers <= 0: - raise ValueError(f"Invalid total_layers={total_layers}. Expected: total_layers > 0.") - if window_size <= 0: - raise ValueError(f"Invalid window_size={window_size}. Expected: window_size > 0.") - - windows = [] - end = total_layers - while end > 0: - start = max(0, end - window_size) - windows.append((start, end)) - end = start - - return windows - - -def _null_outside_window_layers(model): - start = int(getattr(transformers.modeling_utils.PreTrainedModel, "_start", 0)) - end = int(getattr(transformers.modeling_utils.PreTrainedModel, "_end", 0)) - layers = getattr(getattr(model, "model", None), "layers", None) - if layers is None: - return - for idx, _ in enumerate(layers): - if idx < start or idx >= end: - layers[idx] = None - - -def _install_window_patch(model_cls): - if getattr(model_cls, "_window_patch_installed", False): - return - - original_init = model_cls.__init__ - - @functools.wraps(original_init) - def patched_init(self, *args, **kwargs): - original_init(self, *args, **kwargs) - _null_outside_window_layers(self) - - model_cls.__init__ = patched_init - model_cls._window_patch_installed = True - - -def _resolve_export_root(onnx_path: Path) -> Path: - parts = list(onnx_path.parts) - if "onnx_layerwise_tmp" in parts: - marker_idx = parts.index("onnx_layerwise_tmp") - return Path(*parts[:marker_idx]) - return onnx_path.parent - - -def _install_shard_window_patch(): - if getattr(transformers.modeling_utils, "_window_shard_patch_installed", False): - return - - original_get_checkpoint_shard_files = transformers.modeling_utils.get_checkpoint_shard_files - - @functools.wraps(original_get_checkpoint_shard_files) - def patched_get_checkpoint_shard_files(*args, **kwargs): - shard_files, metadata = original_get_checkpoint_shard_files(*args, **kwargs) - weight_map = metadata.get("weight_map") - if not weight_map: - return shard_files, metadata - - start = int(getattr(transformers.modeling_utils.PreTrainedModel, "_start", 0)) - end = int(getattr(transformers.modeling_utils.PreTrainedModel, "_end", 0)) - if end <= start: - return shard_files, metadata - - selected_prefixes = tuple(f"model.layers.{layer_idx}." for layer_idx in range(start, end)) - filtered_weight_map = {} - for checkpoint_key, shard_name in weight_map.items(): - if checkpoint_key.startswith("model.layers."): - if checkpoint_key.startswith(selected_prefixes): - filtered_weight_map[checkpoint_key] = shard_name - continue - filtered_weight_map[checkpoint_key] = shard_name - - if not filtered_weight_map: - return shard_files, metadata - - shard_name_to_path = {path.split("/")[-1]: path for path in shard_files} - filtered_shard_names = sorted(set(filtered_weight_map.values())) - filtered_shard_files = [shard_name_to_path[name] for name in filtered_shard_names if name in shard_name_to_path] - if not filtered_shard_files: - return shard_files, metadata - - metadata["weight_map"] = filtered_weight_map - metadata["all_checkpoint_keys"] = list(filtered_weight_map.keys()) - return filtered_shard_files, metadata - - transformers.modeling_utils.get_checkpoint_shard_files = patched_get_checkpoint_shard_files - transformers.modeling_utils._window_shard_patch_installed = True - - -_ensure_pretrained_window_attrs() -_install_shard_window_patch() -text_config = getattr(config, "text_config", config) -resolved_total_layers = getattr(text_config, "num_hidden_layers", None) -if resolved_total_layers is None: - raise ValueError("Could not resolve `num_hidden_layers` from config.") - -# Layerwise window size. `1` keeps only one decoder layer active per window. -window_size = 1 -total_layers = 2 # resolved_total_layers # config.num_hidden_layers = 1 -windows = _build_layer_windows(total_layers=total_layers, window_size=window_size) -qeff_model = None -first_onnx_path = None -export_start = time.perf_counter() - -os.environ["LAYERWISE_EXPORT"] = "True" -for start, end in windows: - transformers.modeling_utils.PreTrainedModel._start = start - transformers.modeling_utils.PreTrainedModel._end = end - transformers.modeling_utils.PreTrainedModel._total_layers = total_layers - QEfficient.transformers.models.qwen3_moe.modeling_qwen3_moe.QEffQwen3MoeModel._start = start - QEfficient.transformers.models.qwen3_moe.modeling_qwen3_moe.QEffQwen3MoeModel._end = end - QEfficient.transformers.models.qwen3_moe.modeling_qwen3_moe.QEffQwen3MoeModel._total_layers = total_layers - QEfficient.base.modeling_qeff.QEFFBaseModel._start = start - QEfficient.base.modeling_qeff.QEFFBaseModel._end = end - QEfficient.base.modeling_qeff.QEFFBaseModel._total_layers = total_layers - _install_window_patch(transformers.models.qwen3_moe.modeling_qwen3_moe.Qwen3MoeForCausalLM) - qeff_model = QEFFAutoModelForCausalLM.from_pretrained(model_id, config=config) - if hasattr(qeff_model, "model"): - _null_outside_window_layers(qeff_model.model) - - # Following command errors out by default, the user is supposed to run the printed command and provide the generated qpc path as prefill_qpc_path commenting out lines 55-68 - - # prefill_qpc_path = "" - ################################# prefill - - onnx_path = qeff_model.compile( - prefill_seq_len=PREFILL_SEQ_LEN, - ctx_len=CTX_LEN, - num_cores=16, - mxfp6_matmul=True, - mxint8_kv_cache=True, - num_devices=1, - split_retained_state_io=True, - mos=1, - aic_enable_depth_first=True, - num_speculative_tokens=None, - prefill_only=True, - enable_chunking=True, - use_onnx_subfunctions=True, - ) - ################################# decode - # onnx_path = qeff_model.compile( - # prefill_seq_len=PREFILL_SEQ_LEN, - # ctx_len=CTX_LEN, - # num_cores=16, - # mxfp6_matmul=True, - # mxint8_kv_cache=True, - # num_devices=1, - # split_retained_state_io=True, - # mos=1, - # aic_enable_depth_first=True, - # num_speculative_tokens=None, - # prefill_only=False, - # use_onnx_subfunctions=True, - # ) - if first_onnx_path is None: - first_onnx_path = Path(onnx_path) +qeff_model = QEFFAutoModelForCausalLM.from_pretrained(model_id, config=config) -if first_onnx_path is None: - raise RuntimeError("No ONNX path produced during compilation.") -export_root = _resolve_export_root(first_onnx_path) -final_onnx_path = QEfficient.utils.layerwise_pipeline(str(export_root)) -print(f"Layer-wise language export completed. Final artifact/root: {final_onnx_path}") -os.environ["LAYERWISE_EXPORT"] = "False" qpc_path = qeff_model.compile( - onnx_path=final_onnx_path, prefill_seq_len=PREFILL_SEQ_LEN, ctx_len=CTX_LEN, num_cores=16, @@ -222,97 +43,8 @@ def patched_get_checkpoint_shard_files(*args, **kwargs): prefill_only=True, enable_chunking=True, use_onnx_subfunctions=True, + layerwise=True, + layerwise_window_size=1, ) print(f"QPC path: {qpc_path}") - -# inputs = tokenizer(prompt, return_tensors="np", padding=True) -# position_ids = inputs["attention_mask"].sum(1, keepdims=True) -# generation_len = CTX_LEN - position_ids.max() -# padded_len = inputs["input_ids"].shape[1] -# num_chunks = -(padded_len // -PREFILL_SEQ_LEN) # ceil divide without float -# padded_len = num_chunks * PREFILL_SEQ_LEN # Convert to a multiple of prompt_len -# inputs = tokenizer(prompt, return_tensors="np", padding="max_length", max_length=padded_len) -# inputs["position_ids"] = np.where(inputs.pop("attention_mask"), np.arange(padded_len), -1) -# inputs.pop("token_type_ids", None) -# inputs = {k: torch.from_numpy(v) for k, v in inputs.items()} -# inputs.pop("past_key_values", None) -# inputs = {k: v.detach().numpy() for k, v in inputs.items()} - - -# prefill_session = QAICInferenceSession(prefill_qpc_path) - - -# all_outputs = [] -# for i in range(num_chunks): -# chunk_inputs = inputs.copy() -# chunk_inputs["input_ids"] = inputs["input_ids"][:, i * PREFILL_SEQ_LEN : (i + 1) * PREFILL_SEQ_LEN] -# chunk_inputs["position_ids"] = inputs["position_ids"][:, i * PREFILL_SEQ_LEN : (i + 1) * PREFILL_SEQ_LEN] -# ins = time.time() -# qpc_out = prefill_session.run(chunk_inputs) -# print(f"time for this run={time.time() - ins}") -# for i in range(config.num_hidden_layers): -# inputs[f"past_key.{i}"] = qpc_out[f"past_key.{i}_RetainedState"] -# inputs[f"past_value.{i}"] = qpc_out[f"past_value.{i}_RetainedState"] - -# all_outputs.append(np.argmax(qpc_out["logits"])) -# print(all_outputs) -# print(">>>>>>>> export for prefill is done <<<<<<<<<<<") -# ########################### - -# decode_qpc_path = qeff_model.compile( -# prefill_seq_len=1, -# ctx_len=CTX_LEN, -# num_cores=16, -# mxfp6_matmul=True, -# mxint8_kv_cache=True, -# num_devices=1, -# mos=1, -# aic_enable_depth_first=True, -# num_speculative_tokens=None, -# offload_pt_weights=False, # Need the weights in memory for prefill-model export/compilation in the next step -# retain_full_kv=True, -# ) -# decode_session = QAICInferenceSession(decode_qpc_path) - -# decode_inputs = { -# "input_ids": np.argmax(qpc_out["logits"]).reshape(1, 1), -# "position_ids": np.max(inputs["position_ids"]).reshape(1, 1) + 1, -# } -# for i in range(config.num_hidden_layers): -# decode_inputs[f"past_key.{i}"] = qpc_out[f"past_key.{i}_RetainedState"] -# decode_inputs[f"past_value.{i}"] = qpc_out[f"past_value.{i}_RetainedState"] - -# st = time.time() -# decode_out = decode_session.run(decode_inputs) -# print(f"time for first run of decode with KV as input = {time.time() - st} sec\n") -# all_outputs.append(np.argmax(decode_out["logits"])) -# pos_id = np.max(decode_inputs["position_ids"]).reshape(1, 1) + 1 -# loop_decode_inputs = { -# "input_ids": np.argmax(decode_out["logits"]).reshape(1, 1), -# "position_ids": pos_id, -# } - -# for i in range(config.num_hidden_layers): -# loop_decode_inputs[f"past_key.{i}"] = decode_out[f"past_key.{i}_RetainedState"] -# loop_decode_inputs[f"past_value.{i}"] = decode_out[f"past_value.{i}_RetainedState"] - -# st = time.time() -# for i in range(generation_len - 2): -# decode_out = decode_session.run(loop_decode_inputs) -# all_outputs.append(np.argmax(decode_out["logits"])) -# pos_id += 1 -# for i in range(config.num_hidden_layers): -# loop_decode_inputs[f"past_key.{i}"] = decode_out[f"past_key.{i}_RetainedState"] -# loop_decode_inputs[f"past_value.{i}"] = decode_out[f"past_value.{i}_RetainedState"] - -# loop_decode_inputs.update( -# { -# "input_ids": np.argmax(decode_out["logits"]).reshape(1, 1), -# "position_ids": pos_id, -# } -# ) -# ft = time.time() - -# print(f"decode tok/sec={(generation_len - 2) / (ft - st)}") -# print(f"input\n{prompt}\noutput\n{tokenizer.decode(all_outputs)}") diff --git a/examples/image_text_to_text/models/qwen3_5_moe/qwen3_5_moe_layerwise.py b/examples/image_text_to_text/models/qwen3_5_moe/qwen3_5_moe_layerwise.py index cdf25ee36..e00209c57 100644 --- a/examples/image_text_to_text/models/qwen3_5_moe/qwen3_5_moe_layerwise.py +++ b/examples/image_text_to_text/models/qwen3_5_moe/qwen3_5_moe_layerwise.py @@ -5,319 +5,56 @@ # # ----------------------------------------------------------------------------- -import functools -import os -from pathlib import Path +"""Layerwise prefill compile example for Qwen3.5-MoE. + +The orchestration loop that previously lived in this script has been moved +behind the ``layerwise=True`` flag on ``.compile()`` / ``.export()``. + +Note: ``layerwise=True`` is a provisional API and is scheduled for deprecation +once first-class multi-window export lands. Supported model types: +``qwen3_vl_moe``, ``qwen3_5_moe``, ``qwen3_moe``. +""" import torch -import transformers from transformers import AutoConfig -import QEfficient from QEfficient import QEFFAutoModelForImageTextToText MODEL_ID = "Qwen/Qwen3.5-397B-A17B" -PREFILL_SEQ_LEN = 32 -CTX_LEN = 4096 -TEXT_WINDOW_SIZE = 1 - -# For quick local validation only (keep disabled for real export) -# TEST_TEXT_LAYERS = 4 - -# Export controls -BATCH_SIZE = 1 -NUM_CORES = 16 -NUM_DEVICES = 1 -HEIGHT = 354 -WIDTH = 536 - - -def _ensure_pretrained_window_attrs(): - if not hasattr(transformers.modeling_utils.PreTrainedModel, "_start"): - transformers.modeling_utils.PreTrainedModel._start = 0 - if not hasattr(transformers.modeling_utils.PreTrainedModel, "_end"): - transformers.modeling_utils.PreTrainedModel._end = 0 - if not hasattr(transformers.modeling_utils.PreTrainedModel, "_total_layers"): - transformers.modeling_utils.PreTrainedModel._total_layers = 0 - if not hasattr(transformers.modeling_utils.PreTrainedModel, "_text_start"): - transformers.modeling_utils.PreTrainedModel._text_start = 0 - if not hasattr(transformers.modeling_utils.PreTrainedModel, "_text_end"): - transformers.modeling_utils.PreTrainedModel._text_end = 0 - if not hasattr(transformers.modeling_utils.PreTrainedModel, "_text_total_layers"): - transformers.modeling_utils.PreTrainedModel._text_total_layers = 0 - - -def _build_layer_windows(total_layers: int, window_size: int): - if total_layers <= 0: - raise ValueError(f"Invalid total_layers={total_layers}. Expected: total_layers > 0.") - if window_size <= 0: - raise ValueError(f"Invalid window_size={window_size}. Expected: window_size > 0.") - - windows = [] - end = total_layers - while end > 0: - start = max(0, end - window_size) - windows.append((start, end)) - end = start - return windows - - -def _get_text_layers_container(model): - # VLM path first - if ( - hasattr(model, "model") - and hasattr(model.model, "language_model") - and hasattr(model.model.language_model, "layers") - ): - return model.model.language_model.layers - # LLM-compatible fallbacks - if hasattr(model, "model") and hasattr(model.model, "layers"): - return model.model.layers - if hasattr(model, "language_model") and hasattr(model.language_model, "layers"): - return model.language_model.layers - if hasattr(model, "layers"): - return model.layers - return None - - -def _null_outside_window_layers(model, apply_text: bool = True): - if apply_text: - text_start = int( - getattr( - transformers.modeling_utils.PreTrainedModel, - "_text_start", - getattr(transformers.modeling_utils.PreTrainedModel, "_start", 0), - ) - ) - text_end = int( - getattr( - transformers.modeling_utils.PreTrainedModel, - "_text_end", - getattr(transformers.modeling_utils.PreTrainedModel, "_end", 0), - ) - ) - text_layers = _get_text_layers_container(model) - if text_layers is not None and text_end > text_start: - for idx, _ in enumerate(text_layers): - if idx < text_start or idx >= text_end: - text_layers[idx] = None - - -def _install_window_patch(model_cls): - if getattr(model_cls, "_window_patch_installed", False): - return - - original_init = model_cls.__init__ - - @functools.wraps(original_init) - def patched_init(self, *args, **kwargs): - original_init(self, *args, **kwargs) - _null_outside_window_layers(self, apply_text=True) - - model_cls.__init__ = patched_init - model_cls._window_patch_installed = True - - -def _resolve_export_root(onnx_path: Path) -> Path: - parts = list(onnx_path.parts) - if "onnx_layerwise_tmp" in parts: - marker_idx = parts.index("onnx_layerwise_tmp") - return Path(*parts[:marker_idx]) - return onnx_path.parent - - -def _install_shard_window_patch(): - if getattr(transformers.modeling_utils, "_window_shard_patch_installed", False): - return - - original_get_checkpoint_shard_files = transformers.modeling_utils.get_checkpoint_shard_files - - @functools.wraps(original_get_checkpoint_shard_files) - def patched_get_checkpoint_shard_files(*args, **kwargs): - shard_files, metadata = original_get_checkpoint_shard_files(*args, **kwargs) - weight_map = metadata.get("weight_map") - if not weight_map: - return shard_files, metadata - - start = int(getattr(transformers.modeling_utils.PreTrainedModel, "_start", 0)) - end = int(getattr(transformers.modeling_utils.PreTrainedModel, "_end", 0)) - text_start = int(getattr(transformers.modeling_utils.PreTrainedModel, "_text_start", start)) - text_end = int(getattr(transformers.modeling_utils.PreTrainedModel, "_text_end", end)) - has_text_window = text_end > text_start - if not has_text_window: - return shard_files, metadata - - selected_text_prefixes = tuple( - [f"model.layers.{layer_idx}." for layer_idx in range(text_start, text_end)] - + [f"model.language_model.layers.{layer_idx}." for layer_idx in range(text_start, text_end)] - ) - filtered_weight_map = {} - for checkpoint_key, shard_name in weight_map.items(): - if checkpoint_key.startswith("model.layers.") or checkpoint_key.startswith("model.language_model.layers."): - if not has_text_window or checkpoint_key.startswith(selected_text_prefixes): - filtered_weight_map[checkpoint_key] = shard_name - continue - filtered_weight_map[checkpoint_key] = shard_name - - if not filtered_weight_map: - return shard_files, metadata - - shard_name_to_path = {path.split("/")[-1]: path for path in shard_files} - filtered_shard_names = sorted(set(filtered_weight_map.values())) - filtered_shard_files = [shard_name_to_path[name] for name in filtered_shard_names if name in shard_name_to_path] - if not filtered_shard_files: - return shard_files, metadata - - metadata["weight_map"] = filtered_weight_map - metadata["all_checkpoint_keys"] = list(filtered_weight_map.keys()) - return filtered_shard_files, metadata - transformers.modeling_utils.get_checkpoint_shard_files = patched_get_checkpoint_shard_files - transformers.modeling_utils._window_shard_patch_installed = True - - -def _set_layer_windows( - text_start: int, - text_end: int, - text_total_layers: int, -): - transformers.modeling_utils.PreTrainedModel._start = text_start - transformers.modeling_utils.PreTrainedModel._end = text_end - transformers.modeling_utils.PreTrainedModel._total_layers = text_total_layers - transformers.modeling_utils.PreTrainedModel._text_start = text_start - transformers.modeling_utils.PreTrainedModel._text_end = text_end - transformers.modeling_utils.PreTrainedModel._text_total_layers = text_total_layers - - qeff_mod = QEfficient.transformers.models.qwen3_5_moe.modeling_qwen3_5_moe - qeff_mod.QEffQwen3_5MoeTextModel._start = text_start - qeff_mod.QEffQwen3_5MoeTextModel._end = text_end - qeff_mod.QEffQwen3_5MoeTextModel._total_layers = text_total_layers - - QEfficient.base.modeling_qeff.QEFFBaseModel._start = text_start - QEfficient.base.modeling_qeff.QEFFBaseModel._end = text_end - QEfficient.base.modeling_qeff.QEFFBaseModel._total_layers = text_total_layers - - -def _stitch_layerwise_if_available(export_root: Path): - # Some branches expose this helper; fall back gracefully when unavailable. - pipeline_fn = getattr(QEfficient.utils, "layerwise_pipeline", None) - if callable(pipeline_fn): - return pipeline_fn(str(export_root)) - print(f"layerwise_pipeline() not found. Layer-wise ONNX shards kept under: {export_root / 'onnx_layerwise_tmp'}") - return str(export_root / "onnx_layerwise_tmp") +def main(): + config = AutoConfig.from_pretrained(MODEL_ID) + config.torch_dtype = "float32" -def _new_qeff_model(model_id: str, config): - return QEFFAutoModelForImageTextToText.from_pretrained( - model_id, + qeff_model = QEFFAutoModelForImageTextToText.from_pretrained( + MODEL_ID, attn_implementation="eager", kv_offload=True, config=config, torch_dtype=torch.float32, ) - -def main(): - config = AutoConfig.from_pretrained(MODEL_ID) - config.torch_dtype = "float32" - - # if TEST_TEXT_LAYERS: - # config.text_config.num_hidden_layers = TEST_TEXT_LAYERS - - text_config = getattr(config, "text_config", config) - # config.vision_config.depth = 3 - text_total_layers = getattr(text_config, "num_hidden_layers", None) - if text_total_layers is None: - raise ValueError("Could not resolve `num_hidden_layers` from config.text_config.") - config.text_config.num_hidden_layers = text_total_layers - _ensure_pretrained_window_attrs() - _install_shard_window_patch() - - hf_qwen_mod = transformers.models.qwen3_5_moe.modeling_qwen3_5_moe - _install_window_patch(hf_qwen_mod.Qwen3_5MoeForConditionalGeneration) - _install_window_patch(hf_qwen_mod.Qwen3_5MoeForCausalLM) - - text_windows = _build_layer_windows(total_layers=text_total_layers, window_size=TEXT_WINDOW_SIZE) - # Keep layerwise only on text path in this loop. - num_windows = len(text_windows) - first_onnx_path = None - os.environ["LAYERWISE_EXPORT"] = "True" - for window_idx in range(num_windows): - text_start, text_end = text_windows[window_idx] if window_idx < len(text_windows) else (0, 0) - skip_lang_for_window = window_idx >= len(text_windows) - - _set_layer_windows( - text_start=text_start, - text_end=text_end, - text_total_layers=text_total_layers, - ) - print( - f"Exporting window {window_idx + 1}/{num_windows} " - f"text=[{text_start},{text_end})/{text_total_layers} " - f"skip_lang={skip_lang_for_window}" - ) - - qeff_model = _new_qeff_model(MODEL_ID, config) - if hasattr(qeff_model, "model"): - _null_outside_window_layers( - qeff_model.model, - apply_text=not skip_lang_for_window, - ) - - onnx_path = qeff_model.compile( - batch_size=BATCH_SIZE, - prefill_seq_len=PREFILL_SEQ_LEN, - ctx_len=CTX_LEN, - num_cores=NUM_CORES, - num_devices=NUM_DEVICES, - height=HEIGHT, - width=WIDTH, - mxfp6_matmul=False, - aic_enable_depth_first=True, - skip_vision=True, - skip_lang=skip_lang_for_window, - prefill_only=True, - use_onnx_subfunctions=True, - enable_chunking=True, - mos=1, - user_tiled=True, - ) - - if first_onnx_path is None: - first_onnx_path = Path(str(onnx_path["lang_prefill_qpc_path"])) - - if first_onnx_path is None: - raise RuntimeError("No ONNX path produced during layer-wise language export.") - - export_root = _resolve_export_root(first_onnx_path) - final_artifact = _stitch_layerwise_if_available(export_root) - print(f"Layer-wise language export completed. Final artifact/root: {final_artifact}") - - os.environ["LAYERWISE_EXPORT"] = "False" qpc_path = qeff_model.compile( - lang_onnx_path=final_artifact, - batch_size=BATCH_SIZE, - prefill_seq_len=PREFILL_SEQ_LEN, - ctx_len=CTX_LEN, - num_cores=NUM_CORES, - num_devices=NUM_DEVICES, - height=HEIGHT, - width=WIDTH, - mxfp6_matmul=False, + batch_size=1, + prefill_seq_len=32, + ctx_len=4096, + num_cores=16, + num_devices=1, + height=354, + width=536, + mxfp6_matmul=True, aic_enable_depth_first=True, skip_vision=True, - skip_lang=skip_lang_for_window, - prefill_only=True, + split_retained_state_io=True, use_onnx_subfunctions=True, - enable_chunking=True, + prefill_only=True, mos=1, + layerwise=True, + layerwise_window_size=1, ) - print(f"Final QPC path: {qpc_path}") if __name__ == "__main__": main() - - -# /opt/qti-aic/exec/qaic-compile -aic-hw -aic-hw-version=ai100 -m=/home/abhishek/.cache/qeff_models/Qwen3_5MoeForConditionalGeneration/Qwen3_5MoeDecoderWrapper-61a4400d63d1b0bb/final_data/merged_0-2.onnx -retained-state -convert-to-fp16 -aic-num-cores=16 -aic-enable-depth-first -mos=1 -network-specialization-config=/home/abhishek/.cache/qeff_models/Qwen3_5MoeForConditionalGeneration/Qwen3_5MoeDecoderWrapper-61a4400d63d1b0bb/final_data/specializations.json -custom-IO-list-file=/home/abhishek/.cache/qeff_models/Qwen3_5MoeForConditionalGeneration/Qwen3_5MoeDecoderWrapper-61a4400d63d1b0bb/final_data/qpc_binaries/custom_io.yaml -aic-binary-dir=/home/abhishek/.cache/qeff_models/Qwen3_5MoeForConditionalGeneration/Qwen3_5MoeDecoderWrapper-61a4400d63d1b0bb/final_data/qpc_binaries/qpc diff --git a/examples/image_text_to_text/models/qwen3_5_moe/qwen3_5_moe_layerwise_decode.py b/examples/image_text_to_text/models/qwen3_5_moe/qwen3_5_moe_layerwise_decode.py index a5b7475f7..e2c9d77e7 100644 --- a/examples/image_text_to_text/models/qwen3_5_moe/qwen3_5_moe_layerwise_decode.py +++ b/examples/image_text_to_text/models/qwen3_5_moe/qwen3_5_moe_layerwise_decode.py @@ -5,317 +5,55 @@ # # ----------------------------------------------------------------------------- -import functools -import os -from pathlib import Path +"""Layerwise decode compile example for Qwen3.5-MoE. + +The orchestration loop that previously lived in this script has been moved +behind the ``layerwise=True`` flag on ``.compile()`` / ``.export()``. + +Note: ``layerwise=True`` is a provisional API and is scheduled for deprecation +once first-class multi-window export lands. Supported model types: +``qwen3_vl_moe``, ``qwen3_5_moe``, ``qwen3_moe``. +""" import torch -import transformers from transformers import AutoConfig -import QEfficient from QEfficient import QEFFAutoModelForImageTextToText MODEL_ID = "Qwen/Qwen3.5-397B-A17B" -PREFILL_SEQ_LEN = 1 -CTX_LEN = 4096 -TEXT_WINDOW_SIZE = 1 - -# For quick local validation only (keep disabled for real export) -# TEST_TEXT_LAYERS = 4 - -# Export controls -BATCH_SIZE = 1 -NUM_CORES = 16 -NUM_DEVICES = 1 -HEIGHT = 354 -WIDTH = 536 - - -def _ensure_pretrained_window_attrs(): - if not hasattr(transformers.modeling_utils.PreTrainedModel, "_start"): - transformers.modeling_utils.PreTrainedModel._start = 0 - if not hasattr(transformers.modeling_utils.PreTrainedModel, "_end"): - transformers.modeling_utils.PreTrainedModel._end = 0 - if not hasattr(transformers.modeling_utils.PreTrainedModel, "_total_layers"): - transformers.modeling_utils.PreTrainedModel._total_layers = 0 - if not hasattr(transformers.modeling_utils.PreTrainedModel, "_text_start"): - transformers.modeling_utils.PreTrainedModel._text_start = 0 - if not hasattr(transformers.modeling_utils.PreTrainedModel, "_text_end"): - transformers.modeling_utils.PreTrainedModel._text_end = 0 - if not hasattr(transformers.modeling_utils.PreTrainedModel, "_text_total_layers"): - transformers.modeling_utils.PreTrainedModel._text_total_layers = 0 - - -def _build_layer_windows(total_layers: int, window_size: int): - if total_layers <= 0: - raise ValueError(f"Invalid total_layers={total_layers}. Expected: total_layers > 0.") - if window_size <= 0: - raise ValueError(f"Invalid window_size={window_size}. Expected: window_size > 0.") - - windows = [] - end = total_layers - while end > 0: - start = max(0, end - window_size) - windows.append((start, end)) - end = start - return windows - - -def _get_text_layers_container(model): - # VLM path first - if ( - hasattr(model, "model") - and hasattr(model.model, "language_model") - and hasattr(model.model.language_model, "layers") - ): - return model.model.language_model.layers - # LLM-compatible fallbacks - if hasattr(model, "model") and hasattr(model.model, "layers"): - return model.model.layers - if hasattr(model, "language_model") and hasattr(model.language_model, "layers"): - return model.language_model.layers - if hasattr(model, "layers"): - return model.layers - return None - - -def _null_outside_window_layers(model, apply_text: bool = True): - if apply_text: - text_start = int( - getattr( - transformers.modeling_utils.PreTrainedModel, - "_text_start", - getattr(transformers.modeling_utils.PreTrainedModel, "_start", 0), - ) - ) - text_end = int( - getattr( - transformers.modeling_utils.PreTrainedModel, - "_text_end", - getattr(transformers.modeling_utils.PreTrainedModel, "_end", 0), - ) - ) - text_layers = _get_text_layers_container(model) - if text_layers is not None and text_end > text_start: - for idx, _ in enumerate(text_layers): - if idx < text_start or idx >= text_end: - text_layers[idx] = None - - -def _install_window_patch(model_cls): - if getattr(model_cls, "_window_patch_installed", False): - return - - original_init = model_cls.__init__ - - @functools.wraps(original_init) - def patched_init(self, *args, **kwargs): - original_init(self, *args, **kwargs) - _null_outside_window_layers(self, apply_text=True) - - model_cls.__init__ = patched_init - model_cls._window_patch_installed = True - - -def _resolve_export_root(onnx_path: Path) -> Path: - parts = list(onnx_path.parts) - if "onnx_layerwise_tmp" in parts: - marker_idx = parts.index("onnx_layerwise_tmp") - return Path(*parts[:marker_idx]) - return onnx_path.parent - - -def _install_shard_window_patch(): - if getattr(transformers.modeling_utils, "_window_shard_patch_installed", False): - return - - original_get_checkpoint_shard_files = transformers.modeling_utils.get_checkpoint_shard_files - - @functools.wraps(original_get_checkpoint_shard_files) - def patched_get_checkpoint_shard_files(*args, **kwargs): - shard_files, metadata = original_get_checkpoint_shard_files(*args, **kwargs) - weight_map = metadata.get("weight_map") - if not weight_map: - return shard_files, metadata - - start = int(getattr(transformers.modeling_utils.PreTrainedModel, "_start", 0)) - end = int(getattr(transformers.modeling_utils.PreTrainedModel, "_end", 0)) - text_start = int(getattr(transformers.modeling_utils.PreTrainedModel, "_text_start", start)) - text_end = int(getattr(transformers.modeling_utils.PreTrainedModel, "_text_end", end)) - has_text_window = text_end > text_start - if not has_text_window: - return shard_files, metadata - - selected_text_prefixes = tuple( - [f"model.layers.{layer_idx}." for layer_idx in range(text_start, text_end)] - + [f"model.language_model.layers.{layer_idx}." for layer_idx in range(text_start, text_end)] - ) - filtered_weight_map = {} - for checkpoint_key, shard_name in weight_map.items(): - if checkpoint_key.startswith("model.layers.") or checkpoint_key.startswith("model.language_model.layers."): - if not has_text_window or checkpoint_key.startswith(selected_text_prefixes): - filtered_weight_map[checkpoint_key] = shard_name - continue - filtered_weight_map[checkpoint_key] = shard_name - - if not filtered_weight_map: - return shard_files, metadata - - shard_name_to_path = {path.split("/")[-1]: path for path in shard_files} - filtered_shard_names = sorted(set(filtered_weight_map.values())) - filtered_shard_files = [shard_name_to_path[name] for name in filtered_shard_names if name in shard_name_to_path] - if not filtered_shard_files: - return shard_files, metadata - metadata["weight_map"] = filtered_weight_map - metadata["all_checkpoint_keys"] = list(filtered_weight_map.keys()) - return filtered_shard_files, metadata - - transformers.modeling_utils.get_checkpoint_shard_files = patched_get_checkpoint_shard_files - transformers.modeling_utils._window_shard_patch_installed = True - - -def _set_layer_windows( - text_start: int, - text_end: int, - text_total_layers: int, -): - transformers.modeling_utils.PreTrainedModel._start = text_start - transformers.modeling_utils.PreTrainedModel._end = text_end - transformers.modeling_utils.PreTrainedModel._total_layers = text_total_layers - transformers.modeling_utils.PreTrainedModel._text_start = text_start - transformers.modeling_utils.PreTrainedModel._text_end = text_end - transformers.modeling_utils.PreTrainedModel._text_total_layers = text_total_layers - - qeff_mod = QEfficient.transformers.models.qwen3_5_moe.modeling_qwen3_5_moe - qeff_mod.QEffQwen3_5MoeTextModel._start = text_start - qeff_mod.QEffQwen3_5MoeTextModel._end = text_end - qeff_mod.QEffQwen3_5MoeTextModel._total_layers = text_total_layers - - QEfficient.base.modeling_qeff.QEFFBaseModel._start = text_start - QEfficient.base.modeling_qeff.QEFFBaseModel._end = text_end - QEfficient.base.modeling_qeff.QEFFBaseModel._total_layers = text_total_layers - - -def _stitch_layerwise_if_available(export_root: Path): - # Some branches expose this helper; fall back gracefully when unavailable. - pipeline_fn = getattr(QEfficient.utils, "layerwise_pipeline", None) - if callable(pipeline_fn): - return pipeline_fn(str(export_root)) - print(f"layerwise_pipeline() not found. Layer-wise ONNX shards kept under: {export_root / 'onnx_layerwise_tmp'}") - return str(export_root / "onnx_layerwise_tmp") +def main(): + config = AutoConfig.from_pretrained(MODEL_ID) + config.torch_dtype = "float32" -def _new_qeff_model(model_id: str, config): - return QEFFAutoModelForImageTextToText.from_pretrained( - model_id, + qeff_model = QEFFAutoModelForImageTextToText.from_pretrained( + MODEL_ID, attn_implementation="eager", kv_offload=True, config=config, torch_dtype=torch.float32, ) - -def main(): - config = AutoConfig.from_pretrained(MODEL_ID) - config.torch_dtype = "float32" - - # if TEST_TEXT_LAYERS: - # config.text_config.num_hidden_layers = TEST_TEXT_LAYERS - - text_config = getattr(config, "text_config", config) - # config.vision_config.depth = 3 - text_total_layers = getattr(text_config, "num_hidden_layers", None) - if text_total_layers is None: - raise ValueError("Could not resolve `num_hidden_layers` from config.text_config.") - config.text_config.num_hidden_layers = text_total_layers - _ensure_pretrained_window_attrs() - _install_shard_window_patch() - - hf_qwen_mod = transformers.models.qwen3_5_moe.modeling_qwen3_5_moe - _install_window_patch(hf_qwen_mod.Qwen3_5MoeForConditionalGeneration) - _install_window_patch(hf_qwen_mod.Qwen3_5MoeForCausalLM) - - text_windows = _build_layer_windows(total_layers=text_total_layers, window_size=TEXT_WINDOW_SIZE) - # Keep layerwise only on text path in this loop. - num_windows = len(text_windows) - first_onnx_path = None - os.environ["LAYERWISE_EXPORT"] = "True" - for window_idx in range(num_windows): - text_start, text_end = text_windows[window_idx] if window_idx < len(text_windows) else (0, 0) - skip_lang_for_window = window_idx >= len(text_windows) - - _set_layer_windows( - text_start=text_start, - text_end=text_end, - text_total_layers=text_total_layers, - ) - print( - f"Exporting window {window_idx + 1}/{num_windows} " - f"text=[{text_start},{text_end})/{text_total_layers} " - f"skip_lang={skip_lang_for_window}" - ) - - qeff_model = _new_qeff_model(MODEL_ID, config) - if hasattr(qeff_model, "model"): - _null_outside_window_layers( - qeff_model.model, - apply_text=not skip_lang_for_window, - ) - - onnx_path = qeff_model.compile( - batch_size=BATCH_SIZE, - prefill_seq_len=PREFILL_SEQ_LEN, - ctx_len=CTX_LEN, - num_cores=NUM_CORES, - num_devices=NUM_DEVICES, - height=HEIGHT, - width=WIDTH, - mxfp6_matmul=False, - aic_enable_depth_first=True, - skip_vision=True, - skip_lang=skip_lang_for_window, - use_onnx_subfunctions=True, - enable_chunking=True, - mos=1, - user_tiled=True, - ) - - if first_onnx_path is None: - first_onnx_path = Path(str(onnx_path["lang_decode_qpc_path"])) - - if first_onnx_path is None: - raise RuntimeError("No ONNX path produced during layer-wise language export.") - - export_root = _resolve_export_root(first_onnx_path) - final_artifact = _stitch_layerwise_if_available(export_root) - print(f"Layer-wise language export completed. Final artifact/root: {final_artifact}") - - os.environ["LAYERWISE_EXPORT"] = "False" qpc_path = qeff_model.compile( - lang_onnx_path=final_artifact, - batch_size=BATCH_SIZE, - prefill_seq_len=PREFILL_SEQ_LEN, - ctx_len=CTX_LEN, - num_cores=NUM_CORES, - num_devices=NUM_DEVICES, - height=HEIGHT, - width=WIDTH, - mxfp6_matmul=False, + batch_size=1, + prefill_seq_len=1, + ctx_len=4096, + num_cores=16, + num_devices=1, + height=354, + width=536, + mxfp6_matmul=True, aic_enable_depth_first=True, skip_vision=True, - skip_lang=skip_lang_for_window, + split_retained_state_io=True, use_onnx_subfunctions=True, - enable_chunking=True, mos=1, + layerwise=True, + layerwise_window_size=1, ) - print(f"Final QPC path: {qpc_path}") if __name__ == "__main__": main() - - -# /opt/qti-aic/exec/qaic-compile -aic-hw -aic-hw-version=ai100 -m=/home/abhishek/.cache/qeff_models/Qwen3_5MoeForConditionalGeneration/Qwen3_5MoeDecoderWrapper-61a4400d63d1b0bb/final_data/merged_0-2.onnx -retained-state -convert-to-fp16 -aic-num-cores=16 -aic-enable-depth-first -mos=1 -network-specialization-config=/home/abhishek/.cache/qeff_models/Qwen3_5MoeForConditionalGeneration/Qwen3_5MoeDecoderWrapper-61a4400d63d1b0bb/final_data/specializations.json -custom-IO-list-file=/home/abhishek/.cache/qeff_models/Qwen3_5MoeForConditionalGeneration/Qwen3_5MoeDecoderWrapper-61a4400d63d1b0bb/final_data/qpc_binaries/custom_io.yaml -aic-binary-dir=/home/abhishek/.cache/qeff_models/Qwen3_5MoeForConditionalGeneration/Qwen3_5MoeDecoderWrapper-61a4400d63d1b0bb/final_data/qpc_binaries/qpc diff --git a/examples/image_text_to_text/models/qwen3_vl_moe/qwen3_vl_moe_layerwise.py b/examples/image_text_to_text/models/qwen3_vl_moe/qwen3_vl_moe_layerwise.py index 990369fd9..4973ec325 100644 --- a/examples/image_text_to_text/models/qwen3_vl_moe/qwen3_vl_moe_layerwise.py +++ b/examples/image_text_to_text/models/qwen3_vl_moe/qwen3_vl_moe_layerwise.py @@ -5,326 +5,56 @@ # # ----------------------------------------------------------------------------- -import functools -import os -from pathlib import Path +"""Layerwise prefill compile example for Qwen3-VL-MoE. + +The orchestration loop that previously lived in this script has been moved +behind the ``layerwise=True`` flag on ``.compile()`` / ``.export()``. + +Note: ``layerwise=True`` is a provisional API and is scheduled for deprecation +once first-class multi-window export lands. Supported model types: +``qwen3_vl_moe``, ``qwen3_5_moe``, ``qwen3_moe``. +""" import torch -import transformers from transformers import AutoConfig -import QEfficient from QEfficient import QEFFAutoModelForImageTextToText -MODEL_ID = "Qwen/Qwen3-VL-235B-A22B-Instruct" -PREFILL_SEQ_LEN = 32 -CTX_LEN = 4096 -TEXT_WINDOW_SIZE = 1 - -# For quick local validation only (keep disabled for real export) -# TEST_TEXT_LAYERS = 4 - -# Export controls -BATCH_SIZE = 1 -NUM_CORES = 16 -NUM_DEVICES = 4 -HEIGHT = 354 -WIDTH = 536 - - -def _ensure_pretrained_window_attrs(): - if not hasattr(transformers.modeling_utils.PreTrainedModel, "_start"): - transformers.modeling_utils.PreTrainedModel._start = 0 - if not hasattr(transformers.modeling_utils.PreTrainedModel, "_end"): - transformers.modeling_utils.PreTrainedModel._end = 0 - if not hasattr(transformers.modeling_utils.PreTrainedModel, "_total_layers"): - transformers.modeling_utils.PreTrainedModel._total_layers = 0 - if not hasattr(transformers.modeling_utils.PreTrainedModel, "_text_start"): - transformers.modeling_utils.PreTrainedModel._text_start = 0 - if not hasattr(transformers.modeling_utils.PreTrainedModel, "_text_end"): - transformers.modeling_utils.PreTrainedModel._text_end = 0 - if not hasattr(transformers.modeling_utils.PreTrainedModel, "_text_total_layers"): - transformers.modeling_utils.PreTrainedModel._text_total_layers = 0 - - -def _build_layer_windows(total_layers: int, window_size: int): - if total_layers <= 0: - raise ValueError(f"Invalid total_layers={total_layers}. Expected: total_layers > 0.") - if window_size <= 0: - raise ValueError(f"Invalid window_size={window_size}. Expected: window_size > 0.") - - windows = [] - start = 0 - while start < total_layers: - end = min(total_layers, start + window_size) - windows.append((start, end)) - start = end - return windows - - -def _get_text_layers_container(model): - # VLM path first - if ( - hasattr(model, "model") - and hasattr(model.model, "language_model") - and hasattr(model.model.language_model, "layers") - ): - return model.model.language_model.layers - # LLM-compatible fallbacks - if hasattr(model, "model") and hasattr(model.model, "layers"): - return model.model.layers - if hasattr(model, "language_model") and hasattr(model.language_model, "layers"): - return model.language_model.layers - if hasattr(model, "layers"): - return model.layers - return None - - -def _null_outside_window_layers(model, apply_text: bool = True): - if apply_text: - text_start = int( - getattr( - transformers.modeling_utils.PreTrainedModel, - "_text_start", - getattr(transformers.modeling_utils.PreTrainedModel, "_start", 0), - ) - ) - text_end = int( - getattr( - transformers.modeling_utils.PreTrainedModel, - "_text_end", - getattr(transformers.modeling_utils.PreTrainedModel, "_end", 0), - ) - ) - text_layers = _get_text_layers_container(model) - if text_layers is not None and text_end > text_start: - for idx, _ in enumerate(text_layers): - if idx < text_start or idx >= text_end: - text_layers[idx] = None - - -def _install_window_patch(model_cls): - if getattr(model_cls, "_window_patch_installed", False): - return - - original_init = model_cls.__init__ - - @functools.wraps(original_init) - def patched_init(self, *args, **kwargs): - original_init(self, *args, **kwargs) - _null_outside_window_layers(self, apply_text=True) - - model_cls.__init__ = patched_init - model_cls._window_patch_installed = True - - -def _resolve_export_root(onnx_path: Path) -> Path: - parts = list(onnx_path.parts) - if "onnx_layerwise_tmp" in parts: - marker_idx = parts.index("onnx_layerwise_tmp") - return Path(*parts[:marker_idx]) - return onnx_path.parent - - -def _install_shard_window_patch(): - if getattr(transformers.modeling_utils, "_window_shard_patch_installed", False): - return - - original_get_checkpoint_shard_files = transformers.modeling_utils.get_checkpoint_shard_files - - @functools.wraps(original_get_checkpoint_shard_files) - def patched_get_checkpoint_shard_files(*args, **kwargs): - shard_files, metadata = original_get_checkpoint_shard_files(*args, **kwargs) - weight_map = metadata.get("weight_map") - if not weight_map: - return shard_files, metadata - - start = int(getattr(transformers.modeling_utils.PreTrainedModel, "_start", 0)) - end = int(getattr(transformers.modeling_utils.PreTrainedModel, "_end", 0)) - text_start = int(getattr(transformers.modeling_utils.PreTrainedModel, "_text_start", start)) - text_end = int(getattr(transformers.modeling_utils.PreTrainedModel, "_text_end", end)) - has_text_window = text_end > text_start - if not has_text_window: - return shard_files, metadata - - selected_text_prefixes = tuple( - [f"model.layers.{layer_idx}." for layer_idx in range(text_start, text_end)] - + [f"model.language_model.layers.{layer_idx}." for layer_idx in range(text_start, text_end)] - ) - filtered_weight_map = {} - for checkpoint_key, shard_name in weight_map.items(): - if checkpoint_key.startswith("model.layers.") or checkpoint_key.startswith("model.language_model.layers."): - if not has_text_window or checkpoint_key.startswith(selected_text_prefixes): - filtered_weight_map[checkpoint_key] = shard_name - continue - filtered_weight_map[checkpoint_key] = shard_name - - if not filtered_weight_map: - return shard_files, metadata - - shard_name_to_path = {path.split("/")[-1]: path for path in shard_files} - filtered_shard_names = sorted(set(filtered_weight_map.values())) - filtered_shard_files = [shard_name_to_path[name] for name in filtered_shard_names if name in shard_name_to_path] - if not filtered_shard_files: - return shard_files, metadata +# MODEL_ID = "Qwen/Qwen3-VL-235B-A22B-Instruct" +MODEL_ID = "tiny-random/qwen3-vl-moe" - metadata["weight_map"] = filtered_weight_map - metadata["all_checkpoint_keys"] = list(filtered_weight_map.keys()) - return filtered_shard_files, metadata - - transformers.modeling_utils.get_checkpoint_shard_files = patched_get_checkpoint_shard_files - transformers.modeling_utils._window_shard_patch_installed = True - - -def _set_layer_windows( - text_start: int, - text_end: int, - text_total_layers: int, -): - transformers.modeling_utils.PreTrainedModel._start = text_start - transformers.modeling_utils.PreTrainedModel._end = text_end - transformers.modeling_utils.PreTrainedModel._total_layers = text_total_layers - transformers.modeling_utils.PreTrainedModel._text_start = text_start - transformers.modeling_utils.PreTrainedModel._text_end = text_end - transformers.modeling_utils.PreTrainedModel._text_total_layers = text_total_layers - - # Qwen3-VL-MoE model code still checks QEffQwen3_5MoeTextModel window attrs - # in a few places. Set both classes to keep layer-wise behavior consistent. - qeff_vl_mod = QEfficient.transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe - qeff_vl_mod.QEffQwen3VLMoeTextModel._start = text_start - qeff_vl_mod.QEffQwen3VLMoeTextModel._end = text_end - qeff_vl_mod.QEffQwen3VLMoeTextModel._total_layers = text_total_layers - - qeff_35_mod = getattr(QEfficient.transformers.models, "qwen3_5_moe", None) - if qeff_35_mod is not None: - qeff_35_text_model = getattr(qeff_35_mod.modeling_qwen3_5_moe, "QEffQwen3_5MoeTextModel", None) - if qeff_35_text_model is not None: - qeff_35_text_model._start = text_start - qeff_35_text_model._end = text_end - qeff_35_text_model._total_layers = text_total_layers - - QEfficient.base.modeling_qeff.QEFFBaseModel._start = text_start - QEfficient.base.modeling_qeff.QEFFBaseModel._end = text_end - QEfficient.base.modeling_qeff.QEFFBaseModel._total_layers = text_total_layers - - -def _stitch_layerwise_if_available(export_root: Path): - # Some branches expose this helper; fall back gracefully when unavailable. - pipeline_fn = getattr(QEfficient.utils, "layerwise_pipeline", None) - if callable(pipeline_fn): - return pipeline_fn(str(export_root)) - print(f"layerwise_pipeline() not found. Layer-wise ONNX shards kept under: {export_root / 'onnx_layerwise_tmp'}") - return str(export_root / "onnx_layerwise_tmp") +def main(): + config = AutoConfig.from_pretrained(MODEL_ID) + config.torch_dtype = "float32" + config.vision_config.deepstack_visual_indexes = [8, 27, 36] -def _new_qeff_model(model_id: str, config): - return QEFFAutoModelForImageTextToText.from_pretrained( - model_id, + qeff_model = QEFFAutoModelForImageTextToText.from_pretrained( + MODEL_ID, attn_implementation="eager", kv_offload=True, config=config, torch_dtype=torch.float32, ) - -def main(): - config = AutoConfig.from_pretrained(MODEL_ID) - config.torch_dtype = "float32" - # config.vision_config.depth = 9 - # config.text_config.num_hidden_layers = 2 - config.vision_config.deepstack_visual_indexes = [8, 16, 24] - - # if TEST_TEXT_LAYERS: - # config.text_config.num_hidden_layers = TEST_TEXT_LAYERS - - text_config = getattr(config, "text_config", config) - text_total_layers = getattr(text_config, "num_hidden_layers", None) - if text_total_layers is None: - raise ValueError("Could not resolve `num_hidden_layers` from config.text_config.") - config.text_config.num_hidden_layers = text_total_layers - _ensure_pretrained_window_attrs() - _install_shard_window_patch() - - hf_qwen_mod = transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe - _install_window_patch(hf_qwen_mod.Qwen3VLMoeForConditionalGeneration) - if hasattr(hf_qwen_mod, "Qwen3VLMoeForCausalLM"): - _install_window_patch(hf_qwen_mod.Qwen3VLMoeForCausalLM) - - text_windows = _build_layer_windows(total_layers=text_total_layers, window_size=TEXT_WINDOW_SIZE) - # Keep layerwise only on text path in this loop. - num_windows = len(text_windows) - first_onnx_path = None - os.environ["LAYERWISE_EXPORT"] = "True" - for window_idx in range(num_windows): - text_start, text_end = text_windows[window_idx] if window_idx < len(text_windows) else (0, 0) - skip_lang_for_window = window_idx >= len(text_windows) - - _set_layer_windows( - text_start=text_start, - text_end=text_end, - text_total_layers=text_total_layers, - ) - print( - f"Exporting window {window_idx + 1}/{num_windows} " - f"text=[{text_start},{text_end})/{text_total_layers} " - f"skip_lang={skip_lang_for_window}" - ) - - qeff_model = _new_qeff_model(MODEL_ID, config) - if hasattr(qeff_model, "model"): - _null_outside_window_layers( - qeff_model.model, - apply_text=not skip_lang_for_window, - ) - - onnx_path = qeff_model.compile( - batch_size=BATCH_SIZE, - prefill_seq_len=PREFILL_SEQ_LEN, - ctx_len=CTX_LEN, - num_cores=NUM_CORES, - num_devices=NUM_DEVICES, - height=HEIGHT, - width=WIDTH, - mxfp6_matmul=True, - aic_enable_depth_first=True, - skip_vision=True, - skip_lang=skip_lang_for_window, - split_retained_state_io=True, - use_onnx_subfunctions=True, - prefill_only=True, - mos=1, - ) - - if first_onnx_path is None: - first_onnx_path = Path(str(onnx_path["lang_prefill_qpc_path"])) - - if first_onnx_path is None: - raise RuntimeError("No ONNX path produced during layer-wise language export.") - - export_root = _resolve_export_root(first_onnx_path) - final_artifact = _stitch_layerwise_if_available(export_root) - print(f"Layer-wise language export completed. Final artifact/root: {final_artifact}") - - os.environ["LAYERWISE_EXPORT"] = "False" qpc_path = qeff_model.compile( - lang_onnx_path=final_artifact, - batch_size=BATCH_SIZE, - prefill_seq_len=PREFILL_SEQ_LEN, - ctx_len=CTX_LEN, - num_cores=NUM_CORES, - num_devices=NUM_DEVICES, - height=HEIGHT, - width=WIDTH, + batch_size=1, + prefill_seq_len=32, + ctx_len=4096, + num_cores=16, + num_devices=4, + height=354, + width=536, mxfp6_matmul=True, aic_enable_depth_first=True, skip_vision=True, - skip_lang=skip_lang_for_window, split_retained_state_io=True, use_onnx_subfunctions=True, prefill_only=True, mos=1, + layerwise=True, + layerwise_window_size=1, ) - print(f"Final QPC path: {qpc_path}") diff --git a/examples/image_text_to_text/models/qwen3_vl_moe/qwen3_vl_moe_layerwise_decode.py b/examples/image_text_to_text/models/qwen3_vl_moe/qwen3_vl_moe_layerwise_decode.py index 18a61f6c4..5f5147957 100644 --- a/examples/image_text_to_text/models/qwen3_vl_moe/qwen3_vl_moe_layerwise_decode.py +++ b/examples/image_text_to_text/models/qwen3_vl_moe/qwen3_vl_moe_layerwise_decode.py @@ -5,324 +5,56 @@ # # ----------------------------------------------------------------------------- -import functools -import os -from pathlib import Path +""" +Layerwise compile example for Qwen3-VL-MoE. + +The orchestration loop that previously lived in this script has been moved +behind the ``layerwise=True`` flag on ``.compile()`` / ``.export()``. + +Note: ``layerwise=True`` is a provisional API and is scheduled for deprecation +once first-class multi-window export lands. It is currently only supported for +``qwen3_vl_moe``, ``qwen3_5_moe`` and ``qwen3_moe``. +""" import torch -import transformers from transformers import AutoConfig -import QEfficient from QEfficient import QEFFAutoModelForImageTextToText -MODEL_ID = "Qwen/Qwen3-VL-235B-A22B-Instruct" -PREFILL_SEQ_LEN = 1 -CTX_LEN = 4096 -TEXT_WINDOW_SIZE = 1 - -# For quick local validation only (keep disabled for real export) -# TEST_TEXT_LAYERS = 4 - -# Export controls -BATCH_SIZE = 1 -NUM_CORES = 16 -NUM_DEVICES = 4 -HEIGHT = 354 -WIDTH = 536 - - -def _ensure_pretrained_window_attrs(): - if not hasattr(transformers.modeling_utils.PreTrainedModel, "_start"): - transformers.modeling_utils.PreTrainedModel._start = 0 - if not hasattr(transformers.modeling_utils.PreTrainedModel, "_end"): - transformers.modeling_utils.PreTrainedModel._end = 0 - if not hasattr(transformers.modeling_utils.PreTrainedModel, "_total_layers"): - transformers.modeling_utils.PreTrainedModel._total_layers = 0 - if not hasattr(transformers.modeling_utils.PreTrainedModel, "_text_start"): - transformers.modeling_utils.PreTrainedModel._text_start = 0 - if not hasattr(transformers.modeling_utils.PreTrainedModel, "_text_end"): - transformers.modeling_utils.PreTrainedModel._text_end = 0 - if not hasattr(transformers.modeling_utils.PreTrainedModel, "_text_total_layers"): - transformers.modeling_utils.PreTrainedModel._text_total_layers = 0 - - -def _build_layer_windows(total_layers: int, window_size: int): - if total_layers <= 0: - raise ValueError(f"Invalid total_layers={total_layers}. Expected: total_layers > 0.") - if window_size <= 0: - raise ValueError(f"Invalid window_size={window_size}. Expected: window_size > 0.") - - windows = [] - start = 0 - while start < total_layers: - end = min(total_layers, start + window_size) - windows.append((start, end)) - start = end - return windows - - -def _get_text_layers_container(model): - # VLM path first - if ( - hasattr(model, "model") - and hasattr(model.model, "language_model") - and hasattr(model.model.language_model, "layers") - ): - return model.model.language_model.layers - # LLM-compatible fallbacks - if hasattr(model, "model") and hasattr(model.model, "layers"): - return model.model.layers - if hasattr(model, "language_model") and hasattr(model.language_model, "layers"): - return model.language_model.layers - if hasattr(model, "layers"): - return model.layers - return None - - -def _null_outside_window_layers(model, apply_text: bool = True): - if apply_text: - text_start = int( - getattr( - transformers.modeling_utils.PreTrainedModel, - "_text_start", - getattr(transformers.modeling_utils.PreTrainedModel, "_start", 0), - ) - ) - text_end = int( - getattr( - transformers.modeling_utils.PreTrainedModel, - "_text_end", - getattr(transformers.modeling_utils.PreTrainedModel, "_end", 0), - ) - ) - text_layers = _get_text_layers_container(model) - if text_layers is not None and text_end > text_start: - for idx, _ in enumerate(text_layers): - if idx < text_start or idx >= text_end: - text_layers[idx] = None - - -def _install_window_patch(model_cls): - if getattr(model_cls, "_window_patch_installed", False): - return - - original_init = model_cls.__init__ - - @functools.wraps(original_init) - def patched_init(self, *args, **kwargs): - original_init(self, *args, **kwargs) - _null_outside_window_layers(self, apply_text=True) - - model_cls.__init__ = patched_init - model_cls._window_patch_installed = True - - -def _resolve_export_root(onnx_path: Path) -> Path: - parts = list(onnx_path.parts) - if "onnx_layerwise_tmp" in parts: - marker_idx = parts.index("onnx_layerwise_tmp") - return Path(*parts[:marker_idx]) - return onnx_path.parent - - -def _install_shard_window_patch(): - if getattr(transformers.modeling_utils, "_window_shard_patch_installed", False): - return - - original_get_checkpoint_shard_files = transformers.modeling_utils.get_checkpoint_shard_files - - @functools.wraps(original_get_checkpoint_shard_files) - def patched_get_checkpoint_shard_files(*args, **kwargs): - shard_files, metadata = original_get_checkpoint_shard_files(*args, **kwargs) - weight_map = metadata.get("weight_map") - if not weight_map: - return shard_files, metadata - - start = int(getattr(transformers.modeling_utils.PreTrainedModel, "_start", 0)) - end = int(getattr(transformers.modeling_utils.PreTrainedModel, "_end", 0)) - text_start = int(getattr(transformers.modeling_utils.PreTrainedModel, "_text_start", start)) - text_end = int(getattr(transformers.modeling_utils.PreTrainedModel, "_text_end", end)) - has_text_window = text_end > text_start - if not has_text_window: - return shard_files, metadata - - selected_text_prefixes = tuple( - [f"model.layers.{layer_idx}." for layer_idx in range(text_start, text_end)] - + [f"model.language_model.layers.{layer_idx}." for layer_idx in range(text_start, text_end)] - ) - filtered_weight_map = {} - for checkpoint_key, shard_name in weight_map.items(): - if checkpoint_key.startswith("model.layers.") or checkpoint_key.startswith("model.language_model.layers."): - if not has_text_window or checkpoint_key.startswith(selected_text_prefixes): - filtered_weight_map[checkpoint_key] = shard_name - continue - filtered_weight_map[checkpoint_key] = shard_name - - if not filtered_weight_map: - return shard_files, metadata - - shard_name_to_path = {path.split("/")[-1]: path for path in shard_files} - filtered_shard_names = sorted(set(filtered_weight_map.values())) - filtered_shard_files = [shard_name_to_path[name] for name in filtered_shard_names if name in shard_name_to_path] - if not filtered_shard_files: - return shard_files, metadata - - metadata["weight_map"] = filtered_weight_map - metadata["all_checkpoint_keys"] = list(filtered_weight_map.keys()) - return filtered_shard_files, metadata +# MODEL_ID = "Qwen/Qwen3-VL-235B-A22B-Instruct" +MODEL_ID = "tiny-random/qwen3-vl-moe" - transformers.modeling_utils.get_checkpoint_shard_files = patched_get_checkpoint_shard_files - transformers.modeling_utils._window_shard_patch_installed = True - - -def _set_layer_windows( - text_start: int, - text_end: int, - text_total_layers: int, -): - transformers.modeling_utils.PreTrainedModel._start = text_start - transformers.modeling_utils.PreTrainedModel._end = text_end - transformers.modeling_utils.PreTrainedModel._total_layers = text_total_layers - transformers.modeling_utils.PreTrainedModel._text_start = text_start - transformers.modeling_utils.PreTrainedModel._text_end = text_end - transformers.modeling_utils.PreTrainedModel._text_total_layers = text_total_layers - - # Qwen3-VL-MoE model code still checks QEffQwen3_5MoeTextModel window attrs - # in a few places. Set both classes to keep layer-wise behavior consistent. - qeff_vl_mod = QEfficient.transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe - qeff_vl_mod.QEffQwen3VLMoeTextModel._start = text_start - qeff_vl_mod.QEffQwen3VLMoeTextModel._end = text_end - qeff_vl_mod.QEffQwen3VLMoeTextModel._total_layers = text_total_layers - - qeff_35_mod = getattr(QEfficient.transformers.models, "qwen3_5_moe", None) - if qeff_35_mod is not None: - qeff_35_text_model = getattr(qeff_35_mod.modeling_qwen3_5_moe, "QEffQwen3_5MoeTextModel", None) - if qeff_35_text_model is not None: - qeff_35_text_model._start = text_start - qeff_35_text_model._end = text_end - qeff_35_text_model._total_layers = text_total_layers - - QEfficient.base.modeling_qeff.QEFFBaseModel._start = text_start - QEfficient.base.modeling_qeff.QEFFBaseModel._end = text_end - QEfficient.base.modeling_qeff.QEFFBaseModel._total_layers = text_total_layers - - -def _stitch_layerwise_if_available(export_root: Path): - # Some branches expose this helper; fall back gracefully when unavailable. - pipeline_fn = getattr(QEfficient.utils, "layerwise_pipeline", None) - if callable(pipeline_fn): - return pipeline_fn(str(export_root)) - print(f"layerwise_pipeline() not found. Layer-wise ONNX shards kept under: {export_root / 'onnx_layerwise_tmp'}") - return str(export_root / "onnx_layerwise_tmp") +def main(): + config = AutoConfig.from_pretrained(MODEL_ID) + config.torch_dtype = "float32" + config.vision_config.deepstack_visual_indexes = [8, 27, 36] -def _new_qeff_model(model_id: str, config): - return QEFFAutoModelForImageTextToText.from_pretrained( - model_id, + qeff_model = QEFFAutoModelForImageTextToText.from_pretrained( + MODEL_ID, attn_implementation="eager", kv_offload=True, config=config, torch_dtype=torch.float32, ) - -def main(): - config = AutoConfig.from_pretrained(MODEL_ID) - config.torch_dtype = "float32" - # config.vision_config.depth = 9 - # config.text_config.num_hidden_layers = 2 - config.vision_config.deepstack_visual_indexes = [8, 27, 36] - - # if TEST_TEXT_LAYERS: - # config.text_config.num_hidden_layers = TEST_TEXT_LAYERS - - text_config = getattr(config, "text_config", config) - text_total_layers = getattr(text_config, "num_hidden_layers", None) - if text_total_layers is None: - raise ValueError("Could not resolve `num_hidden_layers` from config.text_config.") - config.text_config.num_hidden_layers = text_total_layers - _ensure_pretrained_window_attrs() - _install_shard_window_patch() - - hf_qwen_mod = transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe - _install_window_patch(hf_qwen_mod.Qwen3VLMoeForConditionalGeneration) - if hasattr(hf_qwen_mod, "Qwen3VLMoeForCausalLM"): - _install_window_patch(hf_qwen_mod.Qwen3VLMoeForCausalLM) - - text_windows = _build_layer_windows(total_layers=text_total_layers, window_size=TEXT_WINDOW_SIZE) - # Keep layerwise only on text path in this loop. - num_windows = len(text_windows) - first_onnx_path = None - os.environ["LAYERWISE_EXPORT"] = "True" - for window_idx in range(num_windows): - text_start, text_end = text_windows[window_idx] if window_idx < len(text_windows) else (0, 0) - skip_lang_for_window = window_idx >= len(text_windows) - - _set_layer_windows( - text_start=text_start, - text_end=text_end, - text_total_layers=text_total_layers, - ) - print( - f"Exporting window {window_idx + 1}/{num_windows} " - f"text=[{text_start},{text_end})/{text_total_layers} " - f"skip_lang={skip_lang_for_window}" - ) - - qeff_model = _new_qeff_model(MODEL_ID, config) - if hasattr(qeff_model, "model"): - _null_outside_window_layers( - qeff_model.model, - apply_text=not skip_lang_for_window, - ) - - onnx_path = qeff_model.compile( - batch_size=BATCH_SIZE, - prefill_seq_len=PREFILL_SEQ_LEN, - ctx_len=CTX_LEN, - num_cores=NUM_CORES, - num_devices=NUM_DEVICES, - height=HEIGHT, - width=WIDTH, - mxfp6_matmul=True, - aic_enable_depth_first=True, - skip_vision=True, - skip_lang=skip_lang_for_window, - split_retained_state_io=True, - use_onnx_subfunctions=True, - mos=1, - ) - - if first_onnx_path is None: - first_onnx_path = Path(str(onnx_path["lang_decode_qpc_path"])) - - if first_onnx_path is None: - raise RuntimeError("No ONNX path produced during layer-wise language export.") - - export_root = _resolve_export_root(first_onnx_path) - final_artifact = _stitch_layerwise_if_available(export_root) - print(f"Layer-wise language export completed. Final artifact/root: {final_artifact}") - - os.environ["LAYERWISE_EXPORT"] = "False" qpc_path = qeff_model.compile( - lang_onnx_path=final_artifact, - batch_size=BATCH_SIZE, - prefill_seq_len=PREFILL_SEQ_LEN, - ctx_len=CTX_LEN, - num_cores=NUM_CORES, - num_devices=NUM_DEVICES, - height=HEIGHT, - width=WIDTH, + batch_size=1, + prefill_seq_len=1, + ctx_len=4096, + num_cores=16, + num_devices=4, + height=354, + width=536, mxfp6_matmul=True, aic_enable_depth_first=True, skip_vision=True, - skip_lang=skip_lang_for_window, split_retained_state_io=True, use_onnx_subfunctions=True, mos=1, + layerwise=True, + layerwise_window_size=1, ) - print(f"Final QPC path: {qpc_path}") diff --git a/tests/unit_test/models/test_model_quickcheck.py b/tests/unit_test/models/test_model_quickcheck.py index 4b7ed6f17..e8eb1f09c 100644 --- a/tests/unit_test/models/test_model_quickcheck.py +++ b/tests/unit_test/models/test_model_quickcheck.py @@ -1282,3 +1282,118 @@ def test_no_tag_falls_back_to_lm_rules(self): result = to_named_specializations(flat) assert result[0]["name"] == "Prefill" assert result[1]["name"] == "Decode" + + +# --------------------------------------------------------------------------- +# Layer-wise export (provisional, scheduled for deprecation) +# --------------------------------------------------------------------------- + +LAYERWISE_TINY_MODEL_ID = "tiny-random/qwen3-vl-moe" +LAYERWISE_TINY_MODEL_IDS = { + "qwen3_vl_moe": "tiny-random/qwen3-vl-moe", + "qwen3_5_moe": "tiny-random/qwen3.5-moe", + "qwen3_moe": "tiny-random/qwen3-moe", +} + + +@pytest.mark.llm_model +def test_layerwise_window_helpers(): + """Pure-Python coverage of the windowing helpers - no model load required.""" + from QEfficient.transformers.models import _layerwise + + assert _layerwise._build_layer_windows(4, 1) == [(0, 1), (1, 2), (2, 3), (3, 4)] + assert _layerwise._build_layer_windows(5, 2) == [(0, 2), (2, 4), (4, 5)] + with pytest.raises(ValueError): + _layerwise._build_layer_windows(0, 1) + with pytest.raises(ValueError): + _layerwise._build_layer_windows(4, 0) + + +@pytest.mark.llm_model +def test_layerwise_supported_guard_rejects_unrelated_model(): + """layerwise=True must hard-fail on architectures without windowing hooks.""" + from QEfficient.transformers.models import _layerwise + + config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-LlamaForCausalLM") + with pytest.raises(NotImplementedError, match="layerwise=True is only supported"): + _layerwise.assert_layerwise_supported(config) + + +@pytest.mark.llm_model +def test_layerwise_supported_guard_accepts_qwen3_vl_moe(): + from QEfficient.transformers.models import _layerwise + + try: + config = AutoConfig.from_pretrained(LAYERWISE_TINY_MODEL_ID) + except Exception as exc: + _skip_on_model_fetch_error(exc, LAYERWISE_TINY_MODEL_ID) + resolved = _layerwise.assert_layerwise_supported(config) + assert resolved in {"qwen3_vl_moe", "qwen3_vl_moe_text"} + + +@pytest.mark.llm_model +@pytest.mark.parametrize( + ("arch", "model_id"), + sorted(LAYERWISE_TINY_MODEL_IDS.items()), + ids=sorted(LAYERWISE_TINY_MODEL_IDS), +) +def test_layerwise_supported_guard_accepts_all_supported(arch, model_id): + """Guard must accept each architecture in the layerwise allowlist.""" + from QEfficient.transformers.models import _layerwise + + try: + config = AutoConfig.from_pretrained(model_id) + except Exception as exc: + _skip_on_model_fetch_error(exc, model_id) + resolved = _layerwise.assert_layerwise_supported(config) + assert arch in resolved or resolved.startswith(arch) + + +@pytest.mark.llm_model +def test_layerwise_off_does_not_set_env_var(tmp_path): + """Backward compat: layerwise must be controlled purely via the API, + never via environment variables, and must be off by default.""" + from QEfficient.base.modeling_qeff import QEFFBaseModel + from QEfficient.transformers.models import _layerwise # noqa: F401 + + assert os.environ.get("LAYERWISE_EXPORT") is None + assert QEFFBaseModel._layerwise_active is False + + +@pytest.mark.llm_model +def test_layerwise_context_manager_toggles_class_flag(): + """The driver's context manager must flip the class flag and restore it, + even on exception, with no env-var side-effects.""" + from QEfficient.base.modeling_qeff import QEFFBaseModel + from QEfficient.transformers.models import _layerwise + + assert QEFFBaseModel._layerwise_active is False + with _layerwise._layerwise_export_env(): + assert QEFFBaseModel._layerwise_active is True + assert "LAYERWISE_EXPORT" not in os.environ + assert QEFFBaseModel._layerwise_active is False + + try: + with _layerwise._layerwise_export_env(): + raise RuntimeError("boom") + except RuntimeError: + pass + assert QEFFBaseModel._layerwise_active is False + + +@pytest.mark.llm_model +def test_layerwise_compile_rejects_unsupported_model(): + """End-to-end smoke: invoking layerwise=True on llama bubbles the guard error.""" + try: + qeff_model = QEFFAutoModelForCausalLM.from_pretrained( + "hf-internal-testing/tiny-random-LlamaForCausalLM", + ) + except Exception as exc: + _skip_on_model_fetch_error(exc, "tiny-random/llama") + # CausalLM does not expose a layerwise= kwarg today; only DualQPC VLM does. + # So this test guards via the helper directly to make the contract explicit + # for future surface expansion. + from QEfficient.transformers.models import _layerwise + + with pytest.raises(NotImplementedError): + _layerwise.assert_layerwise_supported(qeff_model.model.config) From 6f8873a29aa4652968877436a32826c4d29c380a Mon Sep 17 00:00:00 2001 From: vbaddi Date: Sat, 6 Jun 2026 01:04:50 +0530 Subject: [PATCH 2/8] fix: shrink layerwise shards, support fp16, suppress qconfig noise - Slim per-window export: truncate sin_cached/cos_cached to ctx_len and null embed_tokens / lm_head when unreached. - Fix fp16 layerwise export: _export_layerwise synthesized inputs_embeds via torch.rand without a dtype. - Suppress confusing "An unexpected error occurred while dumping the qconfig" message when compile short-circuits without producing a QPC (e.g. layerwise per-window export). dump_qconfig now skips when qpc_path is None and demotes real failures to logger.debug. Signed-off-by: vbaddi --- QEfficient/base/modeling_qeff.py | 9 +- QEfficient/transformers/models/_layerwise.py | 98 +++++++++++++++++++ QEfficient/utils/_utils.py | 11 ++- .../qwen3_vl_moe_layerwise_decode.py | 9 +- 4 files changed, 119 insertions(+), 8 deletions(-) diff --git a/QEfficient/base/modeling_qeff.py b/QEfficient/base/modeling_qeff.py index 455b54593..aa890ee99 100755 --- a/QEfficient/base/modeling_qeff.py +++ b/QEfficient/base/modeling_qeff.py @@ -8,7 +8,6 @@ import gc import inspect import logging -import os import shutil import subprocess import warnings @@ -545,9 +544,15 @@ def _resolve_pkv_layers(pkv_obj): z = example_inputs.pop("input_ids") if is_vision: hidden_size = self.model.language_model.config.hidden_size + embed_dtype = getattr(self.model.language_model.config, "torch_dtype", None) else: hidden_size = self.model.model.config.hidden_size - inputs_embeds = torch.rand(z.shape[0], z.shape[1], hidden_size, device=z.device) + embed_dtype = getattr(self.model.model.config, "torch_dtype", None) + # Match the model's dtype so per-window export does not introduce a + # float32/float16 mismatch when running through fp16 decoder layers. + if embed_dtype is None: + embed_dtype = next(self.model.parameters()).dtype + inputs_embeds = torch.rand(z.shape[0], z.shape[1], hidden_size, device=z.device, dtype=embed_dtype) example_inputs["inputs_embeds"] = inputs_embeds dynamic_axes["inputs_embeds"] = dynamic_axes.pop("input_ids") diff --git a/QEfficient/transformers/models/_layerwise.py b/QEfficient/transformers/models/_layerwise.py index 075a8865e..a897fc97e 100644 --- a/QEfficient/transformers/models/_layerwise.py +++ b/QEfficient/transformers/models/_layerwise.py @@ -131,6 +131,103 @@ def _null_outside_window_layers(model, *, apply_text: bool = True) -> None: text_layers[idx] = None +def _find_language_model(model): + """Locate the inner language_model that owns sin_cached / cos_cached / embed_tokens.""" + candidates = [] + if hasattr(model, "model") and hasattr(model.model, "language_model"): + candidates.append(model.model.language_model) + if hasattr(model, "language_model"): + candidates.append(model.language_model) + if hasattr(model, "model"): + candidates.append(model.model) + candidates.append(model) + for cand in candidates: + if any(hasattr(cand, attr) for attr in ("sin_cached", "cos_cached", "embed_tokens")): + return cand + return None + + +def _slim_for_window_export(qeff_model, *, ctx_len: Optional[int]) -> None: + """Shrink top-level params that are unused (or oversized) for this window. + + Without this, every per-window ONNX shard re-bakes the full RoPE base table + (``sin_cached``/``cos_cached`` of shape ``[max_position_embeddings, head_dim]``, + typically tens of MB in fp32) plus the full vocab embedding, blowing each + layer-window shard up by 1-2 orders of magnitude over the actual layer + weight footprint. Each top-level param is touched in-place; the next + window rebuilds the model from scratch via the factory so there is no + leakage across windows. + """ + import torch + + inner = qeff_model.model if hasattr(qeff_model, "model") else qeff_model + lm = _find_language_model(inner) + if lm is None: + return + + pt = transformers.modeling_utils.PreTrainedModel + text_start = int(getattr(pt, "_text_start", getattr(pt, "_start", 0))) + text_end = int(getattr(pt, "_text_end", getattr(pt, "_end", 0))) + text_total = int(getattr(pt, "_text_total_layers", getattr(pt, "_total_layers", 0) or 0)) + + # 1) Truncate sin_cached / cos_cached to the rows actually addressable at + # inference time (ctx_len). The original tables are sized to + # max_position_embeddings (often 256k+) which is dead weight for export. + if ctx_len: + for attr in ("sin_cached", "cos_cached"): + param = getattr(lm, attr, None) + if param is None or not hasattr(param, "shape") or param.dim() < 2: + continue + cur_rows = int(param.shape[0]) + target_rows = max(1, int(ctx_len)) + if cur_rows <= target_rows: + continue + with torch.no_grad(): + truncated = param.detach()[:target_rows].clone().contiguous() + new_param = torch.nn.Parameter(truncated, requires_grad=False) + setattr(lm, attr, new_param) + + # 2) Drop the vocab embedding for windows that don't run input-id lookup. + # The first window (text_start == 0) is the only one that calls + # ``get_input_embeddings()(input_ids)``; later windows take + # ``inputs_embeds`` directly so the embedding matrix is unreached but + # still serialized. Replace its weight with a tiny placeholder of the + # same dtype/device so module attributes stay valid. + if text_start > 0 and hasattr(lm, "embed_tokens"): + embed = getattr(lm, "embed_tokens", None) + weight = getattr(embed, "weight", None) + if weight is not None and weight.dim() == 2 and weight.shape[0] > 1: + with torch.no_grad(): + tiny = torch.zeros((1, weight.shape[1]), dtype=weight.dtype, device=weight.device) + embed.weight = torch.nn.Parameter(tiny, requires_grad=False) + + # 3) Drop the lm_head for windows that aren't the last one. Only the final + # window applies ``self.model.lm_head(hidden_states)``. + outer = qeff_model.model if hasattr(qeff_model, "model") else qeff_model + lm_head = getattr(outer, "lm_head", None) + if ( + lm_head is not None + and text_total > 0 + and text_end < text_total + and hasattr(lm_head, "weight") + and lm_head.weight is not None + and lm_head.weight.dim() == 2 + and lm_head.weight.shape[0] > 1 + ): + with torch.no_grad(): + tiny = torch.zeros( + (1, lm_head.weight.shape[1]), + dtype=lm_head.weight.dtype, + device=lm_head.weight.device, + ) + lm_head.weight = torch.nn.Parameter(tiny, requires_grad=False) + if getattr(lm_head, "bias", None) is not None: + lm_head.bias = torch.nn.Parameter( + torch.zeros((1,), dtype=lm_head.bias.dtype, device=lm_head.bias.device), + requires_grad=False, + ) + + def _install_window_patch(model_cls) -> None: if getattr(model_cls, "_window_patch_installed", False): return @@ -352,6 +449,7 @@ def run_layerwise( last_qeff_model = qeff_model if hasattr(qeff_model, "model"): _null_outside_window_layers(qeff_model.model, apply_text=True) + _slim_for_window_export(qeff_model, ctx_len=compile_kwargs.get("ctx_len")) window_kwargs = dict(compile_kwargs) # skip_lang is a VLM-only kwarg; only inject when present in caller's kwargs. diff --git a/QEfficient/utils/_utils.py b/QEfficient/utils/_utils.py index 24ab88aa0..dc1dcbc0c 100755 --- a/QEfficient/utils/_utils.py +++ b/QEfficient/utils/_utils.py @@ -706,9 +706,16 @@ def __repr__(self): def dump_qconfig(func): def wrapper(self, *args, **kwargs): result = func(self, *args, **kwargs) + # Skip qconfig dumping when no QPC was actually produced (e.g. the + # layer-wise export short-circuits compile to return an ONNX path). + # Without this guard we'd hit a TypeError inside create_and_dump_qconfigs + # and surface a confusing user-facing message. + qpc_path = getattr(self, "qpc_path", None) + if qpc_path is None: + return result try: create_and_dump_qconfigs( - self.qpc_path, + qpc_path, self.onnx_path, self.get_model_config, [cls.__name__ for cls in self._pytorch_transforms], @@ -724,7 +731,7 @@ def wrapper(self, *args, **kwargs): }, ) except Exception as e: - print(f"An unexpected error occurred while dumping the qconfig: {e}") + logger.debug("Skipping qconfig dump: %s", e) return result return wrapper diff --git a/examples/image_text_to_text/models/qwen3_vl_moe/qwen3_vl_moe_layerwise_decode.py b/examples/image_text_to_text/models/qwen3_vl_moe/qwen3_vl_moe_layerwise_decode.py index 5f5147957..4e51e5aef 100644 --- a/examples/image_text_to_text/models/qwen3_vl_moe/qwen3_vl_moe_layerwise_decode.py +++ b/examples/image_text_to_text/models/qwen3_vl_moe/qwen3_vl_moe_layerwise_decode.py @@ -22,20 +22,20 @@ from QEfficient import QEFFAutoModelForImageTextToText # MODEL_ID = "Qwen/Qwen3-VL-235B-A22B-Instruct" -MODEL_ID = "tiny-random/qwen3-vl-moe" +# MODEL_ID = "tiny-random/qwen3-vl-moe" +MODEL_ID = "Qwen/Qwen3-VL-30B-A3B-Instruct" def main(): config = AutoConfig.from_pretrained(MODEL_ID) - config.torch_dtype = "float32" - config.vision_config.deepstack_visual_indexes = [8, 27, 36] + config.torch_dtype = "float16" qeff_model = QEFFAutoModelForImageTextToText.from_pretrained( MODEL_ID, attn_implementation="eager", kv_offload=True, config=config, - torch_dtype=torch.float32, + torch_dtype=torch.float16, ) qpc_path = qeff_model.compile( @@ -47,6 +47,7 @@ def main(): height=354, width=536, mxfp6_matmul=True, + mxint8_kv_cache=True, aic_enable_depth_first=True, skip_vision=True, split_retained_state_io=True, From 336195566e1964afca3b1116996876a5f5941c7a Mon Sep 17 00:00:00 2001 From: vbaddi Date: Sat, 6 Jun 2026 12:55:59 +0530 Subject: [PATCH 3/8] fix(0606): meta-init outer model, scoped state, cache reuse - Add layerwise=True to from_pretrained (VLM + CausalLM). When set, the outer model is built on the meta device via from_config, so the caller's load no longer pulls full checkpoint weights into RAM. - Stop polluting transformers.modeling_utils.PreTrainedModel with class vars. Window state lives in a module-local _LAYERWISE_STATE dict; the patched HF hooks (shard filter, init nuller) close over it and behave as no-ops when layerwise is inactive. - Cache layerwise ONNX between runs: _export_layerwise short-circuits when final_data/merged_*.onnx already exists, and the stitch step reuses it. - WIP: Hard-cap RoPE rows at 32K for now. (was ctx_len) so changing ctx_len does not invalidate the export hash. - Respect explicit low_cpu_mem_usage=True in from_pretrained for VLM and CausalLM (was unconditionally forced False); used by the layerwise factory for window-only weight materialization on sharded checkpoints. Signed-off-by: vbaddi --- QEfficient/base/modeling_qeff.py | 11 ++ QEfficient/transformers/models/_layerwise.py | 174 +++++++++++++----- .../transformers/models/modeling_auto.py | 104 +++++++++-- examples/disagg_serving/qwen3moe_layerwise.py | 2 +- .../qwen3_5_moe/qwen3_5_moe_layerwise.py | 1 + .../qwen3_5_moe_layerwise_decode.py | 1 + .../models/qwen3_vl_moe/qwen3_vl_moe.py | 3 +- .../qwen3_vl_moe/qwen3_vl_moe_layerwise.py | 1 + .../qwen3_vl_moe_layerwise_decode.py | 1 + 9 files changed, 230 insertions(+), 68 deletions(-) diff --git a/QEfficient/base/modeling_qeff.py b/QEfficient/base/modeling_qeff.py index aa890ee99..996ec3d71 100755 --- a/QEfficient/base/modeling_qeff.py +++ b/QEfficient/base/modeling_qeff.py @@ -502,6 +502,17 @@ def _export_layerwise( self.onnx_path = onnx_path return onnx_path + # Layer-wise reuse: if the merged final ONNX from a prior run exists + # under final_data/, skip per-window export entirely. The driver's + # stitch step picks up the same merged file, so re-running the same + # example without changes goes straight to the QPC compile. + final_data_dir = export_dir / "final_data" + if final_data_dir.is_dir(): + cached_merged = sorted(final_data_dir.glob("merged_*.onnx")) + if cached_merged: + self.onnx_path = cached_merged[-1] + return self.onnx_path + # check if the model is in meta state or weights are offloaded self._model_offloaded_check() diff --git a/QEfficient/transformers/models/_layerwise.py b/QEfficient/transformers/models/_layerwise.py index a897fc97e..02e53abae 100644 --- a/QEfficient/transformers/models/_layerwise.py +++ b/QEfficient/transformers/models/_layerwise.py @@ -38,6 +38,29 @@ } ) +# Hard cap on the RoPE base table rows we serialize per window. Chosen as a +# constant (not a function of ctx_len) on purpose: changing ctx_len at compile +# time should re-use the cached ONNX and only re-run the QPC compile. Any +# inference-time position id is bounded by ctx_len, and ctx_len in practice +# stays well under 32K for the supported MoE families today, so dropping +# the unreachable rows past 32K is lossless. +_LAYERWISE_ROPE_MAX_POSITIONS = 32768 + +# Process-local layer-wise window state. We deliberately avoid setting class +# attributes on transformers.modeling_utils.PreTrainedModel - those would leak +# to every HF model in the process and survive past the layer-wise run. The +# patched HF hooks (shard filter, model-init nuller) close over this dict so +# they can be installed once and behave as no-ops whenever ``active`` is False. +_LAYERWISE_STATE: Dict[str, int] = { + "active": 0, + "start": 0, + "end": 0, + "total_layers": 0, + "text_start": 0, + "text_end": 0, + "text_total_layers": 0, +} + _DEPRECATION_WARNED = False @@ -82,10 +105,13 @@ def assert_layerwise_supported(config) -> str: def _ensure_pretrained_window_attrs() -> None: - pt = transformers.modeling_utils.PreTrainedModel - for attr in ("_start", "_end", "_total_layers", "_text_start", "_text_end", "_text_total_layers"): - if not hasattr(pt, attr): - setattr(pt, attr, 0) + """No-op kept for compatibility with prior call sites. + + Layer-wise window state lives in the module-local ``_LAYERWISE_STATE`` + dict (see top of file); we no longer pollute ``PreTrainedModel`` with + class attributes. + """ + return def _build_layer_windows(total_layers: int, window_size: int) -> List[Tuple[int, int]]: @@ -121,9 +147,8 @@ def _get_text_layers_container(model): def _null_outside_window_layers(model, *, apply_text: bool = True) -> None: if not apply_text: return - pt = transformers.modeling_utils.PreTrainedModel - text_start = int(getattr(pt, "_text_start", getattr(pt, "_start", 0))) - text_end = int(getattr(pt, "_text_end", getattr(pt, "_end", 0))) + text_start = int(_LAYERWISE_STATE["text_start"] or _LAYERWISE_STATE["start"]) + text_end = int(_LAYERWISE_STATE["text_end"] or _LAYERWISE_STATE["end"]) text_layers = _get_text_layers_container(model) if text_layers is not None and text_end > text_start: for idx, _ in enumerate(text_layers): @@ -165,27 +190,29 @@ def _slim_for_window_export(qeff_model, *, ctx_len: Optional[int]) -> None: if lm is None: return - pt = transformers.modeling_utils.PreTrainedModel - text_start = int(getattr(pt, "_text_start", getattr(pt, "_start", 0))) - text_end = int(getattr(pt, "_text_end", getattr(pt, "_end", 0))) - text_total = int(getattr(pt, "_text_total_layers", getattr(pt, "_total_layers", 0) or 0)) - - # 1) Truncate sin_cached / cos_cached to the rows actually addressable at - # inference time (ctx_len). The original tables are sized to - # max_position_embeddings (often 256k+) which is dead weight for export. - if ctx_len: - for attr in ("sin_cached", "cos_cached"): - param = getattr(lm, attr, None) - if param is None or not hasattr(param, "shape") or param.dim() < 2: - continue - cur_rows = int(param.shape[0]) - target_rows = max(1, int(ctx_len)) - if cur_rows <= target_rows: - continue - with torch.no_grad(): - truncated = param.detach()[:target_rows].clone().contiguous() - new_param = torch.nn.Parameter(truncated, requires_grad=False) - setattr(lm, attr, new_param) + text_start = int(_LAYERWISE_STATE["text_start"] or _LAYERWISE_STATE["start"]) + text_end = int(_LAYERWISE_STATE["text_end"] or _LAYERWISE_STATE["end"]) + text_total = int(_LAYERWISE_STATE["text_total_layers"] or _LAYERWISE_STATE["total_layers"] or 0) + + # 1) Truncate sin_cached / cos_cached to a fixed cap (32K rows) instead + # of ctx_len. A constant cap keeps the export hash invariant when the + # user changes ctx_len, so re-compiling at a different context length + # only re-runs the QPC compile - the cached layer-wise ONNX is reused. + # Inference-time position ids are bounded by ctx_len which is well + # under the cap for the supported MoE families today. + del ctx_len # signature retained for forward compat, value intentionally unused + rope_cap = _LAYERWISE_ROPE_MAX_POSITIONS + for attr in ("sin_cached", "cos_cached"): + param = getattr(lm, attr, None) + if param is None or not hasattr(param, "shape") or param.dim() < 2: + continue + cur_rows = int(param.shape[0]) + if cur_rows <= rope_cap: + continue + with torch.no_grad(): + truncated = param.detach()[:rope_cap].clone().contiguous() + new_param = torch.nn.Parameter(truncated, requires_grad=False) + setattr(lm, attr, new_param) # 2) Drop the vocab embedding for windows that don't run input-id lookup. # The first window (text_start == 0) is the only one that calls @@ -236,7 +263,10 @@ def _install_window_patch(model_cls) -> None: @functools.wraps(original_init) def patched_init(self, *args, **kwargs): original_init(self, *args, **kwargs) - _null_outside_window_layers(self, apply_text=True) + # Only nullify decoder layers when the layer-wise driver is actively + # exporting; idle calls to from_pretrained must behave normally. + if _LAYERWISE_STATE["active"]: + _null_outside_window_layers(self, apply_text=True) model_cls.__init__ = patched_init model_cls._window_patch_installed = True @@ -251,15 +281,19 @@ def _install_shard_window_patch() -> None: @functools.wraps(original_get_checkpoint_shard_files) def patched_get_checkpoint_shard_files(*args, **kwargs): shard_files, metadata = original_get_checkpoint_shard_files(*args, **kwargs) + # Honor the module-local state instead of polluting PreTrainedModel. + # When layerwise is not active this reduces to a no-op for any HF + # caller that happens to load checkpoint shards in this process. + if not _LAYERWISE_STATE["active"]: + return shard_files, metadata weight_map = metadata.get("weight_map") if not weight_map: return shard_files, metadata - pt = transformers.modeling_utils.PreTrainedModel - start = int(getattr(pt, "_start", 0)) - end = int(getattr(pt, "_end", 0)) - text_start = int(getattr(pt, "_text_start", start)) - text_end = int(getattr(pt, "_text_end", end)) + start = int(_LAYERWISE_STATE["start"]) + end = int(_LAYERWISE_STATE["end"]) + text_start = int(_LAYERWISE_STATE["text_start"] or start) + text_end = int(_LAYERWISE_STATE["text_end"] or end) if text_end <= text_start: return shard_files, metadata @@ -293,14 +327,19 @@ def patched_get_checkpoint_shard_files(*args, **kwargs): def _set_layer_windows(text_start: int, text_end: int, text_total_layers: int) -> None: - pt = transformers.modeling_utils.PreTrainedModel - pt._start = text_start - pt._end = text_end - pt._total_layers = text_total_layers - pt._text_start = text_start - pt._text_end = text_end - pt._text_total_layers = text_total_layers - + # Update the module-local state used by patched HF hooks. We deliberately + # do NOT set class attributes on transformers.modeling_utils.PreTrainedModel + # here - that would leak to every HF model in the process. + _LAYERWISE_STATE["start"] = text_start + _LAYERWISE_STATE["end"] = text_end + _LAYERWISE_STATE["total_layers"] = text_total_layers + _LAYERWISE_STATE["text_start"] = text_start + _LAYERWISE_STATE["text_end"] = text_end + _LAYERWISE_STATE["text_total_layers"] = text_total_layers + + # The QEff modeling classes themselves expose _start/_end/_total_layers as + # part of their windowing contract (they are read inside their forward + # implementations). Those are ours to set. qeff_vl_mod = getattr(QEfficient.transformers.models, "qwen3_vl_moe", None) if qeff_vl_mod is not None: cls = getattr(qeff_vl_mod.modeling_qwen3_vl_moe, "QEffQwen3VLMoeTextModel", None) @@ -338,10 +377,22 @@ def _resolve_export_root(onnx_path: Path) -> Path: parts = list(onnx_path.parts) if "onnx_layerwise_tmp" in parts: return Path(*parts[: parts.index("onnx_layerwise_tmp")]) + if "final_data" in parts: + return Path(*parts[: parts.index("final_data")]) return onnx_path.parent +def _is_cached_merged(onnx_path: Path) -> bool: + return "final_data" in onnx_path.parts and onnx_path.name.startswith("merged_") + + def _stitch_layerwise_if_available(export_root: Path) -> str: + # If a prior run already produced the merged ONNX, just return it. + cached_dir = export_root / "final_data" + if cached_dir.is_dir(): + cached_merged = sorted(cached_dir.glob("merged_*.onnx")) + if cached_merged: + return str(cached_merged[-1]) pipeline_fn = getattr(QEfficient.utils, "layerwise_pipeline", None) if callable(pipeline_fn): return pipeline_fn(str(export_root)) @@ -349,13 +400,33 @@ def _stitch_layerwise_if_available(export_root: Path) -> str: def _install_window_patches_for(model_type: str) -> None: - """Install the HF __init__/shard patches needed for the given model_type.""" + """Install the HF __init__/shard patches needed for the given model_type. + + The shard-file patch makes ``from_pretrained`` skip checkpoint shards that + only contain weights for layers outside the active window, so loading an + N-layer model in a window of size 1 reads ~1/N of the disk. The init + patch nulls the unused decoder layers right after the model is built so + the full layer list is never instantiated in memory. + """ _install_shard_window_patch() - if "qwen3_vl_moe" in model_type: - hf_mod = transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe - _install_window_patch(hf_mod.Qwen3VLMoeForConditionalGeneration) - if hasattr(hf_mod, "Qwen3VLMoeForCausalLM"): - _install_window_patch(hf_mod.Qwen3VLMoeForCausalLM) + candidates = [] + qwen3_vl_moe_mod = getattr(getattr(transformers.models, "qwen3_vl_moe", None), "modeling_qwen3_vl_moe", None) + if qwen3_vl_moe_mod is not None and "qwen3_vl_moe" in model_type: + candidates.extend( + cls + for name in ("Qwen3VLMoeForConditionalGeneration", "Qwen3VLMoeForCausalLM") + if (cls := getattr(qwen3_vl_moe_mod, name, None)) is not None + ) + qwen3_moe_mod = getattr(getattr(transformers.models, "qwen3_moe", None), "modeling_qwen3_moe", None) + if qwen3_moe_mod is not None and model_type in {"qwen3_moe"}: + candidates.extend( + cls for name in ("Qwen3MoeForCausalLM",) if (cls := getattr(qwen3_moe_mod, name, None)) is not None + ) + # qwen3_5_moe shares the qwen3_vl_moe HF classes today; the QEff modeling + # file overrides behavior. Install the shard patch (above) and rely on + # _null_outside_window_layers running post-init in the driver loop. + for cls in candidates: + _install_window_patch(cls) @contextmanager @@ -367,12 +438,15 @@ def _layerwise_export_env(): interpreters (e.g. test workers) operate independently. """ base = QEfficient.base.modeling_qeff.QEFFBaseModel - prev = getattr(base, "_layerwise_active", False) + prev_active = getattr(base, "_layerwise_active", False) + prev_state_active = _LAYERWISE_STATE["active"] base._layerwise_active = True + _LAYERWISE_STATE["active"] = 1 try: yield finally: - base._layerwise_active = prev + base._layerwise_active = prev_active + _LAYERWISE_STATE["active"] = prev_state_active def _resolve_text_total_layers(config) -> int: diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 108b3e5fd..066749b9d 100755 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -123,6 +123,29 @@ def _resolve_torch_dtype(kwargs: dict) -> None: kwargs["torch_dtype"] = torch.float32 +def _build_meta_model(hf_auto_class, pretrained_model_name_or_path, kwargs): + """Construct an HF model on the meta device for layer-wise mode. + + Avoids materializing checkpoint weights at the outer ``from_pretrained`` + call site. The wrapper still has a fully-typed ``nn.Module`` (so config, + architectures, and module structure are all real), but every parameter + and buffer is a meta tensor — zero RAM. The layer-wise driver later + rebuilds a real per-window model when ``compile()``/``export()`` runs. + """ + from transformers import AutoConfig + + config = kwargs.get("config", None) + if config is None: + config_kwargs = { + k: kwargs[k] for k in ("trust_remote_code", "revision", "token", "subfolder", "cache_dir") if k in kwargs + } + config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **config_kwargs) + torch_dtype = kwargs.get("torch_dtype", torch.float32) + with torch.device("meta"): + model = hf_auto_class.from_config(config, torch_dtype=torch_dtype) + return model + + class QEFFTransformersBase(QEFFBaseModel): """ Base class for QEfficient wrappers around HuggingFace transformer models. @@ -1320,10 +1343,17 @@ def from_pretrained(cls, pretrained_model_name_or_path: str, qaic_config: Option if kwargs.get("attn_implementation", None) not in {None, "eager"}: logger.warning('Updating attn_implementation="eager"') - if kwargs.get("low_cpu_mem_usage", None): - logger.warning("Updating low_cpu_mem_usage=False") - - kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False}) + # Respect an explicit low_cpu_mem_usage=True from the caller (used by the + # layer-wise driver to keep RAM bounded by one window's weights via meta + # init + sharded materialization). For everyone else, force False to + # match prior behavior. + explicit_low_cpu = kwargs.get("low_cpu_mem_usage", None) is True + kwargs.update( + { + "attn_implementation": "eager", + "low_cpu_mem_usage": True if explicit_low_cpu else False, + } + ) _resolve_torch_dtype(kwargs) model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, **kwargs) @@ -1501,11 +1531,14 @@ def _layerwise_factory_kwargs(self): """Reproduce the from_pretrained kwargs needed to rebuild this wrapper per window.""" # Mirror the dual-QPC from_pretrained surface; the layerwise driver passes # config explicitly per call, so we only carry torch_dtype + attn here. + # low_cpu_mem_usage=True works with the shard-window patch to allocate + # only the active window's weights instead of the full model. torch_dtype = getattr(self.config, "torch_dtype", None) return { "attn_implementation": "eager", "kv_offload": True, "torch_dtype": torch_dtype, + "low_cpu_mem_usage": True, } def _build_layerwise_factory(self): @@ -2969,6 +3002,7 @@ def from_pretrained( kv_offload: Optional[bool] = None, continuous_batching: bool = False, qaic_config: Optional[dict] = None, + layerwise: bool = False, **kwargs, ): """ @@ -3009,17 +3043,31 @@ def from_pretrained( if kwargs.get("attn_implementation", None) not in {None, "eager"}: logger.warning('Updating attn_implementation="eager"') - if kwargs.get("low_cpu_mem_usage", None): - logger.warning("Updating low_cpu_mem_usage=False") - - kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False}) + # Respect explicit low_cpu_mem_usage=True (used by the layer-wise driver + # to materialize only the active window's weights via meta init + + # sharded materialization). Default behavior remains False. + explicit_low_cpu = kwargs.get("low_cpu_mem_usage", None) is True + kwargs.update( + { + "attn_implementation": "eager", + "low_cpu_mem_usage": True if explicit_low_cpu else False, + } + ) _resolve_torch_dtype(kwargs) - model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, **kwargs) + if layerwise: + # Layer-wise mode: build the outer model on the meta device so the + # caller's ``from_pretrained`` does not pull the full checkpoint + # into RAM. compile()/export() rebuilds a real per-window model + # internally via the layer-wise driver, so the outer instance is + # only used as a config holder. + model = _build_meta_model(cls._hf_auto_class, pretrained_model_name_or_path, kwargs) + else: + model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, **kwargs) kwargs.update({"enable_proxy": enable_proxy} if enable_proxy else {}) - return cls( + instance = cls( model, kv_offload=kv_offload, continuous_batching=continuous_batching, @@ -3027,6 +3075,12 @@ def from_pretrained( qaic_config=qaic_config, **kwargs, ) + # Mark the wrapper so its compile() can default ``layerwise=True`` if + # the user forgot to pass it (the meta model cannot be exported any + # other way) and so the driver knows weights still need to be loaded. + if layerwise: + instance._layerwise_outer_meta = True + return instance MISCLASSIFIED_CAUSAL_LM_TO_QEFF_AUTO_CLASS_MAP = { @@ -3206,6 +3260,7 @@ def from_pretrained( continuous_batching: bool = False, qaic_config: Optional[dict] = None, max_seq_len_cached: Optional[int] = None, + layerwise: bool = False, *args, **kwargs, ): @@ -3261,15 +3316,28 @@ def from_pretrained( if kwargs.get("attn_implementation", None) not in {None, "eager"}: logger.warning('Updating attn_implementation="eager"') - if kwargs.get("low_cpu_mem_usage", None): - logger.warning("Updating low_cpu_mem_usage=False") - kv_offload = kwargs.pop("kv_offload", None) - kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False}) + # Respect explicit low_cpu_mem_usage=True (used by the layer-wise driver + # to materialize only the active window's weights via meta init + + # sharded materialization). Default behavior remains False. + explicit_low_cpu = kwargs.get("low_cpu_mem_usage", None) is True + kwargs.update( + { + "attn_implementation": "eager", + "low_cpu_mem_usage": True if explicit_low_cpu else False, + } + ) _resolve_torch_dtype(kwargs) - model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, *args, **kwargs) + if layerwise: + # Layer-wise mode: build the outer model on the meta device. The + # caller still gets a typed wrapper, but no checkpoint weights are + # pulled into RAM. compile()/export() rebuilds a real per-window + # model internally via the layer-wise driver. + model = _build_meta_model(cls._hf_auto_class, pretrained_model_name_or_path, kwargs) + else: + model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, *args, **kwargs) if qaic_config is not None: qaic_config["pretrained_model_name_or_path"] = pretrained_model_name_or_path @@ -3284,7 +3352,7 @@ def from_pretrained( continuous_batching=continuous_batching, **kwargs, ) - return cls( + instance = cls( model, continuous_batching=continuous_batching, qaic_config=qaic_config, @@ -3292,6 +3360,9 @@ def from_pretrained( max_seq_len_cached=max_seq_len_cached, **kwargs, ) + if layerwise: + instance._layerwise_outer_meta = True + return instance @property def get_model_config(self) -> dict: @@ -3388,6 +3459,7 @@ def _factory(model_id, config): config=config, torch_dtype=torch_dtype, continuous_batching=self.continuous_batching, + low_cpu_mem_usage=True, ) return _layerwise.run_layerwise( diff --git a/examples/disagg_serving/qwen3moe_layerwise.py b/examples/disagg_serving/qwen3moe_layerwise.py index 762df044e..3c9339b30 100644 --- a/examples/disagg_serving/qwen3moe_layerwise.py +++ b/examples/disagg_serving/qwen3moe_layerwise.py @@ -27,7 +27,7 @@ tokenizer = AutoTokenizer.from_pretrained(model_id) -qeff_model = QEFFAutoModelForCausalLM.from_pretrained(model_id, config=config) +qeff_model = QEFFAutoModelForCausalLM.from_pretrained(model_id, config=config, layerwise=True) qpc_path = qeff_model.compile( prefill_seq_len=PREFILL_SEQ_LEN, diff --git a/examples/image_text_to_text/models/qwen3_5_moe/qwen3_5_moe_layerwise.py b/examples/image_text_to_text/models/qwen3_5_moe/qwen3_5_moe_layerwise.py index e00209c57..e1924d837 100644 --- a/examples/image_text_to_text/models/qwen3_5_moe/qwen3_5_moe_layerwise.py +++ b/examples/image_text_to_text/models/qwen3_5_moe/qwen3_5_moe_layerwise.py @@ -33,6 +33,7 @@ def main(): kv_offload=True, config=config, torch_dtype=torch.float32, + layerwise=True, ) qpc_path = qeff_model.compile( diff --git a/examples/image_text_to_text/models/qwen3_5_moe/qwen3_5_moe_layerwise_decode.py b/examples/image_text_to_text/models/qwen3_5_moe/qwen3_5_moe_layerwise_decode.py index e2c9d77e7..1e402438d 100644 --- a/examples/image_text_to_text/models/qwen3_5_moe/qwen3_5_moe_layerwise_decode.py +++ b/examples/image_text_to_text/models/qwen3_5_moe/qwen3_5_moe_layerwise_decode.py @@ -33,6 +33,7 @@ def main(): kv_offload=True, config=config, torch_dtype=torch.float32, + layerwise=True, ) qpc_path = qeff_model.compile( diff --git a/examples/image_text_to_text/models/qwen3_vl_moe/qwen3_vl_moe.py b/examples/image_text_to_text/models/qwen3_vl_moe/qwen3_vl_moe.py index 67f199f0c..35dd6b5ee 100644 --- a/examples/image_text_to_text/models/qwen3_vl_moe/qwen3_vl_moe.py +++ b/examples/image_text_to_text/models/qwen3_vl_moe/qwen3_vl_moe.py @@ -13,7 +13,8 @@ from QEfficient import QEFFAutoModelForImageTextToText -model_id = "Qwen/Qwen3-VL-30B-A3B-Instruct" +# model_id = "Qwen/Qwen3-VL-30B-A3B-Instruct" +model_id = "tiny-random/qwen3-vl-moe" config = AutoConfig.from_pretrained(model_id) # For faster execution user can run with lesser layers, For Testing Purpose Only. Please ensure to use the configuration given below as random configurations may fail due to deepstack diff --git a/examples/image_text_to_text/models/qwen3_vl_moe/qwen3_vl_moe_layerwise.py b/examples/image_text_to_text/models/qwen3_vl_moe/qwen3_vl_moe_layerwise.py index 4973ec325..0912738b6 100644 --- a/examples/image_text_to_text/models/qwen3_vl_moe/qwen3_vl_moe_layerwise.py +++ b/examples/image_text_to_text/models/qwen3_vl_moe/qwen3_vl_moe_layerwise.py @@ -35,6 +35,7 @@ def main(): kv_offload=True, config=config, torch_dtype=torch.float32, + layerwise=True, ) qpc_path = qeff_model.compile( diff --git a/examples/image_text_to_text/models/qwen3_vl_moe/qwen3_vl_moe_layerwise_decode.py b/examples/image_text_to_text/models/qwen3_vl_moe/qwen3_vl_moe_layerwise_decode.py index 4e51e5aef..d87ded2c0 100644 --- a/examples/image_text_to_text/models/qwen3_vl_moe/qwen3_vl_moe_layerwise_decode.py +++ b/examples/image_text_to_text/models/qwen3_vl_moe/qwen3_vl_moe_layerwise_decode.py @@ -36,6 +36,7 @@ def main(): kv_offload=True, config=config, torch_dtype=torch.float16, + layerwise=True, ) qpc_path = qeff_model.compile( From 8bc683869bcf3fe6199e3bc43ab4ad4cb2aa14b4 Mon Sep 17 00:00:00 2001 From: vbaddi Date: Sat, 6 Jun 2026 19:27:49 +0530 Subject: [PATCH 4/8] fix(layerwise): harden cache reuse and retained-state compatibility - Refactor layerwise cache probing to avoid per-window model loads when complete cached artifacts already exist. - Ensure cache misses export all layer windows instead of stopping after the first shard. - Reuse only complete merged layerwise ONNX graphs (merged_0-N.onnx) and preserve existing QPC cache behavior. - Fix retained-state custom I/O naming for ONNX subfunction exports and runtime aliases for prefixed/internal outputs. - Hydrate outer VLM wrapper QPC paths after layerwise compile so generate() works immediately. - Add focused regressions for cache probing. Signed-off-by: vbaddi --- QEfficient/base/modeling_qeff.py | 28 ++- QEfficient/generation/cloud_infer.py | 25 ++- QEfficient/transformers/models/_layerwise.py | 57 +++++- .../transformers/models/modeling_auto.py | 192 ++++++++++-------- QEfficient/utils/export_utils.py | 6 +- .../qwen3_vl_moe/qwen3_vl_disagg_mode.py | 3 +- .../qwen3_vl_moe/qwen3_vl_moe_layerwise.py | 6 +- .../qwen3_vl_moe_layerwise_decode.py | 50 ++++- .../unit_test/models/test_model_quickcheck.py | 153 ++++++++++++++ 9 files changed, 399 insertions(+), 121 deletions(-) diff --git a/QEfficient/base/modeling_qeff.py b/QEfficient/base/modeling_qeff.py index 996ec3d71..a7c63ad0e 100755 --- a/QEfficient/base/modeling_qeff.py +++ b/QEfficient/base/modeling_qeff.py @@ -488,6 +488,7 @@ def _export_layerwise( prefill_only: Optional[bool] = False, **export_kwargs, ) -> str: + cache_probe = export_kwargs.pop("_layerwise_cache_probe", False) idx = int(QEFFBaseModel._start) end_idx = int(getattr(QEFFBaseModel, "_end", idx + 1)) if end_idx <= idx: @@ -508,10 +509,13 @@ def _export_layerwise( # example without changes goes straight to the QPC compile. final_data_dir = export_dir / "final_data" if final_data_dir.is_dir(): - cached_merged = sorted(final_data_dir.glob("merged_*.onnx")) - if cached_merged: - self.onnx_path = cached_merged[-1] + total_layers = int(getattr(QEFFBaseModel, "_total_layers", 0) or 0) + cached_merged = final_data_dir / f"merged_0-{total_layers}.onnx" + if total_layers > 0 and cached_merged.is_file(): + self.onnx_path = cached_merged return self.onnx_path + if cache_probe: + return None # check if the model is in meta state or weights are offloaded self._model_offloaded_check() @@ -752,6 +756,7 @@ def _compile( For QNN Compilation path, when enable_qnn is set to True, any parameter passed in compiler_options will be ignored. """ + layerwise_cache_probe = compiler_options.pop("_layerwise_cache_probe", False) moe_prefill_packed_chunk_size = compiler_options.pop("moe_prefill_packed_chunk_size", None) if onnx_path is None: # If weights were offloaded after export, compiling must use the existing @@ -771,11 +776,15 @@ def _compile( num_devices=mdp_ts_num_devices, qaic_config=qaic_config, moe_prefill_packed_chunk_size=moe_prefill_packed_chunk_size, + _layerwise_cache_probe=layerwise_cache_probe, **compiler_options, ) - onnx_path = Path(onnx_path) if QEFFBaseModel._layerwise_active: + if onnx_path is None: + return None + onnx_path = Path(onnx_path) return onnx_path + onnx_path = Path(onnx_path) compile_dir = Path(compile_dir or onnx_path.parent) qpc_path = compile_dir / "qpc" @@ -858,14 +867,13 @@ def _compile( compile_dir = qpc_path.with_name(qpc_path.name + "-" + compile_hash) qpc_path = compile_dir / "qpc" - qpc_path.mkdir(parents=True, exist_ok=True) - + if (qpc_path / "programqpc.bin").is_file(): + self.qpc_path = qpc_path + return qpc_path if qpc_path.is_dir(): - if (qpc_path / "programqpc.bin").is_file(): - self.qpc_path = qpc_path - return qpc_path - # Probably compilation failure last time, delete directory to start over + # Probably compilation failure last time, delete directory to start over. shutil.rmtree(qpc_path) + compile_dir.mkdir(parents=True, exist_ok=True) # Write the generated MDP partition config file (not if user provided it) diff --git a/QEfficient/generation/cloud_infer.py b/QEfficient/generation/cloud_infer.py index 47703930a..88d1357b6 100644 --- a/QEfficient/generation/cloud_infer.py +++ b/QEfficient/generation/cloud_infer.py @@ -13,6 +13,21 @@ import numpy as np + +def _public_retained_state_name(output_name: str) -> Optional[str]: + """Map internal subfunction retained-state outputs to public runtime names.""" + suffix = "_InternalRetainedState" + if output_name.endswith(suffix): + return output_name[: -len(suffix)] + "_RetainedState" + return None + + +def _add_basename_binding_aliases(binding_index_map: Dict[str, int], bindings) -> None: + """Allow callers to use unprefixed I/O names for prefixed ONNX graphs.""" + for binding in bindings: + binding_index_map.setdefault(binding.name.rsplit("/", 1)[-1], binding.index) + + try: import qaicrt @@ -101,6 +116,7 @@ def __init__( ] self.bindings = iodesc.selected_set.bindings self.binding_index_map = {binding.name: binding.index for binding in self.bindings} + _add_basename_binding_aliases(self.binding_index_map, self.bindings) # Create and load Program prog_properties = qaicrt.QAicProgramProperties() prog_properties.dataPathTimeoutMs = 60_000 @@ -226,8 +242,15 @@ def run(self, inputs: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: buffer_index = self.binding_index_map[output_name] if self.qbuffers[buffer_index].size == 0: continue - outputs[output_name] = np.frombuffer( + output = np.frombuffer( bytes(output_qbuffers[buffer_index]), self.aic_to_np_dtype_mapping[self.bindings[buffer_index].type], ).reshape(self.buf_dims[buffer_index][1]) + outputs[output_name] = output + output_basename = output_name.rsplit("/", 1)[-1] + outputs.setdefault(output_basename, output) + public_name = _public_retained_state_name(output_name) + if public_name is not None: + outputs[public_name] = output + outputs.setdefault(public_name.rsplit("/", 1)[-1], output) return outputs diff --git a/QEfficient/transformers/models/_layerwise.py b/QEfficient/transformers/models/_layerwise.py index 02e53abae..50cc94fcc 100644 --- a/QEfficient/transformers/models/_layerwise.py +++ b/QEfficient/transformers/models/_layerwise.py @@ -386,19 +386,47 @@ def _is_cached_merged(onnx_path: Path) -> bool: return "final_data" in onnx_path.parts and onnx_path.name.startswith("merged_") -def _stitch_layerwise_if_available(export_root: Path) -> str: - # If a prior run already produced the merged ONNX, just return it. +def _cached_merged_onnx(export_root: Path, total_layers: Optional[int] = None) -> Optional[Path]: + """Return the complete cached merged ONNX, if it exists.""" cached_dir = export_root / "final_data" - if cached_dir.is_dir(): - cached_merged = sorted(cached_dir.glob("merged_*.onnx")) - if cached_merged: - return str(cached_merged[-1]) + if not cached_dir.is_dir(): + return None + if total_layers is not None: + expected = cached_dir / f"merged_0-{total_layers}.onnx" + if expected.is_file(): + return expected + cached_merged = sorted(cached_dir.glob("merged_0-*.onnx")) + return cached_merged[-1] if cached_merged else None + + +def _stitch_layerwise_if_available(export_root: Path, total_layers: Optional[int] = None) -> str: + # If a prior run already produced the complete merged ONNX, just return it. + cached_merged = _cached_merged_onnx(export_root, total_layers) + if cached_merged is not None: + return str(cached_merged) pipeline_fn = getattr(QEfficient.utils, "layerwise_pipeline", None) if callable(pipeline_fn): return pipeline_fn(str(export_root)) return str(export_root / "onnx_layerwise_tmp") +def _cached_layerwise_onnx_path(qeff_model, compile_kwargs: Dict[str, Any]) -> Optional[Path]: + """Return a cached merged ONNX path without exporting or loading weights.""" + probe_kwargs = dict(compile_kwargs) + probe_kwargs["_layerwise_cache_probe"] = True + cached = qeff_model.compile(**probe_kwargs) + if isinstance(cached, dict): + cached = next( + ( + cached.get(key) + for key in ("lang_decode_qpc_path", "lang_prefill_qpc_path", "lang_qpc_path") + if cached.get(key) is not None + ), + None, + ) + return Path(cached) if cached is not None else None + + def _install_window_patches_for(model_type: str) -> None: """Install the HF __init__/shard patches needed for the given model_type. @@ -468,6 +496,7 @@ def run_layerwise( config, qeff_factory, compile_kwargs: Dict[str, Any], + probe_qeff_model=None, window_size: int = 1, final_compile: bool = True, ) -> Any: @@ -488,6 +517,9 @@ def run_layerwise( Forwarded verbatim to ``qeff_model.compile(...)`` per window. The driver injects ``skip_lang`` per-window and ``lang_onnx_path`` for the final stitched compile. + probe_qeff_model : QEffModel, optional + Existing wrapper used for cache probing before any per-window weights are + loaded. In normal layerwise use this is the outer meta wrapper. window_size : int Number of text-decoder layers per window. ``1`` matches the legacy example. @@ -516,7 +548,16 @@ def run_layerwise( last_qeff_model = None with _layerwise_export_env(): - for window_idx, (text_start, text_end) in enumerate(windows): + _set_layer_windows(0, min(window_size, text_total_layers), text_total_layers) + cached_probe = probe_qeff_model or qeff_factory(model_id, config) + cached_onnx_path = _cached_layerwise_onnx_path(cached_probe, compile_kwargs) + if cached_onnx_path is not None: + first_onnx_path = cached_onnx_path + last_qeff_model = cached_probe + + for text_start, text_end in windows: + if cached_onnx_path is not None: + break _set_layer_windows(text_start, text_end, text_total_layers) qeff_model = qeff_factory(model_id, config) @@ -555,7 +596,7 @@ def run_layerwise( raise RuntimeError("Layer-wise export produced no ONNX shards.") export_root = _resolve_export_root(first_onnx_path) - final_artifact = _stitch_layerwise_if_available(export_root) + final_artifact = _stitch_layerwise_if_available(export_root, text_total_layers) _reset_layer_windows() diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 066749b9d..7bca55f9d 100755 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -13,6 +13,7 @@ from typing import List, Optional, Union import numpy as np +import onnx import torch import torch.nn as nn from transformers import ( @@ -146,6 +147,62 @@ def _build_meta_model(hf_auto_class, pretrained_model_name_or_path, kwargs): return model +def _compile_io_name(name: str, *, use_onnx_subfunctions: bool) -> str: + """Return the compiler-visible name for retained-state ONNX outputs.""" + if not use_onnx_subfunctions or not name.endswith("_RetainedState"): + return name + if any(token in name for token in ("key", "value", "compressed_kv", "k_pe")): + return name[: -len("_RetainedState")] + "_InternalRetainedState" + return name + + +def _state_input_name(output_name: str) -> str: + """Map a retained-state output name to its matching state input name.""" + for suffix in ("_InternalRetainedState", "_RetainedState"): + if output_name.endswith(suffix): + return output_name[: -len(suffix)] + return output_name + + +def _filter_custom_io_for_onnx(custom_io: dict, onnx_path: Optional[Union[str, Path]]) -> dict: + """Keep custom-IO entries that exist in the ONNX graph. + + Layerwise stitched graphs may prefix I/O names (for example ``layer_0/``) + and may expose public retained-state names even when subfunction export used + internal names during per-window export. Matching by basename keeps compiler + custom-IO compatible with both graph shapes. + """ + if onnx_path is None: + return custom_io + try: + model = onnx.load(onnx_path, load_external_data=False) + except Exception: + return custom_io + + io_names = {value.name for value in list(model.graph.input) + list(model.graph.output)} + basename_to_name = {name.rsplit("/", 1)[-1]: name for name in io_names} + + def resolve_name(name: str) -> Optional[str]: + candidates = [name] + if name.endswith("_InternalRetainedState"): + candidates.append(name[: -len("_InternalRetainedState")] + "_RetainedState") + elif name.endswith("_RetainedState"): + candidates.append(name[: -len("_RetainedState")] + "_InternalRetainedState") + for candidate in candidates: + if candidate in io_names: + return candidate + if candidate in basename_to_name: + return basename_to_name[candidate] + return None + + filtered = {} + for name, dtype in custom_io.items(): + resolved_name = resolve_name(name) + if resolved_name is not None: + filtered[resolved_name] = dtype + return filtered + + class QEFFTransformersBase(QEFFBaseModel): """ Base class for QEfficient wrappers around HuggingFace transformer models. @@ -1187,6 +1244,7 @@ def export( export_dir=export_dir, offload_pt_weights=offload_pt_weights, use_onnx_subfunctions=kwargs.get("use_onnx_subfunctions", False), + _layerwise_cache_probe=kwargs.get("_layerwise_cache_probe", False), ) else: return self._export( @@ -1412,6 +1470,7 @@ def export( List[str] A list containing the paths to the generated ONNX graph files for both components. """ + layerwise_cache_probe = kwargs.pop("_layerwise_cache_probe", False) if layerwise: return self._run_layerwise_export( export_dir=export_dir, @@ -1471,7 +1530,7 @@ def export( == QEfficient.base.modeling_qeff.QEFFBaseModel._total_layers ) ) - if should_export: + if should_export and not layerwise_cache_probe: self.vision_model.export( inputs["vision"], output_names["vision"], @@ -1497,6 +1556,7 @@ def export( prefill_only=prefill_only, enable_chunking=enable_chunking, prefill_seq_len=prefill_seq_len, + _layerwise_cache_probe=layerwise_cache_probe, ) return self.onnx_path @@ -1592,6 +1652,7 @@ def _run_layerwise_export( config=self.config, qeff_factory=self._build_layerwise_factory(), compile_kwargs=compile_kwargs, + probe_qeff_model=self, window_size=layerwise_window_size, final_compile=False, ) @@ -1611,14 +1672,25 @@ def _run_layerwise_compile( "QEFFAutoModelForImageTextToText.from_pretrained(...). " "Direct __init__ does not preserve the model id needed for per-window reload." ) - return _layerwise.run_layerwise( + qpc_paths = _layerwise.run_layerwise( model_id=model_id, config=self.config, qeff_factory=self._build_layerwise_factory(), compile_kwargs=compile_kwargs, + probe_qeff_model=self, window_size=layerwise_window_size, final_compile=True, ) + self.qpc_paths = qpc_paths + if isinstance(qpc_paths, dict): + self.vision_model.qpc_path = qpc_paths.get("vision_qpc_path") or self.vision_model.qpc_path + self.lang_model.qpc_path = ( + qpc_paths.get("lang_decode_qpc_path") + or qpc_paths.get("lang_prefill_qpc_path") + or qpc_paths.get("lang_qpc_path") + or self.lang_model.qpc_path + ) + return qpc_paths def compile( self, @@ -1740,6 +1812,7 @@ def compile( "KV caching requires continuous batching. Please set `full_batch_size` and " "enable `continuous_batching=True` in `from_pretrained`." ) + layerwise_cache_probe = compiler_options.pop("_layerwise_cache_probe", False) # Infer kv_cache_batch_size if not provided kv_cache_batch_size = kv_cache_batch_size or full_batch_size or batch_size @@ -1810,7 +1883,10 @@ def compile( prefill_only=prefill_only, enable_chunking=enable_chunking, prefill_seq_len=prefill_seq_len, + _layerwise_cache_probe=layerwise_cache_probe, ) + if layerwise_cache_probe: + return self.lang_model.onnx_path if hasattr(self.model, "generate_npi_file") and "node_precision_info" in compiler_options: if self.lang_model.onnx_path is None and not skip_lang: @@ -1845,54 +1921,24 @@ def compile( if not skip_lang: custom_io_lang = {} - # Inputs for output_name in output_names["lang"]: if output_name.endswith("_RetainedState"): - custom_io_lang[output_name[: -len("_RetainedState")]] = ( + compiler_output_name = _compile_io_name( + output_name, + use_onnx_subfunctions=use_onnx_subfunctions, + ) + custom_io_lang[_state_input_name(compiler_output_name)] = ( CUSTOM_IO_DTYPE_MAP[target_dtype] if ("vision_embeds" in output_name or "deepstack_features" in output_name) else kv_cache_dtype ) - - # outputs - for output_name in output_names["lang"]: - if output_name.endswith("_RetainedState"): - custom_io_lang[output_name] = ( + custom_io_lang[compiler_output_name] = ( CUSTOM_IO_DTYPE_MAP[target_dtype] if ("vision_embeds" in output_name or "deepstack_features" in output_name) else kv_cache_dtype ) - def filter_custom_io_lang(custom_io_lang, onnx_path): - # Extract filename - filename = os.path.basename(onnx_path) - - # Extract range from "merged_0-2.onnx" - match = re.search(r"merged_(\d+)-(\d+)\.onnx", filename) - if not match: - return custom_io_lang # no filtering if pattern not found - - start, end = map(int, match.groups()) # e.g. 0, 2 - - filtered = {} - - for k, v in custom_io_lang.items(): - # Keep everything that is NOT KV cache - if ("past_key." not in k) and ("past_value." not in k): - filtered[k] = v - continue - - # Extract layer index - layer_match = re.search(r"past_(?:key|value)\.(\d+)", k) - if layer_match: - idx = int(layer_match.group(1)) - if start <= idx < end: - filtered[k] = v - - return filtered - - if self.lang_model.onnx_path is not None and "merged" in str(self.lang_model.onnx_path): - custom_io_lang = filter_custom_io_lang(custom_io_lang, self.lang_model.onnx_path) + custom_io_lang = _filter_custom_io_for_onnx(custom_io_lang, self.lang_model.onnx_path) if prefill_only: specializations = specializations["lang"][:1] @@ -2599,18 +2645,14 @@ def compile( custom_io = {} target_dtype = getattr(self.model.config, "torch_dtype", torch.float32) kv_cache_dtype = "mxint8" if mxint8_kv_cache else CUSTOM_IO_DTYPE_MAP[target_dtype] - # inputs for input_name in output_names: if input_name.endswith("_RetainedState"): - custom_io[input_name[: -len("_RetainedState")]] = ( + compiler_output_name = _compile_io_name(input_name, use_onnx_subfunctions=use_onnx_subfunctions) + custom_io[_state_input_name(compiler_output_name)] = ( CUSTOM_IO_DTYPE_MAP[target_dtype] if "pixel_values" in input_name else kv_cache_dtype ) - - # outputs - for output_name in output_names: - if output_name.endswith("_RetainedState"): - custom_io[output_name] = ( - CUSTOM_IO_DTYPE_MAP[target_dtype] if "pixel_values" in output_name else kv_cache_dtype + custom_io[compiler_output_name] = ( + CUSTOM_IO_DTYPE_MAP[target_dtype] if "pixel_values" in input_name else kv_cache_dtype ) # TODO this hould be removed once the continous batching is supported for all the models. @@ -3467,6 +3509,7 @@ def _factory(model_id, config): config=config, qeff_factory=_factory, compile_kwargs=forward_kwargs, + probe_qeff_model=self, window_size=layerwise_window_size, final_compile=final_compile, ) @@ -4240,43 +4283,22 @@ def compile( for suffix in ["", "_RetainedState"]: for i in range(self.num_layers): for kv in ["key", "value"]: - custom_io[f"past_{kv}.{i}{suffix}"] = kv_cache_dtype + name = _compile_io_name( + f"past_{kv}.{i}{suffix}", + use_onnx_subfunctions=use_onnx_subfunctions, + ) + custom_io[name] = kv_cache_dtype else: for suffix in ["", "_RetainedState"]: for i in range(self.num_layers): - custom_io[f"compressed_kv.{i}{suffix}"] = kv_cache_dtype - custom_io[f"k_pe.{i}{suffix}"] = kv_cache_dtype - - def filter_custom_io(custom_io_lang, onnx_path): - # Extract filename - filename = os.path.basename(onnx_path) - - # Extract range from "merged_0-2.onnx" - match = re.search(r"merged_(\d+)-(\d+)\.onnx", filename) - if not match: - return custom_io_lang # no filtering if pattern not found - - start, end = map(int, match.groups()) # e.g. 0, 2 - - filtered = {} - - for k, v in custom_io_lang.items(): - # Keep everything that is NOT KV cache - if ("past_key." not in k) and ("past_value." not in k): - filtered[k] = v - continue - - # Extract layer index - layer_match = re.search(r"past_(?:key|value)\.(\d+)", k) - if layer_match: - idx = int(layer_match.group(1)) - if start <= idx < end: - filtered[k] = v - - return filtered + for prefix in ("compressed_kv", "k_pe"): + name = _compile_io_name( + f"{prefix}.{i}{suffix}", + use_onnx_subfunctions=use_onnx_subfunctions, + ) + custom_io[name] = kv_cache_dtype - if onnx_path is not None and "merged" in str(onnx_path): - custom_io = filter_custom_io(custom_io, onnx_path) + custom_io = _filter_custom_io_for_onnx(custom_io, onnx_path) qpc_path = self._compile( onnx_path=onnx_path, @@ -4633,15 +4655,11 @@ def compile( custom_io["input_features"] = kv_cache_dtype - # Slice output_names to get input names - for output_name in output_names: - if output_name.endswith("_RetainedState"): - custom_io[output_name[: -len("_RetainedState")]] = kv_cache_dtype - - # Get output names for output_name in output_names: if output_name.endswith("_RetainedState"): - custom_io[output_name] = kv_cache_dtype + compiler_output_name = _compile_io_name(output_name, use_onnx_subfunctions=use_onnx_subfunctions) + custom_io[_state_input_name(compiler_output_name)] = kv_cache_dtype + custom_io[compiler_output_name] = kv_cache_dtype return self._compile( onnx_path=onnx_path, diff --git a/QEfficient/utils/export_utils.py b/QEfficient/utils/export_utils.py index 901484e72..e1c7dc794 100755 --- a/QEfficient/utils/export_utils.py +++ b/QEfficient/utils/export_utils.py @@ -40,6 +40,7 @@ def export_wrapper(func): """ def wrapper(self, *args, **kwargs): + cache_probe = kwargs.pop("_layerwise_cache_probe", False) # 1. Setup ONNX subfunctions if requested if use_onnx_subfunctions := kwargs.pop("use_onnx_subfunctions", False): args, kwargs = _setup_onnx_subfunctions(self, args, kwargs) @@ -52,12 +53,15 @@ def wrapper(self, *args, **kwargs): export_dir = export_dir.with_name(export_dir.name + "-" + export_hash) kwargs["export_dir"] = export_dir self.export_hash = export_hash + if cache_probe: + kwargs["_layerwise_cache_probe"] = True # 4. Execute the actual export onnx_path = func(self, *args, **kwargs) # 5. Save export metadata - _save_export_metadata(export_dir, filtered_hash_params) + if not cache_probe: + _save_export_metadata(export_dir, filtered_hash_params) # 6. Always cleanup subfunctions if they were setup if use_onnx_subfunctions: diff --git a/examples/image_text_to_text/models/qwen3_vl_moe/qwen3_vl_disagg_mode.py b/examples/image_text_to_text/models/qwen3_vl_moe/qwen3_vl_disagg_mode.py index 585a532fa..839852a78 100644 --- a/examples/image_text_to_text/models/qwen3_vl_moe/qwen3_vl_disagg_mode.py +++ b/examples/image_text_to_text/models/qwen3_vl_moe/qwen3_vl_disagg_mode.py @@ -18,7 +18,8 @@ from QEfficient import QEFFAutoModelForImageTextToText from QEfficient.generation.cloud_infer import QAICInferenceSession -model_id = "Qwen/Qwen3-VL-30B-A3B-Instruct" +# model_id = "Qwen/Qwen3-VL-30B-A3B-Instruct" +model_id = "tiny-random/qwen3-vl-moe" config = AutoConfig.from_pretrained(model_id) # For faster execution user can run with lesser layers, For Testing Purpose Only diff --git a/examples/image_text_to_text/models/qwen3_vl_moe/qwen3_vl_moe_layerwise.py b/examples/image_text_to_text/models/qwen3_vl_moe/qwen3_vl_moe_layerwise.py index 0912738b6..fe120ed9d 100644 --- a/examples/image_text_to_text/models/qwen3_vl_moe/qwen3_vl_moe_layerwise.py +++ b/examples/image_text_to_text/models/qwen3_vl_moe/qwen3_vl_moe_layerwise.py @@ -26,15 +26,15 @@ def main(): config = AutoConfig.from_pretrained(MODEL_ID) - config.torch_dtype = "float32" - config.vision_config.deepstack_visual_indexes = [8, 27, 36] + config.torch_dtype = "float16" + # config.vision_config.deepstack_visual_indexes = [8, 27, 36] qeff_model = QEFFAutoModelForImageTextToText.from_pretrained( MODEL_ID, attn_implementation="eager", kv_offload=True, config=config, - torch_dtype=torch.float32, + torch_dtype=torch.float16, layerwise=True, ) diff --git a/examples/image_text_to_text/models/qwen3_vl_moe/qwen3_vl_moe_layerwise_decode.py b/examples/image_text_to_text/models/qwen3_vl_moe/qwen3_vl_moe_layerwise_decode.py index d87ded2c0..6caeeca08 100644 --- a/examples/image_text_to_text/models/qwen3_vl_moe/qwen3_vl_moe_layerwise_decode.py +++ b/examples/image_text_to_text/models/qwen3_vl_moe/qwen3_vl_moe_layerwise_decode.py @@ -5,30 +5,35 @@ # # ----------------------------------------------------------------------------- -""" -Layerwise compile example for Qwen3-VL-MoE. +"""Layerwise prefill compile example for Qwen3-VL-MoE. The orchestration loop that previously lived in this script has been moved behind the ``layerwise=True`` flag on ``.compile()`` / ``.export()``. Note: ``layerwise=True`` is a provisional API and is scheduled for deprecation -once first-class multi-window export lands. It is currently only supported for -``qwen3_vl_moe``, ``qwen3_5_moe`` and ``qwen3_moe``. +once first-class multi-window export lands. Supported model types: +``qwen3_vl_moe``, ``qwen3_5_moe``, ``qwen3_moe``. """ +import requests import torch -from transformers import AutoConfig +import transformers +from PIL import Image +from qwen_vl_utils import process_vision_info +from transformers import AutoConfig, AutoProcessor, TextStreamer from QEfficient import QEFFAutoModelForImageTextToText # MODEL_ID = "Qwen/Qwen3-VL-235B-A22B-Instruct" -# MODEL_ID = "tiny-random/qwen3-vl-moe" MODEL_ID = "Qwen/Qwen3-VL-30B-A3B-Instruct" +# MODEL_ID = "tiny-random/qwen3-vl-moe" def main(): config = AutoConfig.from_pretrained(MODEL_ID) config.torch_dtype = "float16" + tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL_ID) + processor = AutoProcessor.from_pretrained(MODEL_ID) qeff_model = QEFFAutoModelForImageTextToText.from_pretrained( MODEL_ID, @@ -36,9 +41,9 @@ def main(): kv_offload=True, config=config, torch_dtype=torch.float16, - layerwise=True, + # layerwise=True, ) - + batch_size = 1 qpc_path = qeff_model.compile( batch_size=1, prefill_seq_len=1, @@ -54,11 +59,36 @@ def main(): split_retained_state_io=True, use_onnx_subfunctions=True, mos=1, - layerwise=True, - layerwise_window_size=1, + # layerwise=True, + # layerwise_window_size=1, ) print(f"Final QPC path: {qpc_path}") + messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "Tell me about yourself."}, + ], + }, + ] + + messages = [messages] * batch_size + + inputs = processor.apply_chat_template( + messages, + add_generation_prompt=True, + tokenize=True, + return_dict=True, + return_tensors="pt", + ) + inputs = qeff_model.model.prepare_inputs_for_generation(inputs=inputs, prefill_seq_len=128, batch_size=batch_size) + streamer = TextStreamer(tokenizer) + output = qeff_model.generate(inputs=inputs, generation_len=100) + print(output.generated_ids) + print(tokenizer.batch_decode(output.generated_ids)) + print(output) + if __name__ == "__main__": main() diff --git a/tests/unit_test/models/test_model_quickcheck.py b/tests/unit_test/models/test_model_quickcheck.py index e8eb1f09c..7ab9b78df 100644 --- a/tests/unit_test/models/test_model_quickcheck.py +++ b/tests/unit_test/models/test_model_quickcheck.py @@ -1381,6 +1381,159 @@ def test_layerwise_context_manager_toggles_class_flag(): assert QEFFBaseModel._layerwise_active is False +@pytest.mark.llm_model +def test_layerwise_uses_probe_model_for_cached_export(monkeypatch, tmp_path): + """A cached merged ONNX must avoid rebuilding per-window models.""" + from QEfficient.transformers.models import _layerwise + + class DummyConfig: + model_type = "qwen3_moe" + num_hidden_layers = 2 + + class ProbeModel: + def __init__(self, cached_path): + self.cached_path = cached_path + + def compile(self, **kwargs): + assert kwargs.pop("_layerwise_cache_probe") is True + return self.cached_path + + cached_path = tmp_path / "Model-hash" / "final_data" / "merged_0-2.onnx" + cached_path.parent.mkdir(parents=True) + cached_path.touch() + factory_called = False + + def fail_factory(*args, **kwargs): + nonlocal factory_called + factory_called = True + raise AssertionError("factory must not run when merged ONNX is cached") + + monkeypatch.setattr(_layerwise, "_install_window_patches_for", lambda model_type: None) + result = _layerwise.run_layerwise( + model_id="dummy", + config=DummyConfig(), + qeff_factory=fail_factory, + compile_kwargs={}, + probe_qeff_model=ProbeModel(cached_path), + final_compile=False, + ) + + assert result == str(cached_path) + assert factory_called is False + + +@pytest.mark.llm_model +def test_layerwise_cache_miss_exports_all_windows(monkeypatch, tmp_path): + from QEfficient.transformers.models import _layerwise + + class DummyConfig: + model_type = "qwen3_moe" + num_hidden_layers = 3 + + class ProbeModel: + def compile(self, **kwargs): + assert kwargs.pop("_layerwise_cache_probe") is True + return None + + exported_windows = [] + + class WindowModel: + def __init__(self): + self.model = object() + + def compile(self, **kwargs): + start = _layerwise._LAYERWISE_STATE["text_start"] + end = _layerwise._LAYERWISE_STATE["text_end"] + exported_windows.append((start, end)) + shard = tmp_path / "onnx_layerwise_tmp" / f"layer_{start}_{end}" / f"model_layer_tmp_{start}_{end}.onnx" + shard.parent.mkdir(parents=True, exist_ok=True) + shard.touch() + return str(shard) + + monkeypatch.setattr(_layerwise, "_install_window_patches_for", lambda model_type: None) + monkeypatch.setattr(_layerwise, "_null_outside_window_layers", lambda *args, **kwargs: None) + monkeypatch.setattr(_layerwise, "_slim_for_window_export", lambda *args, **kwargs: None) + monkeypatch.setattr( + _layerwise, + "_stitch_layerwise_if_available", + lambda export_root, total_layers=None: str(export_root / "merged.onnx"), + ) + + result = _layerwise.run_layerwise( + model_id="dummy", + config=DummyConfig(), + qeff_factory=lambda *args, **kwargs: WindowModel(), + compile_kwargs={}, + probe_qeff_model=ProbeModel(), + window_size=1, + final_compile=False, + ) + + assert exported_windows == [(0, 1), (1, 2), (2, 3)] + assert result.endswith("merged.onnx") + + +@pytest.mark.llm_model +def test_layerwise_cached_merged_prefers_complete_graph(tmp_path): + from QEfficient.transformers.models import _layerwise + + final_data = tmp_path / "final_data" + final_data.mkdir() + partial = final_data / "merged_9-48.onnx" + complete = final_data / "merged_0-48.onnx" + partial.touch() + complete.touch() + + assert _layerwise._cached_merged_onnx(tmp_path, total_layers=48) == complete + + +@pytest.mark.llm_model +def test_subfunction_compile_io_names_use_internal_retained_state(): + from QEfficient.transformers.models.modeling_auto import _compile_io_name, _state_input_name + + output_name = _compile_io_name("past_value.1_RetainedState", use_onnx_subfunctions=True) + + assert output_name == "past_value.1_InternalRetainedState" + assert _state_input_name(output_name) == "past_value.1" + assert _compile_io_name("vision_embeds_RetainedState", use_onnx_subfunctions=True) == "vision_embeds_RetainedState" + + +@pytest.mark.llm_model +def test_runtime_aliases_internal_retained_state_outputs(): + from QEfficient.generation.cloud_infer import _add_basename_binding_aliases, _public_retained_state_name + + assert _public_retained_state_name("past_key.0_InternalRetainedState") == "past_key.0_RetainedState" + assert _public_retained_state_name("past_value.1_InternalRetainedState") == "past_value.1_RetainedState" + assert _public_retained_state_name("logits") is None + + binding_map = {"layer_0/input_ids": 3} + bindings = [type("Binding", (), {"name": "layer_0/input_ids", "index": 3})()] + _add_basename_binding_aliases(binding_map, bindings) + assert binding_map["input_ids"] == 3 + + +@pytest.mark.llm_model +def test_layerwise_compile_hydrates_outer_qpc_paths(monkeypatch, tmp_path): + from QEfficient.transformers.models import _layerwise + from QEfficient.transformers.models.modeling_auto import _QEffAutoModelForImageTextToTextDualQPC + + qpc_path = tmp_path / "qpc" + model = object.__new__(_QEffAutoModelForImageTextToTextDualQPC) + model._pretrained_model_name_or_path = "dummy" + model.config = object() + model.vision_model = type("Vision", (), {"qpc_path": None})() + model.lang_model = type("Lang", (), {"qpc_path": None})() + model._build_layerwise_factory = lambda: None + + monkeypatch.setattr(_layerwise, "run_layerwise", lambda **kwargs: {"lang_decode_qpc_path": qpc_path}) + + result = model._run_layerwise_compile(layerwise_window_size=1) + + assert result == {"lang_decode_qpc_path": qpc_path} + assert model.qpc_paths == result + assert model.lang_model.qpc_path == qpc_path + + @pytest.mark.llm_model def test_layerwise_compile_rejects_unsupported_model(): """End-to-end smoke: invoking layerwise=True on llama bubbles the guard error.""" From a3d928551d3aa39ee83658dd5f79419ca45c8ba4 Mon Sep 17 00:00:00 2001 From: vbaddi Date: Sat, 6 Jun 2026 20:27:39 +0530 Subject: [PATCH 5/8] nit: Fix layerwise cache isolation and QPC reuse Signed-off-by: vbaddi --- QEfficient/transformers/models/_layerwise.py | 91 ++++++++++--------- QEfficient/utils/export_utils.py | 4 + .../qwen3_vl_moe/qwen3_vl_disagg_mode.py | 8 +- .../qwen3_vl_moe_layerwise_decode.py | 11 +-- .../unit_test/models/test_model_quickcheck.py | 36 ++++++++ 5 files changed, 96 insertions(+), 54 deletions(-) diff --git a/QEfficient/transformers/models/_layerwise.py b/QEfficient/transformers/models/_layerwise.py index 50cc94fcc..22f741616 100644 --- a/QEfficient/transformers/models/_layerwise.py +++ b/QEfficient/transformers/models/_layerwise.py @@ -547,50 +547,53 @@ def run_layerwise( first_onnx_path: Optional[Path] = None last_qeff_model = None - with _layerwise_export_env(): - _set_layer_windows(0, min(window_size, text_total_layers), text_total_layers) - cached_probe = probe_qeff_model or qeff_factory(model_id, config) - cached_onnx_path = _cached_layerwise_onnx_path(cached_probe, compile_kwargs) - if cached_onnx_path is not None: - first_onnx_path = cached_onnx_path - last_qeff_model = cached_probe - - for text_start, text_end in windows: + try: + with _layerwise_export_env(): + _set_layer_windows(0, min(window_size, text_total_layers), text_total_layers) + cached_probe = probe_qeff_model or qeff_factory(model_id, config) + cached_onnx_path = _cached_layerwise_onnx_path(cached_probe, compile_kwargs) if cached_onnx_path is not None: - break - _set_layer_windows(text_start, text_end, text_total_layers) - - qeff_model = qeff_factory(model_id, config) - last_qeff_model = qeff_model - if hasattr(qeff_model, "model"): - _null_outside_window_layers(qeff_model.model, apply_text=True) - _slim_for_window_export(qeff_model, ctx_len=compile_kwargs.get("ctx_len")) - - window_kwargs = dict(compile_kwargs) - # skip_lang is a VLM-only kwarg; only inject when present in caller's kwargs. - if "skip_lang" in window_kwargs: - window_kwargs["skip_lang"] = False - onnx_path = qeff_model.compile(**window_kwargs) - if first_onnx_path is None: - if isinstance(onnx_path, dict): - lang_key = next( - ( - k - for k in ( - "lang_decode_qpc_path", - "lang_prefill_qpc_path", - "lang_qpc_path", - ) - if k in onnx_path - ), - None, - ) - if lang_key is None: - raise RuntimeError(f"Layer-wise window produced no lang_*_qpc_path: keys={list(onnx_path)}") - lang_path = onnx_path[lang_key] - else: - lang_path = onnx_path - first_onnx_path = Path(str(lang_path)) + first_onnx_path = cached_onnx_path + last_qeff_model = cached_probe + + for text_start, text_end in windows: + if cached_onnx_path is not None: + break + _set_layer_windows(text_start, text_end, text_total_layers) + + qeff_model = qeff_factory(model_id, config) + last_qeff_model = qeff_model + if hasattr(qeff_model, "model"): + _null_outside_window_layers(qeff_model.model, apply_text=True) + _slim_for_window_export(qeff_model, ctx_len=compile_kwargs.get("ctx_len")) + + window_kwargs = dict(compile_kwargs) + # skip_lang is a VLM-only kwarg; only inject when present in caller's kwargs. + if "skip_lang" in window_kwargs: + window_kwargs["skip_lang"] = False + onnx_path = qeff_model.compile(**window_kwargs) + if first_onnx_path is None: + if isinstance(onnx_path, dict): + lang_key = next( + ( + k + for k in ( + "lang_decode_qpc_path", + "lang_prefill_qpc_path", + "lang_qpc_path", + ) + if k in onnx_path + ), + None, + ) + if lang_key is None: + raise RuntimeError(f"Layer-wise window produced no lang_*_qpc_path: keys={list(onnx_path)}") + lang_path = onnx_path[lang_key] + else: + lang_path = onnx_path + first_onnx_path = Path(str(lang_path)) + finally: + _reset_layer_windows() if first_onnx_path is None: raise RuntimeError("Layer-wise export produced no ONNX shards.") @@ -598,8 +601,6 @@ def run_layerwise( export_root = _resolve_export_root(first_onnx_path) final_artifact = _stitch_layerwise_if_available(export_root, text_total_layers) - _reset_layer_windows() - if not final_compile: return final_artifact diff --git a/QEfficient/utils/export_utils.py b/QEfficient/utils/export_utils.py index e1c7dc794..6ed7f97cc 100755 --- a/QEfficient/utils/export_utils.py +++ b/QEfficient/utils/export_utils.py @@ -112,6 +112,10 @@ def _generate_export_hash(qeff_model, args, kwargs, func): bound_args = new_sig.bind(*args, **kwargs) bound_args.apply_defaults() all_args = bound_args.arguments + if func.__name__ == "_export_layerwise": + export_kwargs = dict(all_args.get("export_kwargs") or {}) + export_kwargs["_qeff_layerwise_export"] = True + all_args["export_kwargs"] = export_kwargs # Use the model's current configuration for hashing to ensure any post-load modifications are captured # TODO: Replace with get_model_config property of modeling classes and remove the if-else diff --git a/examples/image_text_to_text/models/qwen3_vl_moe/qwen3_vl_disagg_mode.py b/examples/image_text_to_text/models/qwen3_vl_moe/qwen3_vl_disagg_mode.py index 839852a78..26a825791 100644 --- a/examples/image_text_to_text/models/qwen3_vl_moe/qwen3_vl_disagg_mode.py +++ b/examples/image_text_to_text/models/qwen3_vl_moe/qwen3_vl_disagg_mode.py @@ -18,8 +18,8 @@ from QEfficient import QEFFAutoModelForImageTextToText from QEfficient.generation.cloud_infer import QAICInferenceSession -# model_id = "Qwen/Qwen3-VL-30B-A3B-Instruct" -model_id = "tiny-random/qwen3-vl-moe" +model_id = "Qwen/Qwen3-VL-30B-A3B-Instruct" +# model_id = "tiny-random/qwen3-vl-moe" config = AutoConfig.from_pretrained(model_id) # For faster execution user can run with lesser layers, For Testing Purpose Only @@ -74,6 +74,8 @@ enable_chunking=True, skip_vision=True, use_onnx_subfunctions=True, + layerwise=True, + layerwise_window_size=1, ) @@ -93,6 +95,8 @@ prefill_only=False, skip_vision=True, use_onnx_subfunctions=True, + layerwise=True, + layerwise_window_size=1, ) lang_prefill_session = QAICInferenceSession(prefill_qpc_path.get("lang_prefill_qpc_path")) diff --git a/examples/image_text_to_text/models/qwen3_vl_moe/qwen3_vl_moe_layerwise_decode.py b/examples/image_text_to_text/models/qwen3_vl_moe/qwen3_vl_moe_layerwise_decode.py index 6caeeca08..9f5ea23d6 100644 --- a/examples/image_text_to_text/models/qwen3_vl_moe/qwen3_vl_moe_layerwise_decode.py +++ b/examples/image_text_to_text/models/qwen3_vl_moe/qwen3_vl_moe_layerwise_decode.py @@ -15,11 +15,8 @@ ``qwen3_vl_moe``, ``qwen3_5_moe``, ``qwen3_moe``. """ -import requests import torch import transformers -from PIL import Image -from qwen_vl_utils import process_vision_info from transformers import AutoConfig, AutoProcessor, TextStreamer from QEfficient import QEFFAutoModelForImageTextToText @@ -41,7 +38,7 @@ def main(): kv_offload=True, config=config, torch_dtype=torch.float16, - # layerwise=True, + layerwise=True, ) batch_size = 1 qpc_path = qeff_model.compile( @@ -59,8 +56,8 @@ def main(): split_retained_state_io=True, use_onnx_subfunctions=True, mos=1, - # layerwise=True, - # layerwise_window_size=1, + layerwise=True, + layerwise_window_size=1, ) print(f"Final QPC path: {qpc_path}") @@ -83,7 +80,7 @@ def main(): return_tensors="pt", ) inputs = qeff_model.model.prepare_inputs_for_generation(inputs=inputs, prefill_seq_len=128, batch_size=batch_size) - streamer = TextStreamer(tokenizer) + # streamer = TextStreamer(tokenizer) output = qeff_model.generate(inputs=inputs, generation_len=100) print(output.generated_ids) print(tokenizer.batch_decode(output.generated_ids)) diff --git a/tests/unit_test/models/test_model_quickcheck.py b/tests/unit_test/models/test_model_quickcheck.py index 7ab9b78df..03ac6ce27 100644 --- a/tests/unit_test/models/test_model_quickcheck.py +++ b/tests/unit_test/models/test_model_quickcheck.py @@ -1487,6 +1487,42 @@ def test_layerwise_cached_merged_prefers_complete_graph(tmp_path): assert _layerwise._cached_merged_onnx(tmp_path, total_layers=48) == complete +@pytest.mark.llm_model +def test_layerwise_export_hash_is_separate_from_default(monkeypatch): + from QEfficient.utils import export_utils + + class DummyConfig: + def to_diff_dict(self): + return {"model_type": "dummy"} + + class DummyModel: + config = DummyConfig() + + class DummyQEffModel: + model = DummyModel() + hash_params = {} + + def normal_export(self, example_inputs, output_names, dynamic_axes, **export_kwargs): + pass + + def _export_layerwise(self, example_inputs, output_names, dynamic_axes, **export_kwargs): + pass + + captured_export_kwargs = [] + + def fake_create_export_hash(**kwargs): + captured_export_kwargs.append(kwargs.get("export_kwargs")) + return "hash", {} + + monkeypatch.setattr(export_utils, "create_export_hash", fake_create_export_hash) + + export_utils._generate_export_hash(DummyQEffModel(), ({}, ["logits"], {}), {}, normal_export) + export_utils._generate_export_hash(DummyQEffModel(), ({}, ["logits"], {}), {}, _export_layerwise) + + assert captured_export_kwargs[0] in (None, {}) + assert captured_export_kwargs[1]["_qeff_layerwise_export"] is True + + @pytest.mark.llm_model def test_subfunction_compile_io_names_use_internal_retained_state(): from QEfficient.transformers.models.modeling_auto import _compile_io_name, _state_input_name From c6e374c0789e01fff4187819baa1c3c4febc3d04 Mon Sep 17 00:00:00 2001 From: vbaddi Date: Sat, 6 Jun 2026 20:29:36 +0530 Subject: [PATCH 6/8] nit: fix some imports Signed-off-by: vbaddi --- QEfficient/transformers/models/modeling_auto.py | 1 - .../models/qwen3_vl_moe/qwen3_vl_moe_layerwise_decode.py | 3 +-- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 7bca55f9d..3e4670a52 100755 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -6,7 +6,6 @@ # ---------------------------------------------------------------------------- import os -import re import warnings from pathlib import Path from time import perf_counter diff --git a/examples/image_text_to_text/models/qwen3_vl_moe/qwen3_vl_moe_layerwise_decode.py b/examples/image_text_to_text/models/qwen3_vl_moe/qwen3_vl_moe_layerwise_decode.py index 9f5ea23d6..fef43de62 100644 --- a/examples/image_text_to_text/models/qwen3_vl_moe/qwen3_vl_moe_layerwise_decode.py +++ b/examples/image_text_to_text/models/qwen3_vl_moe/qwen3_vl_moe_layerwise_decode.py @@ -17,7 +17,7 @@ import torch import transformers -from transformers import AutoConfig, AutoProcessor, TextStreamer +from transformers import AutoConfig, AutoProcessor from QEfficient import QEFFAutoModelForImageTextToText @@ -80,7 +80,6 @@ def main(): return_tensors="pt", ) inputs = qeff_model.model.prepare_inputs_for_generation(inputs=inputs, prefill_seq_len=128, batch_size=batch_size) - # streamer = TextStreamer(tokenizer) output = qeff_model.generate(inputs=inputs, generation_len=100) print(output.generated_ids) print(tokenizer.batch_decode(output.generated_ids)) From 81407060b4223c7d331aac3457ab205ce03f5e7d Mon Sep 17 00:00:00 2001 From: vbaddi Date: Sat, 6 Jun 2026 22:54:48 +0530 Subject: [PATCH 7/8] nit: qwen3.5 fixes and logits buffer dtype mismatch, WIP(revert to previous commit if breaks) Signed-off-by: vbaddi --- QEfficient/base/modeling_qeff.py | 25 ++++ QEfficient/generation/vlm_generation.py | 9 +- QEfficient/transformers/models/_layerwise.py | 21 ++-- .../transformers/models/modeling_auto.py | 119 +++++++++++++++++- .../qwen3_5_moe/modeling_qwen3_5_moe.py | 12 +- QEfficient/utils/layerwise_pipeline.py | 27 ++-- .../qwen3_5_moe_layerwise_decode.py | 40 +++++- .../qwen3_vl_moe/qwen3_vl_disagg_mode.py | 8 +- .../qwen3_vl_moe_layerwise_decode.py | 8 +- .../unit_test/models/test_model_quickcheck.py | 30 +++++ 10 files changed, 257 insertions(+), 42 deletions(-) diff --git a/QEfficient/base/modeling_qeff.py b/QEfficient/base/modeling_qeff.py index a7c63ad0e..49e8331a6 100755 --- a/QEfficient/base/modeling_qeff.py +++ b/QEfficient/base/modeling_qeff.py @@ -850,6 +850,31 @@ def _compile( continue command.append(f"{option}={value}") + # Final custom-IO normalization against ONNX I/O names. + # This only rewrites retained-state aliases: + # *_InternalRetainedState <-> *_RetainedState. + # Any other custom-IO key is preserved as-is for backward compatibility. + if custom_io is not None and onnx_path is not None: + try: + model = onnx.load(onnx_path, load_external_data=False) + io_names = {value.name for value in list(model.graph.input) + list(model.graph.output)} + normalized_custom_io = {} + for io_name, dtype in custom_io.items(): + resolved_name = io_name + if io_name not in io_names: + if io_name.endswith("_InternalRetainedState"): + candidate = io_name[: -len("_InternalRetainedState")] + "_RetainedState" + if candidate in io_names: + resolved_name = candidate + elif io_name.endswith("_RetainedState"): + candidate = io_name[: -len("_RetainedState")] + "_InternalRetainedState" + if candidate in io_names: + resolved_name = candidate + normalized_custom_io[resolved_name] = dtype + custom_io = normalized_custom_io + except Exception: + pass + if use_onnx_subfunctions: logger.info("Using ONNX subfunctions for compilation.") command.append("-sub-functions") diff --git a/QEfficient/generation/vlm_generation.py b/QEfficient/generation/vlm_generation.py index 2af89a861..138cc2da6 100755 --- a/QEfficient/generation/vlm_generation.py +++ b/QEfficient/generation/vlm_generation.py @@ -279,6 +279,8 @@ def update_decode_inputs_qwen_vl(self, outputs, position_ids, generation_len, de next_token_id (array): The next token ID. """ next_token_id = self._fetch_next_token_id(outputs) + if next_token_id.ndim == 2 and next_token_id.shape[1] > 1: + next_token_id = next_token_id[:, -1:] # Store the generated values. self.decode_input_ids[decode_batch_id or slice(None)] = next_token_id @@ -306,9 +308,6 @@ def _execute_chunked_prefill( Returns: Final prefill outputs """ - # Set output buffers - self._set_output_buffers(batch_size=prefill_logit_bs, sequence_length=1) - # Skip buffers for dual-QPC coordination self._session.skip_buffers(self._lang_skip_buffers) @@ -336,6 +335,10 @@ def _execute_chunked_prefill( ..., i * self._prefill_seq_len : (i + 1) * self._prefill_seq_len ] + # Prefill specializations can return one logit per prompt token. + # Size the output buffer to the active chunk instead of the decode shape. + self._set_output_buffers(batch_size=prefill_logit_bs, sequence_length=input_ids_slice.shape[1]) + chunk_inputs = { "input_ids": input_ids_slice, "position_ids": position_ids_slice, diff --git a/QEfficient/transformers/models/_layerwise.py b/QEfficient/transformers/models/_layerwise.py index 22f741616..a603b8b37 100644 --- a/QEfficient/transformers/models/_layerwise.py +++ b/QEfficient/transformers/models/_layerwise.py @@ -350,11 +350,12 @@ def _set_layer_windows(text_start: int, text_end: int, text_total_layers: int) - qeff_35_mod = getattr(QEfficient.transformers.models, "qwen3_5_moe", None) if qeff_35_mod is not None: - cls = getattr(qeff_35_mod.modeling_qwen3_5_moe, "QEffQwen3_5MoeTextModel", None) - if cls is not None: - cls._start = text_start - cls._end = text_end - cls._total_layers = text_total_layers + for class_name in ("QEffQwen3_5MoeTextModel", "QEffQwen3_5MoeModel"): + cls = getattr(qeff_35_mod.modeling_qwen3_5_moe, class_name, None) + if cls is not None: + cls._start = text_start + cls._end = text_end + cls._total_layers = text_total_layers qeff_3_mod = getattr(QEfficient.transformers.models, "qwen3_moe", None) if qeff_3_mod is not None: @@ -445,14 +446,18 @@ def _install_window_patches_for(model_type: str) -> None: for name in ("Qwen3VLMoeForConditionalGeneration", "Qwen3VLMoeForCausalLM") if (cls := getattr(qwen3_vl_moe_mod, name, None)) is not None ) + qwen3_5_moe_mod = getattr(getattr(transformers.models, "qwen3_5_moe", None), "modeling_qwen3_5_moe", None) + if qwen3_5_moe_mod is not None and model_type in {"qwen3_5_moe", "qwen3_5_moe_text"}: + candidates.extend( + cls + for name in ("Qwen3_5MoeForConditionalGeneration", "Qwen3_5MoeForCausalLM") + if (cls := getattr(qwen3_5_moe_mod, name, None)) is not None + ) qwen3_moe_mod = getattr(getattr(transformers.models, "qwen3_moe", None), "modeling_qwen3_moe", None) if qwen3_moe_mod is not None and model_type in {"qwen3_moe"}: candidates.extend( cls for name in ("Qwen3MoeForCausalLM",) if (cls := getattr(qwen3_moe_mod, name, None)) is not None ) - # qwen3_5_moe shares the qwen3_vl_moe HF classes today; the QEff modeling - # file overrides behavior. Install the shard patch (above) and rely on - # _null_outside_window_layers running post-init in the driver loop. for cls in candidates: _install_window_patch(cls) diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 3e4670a52..0793023e2 100755 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -123,6 +123,39 @@ def _resolve_torch_dtype(kwargs: dict) -> None: kwargs["torch_dtype"] = torch.float32 +def _build_layerwise_vision_export_model(hf_auto_class, pretrained_model_name_or_path, kwargs): + """Load a VLM with vision weights and only the first language window. + + This opt-in path keeps vision export usable while avoiding materializing + every decoder layer up front. Language ONNX/QPC export still goes through + the regular layerwise driver, which reloads each window independently. + """ + from QEfficient.transformers.models import _layerwise + + config = kwargs.get("config", None) + if config is None: + from transformers import AutoConfig + + config_kwargs = { + key: kwargs[key] + for key in ("trust_remote_code", "revision", "token", "subfolder", "cache_dir") + if key in kwargs + } + config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **config_kwargs) + kwargs["config"] = config + model_type = _layerwise.assert_layerwise_supported(config) + total_layers = _layerwise._resolve_text_total_layers(config) + _layerwise._ensure_pretrained_window_attrs() + _layerwise._install_window_patches_for(model_type) + + with _layerwise._layerwise_export_env(): + _layerwise._set_layer_windows(0, min(1, total_layers), total_layers) + try: + return hf_auto_class.from_pretrained(pretrained_model_name_or_path, **kwargs) + finally: + _layerwise._reset_layer_windows() + + def _build_meta_model(hf_auto_class, pretrained_model_name_or_path, kwargs): """Construct an HF model on the meta device for layer-wise mode. @@ -1615,6 +1648,27 @@ def _factory(model_id, config): return _factory + def _build_layerwise_vision_wrapper(self): + """Materialize vision weights while keeping language bounded to one window.""" + model_id = self._pretrained_model_name_or_path + if model_id is None: + raise RuntimeError("layerwise=True requires a model loaded via from_pretrained(...).") + torch_dtype = getattr(self.config, "torch_dtype", None) + kwargs = { + "config": self.config, + "attn_implementation": "eager", + "torch_dtype": torch_dtype, + "low_cpu_mem_usage": True, + } + _resolve_torch_dtype(kwargs) + self.config.torch_dtype = kwargs["torch_dtype"] + model = _build_layerwise_vision_export_model(self._hf_auto_class, model_id, kwargs) + return self.__class__( + model, + continuous_batching=self.continuous_batching, + pretrained_model_name_or_path=model_id, + ) + def _run_layerwise_export( self, *, @@ -1777,6 +1831,36 @@ def compile( raise ValueError("Expected at least one of 'skip_lang' or 'skip_vision' to be False") if layerwise: + if skip_lang and not skip_vision: + vision_wrapper = self._build_layerwise_vision_wrapper() + qpc_paths = vision_wrapper.compile( + img_size=img_size, + vision_onnx_path=vision_onnx_path, + lang_onnx_path=lang_onnx_path, + compile_dir=compile_dir, + prefill_seq_len=prefill_seq_len, + comp_ctx_lengths_prefill=comp_ctx_lengths_prefill, + comp_ctx_lengths_decode=comp_ctx_lengths_decode, + ctx_len=ctx_len, + batch_size=batch_size, + full_batch_size=full_batch_size, + kv_cache_batch_size=kv_cache_batch_size, + num_devices=num_devices, + num_cores=num_cores, + mxfp6_matmul=mxfp6_matmul, + mxint8_kv_cache=mxint8_kv_cache, + skip_vision=skip_vision, + skip_lang=skip_lang, + use_onnx_subfunctions=use_onnx_subfunctions, + prefill_only=prefill_only, + enable_chunking=enable_chunking, + qaic_config=qaic_config, + **compiler_options, + ) + self.vision_model.onnx_path = vision_wrapper.vision_model.onnx_path + self.vision_model.qpc_path = vision_wrapper.vision_model.qpc_path + self.qpc_paths = qpc_paths + return qpc_paths return self._run_layerwise_compile( img_size=img_size, vision_onnx_path=vision_onnx_path, @@ -1874,7 +1958,10 @@ def compile( if lang_onnx_path: self.lang_model.onnx_path = lang_onnx_path - if vision_onnx_path is None or lang_onnx_path is None: + needs_vision_export = not skip_vision and vision_onnx_path is None + needs_lang_export = not skip_lang and lang_onnx_path is None + + if needs_vision_export or needs_lang_export: self.export( use_onnx_subfunctions=use_onnx_subfunctions, skip_vision=skip_vision, @@ -2218,6 +2305,23 @@ def kv_offload_generate( lang_start = perf_counter() # Run prefill chunk_inputs = lang_inputs.copy() + logits_vocab_size = None + logits_dtype = np.float32 + logits_binding_idx = lang_session.binding_index_map.get("logits") + if logits_binding_idx is not None and getattr(lang_session, "allowed_shapes", None): + allowed_vocab_sizes = [ + int(shape_spec[logits_binding_idx][1][-1]) + for shape_spec in lang_session.allowed_shapes + if len(shape_spec) > logits_binding_idx and len(shape_spec[logits_binding_idx][1]) >= 3 + ] + if allowed_vocab_sizes: + logits_vocab_size = max(allowed_vocab_sizes) + if logits_vocab_size is None and logits_binding_idx is not None: + logits_dims = tuple(lang_session.bindings[logits_binding_idx].dims) + if logits_dims: + logits_vocab_size = int(logits_dims[-1]) + if logits_binding_idx is not None: + logits_dtype = lang_session.aic_to_np_dtype_mapping[lang_session.bindings[logits_binding_idx].type] for i in range(num_chunks): if ( self.comp_ctx_lengths_prefill is not None @@ -2242,6 +2346,13 @@ def kv_offload_generate( chunk_inputs["cross_attention_mask"] = lang_inputs["cross_attention_mask"][ :, i * prefill_seq_len : (i + 1) * prefill_seq_len, :, : ] + if logits_vocab_size is not None: + logits_shape = ( + chunk_inputs["input_ids"].shape[0], + chunk_inputs["input_ids"].shape[1], + logits_vocab_size, + ) + lang_session.set_buffers({"logits": np.zeros(logits_shape, dtype=logits_dtype)}) outputs = lang_session.run(chunk_inputs) chunk_inputs["image_idx"] = outputs["image_idx_output"] @@ -2261,6 +2372,8 @@ def kv_offload_generate( lang_session.skip_buffers(vision_outputs.keys()) # Get first token lang_inputs["input_ids"] = outputs["logits"].argmax(2) + if lang_inputs["input_ids"].ndim == 2 and lang_inputs["input_ids"].shape[1] > 1: + lang_inputs["input_ids"] = lang_inputs["input_ids"][:, -1:] lang_inputs["position_ids"] = np.max(lang_inputs["position_ids"], axis=-1, keepdims=True) + 1 if "mm_token_type_ids" in lang_inputs: @@ -2290,6 +2403,10 @@ def kv_offload_generate( break lang_inputs["comp_ctx_lengths"] = list_of_comp_ctx_lengths_decode[ccl_id] + if logits_vocab_size is not None: + decode_logits_shape = (lang_inputs["input_ids"].shape[0], 1, logits_vocab_size) + lang_session.set_buffers({"logits": np.zeros(decode_logits_shape, dtype=logits_dtype)}) + decode_start = perf_counter() for num_token in range(1, generation_len): if self.comp_ctx_lengths_decode is not None: diff --git a/QEfficient/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py b/QEfficient/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py index c564b644a..c5b98acb7 100644 --- a/QEfficient/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py +++ b/QEfficient/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py @@ -1055,6 +1055,9 @@ def forward( class QEffQwen3_5MoeForCausalLM(Qwen3_5MoeForCausalLM): def get_submodules_for_export(self) -> Type[nn.Module]: + layer_types = getattr(self.config, "layer_types", None) + if layer_types and len(set(layer_types)) > 1: + return set() return {QEffQwen3_5MoeDecoderLayer} @staticmethod @@ -1170,6 +1173,10 @@ def forward( class QEffQwen3_5MoeModel(Qwen3_5MoeModel): + _start = 0 + _end = 0 + _total_layers = None + def forward( self, input_ids: torch.LongTensor = None, @@ -1462,6 +1469,9 @@ def __init__(self, model): self.config = model.config def get_submodules_for_export(self) -> Type[nn.Module]: + layer_types = getattr(self.config.text_config, "layer_types", None) + if layer_types and len(set(layer_types)) > 1: + return set() return {QEffQwen3_5MoeDecoderLayer} def forward( @@ -1866,7 +1876,7 @@ def get_dummy_inputs( lang_inputs["past_key_values"] = [[] for _ in range(self.model.config.text_config.num_hidden_layers)] # for i in range(self.model.config.text_config.num_hidden_layers): - i = QEffQwen3_5MoeModel._start + i = QEffQwen3_5MoeTextModel._start if self.model.config.text_config.layer_types[i] == "full_attention": for kv in ["key", "value"]: lang_inputs["past_key_values"][i].append(torch.zeros(kv_cache_shape, dtype=torch.float32)) diff --git a/QEfficient/utils/layerwise_pipeline.py b/QEfficient/utils/layerwise_pipeline.py index fd1e054dc..0bd377d92 100644 --- a/QEfficient/utils/layerwise_pipeline.py +++ b/QEfficient/utils/layerwise_pipeline.py @@ -82,17 +82,6 @@ def split_layer_graph( model = onnx.load(onnx_path, load_external_data=False) - decoder_input = None - decoder_output = None - for node in model.graph.node: - if "DecoderLayer" in node.name: - decoder_input = list(node.input) - decoder_output = list(node.output) - break - - if decoder_input is None or decoder_output is None: - raise RuntimeError(f"DecoderLayer not found in layer window {layer_start}_{layer_end}") - model_ir = onnx_ir.load(onnx_path) graph_inputs = [v.name for v in model.graph.input] @@ -442,13 +431,17 @@ def run_merge_pipeline( m1_pref = onnx.load(m1_path, load_external_data=False) m2_pref = onnx.load(m2_path, load_external_data=False) - decoder_nodes = [n for n in m1_pref.graph.node if "DecoderLayer" in n.name] - if not decoder_nodes: - raise RuntimeError(f"DecoderLayer node not found in {m1_path}") - decoder_output = list(decoder_nodes[-1].output) - selected_output = next((x for x in decoder_output if "RetainedState" not in x), None) + graph_outputs = [output.name for output in m1_pref.graph.output] + selected_output = next( + ( + name + for name in graph_outputs + if "RetainedState" not in name and not name.endswith("position_ids") and "image_idx" not in name + ), + None, + ) if selected_output is None: - raise RuntimeError(f"No decoder output found without 'RetainedState'. Outputs: {decoder_output}") + raise RuntimeError(f"No mergeable decoder output found in {m1_path}. Outputs: {graph_outputs}") merged_model = merge_models( m1_pref, diff --git a/examples/image_text_to_text/models/qwen3_5_moe/qwen3_5_moe_layerwise_decode.py b/examples/image_text_to_text/models/qwen3_5_moe/qwen3_5_moe_layerwise_decode.py index 1e402438d..7d92ca695 100644 --- a/examples/image_text_to_text/models/qwen3_5_moe/qwen3_5_moe_layerwise_decode.py +++ b/examples/image_text_to_text/models/qwen3_5_moe/qwen3_5_moe_layerwise_decode.py @@ -16,23 +16,28 @@ """ import torch -from transformers import AutoConfig +import transformers +from transformers import AutoConfig, AutoProcessor from QEfficient import QEFFAutoModelForImageTextToText -MODEL_ID = "Qwen/Qwen3.5-397B-A17B" +# MODEL_ID = "Qwen/Qwen3.5-397B-A17B" +MODEL_ID = "tiny-random/qwen3.6-moe" +# MODEL_ID = "Qwen/Qwen3.6-35B-A3B" def main(): config = AutoConfig.from_pretrained(MODEL_ID) config.torch_dtype = "float32" + tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL_ID) + processor = AutoProcessor.from_pretrained(MODEL_ID) qeff_model = QEFFAutoModelForImageTextToText.from_pretrained( MODEL_ID, attn_implementation="eager", kv_offload=True, config=config, - torch_dtype=torch.float32, + dtype=torch.float32, layerwise=True, ) @@ -41,20 +46,45 @@ def main(): prefill_seq_len=1, ctx_len=4096, num_cores=16, - num_devices=1, + num_devices=4, height=354, width=536, mxfp6_matmul=True, + mxint8_kv_cache=True, aic_enable_depth_first=True, skip_vision=True, split_retained_state_io=True, - use_onnx_subfunctions=True, + use_onnx_subfunctions=False, mos=1, layerwise=True, layerwise_window_size=1, ) print(f"Final QPC path: {qpc_path}") + batch_size = 1 + messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "Tell me about yourself."}, + ], + }, + ] + messages = [messages] * batch_size + + inputs = processor.apply_chat_template( + messages, + add_generation_prompt=True, + tokenize=True, + return_dict=True, + return_tensors="pt", + ) + inputs = qeff_model.model.prepare_inputs_for_generation(inputs=inputs, prefill_seq_len=128, batch_size=batch_size) + output = qeff_model.generate(inputs=inputs, generation_len=100) + print(output.generated_ids) + print(tokenizer.batch_decode(output.generated_ids)) + print(output) + if __name__ == "__main__": main() diff --git a/examples/image_text_to_text/models/qwen3_vl_moe/qwen3_vl_disagg_mode.py b/examples/image_text_to_text/models/qwen3_vl_moe/qwen3_vl_disagg_mode.py index 26a825791..951fcfb19 100644 --- a/examples/image_text_to_text/models/qwen3_vl_moe/qwen3_vl_disagg_mode.py +++ b/examples/image_text_to_text/models/qwen3_vl_moe/qwen3_vl_disagg_mode.py @@ -18,9 +18,10 @@ from QEfficient import QEFFAutoModelForImageTextToText from QEfficient.generation.cloud_infer import QAICInferenceSession -model_id = "Qwen/Qwen3-VL-30B-A3B-Instruct" -# model_id = "tiny-random/qwen3-vl-moe" +# model_id = "Qwen/Qwen3-VL-30B-A3B-Instruct" +model_id = "tiny-random/qwen3-vl-moe" config = AutoConfig.from_pretrained(model_id) +config.dtype = "float16" # For faster execution user can run with lesser layers, For Testing Purpose Only # config.vision_config.depth = 9 @@ -28,7 +29,7 @@ # config.vision_config.deepstack_visual_indexes = [8] qeff_model = QEFFAutoModelForImageTextToText.from_pretrained( - model_id, attn_implementation="eager", kv_offload=True, config=config + model_id, attn_implementation="eager", kv_offload=True, config=config, dtype=torch.float16, layerwise=True ) tokenizer = transformers.AutoTokenizer.from_pretrained(model_id) processor = AutoProcessor.from_pretrained(model_id) @@ -54,6 +55,7 @@ split_model_io=True, skip_lang=True, use_onnx_subfunctions=True, + layerwise=True, ) prefill_qpc_path = qeff_model.compile( diff --git a/examples/image_text_to_text/models/qwen3_vl_moe/qwen3_vl_moe_layerwise_decode.py b/examples/image_text_to_text/models/qwen3_vl_moe/qwen3_vl_moe_layerwise_decode.py index fef43de62..9406260cb 100644 --- a/examples/image_text_to_text/models/qwen3_vl_moe/qwen3_vl_moe_layerwise_decode.py +++ b/examples/image_text_to_text/models/qwen3_vl_moe/qwen3_vl_moe_layerwise_decode.py @@ -22,13 +22,13 @@ from QEfficient import QEFFAutoModelForImageTextToText # MODEL_ID = "Qwen/Qwen3-VL-235B-A22B-Instruct" -MODEL_ID = "Qwen/Qwen3-VL-30B-A3B-Instruct" -# MODEL_ID = "tiny-random/qwen3-vl-moe" +# MODEL_ID = "Qwen/Qwen3-VL-30B-A3B-Instruct" +MODEL_ID = "tiny-random/qwen3-vl-moe" def main(): config = AutoConfig.from_pretrained(MODEL_ID) - config.torch_dtype = "float16" + config.dtype = "float16" tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL_ID) processor = AutoProcessor.from_pretrained(MODEL_ID) @@ -37,7 +37,7 @@ def main(): attn_implementation="eager", kv_offload=True, config=config, - torch_dtype=torch.float16, + dtype=torch.float16, layerwise=True, ) batch_size = 1 diff --git a/tests/unit_test/models/test_model_quickcheck.py b/tests/unit_test/models/test_model_quickcheck.py index 03ac6ce27..ad101d250 100644 --- a/tests/unit_test/models/test_model_quickcheck.py +++ b/tests/unit_test/models/test_model_quickcheck.py @@ -1360,6 +1360,36 @@ def test_layerwise_off_does_not_set_env_var(tmp_path): assert QEFFBaseModel._layerwise_active is False +@pytest.mark.llm_model +def test_layerwise_vision_wrapper_keeps_only_first_text_window(): + try: + config = AutoConfig.from_pretrained(LAYERWISE_TINY_MODEL_ID) + qeff_model = QEFFAutoModelForImageTextToText.from_pretrained( + LAYERWISE_TINY_MODEL_ID, + kv_offload=True, + config=config, + layerwise=True, + ) + vision_wrapper = qeff_model._build_layerwise_vision_wrapper() + except Exception as exc: + _skip_on_model_fetch_error(exc, LAYERWISE_TINY_MODEL_ID) + + layers = vision_wrapper.model.model.language_model.layers + + assert getattr(qeff_model, "_layerwise_outer_meta", False) is True + assert layers[0] is not None + assert sum(layer is not None for layer in layers) == 1 + assert next(vision_wrapper.model.model.visual.parameters()).device.type != "meta" + + default_model = QEFFAutoModelForImageTextToText.from_pretrained( + LAYERWISE_TINY_MODEL_ID, + kv_offload=True, + config=config, + ) + default_layers = default_model.model.model.language_model.layers + assert sum(layer is not None for layer in default_layers) == len(default_layers) + + @pytest.mark.llm_model def test_layerwise_context_manager_toggles_class_flag(): """The driver's context manager must flip the class flag and restore it, From 0763861956633deaecd6df3de0e9344ab46e346a Mon Sep 17 00:00:00 2001 From: vbaddi Date: Sun, 7 Jun 2026 00:49:40 +0530 Subject: [PATCH 8/8] fix(qwen-moe): gate layerwise logic and restore float16 export - Gate Qwen3.x MoE windowing behind is_layerwise_active() so the default (layerwise=False) path matches the pre-layerwise baseline exactly. - Restore Qwen3.5 float16 export: fix float32 leaks in the MoE block, GatedDeltaNet norm, and dummy export inputs; honor the v5 dtype alias. Signed-off-by: vbaddi --- QEfficient/transformers/models/_layerwise.py | 39 ++++++ .../transformers/models/modeling_auto.py | 13 ++ .../qwen3_5_moe/modeling_qwen3_5_moe.py | 109 ++++++++++++---- .../models/qwen3_moe/modeling_qwen3_moe.py | 19 +-- .../qwen3_vl_moe/modeling_qwen3_vl_moe.py | 56 +++++++- .../qwen3_5_moe_layerwise_decode.py | 10 +- .../qwen3_vl_moe/qwen3_vl_disagg_mode.py | 12 +- .../unit_test/models/test_model_quickcheck.py | 121 ++++++++++++++++++ 8 files changed, 325 insertions(+), 54 deletions(-) diff --git a/QEfficient/transformers/models/_layerwise.py b/QEfficient/transformers/models/_layerwise.py index a603b8b37..af5b5c67f 100644 --- a/QEfficient/transformers/models/_layerwise.py +++ b/QEfficient/transformers/models/_layerwise.py @@ -99,6 +99,45 @@ def assert_layerwise_supported(config) -> str: ) +def is_layerwise_active() -> bool: + """True only while the layer-wise export driver is running. + + The driver flips this on inside :func:`_layerwise_export_env`. Outside that + scope (the default, non-layerwise path) it is always False, which lets the + modeling forwards short-circuit every window branch and behave exactly as + they did before layer-wise support was added. + """ + return bool(_LAYERWISE_STATE["active"]) + + +def resolve_layer_window(model_cls, total_layers: int) -> Tuple[int, int]: + """Return the ``[start, end)`` decoder-layer window to run this forward. + + When layer-wise export is inactive this always returns ``(0, total_layers)`` + regardless of any ``_start``/``_end`` class attributes, so the default path + is independent of (possibly stale) window state left on the modeling class. + When the driver is active it honors the window it poked onto ``model_cls``. + """ + if not is_layerwise_active(): + return 0, total_layers + start = int(getattr(model_cls, "_start", 0) or 0) + end = getattr(model_cls, "_end", 0) or 0 + end = int(end) if end else total_layers + return start, end + + +def is_last_layer_window(model_cls, total_layers: int) -> bool: + """True if this forward owns the final decoder window (applies final norm / lm_head). + + Always True on the default path; on the layer-wise path it is True only for + the window whose ``_end`` reaches the total layer count. + """ + if not is_layerwise_active(): + return True + _, end = resolve_layer_window(model_cls, total_layers) + return end >= total_layers + + # --------------------------------------------------------------------------- # Internal helpers (lifted from the legacy example script) # --------------------------------------------------------------------------- diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 0793023e2..efacab6ba 100755 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -110,8 +110,17 @@ def _resolve_torch_dtype(kwargs: dict) -> None: * If torch_dtype is not set at all, default to float32 so that models whose config.json declares bfloat16 are still loaded in a dtype that the ai100 compiler accepts. + + Transformers v5 renamed the ``torch_dtype`` argument to ``dtype``. To keep + backward compatibility for callers (and examples) that pass either name, + a caller-supplied ``dtype`` is folded into ``torch_dtype`` here and the two + are kept in sync on the way out so the loaded model and its config agree. """ aic_hw_version = constants.DEFAULT_AIC_HW_VERSION + # Normalize the transformers-v5 ``dtype`` alias onto ``torch_dtype`` so a + # single code path governs the HW dtype policy below. + if kwargs.get("torch_dtype", None) is None and kwargs.get("dtype", None) is not None: + kwargs["torch_dtype"] = kwargs["dtype"] current_dtype = kwargs.get("torch_dtype", None) if (current_dtype is None or current_dtype == torch.bfloat16) and aic_hw_version != "ai200": @@ -122,6 +131,10 @@ def _resolve_torch_dtype(kwargs: dict) -> None: ) kwargs["torch_dtype"] = torch.float32 + # Keep the v5 alias in sync so HF from_pretrained and config see one dtype. + if "dtype" in kwargs: + kwargs["dtype"] = kwargs["torch_dtype"] + def _build_layerwise_vision_export_model(hf_auto_class, pretrained_model_name_or_path, kwargs): """Load a VLM with vision weights and only the first language window. diff --git a/QEfficient/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py b/QEfficient/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py index c5b98acb7..8a0bc396e 100644 --- a/QEfficient/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py +++ b/QEfficient/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py @@ -54,6 +54,11 @@ QEffDynamicLayer, ) from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask +from QEfficient.transformers.models._layerwise import ( + is_last_layer_window, + is_layerwise_active, + resolve_layer_window, +) from QEfficient.utils import constants from QEfficient.utils._utils import IOInfo, get_padding_shape_from_config from QEfficient.utils.constants import MIN_MASKED_ATTENTION_VALUE @@ -69,11 +74,12 @@ class QEffQwen3_5MoeGatedDeltaNetCustomRMSNormAIC(nn.Module): """ def forward(self, hidden_states, gate): - return ( - CustomRMSNormFunc.apply( - hidden_states, self.weight, self.variance_epsilon if hasattr(self, "variance_epsilon") else self.eps - ) - ) * F.silu(gate.to(torch.float32)) + normed = CustomRMSNormFunc.apply( + hidden_states, self.weight, self.variance_epsilon if hasattr(self, "variance_epsilon") else self.eps + ) + # silu is computed in float32 for numerical parity, then cast back so the + # gated output keeps the module dtype (e.g. float16) and matches out_proj. + return normed * F.silu(gate.to(torch.float32)).to(normed.dtype) class QEffQwen3_5MoeDynamicCache(Cache): @@ -109,6 +115,21 @@ def from_legacy_cache( if past_key_values is None: return cache + if not is_layerwise_active(): + # Default path: restore every layer, matching pre-layerwise behavior. + for layer_idx, layer_state in enumerate(past_key_values): + if cache.layer_types[layer_idx] == "full_attention": + key_states, value_states = layer_state + layer = QEffDynamicLayer() + layer.keys = key_states + layer.values = value_states + cache.kv_layers[layer_idx] = layer + else: + conv_state, recurrent_state = layer_state + cache.conv_states[layer_idx] = conv_state + cache.recurrent_states[layer_idx] = recurrent_state + return cache + # for layer_idx, layer_state in enumerate(past_key_values): layer_idx = QEffQwen3_5MoeTextModel._start if cache.layer_types[layer_idx] == "full_attention": @@ -991,8 +1012,7 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - start = QEffQwen3_5MoeTextModel._start - end = QEffQwen3_5MoeTextModel._end + start, end = resolve_layer_window(QEffQwen3_5MoeTextModel, len(self.layers)) if cache_position is None: past_seen_tokens = past_key_values.get_seq_length(layer_idx=start) if past_key_values is not None else 0 cache_position = torch.arange( @@ -1038,7 +1058,7 @@ def forward( # break - if QEffQwen3_5MoeTextModel._end == QEffQwen3_5MoeTextModel._total_layers: + if is_last_layer_window(QEffQwen3_5MoeTextModel, len(self.layers)): hidden_states = self.norm(hidden_states) if output_hidden_states: all_hidden_states += (hidden_states,) @@ -1046,7 +1066,8 @@ def forward( if return_legacy_cache: past_key_values = past_key_values.to_legacy_cache() - past_key_values = past_key_values[QEffQwen3_5MoeTextModel._start] + if is_layerwise_active(): + past_key_values = past_key_values[QEffQwen3_5MoeTextModel._start] return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values if use_cache else None, @@ -1095,12 +1116,13 @@ def get_onnx_retained_state_specs( "dynamic_axes": {}, } + kv_dtype = getattr(self.config, "torch_dtype", torch.float32) for layer_idx, layer_type in enumerate(self.config.layer_types): if layer_type == "full_attention": layer_names = [f"past_key.{layer_idx}", f"past_value.{layer_idx}"] layer_tensors = [ - torch.zeros(tuple(kv_cache_shape), dtype=torch.float32), - torch.zeros(tuple(kv_cache_shape), dtype=torch.float32), + torch.zeros(tuple(kv_cache_shape), dtype=kv_dtype), + torch.zeros(tuple(kv_cache_shape), dtype=kv_dtype), ] layer_axes = [ {0: batch_axis_name, 2: "ctx_len"}, @@ -1112,8 +1134,8 @@ def get_onnx_retained_state_specs( recurrent_shape = (batch_size, layer.num_v_heads, layer.head_k_dim, layer.head_v_dim) layer_names = [f"conv_state.{layer_idx}", f"recurrent_state.{layer_idx}"] layer_tensors = [ - torch.zeros(conv_shape, dtype=torch.float32), - torch.zeros(recurrent_shape, dtype=torch.float32), + torch.zeros(conv_shape, dtype=kv_dtype), + torch.zeros(recurrent_shape, dtype=kv_dtype), ] layer_axes = [{0: batch_axis_name}, {0: batch_axis_name}] @@ -1489,6 +1511,33 @@ def forward( inputs_embeds = self.model.model.get_input_embeddings()(input_ids) else: inputs_embeds = inputs_embeds + + if not is_layerwise_active(): + # Default (non-layerwise) path: image merge + full decoder + lm_head in + # a single forward, identical to the pre-layerwise behavior/output contract. + _, _, channel_size = inputs_embeds.shape + selected = input_ids == self.model.config.image_token_id + indices1 = selected.to(torch.int64).cumsum(1) - 1 + indices1 = torch.where(indices1 != -1, indices1 + image_idx, indices1) + indices0 = torch.arange(selected.unsqueeze(0).shape[0]).view(-1, 1) + image_features_expanded = vision_embeds.reshape(-1, channel_size).unsqueeze(0)[indices0, indices1] + image_input_embeds = torch.where(selected.unsqueeze(-1), image_features_expanded, inputs_embeds) + inputs_embeds = image_input_embeds + + outputs = self.language_model( + inputs_embeds=inputs_embeds, + position_ids=position_ids, + past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, + batch_index=batch_index, + use_cache=True, + ) + logit_index = position_ids[0].to(torch.int32).argmax(1, keepdim=True) + hidden_states = outputs.last_hidden_state[torch.arange(position_ids[0].shape[0]).view(-1, 1), logit_index] + logits = self.model.lm_head(hidden_states) + image_idx = (indices1.max() + 1).unsqueeze(0).unsqueeze(0) + return logits, vision_embeds, image_idx, outputs.past_key_values[: len(past_key_values)] + if QEffQwen3_5MoeTextModel._start == 0: B, S, _ = inputs_embeds.shape input_ids = torch.zeros((B, S), dtype=torch.int64, device=inputs_embeds.device) @@ -1848,10 +1897,12 @@ def get_dummy_inputs( vision_inputs = {} lang_inputs = {} - vision_inputs["pixel_values"] = torch.zeros((inputs_shapes["pixel_values"]), dtype=torch.float32) + # Float inputs follow the model dtype so float16 export traces cleanly. + float_dtype = getattr(self.model.config, "torch_dtype", torch.float32) + vision_inputs["pixel_values"] = torch.zeros((inputs_shapes["pixel_values"]), dtype=float_dtype) vision_inputs["image_grid_thw"] = torch.zeros((inputs_shapes["image_grid_thw"]), dtype=torch.int64) lang_inputs["input_ids"] = torch.zeros((inputs_shapes["input_ids"]), dtype=torch.int64) - lang_inputs["vision_embeds"] = torch.zeros((inputs_shapes["vision_embeds"]), dtype=torch.float32) + lang_inputs["vision_embeds"] = torch.zeros((inputs_shapes["vision_embeds"]), dtype=float_dtype) lang_inputs["position_ids"] = ( ( torch.arange(dummy_seq_len, dtype=torch.int64) @@ -1875,17 +1926,24 @@ def get_dummy_inputs( linear_batch_size = fbs if continuous_batching else bs lang_inputs["past_key_values"] = [[] for _ in range(self.model.config.text_config.num_hidden_layers)] - # for i in range(self.model.config.text_config.num_hidden_layers): - i = QEffQwen3_5MoeTextModel._start - if self.model.config.text_config.layer_types[i] == "full_attention": - for kv in ["key", "value"]: - lang_inputs["past_key_values"][i].append(torch.zeros(kv_cache_shape, dtype=torch.float32)) + # Default path exports all layers; layerwise exports only the active window's layer. + if is_layerwise_active(): + window_layers = [QEffQwen3_5MoeTextModel._start] else: - layer = self.model.language_model.layers[i].linear_attn - conv_shape = (linear_batch_size, layer.conv_dim, layer.conv_kernel_size) - recurrent_shape = (linear_batch_size, layer.num_v_heads, layer.head_k_dim, layer.head_v_dim) - lang_inputs["past_key_values"][i].append(torch.zeros(conv_shape, dtype=torch.float32)) - lang_inputs["past_key_values"][i].append(torch.zeros(recurrent_shape, dtype=torch.float32)) + window_layers = range(self.model.config.text_config.num_hidden_layers) + # KV/state dummy dtype follows the model dtype so the export trace works + # for float16 as well as float32 (matches the qwen3_vl_moe export path). + kv_dtype = getattr(self.model.config, "torch_dtype", torch.float32) + for i in window_layers: + if self.model.config.text_config.layer_types[i] == "full_attention": + for kv in ["key", "value"]: + lang_inputs["past_key_values"][i].append(torch.zeros(kv_cache_shape, dtype=kv_dtype)) + else: + layer = self.model.language_model.layers[i].linear_attn + conv_shape = (linear_batch_size, layer.conv_dim, layer.conv_kernel_size) + recurrent_shape = (linear_batch_size, layer.num_v_heads, layer.head_k_dim, layer.head_v_dim) + lang_inputs["past_key_values"][i].append(torch.zeros(conv_shape, dtype=kv_dtype)) + lang_inputs["past_key_values"][i].append(torch.zeros(recurrent_shape, dtype=kv_dtype)) # if continuous_batching: @@ -1971,6 +2029,7 @@ def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tens T = B * S x = hidden_states.view(T, H) prob, top_w, top_i = self.gate(hidden_states) + top_w = top_w.to(x.dtype) idx = top_i.reshape(-1) w_up = self.experts.gate_up_proj[idx.flatten()] diff --git a/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py b/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py index 7794e752e..864270fea 100644 --- a/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py +++ b/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py @@ -39,6 +39,7 @@ ) from QEfficient.transformers.cache_utils import QEffDynamicCache from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask +from QEfficient.transformers.models._layerwise import is_last_layer_window, is_layerwise_active, resolve_layer_window from QEfficient.utils.constants import MIN_MASKED_ATTENTION_VALUE @@ -326,7 +327,8 @@ def forward( past_seen_tokens = past_key_values.get_seq_length(self.layer_idx) if past_key_values is not None else 0 blocking_config = getattr(self, "attn_blocking_config", AttentionBlockingConfig()) - self.layer_idx = self.layer_idx - getattr(QEffQwen3MoeModel, "_start", 0) + if is_layerwise_active(): + self.layer_idx = self.layer_idx - getattr(QEffQwen3MoeModel, "_start", 0) use_blocking = blocking_config is not None and (blocking_config.mode != BlockingMode.NONE) if use_blocking: attn_output, attn_weights = generic_blocked_attention_interface( @@ -467,13 +469,8 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - start = QEffQwen3MoeModel._start - end = QEffQwen3MoeModel._end - - if QEffQwen3MoeModel._end == 0: - total_layers = end = self.config.num_hidden_layers - QEffQwen3MoeModel._end = total_layers - QEffQwen3MoeModel._total_layers = total_layers + total_layers = len(self.layers) + start, end = resolve_layer_window(QEffQwen3MoeModel, total_layers) past_key_values_length = 0 if past_key_values is not None: @@ -514,8 +511,7 @@ def forward( cos_cached=cos, ) - total_layers = getattr(QEffQwen3MoeModel, "_total_layers", len(self.layers)) - if QEffQwen3MoeModel._end == total_layers: + if is_last_layer_window(QEffQwen3MoeModel, len(self.layers)): hidden_states = self.norm(hidden_states) # add hidden states from the last decoder layer @@ -576,8 +572,7 @@ def forward( ) hidden_states = outputs.last_hidden_state - total_layers = getattr(QEffQwen3MoeModel, "_total_layers", len(self.model.layers)) - if QEffQwen3MoeModel._end < total_layers: + if not is_last_layer_window(QEffQwen3MoeModel, len(self.model.layers)): logits = hidden_states else: logit_idx = position_ids.to(torch.int32).argmax(1, keepdim=True) diff --git a/QEfficient/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py b/QEfficient/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py index 4a6259bf8..4b826f6df 100644 --- a/QEfficient/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +++ b/QEfficient/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py @@ -40,6 +40,11 @@ ) from QEfficient.transformers.cache_utils import QEffDynamicCache from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask +from QEfficient.transformers.models._layerwise import ( + is_last_layer_window, + is_layerwise_active, + resolve_layer_window, +) from QEfficient.utils import constants from QEfficient.utils._utils import IOInfo, get_padding_shape_from_config from QEfficient.utils.constants import MIN_MASKED_ATTENTION_VALUE @@ -386,7 +391,8 @@ def forward( key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos_cached, sin_cached) - self.layer_idx = self.layer_idx - getattr(QEffQwen3VLMoeTextModel, "_start", 0) + if is_layerwise_active(): + self.layer_idx = self.layer_idx - getattr(QEffQwen3VLMoeTextModel, "_start", 0) past_seen_tokens = past_key_values.get_seq_length(self.layer_idx) if past_key_values is not None else 0 blocking_config = getattr(self, "attn_blocking_config", AttentionBlockingConfig()) use_blocking = blocking_config is not None and (blocking_config.mode != BlockingMode.NONE) @@ -576,8 +582,7 @@ def forward( all_self_attns = () if output_attentions else None layer_idx = 0 - start = QEffQwen3VLMoeTextModel._start - end = QEffQwen3VLMoeTextModel._end + start, end = resolve_layer_window(QEffQwen3VLMoeTextModel, len(self.layers)) layer_indices_to_run = kwargs.get("layer_indices_to_run", None) for layer_idx, decoder_layer in enumerate(self.layers): @@ -609,15 +614,15 @@ def forward( if output_attentions: all_self_attns += (layer_outputs[1],) - if deepstack_visual_embeds is not None and start in range(deepstack_visual_embeds.shape[0]): + if deepstack_visual_embeds is not None and layer_idx in range(deepstack_visual_embeds.shape[0]): hidden_states = self._deepstack_process( hidden_states, visual_pos_masks, - deepstack_visual_embeds[start], + deepstack_visual_embeds[layer_idx], ) layer_idx += 1 - if QEffQwen3VLMoeTextModel._end == QEffQwen3VLMoeTextModel._total_layers: + if is_last_layer_window(QEffQwen3VLMoeTextModel, len(self.layers)): hidden_states = self.norm(hidden_states) if output_hidden_states: all_hidden_states += (hidden_states,) @@ -814,6 +819,45 @@ def forward( else: inputs_embeds = inputs_embeds + if not is_layerwise_active(): + # Default (non-layerwise) path: image merge + full decoder + lm_head in + # a single forward, identical to the pre-layerwise behavior/output contract. + B, N, C = inputs_embeds.shape + selected = input_ids == self.model.config.image_token_id + indices1 = selected.to(torch.int64).cumsum(1) - 1 + indices1 = torch.where(indices1 != -1, indices1 + image_idx, indices1) + indices0 = torch.arange(selected.unsqueeze(0).shape[0]).view(-1, 1) + image_features_expanded = vision_embeds.reshape(-1, C).unsqueeze(0)[indices0, indices1] + + num_features, bs, split_size, C = deepstack_features.shape + x = deepstack_features.reshape(num_features, bs * split_size, C) + deepstack_features_expanded = x[:, indices1, :] + image_input_embeds = torch.where(selected.unsqueeze(-1), image_features_expanded, inputs_embeds) + inputs_embeds = torch.where(input_ids.shape[1] == torch.tensor(1), inputs_embeds, image_input_embeds) + + image_mask = selected.clone() + visual_pos_masks = None + deepstack_visual_embeds = None + if image_mask is not None: + visual_pos_masks = image_mask + deepstack_visual_embeds = deepstack_features_expanded + + outputs = self.language_model( + inputs_embeds=inputs_embeds, + position_ids=position_ids, + past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, + batch_index=batch_index, + use_cache=True, + visual_pos_masks=visual_pos_masks, + deepstack_visual_embeds=deepstack_visual_embeds, + ) + logit_index = position_ids[0].to(torch.int32).argmax(1, keepdim=True) + hidden_states = outputs.last_hidden_state[torch.arange(position_ids[0].shape[0]).view(-1, 1), logit_index] + logits = self.model.lm_head(hidden_states) + image_idx = (indices1.max() + 1).unsqueeze(0).unsqueeze(0) + return logits, vision_embeds, deepstack_features, image_idx, outputs.past_key_values + if QEffQwen3VLMoeTextModel._start == 0: B, N, C = inputs_embeds.shape selected = input_ids == self.model.config.image_token_id diff --git a/examples/image_text_to_text/models/qwen3_5_moe/qwen3_5_moe_layerwise_decode.py b/examples/image_text_to_text/models/qwen3_5_moe/qwen3_5_moe_layerwise_decode.py index 7d92ca695..c89705ee5 100644 --- a/examples/image_text_to_text/models/qwen3_5_moe/qwen3_5_moe_layerwise_decode.py +++ b/examples/image_text_to_text/models/qwen3_5_moe/qwen3_5_moe_layerwise_decode.py @@ -28,7 +28,7 @@ def main(): config = AutoConfig.from_pretrained(MODEL_ID) - config.torch_dtype = "float32" + config.torch_dtype = "float16" tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL_ID) processor = AutoProcessor.from_pretrained(MODEL_ID) @@ -37,8 +37,8 @@ def main(): attn_implementation="eager", kv_offload=True, config=config, - dtype=torch.float32, - layerwise=True, + dtype=torch.float16, + # layerwise=True, ) qpc_path = qeff_model.compile( @@ -56,8 +56,8 @@ def main(): split_retained_state_io=True, use_onnx_subfunctions=False, mos=1, - layerwise=True, - layerwise_window_size=1, + # layerwise=True, + # layerwise_window_size=1, ) print(f"Final QPC path: {qpc_path}") diff --git a/examples/image_text_to_text/models/qwen3_vl_moe/qwen3_vl_disagg_mode.py b/examples/image_text_to_text/models/qwen3_vl_moe/qwen3_vl_disagg_mode.py index 951fcfb19..74eeff5ff 100644 --- a/examples/image_text_to_text/models/qwen3_vl_moe/qwen3_vl_disagg_mode.py +++ b/examples/image_text_to_text/models/qwen3_vl_moe/qwen3_vl_disagg_mode.py @@ -29,7 +29,7 @@ # config.vision_config.deepstack_visual_indexes = [8] qeff_model = QEFFAutoModelForImageTextToText.from_pretrained( - model_id, attn_implementation="eager", kv_offload=True, config=config, dtype=torch.float16, layerwise=True + model_id, attn_implementation="eager", kv_offload=True, config=config, dtype=torch.float16 ) tokenizer = transformers.AutoTokenizer.from_pretrained(model_id) processor = AutoProcessor.from_pretrained(model_id) @@ -55,7 +55,7 @@ split_model_io=True, skip_lang=True, use_onnx_subfunctions=True, - layerwise=True, + # layerwise=True, ) prefill_qpc_path = qeff_model.compile( @@ -76,8 +76,8 @@ enable_chunking=True, skip_vision=True, use_onnx_subfunctions=True, - layerwise=True, - layerwise_window_size=1, + # layerwise=True, + # layerwise_window_size=1, ) @@ -97,8 +97,8 @@ prefill_only=False, skip_vision=True, use_onnx_subfunctions=True, - layerwise=True, - layerwise_window_size=1, + # layerwise=True, + # layerwise_window_size=1, ) lang_prefill_session = QAICInferenceSession(prefill_qpc_path.get("lang_prefill_qpc_path")) diff --git a/tests/unit_test/models/test_model_quickcheck.py b/tests/unit_test/models/test_model_quickcheck.py index ad101d250..e68252969 100644 --- a/tests/unit_test/models/test_model_quickcheck.py +++ b/tests/unit_test/models/test_model_quickcheck.py @@ -1319,6 +1319,127 @@ def test_layerwise_supported_guard_rejects_unrelated_model(): _layerwise.assert_layerwise_supported(config) +def test_resolve_torch_dtype_normalizes_dtype_alias(): + """transformers-v5 ``dtype`` alias must be honored and kept in sync with ``torch_dtype``. + + Regression guard: passing ``dtype=float16`` used to be ignored, leaving the + model forced to float32 - breaking the Qwen3.5 float16 export path. + """ + from QEfficient.transformers.models.modeling_auto import _resolve_torch_dtype + + # float16 supplied via the v5 alias must survive and populate torch_dtype. + kwargs = {"dtype": torch.float16} + _resolve_torch_dtype(kwargs) + assert kwargs["torch_dtype"] == torch.float16 + assert kwargs["dtype"] == torch.float16 + + # float16 via the legacy name still works. + kwargs = {"torch_dtype": torch.float16} + _resolve_torch_dtype(kwargs) + assert kwargs["torch_dtype"] == torch.float16 + + # bfloat16 is downgraded to float32 on ai100 regardless of which name is used. + kwargs = {"dtype": torch.bfloat16} + _resolve_torch_dtype(kwargs) + assert kwargs["torch_dtype"] == torch.float32 + assert kwargs["dtype"] == torch.float32 + + +def test_qwen3_5_moe_gated_norm_preserves_float16(): + """GatedDeltaNet RMSNorm must keep the input dtype so the gated output feeds + the float16 out_proj without a dtype mismatch (Qwen3.5 float16 export).""" + import torch.nn as nn + + from QEfficient.transformers.models.qwen3_5_moe.modeling_qwen3_5_moe import ( + QEffQwen3_5MoeGatedDeltaNetCustomRMSNormAIC, + ) + + norm = QEffQwen3_5MoeGatedDeltaNetCustomRMSNormAIC() + norm.weight = nn.Parameter(torch.ones(16, dtype=torch.float16)) + norm.eps = 1e-6 + + out = norm(torch.randn(4, 16, dtype=torch.float16), torch.randn(4, 16, dtype=torch.float16)) + assert out.dtype == torch.float16 + + +def test_layerwise_matches_default_path_for_qwen3_moe(): + """Without-layerwise and with-layerwise forwards must produce identical output. + This is the core backward-compatibility contract: running every decoder layer + in a single forward (default path) must match running the same layers one + window at a time and chaining the hidden states (layerwise path), bit for bit. + """ + from transformers.models.qwen3_moe.modeling_qwen3_moe import Qwen3MoeConfig, Qwen3MoeForCausalLM + + import QEfficient + from QEfficient.transformers.models import _layerwise + + cfg = Qwen3MoeConfig( + hidden_size=64, + intermediate_size=128, + moe_intermediate_size=64, + num_hidden_layers=4, + num_attention_heads=4, + num_key_value_heads=2, + head_dim=16, + num_experts=4, + num_experts_per_tok=2, + vocab_size=128, + max_position_embeddings=128, + decoder_sparse_step=1, + norm_topk_prob=True, + ) + torch.manual_seed(0) + hf = Qwen3MoeForCausalLM(cfg).eval() + qeff_model = QEfficient.QEFFAutoModelForCausalLM(hf, continuous_batching=False) + inner = qeff_model.model.model + + B, S, ctx, num_layers = 1, 8, 16, cfg.num_hidden_layers + n_kv, head_dim = cfg.num_key_value_heads, cfg.head_dim + ids = torch.randint(0, cfg.vocab_size, (B, S)) + position_ids = torch.arange(S).view(1, -1) + + def fresh_pkv(): + return tuple( + (torch.zeros(B, n_kv, ctx, head_dim), torch.zeros(B, n_kv, ctx, head_dim)) for _ in range(num_layers) + ) + + # Default (non-layerwise) path: all layers in one shot. + with torch.no_grad(): + default_out = inner( + input_ids=ids, + position_ids=position_ids, + past_key_values=fresh_pkv(), + cache_position=torch.arange(S), + ) + + # Layerwise path: window size 1, chaining hidden states across windows. + hidden_states = inner.embed_tokens(ids) + try: + with _layerwise._layerwise_export_env(): + for window in range(num_layers): + _layerwise._set_layer_windows(window, window + 1, num_layers) + with torch.no_grad(): + out = inner( + inputs_embeds=hidden_states, + position_ids=position_ids, + past_key_values=fresh_pkv(), + cache_position=torch.arange(S), + ) + hidden_states = out.last_hidden_state + finally: + _layerwise._reset_layer_windows() + + assert torch.equal(default_out.last_hidden_state, hidden_states), ( + "layerwise windowed forward diverged from the default single-shot forward" + ) + + # The default path must leave the mutable window state untouched. + from QEfficient.transformers.models.qwen3_moe.modeling_qwen3_moe import QEffQwen3MoeModel + + assert QEffQwen3MoeModel._start == 0 + assert QEffQwen3MoeModel._end == 0 + + @pytest.mark.llm_model def test_layerwise_supported_guard_accepts_qwen3_vl_moe(): from QEfficient.transformers.models import _layerwise