Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 20 additions & 13 deletions QEfficient/base/modeling_qeff.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import gc
import inspect
import logging
import os
import shutil
import subprocess
import warnings
Expand Down Expand Up @@ -90,7 +89,23 @@ def __init__(self, model: torch.nn.Module, **kwargs) -> None:
self.is_transformed: bool = False

self._normalize_torch_dtype()
# Apply the transformations
# Apply the transformations. For layer-wise export the model arrives on
# the `meta` device and the data-mutating transforms must run on the
# real per-window weights instead, so application is deferred and the
# loop calls `apply_pytorch_transforms()` after streaming each window.
if not getattr(self, "_defer_pytorch_transforms", False):
self.apply_pytorch_transforms()

if self.config.torch_dtype == torch.bfloat16:
logger.warning("BFloat16 dtype is not yet supported; converting to float16 precision!")

def apply_pytorch_transforms(self) -> bool:
"""Apply the class ``_pytorch_transforms`` to ``self.model`` in place.

Returns ``True`` if any transform reported a change. Used both by the
normal init flow and by the layer-wise export loop (after streaming each
window's real weights into ``self.model``).
"""
any_transformed = False
for transform in self._pytorch_transforms:
self.model, transformed = transform.apply(self.model)
Expand All @@ -100,9 +115,7 @@ def __init__(self, model: torch.nn.Module, **kwargs) -> None:
warnings.warn(f"No transforms applied to model: {self.model_name}. It may be an unsupported model!")
else:
logger.info(f"Pytorch transforms applied to model: {self.model_name}")

if self.config.torch_dtype == torch.bfloat16:
logger.warning("BFloat16 dtype is not yet supported; converting to float16 precision!")
return any_transformed

def _normalize_torch_dtype(self):
"""
Expand Down Expand Up @@ -488,8 +501,8 @@ def _export_layerwise(
prefill_only: Optional[bool] = False,
**export_kwargs,
) -> str:
idx = int(QEFFBaseModel._start)
end_idx = int(getattr(QEFFBaseModel, "_end", idx + 1))
idx = int(getattr(self, "_start", 0))
end_idx = int(getattr(self, "_end", idx + 1))
if end_idx <= idx:
raise ValueError(f"Invalid export window: start={idx}, end={end_idx}")

Expand All @@ -502,9 +515,6 @@ def _export_layerwise(
self.onnx_path = onnx_path
return onnx_path

# check if the model is in meta state or weights are offloaded
self._model_offloaded_check()

export_dir.mkdir(parents=True, exist_ok=True)

# Setup temporary paths
Expand Down Expand Up @@ -757,9 +767,6 @@ def _compile(
**compiler_options,
)
onnx_path = Path(onnx_path)
if os.environ.get("LAYERWISE_EXPORT", "False") == "True":
return onnx_path

compile_dir = Path(compile_dir or onnx_path.parent)
qpc_path = compile_dir / "qpc"
if not onnx_path.is_file():
Expand Down
Loading
Loading