diff --git a/QEfficient/base/modeling_qeff.py b/QEfficient/base/modeling_qeff.py index 3bb05c7b9..a7c63ad0e 100755 --- a/QEfficient/base/modeling_qeff.py +++ b/QEfficient/base/modeling_qeff.py @@ -8,7 +8,6 @@ import gc import inspect import logging -import os import shutil import subprocess import warnings @@ -65,6 +64,7 @@ class QEFFBaseModel(ABC): _start = 0 _end = 0 _total_layers = None + _layerwise_active = False _pytorch_transforms: List[PytorchTransform] _onnx_transforms = [BaseOnnxTransform] @@ -488,6 +488,7 @@ def _export_layerwise( prefill_only: Optional[bool] = False, **export_kwargs, ) -> str: + cache_probe = export_kwargs.pop("_layerwise_cache_probe", False) idx = int(QEFFBaseModel._start) end_idx = int(getattr(QEFFBaseModel, "_end", idx + 1)) if end_idx <= idx: @@ -502,6 +503,20 @@ def _export_layerwise( self.onnx_path = onnx_path return onnx_path + # Layer-wise reuse: if the merged final ONNX from a prior run exists + # under final_data/, skip per-window export entirely. The driver's + # stitch step picks up the same merged file, so re-running the same + # example without changes goes straight to the QPC compile. + final_data_dir = export_dir / "final_data" + if final_data_dir.is_dir(): + total_layers = int(getattr(QEFFBaseModel, "_total_layers", 0) or 0) + cached_merged = final_data_dir / f"merged_0-{total_layers}.onnx" + if total_layers > 0 and cached_merged.is_file(): + self.onnx_path = cached_merged + return self.onnx_path + if cache_probe: + return None + # check if the model is in meta state or weights are offloaded self._model_offloaded_check() @@ -544,9 +559,15 @@ def _resolve_pkv_layers(pkv_obj): z = example_inputs.pop("input_ids") if is_vision: hidden_size = self.model.language_model.config.hidden_size + embed_dtype = getattr(self.model.language_model.config, "torch_dtype", None) else: hidden_size = self.model.model.config.hidden_size - inputs_embeds = torch.rand(z.shape[0], z.shape[1], hidden_size, device=z.device) + embed_dtype = getattr(self.model.model.config, "torch_dtype", None) + # Match the model's dtype so per-window export does not introduce a + # float32/float16 mismatch when running through fp16 decoder layers. + if embed_dtype is None: + embed_dtype = next(self.model.parameters()).dtype + inputs_embeds = torch.rand(z.shape[0], z.shape[1], hidden_size, device=z.device, dtype=embed_dtype) example_inputs["inputs_embeds"] = inputs_embeds dynamic_axes["inputs_embeds"] = dynamic_axes.pop("input_ids") @@ -735,6 +756,7 @@ def _compile( For QNN Compilation path, when enable_qnn is set to True, any parameter passed in compiler_options will be ignored. """ + layerwise_cache_probe = compiler_options.pop("_layerwise_cache_probe", False) moe_prefill_packed_chunk_size = compiler_options.pop("moe_prefill_packed_chunk_size", None) if onnx_path is None: # If weights were offloaded after export, compiling must use the existing @@ -754,11 +776,15 @@ def _compile( num_devices=mdp_ts_num_devices, qaic_config=qaic_config, moe_prefill_packed_chunk_size=moe_prefill_packed_chunk_size, + _layerwise_cache_probe=layerwise_cache_probe, **compiler_options, ) - onnx_path = Path(onnx_path) - if os.environ.get("LAYERWISE_EXPORT", "False") == "True": + if QEFFBaseModel._layerwise_active: + if onnx_path is None: + return None + onnx_path = Path(onnx_path) return onnx_path + onnx_path = Path(onnx_path) compile_dir = Path(compile_dir or onnx_path.parent) qpc_path = compile_dir / "qpc" @@ -841,14 +867,13 @@ def _compile( compile_dir = qpc_path.with_name(qpc_path.name + "-" + compile_hash) qpc_path = compile_dir / "qpc" - qpc_path.mkdir(parents=True, exist_ok=True) - + if (qpc_path / "programqpc.bin").is_file(): + self.qpc_path = qpc_path + return qpc_path if qpc_path.is_dir(): - if (qpc_path / "programqpc.bin").is_file(): - self.qpc_path = qpc_path - return qpc_path - # Probably compilation failure last time, delete directory to start over + # Probably compilation failure last time, delete directory to start over. shutil.rmtree(qpc_path) + compile_dir.mkdir(parents=True, exist_ok=True) # Write the generated MDP partition config file (not if user provided it) diff --git a/QEfficient/generation/cloud_infer.py b/QEfficient/generation/cloud_infer.py index 47703930a..88d1357b6 100644 --- a/QEfficient/generation/cloud_infer.py +++ b/QEfficient/generation/cloud_infer.py @@ -13,6 +13,21 @@ import numpy as np + +def _public_retained_state_name(output_name: str) -> Optional[str]: + """Map internal subfunction retained-state outputs to public runtime names.""" + suffix = "_InternalRetainedState" + if output_name.endswith(suffix): + return output_name[: -len(suffix)] + "_RetainedState" + return None + + +def _add_basename_binding_aliases(binding_index_map: Dict[str, int], bindings) -> None: + """Allow callers to use unprefixed I/O names for prefixed ONNX graphs.""" + for binding in bindings: + binding_index_map.setdefault(binding.name.rsplit("/", 1)[-1], binding.index) + + try: import qaicrt @@ -101,6 +116,7 @@ def __init__( ] self.bindings = iodesc.selected_set.bindings self.binding_index_map = {binding.name: binding.index for binding in self.bindings} + _add_basename_binding_aliases(self.binding_index_map, self.bindings) # Create and load Program prog_properties = qaicrt.QAicProgramProperties() prog_properties.dataPathTimeoutMs = 60_000 @@ -226,8 +242,15 @@ def run(self, inputs: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: buffer_index = self.binding_index_map[output_name] if self.qbuffers[buffer_index].size == 0: continue - outputs[output_name] = np.frombuffer( + output = np.frombuffer( bytes(output_qbuffers[buffer_index]), self.aic_to_np_dtype_mapping[self.bindings[buffer_index].type], ).reshape(self.buf_dims[buffer_index][1]) + outputs[output_name] = output + output_basename = output_name.rsplit("/", 1)[-1] + outputs.setdefault(output_basename, output) + public_name = _public_retained_state_name(output_name) + if public_name is not None: + outputs[public_name] = output + outputs.setdefault(public_name.rsplit("/", 1)[-1], output) return outputs diff --git a/QEfficient/transformers/models/_layerwise.py b/QEfficient/transformers/models/_layerwise.py new file mode 100644 index 000000000..22f741616 --- /dev/null +++ b/QEfficient/transformers/models/_layerwise.py @@ -0,0 +1,610 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +""" +Layer-wise export/compile orchestration for selected MoE architectures. + +Provisional API. Scheduled for removal once first-class multi-window export lands. +Today only ``qwen3_vl_moe``, ``qwen3_5_moe`` and ``qwen3_moe`` carry the +windowing hooks (``_start``/``_end`` class attributes) that this driver relies +on; calling :func:`run_layerwise` for any other architecture will raise. +""" + +from __future__ import annotations + +import functools +import warnings +from contextlib import contextmanager +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple + +import transformers + +import QEfficient + +# Architectures whose modeling files declare _start/_end class attributes the +# layer-wise driver pokes. Keep this list narrow on purpose - adding a new +# architecture must come with the corresponding modeling-file hooks. +_LAYERWISE_SUPPORTED_MODEL_TYPES = frozenset( + { + "qwen3_vl_moe", + "qwen3_vl_moe_text", + "qwen3_5_moe", + "qwen3_moe", + } +) + +# Hard cap on the RoPE base table rows we serialize per window. Chosen as a +# constant (not a function of ctx_len) on purpose: changing ctx_len at compile +# time should re-use the cached ONNX and only re-run the QPC compile. Any +# inference-time position id is bounded by ctx_len, and ctx_len in practice +# stays well under 32K for the supported MoE families today, so dropping +# the unreachable rows past 32K is lossless. +_LAYERWISE_ROPE_MAX_POSITIONS = 32768 + +# Process-local layer-wise window state. We deliberately avoid setting class +# attributes on transformers.modeling_utils.PreTrainedModel - those would leak +# to every HF model in the process and survive past the layer-wise run. The +# patched HF hooks (shard filter, model-init nuller) close over this dict so +# they can be installed once and behave as no-ops whenever ``active`` is False. +_LAYERWISE_STATE: Dict[str, int] = { + "active": 0, + "start": 0, + "end": 0, + "total_layers": 0, + "text_start": 0, + "text_end": 0, + "text_total_layers": 0, +} + +_DEPRECATION_WARNED = False + + +def _maybe_warn_deprecation() -> None: + global _DEPRECATION_WARNED + if _DEPRECATION_WARNED: + return + warnings.warn( + "layerwise=True is a provisional API and will be deprecated once " + "first-class multi-window export lands. Use only for the supported " + f"model types: {sorted(_LAYERWISE_SUPPORTED_MODEL_TYPES)}.", + DeprecationWarning, + stacklevel=3, + ) + _DEPRECATION_WARNED = True + + +def _resolve_model_type(config) -> str: + text_config = getattr(config, "text_config", None) + if text_config is not None and getattr(text_config, "model_type", None): + return text_config.model_type + return getattr(config, "model_type", "") or "" + + +def assert_layerwise_supported(config) -> str: + """Raise a clear error if the model architecture has no layerwise hooks.""" + top_type = getattr(config, "model_type", "") or "" + text_type = _resolve_model_type(config) + if top_type in _LAYERWISE_SUPPORTED_MODEL_TYPES or text_type in _LAYERWISE_SUPPORTED_MODEL_TYPES: + return text_type or top_type + raise NotImplementedError( + "layerwise=True is only supported for model types: " + f"{sorted(_LAYERWISE_SUPPORTED_MODEL_TYPES)}. Got model_type=" + f"'{top_type}' (text_config.model_type='{text_type}'). " + "Run with layerwise=False (the default) for this model." + ) + + +# --------------------------------------------------------------------------- +# Internal helpers (lifted from the legacy example script) +# --------------------------------------------------------------------------- + + +def _ensure_pretrained_window_attrs() -> None: + """No-op kept for compatibility with prior call sites. + + Layer-wise window state lives in the module-local ``_LAYERWISE_STATE`` + dict (see top of file); we no longer pollute ``PreTrainedModel`` with + class attributes. + """ + return + + +def _build_layer_windows(total_layers: int, window_size: int) -> List[Tuple[int, int]]: + if total_layers <= 0: + raise ValueError(f"Invalid total_layers={total_layers}; expected > 0.") + if window_size <= 0: + raise ValueError(f"Invalid window_size={window_size}; expected > 0.") + windows: List[Tuple[int, int]] = [] + start = 0 + while start < total_layers: + end = min(total_layers, start + window_size) + windows.append((start, end)) + start = end + return windows + + +def _get_text_layers_container(model): + if ( + hasattr(model, "model") + and hasattr(model.model, "language_model") + and hasattr(model.model.language_model, "layers") + ): + return model.model.language_model.layers + if hasattr(model, "model") and hasattr(model.model, "layers"): + return model.model.layers + if hasattr(model, "language_model") and hasattr(model.language_model, "layers"): + return model.language_model.layers + if hasattr(model, "layers"): + return model.layers + return None + + +def _null_outside_window_layers(model, *, apply_text: bool = True) -> None: + if not apply_text: + return + text_start = int(_LAYERWISE_STATE["text_start"] or _LAYERWISE_STATE["start"]) + text_end = int(_LAYERWISE_STATE["text_end"] or _LAYERWISE_STATE["end"]) + text_layers = _get_text_layers_container(model) + if text_layers is not None and text_end > text_start: + for idx, _ in enumerate(text_layers): + if idx < text_start or idx >= text_end: + text_layers[idx] = None + + +def _find_language_model(model): + """Locate the inner language_model that owns sin_cached / cos_cached / embed_tokens.""" + candidates = [] + if hasattr(model, "model") and hasattr(model.model, "language_model"): + candidates.append(model.model.language_model) + if hasattr(model, "language_model"): + candidates.append(model.language_model) + if hasattr(model, "model"): + candidates.append(model.model) + candidates.append(model) + for cand in candidates: + if any(hasattr(cand, attr) for attr in ("sin_cached", "cos_cached", "embed_tokens")): + return cand + return None + + +def _slim_for_window_export(qeff_model, *, ctx_len: Optional[int]) -> None: + """Shrink top-level params that are unused (or oversized) for this window. + + Without this, every per-window ONNX shard re-bakes the full RoPE base table + (``sin_cached``/``cos_cached`` of shape ``[max_position_embeddings, head_dim]``, + typically tens of MB in fp32) plus the full vocab embedding, blowing each + layer-window shard up by 1-2 orders of magnitude over the actual layer + weight footprint. Each top-level param is touched in-place; the next + window rebuilds the model from scratch via the factory so there is no + leakage across windows. + """ + import torch + + inner = qeff_model.model if hasattr(qeff_model, "model") else qeff_model + lm = _find_language_model(inner) + if lm is None: + return + + text_start = int(_LAYERWISE_STATE["text_start"] or _LAYERWISE_STATE["start"]) + text_end = int(_LAYERWISE_STATE["text_end"] or _LAYERWISE_STATE["end"]) + text_total = int(_LAYERWISE_STATE["text_total_layers"] or _LAYERWISE_STATE["total_layers"] or 0) + + # 1) Truncate sin_cached / cos_cached to a fixed cap (32K rows) instead + # of ctx_len. A constant cap keeps the export hash invariant when the + # user changes ctx_len, so re-compiling at a different context length + # only re-runs the QPC compile - the cached layer-wise ONNX is reused. + # Inference-time position ids are bounded by ctx_len which is well + # under the cap for the supported MoE families today. + del ctx_len # signature retained for forward compat, value intentionally unused + rope_cap = _LAYERWISE_ROPE_MAX_POSITIONS + for attr in ("sin_cached", "cos_cached"): + param = getattr(lm, attr, None) + if param is None or not hasattr(param, "shape") or param.dim() < 2: + continue + cur_rows = int(param.shape[0]) + if cur_rows <= rope_cap: + continue + with torch.no_grad(): + truncated = param.detach()[:rope_cap].clone().contiguous() + new_param = torch.nn.Parameter(truncated, requires_grad=False) + setattr(lm, attr, new_param) + + # 2) Drop the vocab embedding for windows that don't run input-id lookup. + # The first window (text_start == 0) is the only one that calls + # ``get_input_embeddings()(input_ids)``; later windows take + # ``inputs_embeds`` directly so the embedding matrix is unreached but + # still serialized. Replace its weight with a tiny placeholder of the + # same dtype/device so module attributes stay valid. + if text_start > 0 and hasattr(lm, "embed_tokens"): + embed = getattr(lm, "embed_tokens", None) + weight = getattr(embed, "weight", None) + if weight is not None and weight.dim() == 2 and weight.shape[0] > 1: + with torch.no_grad(): + tiny = torch.zeros((1, weight.shape[1]), dtype=weight.dtype, device=weight.device) + embed.weight = torch.nn.Parameter(tiny, requires_grad=False) + + # 3) Drop the lm_head for windows that aren't the last one. Only the final + # window applies ``self.model.lm_head(hidden_states)``. + outer = qeff_model.model if hasattr(qeff_model, "model") else qeff_model + lm_head = getattr(outer, "lm_head", None) + if ( + lm_head is not None + and text_total > 0 + and text_end < text_total + and hasattr(lm_head, "weight") + and lm_head.weight is not None + and lm_head.weight.dim() == 2 + and lm_head.weight.shape[0] > 1 + ): + with torch.no_grad(): + tiny = torch.zeros( + (1, lm_head.weight.shape[1]), + dtype=lm_head.weight.dtype, + device=lm_head.weight.device, + ) + lm_head.weight = torch.nn.Parameter(tiny, requires_grad=False) + if getattr(lm_head, "bias", None) is not None: + lm_head.bias = torch.nn.Parameter( + torch.zeros((1,), dtype=lm_head.bias.dtype, device=lm_head.bias.device), + requires_grad=False, + ) + + +def _install_window_patch(model_cls) -> None: + if getattr(model_cls, "_window_patch_installed", False): + return + original_init = model_cls.__init__ + + @functools.wraps(original_init) + def patched_init(self, *args, **kwargs): + original_init(self, *args, **kwargs) + # Only nullify decoder layers when the layer-wise driver is actively + # exporting; idle calls to from_pretrained must behave normally. + if _LAYERWISE_STATE["active"]: + _null_outside_window_layers(self, apply_text=True) + + model_cls.__init__ = patched_init + model_cls._window_patch_installed = True + + +def _install_shard_window_patch() -> None: + if getattr(transformers.modeling_utils, "_window_shard_patch_installed", False): + return + + original_get_checkpoint_shard_files = transformers.modeling_utils.get_checkpoint_shard_files + + @functools.wraps(original_get_checkpoint_shard_files) + def patched_get_checkpoint_shard_files(*args, **kwargs): + shard_files, metadata = original_get_checkpoint_shard_files(*args, **kwargs) + # Honor the module-local state instead of polluting PreTrainedModel. + # When layerwise is not active this reduces to a no-op for any HF + # caller that happens to load checkpoint shards in this process. + if not _LAYERWISE_STATE["active"]: + return shard_files, metadata + weight_map = metadata.get("weight_map") + if not weight_map: + return shard_files, metadata + + start = int(_LAYERWISE_STATE["start"]) + end = int(_LAYERWISE_STATE["end"]) + text_start = int(_LAYERWISE_STATE["text_start"] or start) + text_end = int(_LAYERWISE_STATE["text_end"] or end) + if text_end <= text_start: + return shard_files, metadata + + selected_text_prefixes = tuple( + [f"model.layers.{i}." for i in range(text_start, text_end)] + + [f"model.language_model.layers.{i}." for i in range(text_start, text_end)] + ) + filtered_weight_map: Dict[str, str] = {} + for checkpoint_key, shard_name in weight_map.items(): + if checkpoint_key.startswith("model.layers.") or checkpoint_key.startswith("model.language_model.layers."): + if checkpoint_key.startswith(selected_text_prefixes): + filtered_weight_map[checkpoint_key] = shard_name + continue + filtered_weight_map[checkpoint_key] = shard_name + + if not filtered_weight_map: + return shard_files, metadata + + shard_name_to_path = {path.split("/")[-1]: path for path in shard_files} + filtered_shard_names = sorted(set(filtered_weight_map.values())) + filtered_shard_files = [shard_name_to_path[name] for name in filtered_shard_names if name in shard_name_to_path] + if not filtered_shard_files: + return shard_files, metadata + + metadata["weight_map"] = filtered_weight_map + metadata["all_checkpoint_keys"] = list(filtered_weight_map.keys()) + return filtered_shard_files, metadata + + transformers.modeling_utils.get_checkpoint_shard_files = patched_get_checkpoint_shard_files + transformers.modeling_utils._window_shard_patch_installed = True + + +def _set_layer_windows(text_start: int, text_end: int, text_total_layers: int) -> None: + # Update the module-local state used by patched HF hooks. We deliberately + # do NOT set class attributes on transformers.modeling_utils.PreTrainedModel + # here - that would leak to every HF model in the process. + _LAYERWISE_STATE["start"] = text_start + _LAYERWISE_STATE["end"] = text_end + _LAYERWISE_STATE["total_layers"] = text_total_layers + _LAYERWISE_STATE["text_start"] = text_start + _LAYERWISE_STATE["text_end"] = text_end + _LAYERWISE_STATE["text_total_layers"] = text_total_layers + + # The QEff modeling classes themselves expose _start/_end/_total_layers as + # part of their windowing contract (they are read inside their forward + # implementations). Those are ours to set. + qeff_vl_mod = getattr(QEfficient.transformers.models, "qwen3_vl_moe", None) + if qeff_vl_mod is not None: + cls = getattr(qeff_vl_mod.modeling_qwen3_vl_moe, "QEffQwen3VLMoeTextModel", None) + if cls is not None: + cls._start = text_start + cls._end = text_end + cls._total_layers = text_total_layers + + qeff_35_mod = getattr(QEfficient.transformers.models, "qwen3_5_moe", None) + if qeff_35_mod is not None: + cls = getattr(qeff_35_mod.modeling_qwen3_5_moe, "QEffQwen3_5MoeTextModel", None) + if cls is not None: + cls._start = text_start + cls._end = text_end + cls._total_layers = text_total_layers + + qeff_3_mod = getattr(QEfficient.transformers.models, "qwen3_moe", None) + if qeff_3_mod is not None: + cls = getattr(qeff_3_mod.modeling_qwen3_moe, "QEffQwen3MoeModel", None) + if cls is not None: + cls._start = text_start + cls._end = text_end + cls._total_layers = text_total_layers + + QEfficient.base.modeling_qeff.QEFFBaseModel._start = text_start + QEfficient.base.modeling_qeff.QEFFBaseModel._end = text_end + QEfficient.base.modeling_qeff.QEFFBaseModel._total_layers = text_total_layers + + +def _reset_layer_windows() -> None: + _set_layer_windows(0, 0, 0) + + +def _resolve_export_root(onnx_path: Path) -> Path: + parts = list(onnx_path.parts) + if "onnx_layerwise_tmp" in parts: + return Path(*parts[: parts.index("onnx_layerwise_tmp")]) + if "final_data" in parts: + return Path(*parts[: parts.index("final_data")]) + return onnx_path.parent + + +def _is_cached_merged(onnx_path: Path) -> bool: + return "final_data" in onnx_path.parts and onnx_path.name.startswith("merged_") + + +def _cached_merged_onnx(export_root: Path, total_layers: Optional[int] = None) -> Optional[Path]: + """Return the complete cached merged ONNX, if it exists.""" + cached_dir = export_root / "final_data" + if not cached_dir.is_dir(): + return None + if total_layers is not None: + expected = cached_dir / f"merged_0-{total_layers}.onnx" + if expected.is_file(): + return expected + cached_merged = sorted(cached_dir.glob("merged_0-*.onnx")) + return cached_merged[-1] if cached_merged else None + + +def _stitch_layerwise_if_available(export_root: Path, total_layers: Optional[int] = None) -> str: + # If a prior run already produced the complete merged ONNX, just return it. + cached_merged = _cached_merged_onnx(export_root, total_layers) + if cached_merged is not None: + return str(cached_merged) + pipeline_fn = getattr(QEfficient.utils, "layerwise_pipeline", None) + if callable(pipeline_fn): + return pipeline_fn(str(export_root)) + return str(export_root / "onnx_layerwise_tmp") + + +def _cached_layerwise_onnx_path(qeff_model, compile_kwargs: Dict[str, Any]) -> Optional[Path]: + """Return a cached merged ONNX path without exporting or loading weights.""" + probe_kwargs = dict(compile_kwargs) + probe_kwargs["_layerwise_cache_probe"] = True + cached = qeff_model.compile(**probe_kwargs) + if isinstance(cached, dict): + cached = next( + ( + cached.get(key) + for key in ("lang_decode_qpc_path", "lang_prefill_qpc_path", "lang_qpc_path") + if cached.get(key) is not None + ), + None, + ) + return Path(cached) if cached is not None else None + + +def _install_window_patches_for(model_type: str) -> None: + """Install the HF __init__/shard patches needed for the given model_type. + + The shard-file patch makes ``from_pretrained`` skip checkpoint shards that + only contain weights for layers outside the active window, so loading an + N-layer model in a window of size 1 reads ~1/N of the disk. The init + patch nulls the unused decoder layers right after the model is built so + the full layer list is never instantiated in memory. + """ + _install_shard_window_patch() + candidates = [] + qwen3_vl_moe_mod = getattr(getattr(transformers.models, "qwen3_vl_moe", None), "modeling_qwen3_vl_moe", None) + if qwen3_vl_moe_mod is not None and "qwen3_vl_moe" in model_type: + candidates.extend( + cls + for name in ("Qwen3VLMoeForConditionalGeneration", "Qwen3VLMoeForCausalLM") + if (cls := getattr(qwen3_vl_moe_mod, name, None)) is not None + ) + qwen3_moe_mod = getattr(getattr(transformers.models, "qwen3_moe", None), "modeling_qwen3_moe", None) + if qwen3_moe_mod is not None and model_type in {"qwen3_moe"}: + candidates.extend( + cls for name in ("Qwen3MoeForCausalLM",) if (cls := getattr(qwen3_moe_mod, name, None)) is not None + ) + # qwen3_5_moe shares the qwen3_vl_moe HF classes today; the QEff modeling + # file overrides behavior. Install the shard patch (above) and rely on + # _null_outside_window_layers running post-init in the driver loop. + for cls in candidates: + _install_window_patch(cls) + + +@contextmanager +def _layerwise_export_env(): + """Toggle the layerwise-active flag on QEFFBaseModel for the inner block. + + No environment variables are touched - the flag is a pure in-process class + attribute, which keeps the API self-contained and lets concurrent Python + interpreters (e.g. test workers) operate independently. + """ + base = QEfficient.base.modeling_qeff.QEFFBaseModel + prev_active = getattr(base, "_layerwise_active", False) + prev_state_active = _LAYERWISE_STATE["active"] + base._layerwise_active = True + _LAYERWISE_STATE["active"] = 1 + try: + yield + finally: + base._layerwise_active = prev_active + _LAYERWISE_STATE["active"] = prev_state_active + + +def _resolve_text_total_layers(config) -> int: + text_config = getattr(config, "text_config", config) + total = getattr(text_config, "num_hidden_layers", None) + if total is None: + raise ValueError("Could not resolve `num_hidden_layers` from config.text_config / config.") + return int(total) + + +# --------------------------------------------------------------------------- +# Public driver +# --------------------------------------------------------------------------- + + +def run_layerwise( + *, + model_id: str, + config, + qeff_factory, + compile_kwargs: Dict[str, Any], + probe_qeff_model=None, + window_size: int = 1, + final_compile: bool = True, +) -> Any: + """Drive the per-window export loop and (optionally) the final stitched compile. + + Parameters + ---------- + model_id : str + HF id / path passed to ``qeff_factory(model_id, config)`` to (re)build a + QEff model fresh per window. + config : transformers.PretrainedConfig + Already-mutated config (the caller is responsible for any vision tweaks + like ``deepstack_visual_indexes``). + qeff_factory : Callable[[str, PretrainedConfig], QEffModel] + Factory invoked once per window to materialize a fresh QEff model that + only loads the active window's weights. + compile_kwargs : dict + Forwarded verbatim to ``qeff_model.compile(...)`` per window. The driver + injects ``skip_lang`` per-window and ``lang_onnx_path`` for the final + stitched compile. + probe_qeff_model : QEffModel, optional + Existing wrapper used for cache probing before any per-window weights are + loaded. In normal layerwise use this is the outer meta wrapper. + window_size : int + Number of text-decoder layers per window. ``1`` matches the legacy + example. + final_compile : bool + When True (compile path), do the final QPC compile on the merged ONNX + and return ``qpc_paths``. When False (export path), return the merged + ONNX path only. + + Returns + ------- + Either the qpc_paths dict (final_compile=True) or the merged ONNX path + (final_compile=False). + """ + _maybe_warn_deprecation() + model_type = assert_layerwise_supported(config) + + text_total_layers = _resolve_text_total_layers(config) + text_cfg = getattr(config, "text_config", config) + text_cfg.num_hidden_layers = text_total_layers + + _ensure_pretrained_window_attrs() + _install_window_patches_for(model_type) + + windows = _build_layer_windows(text_total_layers, window_size) + first_onnx_path: Optional[Path] = None + last_qeff_model = None + + try: + with _layerwise_export_env(): + _set_layer_windows(0, min(window_size, text_total_layers), text_total_layers) + cached_probe = probe_qeff_model or qeff_factory(model_id, config) + cached_onnx_path = _cached_layerwise_onnx_path(cached_probe, compile_kwargs) + if cached_onnx_path is not None: + first_onnx_path = cached_onnx_path + last_qeff_model = cached_probe + + for text_start, text_end in windows: + if cached_onnx_path is not None: + break + _set_layer_windows(text_start, text_end, text_total_layers) + + qeff_model = qeff_factory(model_id, config) + last_qeff_model = qeff_model + if hasattr(qeff_model, "model"): + _null_outside_window_layers(qeff_model.model, apply_text=True) + _slim_for_window_export(qeff_model, ctx_len=compile_kwargs.get("ctx_len")) + + window_kwargs = dict(compile_kwargs) + # skip_lang is a VLM-only kwarg; only inject when present in caller's kwargs. + if "skip_lang" in window_kwargs: + window_kwargs["skip_lang"] = False + onnx_path = qeff_model.compile(**window_kwargs) + if first_onnx_path is None: + if isinstance(onnx_path, dict): + lang_key = next( + ( + k + for k in ( + "lang_decode_qpc_path", + "lang_prefill_qpc_path", + "lang_qpc_path", + ) + if k in onnx_path + ), + None, + ) + if lang_key is None: + raise RuntimeError(f"Layer-wise window produced no lang_*_qpc_path: keys={list(onnx_path)}") + lang_path = onnx_path[lang_key] + else: + lang_path = onnx_path + first_onnx_path = Path(str(lang_path)) + finally: + _reset_layer_windows() + + if first_onnx_path is None: + raise RuntimeError("Layer-wise export produced no ONNX shards.") + + export_root = _resolve_export_root(first_onnx_path) + final_artifact = _stitch_layerwise_if_available(export_root, text_total_layers) + + if not final_compile: + return final_artifact + + final_kwargs = dict(compile_kwargs) + final_kwargs["lang_onnx_path"] = final_artifact + final_kwargs.setdefault("skip_lang", False) + return last_qeff_model.compile(**final_kwargs) diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 65b89d274..7bca55f9d 100755 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -13,6 +13,7 @@ from typing import List, Optional, Union import numpy as np +import onnx import torch import torch.nn as nn from transformers import ( @@ -123,6 +124,85 @@ def _resolve_torch_dtype(kwargs: dict) -> None: kwargs["torch_dtype"] = torch.float32 +def _build_meta_model(hf_auto_class, pretrained_model_name_or_path, kwargs): + """Construct an HF model on the meta device for layer-wise mode. + + Avoids materializing checkpoint weights at the outer ``from_pretrained`` + call site. The wrapper still has a fully-typed ``nn.Module`` (so config, + architectures, and module structure are all real), but every parameter + and buffer is a meta tensor — zero RAM. The layer-wise driver later + rebuilds a real per-window model when ``compile()``/``export()`` runs. + """ + from transformers import AutoConfig + + config = kwargs.get("config", None) + if config is None: + config_kwargs = { + k: kwargs[k] for k in ("trust_remote_code", "revision", "token", "subfolder", "cache_dir") if k in kwargs + } + config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **config_kwargs) + torch_dtype = kwargs.get("torch_dtype", torch.float32) + with torch.device("meta"): + model = hf_auto_class.from_config(config, torch_dtype=torch_dtype) + return model + + +def _compile_io_name(name: str, *, use_onnx_subfunctions: bool) -> str: + """Return the compiler-visible name for retained-state ONNX outputs.""" + if not use_onnx_subfunctions or not name.endswith("_RetainedState"): + return name + if any(token in name for token in ("key", "value", "compressed_kv", "k_pe")): + return name[: -len("_RetainedState")] + "_InternalRetainedState" + return name + + +def _state_input_name(output_name: str) -> str: + """Map a retained-state output name to its matching state input name.""" + for suffix in ("_InternalRetainedState", "_RetainedState"): + if output_name.endswith(suffix): + return output_name[: -len(suffix)] + return output_name + + +def _filter_custom_io_for_onnx(custom_io: dict, onnx_path: Optional[Union[str, Path]]) -> dict: + """Keep custom-IO entries that exist in the ONNX graph. + + Layerwise stitched graphs may prefix I/O names (for example ``layer_0/``) + and may expose public retained-state names even when subfunction export used + internal names during per-window export. Matching by basename keeps compiler + custom-IO compatible with both graph shapes. + """ + if onnx_path is None: + return custom_io + try: + model = onnx.load(onnx_path, load_external_data=False) + except Exception: + return custom_io + + io_names = {value.name for value in list(model.graph.input) + list(model.graph.output)} + basename_to_name = {name.rsplit("/", 1)[-1]: name for name in io_names} + + def resolve_name(name: str) -> Optional[str]: + candidates = [name] + if name.endswith("_InternalRetainedState"): + candidates.append(name[: -len("_InternalRetainedState")] + "_RetainedState") + elif name.endswith("_RetainedState"): + candidates.append(name[: -len("_RetainedState")] + "_InternalRetainedState") + for candidate in candidates: + if candidate in io_names: + return candidate + if candidate in basename_to_name: + return basename_to_name[candidate] + return None + + filtered = {} + for name, dtype in custom_io.items(): + resolved_name = resolve_name(name) + if resolved_name is not None: + filtered[resolved_name] = dtype + return filtered + + class QEFFTransformersBase(QEFFBaseModel): """ Base class for QEfficient wrappers around HuggingFace transformer models. @@ -1156,7 +1236,7 @@ def export( self.hash_params["prefill_only"] = False self.__update_prefill_transform(False, retain_full_kv=kwargs.get("retain_full_kv", False)) - if os.environ.get("LAYERWISE_EXPORT", "False") == "True": + if QEfficient.base.modeling_qeff.QEFFBaseModel._layerwise_active: return self._export_layerwise( inputs, output_names=output_names, @@ -1164,6 +1244,7 @@ def export( export_dir=export_dir, offload_pt_weights=offload_pt_weights, use_onnx_subfunctions=kwargs.get("use_onnx_subfunctions", False), + _layerwise_cache_probe=kwargs.get("_layerwise_cache_probe", False), ) else: return self._export( @@ -1280,6 +1361,7 @@ def __init__( ) self.model = model self.config = model.config + self._pretrained_model_name_or_path = kwargs.get("pretrained_model_name_or_path", None) self.vision_model = QEffVisionEncoderForTextImageToTextModel(model, **kwargs) self.lang_model = QEffCausalLMForTextImageToTextModel(model, qaic_config=qaic_config, **kwargs) @@ -1319,10 +1401,17 @@ def from_pretrained(cls, pretrained_model_name_or_path: str, qaic_config: Option if kwargs.get("attn_implementation", None) not in {None, "eager"}: logger.warning('Updating attn_implementation="eager"') - if kwargs.get("low_cpu_mem_usage", None): - logger.warning("Updating low_cpu_mem_usage=False") - - kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False}) + # Respect an explicit low_cpu_mem_usage=True from the caller (used by the + # layer-wise driver to keep RAM bounded by one window's weights via meta + # init + sharded materialization). For everyone else, force False to + # match prior behavior. + explicit_low_cpu = kwargs.get("low_cpu_mem_usage", None) is True + kwargs.update( + { + "attn_implementation": "eager", + "low_cpu_mem_usage": True if explicit_low_cpu else False, + } + ) _resolve_torch_dtype(kwargs) model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, **kwargs) @@ -1357,6 +1446,8 @@ def export( prefill_seq_len: Optional[int] = None, prefill_only: bool = False, enable_chunking: bool = False, + layerwise: bool = False, + layerwise_window_size: int = 1, **kwargs, ) -> str: """ @@ -1379,6 +1470,19 @@ def export( List[str] A list containing the paths to the generated ONNX graph files for both components. """ + layerwise_cache_probe = kwargs.pop("_layerwise_cache_probe", False) + if layerwise: + return self._run_layerwise_export( + export_dir=export_dir, + use_onnx_subfunctions=use_onnx_subfunctions, + skip_vision=skip_vision, + skip_lang=skip_lang, + prefill_seq_len=prefill_seq_len, + prefill_only=prefill_only, + enable_chunking=enable_chunking, + layerwise_window_size=layerwise_window_size, + **kwargs, + ) dummy_inputs_kwargs = {} if prefill_seq_len is not None: dummy_inputs_kwargs["prefill_seq_len"] = int(prefill_seq_len) @@ -1416,7 +1520,7 @@ def export( qaic_config=self.lang_model.model.qaic_config, ) - layerwise_export = os.environ.get("LAYERWISE_EXPORT", "False") == "True" + layerwise_export = QEFFBaseModel._layerwise_active should_export = not skip_vision and ( not layerwise_export @@ -1426,7 +1530,7 @@ def export( == QEfficient.base.modeling_qeff.QEFFBaseModel._total_layers ) ) - if should_export: + if should_export and not layerwise_cache_probe: self.vision_model.export( inputs["vision"], output_names["vision"], @@ -1452,6 +1556,7 @@ def export( prefill_only=prefill_only, enable_chunking=enable_chunking, prefill_seq_len=prefill_seq_len, + _layerwise_cache_probe=layerwise_cache_probe, ) return self.onnx_path @@ -1482,6 +1587,111 @@ def transform( **compiler_options, ) + def _layerwise_factory_kwargs(self): + """Reproduce the from_pretrained kwargs needed to rebuild this wrapper per window.""" + # Mirror the dual-QPC from_pretrained surface; the layerwise driver passes + # config explicitly per call, so we only carry torch_dtype + attn here. + # low_cpu_mem_usage=True works with the shard-window patch to allocate + # only the active window's weights instead of the full model. + torch_dtype = getattr(self.config, "torch_dtype", None) + return { + "attn_implementation": "eager", + "kv_offload": True, + "torch_dtype": torch_dtype, + "low_cpu_mem_usage": True, + } + + def _build_layerwise_factory(self): + """Return a callable suitable for layerwise driver's qeff_factory hook.""" + from QEfficient import QEFFAutoModelForImageTextToText + + base_kwargs = self._layerwise_factory_kwargs() + + def _factory(model_id, config): + return QEFFAutoModelForImageTextToText.from_pretrained( + model_id, + config=config, + **base_kwargs, + ) + + return _factory + + def _run_layerwise_export( + self, + *, + export_dir, + use_onnx_subfunctions, + skip_vision, + skip_lang, + prefill_seq_len, + prefill_only, + enable_chunking, + layerwise_window_size, + **kwargs, + ): + from QEfficient.transformers.models import _layerwise + + model_id = self._pretrained_model_name_or_path + if model_id is None: + raise RuntimeError( + "layerwise=True requires the QEff model to be built via " + "QEFFAutoModelForImageTextToText.from_pretrained(...). " + "Direct __init__ does not preserve the model id needed for per-window reload." + ) + compile_kwargs = dict( + export_dir=export_dir, + use_onnx_subfunctions=use_onnx_subfunctions, + skip_vision=skip_vision, + prefill_seq_len=prefill_seq_len, + prefill_only=prefill_only, + enable_chunking=enable_chunking, + **kwargs, + ) + return _layerwise.run_layerwise( + model_id=model_id, + config=self.config, + qeff_factory=self._build_layerwise_factory(), + compile_kwargs=compile_kwargs, + probe_qeff_model=self, + window_size=layerwise_window_size, + final_compile=False, + ) + + def _run_layerwise_compile( + self, + *, + layerwise_window_size, + **compile_kwargs, + ): + from QEfficient.transformers.models import _layerwise + + model_id = self._pretrained_model_name_or_path + if model_id is None: + raise RuntimeError( + "layerwise=True requires the QEff model to be built via " + "QEFFAutoModelForImageTextToText.from_pretrained(...). " + "Direct __init__ does not preserve the model id needed for per-window reload." + ) + qpc_paths = _layerwise.run_layerwise( + model_id=model_id, + config=self.config, + qeff_factory=self._build_layerwise_factory(), + compile_kwargs=compile_kwargs, + probe_qeff_model=self, + window_size=layerwise_window_size, + final_compile=True, + ) + self.qpc_paths = qpc_paths + if isinstance(qpc_paths, dict): + self.vision_model.qpc_path = qpc_paths.get("vision_qpc_path") or self.vision_model.qpc_path + self.lang_model.qpc_path = ( + qpc_paths.get("lang_decode_qpc_path") + or qpc_paths.get("lang_prefill_qpc_path") + or qpc_paths.get("lang_qpc_path") + or self.lang_model.qpc_path + ) + return qpc_paths + def compile( self, img_size: Optional[int] = None, @@ -1506,6 +1716,8 @@ def compile( prefill_only=None, enable_chunking=False, qaic_config: Optional[dict] = None, + layerwise: bool = False, + layerwise_window_size: int = 1, **compiler_options, ) -> str: """ @@ -1565,6 +1777,33 @@ def compile( if skip_lang and skip_vision: raise ValueError("Expected at least one of 'skip_lang' or 'skip_vision' to be False") + if layerwise: + return self._run_layerwise_compile( + img_size=img_size, + vision_onnx_path=vision_onnx_path, + lang_onnx_path=lang_onnx_path, + compile_dir=compile_dir, + prefill_seq_len=prefill_seq_len, + comp_ctx_lengths_prefill=comp_ctx_lengths_prefill, + comp_ctx_lengths_decode=comp_ctx_lengths_decode, + ctx_len=ctx_len, + batch_size=batch_size, + full_batch_size=full_batch_size, + kv_cache_batch_size=kv_cache_batch_size, + num_devices=num_devices, + num_cores=num_cores, + mxfp6_matmul=mxfp6_matmul, + mxint8_kv_cache=mxint8_kv_cache, + skip_vision=skip_vision, + skip_lang=skip_lang, + use_onnx_subfunctions=use_onnx_subfunctions, + prefill_only=prefill_only, + enable_chunking=enable_chunking, + qaic_config=qaic_config, + layerwise_window_size=layerwise_window_size, + **compiler_options, + ) + if self.continuous_batching and full_batch_size is None: raise TypeError("`full_batch_size` is required when `continuous_batching=True`.") @@ -1573,6 +1812,7 @@ def compile( "KV caching requires continuous batching. Please set `full_batch_size` and " "enable `continuous_batching=True` in `from_pretrained`." ) + layerwise_cache_probe = compiler_options.pop("_layerwise_cache_probe", False) # Infer kv_cache_batch_size if not provided kv_cache_batch_size = kv_cache_batch_size or full_batch_size or batch_size @@ -1643,7 +1883,10 @@ def compile( prefill_only=prefill_only, enable_chunking=enable_chunking, prefill_seq_len=prefill_seq_len, + _layerwise_cache_probe=layerwise_cache_probe, ) + if layerwise_cache_probe: + return self.lang_model.onnx_path if hasattr(self.model, "generate_npi_file") and "node_precision_info" in compiler_options: if self.lang_model.onnx_path is None and not skip_lang: @@ -1678,54 +1921,24 @@ def compile( if not skip_lang: custom_io_lang = {} - # Inputs for output_name in output_names["lang"]: if output_name.endswith("_RetainedState"): - custom_io_lang[output_name[: -len("_RetainedState")]] = ( + compiler_output_name = _compile_io_name( + output_name, + use_onnx_subfunctions=use_onnx_subfunctions, + ) + custom_io_lang[_state_input_name(compiler_output_name)] = ( CUSTOM_IO_DTYPE_MAP[target_dtype] if ("vision_embeds" in output_name or "deepstack_features" in output_name) else kv_cache_dtype ) - - # outputs - for output_name in output_names["lang"]: - if output_name.endswith("_RetainedState"): - custom_io_lang[output_name] = ( + custom_io_lang[compiler_output_name] = ( CUSTOM_IO_DTYPE_MAP[target_dtype] if ("vision_embeds" in output_name or "deepstack_features" in output_name) else kv_cache_dtype ) - def filter_custom_io_lang(custom_io_lang, onnx_path): - # Extract filename - filename = os.path.basename(onnx_path) - - # Extract range from "merged_0-2.onnx" - match = re.search(r"merged_(\d+)-(\d+)\.onnx", filename) - if not match: - return custom_io_lang # no filtering if pattern not found - - start, end = map(int, match.groups()) # e.g. 0, 2 - - filtered = {} - - for k, v in custom_io_lang.items(): - # Keep everything that is NOT KV cache - if ("past_key." not in k) and ("past_value." not in k): - filtered[k] = v - continue - - # Extract layer index - layer_match = re.search(r"past_(?:key|value)\.(\d+)", k) - if layer_match: - idx = int(layer_match.group(1)) - if start <= idx < end: - filtered[k] = v - - return filtered - - if self.lang_model.onnx_path is not None and "merged" in str(self.lang_model.onnx_path): - custom_io_lang = filter_custom_io_lang(custom_io_lang, self.lang_model.onnx_path) + custom_io_lang = _filter_custom_io_for_onnx(custom_io_lang, self.lang_model.onnx_path) if prefill_only: specializations = specializations["lang"][:1] @@ -2432,18 +2645,14 @@ def compile( custom_io = {} target_dtype = getattr(self.model.config, "torch_dtype", torch.float32) kv_cache_dtype = "mxint8" if mxint8_kv_cache else CUSTOM_IO_DTYPE_MAP[target_dtype] - # inputs for input_name in output_names: if input_name.endswith("_RetainedState"): - custom_io[input_name[: -len("_RetainedState")]] = ( + compiler_output_name = _compile_io_name(input_name, use_onnx_subfunctions=use_onnx_subfunctions) + custom_io[_state_input_name(compiler_output_name)] = ( CUSTOM_IO_DTYPE_MAP[target_dtype] if "pixel_values" in input_name else kv_cache_dtype ) - - # outputs - for output_name in output_names: - if output_name.endswith("_RetainedState"): - custom_io[output_name] = ( - CUSTOM_IO_DTYPE_MAP[target_dtype] if "pixel_values" in output_name else kv_cache_dtype + custom_io[compiler_output_name] = ( + CUSTOM_IO_DTYPE_MAP[target_dtype] if "pixel_values" in input_name else kv_cache_dtype ) # TODO this hould be removed once the continous batching is supported for all the models. @@ -2835,6 +3044,7 @@ def from_pretrained( kv_offload: Optional[bool] = None, continuous_batching: bool = False, qaic_config: Optional[dict] = None, + layerwise: bool = False, **kwargs, ): """ @@ -2875,17 +3085,31 @@ def from_pretrained( if kwargs.get("attn_implementation", None) not in {None, "eager"}: logger.warning('Updating attn_implementation="eager"') - if kwargs.get("low_cpu_mem_usage", None): - logger.warning("Updating low_cpu_mem_usage=False") - - kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False}) + # Respect explicit low_cpu_mem_usage=True (used by the layer-wise driver + # to materialize only the active window's weights via meta init + + # sharded materialization). Default behavior remains False. + explicit_low_cpu = kwargs.get("low_cpu_mem_usage", None) is True + kwargs.update( + { + "attn_implementation": "eager", + "low_cpu_mem_usage": True if explicit_low_cpu else False, + } + ) _resolve_torch_dtype(kwargs) - model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, **kwargs) + if layerwise: + # Layer-wise mode: build the outer model on the meta device so the + # caller's ``from_pretrained`` does not pull the full checkpoint + # into RAM. compile()/export() rebuilds a real per-window model + # internally via the layer-wise driver, so the outer instance is + # only used as a config holder. + model = _build_meta_model(cls._hf_auto_class, pretrained_model_name_or_path, kwargs) + else: + model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, **kwargs) kwargs.update({"enable_proxy": enable_proxy} if enable_proxy else {}) - return cls( + instance = cls( model, kv_offload=kv_offload, continuous_batching=continuous_batching, @@ -2893,6 +3117,12 @@ def from_pretrained( qaic_config=qaic_config, **kwargs, ) + # Mark the wrapper so its compile() can default ``layerwise=True`` if + # the user forgot to pass it (the meta model cannot be exported any + # other way) and so the driver knows weights still need to be loaded. + if layerwise: + instance._layerwise_outer_meta = True + return instance MISCLASSIFIED_CAUSAL_LM_TO_QEFF_AUTO_CLASS_MAP = { @@ -3072,6 +3302,7 @@ def from_pretrained( continuous_batching: bool = False, qaic_config: Optional[dict] = None, max_seq_len_cached: Optional[int] = None, + layerwise: bool = False, *args, **kwargs, ): @@ -3127,15 +3358,28 @@ def from_pretrained( if kwargs.get("attn_implementation", None) not in {None, "eager"}: logger.warning('Updating attn_implementation="eager"') - if kwargs.get("low_cpu_mem_usage", None): - logger.warning("Updating low_cpu_mem_usage=False") - kv_offload = kwargs.pop("kv_offload", None) - kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False}) + # Respect explicit low_cpu_mem_usage=True (used by the layer-wise driver + # to materialize only the active window's weights via meta init + + # sharded materialization). Default behavior remains False. + explicit_low_cpu = kwargs.get("low_cpu_mem_usage", None) is True + kwargs.update( + { + "attn_implementation": "eager", + "low_cpu_mem_usage": True if explicit_low_cpu else False, + } + ) _resolve_torch_dtype(kwargs) - model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, *args, **kwargs) + if layerwise: + # Layer-wise mode: build the outer model on the meta device. The + # caller still gets a typed wrapper, but no checkpoint weights are + # pulled into RAM. compile()/export() rebuilds a real per-window + # model internally via the layer-wise driver. + model = _build_meta_model(cls._hf_auto_class, pretrained_model_name_or_path, kwargs) + else: + model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, *args, **kwargs) if qaic_config is not None: qaic_config["pretrained_model_name_or_path"] = pretrained_model_name_or_path @@ -3150,7 +3394,7 @@ def from_pretrained( continuous_batching=continuous_batching, **kwargs, ) - return cls( + instance = cls( model, continuous_batching=continuous_batching, qaic_config=qaic_config, @@ -3158,6 +3402,9 @@ def from_pretrained( max_seq_len_cached=max_seq_len_cached, **kwargs, ) + if layerwise: + instance._layerwise_outer_meta = True + return instance @property def get_model_config(self) -> dict: @@ -3234,6 +3481,39 @@ def get_seq_len_and_handle_specialized_prefill_model( else constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN ) + def _run_layerwise(self, *, final_compile: bool, layerwise_window_size: int, **forward_kwargs): + """Drive the layer-wise export/compile loop for CausalLM models.""" + from QEfficient.transformers.models import _layerwise + + model_id = getattr(self.model, "pretrained_path", None) + if model_id is None: + raise RuntimeError( + "layerwise=True requires the QEff model to be built via " + "QEFFAutoModelForCausalLM.from_pretrained(...). " + "Direct __init__ does not preserve the model id needed for per-window reload." + ) + config = getattr(self.model, "config", None) + torch_dtype = getattr(config, "torch_dtype", None) + + def _factory(model_id, config): + return QEFFAutoModelForCausalLM.from_pretrained( + model_id, + config=config, + torch_dtype=torch_dtype, + continuous_batching=self.continuous_batching, + low_cpu_mem_usage=True, + ) + + return _layerwise.run_layerwise( + model_id=model_id, + config=config, + qeff_factory=_factory, + compile_kwargs=forward_kwargs, + probe_qeff_model=self, + window_size=layerwise_window_size, + final_compile=final_compile, + ) + def export( self, export_dir: Optional[str] = None, @@ -3241,6 +3521,8 @@ def export( prefill_seq_len: Optional[int] = None, num_cores: int = constants.DEFAULT_AIC_NUM_CORES, moe_prefill_packed_chunk_size: int = constants.MOE_PREFILL_PACKED_CHUNK_SIZE, + layerwise: bool = False, + layerwise_window_size: int = 1, **kwargs, ) -> str: """ @@ -3268,6 +3550,18 @@ def export( "Use the default non-prefill export path for standard CausalLM decode graphs." ) + if layerwise: + return self._run_layerwise( + final_compile=False, + layerwise_window_size=layerwise_window_size, + export_dir=export_dir, + prefill_only=prefill_only, + prefill_seq_len=prefill_seq_len, + num_cores=num_cores, + moe_prefill_packed_chunk_size=moe_prefill_packed_chunk_size, + **kwargs, + ) + bs: int = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE seq_len: int = constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN @@ -3506,7 +3800,7 @@ def _legacyify_cache(obj): self.model.forward = _qeff_patched_forward self.model._qeff_export_gemma3_cache_patch = True - if os.environ.get("LAYERWISE_EXPORT", "False") == "True": + if QEFFBaseModel._layerwise_active: return self._export_layerwise( example_inputs, output_names=output_names, @@ -3678,6 +3972,8 @@ def compile( enable_chunking: Optional[bool] = False, moe_prefill_packed_chunk_size: int = constants.MOE_PREFILL_PACKED_CHUNK_SIZE, retain_full_kv: Optional[bool] = None, + layerwise: bool = False, + layerwise_window_size: int = 1, **compiler_options, ) -> str: """ @@ -3766,6 +4062,32 @@ def compile( If `prefill_seq_len` is less than `num_speculative_tokens + 1` for TLM models. """ + if layerwise: + return self._run_layerwise( + final_compile=True, + layerwise_window_size=layerwise_window_size, + onnx_path=onnx_path, + compile_dir=compile_dir, + prefill_seq_len=prefill_seq_len, + ctx_len=ctx_len, + comp_ctx_lengths_prefill=comp_ctx_lengths_prefill, + comp_ctx_lengths_decode=comp_ctx_lengths_decode, + batch_size=batch_size, + full_batch_size=full_batch_size, + kv_cache_batch_size=kv_cache_batch_size, + num_devices=num_devices, + num_cores=num_cores, + mxfp6_matmul=mxfp6_matmul, + mxint8_kv_cache=mxint8_kv_cache, + num_speculative_tokens=num_speculative_tokens, + prefill_only=prefill_only, + use_onnx_subfunctions=use_onnx_subfunctions, + offload_pt_weights=offload_pt_weights, + enable_chunking=enable_chunking, + moe_prefill_packed_chunk_size=moe_prefill_packed_chunk_size, + retain_full_kv=retain_full_kv, + **compiler_options, + ) if self.model.qaic_config is not None and self.model.qaic_config.get("mla_absorption", None) is not None: mla_absorption = self.model.qaic_config["mla_absorption"] cache_compressed = mla_absorption.get("cache_compressed", False) @@ -3961,43 +4283,22 @@ def compile( for suffix in ["", "_RetainedState"]: for i in range(self.num_layers): for kv in ["key", "value"]: - custom_io[f"past_{kv}.{i}{suffix}"] = kv_cache_dtype + name = _compile_io_name( + f"past_{kv}.{i}{suffix}", + use_onnx_subfunctions=use_onnx_subfunctions, + ) + custom_io[name] = kv_cache_dtype else: for suffix in ["", "_RetainedState"]: for i in range(self.num_layers): - custom_io[f"compressed_kv.{i}{suffix}"] = kv_cache_dtype - custom_io[f"k_pe.{i}{suffix}"] = kv_cache_dtype - - def filter_custom_io(custom_io_lang, onnx_path): - # Extract filename - filename = os.path.basename(onnx_path) - - # Extract range from "merged_0-2.onnx" - match = re.search(r"merged_(\d+)-(\d+)\.onnx", filename) - if not match: - return custom_io_lang # no filtering if pattern not found - - start, end = map(int, match.groups()) # e.g. 0, 2 - - filtered = {} - - for k, v in custom_io_lang.items(): - # Keep everything that is NOT KV cache - if ("past_key." not in k) and ("past_value." not in k): - filtered[k] = v - continue - - # Extract layer index - layer_match = re.search(r"past_(?:key|value)\.(\d+)", k) - if layer_match: - idx = int(layer_match.group(1)) - if start <= idx < end: - filtered[k] = v - - return filtered + for prefix in ("compressed_kv", "k_pe"): + name = _compile_io_name( + f"{prefix}.{i}{suffix}", + use_onnx_subfunctions=use_onnx_subfunctions, + ) + custom_io[name] = kv_cache_dtype - if onnx_path is not None and "merged" in str(onnx_path): - custom_io = filter_custom_io(custom_io, onnx_path) + custom_io = _filter_custom_io_for_onnx(custom_io, onnx_path) qpc_path = self._compile( onnx_path=onnx_path, @@ -4354,15 +4655,11 @@ def compile( custom_io["input_features"] = kv_cache_dtype - # Slice output_names to get input names - for output_name in output_names: - if output_name.endswith("_RetainedState"): - custom_io[output_name[: -len("_RetainedState")]] = kv_cache_dtype - - # Get output names for output_name in output_names: if output_name.endswith("_RetainedState"): - custom_io[output_name] = kv_cache_dtype + compiler_output_name = _compile_io_name(output_name, use_onnx_subfunctions=use_onnx_subfunctions) + custom_io[_state_input_name(compiler_output_name)] = kv_cache_dtype + custom_io[compiler_output_name] = kv_cache_dtype return self._compile( onnx_path=onnx_path, diff --git a/QEfficient/utils/_utils.py b/QEfficient/utils/_utils.py index 24ab88aa0..dc1dcbc0c 100755 --- a/QEfficient/utils/_utils.py +++ b/QEfficient/utils/_utils.py @@ -706,9 +706,16 @@ def __repr__(self): def dump_qconfig(func): def wrapper(self, *args, **kwargs): result = func(self, *args, **kwargs) + # Skip qconfig dumping when no QPC was actually produced (e.g. the + # layer-wise export short-circuits compile to return an ONNX path). + # Without this guard we'd hit a TypeError inside create_and_dump_qconfigs + # and surface a confusing user-facing message. + qpc_path = getattr(self, "qpc_path", None) + if qpc_path is None: + return result try: create_and_dump_qconfigs( - self.qpc_path, + qpc_path, self.onnx_path, self.get_model_config, [cls.__name__ for cls in self._pytorch_transforms], @@ -724,7 +731,7 @@ def wrapper(self, *args, **kwargs): }, ) except Exception as e: - print(f"An unexpected error occurred while dumping the qconfig: {e}") + logger.debug("Skipping qconfig dump: %s", e) return result return wrapper diff --git a/QEfficient/utils/export_utils.py b/QEfficient/utils/export_utils.py index 901484e72..6ed7f97cc 100755 --- a/QEfficient/utils/export_utils.py +++ b/QEfficient/utils/export_utils.py @@ -40,6 +40,7 @@ def export_wrapper(func): """ def wrapper(self, *args, **kwargs): + cache_probe = kwargs.pop("_layerwise_cache_probe", False) # 1. Setup ONNX subfunctions if requested if use_onnx_subfunctions := kwargs.pop("use_onnx_subfunctions", False): args, kwargs = _setup_onnx_subfunctions(self, args, kwargs) @@ -52,12 +53,15 @@ def wrapper(self, *args, **kwargs): export_dir = export_dir.with_name(export_dir.name + "-" + export_hash) kwargs["export_dir"] = export_dir self.export_hash = export_hash + if cache_probe: + kwargs["_layerwise_cache_probe"] = True # 4. Execute the actual export onnx_path = func(self, *args, **kwargs) # 5. Save export metadata - _save_export_metadata(export_dir, filtered_hash_params) + if not cache_probe: + _save_export_metadata(export_dir, filtered_hash_params) # 6. Always cleanup subfunctions if they were setup if use_onnx_subfunctions: @@ -108,6 +112,10 @@ def _generate_export_hash(qeff_model, args, kwargs, func): bound_args = new_sig.bind(*args, **kwargs) bound_args.apply_defaults() all_args = bound_args.arguments + if func.__name__ == "_export_layerwise": + export_kwargs = dict(all_args.get("export_kwargs") or {}) + export_kwargs["_qeff_layerwise_export"] = True + all_args["export_kwargs"] = export_kwargs # Use the model's current configuration for hashing to ensure any post-load modifications are captured # TODO: Replace with get_model_config property of modeling classes and remove the if-else diff --git a/examples/disagg_serving/qwen3moe_layerwise.py b/examples/disagg_serving/qwen3moe_layerwise.py index ea29e1174..3c9339b30 100644 --- a/examples/disagg_serving/qwen3moe_layerwise.py +++ b/examples/disagg_serving/qwen3moe_layerwise.py @@ -5,210 +5,31 @@ # # ----------------------------------------------------------------------------- -import functools -import os -import time -from pathlib import Path +"""Layerwise prefill compile example for Qwen3-MoE (disaggregated serving). + +The orchestration loop that previously lived in this script has been moved +behind the ``layerwise=True`` flag on ``.compile()`` / ``.export()``. + +Note: ``layerwise=True`` is a provisional API and is scheduled for deprecation +once first-class multi-window export lands. Supported model types: +``qwen3_vl_moe``, ``qwen3_5_moe``, ``qwen3_moe``. +""" -import transformers from transformers import AutoConfig, AutoTokenizer -import QEfficient from QEfficient import QEFFAutoModelForCausalLM model_id = "Qwen/Qwen3-235B-A22B-Instruct-2507" # weights are not required to convert to fp32 -# model_id = "yujiepan/qwen3-moe-tiny-random" -prompt = """ -Explain quantum computing in simple terms. -""" -config = AutoConfig.from_pretrained(model_id) -tokenizer = AutoTokenizer.from_pretrained(model_id) -config = AutoConfig.from_pretrained(model_id) - -tokenizer = AutoTokenizer.from_pretrained(model_id) PREFILL_SEQ_LEN = 4 CTX_LEN = 128 +config = AutoConfig.from_pretrained(model_id) +tokenizer = AutoTokenizer.from_pretrained(model_id) -def _ensure_pretrained_window_attrs(): - if not hasattr(transformers.modeling_utils.PreTrainedModel, "_start"): - transformers.modeling_utils.PreTrainedModel._start = 0 - if not hasattr(transformers.modeling_utils.PreTrainedModel, "_end"): - transformers.modeling_utils.PreTrainedModel._end = 0 - - -def _build_layer_windows(total_layers: int, window_size: int): - if total_layers <= 0: - raise ValueError(f"Invalid total_layers={total_layers}. Expected: total_layers > 0.") - if window_size <= 0: - raise ValueError(f"Invalid window_size={window_size}. Expected: window_size > 0.") - - windows = [] - end = total_layers - while end > 0: - start = max(0, end - window_size) - windows.append((start, end)) - end = start - - return windows - - -def _null_outside_window_layers(model): - start = int(getattr(transformers.modeling_utils.PreTrainedModel, "_start", 0)) - end = int(getattr(transformers.modeling_utils.PreTrainedModel, "_end", 0)) - layers = getattr(getattr(model, "model", None), "layers", None) - if layers is None: - return - for idx, _ in enumerate(layers): - if idx < start or idx >= end: - layers[idx] = None - - -def _install_window_patch(model_cls): - if getattr(model_cls, "_window_patch_installed", False): - return - - original_init = model_cls.__init__ - - @functools.wraps(original_init) - def patched_init(self, *args, **kwargs): - original_init(self, *args, **kwargs) - _null_outside_window_layers(self) - - model_cls.__init__ = patched_init - model_cls._window_patch_installed = True - - -def _resolve_export_root(onnx_path: Path) -> Path: - parts = list(onnx_path.parts) - if "onnx_layerwise_tmp" in parts: - marker_idx = parts.index("onnx_layerwise_tmp") - return Path(*parts[:marker_idx]) - return onnx_path.parent - - -def _install_shard_window_patch(): - if getattr(transformers.modeling_utils, "_window_shard_patch_installed", False): - return - - original_get_checkpoint_shard_files = transformers.modeling_utils.get_checkpoint_shard_files - - @functools.wraps(original_get_checkpoint_shard_files) - def patched_get_checkpoint_shard_files(*args, **kwargs): - shard_files, metadata = original_get_checkpoint_shard_files(*args, **kwargs) - weight_map = metadata.get("weight_map") - if not weight_map: - return shard_files, metadata - - start = int(getattr(transformers.modeling_utils.PreTrainedModel, "_start", 0)) - end = int(getattr(transformers.modeling_utils.PreTrainedModel, "_end", 0)) - if end <= start: - return shard_files, metadata - - selected_prefixes = tuple(f"model.layers.{layer_idx}." for layer_idx in range(start, end)) - filtered_weight_map = {} - for checkpoint_key, shard_name in weight_map.items(): - if checkpoint_key.startswith("model.layers."): - if checkpoint_key.startswith(selected_prefixes): - filtered_weight_map[checkpoint_key] = shard_name - continue - filtered_weight_map[checkpoint_key] = shard_name - - if not filtered_weight_map: - return shard_files, metadata - - shard_name_to_path = {path.split("/")[-1]: path for path in shard_files} - filtered_shard_names = sorted(set(filtered_weight_map.values())) - filtered_shard_files = [shard_name_to_path[name] for name in filtered_shard_names if name in shard_name_to_path] - if not filtered_shard_files: - return shard_files, metadata - - metadata["weight_map"] = filtered_weight_map - metadata["all_checkpoint_keys"] = list(filtered_weight_map.keys()) - return filtered_shard_files, metadata - - transformers.modeling_utils.get_checkpoint_shard_files = patched_get_checkpoint_shard_files - transformers.modeling_utils._window_shard_patch_installed = True - - -_ensure_pretrained_window_attrs() -_install_shard_window_patch() -text_config = getattr(config, "text_config", config) -resolved_total_layers = getattr(text_config, "num_hidden_layers", None) -if resolved_total_layers is None: - raise ValueError("Could not resolve `num_hidden_layers` from config.") - -# Layerwise window size. `1` keeps only one decoder layer active per window. -window_size = 1 -total_layers = 2 # resolved_total_layers # config.num_hidden_layers = 1 -windows = _build_layer_windows(total_layers=total_layers, window_size=window_size) -qeff_model = None -first_onnx_path = None -export_start = time.perf_counter() - -os.environ["LAYERWISE_EXPORT"] = "True" -for start, end in windows: - transformers.modeling_utils.PreTrainedModel._start = start - transformers.modeling_utils.PreTrainedModel._end = end - transformers.modeling_utils.PreTrainedModel._total_layers = total_layers - QEfficient.transformers.models.qwen3_moe.modeling_qwen3_moe.QEffQwen3MoeModel._start = start - QEfficient.transformers.models.qwen3_moe.modeling_qwen3_moe.QEffQwen3MoeModel._end = end - QEfficient.transformers.models.qwen3_moe.modeling_qwen3_moe.QEffQwen3MoeModel._total_layers = total_layers - QEfficient.base.modeling_qeff.QEFFBaseModel._start = start - QEfficient.base.modeling_qeff.QEFFBaseModel._end = end - QEfficient.base.modeling_qeff.QEFFBaseModel._total_layers = total_layers - _install_window_patch(transformers.models.qwen3_moe.modeling_qwen3_moe.Qwen3MoeForCausalLM) - qeff_model = QEFFAutoModelForCausalLM.from_pretrained(model_id, config=config) - if hasattr(qeff_model, "model"): - _null_outside_window_layers(qeff_model.model) - - # Following command errors out by default, the user is supposed to run the printed command and provide the generated qpc path as prefill_qpc_path commenting out lines 55-68 - - # prefill_qpc_path = "" - ################################# prefill - - onnx_path = qeff_model.compile( - prefill_seq_len=PREFILL_SEQ_LEN, - ctx_len=CTX_LEN, - num_cores=16, - mxfp6_matmul=True, - mxint8_kv_cache=True, - num_devices=1, - split_retained_state_io=True, - mos=1, - aic_enable_depth_first=True, - num_speculative_tokens=None, - prefill_only=True, - enable_chunking=True, - use_onnx_subfunctions=True, - ) - ################################# decode - # onnx_path = qeff_model.compile( - # prefill_seq_len=PREFILL_SEQ_LEN, - # ctx_len=CTX_LEN, - # num_cores=16, - # mxfp6_matmul=True, - # mxint8_kv_cache=True, - # num_devices=1, - # split_retained_state_io=True, - # mos=1, - # aic_enable_depth_first=True, - # num_speculative_tokens=None, - # prefill_only=False, - # use_onnx_subfunctions=True, - # ) - if first_onnx_path is None: - first_onnx_path = Path(onnx_path) +qeff_model = QEFFAutoModelForCausalLM.from_pretrained(model_id, config=config, layerwise=True) -if first_onnx_path is None: - raise RuntimeError("No ONNX path produced during compilation.") -export_root = _resolve_export_root(first_onnx_path) -final_onnx_path = QEfficient.utils.layerwise_pipeline(str(export_root)) -print(f"Layer-wise language export completed. Final artifact/root: {final_onnx_path}") -os.environ["LAYERWISE_EXPORT"] = "False" qpc_path = qeff_model.compile( - onnx_path=final_onnx_path, prefill_seq_len=PREFILL_SEQ_LEN, ctx_len=CTX_LEN, num_cores=16, @@ -222,97 +43,8 @@ def patched_get_checkpoint_shard_files(*args, **kwargs): prefill_only=True, enable_chunking=True, use_onnx_subfunctions=True, + layerwise=True, + layerwise_window_size=1, ) print(f"QPC path: {qpc_path}") - -# inputs = tokenizer(prompt, return_tensors="np", padding=True) -# position_ids = inputs["attention_mask"].sum(1, keepdims=True) -# generation_len = CTX_LEN - position_ids.max() -# padded_len = inputs["input_ids"].shape[1] -# num_chunks = -(padded_len // -PREFILL_SEQ_LEN) # ceil divide without float -# padded_len = num_chunks * PREFILL_SEQ_LEN # Convert to a multiple of prompt_len -# inputs = tokenizer(prompt, return_tensors="np", padding="max_length", max_length=padded_len) -# inputs["position_ids"] = np.where(inputs.pop("attention_mask"), np.arange(padded_len), -1) -# inputs.pop("token_type_ids", None) -# inputs = {k: torch.from_numpy(v) for k, v in inputs.items()} -# inputs.pop("past_key_values", None) -# inputs = {k: v.detach().numpy() for k, v in inputs.items()} - - -# prefill_session = QAICInferenceSession(prefill_qpc_path) - - -# all_outputs = [] -# for i in range(num_chunks): -# chunk_inputs = inputs.copy() -# chunk_inputs["input_ids"] = inputs["input_ids"][:, i * PREFILL_SEQ_LEN : (i + 1) * PREFILL_SEQ_LEN] -# chunk_inputs["position_ids"] = inputs["position_ids"][:, i * PREFILL_SEQ_LEN : (i + 1) * PREFILL_SEQ_LEN] -# ins = time.time() -# qpc_out = prefill_session.run(chunk_inputs) -# print(f"time for this run={time.time() - ins}") -# for i in range(config.num_hidden_layers): -# inputs[f"past_key.{i}"] = qpc_out[f"past_key.{i}_RetainedState"] -# inputs[f"past_value.{i}"] = qpc_out[f"past_value.{i}_RetainedState"] - -# all_outputs.append(np.argmax(qpc_out["logits"])) -# print(all_outputs) -# print(">>>>>>>> export for prefill is done <<<<<<<<<<<") -# ########################### - -# decode_qpc_path = qeff_model.compile( -# prefill_seq_len=1, -# ctx_len=CTX_LEN, -# num_cores=16, -# mxfp6_matmul=True, -# mxint8_kv_cache=True, -# num_devices=1, -# mos=1, -# aic_enable_depth_first=True, -# num_speculative_tokens=None, -# offload_pt_weights=False, # Need the weights in memory for prefill-model export/compilation in the next step -# retain_full_kv=True, -# ) -# decode_session = QAICInferenceSession(decode_qpc_path) - -# decode_inputs = { -# "input_ids": np.argmax(qpc_out["logits"]).reshape(1, 1), -# "position_ids": np.max(inputs["position_ids"]).reshape(1, 1) + 1, -# } -# for i in range(config.num_hidden_layers): -# decode_inputs[f"past_key.{i}"] = qpc_out[f"past_key.{i}_RetainedState"] -# decode_inputs[f"past_value.{i}"] = qpc_out[f"past_value.{i}_RetainedState"] - -# st = time.time() -# decode_out = decode_session.run(decode_inputs) -# print(f"time for first run of decode with KV as input = {time.time() - st} sec\n") -# all_outputs.append(np.argmax(decode_out["logits"])) -# pos_id = np.max(decode_inputs["position_ids"]).reshape(1, 1) + 1 -# loop_decode_inputs = { -# "input_ids": np.argmax(decode_out["logits"]).reshape(1, 1), -# "position_ids": pos_id, -# } - -# for i in range(config.num_hidden_layers): -# loop_decode_inputs[f"past_key.{i}"] = decode_out[f"past_key.{i}_RetainedState"] -# loop_decode_inputs[f"past_value.{i}"] = decode_out[f"past_value.{i}_RetainedState"] - -# st = time.time() -# for i in range(generation_len - 2): -# decode_out = decode_session.run(loop_decode_inputs) -# all_outputs.append(np.argmax(decode_out["logits"])) -# pos_id += 1 -# for i in range(config.num_hidden_layers): -# loop_decode_inputs[f"past_key.{i}"] = decode_out[f"past_key.{i}_RetainedState"] -# loop_decode_inputs[f"past_value.{i}"] = decode_out[f"past_value.{i}_RetainedState"] - -# loop_decode_inputs.update( -# { -# "input_ids": np.argmax(decode_out["logits"]).reshape(1, 1), -# "position_ids": pos_id, -# } -# ) -# ft = time.time() - -# print(f"decode tok/sec={(generation_len - 2) / (ft - st)}") -# print(f"input\n{prompt}\noutput\n{tokenizer.decode(all_outputs)}") diff --git a/examples/image_text_to_text/models/qwen3_5_moe/qwen3_5_moe_layerwise.py b/examples/image_text_to_text/models/qwen3_5_moe/qwen3_5_moe_layerwise.py index cdf25ee36..e1924d837 100644 --- a/examples/image_text_to_text/models/qwen3_5_moe/qwen3_5_moe_layerwise.py +++ b/examples/image_text_to_text/models/qwen3_5_moe/qwen3_5_moe_layerwise.py @@ -5,319 +5,57 @@ # # ----------------------------------------------------------------------------- -import functools -import os -from pathlib import Path +"""Layerwise prefill compile example for Qwen3.5-MoE. + +The orchestration loop that previously lived in this script has been moved +behind the ``layerwise=True`` flag on ``.compile()`` / ``.export()``. + +Note: ``layerwise=True`` is a provisional API and is scheduled for deprecation +once first-class multi-window export lands. Supported model types: +``qwen3_vl_moe``, ``qwen3_5_moe``, ``qwen3_moe``. +""" import torch -import transformers from transformers import AutoConfig -import QEfficient from QEfficient import QEFFAutoModelForImageTextToText MODEL_ID = "Qwen/Qwen3.5-397B-A17B" -PREFILL_SEQ_LEN = 32 -CTX_LEN = 4096 -TEXT_WINDOW_SIZE = 1 - -# For quick local validation only (keep disabled for real export) -# TEST_TEXT_LAYERS = 4 - -# Export controls -BATCH_SIZE = 1 -NUM_CORES = 16 -NUM_DEVICES = 1 -HEIGHT = 354 -WIDTH = 536 - - -def _ensure_pretrained_window_attrs(): - if not hasattr(transformers.modeling_utils.PreTrainedModel, "_start"): - transformers.modeling_utils.PreTrainedModel._start = 0 - if not hasattr(transformers.modeling_utils.PreTrainedModel, "_end"): - transformers.modeling_utils.PreTrainedModel._end = 0 - if not hasattr(transformers.modeling_utils.PreTrainedModel, "_total_layers"): - transformers.modeling_utils.PreTrainedModel._total_layers = 0 - if not hasattr(transformers.modeling_utils.PreTrainedModel, "_text_start"): - transformers.modeling_utils.PreTrainedModel._text_start = 0 - if not hasattr(transformers.modeling_utils.PreTrainedModel, "_text_end"): - transformers.modeling_utils.PreTrainedModel._text_end = 0 - if not hasattr(transformers.modeling_utils.PreTrainedModel, "_text_total_layers"): - transformers.modeling_utils.PreTrainedModel._text_total_layers = 0 - - -def _build_layer_windows(total_layers: int, window_size: int): - if total_layers <= 0: - raise ValueError(f"Invalid total_layers={total_layers}. Expected: total_layers > 0.") - if window_size <= 0: - raise ValueError(f"Invalid window_size={window_size}. Expected: window_size > 0.") - - windows = [] - end = total_layers - while end > 0: - start = max(0, end - window_size) - windows.append((start, end)) - end = start - return windows - - -def _get_text_layers_container(model): - # VLM path first - if ( - hasattr(model, "model") - and hasattr(model.model, "language_model") - and hasattr(model.model.language_model, "layers") - ): - return model.model.language_model.layers - # LLM-compatible fallbacks - if hasattr(model, "model") and hasattr(model.model, "layers"): - return model.model.layers - if hasattr(model, "language_model") and hasattr(model.language_model, "layers"): - return model.language_model.layers - if hasattr(model, "layers"): - return model.layers - return None - - -def _null_outside_window_layers(model, apply_text: bool = True): - if apply_text: - text_start = int( - getattr( - transformers.modeling_utils.PreTrainedModel, - "_text_start", - getattr(transformers.modeling_utils.PreTrainedModel, "_start", 0), - ) - ) - text_end = int( - getattr( - transformers.modeling_utils.PreTrainedModel, - "_text_end", - getattr(transformers.modeling_utils.PreTrainedModel, "_end", 0), - ) - ) - text_layers = _get_text_layers_container(model) - if text_layers is not None and text_end > text_start: - for idx, _ in enumerate(text_layers): - if idx < text_start or idx >= text_end: - text_layers[idx] = None - - -def _install_window_patch(model_cls): - if getattr(model_cls, "_window_patch_installed", False): - return - - original_init = model_cls.__init__ - - @functools.wraps(original_init) - def patched_init(self, *args, **kwargs): - original_init(self, *args, **kwargs) - _null_outside_window_layers(self, apply_text=True) - - model_cls.__init__ = patched_init - model_cls._window_patch_installed = True - - -def _resolve_export_root(onnx_path: Path) -> Path: - parts = list(onnx_path.parts) - if "onnx_layerwise_tmp" in parts: - marker_idx = parts.index("onnx_layerwise_tmp") - return Path(*parts[:marker_idx]) - return onnx_path.parent - - -def _install_shard_window_patch(): - if getattr(transformers.modeling_utils, "_window_shard_patch_installed", False): - return - - original_get_checkpoint_shard_files = transformers.modeling_utils.get_checkpoint_shard_files - - @functools.wraps(original_get_checkpoint_shard_files) - def patched_get_checkpoint_shard_files(*args, **kwargs): - shard_files, metadata = original_get_checkpoint_shard_files(*args, **kwargs) - weight_map = metadata.get("weight_map") - if not weight_map: - return shard_files, metadata - - start = int(getattr(transformers.modeling_utils.PreTrainedModel, "_start", 0)) - end = int(getattr(transformers.modeling_utils.PreTrainedModel, "_end", 0)) - text_start = int(getattr(transformers.modeling_utils.PreTrainedModel, "_text_start", start)) - text_end = int(getattr(transformers.modeling_utils.PreTrainedModel, "_text_end", end)) - has_text_window = text_end > text_start - if not has_text_window: - return shard_files, metadata - - selected_text_prefixes = tuple( - [f"model.layers.{layer_idx}." for layer_idx in range(text_start, text_end)] - + [f"model.language_model.layers.{layer_idx}." for layer_idx in range(text_start, text_end)] - ) - filtered_weight_map = {} - for checkpoint_key, shard_name in weight_map.items(): - if checkpoint_key.startswith("model.layers.") or checkpoint_key.startswith("model.language_model.layers."): - if not has_text_window or checkpoint_key.startswith(selected_text_prefixes): - filtered_weight_map[checkpoint_key] = shard_name - continue - filtered_weight_map[checkpoint_key] = shard_name - - if not filtered_weight_map: - return shard_files, metadata - - shard_name_to_path = {path.split("/")[-1]: path for path in shard_files} - filtered_shard_names = sorted(set(filtered_weight_map.values())) - filtered_shard_files = [shard_name_to_path[name] for name in filtered_shard_names if name in shard_name_to_path] - if not filtered_shard_files: - return shard_files, metadata - - metadata["weight_map"] = filtered_weight_map - metadata["all_checkpoint_keys"] = list(filtered_weight_map.keys()) - return filtered_shard_files, metadata - transformers.modeling_utils.get_checkpoint_shard_files = patched_get_checkpoint_shard_files - transformers.modeling_utils._window_shard_patch_installed = True - - -def _set_layer_windows( - text_start: int, - text_end: int, - text_total_layers: int, -): - transformers.modeling_utils.PreTrainedModel._start = text_start - transformers.modeling_utils.PreTrainedModel._end = text_end - transformers.modeling_utils.PreTrainedModel._total_layers = text_total_layers - transformers.modeling_utils.PreTrainedModel._text_start = text_start - transformers.modeling_utils.PreTrainedModel._text_end = text_end - transformers.modeling_utils.PreTrainedModel._text_total_layers = text_total_layers - - qeff_mod = QEfficient.transformers.models.qwen3_5_moe.modeling_qwen3_5_moe - qeff_mod.QEffQwen3_5MoeTextModel._start = text_start - qeff_mod.QEffQwen3_5MoeTextModel._end = text_end - qeff_mod.QEffQwen3_5MoeTextModel._total_layers = text_total_layers - - QEfficient.base.modeling_qeff.QEFFBaseModel._start = text_start - QEfficient.base.modeling_qeff.QEFFBaseModel._end = text_end - QEfficient.base.modeling_qeff.QEFFBaseModel._total_layers = text_total_layers - - -def _stitch_layerwise_if_available(export_root: Path): - # Some branches expose this helper; fall back gracefully when unavailable. - pipeline_fn = getattr(QEfficient.utils, "layerwise_pipeline", None) - if callable(pipeline_fn): - return pipeline_fn(str(export_root)) - print(f"layerwise_pipeline() not found. Layer-wise ONNX shards kept under: {export_root / 'onnx_layerwise_tmp'}") - return str(export_root / "onnx_layerwise_tmp") +def main(): + config = AutoConfig.from_pretrained(MODEL_ID) + config.torch_dtype = "float32" -def _new_qeff_model(model_id: str, config): - return QEFFAutoModelForImageTextToText.from_pretrained( - model_id, + qeff_model = QEFFAutoModelForImageTextToText.from_pretrained( + MODEL_ID, attn_implementation="eager", kv_offload=True, config=config, torch_dtype=torch.float32, + layerwise=True, ) - -def main(): - config = AutoConfig.from_pretrained(MODEL_ID) - config.torch_dtype = "float32" - - # if TEST_TEXT_LAYERS: - # config.text_config.num_hidden_layers = TEST_TEXT_LAYERS - - text_config = getattr(config, "text_config", config) - # config.vision_config.depth = 3 - text_total_layers = getattr(text_config, "num_hidden_layers", None) - if text_total_layers is None: - raise ValueError("Could not resolve `num_hidden_layers` from config.text_config.") - config.text_config.num_hidden_layers = text_total_layers - _ensure_pretrained_window_attrs() - _install_shard_window_patch() - - hf_qwen_mod = transformers.models.qwen3_5_moe.modeling_qwen3_5_moe - _install_window_patch(hf_qwen_mod.Qwen3_5MoeForConditionalGeneration) - _install_window_patch(hf_qwen_mod.Qwen3_5MoeForCausalLM) - - text_windows = _build_layer_windows(total_layers=text_total_layers, window_size=TEXT_WINDOW_SIZE) - # Keep layerwise only on text path in this loop. - num_windows = len(text_windows) - first_onnx_path = None - os.environ["LAYERWISE_EXPORT"] = "True" - for window_idx in range(num_windows): - text_start, text_end = text_windows[window_idx] if window_idx < len(text_windows) else (0, 0) - skip_lang_for_window = window_idx >= len(text_windows) - - _set_layer_windows( - text_start=text_start, - text_end=text_end, - text_total_layers=text_total_layers, - ) - print( - f"Exporting window {window_idx + 1}/{num_windows} " - f"text=[{text_start},{text_end})/{text_total_layers} " - f"skip_lang={skip_lang_for_window}" - ) - - qeff_model = _new_qeff_model(MODEL_ID, config) - if hasattr(qeff_model, "model"): - _null_outside_window_layers( - qeff_model.model, - apply_text=not skip_lang_for_window, - ) - - onnx_path = qeff_model.compile( - batch_size=BATCH_SIZE, - prefill_seq_len=PREFILL_SEQ_LEN, - ctx_len=CTX_LEN, - num_cores=NUM_CORES, - num_devices=NUM_DEVICES, - height=HEIGHT, - width=WIDTH, - mxfp6_matmul=False, - aic_enable_depth_first=True, - skip_vision=True, - skip_lang=skip_lang_for_window, - prefill_only=True, - use_onnx_subfunctions=True, - enable_chunking=True, - mos=1, - user_tiled=True, - ) - - if first_onnx_path is None: - first_onnx_path = Path(str(onnx_path["lang_prefill_qpc_path"])) - - if first_onnx_path is None: - raise RuntimeError("No ONNX path produced during layer-wise language export.") - - export_root = _resolve_export_root(first_onnx_path) - final_artifact = _stitch_layerwise_if_available(export_root) - print(f"Layer-wise language export completed. Final artifact/root: {final_artifact}") - - os.environ["LAYERWISE_EXPORT"] = "False" qpc_path = qeff_model.compile( - lang_onnx_path=final_artifact, - batch_size=BATCH_SIZE, - prefill_seq_len=PREFILL_SEQ_LEN, - ctx_len=CTX_LEN, - num_cores=NUM_CORES, - num_devices=NUM_DEVICES, - height=HEIGHT, - width=WIDTH, - mxfp6_matmul=False, + batch_size=1, + prefill_seq_len=32, + ctx_len=4096, + num_cores=16, + num_devices=1, + height=354, + width=536, + mxfp6_matmul=True, aic_enable_depth_first=True, skip_vision=True, - skip_lang=skip_lang_for_window, - prefill_only=True, + split_retained_state_io=True, use_onnx_subfunctions=True, - enable_chunking=True, + prefill_only=True, mos=1, + layerwise=True, + layerwise_window_size=1, ) - print(f"Final QPC path: {qpc_path}") if __name__ == "__main__": main() - - -# /opt/qti-aic/exec/qaic-compile -aic-hw -aic-hw-version=ai100 -m=/home/abhishek/.cache/qeff_models/Qwen3_5MoeForConditionalGeneration/Qwen3_5MoeDecoderWrapper-61a4400d63d1b0bb/final_data/merged_0-2.onnx -retained-state -convert-to-fp16 -aic-num-cores=16 -aic-enable-depth-first -mos=1 -network-specialization-config=/home/abhishek/.cache/qeff_models/Qwen3_5MoeForConditionalGeneration/Qwen3_5MoeDecoderWrapper-61a4400d63d1b0bb/final_data/specializations.json -custom-IO-list-file=/home/abhishek/.cache/qeff_models/Qwen3_5MoeForConditionalGeneration/Qwen3_5MoeDecoderWrapper-61a4400d63d1b0bb/final_data/qpc_binaries/custom_io.yaml -aic-binary-dir=/home/abhishek/.cache/qeff_models/Qwen3_5MoeForConditionalGeneration/Qwen3_5MoeDecoderWrapper-61a4400d63d1b0bb/final_data/qpc_binaries/qpc diff --git a/examples/image_text_to_text/models/qwen3_5_moe/qwen3_5_moe_layerwise_decode.py b/examples/image_text_to_text/models/qwen3_5_moe/qwen3_5_moe_layerwise_decode.py index a5b7475f7..1e402438d 100644 --- a/examples/image_text_to_text/models/qwen3_5_moe/qwen3_5_moe_layerwise_decode.py +++ b/examples/image_text_to_text/models/qwen3_5_moe/qwen3_5_moe_layerwise_decode.py @@ -5,317 +5,56 @@ # # ----------------------------------------------------------------------------- -import functools -import os -from pathlib import Path +"""Layerwise decode compile example for Qwen3.5-MoE. + +The orchestration loop that previously lived in this script has been moved +behind the ``layerwise=True`` flag on ``.compile()`` / ``.export()``. + +Note: ``layerwise=True`` is a provisional API and is scheduled for deprecation +once first-class multi-window export lands. Supported model types: +``qwen3_vl_moe``, ``qwen3_5_moe``, ``qwen3_moe``. +""" import torch -import transformers from transformers import AutoConfig -import QEfficient from QEfficient import QEFFAutoModelForImageTextToText MODEL_ID = "Qwen/Qwen3.5-397B-A17B" -PREFILL_SEQ_LEN = 1 -CTX_LEN = 4096 -TEXT_WINDOW_SIZE = 1 - -# For quick local validation only (keep disabled for real export) -# TEST_TEXT_LAYERS = 4 - -# Export controls -BATCH_SIZE = 1 -NUM_CORES = 16 -NUM_DEVICES = 1 -HEIGHT = 354 -WIDTH = 536 - - -def _ensure_pretrained_window_attrs(): - if not hasattr(transformers.modeling_utils.PreTrainedModel, "_start"): - transformers.modeling_utils.PreTrainedModel._start = 0 - if not hasattr(transformers.modeling_utils.PreTrainedModel, "_end"): - transformers.modeling_utils.PreTrainedModel._end = 0 - if not hasattr(transformers.modeling_utils.PreTrainedModel, "_total_layers"): - transformers.modeling_utils.PreTrainedModel._total_layers = 0 - if not hasattr(transformers.modeling_utils.PreTrainedModel, "_text_start"): - transformers.modeling_utils.PreTrainedModel._text_start = 0 - if not hasattr(transformers.modeling_utils.PreTrainedModel, "_text_end"): - transformers.modeling_utils.PreTrainedModel._text_end = 0 - if not hasattr(transformers.modeling_utils.PreTrainedModel, "_text_total_layers"): - transformers.modeling_utils.PreTrainedModel._text_total_layers = 0 - - -def _build_layer_windows(total_layers: int, window_size: int): - if total_layers <= 0: - raise ValueError(f"Invalid total_layers={total_layers}. Expected: total_layers > 0.") - if window_size <= 0: - raise ValueError(f"Invalid window_size={window_size}. Expected: window_size > 0.") - - windows = [] - end = total_layers - while end > 0: - start = max(0, end - window_size) - windows.append((start, end)) - end = start - return windows - - -def _get_text_layers_container(model): - # VLM path first - if ( - hasattr(model, "model") - and hasattr(model.model, "language_model") - and hasattr(model.model.language_model, "layers") - ): - return model.model.language_model.layers - # LLM-compatible fallbacks - if hasattr(model, "model") and hasattr(model.model, "layers"): - return model.model.layers - if hasattr(model, "language_model") and hasattr(model.language_model, "layers"): - return model.language_model.layers - if hasattr(model, "layers"): - return model.layers - return None - - -def _null_outside_window_layers(model, apply_text: bool = True): - if apply_text: - text_start = int( - getattr( - transformers.modeling_utils.PreTrainedModel, - "_text_start", - getattr(transformers.modeling_utils.PreTrainedModel, "_start", 0), - ) - ) - text_end = int( - getattr( - transformers.modeling_utils.PreTrainedModel, - "_text_end", - getattr(transformers.modeling_utils.PreTrainedModel, "_end", 0), - ) - ) - text_layers = _get_text_layers_container(model) - if text_layers is not None and text_end > text_start: - for idx, _ in enumerate(text_layers): - if idx < text_start or idx >= text_end: - text_layers[idx] = None - - -def _install_window_patch(model_cls): - if getattr(model_cls, "_window_patch_installed", False): - return - - original_init = model_cls.__init__ - - @functools.wraps(original_init) - def patched_init(self, *args, **kwargs): - original_init(self, *args, **kwargs) - _null_outside_window_layers(self, apply_text=True) - - model_cls.__init__ = patched_init - model_cls._window_patch_installed = True - - -def _resolve_export_root(onnx_path: Path) -> Path: - parts = list(onnx_path.parts) - if "onnx_layerwise_tmp" in parts: - marker_idx = parts.index("onnx_layerwise_tmp") - return Path(*parts[:marker_idx]) - return onnx_path.parent - - -def _install_shard_window_patch(): - if getattr(transformers.modeling_utils, "_window_shard_patch_installed", False): - return - - original_get_checkpoint_shard_files = transformers.modeling_utils.get_checkpoint_shard_files - - @functools.wraps(original_get_checkpoint_shard_files) - def patched_get_checkpoint_shard_files(*args, **kwargs): - shard_files, metadata = original_get_checkpoint_shard_files(*args, **kwargs) - weight_map = metadata.get("weight_map") - if not weight_map: - return shard_files, metadata - - start = int(getattr(transformers.modeling_utils.PreTrainedModel, "_start", 0)) - end = int(getattr(transformers.modeling_utils.PreTrainedModel, "_end", 0)) - text_start = int(getattr(transformers.modeling_utils.PreTrainedModel, "_text_start", start)) - text_end = int(getattr(transformers.modeling_utils.PreTrainedModel, "_text_end", end)) - has_text_window = text_end > text_start - if not has_text_window: - return shard_files, metadata - - selected_text_prefixes = tuple( - [f"model.layers.{layer_idx}." for layer_idx in range(text_start, text_end)] - + [f"model.language_model.layers.{layer_idx}." for layer_idx in range(text_start, text_end)] - ) - filtered_weight_map = {} - for checkpoint_key, shard_name in weight_map.items(): - if checkpoint_key.startswith("model.layers.") or checkpoint_key.startswith("model.language_model.layers."): - if not has_text_window or checkpoint_key.startswith(selected_text_prefixes): - filtered_weight_map[checkpoint_key] = shard_name - continue - filtered_weight_map[checkpoint_key] = shard_name - - if not filtered_weight_map: - return shard_files, metadata - - shard_name_to_path = {path.split("/")[-1]: path for path in shard_files} - filtered_shard_names = sorted(set(filtered_weight_map.values())) - filtered_shard_files = [shard_name_to_path[name] for name in filtered_shard_names if name in shard_name_to_path] - if not filtered_shard_files: - return shard_files, metadata - metadata["weight_map"] = filtered_weight_map - metadata["all_checkpoint_keys"] = list(filtered_weight_map.keys()) - return filtered_shard_files, metadata - - transformers.modeling_utils.get_checkpoint_shard_files = patched_get_checkpoint_shard_files - transformers.modeling_utils._window_shard_patch_installed = True - - -def _set_layer_windows( - text_start: int, - text_end: int, - text_total_layers: int, -): - transformers.modeling_utils.PreTrainedModel._start = text_start - transformers.modeling_utils.PreTrainedModel._end = text_end - transformers.modeling_utils.PreTrainedModel._total_layers = text_total_layers - transformers.modeling_utils.PreTrainedModel._text_start = text_start - transformers.modeling_utils.PreTrainedModel._text_end = text_end - transformers.modeling_utils.PreTrainedModel._text_total_layers = text_total_layers - - qeff_mod = QEfficient.transformers.models.qwen3_5_moe.modeling_qwen3_5_moe - qeff_mod.QEffQwen3_5MoeTextModel._start = text_start - qeff_mod.QEffQwen3_5MoeTextModel._end = text_end - qeff_mod.QEffQwen3_5MoeTextModel._total_layers = text_total_layers - - QEfficient.base.modeling_qeff.QEFFBaseModel._start = text_start - QEfficient.base.modeling_qeff.QEFFBaseModel._end = text_end - QEfficient.base.modeling_qeff.QEFFBaseModel._total_layers = text_total_layers - - -def _stitch_layerwise_if_available(export_root: Path): - # Some branches expose this helper; fall back gracefully when unavailable. - pipeline_fn = getattr(QEfficient.utils, "layerwise_pipeline", None) - if callable(pipeline_fn): - return pipeline_fn(str(export_root)) - print(f"layerwise_pipeline() not found. Layer-wise ONNX shards kept under: {export_root / 'onnx_layerwise_tmp'}") - return str(export_root / "onnx_layerwise_tmp") +def main(): + config = AutoConfig.from_pretrained(MODEL_ID) + config.torch_dtype = "float32" -def _new_qeff_model(model_id: str, config): - return QEFFAutoModelForImageTextToText.from_pretrained( - model_id, + qeff_model = QEFFAutoModelForImageTextToText.from_pretrained( + MODEL_ID, attn_implementation="eager", kv_offload=True, config=config, torch_dtype=torch.float32, + layerwise=True, ) - -def main(): - config = AutoConfig.from_pretrained(MODEL_ID) - config.torch_dtype = "float32" - - # if TEST_TEXT_LAYERS: - # config.text_config.num_hidden_layers = TEST_TEXT_LAYERS - - text_config = getattr(config, "text_config", config) - # config.vision_config.depth = 3 - text_total_layers = getattr(text_config, "num_hidden_layers", None) - if text_total_layers is None: - raise ValueError("Could not resolve `num_hidden_layers` from config.text_config.") - config.text_config.num_hidden_layers = text_total_layers - _ensure_pretrained_window_attrs() - _install_shard_window_patch() - - hf_qwen_mod = transformers.models.qwen3_5_moe.modeling_qwen3_5_moe - _install_window_patch(hf_qwen_mod.Qwen3_5MoeForConditionalGeneration) - _install_window_patch(hf_qwen_mod.Qwen3_5MoeForCausalLM) - - text_windows = _build_layer_windows(total_layers=text_total_layers, window_size=TEXT_WINDOW_SIZE) - # Keep layerwise only on text path in this loop. - num_windows = len(text_windows) - first_onnx_path = None - os.environ["LAYERWISE_EXPORT"] = "True" - for window_idx in range(num_windows): - text_start, text_end = text_windows[window_idx] if window_idx < len(text_windows) else (0, 0) - skip_lang_for_window = window_idx >= len(text_windows) - - _set_layer_windows( - text_start=text_start, - text_end=text_end, - text_total_layers=text_total_layers, - ) - print( - f"Exporting window {window_idx + 1}/{num_windows} " - f"text=[{text_start},{text_end})/{text_total_layers} " - f"skip_lang={skip_lang_for_window}" - ) - - qeff_model = _new_qeff_model(MODEL_ID, config) - if hasattr(qeff_model, "model"): - _null_outside_window_layers( - qeff_model.model, - apply_text=not skip_lang_for_window, - ) - - onnx_path = qeff_model.compile( - batch_size=BATCH_SIZE, - prefill_seq_len=PREFILL_SEQ_LEN, - ctx_len=CTX_LEN, - num_cores=NUM_CORES, - num_devices=NUM_DEVICES, - height=HEIGHT, - width=WIDTH, - mxfp6_matmul=False, - aic_enable_depth_first=True, - skip_vision=True, - skip_lang=skip_lang_for_window, - use_onnx_subfunctions=True, - enable_chunking=True, - mos=1, - user_tiled=True, - ) - - if first_onnx_path is None: - first_onnx_path = Path(str(onnx_path["lang_decode_qpc_path"])) - - if first_onnx_path is None: - raise RuntimeError("No ONNX path produced during layer-wise language export.") - - export_root = _resolve_export_root(first_onnx_path) - final_artifact = _stitch_layerwise_if_available(export_root) - print(f"Layer-wise language export completed. Final artifact/root: {final_artifact}") - - os.environ["LAYERWISE_EXPORT"] = "False" qpc_path = qeff_model.compile( - lang_onnx_path=final_artifact, - batch_size=BATCH_SIZE, - prefill_seq_len=PREFILL_SEQ_LEN, - ctx_len=CTX_LEN, - num_cores=NUM_CORES, - num_devices=NUM_DEVICES, - height=HEIGHT, - width=WIDTH, - mxfp6_matmul=False, + batch_size=1, + prefill_seq_len=1, + ctx_len=4096, + num_cores=16, + num_devices=1, + height=354, + width=536, + mxfp6_matmul=True, aic_enable_depth_first=True, skip_vision=True, - skip_lang=skip_lang_for_window, + split_retained_state_io=True, use_onnx_subfunctions=True, - enable_chunking=True, mos=1, + layerwise=True, + layerwise_window_size=1, ) - print(f"Final QPC path: {qpc_path}") if __name__ == "__main__": main() - - -# /opt/qti-aic/exec/qaic-compile -aic-hw -aic-hw-version=ai100 -m=/home/abhishek/.cache/qeff_models/Qwen3_5MoeForConditionalGeneration/Qwen3_5MoeDecoderWrapper-61a4400d63d1b0bb/final_data/merged_0-2.onnx -retained-state -convert-to-fp16 -aic-num-cores=16 -aic-enable-depth-first -mos=1 -network-specialization-config=/home/abhishek/.cache/qeff_models/Qwen3_5MoeForConditionalGeneration/Qwen3_5MoeDecoderWrapper-61a4400d63d1b0bb/final_data/specializations.json -custom-IO-list-file=/home/abhishek/.cache/qeff_models/Qwen3_5MoeForConditionalGeneration/Qwen3_5MoeDecoderWrapper-61a4400d63d1b0bb/final_data/qpc_binaries/custom_io.yaml -aic-binary-dir=/home/abhishek/.cache/qeff_models/Qwen3_5MoeForConditionalGeneration/Qwen3_5MoeDecoderWrapper-61a4400d63d1b0bb/final_data/qpc_binaries/qpc diff --git a/examples/image_text_to_text/models/qwen3_vl_moe/qwen3_vl_disagg_mode.py b/examples/image_text_to_text/models/qwen3_vl_moe/qwen3_vl_disagg_mode.py index 585a532fa..26a825791 100644 --- a/examples/image_text_to_text/models/qwen3_vl_moe/qwen3_vl_disagg_mode.py +++ b/examples/image_text_to_text/models/qwen3_vl_moe/qwen3_vl_disagg_mode.py @@ -19,6 +19,7 @@ from QEfficient.generation.cloud_infer import QAICInferenceSession model_id = "Qwen/Qwen3-VL-30B-A3B-Instruct" +# model_id = "tiny-random/qwen3-vl-moe" config = AutoConfig.from_pretrained(model_id) # For faster execution user can run with lesser layers, For Testing Purpose Only @@ -73,6 +74,8 @@ enable_chunking=True, skip_vision=True, use_onnx_subfunctions=True, + layerwise=True, + layerwise_window_size=1, ) @@ -92,6 +95,8 @@ prefill_only=False, skip_vision=True, use_onnx_subfunctions=True, + layerwise=True, + layerwise_window_size=1, ) lang_prefill_session = QAICInferenceSession(prefill_qpc_path.get("lang_prefill_qpc_path")) diff --git a/examples/image_text_to_text/models/qwen3_vl_moe/qwen3_vl_moe.py b/examples/image_text_to_text/models/qwen3_vl_moe/qwen3_vl_moe.py index 67f199f0c..35dd6b5ee 100644 --- a/examples/image_text_to_text/models/qwen3_vl_moe/qwen3_vl_moe.py +++ b/examples/image_text_to_text/models/qwen3_vl_moe/qwen3_vl_moe.py @@ -13,7 +13,8 @@ from QEfficient import QEFFAutoModelForImageTextToText -model_id = "Qwen/Qwen3-VL-30B-A3B-Instruct" +# model_id = "Qwen/Qwen3-VL-30B-A3B-Instruct" +model_id = "tiny-random/qwen3-vl-moe" config = AutoConfig.from_pretrained(model_id) # For faster execution user can run with lesser layers, For Testing Purpose Only. Please ensure to use the configuration given below as random configurations may fail due to deepstack diff --git a/examples/image_text_to_text/models/qwen3_vl_moe/qwen3_vl_moe_layerwise.py b/examples/image_text_to_text/models/qwen3_vl_moe/qwen3_vl_moe_layerwise.py index 990369fd9..fe120ed9d 100644 --- a/examples/image_text_to_text/models/qwen3_vl_moe/qwen3_vl_moe_layerwise.py +++ b/examples/image_text_to_text/models/qwen3_vl_moe/qwen3_vl_moe_layerwise.py @@ -5,326 +5,57 @@ # # ----------------------------------------------------------------------------- -import functools -import os -from pathlib import Path +"""Layerwise prefill compile example for Qwen3-VL-MoE. + +The orchestration loop that previously lived in this script has been moved +behind the ``layerwise=True`` flag on ``.compile()`` / ``.export()``. + +Note: ``layerwise=True`` is a provisional API and is scheduled for deprecation +once first-class multi-window export lands. Supported model types: +``qwen3_vl_moe``, ``qwen3_5_moe``, ``qwen3_moe``. +""" import torch -import transformers from transformers import AutoConfig -import QEfficient from QEfficient import QEFFAutoModelForImageTextToText -MODEL_ID = "Qwen/Qwen3-VL-235B-A22B-Instruct" -PREFILL_SEQ_LEN = 32 -CTX_LEN = 4096 -TEXT_WINDOW_SIZE = 1 - -# For quick local validation only (keep disabled for real export) -# TEST_TEXT_LAYERS = 4 - -# Export controls -BATCH_SIZE = 1 -NUM_CORES = 16 -NUM_DEVICES = 4 -HEIGHT = 354 -WIDTH = 536 - - -def _ensure_pretrained_window_attrs(): - if not hasattr(transformers.modeling_utils.PreTrainedModel, "_start"): - transformers.modeling_utils.PreTrainedModel._start = 0 - if not hasattr(transformers.modeling_utils.PreTrainedModel, "_end"): - transformers.modeling_utils.PreTrainedModel._end = 0 - if not hasattr(transformers.modeling_utils.PreTrainedModel, "_total_layers"): - transformers.modeling_utils.PreTrainedModel._total_layers = 0 - if not hasattr(transformers.modeling_utils.PreTrainedModel, "_text_start"): - transformers.modeling_utils.PreTrainedModel._text_start = 0 - if not hasattr(transformers.modeling_utils.PreTrainedModel, "_text_end"): - transformers.modeling_utils.PreTrainedModel._text_end = 0 - if not hasattr(transformers.modeling_utils.PreTrainedModel, "_text_total_layers"): - transformers.modeling_utils.PreTrainedModel._text_total_layers = 0 - - -def _build_layer_windows(total_layers: int, window_size: int): - if total_layers <= 0: - raise ValueError(f"Invalid total_layers={total_layers}. Expected: total_layers > 0.") - if window_size <= 0: - raise ValueError(f"Invalid window_size={window_size}. Expected: window_size > 0.") - - windows = [] - start = 0 - while start < total_layers: - end = min(total_layers, start + window_size) - windows.append((start, end)) - start = end - return windows - - -def _get_text_layers_container(model): - # VLM path first - if ( - hasattr(model, "model") - and hasattr(model.model, "language_model") - and hasattr(model.model.language_model, "layers") - ): - return model.model.language_model.layers - # LLM-compatible fallbacks - if hasattr(model, "model") and hasattr(model.model, "layers"): - return model.model.layers - if hasattr(model, "language_model") and hasattr(model.language_model, "layers"): - return model.language_model.layers - if hasattr(model, "layers"): - return model.layers - return None - - -def _null_outside_window_layers(model, apply_text: bool = True): - if apply_text: - text_start = int( - getattr( - transformers.modeling_utils.PreTrainedModel, - "_text_start", - getattr(transformers.modeling_utils.PreTrainedModel, "_start", 0), - ) - ) - text_end = int( - getattr( - transformers.modeling_utils.PreTrainedModel, - "_text_end", - getattr(transformers.modeling_utils.PreTrainedModel, "_end", 0), - ) - ) - text_layers = _get_text_layers_container(model) - if text_layers is not None and text_end > text_start: - for idx, _ in enumerate(text_layers): - if idx < text_start or idx >= text_end: - text_layers[idx] = None - - -def _install_window_patch(model_cls): - if getattr(model_cls, "_window_patch_installed", False): - return - - original_init = model_cls.__init__ - - @functools.wraps(original_init) - def patched_init(self, *args, **kwargs): - original_init(self, *args, **kwargs) - _null_outside_window_layers(self, apply_text=True) - - model_cls.__init__ = patched_init - model_cls._window_patch_installed = True - - -def _resolve_export_root(onnx_path: Path) -> Path: - parts = list(onnx_path.parts) - if "onnx_layerwise_tmp" in parts: - marker_idx = parts.index("onnx_layerwise_tmp") - return Path(*parts[:marker_idx]) - return onnx_path.parent - - -def _install_shard_window_patch(): - if getattr(transformers.modeling_utils, "_window_shard_patch_installed", False): - return - - original_get_checkpoint_shard_files = transformers.modeling_utils.get_checkpoint_shard_files - - @functools.wraps(original_get_checkpoint_shard_files) - def patched_get_checkpoint_shard_files(*args, **kwargs): - shard_files, metadata = original_get_checkpoint_shard_files(*args, **kwargs) - weight_map = metadata.get("weight_map") - if not weight_map: - return shard_files, metadata - - start = int(getattr(transformers.modeling_utils.PreTrainedModel, "_start", 0)) - end = int(getattr(transformers.modeling_utils.PreTrainedModel, "_end", 0)) - text_start = int(getattr(transformers.modeling_utils.PreTrainedModel, "_text_start", start)) - text_end = int(getattr(transformers.modeling_utils.PreTrainedModel, "_text_end", end)) - has_text_window = text_end > text_start - if not has_text_window: - return shard_files, metadata - - selected_text_prefixes = tuple( - [f"model.layers.{layer_idx}." for layer_idx in range(text_start, text_end)] - + [f"model.language_model.layers.{layer_idx}." for layer_idx in range(text_start, text_end)] - ) - filtered_weight_map = {} - for checkpoint_key, shard_name in weight_map.items(): - if checkpoint_key.startswith("model.layers.") or checkpoint_key.startswith("model.language_model.layers."): - if not has_text_window or checkpoint_key.startswith(selected_text_prefixes): - filtered_weight_map[checkpoint_key] = shard_name - continue - filtered_weight_map[checkpoint_key] = shard_name - - if not filtered_weight_map: - return shard_files, metadata +# MODEL_ID = "Qwen/Qwen3-VL-235B-A22B-Instruct" +MODEL_ID = "tiny-random/qwen3-vl-moe" - shard_name_to_path = {path.split("/")[-1]: path for path in shard_files} - filtered_shard_names = sorted(set(filtered_weight_map.values())) - filtered_shard_files = [shard_name_to_path[name] for name in filtered_shard_names if name in shard_name_to_path] - if not filtered_shard_files: - return shard_files, metadata - - metadata["weight_map"] = filtered_weight_map - metadata["all_checkpoint_keys"] = list(filtered_weight_map.keys()) - return filtered_shard_files, metadata - - transformers.modeling_utils.get_checkpoint_shard_files = patched_get_checkpoint_shard_files - transformers.modeling_utils._window_shard_patch_installed = True - - -def _set_layer_windows( - text_start: int, - text_end: int, - text_total_layers: int, -): - transformers.modeling_utils.PreTrainedModel._start = text_start - transformers.modeling_utils.PreTrainedModel._end = text_end - transformers.modeling_utils.PreTrainedModel._total_layers = text_total_layers - transformers.modeling_utils.PreTrainedModel._text_start = text_start - transformers.modeling_utils.PreTrainedModel._text_end = text_end - transformers.modeling_utils.PreTrainedModel._text_total_layers = text_total_layers - - # Qwen3-VL-MoE model code still checks QEffQwen3_5MoeTextModel window attrs - # in a few places. Set both classes to keep layer-wise behavior consistent. - qeff_vl_mod = QEfficient.transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe - qeff_vl_mod.QEffQwen3VLMoeTextModel._start = text_start - qeff_vl_mod.QEffQwen3VLMoeTextModel._end = text_end - qeff_vl_mod.QEffQwen3VLMoeTextModel._total_layers = text_total_layers - - qeff_35_mod = getattr(QEfficient.transformers.models, "qwen3_5_moe", None) - if qeff_35_mod is not None: - qeff_35_text_model = getattr(qeff_35_mod.modeling_qwen3_5_moe, "QEffQwen3_5MoeTextModel", None) - if qeff_35_text_model is not None: - qeff_35_text_model._start = text_start - qeff_35_text_model._end = text_end - qeff_35_text_model._total_layers = text_total_layers - - QEfficient.base.modeling_qeff.QEFFBaseModel._start = text_start - QEfficient.base.modeling_qeff.QEFFBaseModel._end = text_end - QEfficient.base.modeling_qeff.QEFFBaseModel._total_layers = text_total_layers - - -def _stitch_layerwise_if_available(export_root: Path): - # Some branches expose this helper; fall back gracefully when unavailable. - pipeline_fn = getattr(QEfficient.utils, "layerwise_pipeline", None) - if callable(pipeline_fn): - return pipeline_fn(str(export_root)) - print(f"layerwise_pipeline() not found. Layer-wise ONNX shards kept under: {export_root / 'onnx_layerwise_tmp'}") - return str(export_root / "onnx_layerwise_tmp") +def main(): + config = AutoConfig.from_pretrained(MODEL_ID) + config.torch_dtype = "float16" + # config.vision_config.deepstack_visual_indexes = [8, 27, 36] -def _new_qeff_model(model_id: str, config): - return QEFFAutoModelForImageTextToText.from_pretrained( - model_id, + qeff_model = QEFFAutoModelForImageTextToText.from_pretrained( + MODEL_ID, attn_implementation="eager", kv_offload=True, config=config, - torch_dtype=torch.float32, + torch_dtype=torch.float16, + layerwise=True, ) - -def main(): - config = AutoConfig.from_pretrained(MODEL_ID) - config.torch_dtype = "float32" - # config.vision_config.depth = 9 - # config.text_config.num_hidden_layers = 2 - config.vision_config.deepstack_visual_indexes = [8, 16, 24] - - # if TEST_TEXT_LAYERS: - # config.text_config.num_hidden_layers = TEST_TEXT_LAYERS - - text_config = getattr(config, "text_config", config) - text_total_layers = getattr(text_config, "num_hidden_layers", None) - if text_total_layers is None: - raise ValueError("Could not resolve `num_hidden_layers` from config.text_config.") - config.text_config.num_hidden_layers = text_total_layers - _ensure_pretrained_window_attrs() - _install_shard_window_patch() - - hf_qwen_mod = transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe - _install_window_patch(hf_qwen_mod.Qwen3VLMoeForConditionalGeneration) - if hasattr(hf_qwen_mod, "Qwen3VLMoeForCausalLM"): - _install_window_patch(hf_qwen_mod.Qwen3VLMoeForCausalLM) - - text_windows = _build_layer_windows(total_layers=text_total_layers, window_size=TEXT_WINDOW_SIZE) - # Keep layerwise only on text path in this loop. - num_windows = len(text_windows) - first_onnx_path = None - os.environ["LAYERWISE_EXPORT"] = "True" - for window_idx in range(num_windows): - text_start, text_end = text_windows[window_idx] if window_idx < len(text_windows) else (0, 0) - skip_lang_for_window = window_idx >= len(text_windows) - - _set_layer_windows( - text_start=text_start, - text_end=text_end, - text_total_layers=text_total_layers, - ) - print( - f"Exporting window {window_idx + 1}/{num_windows} " - f"text=[{text_start},{text_end})/{text_total_layers} " - f"skip_lang={skip_lang_for_window}" - ) - - qeff_model = _new_qeff_model(MODEL_ID, config) - if hasattr(qeff_model, "model"): - _null_outside_window_layers( - qeff_model.model, - apply_text=not skip_lang_for_window, - ) - - onnx_path = qeff_model.compile( - batch_size=BATCH_SIZE, - prefill_seq_len=PREFILL_SEQ_LEN, - ctx_len=CTX_LEN, - num_cores=NUM_CORES, - num_devices=NUM_DEVICES, - height=HEIGHT, - width=WIDTH, - mxfp6_matmul=True, - aic_enable_depth_first=True, - skip_vision=True, - skip_lang=skip_lang_for_window, - split_retained_state_io=True, - use_onnx_subfunctions=True, - prefill_only=True, - mos=1, - ) - - if first_onnx_path is None: - first_onnx_path = Path(str(onnx_path["lang_prefill_qpc_path"])) - - if first_onnx_path is None: - raise RuntimeError("No ONNX path produced during layer-wise language export.") - - export_root = _resolve_export_root(first_onnx_path) - final_artifact = _stitch_layerwise_if_available(export_root) - print(f"Layer-wise language export completed. Final artifact/root: {final_artifact}") - - os.environ["LAYERWISE_EXPORT"] = "False" qpc_path = qeff_model.compile( - lang_onnx_path=final_artifact, - batch_size=BATCH_SIZE, - prefill_seq_len=PREFILL_SEQ_LEN, - ctx_len=CTX_LEN, - num_cores=NUM_CORES, - num_devices=NUM_DEVICES, - height=HEIGHT, - width=WIDTH, + batch_size=1, + prefill_seq_len=32, + ctx_len=4096, + num_cores=16, + num_devices=4, + height=354, + width=536, mxfp6_matmul=True, aic_enable_depth_first=True, skip_vision=True, - skip_lang=skip_lang_for_window, split_retained_state_io=True, use_onnx_subfunctions=True, prefill_only=True, mos=1, + layerwise=True, + layerwise_window_size=1, ) - print(f"Final QPC path: {qpc_path}") diff --git a/examples/image_text_to_text/models/qwen3_vl_moe/qwen3_vl_moe_layerwise_decode.py b/examples/image_text_to_text/models/qwen3_vl_moe/qwen3_vl_moe_layerwise_decode.py index 18a61f6c4..9f5ea23d6 100644 --- a/examples/image_text_to_text/models/qwen3_vl_moe/qwen3_vl_moe_layerwise_decode.py +++ b/examples/image_text_to_text/models/qwen3_vl_moe/qwen3_vl_moe_layerwise_decode.py @@ -5,326 +5,87 @@ # # ----------------------------------------------------------------------------- -import functools -import os -from pathlib import Path +"""Layerwise prefill compile example for Qwen3-VL-MoE. + +The orchestration loop that previously lived in this script has been moved +behind the ``layerwise=True`` flag on ``.compile()`` / ``.export()``. + +Note: ``layerwise=True`` is a provisional API and is scheduled for deprecation +once first-class multi-window export lands. Supported model types: +``qwen3_vl_moe``, ``qwen3_5_moe``, ``qwen3_moe``. +""" import torch import transformers -from transformers import AutoConfig +from transformers import AutoConfig, AutoProcessor, TextStreamer -import QEfficient from QEfficient import QEFFAutoModelForImageTextToText -MODEL_ID = "Qwen/Qwen3-VL-235B-A22B-Instruct" -PREFILL_SEQ_LEN = 1 -CTX_LEN = 4096 -TEXT_WINDOW_SIZE = 1 - -# For quick local validation only (keep disabled for real export) -# TEST_TEXT_LAYERS = 4 - -# Export controls -BATCH_SIZE = 1 -NUM_CORES = 16 -NUM_DEVICES = 4 -HEIGHT = 354 -WIDTH = 536 - - -def _ensure_pretrained_window_attrs(): - if not hasattr(transformers.modeling_utils.PreTrainedModel, "_start"): - transformers.modeling_utils.PreTrainedModel._start = 0 - if not hasattr(transformers.modeling_utils.PreTrainedModel, "_end"): - transformers.modeling_utils.PreTrainedModel._end = 0 - if not hasattr(transformers.modeling_utils.PreTrainedModel, "_total_layers"): - transformers.modeling_utils.PreTrainedModel._total_layers = 0 - if not hasattr(transformers.modeling_utils.PreTrainedModel, "_text_start"): - transformers.modeling_utils.PreTrainedModel._text_start = 0 - if not hasattr(transformers.modeling_utils.PreTrainedModel, "_text_end"): - transformers.modeling_utils.PreTrainedModel._text_end = 0 - if not hasattr(transformers.modeling_utils.PreTrainedModel, "_text_total_layers"): - transformers.modeling_utils.PreTrainedModel._text_total_layers = 0 - - -def _build_layer_windows(total_layers: int, window_size: int): - if total_layers <= 0: - raise ValueError(f"Invalid total_layers={total_layers}. Expected: total_layers > 0.") - if window_size <= 0: - raise ValueError(f"Invalid window_size={window_size}. Expected: window_size > 0.") - - windows = [] - start = 0 - while start < total_layers: - end = min(total_layers, start + window_size) - windows.append((start, end)) - start = end - return windows - - -def _get_text_layers_container(model): - # VLM path first - if ( - hasattr(model, "model") - and hasattr(model.model, "language_model") - and hasattr(model.model.language_model, "layers") - ): - return model.model.language_model.layers - # LLM-compatible fallbacks - if hasattr(model, "model") and hasattr(model.model, "layers"): - return model.model.layers - if hasattr(model, "language_model") and hasattr(model.language_model, "layers"): - return model.language_model.layers - if hasattr(model, "layers"): - return model.layers - return None - - -def _null_outside_window_layers(model, apply_text: bool = True): - if apply_text: - text_start = int( - getattr( - transformers.modeling_utils.PreTrainedModel, - "_text_start", - getattr(transformers.modeling_utils.PreTrainedModel, "_start", 0), - ) - ) - text_end = int( - getattr( - transformers.modeling_utils.PreTrainedModel, - "_text_end", - getattr(transformers.modeling_utils.PreTrainedModel, "_end", 0), - ) - ) - text_layers = _get_text_layers_container(model) - if text_layers is not None and text_end > text_start: - for idx, _ in enumerate(text_layers): - if idx < text_start or idx >= text_end: - text_layers[idx] = None - - -def _install_window_patch(model_cls): - if getattr(model_cls, "_window_patch_installed", False): - return - - original_init = model_cls.__init__ - - @functools.wraps(original_init) - def patched_init(self, *args, **kwargs): - original_init(self, *args, **kwargs) - _null_outside_window_layers(self, apply_text=True) - - model_cls.__init__ = patched_init - model_cls._window_patch_installed = True - - -def _resolve_export_root(onnx_path: Path) -> Path: - parts = list(onnx_path.parts) - if "onnx_layerwise_tmp" in parts: - marker_idx = parts.index("onnx_layerwise_tmp") - return Path(*parts[:marker_idx]) - return onnx_path.parent - - -def _install_shard_window_patch(): - if getattr(transformers.modeling_utils, "_window_shard_patch_installed", False): - return - - original_get_checkpoint_shard_files = transformers.modeling_utils.get_checkpoint_shard_files - - @functools.wraps(original_get_checkpoint_shard_files) - def patched_get_checkpoint_shard_files(*args, **kwargs): - shard_files, metadata = original_get_checkpoint_shard_files(*args, **kwargs) - weight_map = metadata.get("weight_map") - if not weight_map: - return shard_files, metadata - - start = int(getattr(transformers.modeling_utils.PreTrainedModel, "_start", 0)) - end = int(getattr(transformers.modeling_utils.PreTrainedModel, "_end", 0)) - text_start = int(getattr(transformers.modeling_utils.PreTrainedModel, "_text_start", start)) - text_end = int(getattr(transformers.modeling_utils.PreTrainedModel, "_text_end", end)) - has_text_window = text_end > text_start - if not has_text_window: - return shard_files, metadata - - selected_text_prefixes = tuple( - [f"model.layers.{layer_idx}." for layer_idx in range(text_start, text_end)] - + [f"model.language_model.layers.{layer_idx}." for layer_idx in range(text_start, text_end)] - ) - filtered_weight_map = {} - for checkpoint_key, shard_name in weight_map.items(): - if checkpoint_key.startswith("model.layers.") or checkpoint_key.startswith("model.language_model.layers."): - if not has_text_window or checkpoint_key.startswith(selected_text_prefixes): - filtered_weight_map[checkpoint_key] = shard_name - continue - filtered_weight_map[checkpoint_key] = shard_name - - if not filtered_weight_map: - return shard_files, metadata - - shard_name_to_path = {path.split("/")[-1]: path for path in shard_files} - filtered_shard_names = sorted(set(filtered_weight_map.values())) - filtered_shard_files = [shard_name_to_path[name] for name in filtered_shard_names if name in shard_name_to_path] - if not filtered_shard_files: - return shard_files, metadata +# MODEL_ID = "Qwen/Qwen3-VL-235B-A22B-Instruct" +MODEL_ID = "Qwen/Qwen3-VL-30B-A3B-Instruct" +# MODEL_ID = "tiny-random/qwen3-vl-moe" - metadata["weight_map"] = filtered_weight_map - metadata["all_checkpoint_keys"] = list(filtered_weight_map.keys()) - return filtered_shard_files, metadata - - transformers.modeling_utils.get_checkpoint_shard_files = patched_get_checkpoint_shard_files - transformers.modeling_utils._window_shard_patch_installed = True - - -def _set_layer_windows( - text_start: int, - text_end: int, - text_total_layers: int, -): - transformers.modeling_utils.PreTrainedModel._start = text_start - transformers.modeling_utils.PreTrainedModel._end = text_end - transformers.modeling_utils.PreTrainedModel._total_layers = text_total_layers - transformers.modeling_utils.PreTrainedModel._text_start = text_start - transformers.modeling_utils.PreTrainedModel._text_end = text_end - transformers.modeling_utils.PreTrainedModel._text_total_layers = text_total_layers - - # Qwen3-VL-MoE model code still checks QEffQwen3_5MoeTextModel window attrs - # in a few places. Set both classes to keep layer-wise behavior consistent. - qeff_vl_mod = QEfficient.transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe - qeff_vl_mod.QEffQwen3VLMoeTextModel._start = text_start - qeff_vl_mod.QEffQwen3VLMoeTextModel._end = text_end - qeff_vl_mod.QEffQwen3VLMoeTextModel._total_layers = text_total_layers - - qeff_35_mod = getattr(QEfficient.transformers.models, "qwen3_5_moe", None) - if qeff_35_mod is not None: - qeff_35_text_model = getattr(qeff_35_mod.modeling_qwen3_5_moe, "QEffQwen3_5MoeTextModel", None) - if qeff_35_text_model is not None: - qeff_35_text_model._start = text_start - qeff_35_text_model._end = text_end - qeff_35_text_model._total_layers = text_total_layers - - QEfficient.base.modeling_qeff.QEFFBaseModel._start = text_start - QEfficient.base.modeling_qeff.QEFFBaseModel._end = text_end - QEfficient.base.modeling_qeff.QEFFBaseModel._total_layers = text_total_layers - - -def _stitch_layerwise_if_available(export_root: Path): - # Some branches expose this helper; fall back gracefully when unavailable. - pipeline_fn = getattr(QEfficient.utils, "layerwise_pipeline", None) - if callable(pipeline_fn): - return pipeline_fn(str(export_root)) - print(f"layerwise_pipeline() not found. Layer-wise ONNX shards kept under: {export_root / 'onnx_layerwise_tmp'}") - return str(export_root / "onnx_layerwise_tmp") +def main(): + config = AutoConfig.from_pretrained(MODEL_ID) + config.torch_dtype = "float16" + tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL_ID) + processor = AutoProcessor.from_pretrained(MODEL_ID) -def _new_qeff_model(model_id: str, config): - return QEFFAutoModelForImageTextToText.from_pretrained( - model_id, + qeff_model = QEFFAutoModelForImageTextToText.from_pretrained( + MODEL_ID, attn_implementation="eager", kv_offload=True, config=config, - torch_dtype=torch.float32, + torch_dtype=torch.float16, + layerwise=True, ) - - -def main(): - config = AutoConfig.from_pretrained(MODEL_ID) - config.torch_dtype = "float32" - # config.vision_config.depth = 9 - # config.text_config.num_hidden_layers = 2 - config.vision_config.deepstack_visual_indexes = [8, 27, 36] - - # if TEST_TEXT_LAYERS: - # config.text_config.num_hidden_layers = TEST_TEXT_LAYERS - - text_config = getattr(config, "text_config", config) - text_total_layers = getattr(text_config, "num_hidden_layers", None) - if text_total_layers is None: - raise ValueError("Could not resolve `num_hidden_layers` from config.text_config.") - config.text_config.num_hidden_layers = text_total_layers - _ensure_pretrained_window_attrs() - _install_shard_window_patch() - - hf_qwen_mod = transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe - _install_window_patch(hf_qwen_mod.Qwen3VLMoeForConditionalGeneration) - if hasattr(hf_qwen_mod, "Qwen3VLMoeForCausalLM"): - _install_window_patch(hf_qwen_mod.Qwen3VLMoeForCausalLM) - - text_windows = _build_layer_windows(total_layers=text_total_layers, window_size=TEXT_WINDOW_SIZE) - # Keep layerwise only on text path in this loop. - num_windows = len(text_windows) - first_onnx_path = None - os.environ["LAYERWISE_EXPORT"] = "True" - for window_idx in range(num_windows): - text_start, text_end = text_windows[window_idx] if window_idx < len(text_windows) else (0, 0) - skip_lang_for_window = window_idx >= len(text_windows) - - _set_layer_windows( - text_start=text_start, - text_end=text_end, - text_total_layers=text_total_layers, - ) - print( - f"Exporting window {window_idx + 1}/{num_windows} " - f"text=[{text_start},{text_end})/{text_total_layers} " - f"skip_lang={skip_lang_for_window}" - ) - - qeff_model = _new_qeff_model(MODEL_ID, config) - if hasattr(qeff_model, "model"): - _null_outside_window_layers( - qeff_model.model, - apply_text=not skip_lang_for_window, - ) - - onnx_path = qeff_model.compile( - batch_size=BATCH_SIZE, - prefill_seq_len=PREFILL_SEQ_LEN, - ctx_len=CTX_LEN, - num_cores=NUM_CORES, - num_devices=NUM_DEVICES, - height=HEIGHT, - width=WIDTH, - mxfp6_matmul=True, - aic_enable_depth_first=True, - skip_vision=True, - skip_lang=skip_lang_for_window, - split_retained_state_io=True, - use_onnx_subfunctions=True, - mos=1, - ) - - if first_onnx_path is None: - first_onnx_path = Path(str(onnx_path["lang_decode_qpc_path"])) - - if first_onnx_path is None: - raise RuntimeError("No ONNX path produced during layer-wise language export.") - - export_root = _resolve_export_root(first_onnx_path) - final_artifact = _stitch_layerwise_if_available(export_root) - print(f"Layer-wise language export completed. Final artifact/root: {final_artifact}") - - os.environ["LAYERWISE_EXPORT"] = "False" + batch_size = 1 qpc_path = qeff_model.compile( - lang_onnx_path=final_artifact, - batch_size=BATCH_SIZE, - prefill_seq_len=PREFILL_SEQ_LEN, - ctx_len=CTX_LEN, - num_cores=NUM_CORES, - num_devices=NUM_DEVICES, - height=HEIGHT, - width=WIDTH, + batch_size=1, + prefill_seq_len=1, + ctx_len=4096, + num_cores=16, + num_devices=4, + height=354, + width=536, mxfp6_matmul=True, + mxint8_kv_cache=True, aic_enable_depth_first=True, skip_vision=True, - skip_lang=skip_lang_for_window, split_retained_state_io=True, use_onnx_subfunctions=True, mos=1, + layerwise=True, + layerwise_window_size=1, ) - print(f"Final QPC path: {qpc_path}") + messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "Tell me about yourself."}, + ], + }, + ] + + messages = [messages] * batch_size + + inputs = processor.apply_chat_template( + messages, + add_generation_prompt=True, + tokenize=True, + return_dict=True, + return_tensors="pt", + ) + inputs = qeff_model.model.prepare_inputs_for_generation(inputs=inputs, prefill_seq_len=128, batch_size=batch_size) + # streamer = TextStreamer(tokenizer) + output = qeff_model.generate(inputs=inputs, generation_len=100) + print(output.generated_ids) + print(tokenizer.batch_decode(output.generated_ids)) + print(output) + if __name__ == "__main__": main() diff --git a/tests/unit_test/models/test_model_quickcheck.py b/tests/unit_test/models/test_model_quickcheck.py index 4b7ed6f17..03ac6ce27 100644 --- a/tests/unit_test/models/test_model_quickcheck.py +++ b/tests/unit_test/models/test_model_quickcheck.py @@ -1282,3 +1282,307 @@ def test_no_tag_falls_back_to_lm_rules(self): result = to_named_specializations(flat) assert result[0]["name"] == "Prefill" assert result[1]["name"] == "Decode" + + +# --------------------------------------------------------------------------- +# Layer-wise export (provisional, scheduled for deprecation) +# --------------------------------------------------------------------------- + +LAYERWISE_TINY_MODEL_ID = "tiny-random/qwen3-vl-moe" +LAYERWISE_TINY_MODEL_IDS = { + "qwen3_vl_moe": "tiny-random/qwen3-vl-moe", + "qwen3_5_moe": "tiny-random/qwen3.5-moe", + "qwen3_moe": "tiny-random/qwen3-moe", +} + + +@pytest.mark.llm_model +def test_layerwise_window_helpers(): + """Pure-Python coverage of the windowing helpers - no model load required.""" + from QEfficient.transformers.models import _layerwise + + assert _layerwise._build_layer_windows(4, 1) == [(0, 1), (1, 2), (2, 3), (3, 4)] + assert _layerwise._build_layer_windows(5, 2) == [(0, 2), (2, 4), (4, 5)] + with pytest.raises(ValueError): + _layerwise._build_layer_windows(0, 1) + with pytest.raises(ValueError): + _layerwise._build_layer_windows(4, 0) + + +@pytest.mark.llm_model +def test_layerwise_supported_guard_rejects_unrelated_model(): + """layerwise=True must hard-fail on architectures without windowing hooks.""" + from QEfficient.transformers.models import _layerwise + + config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-LlamaForCausalLM") + with pytest.raises(NotImplementedError, match="layerwise=True is only supported"): + _layerwise.assert_layerwise_supported(config) + + +@pytest.mark.llm_model +def test_layerwise_supported_guard_accepts_qwen3_vl_moe(): + from QEfficient.transformers.models import _layerwise + + try: + config = AutoConfig.from_pretrained(LAYERWISE_TINY_MODEL_ID) + except Exception as exc: + _skip_on_model_fetch_error(exc, LAYERWISE_TINY_MODEL_ID) + resolved = _layerwise.assert_layerwise_supported(config) + assert resolved in {"qwen3_vl_moe", "qwen3_vl_moe_text"} + + +@pytest.mark.llm_model +@pytest.mark.parametrize( + ("arch", "model_id"), + sorted(LAYERWISE_TINY_MODEL_IDS.items()), + ids=sorted(LAYERWISE_TINY_MODEL_IDS), +) +def test_layerwise_supported_guard_accepts_all_supported(arch, model_id): + """Guard must accept each architecture in the layerwise allowlist.""" + from QEfficient.transformers.models import _layerwise + + try: + config = AutoConfig.from_pretrained(model_id) + except Exception as exc: + _skip_on_model_fetch_error(exc, model_id) + resolved = _layerwise.assert_layerwise_supported(config) + assert arch in resolved or resolved.startswith(arch) + + +@pytest.mark.llm_model +def test_layerwise_off_does_not_set_env_var(tmp_path): + """Backward compat: layerwise must be controlled purely via the API, + never via environment variables, and must be off by default.""" + from QEfficient.base.modeling_qeff import QEFFBaseModel + from QEfficient.transformers.models import _layerwise # noqa: F401 + + assert os.environ.get("LAYERWISE_EXPORT") is None + assert QEFFBaseModel._layerwise_active is False + + +@pytest.mark.llm_model +def test_layerwise_context_manager_toggles_class_flag(): + """The driver's context manager must flip the class flag and restore it, + even on exception, with no env-var side-effects.""" + from QEfficient.base.modeling_qeff import QEFFBaseModel + from QEfficient.transformers.models import _layerwise + + assert QEFFBaseModel._layerwise_active is False + with _layerwise._layerwise_export_env(): + assert QEFFBaseModel._layerwise_active is True + assert "LAYERWISE_EXPORT" not in os.environ + assert QEFFBaseModel._layerwise_active is False + + try: + with _layerwise._layerwise_export_env(): + raise RuntimeError("boom") + except RuntimeError: + pass + assert QEFFBaseModel._layerwise_active is False + + +@pytest.mark.llm_model +def test_layerwise_uses_probe_model_for_cached_export(monkeypatch, tmp_path): + """A cached merged ONNX must avoid rebuilding per-window models.""" + from QEfficient.transformers.models import _layerwise + + class DummyConfig: + model_type = "qwen3_moe" + num_hidden_layers = 2 + + class ProbeModel: + def __init__(self, cached_path): + self.cached_path = cached_path + + def compile(self, **kwargs): + assert kwargs.pop("_layerwise_cache_probe") is True + return self.cached_path + + cached_path = tmp_path / "Model-hash" / "final_data" / "merged_0-2.onnx" + cached_path.parent.mkdir(parents=True) + cached_path.touch() + factory_called = False + + def fail_factory(*args, **kwargs): + nonlocal factory_called + factory_called = True + raise AssertionError("factory must not run when merged ONNX is cached") + + monkeypatch.setattr(_layerwise, "_install_window_patches_for", lambda model_type: None) + result = _layerwise.run_layerwise( + model_id="dummy", + config=DummyConfig(), + qeff_factory=fail_factory, + compile_kwargs={}, + probe_qeff_model=ProbeModel(cached_path), + final_compile=False, + ) + + assert result == str(cached_path) + assert factory_called is False + + +@pytest.mark.llm_model +def test_layerwise_cache_miss_exports_all_windows(monkeypatch, tmp_path): + from QEfficient.transformers.models import _layerwise + + class DummyConfig: + model_type = "qwen3_moe" + num_hidden_layers = 3 + + class ProbeModel: + def compile(self, **kwargs): + assert kwargs.pop("_layerwise_cache_probe") is True + return None + + exported_windows = [] + + class WindowModel: + def __init__(self): + self.model = object() + + def compile(self, **kwargs): + start = _layerwise._LAYERWISE_STATE["text_start"] + end = _layerwise._LAYERWISE_STATE["text_end"] + exported_windows.append((start, end)) + shard = tmp_path / "onnx_layerwise_tmp" / f"layer_{start}_{end}" / f"model_layer_tmp_{start}_{end}.onnx" + shard.parent.mkdir(parents=True, exist_ok=True) + shard.touch() + return str(shard) + + monkeypatch.setattr(_layerwise, "_install_window_patches_for", lambda model_type: None) + monkeypatch.setattr(_layerwise, "_null_outside_window_layers", lambda *args, **kwargs: None) + monkeypatch.setattr(_layerwise, "_slim_for_window_export", lambda *args, **kwargs: None) + monkeypatch.setattr( + _layerwise, + "_stitch_layerwise_if_available", + lambda export_root, total_layers=None: str(export_root / "merged.onnx"), + ) + + result = _layerwise.run_layerwise( + model_id="dummy", + config=DummyConfig(), + qeff_factory=lambda *args, **kwargs: WindowModel(), + compile_kwargs={}, + probe_qeff_model=ProbeModel(), + window_size=1, + final_compile=False, + ) + + assert exported_windows == [(0, 1), (1, 2), (2, 3)] + assert result.endswith("merged.onnx") + + +@pytest.mark.llm_model +def test_layerwise_cached_merged_prefers_complete_graph(tmp_path): + from QEfficient.transformers.models import _layerwise + + final_data = tmp_path / "final_data" + final_data.mkdir() + partial = final_data / "merged_9-48.onnx" + complete = final_data / "merged_0-48.onnx" + partial.touch() + complete.touch() + + assert _layerwise._cached_merged_onnx(tmp_path, total_layers=48) == complete + + +@pytest.mark.llm_model +def test_layerwise_export_hash_is_separate_from_default(monkeypatch): + from QEfficient.utils import export_utils + + class DummyConfig: + def to_diff_dict(self): + return {"model_type": "dummy"} + + class DummyModel: + config = DummyConfig() + + class DummyQEffModel: + model = DummyModel() + hash_params = {} + + def normal_export(self, example_inputs, output_names, dynamic_axes, **export_kwargs): + pass + + def _export_layerwise(self, example_inputs, output_names, dynamic_axes, **export_kwargs): + pass + + captured_export_kwargs = [] + + def fake_create_export_hash(**kwargs): + captured_export_kwargs.append(kwargs.get("export_kwargs")) + return "hash", {} + + monkeypatch.setattr(export_utils, "create_export_hash", fake_create_export_hash) + + export_utils._generate_export_hash(DummyQEffModel(), ({}, ["logits"], {}), {}, normal_export) + export_utils._generate_export_hash(DummyQEffModel(), ({}, ["logits"], {}), {}, _export_layerwise) + + assert captured_export_kwargs[0] in (None, {}) + assert captured_export_kwargs[1]["_qeff_layerwise_export"] is True + + +@pytest.mark.llm_model +def test_subfunction_compile_io_names_use_internal_retained_state(): + from QEfficient.transformers.models.modeling_auto import _compile_io_name, _state_input_name + + output_name = _compile_io_name("past_value.1_RetainedState", use_onnx_subfunctions=True) + + assert output_name == "past_value.1_InternalRetainedState" + assert _state_input_name(output_name) == "past_value.1" + assert _compile_io_name("vision_embeds_RetainedState", use_onnx_subfunctions=True) == "vision_embeds_RetainedState" + + +@pytest.mark.llm_model +def test_runtime_aliases_internal_retained_state_outputs(): + from QEfficient.generation.cloud_infer import _add_basename_binding_aliases, _public_retained_state_name + + assert _public_retained_state_name("past_key.0_InternalRetainedState") == "past_key.0_RetainedState" + assert _public_retained_state_name("past_value.1_InternalRetainedState") == "past_value.1_RetainedState" + assert _public_retained_state_name("logits") is None + + binding_map = {"layer_0/input_ids": 3} + bindings = [type("Binding", (), {"name": "layer_0/input_ids", "index": 3})()] + _add_basename_binding_aliases(binding_map, bindings) + assert binding_map["input_ids"] == 3 + + +@pytest.mark.llm_model +def test_layerwise_compile_hydrates_outer_qpc_paths(monkeypatch, tmp_path): + from QEfficient.transformers.models import _layerwise + from QEfficient.transformers.models.modeling_auto import _QEffAutoModelForImageTextToTextDualQPC + + qpc_path = tmp_path / "qpc" + model = object.__new__(_QEffAutoModelForImageTextToTextDualQPC) + model._pretrained_model_name_or_path = "dummy" + model.config = object() + model.vision_model = type("Vision", (), {"qpc_path": None})() + model.lang_model = type("Lang", (), {"qpc_path": None})() + model._build_layerwise_factory = lambda: None + + monkeypatch.setattr(_layerwise, "run_layerwise", lambda **kwargs: {"lang_decode_qpc_path": qpc_path}) + + result = model._run_layerwise_compile(layerwise_window_size=1) + + assert result == {"lang_decode_qpc_path": qpc_path} + assert model.qpc_paths == result + assert model.lang_model.qpc_path == qpc_path + + +@pytest.mark.llm_model +def test_layerwise_compile_rejects_unsupported_model(): + """End-to-end smoke: invoking layerwise=True on llama bubbles the guard error.""" + try: + qeff_model = QEFFAutoModelForCausalLM.from_pretrained( + "hf-internal-testing/tiny-random-LlamaForCausalLM", + ) + except Exception as exc: + _skip_on_model_fetch_error(exc, "tiny-random/llama") + # CausalLM does not expose a layerwise= kwarg today; only DualQPC VLM does. + # So this test guards via the helper directly to make the contract explicit + # for future surface expansion. + from QEfficient.transformers.models import _layerwise + + with pytest.raises(NotImplementedError): + _layerwise.assert_layerwise_supported(qeff_model.model.config)