diff --git a/QEfficient/base/modeling_qeff.py b/QEfficient/base/modeling_qeff.py index 3bb05c7b98..ff44723a8e 100755 --- a/QEfficient/base/modeling_qeff.py +++ b/QEfficient/base/modeling_qeff.py @@ -8,7 +8,6 @@ import gc import inspect import logging -import os import shutil import subprocess import warnings @@ -90,7 +89,23 @@ def __init__(self, model: torch.nn.Module, **kwargs) -> None: self.is_transformed: bool = False self._normalize_torch_dtype() - # Apply the transformations + # Apply the transformations. For layer-wise export the model arrives on + # the `meta` device and the data-mutating transforms must run on the + # real per-window weights instead, so application is deferred and the + # loop calls `apply_pytorch_transforms()` after streaming each window. + if not getattr(self, "_defer_pytorch_transforms", False): + self.apply_pytorch_transforms() + + if self.config.torch_dtype == torch.bfloat16: + logger.warning("BFloat16 dtype is not yet supported; converting to float16 precision!") + + def apply_pytorch_transforms(self) -> bool: + """Apply the class ``_pytorch_transforms`` to ``self.model`` in place. + + Returns ``True`` if any transform reported a change. Used both by the + normal init flow and by the layer-wise export loop (after streaming each + window's real weights into ``self.model``). + """ any_transformed = False for transform in self._pytorch_transforms: self.model, transformed = transform.apply(self.model) @@ -100,9 +115,7 @@ def __init__(self, model: torch.nn.Module, **kwargs) -> None: warnings.warn(f"No transforms applied to model: {self.model_name}. It may be an unsupported model!") else: logger.info(f"Pytorch transforms applied to model: {self.model_name}") - - if self.config.torch_dtype == torch.bfloat16: - logger.warning("BFloat16 dtype is not yet supported; converting to float16 precision!") + return any_transformed def _normalize_torch_dtype(self): """ @@ -488,8 +501,8 @@ def _export_layerwise( prefill_only: Optional[bool] = False, **export_kwargs, ) -> str: - idx = int(QEFFBaseModel._start) - end_idx = int(getattr(QEFFBaseModel, "_end", idx + 1)) + idx = int(getattr(self, "_start", 0)) + end_idx = int(getattr(self, "_end", idx + 1)) if end_idx <= idx: raise ValueError(f"Invalid export window: start={idx}, end={end_idx}") @@ -502,9 +515,6 @@ def _export_layerwise( self.onnx_path = onnx_path return onnx_path - # check if the model is in meta state or weights are offloaded - self._model_offloaded_check() - export_dir.mkdir(parents=True, exist_ok=True) # Setup temporary paths @@ -757,9 +767,6 @@ def _compile( **compiler_options, ) onnx_path = Path(onnx_path) - if os.environ.get("LAYERWISE_EXPORT", "False") == "True": - return onnx_path - compile_dir = Path(compile_dir or onnx_path.parent) qpc_path = compile_dir / "qpc" if not onnx_path.is_file(): diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 65b89d274f..a2125e9227 100755 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -77,6 +77,14 @@ get_padding_shape_from_config, ) from QEfficient.utils.check_ccl_specializations import process_ccl_specializations +from QEfficient.utils.custom_loader import CustomLoader +from QEfficient.utils.layerwise_utils import ( + build_layer_windows, + build_meta_model, + reset_window_state, + resolve_text_model, + set_window_state, +) from QEfficient.utils.logging_utils import logger from QEfficient.utils.sampler_utils import get_sampling_inputs_and_outputs @@ -1085,6 +1093,7 @@ def __init__(self, model, qaic_config: Optional[dict] = None, **kwargs): Additional keyword arguments passed to the base class constructor. """ _configure_proxy_for_model(self, kwargs.pop("enable_proxy", False)) + self.layerwise = bool(kwargs.pop("layerwise", False)) super().__init__(model, **kwargs) self.model = model.get_qeff_language_decoder() self.model.qaic_config = qaic_config @@ -1156,14 +1165,21 @@ def export( self.hash_params["prefill_only"] = False self.__update_prefill_transform(False, retain_full_kv=kwargs.get("retain_full_kv", False)) - if os.environ.get("LAYERWISE_EXPORT", "False") == "True": + if self.layerwise: + use_onnx_subfunctions = kwargs.get("use_onnx_subfunctions", False) + if not use_onnx_subfunctions: + logger.warning( + "use_onnx_subfunctions is being set to True because layerwise=True; " + "the layer-wise export pipeline requires ONNX subfunctions." + ) + use_onnx_subfunctions = True return self._export_layerwise( inputs, output_names=output_names, dynamic_axes=dynamic_axes, export_dir=export_dir, offload_pt_weights=offload_pt_weights, - use_onnx_subfunctions=kwargs.get("use_onnx_subfunctions", False), + use_onnx_subfunctions=use_onnx_subfunctions, ) else: return self._export( @@ -1278,11 +1294,23 @@ def __init__( warnings.warn( "full_batch_size argument is deprecated. Use continuous_batching=True instead.", DeprecationWarning, 2 ) + # Layer-wise export applies only to the language decoder; the vision + # encoder is always exported normally. + self.layerwise = bool(kwargs.pop("layerwise", False)) + self._layerwise_window_size = 1 + self._layerwise_total_layers = None + self._custom_loader = None + self._layerwise_qaic_config = qaic_config + self._layerwise_init_kwargs = dict(kwargs) + self._layerwise_pretrained_path = kwargs.get("pretrained_model_name_or_path", None) + self._layerwise_from_pretrained_kwargs = dict(kwargs.pop("_layerwise_from_pretrained_kwargs", {})) self.model = model self.config = model.config self.vision_model = QEffVisionEncoderForTextImageToTextModel(model, **kwargs) - self.lang_model = QEffCausalLMForTextImageToTextModel(model, qaic_config=qaic_config, **kwargs) + self.lang_model = QEffCausalLMForTextImageToTextModel( + model, qaic_config=qaic_config, layerwise=self.layerwise, **kwargs + ) self.continuous_batching = continuous_batching self.ccl_enabled = False if qaic_config: @@ -1295,6 +1323,31 @@ def __init__( # previous transform function. self.lang_model.model, _ = SamplerTransform.apply(self.lang_model.model, qaic_config, **kwargs) + # ---Layer-wise export setup (language decoder only)--- + if self.layerwise: + if self._layerwise_pretrained_path is None: + raise ValueError( + "layerwise=True requires `pretrained_model_name_or_path` to locate the safetensors checkpoint." + ) + text_total_layers = self._resolve_text_total_layers(model.config) + # Window only the language decoder layers; keep vision/projector/edges. + self._custom_loader = CustomLoader( + hf_auto_class=self._hf_auto_class, + pretrained_model_name_or_path=self._layerwise_pretrained_path, + layer_prefix=("model.layers.", "model.language_model.layers."), + total_layers=text_total_layers, + from_pretrained_kwargs=self._layerwise_from_pretrained_kwargs, + ) + self._layerwise_text_total_layers = text_total_layers + + @staticmethod + def _resolve_text_total_layers(config) -> int: + text_config = getattr(config, "text_config", config) + total = getattr(text_config, "num_hidden_layers", None) + if total is None: + raise ValueError("Could not resolve `num_hidden_layers` from the (text) config for layer-wise export.") + return int(total) + @classmethod def from_pretrained(cls, pretrained_model_name_or_path: str, qaic_config: Optional[dict] = None, **kwargs): """ @@ -1416,17 +1469,32 @@ def export( qaic_config=self.lang_model.model.qaic_config, ) - layerwise_export = os.environ.get("LAYERWISE_EXPORT", "False") == "True" + if prefill_only and prefill_seq_len > 1: + offload_pt_weights = False # to keep weight for decode onnx + else: + offload_pt_weights = kwargs.get("offload_pt_weights", True) - should_export = not skip_vision and ( - not layerwise_export - or ( - layerwise_export - and QEfficient.base.modeling_qeff.QEFFBaseModel._end - == QEfficient.base.modeling_qeff.QEFFBaseModel._total_layers + if self.layerwise and not skip_lang: + # Run the per-window layer-wise loop for the language decoder. The + # vision encoder is exported once inside the loop using the first + # window's real (fully-loaded) VLM weights. + self._run_layerwise_lang_export( + inputs["lang"], + output_names["lang"], + dynamic_axes["lang"], + export_dir=export_dir, + use_onnx_subfunctions=use_onnx_subfunctions, + offload_pt_weights=offload_pt_weights, + prefill_only=prefill_only, + enable_chunking=enable_chunking, + prefill_seq_len=prefill_seq_len, + vision_inputs=None if skip_vision else inputs["vision"], + vision_output_names=None if skip_vision else output_names["vision"], + vision_dynamic_axes=None if skip_vision else dynamic_axes["vision"], ) - ) - if should_export: + return self.onnx_path + + if not skip_vision and self.vision_model.onnx_path is None: self.vision_model.export( inputs["vision"], output_names["vision"], @@ -1436,11 +1504,6 @@ def export( use_onnx_subfunctions=use_onnx_subfunctions, ) - if prefill_only and prefill_seq_len > 1: - offload_pt_weights = False # to keep weight for decode onnx - else: - offload_pt_weights = kwargs.get("offload_pt_weights", True) - if not skip_lang and self.lang_model.onnx_path is None: self.lang_model.export( inputs["lang"], @@ -1455,6 +1518,110 @@ def export( ) return self.onnx_path + def _run_layerwise_lang_export( + self, + lang_inputs, + lang_output_names, + lang_dynamic_axes, + export_dir=None, + use_onnx_subfunctions=False, + offload_pt_weights=True, + prefill_only=False, + enable_chunking=False, + prefill_seq_len=None, + vision_inputs=None, + vision_output_names=None, + vision_dynamic_axes=None, + ) -> str: + """Run the per-window layer-wise export for the language decoder. + + For each ``(start, end)`` window: reload a window-filtered VLM (vision + + projector + window language layers) via the CustomLoader, rebuild the + language sub-model, apply its transforms, set the window state, and + export that window. After all windows, merge them into one language ONNX. + """ + if self._custom_loader is None: + raise RuntimeError( + "Layer-wise export requested but no CustomLoader was set up. " + "Load the model with `from_pretrained(..., layerwise=True)`." + ) + + # The split/merge pipeline keys off ONNX subfunction nodes. + if not use_onnx_subfunctions: + logger.warning( + "use_onnx_subfunctions is being set to True because layerwise=True; " + "the layer-wise export pipeline requires ONNX subfunctions." + ) + use_onnx_subfunctions = True + + total_layers = self._layerwise_total_layers or self._layerwise_text_total_layers + if total_layers <= 1: + raise ValueError(f"Layer-wise export needs more than one decoder layer, got total_layers={total_layers}.") + windows = build_layer_windows(total_layers, self._layerwise_window_size) + qaic_config = self._layerwise_qaic_config + init_kwargs = {k: v for k, v in self._layerwise_init_kwargs.items() if k != "_layerwise_from_pretrained_kwargs"} + + first_window_onnx = None + for start, end in windows: + # 1) Reload a window-filtered VLM (vision + projector + window + # language layers) with real weights. + window_model = self._custom_loader.load_window_model(start, end) + window_model.config.use_cache = True + self.model = window_model + self.config = window_model.config + + # Export the vision encoder once, using the first window's real + # (fully-loaded) vision weights. + if vision_inputs is not None and self.vision_model.onnx_path is None: + self.vision_model = QEffVisionEncoderForTextImageToTextModel(window_model, **init_kwargs) + self.vision_model.export( + vision_inputs, + vision_output_names, + vision_dynamic_axes, + export_dir=export_dir, + offload_pt_weights=False, + use_onnx_subfunctions=use_onnx_subfunctions, + ) + + # 2) Rebuild the language sub-model on the real window weights and + # apply its transforms (deferred during meta init). + self.lang_model = QEffCausalLMForTextImageToTextModel( + window_model, qaic_config=qaic_config, layerwise=False, **init_kwargs + ) + self.lang_model.model, _ = SamplerTransform.apply(self.lang_model.model, qaic_config, **init_kwargs) + + # 3) Set the layer-window state on the language text model. + lang_text_model, _ = resolve_text_model(self.lang_model.model) + set_window_state(lang_text_model, start, end, total_layers, qeff_wrapper=self.lang_model) + + # 4) Export just this language window. + window_onnx = self.lang_model._export_layerwise( + dict(lang_inputs), + output_names=list(lang_output_names), + dynamic_axes=dict(lang_dynamic_axes), + export_dir=export_dir, + use_onnx_subfunctions=use_onnx_subfunctions, + offload_pt_weights=False, + prefill_only=prefill_only, + ) + if first_window_onnx is None: + first_window_onnx = Path(window_onnx) + + if first_window_onnx is None: + raise RuntimeError("No layer windows were exported during layer-wise language export.") + + parts = list(first_window_onnx.parts) + if "onnx_layerwise_tmp" in parts: + export_root = Path(*parts[: parts.index("onnx_layerwise_tmp")]) + else: + export_root = first_window_onnx.parent + + final_onnx_path = QEfficient.utils.layerwise_pipeline(str(export_root), num_layers=total_layers) + + reset_window_state(lang_text_model, total_layers, qeff_wrapper=self.lang_model) + self.lang_model.onnx_path = final_onnx_path + return final_onnx_path + def transform( self, ctx_len: Optional[int] = None, @@ -1506,6 +1673,8 @@ def compile( prefill_only=None, enable_chunking=False, qaic_config: Optional[dict] = None, + layerwise_window_size: int = 1, + total_layers: Optional[int] = None, **compiler_options, ) -> str: """ @@ -1565,6 +1734,22 @@ def compile( if skip_lang and skip_vision: raise ValueError("Expected at least one of 'skip_lang' or 'skip_vision' to be False") + # Stash layer-wise controls; the per-window loop lives in export(). + self._layerwise_window_size = layerwise_window_size + if total_layers is not None: + if not isinstance(total_layers, int) or total_layers <= 1: + raise ValueError(f"`total_layers` must be an integer > 1, got {total_layers}.") + self._layerwise_total_layers = total_layers + if self.layerwise and self._custom_loader is not None: + self._custom_loader.total_layers = total_layers + cfg = self._custom_loader.from_pretrained_kwargs.get("config", None) + if cfg is None: + cfg = self.model.config + self._custom_loader.from_pretrained_kwargs["config"] = cfg + text_cfg = getattr(cfg, "text_config", None) + setattr(text_cfg if text_cfg is not None else cfg, "num_hidden_layers", total_layers) + self._layerwise_text_total_layers = total_layers + if self.continuous_batching and full_batch_size is None: raise TypeError("`full_batch_size` is required when `continuous_batching=True`.") @@ -2825,6 +3010,10 @@ def __new__( model, continuous_batching, qaic_config=qaic_config, **kwargs ) else: + # Layer-wise export is supported via the dual-QPC language decoder path; + # the single-QPC path ignores the flag for now. + kwargs.pop("layerwise", None) + kwargs.pop("_layerwise_from_pretrained_kwargs", None) return _QEFFAutoModelForImageTextToTextSingleQPC(model, qaic_config=qaic_config, **kwargs) @classmethod @@ -2835,6 +3024,7 @@ def from_pretrained( kv_offload: Optional[bool] = None, continuous_batching: bool = False, qaic_config: Optional[dict] = None, + layerwise: bool = False, **kwargs, ): """ @@ -2868,6 +3058,12 @@ def from_pretrained( """ enable_proxy = kwargs.pop("enable_proxy", False) + if layerwise and kv_offload is False: + raise NotImplementedError( + "layerwise=True is only supported with the dual-QPC path (kv_offload=True) " + "for image-text-to-text models." + ) + # TODO: add a check to see if kv_offload is allowed for given model by loading the config and checking architecture or type of config here. if continuous_batching and not kv_offload: NotImplementedError("Continuous batching is not supported for kv_offload = False") @@ -2881,7 +3077,23 @@ def from_pretrained( kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False}) _resolve_torch_dtype(kwargs) - model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, **kwargs) + if layerwise: + # Defer weight loading: build the VLM on `meta`; the language decoder + # weights are streamed per window during export. + model = build_meta_model(cls._hf_auto_class, pretrained_model_name_or_path, **kwargs) + layerwise_fp_kwargs = { + "attn_implementation": kwargs.get("attn_implementation", "eager"), + "low_cpu_mem_usage": kwargs.get("low_cpu_mem_usage", False), + } + if "torch_dtype" in kwargs: + layerwise_fp_kwargs["torch_dtype"] = kwargs["torch_dtype"] + if "config" in kwargs: + layerwise_fp_kwargs["config"] = kwargs["config"] + if "trust_remote_code" in kwargs: + layerwise_fp_kwargs["trust_remote_code"] = kwargs["trust_remote_code"] + else: + model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, **kwargs) + layerwise_fp_kwargs = {} kwargs.update({"enable_proxy": enable_proxy} if enable_proxy else {}) @@ -2891,6 +3103,8 @@ def from_pretrained( continuous_batching=continuous_batching, pretrained_model_name_or_path=pretrained_model_name_or_path, qaic_config=qaic_config, + layerwise=layerwise, + _layerwise_from_pretrained_kwargs=layerwise_fp_kwargs, **kwargs, ) @@ -3017,6 +3231,19 @@ def __init__( raise TypeError(f"Required pytorch module for CausalLM or LMHeadModel, got {model_class_name}") _configure_proxy_for_model(self, kwargs.pop("enable_proxy", False)) + # Layer-wise export defers weight loading; the model arrives on `meta` and + # weights are streamed per window during export via a CustomLoader. + self.layerwise = bool(kwargs.pop("layerwise", False)) + self._custom_loader = None + self._layerwise_pretrained_path = kwargs.get("pretrained_model_name_or_path", None) + self._layerwise_window_size = 1 + self._layerwise_total_layers = None + # When layer-wise, defer the (data-mutating) pytorch transforms in the + # base __init__; they are re-applied per window on the real weights. + self._defer_pytorch_transforms = self.layerwise + self._layerwise_from_pretrained_kwargs = dict(kwargs.get("_layerwise_from_pretrained_kwargs", {})) + kwargs.pop("_layerwise_from_pretrained_kwargs", None) + # TODO: remove from version 1.20 if kwargs.pop("full_batch_size", None): continuous_batching = True @@ -3061,6 +3288,22 @@ def __init__( if self.is_tlm: self.model.qaic_config["return_pdfs"] = True + # ---Layer-wise export setup--- + if self.layerwise: + self.hash_params["layerwise"] = True + text_model, layer_prefix = resolve_text_model(self.model) + if self._layerwise_pretrained_path is None: + raise ValueError( + "layerwise=True requires `pretrained_model_name_or_path` to locate the safetensors checkpoint." + ) + self._custom_loader = CustomLoader( + hf_auto_class=self._hf_auto_class, + pretrained_model_name_or_path=self._layerwise_pretrained_path, + layer_prefix=layer_prefix, + total_layers=self.num_layers, + from_pretrained_kwargs=self._layerwise_from_pretrained_kwargs, + ) + def __repr__(self) -> str: return self.__class__.__name__ + "\n" + self.model.__repr__() @@ -3072,6 +3315,7 @@ def from_pretrained( continuous_batching: bool = False, qaic_config: Optional[dict] = None, max_seq_len_cached: Optional[int] = None, + layerwise: bool = False, *args, **kwargs, ): @@ -3135,7 +3379,24 @@ def from_pretrained( kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False}) _resolve_torch_dtype(kwargs) - model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, *args, **kwargs) + if layerwise: + # Defer weight loading: build the model on the `meta` device from config + # only. Real weights are streamed one layer-window at a time during + # export via the attached CustomLoader. + model = build_meta_model(cls._hf_auto_class, pretrained_model_name_or_path, **kwargs) + # Kwargs the CustomLoader must reuse so each per-window load matches + # the non-layerwise load (dtype, eager attention, no low-cpu-mem). + layerwise_fp_kwargs = { + "attn_implementation": kwargs.get("attn_implementation", "eager"), + "low_cpu_mem_usage": kwargs.get("low_cpu_mem_usage", False), + } + if "torch_dtype" in kwargs: + layerwise_fp_kwargs["torch_dtype"] = kwargs["torch_dtype"] + if "trust_remote_code" in kwargs: + layerwise_fp_kwargs["trust_remote_code"] = kwargs["trust_remote_code"] + else: + model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, *args, **kwargs) + layerwise_fp_kwargs = {} if qaic_config is not None: qaic_config["pretrained_model_name_or_path"] = pretrained_model_name_or_path @@ -3148,6 +3409,8 @@ def from_pretrained( pretrained_model_name_or_path=pretrained_model_name_or_path, qaic_config=qaic_config, continuous_batching=continuous_batching, + layerwise=layerwise, + _layerwise_from_pretrained_kwargs=layerwise_fp_kwargs, **kwargs, ) return cls( @@ -3156,6 +3419,8 @@ def from_pretrained( qaic_config=qaic_config, pretrained_model_name_or_path=pretrained_model_name_or_path, max_seq_len_cached=max_seq_len_cached, + layerwise=layerwise, + _layerwise_from_pretrained_kwargs=layerwise_fp_kwargs, **kwargs, ) @@ -3241,6 +3506,7 @@ def export( prefill_seq_len: Optional[int] = None, num_cores: int = constants.DEFAULT_AIC_NUM_CORES, moe_prefill_packed_chunk_size: int = constants.MOE_PREFILL_PACKED_CHUNK_SIZE, + layerwise_window_size: int = 1, **kwargs, ) -> str: """ @@ -3506,15 +3772,25 @@ def _legacyify_cache(obj): self.model.forward = _qeff_patched_forward self.model._qeff_export_gemma3_cache_patch = True - if os.environ.get("LAYERWISE_EXPORT", "False") == "True": - return self._export_layerwise( + if self.layerwise: + # The layer-wise split/merge pipeline keys off the per-layer ONNX + # subfunction nodes, so subfunctions are mandatory here. + use_onnx_subfunctions = kwargs.get("use_onnx_subfunctions", False) + if not use_onnx_subfunctions: + logger.warning( + "use_onnx_subfunctions is being set to True because layerwise=True; " + "the layer-wise export pipeline requires ONNX subfunctions." + ) + use_onnx_subfunctions = True + return self._run_layerwise_export( example_inputs, output_names=output_names, dynamic_axes=dynamic_axes, export_dir=export_dir, - use_onnx_subfunctions=kwargs.get("use_onnx_subfunctions", False), + use_onnx_subfunctions=use_onnx_subfunctions, offload_pt_weights=kwargs.get("offload_pt_weights", True), prefill_only=prefill_only, + layerwise_window_size=layerwise_window_size or self._layerwise_window_size, ) else: return self._export( @@ -3527,6 +3803,86 @@ def _legacyify_cache(obj): prefill_only=prefill_only, ) + def _run_layerwise_export( + self, + example_inputs, + output_names, + dynamic_axes, + export_dir=None, + use_onnx_subfunctions=False, + offload_pt_weights=True, + prefill_only=False, + layerwise_window_size: int = 1, + ) -> str: + """Run the layer-wise export loop and return the merged ONNX path. + + For each ``(start, end)`` window: set window state on the text-model + class, stream that window's weights into the meta model via the + CustomLoader, then export the window. After all windows, run the + split -> add-prefix -> merge pipeline to produce one final graph + equivalent to a full-model export. + """ + if self._custom_loader is None: + raise RuntimeError( + "Layer-wise export requested but no CustomLoader was set up. " + "Load the model with `from_pretrained(..., layerwise=True)`." + ) + + total_layers = self.num_layers + if total_layers <= 1: + raise ValueError(f"Layer-wise export needs more than one decoder layer, got total_layers={total_layers}.") + windows = build_layer_windows(total_layers, layerwise_window_size) + qaic_config = getattr(self.model, "qaic_config", None) + + first_window_onnx = None + for start, end in windows: + # 1) Stream this window's real weights via HF (handles checkpoint -> + # module weight conversion such as fused-MoE experts). + window_model = self._custom_loader.load_window_model(start, end) + window_model.config.use_cache = True + self.model = window_model + self.config = window_model.config + + # 2) Re-apply the QEfficient pytorch transforms on the real weights. + self.apply_pytorch_transforms() + self.model.qaic_config = qaic_config + self.model, _ = SpDTransform.apply(self.model, qaic_config) + self.model, _ = SamplerTransform.apply(self.model, qaic_config) + + # 3) Set the layer-window state on the (post-transform) text model. + text_model, _ = resolve_text_model(self.model) + set_window_state(text_model, start, end, total_layers, qeff_wrapper=self) + + # 4) Export just this window. + window_onnx = self._export_layerwise( + dict(example_inputs), + output_names=list(output_names), + dynamic_axes=dict(dynamic_axes), + export_dir=export_dir, + use_onnx_subfunctions=use_onnx_subfunctions, + offload_pt_weights=False, + prefill_only=prefill_only, + ) + if first_window_onnx is None: + first_window_onnx = Path(window_onnx) + + if first_window_onnx is None: + raise RuntimeError("No layer windows were exported during layer-wise export.") + + # Resolve the export root that contains `onnx_layerwise_tmp/`. + parts = list(first_window_onnx.parts) + if "onnx_layerwise_tmp" in parts: + export_root = Path(*parts[: parts.index("onnx_layerwise_tmp")]) + else: + export_root = first_window_onnx.parent + + final_onnx_path = QEfficient.utils.layerwise_pipeline(str(export_root), num_layers=total_layers) + + # Restore full-model window state so any later full-graph operations behave normally. + reset_window_state(text_model, total_layers, qeff_wrapper=self) + self.onnx_path = final_onnx_path + return final_onnx_path + def build_prefill_specialization( self, prefill_seq_len: int = 32, @@ -3678,6 +4034,8 @@ def compile( enable_chunking: Optional[bool] = False, moe_prefill_packed_chunk_size: int = constants.MOE_PREFILL_PACKED_CHUNK_SIZE, retain_full_kv: Optional[bool] = None, + layerwise_window_size: int = 1, + total_layers: Optional[int] = None, **compiler_options, ) -> str: """ @@ -3771,6 +4129,38 @@ def compile( cache_compressed = mla_absorption.get("cache_compressed", False) else: cache_compressed = False + # Stash window size so the deferred export() loop can read it; the loop + # itself lives in export(), compile() only forwards the size. + self._layerwise_window_size = layerwise_window_size + # Optional override for the number of decoder layers exported layer-wise. + # Must be > 1 when provided. Lets users export a reduced layer count + # (e.g. for validation) without editing the model config. + if total_layers is not None: + if not isinstance(total_layers, int) or total_layers <= 1: + raise ValueError(f"`total_layers` must be an integer > 1, got {total_layers}.") + self._layerwise_total_layers = total_layers + # Apply the reduced-layer override up front so the deferred export() + # builds example inputs / KV IO for exactly this many layers, and the + # per-window loader materializes a model with this many layers. + if self.layerwise: + self.num_layers = total_layers + setattr(self.model.config, "num_hidden_layers", total_layers) + text_cfg = getattr(self.model.config, "text_config", None) + if text_cfg is not None: + setattr(text_cfg, "num_hidden_layers", total_layers) + if self._custom_loader is not None: + self._custom_loader.total_layers = total_layers + cfg = self._custom_loader.from_pretrained_kwargs.get("config", None) + if cfg is None: + # Ensure the per-window reload uses the overridden layer + # count by pinning a config on the loader. + cfg = self.model.config + self._custom_loader.from_pretrained_kwargs["config"] = cfg + if cfg is not None: + setattr(cfg, "num_hidden_layers", total_layers) + cfg_text = getattr(cfg, "text_config", None) + if cfg_text is not None: + setattr(cfg_text, "num_hidden_layers", total_layers) if ( self.model.qaic_config is not None and self.model.qaic_config.get("mla_absorption", None) is not None diff --git a/QEfficient/transformers/models/qwen3_5/modeling_qwen3_5.py b/QEfficient/transformers/models/qwen3_5/modeling_qwen3_5.py index ffb87e31eb..bf72dd03d5 100644 --- a/QEfficient/transformers/models/qwen3_5/modeling_qwen3_5.py +++ b/QEfficient/transformers/models/qwen3_5/modeling_qwen3_5.py @@ -94,12 +94,14 @@ def from_legacy_cache( cls, config, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor, ...], ...]] = None, + start_layer: int = 0, ) -> "QEffQwen3_5DynamicCache": cache = cls(config) if past_key_values is None: return cache - for layer_idx, layer_state in enumerate(past_key_values): + for offset, layer_state in enumerate(past_key_values): + layer_idx = start_layer + offset if cache.layer_types[layer_idx] == "full_attention": key_states, value_states = layer_state layer = QEffDynamicLayer() @@ -944,15 +946,20 @@ def forward( if past_key_values is not None and not isinstance(past_key_values, QEffQwen3_5DynamicCache): return_legacy_cache = True - past_key_values = QEffQwen3_5DynamicCache.from_legacy_cache(self.config, past_key_values) + past_key_values = QEffQwen3_5DynamicCache.from_legacy_cache(self.config, past_key_values, start_layer=getattr(self, "_start", 0)) elif use_cache and past_key_values is None: past_key_values = QEffQwen3_5DynamicCache(self.config) if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) + start = getattr(self, "_start", 0) + end = getattr(self, "_end", 0) + if end == 0: + end = self.config.num_hidden_layers + if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_seq_length(layer_idx=start) if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) @@ -970,7 +977,9 @@ def forward( position_embeddings = self.rotary_emb(hidden_states, position_ids[1:]) # position_embeddings = None all_hidden_states = () if output_hidden_states else None - for decoder_layer in self.layers[: self.config.num_hidden_layers]: + for layer_idx, decoder_layer in enumerate(self.layers): + if layer_idx < start or layer_idx >= end: + continue if output_hidden_states: all_hidden_states += (hidden_states,) @@ -988,15 +997,17 @@ def forward( **kwargs, ) - # break - - hidden_states = self.norm(hidden_states) + if end == getattr(self, "_total_layers", self.config.num_hidden_layers): + hidden_states = self.norm(hidden_states) if output_hidden_states: all_hidden_states += (hidden_states,) if return_legacy_cache: past_key_values = past_key_values.to_legacy_cache() + if return_legacy_cache: + past_key_values = past_key_values[start] + return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values if use_cache else None, @@ -1420,23 +1431,30 @@ def get_submodules_for_export(self) -> Type[nn.Module]: def forward( self, - input_ids, - vision_embeds, - position_ids, - image_idx, - past_key_values, + input_ids=None, + inputs_embeds=None, + vision_embeds=None, + position_ids=None, + image_idx=None, + past_key_values=None, batch_index: Optional[torch.LongTensor] = None, comp_ctx_lengths: Optional[List[int]] = None, ): - inputs_embeds = self.model.model.get_input_embeddings()(input_ids) - _, _, channel_size = inputs_embeds.shape - selected = input_ids == self.model.config.image_token_id - indices1 = selected.to(torch.int64).cumsum(1) - 1 - indices1 = torch.where(indices1 != -1, indices1 + image_idx, indices1) - indices0 = torch.arange(selected.unsqueeze(0).shape[0]).view(-1, 1) - image_features_expanded = vision_embeds.reshape(-1, channel_size).unsqueeze(0)[indices0, indices1] - image_input_embeds = torch.where(selected.unsqueeze(-1), image_features_expanded, inputs_embeds) - inputs_embeds = image_input_embeds + if inputs_embeds is None: + inputs_embeds = self.model.model.get_input_embeddings()(input_ids) + else: + inputs_embeds = inputs_embeds + + if getattr(self.language_model, "_start", 0) == 0: + _, _, channel_size = inputs_embeds.shape + selected = input_ids == self.model.config.image_token_id + indices1 = selected.to(torch.int64).cumsum(1) - 1 + indices1 = torch.where(indices1 != -1, indices1 + image_idx, indices1) + indices0 = torch.arange(selected.unsqueeze(0).shape[0]).view(-1, 1) + image_features_expanded = vision_embeds.reshape(-1, channel_size).unsqueeze(0)[indices0, indices1] + image_input_embeds = torch.where(selected.unsqueeze(-1), image_features_expanded, inputs_embeds) + inputs_embeds = image_input_embeds + outputs = self.language_model( inputs_embeds=inputs_embeds, position_ids=position_ids, @@ -1445,11 +1463,20 @@ def forward( batch_index=batch_index, use_cache=True, ) - logit_index = position_ids[0].to(torch.int32).argmax(1, keepdim=True) - hidden_states = outputs.last_hidden_state[torch.arange(position_ids[0].shape[0]).view(-1, 1), logit_index] - logits = self.model.lm_head(hidden_states) - image_idx = (indices1.max() + 1).unsqueeze(0).unsqueeze(0) - return logits, vision_embeds, image_idx, outputs.past_key_values[: len(past_key_values)] + + total_layers = getattr(self.language_model, "_total_layers", len(self.language_model.layers)) + if getattr(self.language_model, "_end", total_layers) == total_layers: + logit_index = position_ids[0].to(torch.int32).argmax(1, keepdim=True) + hidden_states = outputs.last_hidden_state[torch.arange(position_ids[0].shape[0]).view(-1, 1), logit_index] + logits = self.model.lm_head(hidden_states) + return logits, outputs.past_key_values + else: + logits = outputs.last_hidden_state + if getattr(self.language_model, "_start", 0) == 0: + image_idx = (indices1.max() + 1).unsqueeze(0).unsqueeze(0) + return logits, vision_embeds, image_idx, outputs.past_key_values + return logits, outputs.past_key_values + class QEffQwen3_5ForConditionalGeneration(Qwen3_5ForConditionalGeneration): diff --git a/QEfficient/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py b/QEfficient/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py index c564b644a6..b64edcaaba 100644 --- a/QEfficient/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py +++ b/QEfficient/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py @@ -104,13 +104,14 @@ def from_legacy_cache( cls, config, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor, ...], ...]] = None, + start_layer: int = 0, ) -> "QEffQwen3_5MoeDynamicCache": cache = cls(config) if past_key_values is None: return cache # for layer_idx, layer_state in enumerate(past_key_values): - layer_idx = QEffQwen3_5MoeTextModel._start + layer_idx = start_layer if cache.layer_types[layer_idx] == "full_attention": key_states, value_states = past_key_values[0] layer = QEffDynamicLayer() @@ -984,15 +985,15 @@ def forward( if past_key_values is not None and not isinstance(past_key_values, QEffQwen3_5MoeDynamicCache): return_legacy_cache = True - past_key_values = QEffQwen3_5MoeDynamicCache.from_legacy_cache(self.config, past_key_values) + past_key_values = QEffQwen3_5MoeDynamicCache.from_legacy_cache(self.config, past_key_values, start_layer=getattr(self, "_start", 0)) elif use_cache and past_key_values is None: past_key_values = QEffQwen3_5MoeDynamicCache(self.config) if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - start = QEffQwen3_5MoeTextModel._start - end = QEffQwen3_5MoeTextModel._end + start = getattr(self, "_start", 0) + end = getattr(self, "_end", 0) if cache_position is None: past_seen_tokens = past_key_values.get_seq_length(layer_idx=start) if past_key_values is not None else 0 cache_position = torch.arange( @@ -1038,7 +1039,7 @@ def forward( # break - if QEffQwen3_5MoeTextModel._end == QEffQwen3_5MoeTextModel._total_layers: + if end == getattr(self, "_total_layers", len(self.layers)): hidden_states = self.norm(hidden_states) if output_hidden_states: all_hidden_states += (hidden_states,) @@ -1046,7 +1047,7 @@ def forward( if return_legacy_cache: past_key_values = past_key_values.to_legacy_cache() - past_key_values = past_key_values[QEffQwen3_5MoeTextModel._start] + past_key_values = past_key_values[start] return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values if use_cache else None, @@ -1479,7 +1480,7 @@ def forward( inputs_embeds = self.model.model.get_input_embeddings()(input_ids) else: inputs_embeds = inputs_embeds - if QEffQwen3_5MoeTextModel._start == 0: + if getattr(self.language_model, "_start", 0) == 0: B, S, _ = inputs_embeds.shape input_ids = torch.zeros((B, S), dtype=torch.int64, device=inputs_embeds.device) _, _, channel_size = inputs_embeds.shape @@ -1507,7 +1508,7 @@ def forward( image_idx = (indices1.max() + 1).unsqueeze(0).unsqueeze(0) return logits, vision_embeds, image_idx, outputs.past_key_values - elif QEffQwen3_5MoeTextModel._end == QEffQwen3_5MoeTextModel._total_layers: + elif getattr(self.language_model, "_end", 0) == getattr(self.language_model, "_total_layers", len(self.language_model.layers)): outputs = self.language_model( inputs_embeds=inputs_embeds, position_ids=position_ids, @@ -1866,7 +1867,7 @@ def get_dummy_inputs( lang_inputs["past_key_values"] = [[] for _ in range(self.model.config.text_config.num_hidden_layers)] # for i in range(self.model.config.text_config.num_hidden_layers): - i = QEffQwen3_5MoeModel._start + i = getattr(self.model.language_model, "_start", 0) if self.model.config.text_config.layer_types[i] == "full_attention": for kv in ["key", "value"]: lang_inputs["past_key_values"][i].append(torch.zeros(kv_cache_shape, dtype=torch.float32)) diff --git a/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py b/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py index 7794e752ef..5dfc32bdd5 100644 --- a/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py +++ b/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py @@ -326,7 +326,7 @@ def forward( past_seen_tokens = past_key_values.get_seq_length(self.layer_idx) if past_key_values is not None else 0 blocking_config = getattr(self, "attn_blocking_config", AttentionBlockingConfig()) - self.layer_idx = self.layer_idx - getattr(QEffQwen3MoeModel, "_start", 0) + self.layer_idx = self.layer_idx - getattr(self, "_start", 0) use_blocking = blocking_config is not None and (blocking_config.mode != BlockingMode.NONE) if use_blocking: attn_output, attn_weights = generic_blocked_attention_interface( @@ -467,13 +467,13 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - start = QEffQwen3MoeModel._start - end = QEffQwen3MoeModel._end + start = getattr(self, "_start", 0) + end = getattr(self, "_end", 0) - if QEffQwen3MoeModel._end == 0: + if end == 0: total_layers = end = self.config.num_hidden_layers - QEffQwen3MoeModel._end = total_layers - QEffQwen3MoeModel._total_layers = total_layers + self._end = total_layers + self._total_layers = total_layers past_key_values_length = 0 if past_key_values is not None: @@ -514,8 +514,8 @@ def forward( cos_cached=cos, ) - total_layers = getattr(QEffQwen3MoeModel, "_total_layers", len(self.layers)) - if QEffQwen3MoeModel._end == total_layers: + total_layers = getattr(self, "_total_layers", len(self.layers)) + if end == total_layers: hidden_states = self.norm(hidden_states) # add hidden states from the last decoder layer @@ -576,8 +576,8 @@ def forward( ) hidden_states = outputs.last_hidden_state - total_layers = getattr(QEffQwen3MoeModel, "_total_layers", len(self.model.layers)) - if QEffQwen3MoeModel._end < total_layers: + total_layers = getattr(self.model, "_total_layers", len(self.model.layers)) + if getattr(self.model, "_end", 0) < total_layers: logits = hidden_states else: logit_idx = position_ids.to(torch.int32).argmax(1, keepdim=True) diff --git a/QEfficient/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py b/QEfficient/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py index 4a6259bf8d..ac3e695a40 100644 --- a/QEfficient/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +++ b/QEfficient/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py @@ -386,7 +386,7 @@ def forward( key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos_cached, sin_cached) - self.layer_idx = self.layer_idx - getattr(QEffQwen3VLMoeTextModel, "_start", 0) + self.layer_idx = self.layer_idx - getattr(self, "_start", 0) past_seen_tokens = past_key_values.get_seq_length(self.layer_idx) if past_key_values is not None else 0 blocking_config = getattr(self, "attn_blocking_config", AttentionBlockingConfig()) use_blocking = blocking_config is not None and (blocking_config.mode != BlockingMode.NONE) @@ -576,8 +576,8 @@ def forward( all_self_attns = () if output_attentions else None layer_idx = 0 - start = QEffQwen3VLMoeTextModel._start - end = QEffQwen3VLMoeTextModel._end + start = getattr(self, "_start", 0) + end = getattr(self, "_end", 0) layer_indices_to_run = kwargs.get("layer_indices_to_run", None) for layer_idx, decoder_layer in enumerate(self.layers): @@ -617,7 +617,7 @@ def forward( ) layer_idx += 1 - if QEffQwen3VLMoeTextModel._end == QEffQwen3VLMoeTextModel._total_layers: + if end == getattr(self, "_total_layers", len(self.layers)): hidden_states = self.norm(hidden_states) if output_hidden_states: all_hidden_states += (hidden_states,) @@ -814,7 +814,7 @@ def forward( else: inputs_embeds = inputs_embeds - if QEffQwen3VLMoeTextModel._start == 0: + if getattr(self.language_model, "_start", 0) == 0: B, N, C = inputs_embeds.shape selected = input_ids == self.model.config.image_token_id indices1 = selected.to(torch.int64).cumsum(1) - 1 @@ -857,7 +857,7 @@ def forward( image_idx = (indices1.max() + 1).unsqueeze(0).unsqueeze(0) return logits, vision_embeds, deepstack_features, image_idx, outputs.past_key_values - elif QEffQwen3VLMoeTextModel._end == QEffQwen3VLMoeTextModel._total_layers: + elif getattr(self.language_model, "_end", 0) == getattr(self.language_model, "_total_layers", None): outputs = self.language_model( inputs_embeds=inputs_embeds, position_ids=position_ids, diff --git a/QEfficient/utils/custom_loader.py b/QEfficient/utils/custom_loader.py new file mode 100644 index 0000000000..6b4b66a727 --- /dev/null +++ b/QEfficient/utils/custom_loader.py @@ -0,0 +1,146 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +"""Generic per-window model loader for layer-wise ONNX export. + +:class:`CustomLoader` materializes a *real* PyTorch model that contains the +weights for only a requested window of decoder layers. It delegates the actual +checkpoint -> module weight conversion to HuggingFace ``from_pretrained`` (so +model-specific restructuring such as fused-MoE experts is handled correctly), +while transparently restricting the sharded checkpoint to the window's layers so +that arbitrarily large models can be exported one window at a time without ever +materializing the full set of weights. +""" + +import contextlib +import functools +from typing import Optional, Sequence, Tuple, Union + +import torch + +from QEfficient.utils.logging_utils import logger + + +class CustomLoader: + """Load a window of decoder layers as a real PyTorch model. + + Parameters + ---------- + hf_auto_class : type + The HuggingFace auto class used to load the model (e.g. + ``AutoModelForCausalLM``). + pretrained_model_name_or_path : str + HuggingFace hub id or local path to the model directory. + layer_prefix : str + State-dict key prefix(es) used for the repeated decoder layers, e.g. + ``"model.layers."`` (or a sequence such as + ``("model.layers.", "model.language_model.layers.")`` for multimodal + models). Keys matching ``f"{prefix}{i}."`` belong to decoder layer ``i``; + all other keys are always loaded (vision encoder, projector, embeddings, + final norm, lm_head, ...). + total_layers : int + Total number of decoder layers in the model. + from_pretrained_kwargs : dict, optional + Keyword arguments forwarded to ``hf_auto_class.from_pretrained`` (e.g. + ``torch_dtype``, ``attn_implementation``, ``config``). + """ + + def __init__( + self, + hf_auto_class, + pretrained_model_name_or_path: str, + layer_prefix: Union[str, Sequence[str]], + total_layers: int, + from_pretrained_kwargs: Optional[dict] = None, + ) -> None: + self.hf_auto_class = hf_auto_class + self.pretrained_model_name_or_path = pretrained_model_name_or_path + self.layer_prefixes: Tuple[str, ...] = (layer_prefix,) if isinstance(layer_prefix, str) else tuple(layer_prefix) + self.total_layers = int(total_layers) + self.from_pretrained_kwargs = dict(from_pretrained_kwargs or {}) + + # ------------------------------------------------------------------ + # Window weight-map filtering + # ------------------------------------------------------------------ + def _selected_layer_prefixes(self, start: int, end: int) -> Tuple[str, ...]: + return tuple(f"{prefix}{i}." for prefix in self.layer_prefixes for i in range(start, end)) + + @contextlib.contextmanager + def _shard_filter(self, start: int, end: int): + """Restrict sharded checkpoints to the window's layers during load. + + Patches ``transformers.modeling_utils.get_checkpoint_shard_files`` so + only the shards containing the window's decoder-layer weights (plus all + non-layer / edge weights) are returned. This is a no-op for single-file + (non-sharded) checkpoints, which are loaded in full. + """ + import transformers + + original = transformers.modeling_utils.get_checkpoint_shard_files + selected_prefixes = self._selected_layer_prefixes(start, end) + layer_prefixes = self.layer_prefixes + + @functools.wraps(original) + def patched(*args, **kwargs): + shard_files, metadata = original(*args, **kwargs) + weight_map = metadata.get("weight_map") if isinstance(metadata, dict) else None + if not weight_map: + return shard_files, metadata + + filtered_weight_map = {} + for checkpoint_key, shard_name in weight_map.items(): + if checkpoint_key.startswith(layer_prefixes): + if checkpoint_key.startswith(selected_prefixes): + filtered_weight_map[checkpoint_key] = shard_name + continue + # Non-layer / edge weight (embeddings, final norm, lm_head, ...). + filtered_weight_map[checkpoint_key] = shard_name + + if not filtered_weight_map: + return shard_files, metadata + + shard_name_to_path = {path.split("/")[-1]: path for path in shard_files} + filtered_shard_names = sorted(set(filtered_weight_map.values())) + filtered_shard_files = [ + shard_name_to_path[name] for name in filtered_shard_names if name in shard_name_to_path + ] + if not filtered_shard_files: + return shard_files, metadata + + metadata = dict(metadata) + metadata["weight_map"] = filtered_weight_map + metadata["all_checkpoint_keys"] = list(filtered_weight_map.keys()) + return filtered_shard_files, metadata + + transformers.modeling_utils.get_checkpoint_shard_files = patched + try: + yield + finally: + transformers.modeling_utils.get_checkpoint_shard_files = original + + # ------------------------------------------------------------------ + # Loading + # ------------------------------------------------------------------ + def load_window_model(self, start: int, end: int) -> torch.nn.Module: + """Load and return a real PyTorch model for layers ``[start, end)``. + + Uses HuggingFace ``from_pretrained`` (which applies any checkpoint -> + module weight conversion) while restricting sharded checkpoints to the + window's layers. Decoder layers outside the window remain + un-materialized (left on ``meta`` by the loader) and are skipped by the + model ``forward`` via the ``_start/_end`` window contract. + """ + if end <= start: + raise ValueError(f"Invalid window: start={start}, end={end}") + + with self._shard_filter(start, end): + model = self.hf_auto_class.from_pretrained( + self.pretrained_model_name_or_path, + **self.from_pretrained_kwargs, + ) + logger.info(f"Loaded model weights for layer window [{start}, {end})") + return model diff --git a/QEfficient/utils/layerwise_utils.py b/QEfficient/utils/layerwise_utils.py new file mode 100644 index 0000000000..692b2a8693 --- /dev/null +++ b/QEfficient/utils/layerwise_utils.py @@ -0,0 +1,139 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +"""Helpers for the generic layer-wise ONNX export flow. + +These utilities keep model-specific resolution logic (which submodule holds the +repeated decoder layers, what the checkpoint key prefix is, how window state is +threaded into the forward) out of the model wrapper classes. +""" + +from typing import List, Tuple + +import torch.nn as nn + + +def build_layer_windows(total_layers: int, window_size: int) -> List[Tuple[int, int]]: + """Tile ``[0, total_layers)`` into descending ``(start, end)`` windows. + + Matches the tiling used by the original layer-wise example scripts: windows + are produced from the last layer towards layer 0. + """ + if total_layers <= 0: + raise ValueError(f"Invalid total_layers={total_layers}. Expected total_layers > 0.") + if window_size <= 0: + raise ValueError(f"Invalid window_size={window_size}. Expected window_size > 0.") + + windows: List[Tuple[int, int]] = [] + end = total_layers + while end > 0: + start = max(0, end - window_size) + windows.append((start, end)) + end = start + return windows + + +def _find_layers_container(module: nn.Module): + """Return the submodule that owns ``.layers`` (the repeated decoder stack). + + Tries the common nesting patterns used across CausalLM and VLM language + models. Returns ``(owner_module, layer_prefix)`` where ``layer_prefix`` is + the state-dict key prefix for the repeated layers, or ``(None, None)``. + """ + candidates = [ + ("model.language_model", getattr(getattr(module, "model", None), "language_model", None)), + ("language_model", getattr(module, "language_model", None)), + ("model", getattr(module, "model", None)), + ("", module), + ] + for prefix, owner in candidates: + if owner is not None and hasattr(owner, "layers") and isinstance(owner.layers, nn.ModuleList): + layer_prefix = f"{prefix}.layers." if prefix else "layers." + return owner, layer_prefix + return None, None + + +def resolve_text_model(model: nn.Module): + """Resolve the text/decoder model that carries the repeated layers. + + Returns ``(text_model_module, layer_prefix)``. ``text_model_module`` is the + module whose class holds the ``_start/_end/_total_layers`` window contract + (i.e. the module that owns ``.layers``). + """ + owner, layer_prefix = _find_layers_container(model) + if owner is None: + raise RuntimeError( + "Could not locate the repeated decoder-layer container (`.layers`) on the model; " + "layer-wise export is not supported for this architecture." + ) + return owner, layer_prefix + + +def set_window_state(text_model: nn.Module, start: int, end: int, total_layers: int, qeff_wrapper=None) -> None: + """Set the layer-window state as instance attributes on the text-model and wrapper. + + The model ``forward`` reads ``_start/_end/_total_layers`` as class + attributes (inert defaults). This helper sets them as instance attributes + on ``text_model`` (and its child attention modules) so that no global class + state is mutated. + + When ``qeff_wrapper`` is provided (the :class:`QEFFBaseModel` subclass + instance driving the export), the window is also mirrored onto it so that + ``_export_layerwise`` can read ``self._start / self._end``. + """ + text_model._start = int(start) + text_model._end = int(end) + text_model._total_layers = int(total_layers) + + # Propagate _start to child attention modules that need the layer offset. + if hasattr(text_model, "layers"): + for layer in text_model.layers: + if hasattr(layer, "self_attn"): + layer.self_attn._start = int(start) + + if qeff_wrapper is not None: + qeff_wrapper._start = int(start) + qeff_wrapper._end = int(end) + qeff_wrapper._total_layers = int(total_layers) + + +def reset_window_state(text_model: nn.Module, total_layers: int, qeff_wrapper=None) -> None: + """Reset window state to cover the full model (``[0, total_layers)``).""" + set_window_state(text_model, 0, total_layers, total_layers, qeff_wrapper=qeff_wrapper) + + +def build_meta_model(hf_auto_class, pretrained_model_name_or_path: str, **kwargs): + """Instantiate a model on the ``meta`` device from its config only. + + Used for layer-wise export where weights are streamed per window instead of + loaded up front. ``kwargs`` mirror those passed to ``from_pretrained`` + (e.g. ``torch_dtype``, ``attn_implementation``); only config-relevant ones + are forwarded to ``from_config``. + """ + import torch + from transformers import AutoConfig + + config = kwargs.pop("config", None) + if config is None: + config = AutoConfig.from_pretrained( + pretrained_model_name_or_path, trust_remote_code=kwargs.get("trust_remote_code", False) + ) + + torch_dtype = kwargs.get("torch_dtype", torch.float32) + if torch_dtype is not None: + config.torch_dtype = torch_dtype + config.use_cache = True + + from_config_kwargs = {} + if "attn_implementation" in kwargs: + from_config_kwargs["attn_implementation"] = kwargs["attn_implementation"] + if "trust_remote_code" in kwargs: + from_config_kwargs["trust_remote_code"] = kwargs["trust_remote_code"] + + with torch.device("meta"): + model = hf_auto_class.from_config(config, **from_config_kwargs) + return model diff --git a/examples/disagg_serving/qwen3moe_layerwise.py b/examples/disagg_serving/qwen3moe_layerwise.py index ea29e11747..0acb4073de 100644 --- a/examples/disagg_serving/qwen3moe_layerwise.py +++ b/examples/disagg_serving/qwen3moe_layerwise.py @@ -5,210 +5,44 @@ # # ----------------------------------------------------------------------------- -import functools -import os +"""Layer-wise ONNX export + compile for a large Qwen3-MoE causal LM. + +The layer-wise flow loads only one window of decoder layers at a time, exports +that window, then splits/prefixes/merges all windows into a single ONNX graph +that is equivalent to a full-model export. This keeps peak host memory bounded +by a single window's weights, enabling export of models that do not fit in +memory all at once. + +Usage is identical to a normal QEFFAutoModelForCausalLM export, except: + * pass ``layerwise=True`` to ``from_pretrained`` (the model is built on the + ``meta`` device; weights are streamed per window during export), and + * optionally pass ``layerwise_window_size`` to ``compile`` (defaults to 1). +""" + import time -from pathlib import Path -import transformers from transformers import AutoConfig, AutoTokenizer -import QEfficient from QEfficient import QEFFAutoModelForCausalLM -model_id = "Qwen/Qwen3-235B-A22B-Instruct-2507" # weights are not required to convert to fp32 -# model_id = "yujiepan/qwen3-moe-tiny-random" -prompt = """ -Explain quantum computing in simple terms. -""" +model_id = "yujiepan/qwen3-moe-tiny-random" +# model_id = "Qwen/Qwen3-235B-A22B-Instruct-2507" # weights are not required to convert to fp32 + config = AutoConfig.from_pretrained(model_id) tokenizer = AutoTokenizer.from_pretrained(model_id) -config = AutoConfig.from_pretrained(model_id) -tokenizer = AutoTokenizer.from_pretrained(model_id) PREFILL_SEQ_LEN = 4 CTX_LEN = 128 +# Build the model on `meta`; per-window weights are streamed during export. +qeff_model = QEFFAutoModelForCausalLM.from_pretrained(model_id, config=config, layerwise=True) -def _ensure_pretrained_window_attrs(): - if not hasattr(transformers.modeling_utils.PreTrainedModel, "_start"): - transformers.modeling_utils.PreTrainedModel._start = 0 - if not hasattr(transformers.modeling_utils.PreTrainedModel, "_end"): - transformers.modeling_utils.PreTrainedModel._end = 0 - - -def _build_layer_windows(total_layers: int, window_size: int): - if total_layers <= 0: - raise ValueError(f"Invalid total_layers={total_layers}. Expected: total_layers > 0.") - if window_size <= 0: - raise ValueError(f"Invalid window_size={window_size}. Expected: window_size > 0.") - - windows = [] - end = total_layers - while end > 0: - start = max(0, end - window_size) - windows.append((start, end)) - end = start - - return windows - - -def _null_outside_window_layers(model): - start = int(getattr(transformers.modeling_utils.PreTrainedModel, "_start", 0)) - end = int(getattr(transformers.modeling_utils.PreTrainedModel, "_end", 0)) - layers = getattr(getattr(model, "model", None), "layers", None) - if layers is None: - return - for idx, _ in enumerate(layers): - if idx < start or idx >= end: - layers[idx] = None - - -def _install_window_patch(model_cls): - if getattr(model_cls, "_window_patch_installed", False): - return - - original_init = model_cls.__init__ - - @functools.wraps(original_init) - def patched_init(self, *args, **kwargs): - original_init(self, *args, **kwargs) - _null_outside_window_layers(self) - - model_cls.__init__ = patched_init - model_cls._window_patch_installed = True - - -def _resolve_export_root(onnx_path: Path) -> Path: - parts = list(onnx_path.parts) - if "onnx_layerwise_tmp" in parts: - marker_idx = parts.index("onnx_layerwise_tmp") - return Path(*parts[:marker_idx]) - return onnx_path.parent - - -def _install_shard_window_patch(): - if getattr(transformers.modeling_utils, "_window_shard_patch_installed", False): - return - - original_get_checkpoint_shard_files = transformers.modeling_utils.get_checkpoint_shard_files - - @functools.wraps(original_get_checkpoint_shard_files) - def patched_get_checkpoint_shard_files(*args, **kwargs): - shard_files, metadata = original_get_checkpoint_shard_files(*args, **kwargs) - weight_map = metadata.get("weight_map") - if not weight_map: - return shard_files, metadata - - start = int(getattr(transformers.modeling_utils.PreTrainedModel, "_start", 0)) - end = int(getattr(transformers.modeling_utils.PreTrainedModel, "_end", 0)) - if end <= start: - return shard_files, metadata - - selected_prefixes = tuple(f"model.layers.{layer_idx}." for layer_idx in range(start, end)) - filtered_weight_map = {} - for checkpoint_key, shard_name in weight_map.items(): - if checkpoint_key.startswith("model.layers."): - if checkpoint_key.startswith(selected_prefixes): - filtered_weight_map[checkpoint_key] = shard_name - continue - filtered_weight_map[checkpoint_key] = shard_name - - if not filtered_weight_map: - return shard_files, metadata - - shard_name_to_path = {path.split("/")[-1]: path for path in shard_files} - filtered_shard_names = sorted(set(filtered_weight_map.values())) - filtered_shard_files = [shard_name_to_path[name] for name in filtered_shard_names if name in shard_name_to_path] - if not filtered_shard_files: - return shard_files, metadata - - metadata["weight_map"] = filtered_weight_map - metadata["all_checkpoint_keys"] = list(filtered_weight_map.keys()) - return filtered_shard_files, metadata - - transformers.modeling_utils.get_checkpoint_shard_files = patched_get_checkpoint_shard_files - transformers.modeling_utils._window_shard_patch_installed = True - - -_ensure_pretrained_window_attrs() -_install_shard_window_patch() -text_config = getattr(config, "text_config", config) -resolved_total_layers = getattr(text_config, "num_hidden_layers", None) -if resolved_total_layers is None: - raise ValueError("Could not resolve `num_hidden_layers` from config.") - -# Layerwise window size. `1` keeps only one decoder layer active per window. -window_size = 1 -total_layers = 2 # resolved_total_layers # config.num_hidden_layers = 1 -windows = _build_layer_windows(total_layers=total_layers, window_size=window_size) -qeff_model = None -first_onnx_path = None export_start = time.perf_counter() -os.environ["LAYERWISE_EXPORT"] = "True" -for start, end in windows: - transformers.modeling_utils.PreTrainedModel._start = start - transformers.modeling_utils.PreTrainedModel._end = end - transformers.modeling_utils.PreTrainedModel._total_layers = total_layers - QEfficient.transformers.models.qwen3_moe.modeling_qwen3_moe.QEffQwen3MoeModel._start = start - QEfficient.transformers.models.qwen3_moe.modeling_qwen3_moe.QEffQwen3MoeModel._end = end - QEfficient.transformers.models.qwen3_moe.modeling_qwen3_moe.QEffQwen3MoeModel._total_layers = total_layers - QEfficient.base.modeling_qeff.QEFFBaseModel._start = start - QEfficient.base.modeling_qeff.QEFFBaseModel._end = end - QEfficient.base.modeling_qeff.QEFFBaseModel._total_layers = total_layers - _install_window_patch(transformers.models.qwen3_moe.modeling_qwen3_moe.Qwen3MoeForCausalLM) - qeff_model = QEFFAutoModelForCausalLM.from_pretrained(model_id, config=config) - if hasattr(qeff_model, "model"): - _null_outside_window_layers(qeff_model.model) - - # Following command errors out by default, the user is supposed to run the printed command and provide the generated qpc path as prefill_qpc_path commenting out lines 55-68 - - # prefill_qpc_path = "" - ################################# prefill - - onnx_path = qeff_model.compile( - prefill_seq_len=PREFILL_SEQ_LEN, - ctx_len=CTX_LEN, - num_cores=16, - mxfp6_matmul=True, - mxint8_kv_cache=True, - num_devices=1, - split_retained_state_io=True, - mos=1, - aic_enable_depth_first=True, - num_speculative_tokens=None, - prefill_only=True, - enable_chunking=True, - use_onnx_subfunctions=True, - ) - - ################################# decode - # onnx_path = qeff_model.compile( - # prefill_seq_len=PREFILL_SEQ_LEN, - # ctx_len=CTX_LEN, - # num_cores=16, - # mxfp6_matmul=True, - # mxint8_kv_cache=True, - # num_devices=1, - # split_retained_state_io=True, - # mos=1, - # aic_enable_depth_first=True, - # num_speculative_tokens=None, - # prefill_only=False, - # use_onnx_subfunctions=True, - # ) - if first_onnx_path is None: - first_onnx_path = Path(onnx_path) - -if first_onnx_path is None: - raise RuntimeError("No ONNX path produced during compilation.") -export_root = _resolve_export_root(first_onnx_path) -final_onnx_path = QEfficient.utils.layerwise_pipeline(str(export_root)) -print(f"Layer-wise language export completed. Final artifact/root: {final_onnx_path}") -os.environ["LAYERWISE_EXPORT"] = "False" +# A single compile call drives the full layer-wise loop internally: +# per window -> load weights -> apply transforms -> export window +# then split -> add prefix -> merge into one final ONNX, then compile. qpc_path = qeff_model.compile( - onnx_path=final_onnx_path, prefill_seq_len=PREFILL_SEQ_LEN, ctx_len=CTX_LEN, num_cores=16, @@ -222,97 +56,8 @@ def patched_get_checkpoint_shard_files(*args, **kwargs): prefill_only=True, enable_chunking=True, use_onnx_subfunctions=True, + layerwise_window_size=1, ) +print(f"Layer-wise export + compile completed in {time.perf_counter() - export_start:.2f}s") print(f"QPC path: {qpc_path}") - -# inputs = tokenizer(prompt, return_tensors="np", padding=True) -# position_ids = inputs["attention_mask"].sum(1, keepdims=True) -# generation_len = CTX_LEN - position_ids.max() -# padded_len = inputs["input_ids"].shape[1] -# num_chunks = -(padded_len // -PREFILL_SEQ_LEN) # ceil divide without float -# padded_len = num_chunks * PREFILL_SEQ_LEN # Convert to a multiple of prompt_len -# inputs = tokenizer(prompt, return_tensors="np", padding="max_length", max_length=padded_len) -# inputs["position_ids"] = np.where(inputs.pop("attention_mask"), np.arange(padded_len), -1) -# inputs.pop("token_type_ids", None) -# inputs = {k: torch.from_numpy(v) for k, v in inputs.items()} -# inputs.pop("past_key_values", None) -# inputs = {k: v.detach().numpy() for k, v in inputs.items()} - - -# prefill_session = QAICInferenceSession(prefill_qpc_path) - - -# all_outputs = [] -# for i in range(num_chunks): -# chunk_inputs = inputs.copy() -# chunk_inputs["input_ids"] = inputs["input_ids"][:, i * PREFILL_SEQ_LEN : (i + 1) * PREFILL_SEQ_LEN] -# chunk_inputs["position_ids"] = inputs["position_ids"][:, i * PREFILL_SEQ_LEN : (i + 1) * PREFILL_SEQ_LEN] -# ins = time.time() -# qpc_out = prefill_session.run(chunk_inputs) -# print(f"time for this run={time.time() - ins}") -# for i in range(config.num_hidden_layers): -# inputs[f"past_key.{i}"] = qpc_out[f"past_key.{i}_RetainedState"] -# inputs[f"past_value.{i}"] = qpc_out[f"past_value.{i}_RetainedState"] - -# all_outputs.append(np.argmax(qpc_out["logits"])) -# print(all_outputs) -# print(">>>>>>>> export for prefill is done <<<<<<<<<<<") -# ########################### - -# decode_qpc_path = qeff_model.compile( -# prefill_seq_len=1, -# ctx_len=CTX_LEN, -# num_cores=16, -# mxfp6_matmul=True, -# mxint8_kv_cache=True, -# num_devices=1, -# mos=1, -# aic_enable_depth_first=True, -# num_speculative_tokens=None, -# offload_pt_weights=False, # Need the weights in memory for prefill-model export/compilation in the next step -# retain_full_kv=True, -# ) -# decode_session = QAICInferenceSession(decode_qpc_path) - -# decode_inputs = { -# "input_ids": np.argmax(qpc_out["logits"]).reshape(1, 1), -# "position_ids": np.max(inputs["position_ids"]).reshape(1, 1) + 1, -# } -# for i in range(config.num_hidden_layers): -# decode_inputs[f"past_key.{i}"] = qpc_out[f"past_key.{i}_RetainedState"] -# decode_inputs[f"past_value.{i}"] = qpc_out[f"past_value.{i}_RetainedState"] - -# st = time.time() -# decode_out = decode_session.run(decode_inputs) -# print(f"time for first run of decode with KV as input = {time.time() - st} sec\n") -# all_outputs.append(np.argmax(decode_out["logits"])) -# pos_id = np.max(decode_inputs["position_ids"]).reshape(1, 1) + 1 -# loop_decode_inputs = { -# "input_ids": np.argmax(decode_out["logits"]).reshape(1, 1), -# "position_ids": pos_id, -# } - -# for i in range(config.num_hidden_layers): -# loop_decode_inputs[f"past_key.{i}"] = decode_out[f"past_key.{i}_RetainedState"] -# loop_decode_inputs[f"past_value.{i}"] = decode_out[f"past_value.{i}_RetainedState"] - -# st = time.time() -# for i in range(generation_len - 2): -# decode_out = decode_session.run(loop_decode_inputs) -# all_outputs.append(np.argmax(decode_out["logits"])) -# pos_id += 1 -# for i in range(config.num_hidden_layers): -# loop_decode_inputs[f"past_key.{i}"] = decode_out[f"past_key.{i}_RetainedState"] -# loop_decode_inputs[f"past_value.{i}"] = decode_out[f"past_value.{i}_RetainedState"] - -# loop_decode_inputs.update( -# { -# "input_ids": np.argmax(decode_out["logits"]).reshape(1, 1), -# "position_ids": pos_id, -# } -# ) -# ft = time.time() - -# print(f"decode tok/sec={(generation_len - 2) / (ft - st)}") -# print(f"input\n{prompt}\noutput\n{tokenizer.decode(all_outputs)}") diff --git a/examples/image_text_to_text/models/qwen3_5_moe/qwen3_5_moe_layerwise.py b/examples/image_text_to_text/models/qwen3_5_moe/qwen3_5_moe_layerwise.py index cdf25ee369..a66f332ede 100644 --- a/examples/image_text_to_text/models/qwen3_5_moe/qwen3_5_moe_layerwise.py +++ b/examples/image_text_to_text/models/qwen3_5_moe/qwen3_5_moe_layerwise.py @@ -5,297 +5,54 @@ # # ----------------------------------------------------------------------------- -import functools -import os -from pathlib import Path +"""Layer-wise ONNX export + compile for Qwen3.5-MoE (dual-QPC). + +The language decoder is exported one window of layers at a time and merged into +a single ONNX graph equivalent to a full export; the vision encoder is exported +once. This keeps peak host memory bounded by a single window's language weights. + +Flow: + * from_pretrained(..., layerwise=True, kv_offload=True) builds the VLM on the + `meta` device; language weights are streamed per window during export. + * compile(..., layerwise_window_size=1[, total_layers=N]) drives the loop + internally and compiles the merged language ONNX + vision ONNX. +""" import torch -import transformers from transformers import AutoConfig -import QEfficient from QEfficient import QEFFAutoModelForImageTextToText -MODEL_ID = "Qwen/Qwen3.5-397B-A17B" +MODEL_ID = "Qwen/Qwen3.5-0.8B" +# MODEL_ID = "Qwen/Qwen3.5-397B-A17B" PREFILL_SEQ_LEN = 32 CTX_LEN = 4096 -TEXT_WINDOW_SIZE = 1 - -# For quick local validation only (keep disabled for real export) -# TEST_TEXT_LAYERS = 4 - -# Export controls BATCH_SIZE = 1 NUM_CORES = 16 NUM_DEVICES = 1 HEIGHT = 354 WIDTH = 536 +TEXT_WINDOW_SIZE = 1 +# Optional: export only the first N (>1) text layers for quick validation. +TOTAL_TEXT_LAYERS = 2 +# TOTAL_TEXT_LAYERS = None -def _ensure_pretrained_window_attrs(): - if not hasattr(transformers.modeling_utils.PreTrainedModel, "_start"): - transformers.modeling_utils.PreTrainedModel._start = 0 - if not hasattr(transformers.modeling_utils.PreTrainedModel, "_end"): - transformers.modeling_utils.PreTrainedModel._end = 0 - if not hasattr(transformers.modeling_utils.PreTrainedModel, "_total_layers"): - transformers.modeling_utils.PreTrainedModel._total_layers = 0 - if not hasattr(transformers.modeling_utils.PreTrainedModel, "_text_start"): - transformers.modeling_utils.PreTrainedModel._text_start = 0 - if not hasattr(transformers.modeling_utils.PreTrainedModel, "_text_end"): - transformers.modeling_utils.PreTrainedModel._text_end = 0 - if not hasattr(transformers.modeling_utils.PreTrainedModel, "_text_total_layers"): - transformers.modeling_utils.PreTrainedModel._text_total_layers = 0 - - -def _build_layer_windows(total_layers: int, window_size: int): - if total_layers <= 0: - raise ValueError(f"Invalid total_layers={total_layers}. Expected: total_layers > 0.") - if window_size <= 0: - raise ValueError(f"Invalid window_size={window_size}. Expected: window_size > 0.") - - windows = [] - end = total_layers - while end > 0: - start = max(0, end - window_size) - windows.append((start, end)) - end = start - return windows - - -def _get_text_layers_container(model): - # VLM path first - if ( - hasattr(model, "model") - and hasattr(model.model, "language_model") - and hasattr(model.model.language_model, "layers") - ): - return model.model.language_model.layers - # LLM-compatible fallbacks - if hasattr(model, "model") and hasattr(model.model, "layers"): - return model.model.layers - if hasattr(model, "language_model") and hasattr(model.language_model, "layers"): - return model.language_model.layers - if hasattr(model, "layers"): - return model.layers - return None - - -def _null_outside_window_layers(model, apply_text: bool = True): - if apply_text: - text_start = int( - getattr( - transformers.modeling_utils.PreTrainedModel, - "_text_start", - getattr(transformers.modeling_utils.PreTrainedModel, "_start", 0), - ) - ) - text_end = int( - getattr( - transformers.modeling_utils.PreTrainedModel, - "_text_end", - getattr(transformers.modeling_utils.PreTrainedModel, "_end", 0), - ) - ) - text_layers = _get_text_layers_container(model) - if text_layers is not None and text_end > text_start: - for idx, _ in enumerate(text_layers): - if idx < text_start or idx >= text_end: - text_layers[idx] = None - - -def _install_window_patch(model_cls): - if getattr(model_cls, "_window_patch_installed", False): - return - - original_init = model_cls.__init__ - - @functools.wraps(original_init) - def patched_init(self, *args, **kwargs): - original_init(self, *args, **kwargs) - _null_outside_window_layers(self, apply_text=True) - - model_cls.__init__ = patched_init - model_cls._window_patch_installed = True - - -def _resolve_export_root(onnx_path: Path) -> Path: - parts = list(onnx_path.parts) - if "onnx_layerwise_tmp" in parts: - marker_idx = parts.index("onnx_layerwise_tmp") - return Path(*parts[:marker_idx]) - return onnx_path.parent - - -def _install_shard_window_patch(): - if getattr(transformers.modeling_utils, "_window_shard_patch_installed", False): - return - - original_get_checkpoint_shard_files = transformers.modeling_utils.get_checkpoint_shard_files - - @functools.wraps(original_get_checkpoint_shard_files) - def patched_get_checkpoint_shard_files(*args, **kwargs): - shard_files, metadata = original_get_checkpoint_shard_files(*args, **kwargs) - weight_map = metadata.get("weight_map") - if not weight_map: - return shard_files, metadata - - start = int(getattr(transformers.modeling_utils.PreTrainedModel, "_start", 0)) - end = int(getattr(transformers.modeling_utils.PreTrainedModel, "_end", 0)) - text_start = int(getattr(transformers.modeling_utils.PreTrainedModel, "_text_start", start)) - text_end = int(getattr(transformers.modeling_utils.PreTrainedModel, "_text_end", end)) - has_text_window = text_end > text_start - if not has_text_window: - return shard_files, metadata - - selected_text_prefixes = tuple( - [f"model.layers.{layer_idx}." for layer_idx in range(text_start, text_end)] - + [f"model.language_model.layers.{layer_idx}." for layer_idx in range(text_start, text_end)] - ) - filtered_weight_map = {} - for checkpoint_key, shard_name in weight_map.items(): - if checkpoint_key.startswith("model.layers.") or checkpoint_key.startswith("model.language_model.layers."): - if not has_text_window or checkpoint_key.startswith(selected_text_prefixes): - filtered_weight_map[checkpoint_key] = shard_name - continue - filtered_weight_map[checkpoint_key] = shard_name - - if not filtered_weight_map: - return shard_files, metadata - - shard_name_to_path = {path.split("/")[-1]: path for path in shard_files} - filtered_shard_names = sorted(set(filtered_weight_map.values())) - filtered_shard_files = [shard_name_to_path[name] for name in filtered_shard_names if name in shard_name_to_path] - if not filtered_shard_files: - return shard_files, metadata - - metadata["weight_map"] = filtered_weight_map - metadata["all_checkpoint_keys"] = list(filtered_weight_map.keys()) - return filtered_shard_files, metadata - - transformers.modeling_utils.get_checkpoint_shard_files = patched_get_checkpoint_shard_files - transformers.modeling_utils._window_shard_patch_installed = True - - -def _set_layer_windows( - text_start: int, - text_end: int, - text_total_layers: int, -): - transformers.modeling_utils.PreTrainedModel._start = text_start - transformers.modeling_utils.PreTrainedModel._end = text_end - transformers.modeling_utils.PreTrainedModel._total_layers = text_total_layers - transformers.modeling_utils.PreTrainedModel._text_start = text_start - transformers.modeling_utils.PreTrainedModel._text_end = text_end - transformers.modeling_utils.PreTrainedModel._text_total_layers = text_total_layers - - qeff_mod = QEfficient.transformers.models.qwen3_5_moe.modeling_qwen3_5_moe - qeff_mod.QEffQwen3_5MoeTextModel._start = text_start - qeff_mod.QEffQwen3_5MoeTextModel._end = text_end - qeff_mod.QEffQwen3_5MoeTextModel._total_layers = text_total_layers - - QEfficient.base.modeling_qeff.QEFFBaseModel._start = text_start - QEfficient.base.modeling_qeff.QEFFBaseModel._end = text_end - QEfficient.base.modeling_qeff.QEFFBaseModel._total_layers = text_total_layers - - -def _stitch_layerwise_if_available(export_root: Path): - # Some branches expose this helper; fall back gracefully when unavailable. - pipeline_fn = getattr(QEfficient.utils, "layerwise_pipeline", None) - if callable(pipeline_fn): - return pipeline_fn(str(export_root)) - print(f"layerwise_pipeline() not found. Layer-wise ONNX shards kept under: {export_root / 'onnx_layerwise_tmp'}") - return str(export_root / "onnx_layerwise_tmp") +def main(): + config = AutoConfig.from_pretrained(MODEL_ID) + config.torch_dtype = "float32" -def _new_qeff_model(model_id: str, config): - return QEFFAutoModelForImageTextToText.from_pretrained( - model_id, + qeff_model = QEFFAutoModelForImageTextToText.from_pretrained( + MODEL_ID, attn_implementation="eager", kv_offload=True, config=config, torch_dtype=torch.float32, + layerwise=True, ) - -def main(): - config = AutoConfig.from_pretrained(MODEL_ID) - config.torch_dtype = "float32" - - # if TEST_TEXT_LAYERS: - # config.text_config.num_hidden_layers = TEST_TEXT_LAYERS - - text_config = getattr(config, "text_config", config) - # config.vision_config.depth = 3 - text_total_layers = getattr(text_config, "num_hidden_layers", None) - if text_total_layers is None: - raise ValueError("Could not resolve `num_hidden_layers` from config.text_config.") - config.text_config.num_hidden_layers = text_total_layers - _ensure_pretrained_window_attrs() - _install_shard_window_patch() - - hf_qwen_mod = transformers.models.qwen3_5_moe.modeling_qwen3_5_moe - _install_window_patch(hf_qwen_mod.Qwen3_5MoeForConditionalGeneration) - _install_window_patch(hf_qwen_mod.Qwen3_5MoeForCausalLM) - - text_windows = _build_layer_windows(total_layers=text_total_layers, window_size=TEXT_WINDOW_SIZE) - # Keep layerwise only on text path in this loop. - num_windows = len(text_windows) - first_onnx_path = None - os.environ["LAYERWISE_EXPORT"] = "True" - for window_idx in range(num_windows): - text_start, text_end = text_windows[window_idx] if window_idx < len(text_windows) else (0, 0) - skip_lang_for_window = window_idx >= len(text_windows) - - _set_layer_windows( - text_start=text_start, - text_end=text_end, - text_total_layers=text_total_layers, - ) - print( - f"Exporting window {window_idx + 1}/{num_windows} " - f"text=[{text_start},{text_end})/{text_total_layers} " - f"skip_lang={skip_lang_for_window}" - ) - - qeff_model = _new_qeff_model(MODEL_ID, config) - if hasattr(qeff_model, "model"): - _null_outside_window_layers( - qeff_model.model, - apply_text=not skip_lang_for_window, - ) - - onnx_path = qeff_model.compile( - batch_size=BATCH_SIZE, - prefill_seq_len=PREFILL_SEQ_LEN, - ctx_len=CTX_LEN, - num_cores=NUM_CORES, - num_devices=NUM_DEVICES, - height=HEIGHT, - width=WIDTH, - mxfp6_matmul=False, - aic_enable_depth_first=True, - skip_vision=True, - skip_lang=skip_lang_for_window, - prefill_only=True, - use_onnx_subfunctions=True, - enable_chunking=True, - mos=1, - user_tiled=True, - ) - - if first_onnx_path is None: - first_onnx_path = Path(str(onnx_path["lang_prefill_qpc_path"])) - - if first_onnx_path is None: - raise RuntimeError("No ONNX path produced during layer-wise language export.") - - export_root = _resolve_export_root(first_onnx_path) - final_artifact = _stitch_layerwise_if_available(export_root) - print(f"Layer-wise language export completed. Final artifact/root: {final_artifact}") - - os.environ["LAYERWISE_EXPORT"] = "False" - qpc_path = qeff_model.compile( - lang_onnx_path=final_artifact, + compile_kwargs = dict( batch_size=BATCH_SIZE, prefill_seq_len=PREFILL_SEQ_LEN, ctx_len=CTX_LEN, @@ -305,19 +62,18 @@ def main(): width=WIDTH, mxfp6_matmul=False, aic_enable_depth_first=True, - skip_vision=True, - skip_lang=skip_lang_for_window, prefill_only=True, - use_onnx_subfunctions=True, enable_chunking=True, + use_onnx_subfunctions=True, mos=1, + layerwise_window_size=TEXT_WINDOW_SIZE, ) + if TOTAL_TEXT_LAYERS is not None: + compile_kwargs["total_layers"] = TOTAL_TEXT_LAYERS - print(f"Final QPC path: {qpc_path}") + qpc_path = qeff_model.compile(**compile_kwargs) + print(f"Layer-wise export + compile completed. QPC paths: {qpc_path}") if __name__ == "__main__": main() - - -# /opt/qti-aic/exec/qaic-compile -aic-hw -aic-hw-version=ai100 -m=/home/abhishek/.cache/qeff_models/Qwen3_5MoeForConditionalGeneration/Qwen3_5MoeDecoderWrapper-61a4400d63d1b0bb/final_data/merged_0-2.onnx -retained-state -convert-to-fp16 -aic-num-cores=16 -aic-enable-depth-first -mos=1 -network-specialization-config=/home/abhishek/.cache/qeff_models/Qwen3_5MoeForConditionalGeneration/Qwen3_5MoeDecoderWrapper-61a4400d63d1b0bb/final_data/specializations.json -custom-IO-list-file=/home/abhishek/.cache/qeff_models/Qwen3_5MoeForConditionalGeneration/Qwen3_5MoeDecoderWrapper-61a4400d63d1b0bb/final_data/qpc_binaries/custom_io.yaml -aic-binary-dir=/home/abhishek/.cache/qeff_models/Qwen3_5MoeForConditionalGeneration/Qwen3_5MoeDecoderWrapper-61a4400d63d1b0bb/final_data/qpc_binaries/qpc diff --git a/examples/image_text_to_text/models/qwen3_5_moe/qwen3_5_moe_layerwise_decode.py b/examples/image_text_to_text/models/qwen3_5_moe/qwen3_5_moe_layerwise_decode.py index a5b7475f73..db056010c8 100644 --- a/examples/image_text_to_text/models/qwen3_5_moe/qwen3_5_moe_layerwise_decode.py +++ b/examples/image_text_to_text/models/qwen3_5_moe/qwen3_5_moe_layerwise_decode.py @@ -5,296 +5,48 @@ # # ----------------------------------------------------------------------------- -import functools -import os -from pathlib import Path +"""Layer-wise ONNX export + compile (decode) for qwen3_5_moe (dual-QPC). + +Same layer-wise language flow as the prefill script, but compiles the decode +specialization. The language decoder is exported one window of layers at a time +and merged into a single ONNX graph; the vision encoder is exported once. +""" import torch -import transformers from transformers import AutoConfig -import QEfficient from QEfficient import QEFFAutoModelForImageTextToText -MODEL_ID = "Qwen/Qwen3.5-397B-A17B" -PREFILL_SEQ_LEN = 1 +MODEL_ID = "Qwen/Qwen3.5-0.8B" +# MODEL_ID = "Qwen/Qwen3.5-397B-A17B" +PREFILL_SEQ_LEN = 32 CTX_LEN = 4096 -TEXT_WINDOW_SIZE = 1 - -# For quick local validation only (keep disabled for real export) -# TEST_TEXT_LAYERS = 4 - -# Export controls BATCH_SIZE = 1 NUM_CORES = 16 NUM_DEVICES = 1 HEIGHT = 354 WIDTH = 536 +TEXT_WINDOW_SIZE = 1 +# Optional: export only the first N (>1) text layers for quick validation. +TOTAL_TEXT_LAYERS = 2 +# TOTAL_TEXT_LAYERS = None -def _ensure_pretrained_window_attrs(): - if not hasattr(transformers.modeling_utils.PreTrainedModel, "_start"): - transformers.modeling_utils.PreTrainedModel._start = 0 - if not hasattr(transformers.modeling_utils.PreTrainedModel, "_end"): - transformers.modeling_utils.PreTrainedModel._end = 0 - if not hasattr(transformers.modeling_utils.PreTrainedModel, "_total_layers"): - transformers.modeling_utils.PreTrainedModel._total_layers = 0 - if not hasattr(transformers.modeling_utils.PreTrainedModel, "_text_start"): - transformers.modeling_utils.PreTrainedModel._text_start = 0 - if not hasattr(transformers.modeling_utils.PreTrainedModel, "_text_end"): - transformers.modeling_utils.PreTrainedModel._text_end = 0 - if not hasattr(transformers.modeling_utils.PreTrainedModel, "_text_total_layers"): - transformers.modeling_utils.PreTrainedModel._text_total_layers = 0 - - -def _build_layer_windows(total_layers: int, window_size: int): - if total_layers <= 0: - raise ValueError(f"Invalid total_layers={total_layers}. Expected: total_layers > 0.") - if window_size <= 0: - raise ValueError(f"Invalid window_size={window_size}. Expected: window_size > 0.") - - windows = [] - end = total_layers - while end > 0: - start = max(0, end - window_size) - windows.append((start, end)) - end = start - return windows - - -def _get_text_layers_container(model): - # VLM path first - if ( - hasattr(model, "model") - and hasattr(model.model, "language_model") - and hasattr(model.model.language_model, "layers") - ): - return model.model.language_model.layers - # LLM-compatible fallbacks - if hasattr(model, "model") and hasattr(model.model, "layers"): - return model.model.layers - if hasattr(model, "language_model") and hasattr(model.language_model, "layers"): - return model.language_model.layers - if hasattr(model, "layers"): - return model.layers - return None - - -def _null_outside_window_layers(model, apply_text: bool = True): - if apply_text: - text_start = int( - getattr( - transformers.modeling_utils.PreTrainedModel, - "_text_start", - getattr(transformers.modeling_utils.PreTrainedModel, "_start", 0), - ) - ) - text_end = int( - getattr( - transformers.modeling_utils.PreTrainedModel, - "_text_end", - getattr(transformers.modeling_utils.PreTrainedModel, "_end", 0), - ) - ) - text_layers = _get_text_layers_container(model) - if text_layers is not None and text_end > text_start: - for idx, _ in enumerate(text_layers): - if idx < text_start or idx >= text_end: - text_layers[idx] = None - - -def _install_window_patch(model_cls): - if getattr(model_cls, "_window_patch_installed", False): - return - - original_init = model_cls.__init__ - - @functools.wraps(original_init) - def patched_init(self, *args, **kwargs): - original_init(self, *args, **kwargs) - _null_outside_window_layers(self, apply_text=True) - - model_cls.__init__ = patched_init - model_cls._window_patch_installed = True - - -def _resolve_export_root(onnx_path: Path) -> Path: - parts = list(onnx_path.parts) - if "onnx_layerwise_tmp" in parts: - marker_idx = parts.index("onnx_layerwise_tmp") - return Path(*parts[:marker_idx]) - return onnx_path.parent - - -def _install_shard_window_patch(): - if getattr(transformers.modeling_utils, "_window_shard_patch_installed", False): - return - - original_get_checkpoint_shard_files = transformers.modeling_utils.get_checkpoint_shard_files - - @functools.wraps(original_get_checkpoint_shard_files) - def patched_get_checkpoint_shard_files(*args, **kwargs): - shard_files, metadata = original_get_checkpoint_shard_files(*args, **kwargs) - weight_map = metadata.get("weight_map") - if not weight_map: - return shard_files, metadata - - start = int(getattr(transformers.modeling_utils.PreTrainedModel, "_start", 0)) - end = int(getattr(transformers.modeling_utils.PreTrainedModel, "_end", 0)) - text_start = int(getattr(transformers.modeling_utils.PreTrainedModel, "_text_start", start)) - text_end = int(getattr(transformers.modeling_utils.PreTrainedModel, "_text_end", end)) - has_text_window = text_end > text_start - if not has_text_window: - return shard_files, metadata - - selected_text_prefixes = tuple( - [f"model.layers.{layer_idx}." for layer_idx in range(text_start, text_end)] - + [f"model.language_model.layers.{layer_idx}." for layer_idx in range(text_start, text_end)] - ) - filtered_weight_map = {} - for checkpoint_key, shard_name in weight_map.items(): - if checkpoint_key.startswith("model.layers.") or checkpoint_key.startswith("model.language_model.layers."): - if not has_text_window or checkpoint_key.startswith(selected_text_prefixes): - filtered_weight_map[checkpoint_key] = shard_name - continue - filtered_weight_map[checkpoint_key] = shard_name - - if not filtered_weight_map: - return shard_files, metadata - - shard_name_to_path = {path.split("/")[-1]: path for path in shard_files} - filtered_shard_names = sorted(set(filtered_weight_map.values())) - filtered_shard_files = [shard_name_to_path[name] for name in filtered_shard_names if name in shard_name_to_path] - if not filtered_shard_files: - return shard_files, metadata - - metadata["weight_map"] = filtered_weight_map - metadata["all_checkpoint_keys"] = list(filtered_weight_map.keys()) - return filtered_shard_files, metadata - - transformers.modeling_utils.get_checkpoint_shard_files = patched_get_checkpoint_shard_files - transformers.modeling_utils._window_shard_patch_installed = True - - -def _set_layer_windows( - text_start: int, - text_end: int, - text_total_layers: int, -): - transformers.modeling_utils.PreTrainedModel._start = text_start - transformers.modeling_utils.PreTrainedModel._end = text_end - transformers.modeling_utils.PreTrainedModel._total_layers = text_total_layers - transformers.modeling_utils.PreTrainedModel._text_start = text_start - transformers.modeling_utils.PreTrainedModel._text_end = text_end - transformers.modeling_utils.PreTrainedModel._text_total_layers = text_total_layers - - qeff_mod = QEfficient.transformers.models.qwen3_5_moe.modeling_qwen3_5_moe - qeff_mod.QEffQwen3_5MoeTextModel._start = text_start - qeff_mod.QEffQwen3_5MoeTextModel._end = text_end - qeff_mod.QEffQwen3_5MoeTextModel._total_layers = text_total_layers - - QEfficient.base.modeling_qeff.QEFFBaseModel._start = text_start - QEfficient.base.modeling_qeff.QEFFBaseModel._end = text_end - QEfficient.base.modeling_qeff.QEFFBaseModel._total_layers = text_total_layers - - -def _stitch_layerwise_if_available(export_root: Path): - # Some branches expose this helper; fall back gracefully when unavailable. - pipeline_fn = getattr(QEfficient.utils, "layerwise_pipeline", None) - if callable(pipeline_fn): - return pipeline_fn(str(export_root)) - print(f"layerwise_pipeline() not found. Layer-wise ONNX shards kept under: {export_root / 'onnx_layerwise_tmp'}") - return str(export_root / "onnx_layerwise_tmp") +def main(): + config = AutoConfig.from_pretrained(MODEL_ID) + config.torch_dtype = "float32" -def _new_qeff_model(model_id: str, config): - return QEFFAutoModelForImageTextToText.from_pretrained( - model_id, + qeff_model = QEFFAutoModelForImageTextToText.from_pretrained( + MODEL_ID, attn_implementation="eager", kv_offload=True, config=config, torch_dtype=torch.float32, + layerwise=True, ) - -def main(): - config = AutoConfig.from_pretrained(MODEL_ID) - config.torch_dtype = "float32" - - # if TEST_TEXT_LAYERS: - # config.text_config.num_hidden_layers = TEST_TEXT_LAYERS - - text_config = getattr(config, "text_config", config) - # config.vision_config.depth = 3 - text_total_layers = getattr(text_config, "num_hidden_layers", None) - if text_total_layers is None: - raise ValueError("Could not resolve `num_hidden_layers` from config.text_config.") - config.text_config.num_hidden_layers = text_total_layers - _ensure_pretrained_window_attrs() - _install_shard_window_patch() - - hf_qwen_mod = transformers.models.qwen3_5_moe.modeling_qwen3_5_moe - _install_window_patch(hf_qwen_mod.Qwen3_5MoeForConditionalGeneration) - _install_window_patch(hf_qwen_mod.Qwen3_5MoeForCausalLM) - - text_windows = _build_layer_windows(total_layers=text_total_layers, window_size=TEXT_WINDOW_SIZE) - # Keep layerwise only on text path in this loop. - num_windows = len(text_windows) - first_onnx_path = None - os.environ["LAYERWISE_EXPORT"] = "True" - for window_idx in range(num_windows): - text_start, text_end = text_windows[window_idx] if window_idx < len(text_windows) else (0, 0) - skip_lang_for_window = window_idx >= len(text_windows) - - _set_layer_windows( - text_start=text_start, - text_end=text_end, - text_total_layers=text_total_layers, - ) - print( - f"Exporting window {window_idx + 1}/{num_windows} " - f"text=[{text_start},{text_end})/{text_total_layers} " - f"skip_lang={skip_lang_for_window}" - ) - - qeff_model = _new_qeff_model(MODEL_ID, config) - if hasattr(qeff_model, "model"): - _null_outside_window_layers( - qeff_model.model, - apply_text=not skip_lang_for_window, - ) - - onnx_path = qeff_model.compile( - batch_size=BATCH_SIZE, - prefill_seq_len=PREFILL_SEQ_LEN, - ctx_len=CTX_LEN, - num_cores=NUM_CORES, - num_devices=NUM_DEVICES, - height=HEIGHT, - width=WIDTH, - mxfp6_matmul=False, - aic_enable_depth_first=True, - skip_vision=True, - skip_lang=skip_lang_for_window, - use_onnx_subfunctions=True, - enable_chunking=True, - mos=1, - user_tiled=True, - ) - - if first_onnx_path is None: - first_onnx_path = Path(str(onnx_path["lang_decode_qpc_path"])) - - if first_onnx_path is None: - raise RuntimeError("No ONNX path produced during layer-wise language export.") - - export_root = _resolve_export_root(first_onnx_path) - final_artifact = _stitch_layerwise_if_available(export_root) - print(f"Layer-wise language export completed. Final artifact/root: {final_artifact}") - - os.environ["LAYERWISE_EXPORT"] = "False" - qpc_path = qeff_model.compile( - lang_onnx_path=final_artifact, + compile_kwargs = dict( batch_size=BATCH_SIZE, prefill_seq_len=PREFILL_SEQ_LEN, ctx_len=CTX_LEN, @@ -302,20 +54,19 @@ def main(): num_devices=NUM_DEVICES, height=HEIGHT, width=WIDTH, - mxfp6_matmul=False, + mxfp6_matmul=True, aic_enable_depth_first=True, - skip_vision=True, - skip_lang=skip_lang_for_window, + split_retained_state_io=True, use_onnx_subfunctions=True, - enable_chunking=True, mos=1, + layerwise_window_size=TEXT_WINDOW_SIZE, ) + if TOTAL_TEXT_LAYERS is not None: + compile_kwargs["total_layers"] = TOTAL_TEXT_LAYERS - print(f"Final QPC path: {qpc_path}") + qpc_path = qeff_model.compile(**compile_kwargs) + print(f"Layer-wise decode export + compile completed. QPC paths: {qpc_path}") if __name__ == "__main__": main() - - -# /opt/qti-aic/exec/qaic-compile -aic-hw -aic-hw-version=ai100 -m=/home/abhishek/.cache/qeff_models/Qwen3_5MoeForConditionalGeneration/Qwen3_5MoeDecoderWrapper-61a4400d63d1b0bb/final_data/merged_0-2.onnx -retained-state -convert-to-fp16 -aic-num-cores=16 -aic-enable-depth-first -mos=1 -network-specialization-config=/home/abhishek/.cache/qeff_models/Qwen3_5MoeForConditionalGeneration/Qwen3_5MoeDecoderWrapper-61a4400d63d1b0bb/final_data/specializations.json -custom-IO-list-file=/home/abhishek/.cache/qeff_models/Qwen3_5MoeForConditionalGeneration/Qwen3_5MoeDecoderWrapper-61a4400d63d1b0bb/final_data/qpc_binaries/custom_io.yaml -aic-binary-dir=/home/abhishek/.cache/qeff_models/Qwen3_5MoeForConditionalGeneration/Qwen3_5MoeDecoderWrapper-61a4400d63d1b0bb/final_data/qpc_binaries/qpc diff --git a/examples/image_text_to_text/models/qwen3_vl_moe/qwen3_vl_moe_layerwise.py b/examples/image_text_to_text/models/qwen3_vl_moe/qwen3_vl_moe_layerwise.py index 990369fd9d..26d3f3b609 100644 --- a/examples/image_text_to_text/models/qwen3_vl_moe/qwen3_vl_moe_layerwise.py +++ b/examples/image_text_to_text/models/qwen3_vl_moe/qwen3_vl_moe_layerwise.py @@ -5,309 +5,54 @@ # # ----------------------------------------------------------------------------- -import functools -import os -from pathlib import Path +"""Layer-wise ONNX export + compile for Qwen3-VL-MoE (dual-QPC). + +The language decoder is exported one window of layers at a time and merged into +a single ONNX graph equivalent to a full export; the vision encoder is exported +once. This keeps peak host memory bounded by a single window's language weights. + +Flow: + * from_pretrained(..., layerwise=True, kv_offload=True) builds the VLM on the + `meta` device; language weights are streamed per window during export. + * compile(..., layerwise_window_size=1[, total_layers=N]) drives the loop + internally and compiles the merged language ONNX + vision ONNX. +""" import torch -import transformers from transformers import AutoConfig -import QEfficient from QEfficient import QEFFAutoModelForImageTextToText -MODEL_ID = "Qwen/Qwen3-VL-235B-A22B-Instruct" +MODEL_ID = "tiny-random/qwen3-vl-moe" +# MODEL_ID = "Qwen/Qwen3-VL-30B-A3B-Instruct" PREFILL_SEQ_LEN = 32 CTX_LEN = 4096 -TEXT_WINDOW_SIZE = 1 - -# For quick local validation only (keep disabled for real export) -# TEST_TEXT_LAYERS = 4 - -# Export controls BATCH_SIZE = 1 NUM_CORES = 16 -NUM_DEVICES = 4 +NUM_DEVICES = 1 HEIGHT = 354 WIDTH = 536 +TEXT_WINDOW_SIZE = 1 +# Optional: export only the first N (>1) text layers for quick validation. +# TOTAL_TEXT_LAYERS = 2 +TOTAL_TEXT_LAYERS = None -def _ensure_pretrained_window_attrs(): - if not hasattr(transformers.modeling_utils.PreTrainedModel, "_start"): - transformers.modeling_utils.PreTrainedModel._start = 0 - if not hasattr(transformers.modeling_utils.PreTrainedModel, "_end"): - transformers.modeling_utils.PreTrainedModel._end = 0 - if not hasattr(transformers.modeling_utils.PreTrainedModel, "_total_layers"): - transformers.modeling_utils.PreTrainedModel._total_layers = 0 - if not hasattr(transformers.modeling_utils.PreTrainedModel, "_text_start"): - transformers.modeling_utils.PreTrainedModel._text_start = 0 - if not hasattr(transformers.modeling_utils.PreTrainedModel, "_text_end"): - transformers.modeling_utils.PreTrainedModel._text_end = 0 - if not hasattr(transformers.modeling_utils.PreTrainedModel, "_text_total_layers"): - transformers.modeling_utils.PreTrainedModel._text_total_layers = 0 - - -def _build_layer_windows(total_layers: int, window_size: int): - if total_layers <= 0: - raise ValueError(f"Invalid total_layers={total_layers}. Expected: total_layers > 0.") - if window_size <= 0: - raise ValueError(f"Invalid window_size={window_size}. Expected: window_size > 0.") - - windows = [] - start = 0 - while start < total_layers: - end = min(total_layers, start + window_size) - windows.append((start, end)) - start = end - return windows - - -def _get_text_layers_container(model): - # VLM path first - if ( - hasattr(model, "model") - and hasattr(model.model, "language_model") - and hasattr(model.model.language_model, "layers") - ): - return model.model.language_model.layers - # LLM-compatible fallbacks - if hasattr(model, "model") and hasattr(model.model, "layers"): - return model.model.layers - if hasattr(model, "language_model") and hasattr(model.language_model, "layers"): - return model.language_model.layers - if hasattr(model, "layers"): - return model.layers - return None - - -def _null_outside_window_layers(model, apply_text: bool = True): - if apply_text: - text_start = int( - getattr( - transformers.modeling_utils.PreTrainedModel, - "_text_start", - getattr(transformers.modeling_utils.PreTrainedModel, "_start", 0), - ) - ) - text_end = int( - getattr( - transformers.modeling_utils.PreTrainedModel, - "_text_end", - getattr(transformers.modeling_utils.PreTrainedModel, "_end", 0), - ) - ) - text_layers = _get_text_layers_container(model) - if text_layers is not None and text_end > text_start: - for idx, _ in enumerate(text_layers): - if idx < text_start or idx >= text_end: - text_layers[idx] = None - - -def _install_window_patch(model_cls): - if getattr(model_cls, "_window_patch_installed", False): - return - - original_init = model_cls.__init__ - - @functools.wraps(original_init) - def patched_init(self, *args, **kwargs): - original_init(self, *args, **kwargs) - _null_outside_window_layers(self, apply_text=True) - - model_cls.__init__ = patched_init - model_cls._window_patch_installed = True - - -def _resolve_export_root(onnx_path: Path) -> Path: - parts = list(onnx_path.parts) - if "onnx_layerwise_tmp" in parts: - marker_idx = parts.index("onnx_layerwise_tmp") - return Path(*parts[:marker_idx]) - return onnx_path.parent - - -def _install_shard_window_patch(): - if getattr(transformers.modeling_utils, "_window_shard_patch_installed", False): - return - - original_get_checkpoint_shard_files = transformers.modeling_utils.get_checkpoint_shard_files - - @functools.wraps(original_get_checkpoint_shard_files) - def patched_get_checkpoint_shard_files(*args, **kwargs): - shard_files, metadata = original_get_checkpoint_shard_files(*args, **kwargs) - weight_map = metadata.get("weight_map") - if not weight_map: - return shard_files, metadata - - start = int(getattr(transformers.modeling_utils.PreTrainedModel, "_start", 0)) - end = int(getattr(transformers.modeling_utils.PreTrainedModel, "_end", 0)) - text_start = int(getattr(transformers.modeling_utils.PreTrainedModel, "_text_start", start)) - text_end = int(getattr(transformers.modeling_utils.PreTrainedModel, "_text_end", end)) - has_text_window = text_end > text_start - if not has_text_window: - return shard_files, metadata - - selected_text_prefixes = tuple( - [f"model.layers.{layer_idx}." for layer_idx in range(text_start, text_end)] - + [f"model.language_model.layers.{layer_idx}." for layer_idx in range(text_start, text_end)] - ) - filtered_weight_map = {} - for checkpoint_key, shard_name in weight_map.items(): - if checkpoint_key.startswith("model.layers.") or checkpoint_key.startswith("model.language_model.layers."): - if not has_text_window or checkpoint_key.startswith(selected_text_prefixes): - filtered_weight_map[checkpoint_key] = shard_name - continue - filtered_weight_map[checkpoint_key] = shard_name - - if not filtered_weight_map: - return shard_files, metadata - - shard_name_to_path = {path.split("/")[-1]: path for path in shard_files} - filtered_shard_names = sorted(set(filtered_weight_map.values())) - filtered_shard_files = [shard_name_to_path[name] for name in filtered_shard_names if name in shard_name_to_path] - if not filtered_shard_files: - return shard_files, metadata - - metadata["weight_map"] = filtered_weight_map - metadata["all_checkpoint_keys"] = list(filtered_weight_map.keys()) - return filtered_shard_files, metadata - - transformers.modeling_utils.get_checkpoint_shard_files = patched_get_checkpoint_shard_files - transformers.modeling_utils._window_shard_patch_installed = True - - -def _set_layer_windows( - text_start: int, - text_end: int, - text_total_layers: int, -): - transformers.modeling_utils.PreTrainedModel._start = text_start - transformers.modeling_utils.PreTrainedModel._end = text_end - transformers.modeling_utils.PreTrainedModel._total_layers = text_total_layers - transformers.modeling_utils.PreTrainedModel._text_start = text_start - transformers.modeling_utils.PreTrainedModel._text_end = text_end - transformers.modeling_utils.PreTrainedModel._text_total_layers = text_total_layers - - # Qwen3-VL-MoE model code still checks QEffQwen3_5MoeTextModel window attrs - # in a few places. Set both classes to keep layer-wise behavior consistent. - qeff_vl_mod = QEfficient.transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe - qeff_vl_mod.QEffQwen3VLMoeTextModel._start = text_start - qeff_vl_mod.QEffQwen3VLMoeTextModel._end = text_end - qeff_vl_mod.QEffQwen3VLMoeTextModel._total_layers = text_total_layers - - qeff_35_mod = getattr(QEfficient.transformers.models, "qwen3_5_moe", None) - if qeff_35_mod is not None: - qeff_35_text_model = getattr(qeff_35_mod.modeling_qwen3_5_moe, "QEffQwen3_5MoeTextModel", None) - if qeff_35_text_model is not None: - qeff_35_text_model._start = text_start - qeff_35_text_model._end = text_end - qeff_35_text_model._total_layers = text_total_layers - - QEfficient.base.modeling_qeff.QEFFBaseModel._start = text_start - QEfficient.base.modeling_qeff.QEFFBaseModel._end = text_end - QEfficient.base.modeling_qeff.QEFFBaseModel._total_layers = text_total_layers - - -def _stitch_layerwise_if_available(export_root: Path): - # Some branches expose this helper; fall back gracefully when unavailable. - pipeline_fn = getattr(QEfficient.utils, "layerwise_pipeline", None) - if callable(pipeline_fn): - return pipeline_fn(str(export_root)) - print(f"layerwise_pipeline() not found. Layer-wise ONNX shards kept under: {export_root / 'onnx_layerwise_tmp'}") - return str(export_root / "onnx_layerwise_tmp") +def main(): + config = AutoConfig.from_pretrained(MODEL_ID) + config.torch_dtype = "float32" -def _new_qeff_model(model_id: str, config): - return QEFFAutoModelForImageTextToText.from_pretrained( - model_id, + qeff_model = QEFFAutoModelForImageTextToText.from_pretrained( + MODEL_ID, attn_implementation="eager", kv_offload=True, config=config, torch_dtype=torch.float32, + layerwise=True, ) - -def main(): - config = AutoConfig.from_pretrained(MODEL_ID) - config.torch_dtype = "float32" - # config.vision_config.depth = 9 - # config.text_config.num_hidden_layers = 2 - config.vision_config.deepstack_visual_indexes = [8, 16, 24] - - # if TEST_TEXT_LAYERS: - # config.text_config.num_hidden_layers = TEST_TEXT_LAYERS - - text_config = getattr(config, "text_config", config) - text_total_layers = getattr(text_config, "num_hidden_layers", None) - if text_total_layers is None: - raise ValueError("Could not resolve `num_hidden_layers` from config.text_config.") - config.text_config.num_hidden_layers = text_total_layers - _ensure_pretrained_window_attrs() - _install_shard_window_patch() - - hf_qwen_mod = transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe - _install_window_patch(hf_qwen_mod.Qwen3VLMoeForConditionalGeneration) - if hasattr(hf_qwen_mod, "Qwen3VLMoeForCausalLM"): - _install_window_patch(hf_qwen_mod.Qwen3VLMoeForCausalLM) - - text_windows = _build_layer_windows(total_layers=text_total_layers, window_size=TEXT_WINDOW_SIZE) - # Keep layerwise only on text path in this loop. - num_windows = len(text_windows) - first_onnx_path = None - os.environ["LAYERWISE_EXPORT"] = "True" - for window_idx in range(num_windows): - text_start, text_end = text_windows[window_idx] if window_idx < len(text_windows) else (0, 0) - skip_lang_for_window = window_idx >= len(text_windows) - - _set_layer_windows( - text_start=text_start, - text_end=text_end, - text_total_layers=text_total_layers, - ) - print( - f"Exporting window {window_idx + 1}/{num_windows} " - f"text=[{text_start},{text_end})/{text_total_layers} " - f"skip_lang={skip_lang_for_window}" - ) - - qeff_model = _new_qeff_model(MODEL_ID, config) - if hasattr(qeff_model, "model"): - _null_outside_window_layers( - qeff_model.model, - apply_text=not skip_lang_for_window, - ) - - onnx_path = qeff_model.compile( - batch_size=BATCH_SIZE, - prefill_seq_len=PREFILL_SEQ_LEN, - ctx_len=CTX_LEN, - num_cores=NUM_CORES, - num_devices=NUM_DEVICES, - height=HEIGHT, - width=WIDTH, - mxfp6_matmul=True, - aic_enable_depth_first=True, - skip_vision=True, - skip_lang=skip_lang_for_window, - split_retained_state_io=True, - use_onnx_subfunctions=True, - prefill_only=True, - mos=1, - ) - - if first_onnx_path is None: - first_onnx_path = Path(str(onnx_path["lang_prefill_qpc_path"])) - - if first_onnx_path is None: - raise RuntimeError("No ONNX path produced during layer-wise language export.") - - export_root = _resolve_export_root(first_onnx_path) - final_artifact = _stitch_layerwise_if_available(export_root) - print(f"Layer-wise language export completed. Final artifact/root: {final_artifact}") - - os.environ["LAYERWISE_EXPORT"] = "False" - qpc_path = qeff_model.compile( - lang_onnx_path=final_artifact, + compile_kwargs = dict( batch_size=BATCH_SIZE, prefill_seq_len=PREFILL_SEQ_LEN, ctx_len=CTX_LEN, @@ -315,17 +60,19 @@ def main(): num_devices=NUM_DEVICES, height=HEIGHT, width=WIDTH, - mxfp6_matmul=True, + mxfp6_matmul=False, aic_enable_depth_first=True, - skip_vision=True, - skip_lang=skip_lang_for_window, - split_retained_state_io=True, - use_onnx_subfunctions=True, prefill_only=True, + enable_chunking=True, + use_onnx_subfunctions=True, mos=1, + layerwise_window_size=TEXT_WINDOW_SIZE, ) + if TOTAL_TEXT_LAYERS is not None: + compile_kwargs["total_layers"] = TOTAL_TEXT_LAYERS - print(f"Final QPC path: {qpc_path}") + qpc_path = qeff_model.compile(**compile_kwargs) + print(f"Layer-wise export + compile completed. QPC paths: {qpc_path}") if __name__ == "__main__": diff --git a/examples/image_text_to_text/models/qwen3_vl_moe/qwen3_vl_moe_layerwise_decode.py b/examples/image_text_to_text/models/qwen3_vl_moe/qwen3_vl_moe_layerwise_decode.py index 18a61f6c44..89a8767087 100644 --- a/examples/image_text_to_text/models/qwen3_vl_moe/qwen3_vl_moe_layerwise_decode.py +++ b/examples/image_text_to_text/models/qwen3_vl_moe/qwen3_vl_moe_layerwise_decode.py @@ -5,308 +5,48 @@ # # ----------------------------------------------------------------------------- -import functools -import os -from pathlib import Path +"""Layer-wise ONNX export + compile (decode) for qwen3_vl_moe (dual-QPC). + +Same layer-wise language flow as the prefill script, but compiles the decode +specialization. The language decoder is exported one window of layers at a time +and merged into a single ONNX graph; the vision encoder is exported once. +""" import torch -import transformers from transformers import AutoConfig -import QEfficient from QEfficient import QEFFAutoModelForImageTextToText -MODEL_ID = "Qwen/Qwen3-VL-235B-A22B-Instruct" -PREFILL_SEQ_LEN = 1 +MODEL_ID = "tiny-random/qwen3-vl-moe" +# MODEL_ID = "Qwen/Qwen3-VL-30B-A3B-Instruct" +PREFILL_SEQ_LEN = 32 CTX_LEN = 4096 -TEXT_WINDOW_SIZE = 1 - -# For quick local validation only (keep disabled for real export) -# TEST_TEXT_LAYERS = 4 - -# Export controls BATCH_SIZE = 1 NUM_CORES = 16 -NUM_DEVICES = 4 +NUM_DEVICES = 1 HEIGHT = 354 WIDTH = 536 +TEXT_WINDOW_SIZE = 1 +# Optional: export only the first N (>1) text layers for quick validation. +# TOTAL_TEXT_LAYERS = 2 +TOTAL_TEXT_LAYERS = None -def _ensure_pretrained_window_attrs(): - if not hasattr(transformers.modeling_utils.PreTrainedModel, "_start"): - transformers.modeling_utils.PreTrainedModel._start = 0 - if not hasattr(transformers.modeling_utils.PreTrainedModel, "_end"): - transformers.modeling_utils.PreTrainedModel._end = 0 - if not hasattr(transformers.modeling_utils.PreTrainedModel, "_total_layers"): - transformers.modeling_utils.PreTrainedModel._total_layers = 0 - if not hasattr(transformers.modeling_utils.PreTrainedModel, "_text_start"): - transformers.modeling_utils.PreTrainedModel._text_start = 0 - if not hasattr(transformers.modeling_utils.PreTrainedModel, "_text_end"): - transformers.modeling_utils.PreTrainedModel._text_end = 0 - if not hasattr(transformers.modeling_utils.PreTrainedModel, "_text_total_layers"): - transformers.modeling_utils.PreTrainedModel._text_total_layers = 0 - - -def _build_layer_windows(total_layers: int, window_size: int): - if total_layers <= 0: - raise ValueError(f"Invalid total_layers={total_layers}. Expected: total_layers > 0.") - if window_size <= 0: - raise ValueError(f"Invalid window_size={window_size}. Expected: window_size > 0.") - - windows = [] - start = 0 - while start < total_layers: - end = min(total_layers, start + window_size) - windows.append((start, end)) - start = end - return windows - - -def _get_text_layers_container(model): - # VLM path first - if ( - hasattr(model, "model") - and hasattr(model.model, "language_model") - and hasattr(model.model.language_model, "layers") - ): - return model.model.language_model.layers - # LLM-compatible fallbacks - if hasattr(model, "model") and hasattr(model.model, "layers"): - return model.model.layers - if hasattr(model, "language_model") and hasattr(model.language_model, "layers"): - return model.language_model.layers - if hasattr(model, "layers"): - return model.layers - return None - - -def _null_outside_window_layers(model, apply_text: bool = True): - if apply_text: - text_start = int( - getattr( - transformers.modeling_utils.PreTrainedModel, - "_text_start", - getattr(transformers.modeling_utils.PreTrainedModel, "_start", 0), - ) - ) - text_end = int( - getattr( - transformers.modeling_utils.PreTrainedModel, - "_text_end", - getattr(transformers.modeling_utils.PreTrainedModel, "_end", 0), - ) - ) - text_layers = _get_text_layers_container(model) - if text_layers is not None and text_end > text_start: - for idx, _ in enumerate(text_layers): - if idx < text_start or idx >= text_end: - text_layers[idx] = None - - -def _install_window_patch(model_cls): - if getattr(model_cls, "_window_patch_installed", False): - return - - original_init = model_cls.__init__ - - @functools.wraps(original_init) - def patched_init(self, *args, **kwargs): - original_init(self, *args, **kwargs) - _null_outside_window_layers(self, apply_text=True) - - model_cls.__init__ = patched_init - model_cls._window_patch_installed = True - - -def _resolve_export_root(onnx_path: Path) -> Path: - parts = list(onnx_path.parts) - if "onnx_layerwise_tmp" in parts: - marker_idx = parts.index("onnx_layerwise_tmp") - return Path(*parts[:marker_idx]) - return onnx_path.parent - - -def _install_shard_window_patch(): - if getattr(transformers.modeling_utils, "_window_shard_patch_installed", False): - return - - original_get_checkpoint_shard_files = transformers.modeling_utils.get_checkpoint_shard_files - - @functools.wraps(original_get_checkpoint_shard_files) - def patched_get_checkpoint_shard_files(*args, **kwargs): - shard_files, metadata = original_get_checkpoint_shard_files(*args, **kwargs) - weight_map = metadata.get("weight_map") - if not weight_map: - return shard_files, metadata - - start = int(getattr(transformers.modeling_utils.PreTrainedModel, "_start", 0)) - end = int(getattr(transformers.modeling_utils.PreTrainedModel, "_end", 0)) - text_start = int(getattr(transformers.modeling_utils.PreTrainedModel, "_text_start", start)) - text_end = int(getattr(transformers.modeling_utils.PreTrainedModel, "_text_end", end)) - has_text_window = text_end > text_start - if not has_text_window: - return shard_files, metadata - - selected_text_prefixes = tuple( - [f"model.layers.{layer_idx}." for layer_idx in range(text_start, text_end)] - + [f"model.language_model.layers.{layer_idx}." for layer_idx in range(text_start, text_end)] - ) - filtered_weight_map = {} - for checkpoint_key, shard_name in weight_map.items(): - if checkpoint_key.startswith("model.layers.") or checkpoint_key.startswith("model.language_model.layers."): - if not has_text_window or checkpoint_key.startswith(selected_text_prefixes): - filtered_weight_map[checkpoint_key] = shard_name - continue - filtered_weight_map[checkpoint_key] = shard_name - - if not filtered_weight_map: - return shard_files, metadata - - shard_name_to_path = {path.split("/")[-1]: path for path in shard_files} - filtered_shard_names = sorted(set(filtered_weight_map.values())) - filtered_shard_files = [shard_name_to_path[name] for name in filtered_shard_names if name in shard_name_to_path] - if not filtered_shard_files: - return shard_files, metadata - - metadata["weight_map"] = filtered_weight_map - metadata["all_checkpoint_keys"] = list(filtered_weight_map.keys()) - return filtered_shard_files, metadata - - transformers.modeling_utils.get_checkpoint_shard_files = patched_get_checkpoint_shard_files - transformers.modeling_utils._window_shard_patch_installed = True - - -def _set_layer_windows( - text_start: int, - text_end: int, - text_total_layers: int, -): - transformers.modeling_utils.PreTrainedModel._start = text_start - transformers.modeling_utils.PreTrainedModel._end = text_end - transformers.modeling_utils.PreTrainedModel._total_layers = text_total_layers - transformers.modeling_utils.PreTrainedModel._text_start = text_start - transformers.modeling_utils.PreTrainedModel._text_end = text_end - transformers.modeling_utils.PreTrainedModel._text_total_layers = text_total_layers - - # Qwen3-VL-MoE model code still checks QEffQwen3_5MoeTextModel window attrs - # in a few places. Set both classes to keep layer-wise behavior consistent. - qeff_vl_mod = QEfficient.transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe - qeff_vl_mod.QEffQwen3VLMoeTextModel._start = text_start - qeff_vl_mod.QEffQwen3VLMoeTextModel._end = text_end - qeff_vl_mod.QEffQwen3VLMoeTextModel._total_layers = text_total_layers - - qeff_35_mod = getattr(QEfficient.transformers.models, "qwen3_5_moe", None) - if qeff_35_mod is not None: - qeff_35_text_model = getattr(qeff_35_mod.modeling_qwen3_5_moe, "QEffQwen3_5MoeTextModel", None) - if qeff_35_text_model is not None: - qeff_35_text_model._start = text_start - qeff_35_text_model._end = text_end - qeff_35_text_model._total_layers = text_total_layers - - QEfficient.base.modeling_qeff.QEFFBaseModel._start = text_start - QEfficient.base.modeling_qeff.QEFFBaseModel._end = text_end - QEfficient.base.modeling_qeff.QEFFBaseModel._total_layers = text_total_layers - - -def _stitch_layerwise_if_available(export_root: Path): - # Some branches expose this helper; fall back gracefully when unavailable. - pipeline_fn = getattr(QEfficient.utils, "layerwise_pipeline", None) - if callable(pipeline_fn): - return pipeline_fn(str(export_root)) - print(f"layerwise_pipeline() not found. Layer-wise ONNX shards kept under: {export_root / 'onnx_layerwise_tmp'}") - return str(export_root / "onnx_layerwise_tmp") +def main(): + config = AutoConfig.from_pretrained(MODEL_ID) + config.torch_dtype = "float32" -def _new_qeff_model(model_id: str, config): - return QEFFAutoModelForImageTextToText.from_pretrained( - model_id, + qeff_model = QEFFAutoModelForImageTextToText.from_pretrained( + MODEL_ID, attn_implementation="eager", kv_offload=True, config=config, torch_dtype=torch.float32, + layerwise=True, ) - -def main(): - config = AutoConfig.from_pretrained(MODEL_ID) - config.torch_dtype = "float32" - # config.vision_config.depth = 9 - # config.text_config.num_hidden_layers = 2 - config.vision_config.deepstack_visual_indexes = [8, 27, 36] - - # if TEST_TEXT_LAYERS: - # config.text_config.num_hidden_layers = TEST_TEXT_LAYERS - - text_config = getattr(config, "text_config", config) - text_total_layers = getattr(text_config, "num_hidden_layers", None) - if text_total_layers is None: - raise ValueError("Could not resolve `num_hidden_layers` from config.text_config.") - config.text_config.num_hidden_layers = text_total_layers - _ensure_pretrained_window_attrs() - _install_shard_window_patch() - - hf_qwen_mod = transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe - _install_window_patch(hf_qwen_mod.Qwen3VLMoeForConditionalGeneration) - if hasattr(hf_qwen_mod, "Qwen3VLMoeForCausalLM"): - _install_window_patch(hf_qwen_mod.Qwen3VLMoeForCausalLM) - - text_windows = _build_layer_windows(total_layers=text_total_layers, window_size=TEXT_WINDOW_SIZE) - # Keep layerwise only on text path in this loop. - num_windows = len(text_windows) - first_onnx_path = None - os.environ["LAYERWISE_EXPORT"] = "True" - for window_idx in range(num_windows): - text_start, text_end = text_windows[window_idx] if window_idx < len(text_windows) else (0, 0) - skip_lang_for_window = window_idx >= len(text_windows) - - _set_layer_windows( - text_start=text_start, - text_end=text_end, - text_total_layers=text_total_layers, - ) - print( - f"Exporting window {window_idx + 1}/{num_windows} " - f"text=[{text_start},{text_end})/{text_total_layers} " - f"skip_lang={skip_lang_for_window}" - ) - - qeff_model = _new_qeff_model(MODEL_ID, config) - if hasattr(qeff_model, "model"): - _null_outside_window_layers( - qeff_model.model, - apply_text=not skip_lang_for_window, - ) - - onnx_path = qeff_model.compile( - batch_size=BATCH_SIZE, - prefill_seq_len=PREFILL_SEQ_LEN, - ctx_len=CTX_LEN, - num_cores=NUM_CORES, - num_devices=NUM_DEVICES, - height=HEIGHT, - width=WIDTH, - mxfp6_matmul=True, - aic_enable_depth_first=True, - skip_vision=True, - skip_lang=skip_lang_for_window, - split_retained_state_io=True, - use_onnx_subfunctions=True, - mos=1, - ) - - if first_onnx_path is None: - first_onnx_path = Path(str(onnx_path["lang_decode_qpc_path"])) - - if first_onnx_path is None: - raise RuntimeError("No ONNX path produced during layer-wise language export.") - - export_root = _resolve_export_root(first_onnx_path) - final_artifact = _stitch_layerwise_if_available(export_root) - print(f"Layer-wise language export completed. Final artifact/root: {final_artifact}") - - os.environ["LAYERWISE_EXPORT"] = "False" - qpc_path = qeff_model.compile( - lang_onnx_path=final_artifact, + compile_kwargs = dict( batch_size=BATCH_SIZE, prefill_seq_len=PREFILL_SEQ_LEN, ctx_len=CTX_LEN, @@ -316,14 +56,16 @@ def main(): width=WIDTH, mxfp6_matmul=True, aic_enable_depth_first=True, - skip_vision=True, - skip_lang=skip_lang_for_window, split_retained_state_io=True, use_onnx_subfunctions=True, mos=1, + layerwise_window_size=TEXT_WINDOW_SIZE, ) + if TOTAL_TEXT_LAYERS is not None: + compile_kwargs["total_layers"] = TOTAL_TEXT_LAYERS - print(f"Final QPC path: {qpc_path}") + qpc_path = qeff_model.compile(**compile_kwargs) + print(f"Layer-wise decode export + compile completed. QPC paths: {qpc_path}") if __name__ == "__main__": diff --git a/tests/utils/test_layerwise_utils.py b/tests/utils/test_layerwise_utils.py new file mode 100644 index 0000000000..f873f8b73b --- /dev/null +++ b/tests/utils/test_layerwise_utils.py @@ -0,0 +1,131 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import pytest + +from QEfficient.utils.custom_loader import CustomLoader +from QEfficient.utils.layerwise_utils import build_layer_windows + + +def test_build_layer_windows_divisible(): + assert build_layer_windows(4, 1) == [(3, 4), (2, 3), (1, 2), (0, 1)] + assert build_layer_windows(4, 2) == [(2, 4), (0, 2)] + + +def test_build_layer_windows_non_divisible(): + # Last (lowest) window is the smaller remainder. + assert build_layer_windows(5, 2) == [(3, 5), (1, 3), (0, 1)] + + +def test_build_layer_windows_invalid(): + with pytest.raises(ValueError): + build_layer_windows(0, 1) + with pytest.raises(ValueError): + build_layer_windows(4, 0) + + +def _make_loader(): + # Build without touching the network: bypass __init__ download. + loader = CustomLoader.__new__(CustomLoader) + loader.hf_auto_class = None + loader.pretrained_model_name_or_path = "dummy" + loader.layer_prefixes = ("model.layers.",) + loader.total_layers = 4 + loader.from_pretrained_kwargs = {} + return loader + + +def test_shard_filter_keeps_window_layers_and_edges(): + loader = _make_loader() + + weight_map = { + "model.embed_tokens.weight": "shard_a.safetensors", + "model.layers.0.self_attn.q_proj.weight": "shard_a.safetensors", + "model.layers.1.self_attn.q_proj.weight": "shard_b.safetensors", + "model.layers.2.self_attn.q_proj.weight": "shard_c.safetensors", + "model.layers.3.self_attn.q_proj.weight": "shard_d.safetensors", + "model.norm.weight": "shard_d.safetensors", + "lm_head.weight": "shard_d.safetensors", + } + shard_files = [f"/tmp/{name}" for name in sorted(set(weight_map.values()))] + + import transformers + + original = transformers.modeling_utils.get_checkpoint_shard_files + try: + transformers.modeling_utils.get_checkpoint_shard_files = lambda *a, **k: ( + list(shard_files), + {"weight_map": dict(weight_map)}, + ) + # Window [1, 2): keep layer 1 + all non-layer (edge) keys, drop layers 0/2/3. + with loader._shard_filter(1, 2): + files, meta = transformers.modeling_utils.get_checkpoint_shard_files() + finally: + transformers.modeling_utils.get_checkpoint_shard_files = original + + kept = set(meta["weight_map"].keys()) + assert "model.layers.1.self_attn.q_proj.weight" in kept + assert "model.layers.0.self_attn.q_proj.weight" not in kept + assert "model.layers.2.self_attn.q_proj.weight" not in kept + # Non-layer (edge) weights are always kept; HF decides where to place them. + assert {"model.embed_tokens.weight", "model.norm.weight", "lm_head.weight"} <= kept + + +def test_shard_filter_noop_without_weight_map(): + loader = _make_loader() + import transformers + + original = transformers.modeling_utils.get_checkpoint_shard_files + try: + transformers.modeling_utils.get_checkpoint_shard_files = lambda *a, **k: (["/tmp/x.safetensors"], {}) + with loader._shard_filter(0, 1): + files, meta = transformers.modeling_utils.get_checkpoint_shard_files() + finally: + transformers.modeling_utils.get_checkpoint_shard_files = original + + assert files == ["/tmp/x.safetensors"] + + +def test_custom_loader_multi_prefix_selection(): + loader = CustomLoader.__new__(CustomLoader) + loader.layer_prefixes = ("model.layers.", "model.language_model.layers.") + loader.total_layers = 4 + + weight_map = { + "model.vision_model.encoder.0.weight": "v.safetensors", + "model.multi_modal_projector.weight": "v.safetensors", + "model.language_model.embed_tokens.weight": "a.safetensors", + "model.language_model.layers.0.x": "a.safetensors", + "model.language_model.layers.1.x": "b.safetensors", + "model.language_model.layers.2.x": "c.safetensors", + "model.language_model.norm.weight": "c.safetensors", + "lm_head.weight": "c.safetensors", + } + shard_files = [f"/tmp/{name}" for name in sorted(set(weight_map.values()))] + + import transformers + + original = transformers.modeling_utils.get_checkpoint_shard_files + try: + transformers.modeling_utils.get_checkpoint_shard_files = lambda *a, **k: ( + list(shard_files), + {"weight_map": dict(weight_map)}, + ) + with loader._shard_filter(1, 2): + _, meta = transformers.modeling_utils.get_checkpoint_shard_files() + finally: + transformers.modeling_utils.get_checkpoint_shard_files = original + + kept = set(meta["weight_map"].keys()) + # Window [1,2): keep language layer 1, drop language layers 0/2. + assert "model.language_model.layers.1.x" in kept + assert "model.language_model.layers.0.x" not in kept + assert "model.language_model.layers.2.x" not in kept + # Vision encoder + projector + edges are always kept. + assert "model.vision_model.encoder.0.weight" in kept + assert "model.multi_modal_projector.weight" in kept + assert {"model.language_model.embed_tokens.weight", "model.language_model.norm.weight", "lm_head.weight"} <= kept