Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 20 additions & 3 deletions QEfficient/base/modeling_qeff.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import gc
import inspect
import logging
import os
import shutil
import subprocess
import warnings
Expand Down Expand Up @@ -65,6 +64,7 @@ class QEFFBaseModel(ABC):
_start = 0
_end = 0
_total_layers = None
_layerwise_active = False
_pytorch_transforms: List[PytorchTransform]
_onnx_transforms = [BaseOnnxTransform]

Expand Down Expand Up @@ -502,6 +502,17 @@ def _export_layerwise(
self.onnx_path = onnx_path
return onnx_path

# Layer-wise reuse: if the merged final ONNX from a prior run exists
# under final_data/, skip per-window export entirely. The driver's
# stitch step picks up the same merged file, so re-running the same
# example without changes goes straight to the QPC compile.
final_data_dir = export_dir / "final_data"
if final_data_dir.is_dir():
cached_merged = sorted(final_data_dir.glob("merged_*.onnx"))
if cached_merged:
self.onnx_path = cached_merged[-1]
return self.onnx_path

# check if the model is in meta state or weights are offloaded
self._model_offloaded_check()

Expand Down Expand Up @@ -544,9 +555,15 @@ def _resolve_pkv_layers(pkv_obj):
z = example_inputs.pop("input_ids")
if is_vision:
hidden_size = self.model.language_model.config.hidden_size
embed_dtype = getattr(self.model.language_model.config, "torch_dtype", None)
else:
hidden_size = self.model.model.config.hidden_size
inputs_embeds = torch.rand(z.shape[0], z.shape[1], hidden_size, device=z.device)
embed_dtype = getattr(self.model.model.config, "torch_dtype", None)
# Match the model's dtype so per-window export does not introduce a
# float32/float16 mismatch when running through fp16 decoder layers.
if embed_dtype is None:
embed_dtype = next(self.model.parameters()).dtype
inputs_embeds = torch.rand(z.shape[0], z.shape[1], hidden_size, device=z.device, dtype=embed_dtype)
example_inputs["inputs_embeds"] = inputs_embeds
dynamic_axes["inputs_embeds"] = dynamic_axes.pop("input_ids")

Expand Down Expand Up @@ -757,7 +774,7 @@ def _compile(
**compiler_options,
)
onnx_path = Path(onnx_path)
if os.environ.get("LAYERWISE_EXPORT", "False") == "True":
if QEFFBaseModel._layerwise_active:
return onnx_path

compile_dir = Path(compile_dir or onnx_path.parent)
Expand Down
Loading
Loading