Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 8 additions & 13 deletions QEfficient/base/modeling_qeff.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,7 @@

from QEfficient.base.onnx_transforms import (
BaseOnnxTransform,
FP16ClipTransform,
OnnxTransformPipeline,
SplitTensorsTransform,
)
from QEfficient.base.pytorch_transforms import PytorchTransform
from QEfficient.compile.qnn_compiler import compile as qnn_compile
Expand Down Expand Up @@ -54,8 +52,9 @@ class QEFFBaseModel(ABC):
_pytorch_transforms: List[PytorchTransform]
_onnx_transforms = [BaseOnnxTransform]

def _transform_names(self) -> List[str]:
return [x.__name__ for x in self._pytorch_transforms + self._onnx_transforms]
@classmethod
def _transform_names(cls) -> List[str]:
return [x.__name__ for x in cls._pytorch_transforms + cls._onnx_transforms]

def __init__(self, model: torch.nn.Module, **kwargs) -> None:
super().__init__()
Expand Down Expand Up @@ -243,6 +242,7 @@ def _export(
self.onnx_path = onnx_path
return onnx_path

tmp_onnx_path = onnx_path
# check if the model is in meta state or weights are offloaded
self._model_offloaded_check()

Expand Down Expand Up @@ -276,7 +276,7 @@ def _export(
torch.onnx.export(
self.model,
(example_inputs,),
str(onnx_path),
str(tmp_onnx_path),
input_names=input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
Expand All @@ -285,13 +285,10 @@ def _export(
)
logger.info("PyTorch export successful")
_ = self._offload_model_weights(offload_pt_weights)
model = onnx.load(onnx_path, load_external_data=False)
model = onnx.load(tmp_onnx_path, load_external_data=False)

needs_external_tensor_data = any(
transform in self._onnx_transforms for transform in (FP16ClipTransform, SplitTensorsTransform)
)
transform_kwargs = {
"onnx_base_dir": str(export_dir) if needs_external_tensor_data else None,
"onnx_base_dir": str(export_dir),
"model_name": self.model_name,
}
if onnx_transform_kwargs is not None:
Expand All @@ -306,9 +303,7 @@ def _export(
)
logger.info("ONNX transforms applied")

onnx_path_tmp = onnx_path.with_suffix(onnx_path.suffix + ".tmp")
onnx.save(model, onnx_path_tmp)
onnx_path_tmp.replace(onnx_path)
onnx.save(model, onnx_path)
del model
gc.collect()
logger.info("Transformed ONNX saved")
Expand Down
3 changes: 1 addition & 2 deletions QEfficient/base/onnx_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,8 +236,7 @@ def apply(
do_split = SplitTensorsTransform in requested
fp16_min, fp16_max = np.finfo(np.float16).min, np.finfo(np.float16).max
file_num_tracker = {"num": 0, "size": 0}
if onnx_base_dir is not None:
external_data_helper.load_external_data_for_model(model, onnx_base_dir)
external_data_helper.load_external_data_for_model(model, onnx_base_dir)

if do_fp16 or do_split:
for tensor in external_data_helper._get_all_tensors(model):
Expand Down
18 changes: 9 additions & 9 deletions QEfficient/transformers/models/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@

import QEfficient
from QEfficient.base.modeling_qeff import QEFFBaseModel
from QEfficient.base.onnx_transforms import FP16ClipTransform
from QEfficient.base.onnx_transforms import FP16ClipTransform, SplitTensorsTransform
from QEfficient.base.pytorch_transforms import SplitGateUpWeightsTransform
from QEfficient.generation.cloud_infer import QAICInferenceSession
from QEfficient.generation.text_generation_inference import (
Expand Down Expand Up @@ -231,7 +231,7 @@ class QEFFAutoModel(QEFFTransformersBase):

_hf_auto_class = AutoModel
_pytorch_transforms = [CustomOpsTransform, AwqToMatmulNbitsTransform, GPTQToMatmulNbitsTransform]
_onnx_transforms = [FP16ClipTransform]
_onnx_transforms = [FP16ClipTransform, SplitTensorsTransform]

def __init__(self, model: nn.Module, pooling=None, **kwargs):
"""
Expand Down Expand Up @@ -619,7 +619,7 @@ class QEFFAutoModelForSequenceClassification(QEFFTransformersBase):

_hf_auto_class = AutoModelForSequenceClassification
_pytorch_transforms = [CustomOpsTransform, TextClassificationTransform]
_onnx_transforms = []
_onnx_transforms = [FP16ClipTransform, SplitTensorsTransform]

def __init__(self, model: nn.Module, **kwargs):
"""
Expand Down Expand Up @@ -861,7 +861,7 @@ class QEffVisionEncoderForTextImageToTextModel(QEFFBaseModel):
KVCacheTransform,
KVCacheExternalModuleMapperTransform,
]
_onnx_transforms = []
_onnx_transforms = [FP16ClipTransform, SplitTensorsTransform]

def __init__(self, model: nn.modules, **kwargs):
"""
Expand Down Expand Up @@ -1002,7 +1002,7 @@ class QEffCausalLMForTextImageToTextModel(QEFFBaseModel):
VlmKVOffloadTransform,
SplitGateUpWeightsTransform,
]
_onnx_transforms = []
_onnx_transforms = [FP16ClipTransform, SplitTensorsTransform]

def __init__(self, model, qaic_config: Optional[dict] = None, **kwargs):
"""
Expand Down Expand Up @@ -1932,7 +1932,7 @@ class _QEFFAutoModelForImageTextToTextSingleQPC(QEFFTransformersBase, Multimodal
VlmNoKVOffloadTransform,
SplitGateUpWeightsTransform,
]
_onnx_transforms = []
_onnx_transforms = [FP16ClipTransform, SplitTensorsTransform]

def __init__(
self,
Expand Down Expand Up @@ -2684,7 +2684,7 @@ class QEFFAutoModelForCausalLM(QEFFBaseModel):
KVCacheExternalModuleMapperTransform,
]

_onnx_transforms = []
_onnx_transforms = [FP16ClipTransform, SplitTensorsTransform]

def prefill(
self,
Expand Down Expand Up @@ -3649,7 +3649,7 @@ class QEFFAutoModelForSpeechSeq2Seq(QEFFTransformersBase, MultimodalUtilityMixin

_hf_auto_class = AutoModelForSpeechSeq2Seq
_pytorch_transforms = [CustomOpsTransform, AwqToMatmulNbitsTransform, GPTQToMatmulNbitsTransform, KVCacheTransform]
_onnx_transforms = []
_onnx_transforms = [FP16ClipTransform, SplitTensorsTransform]

def __init__(self, model: nn.Module, **kwargs):
"""
Expand Down Expand Up @@ -4008,7 +4008,7 @@ class QEFFAutoModelForCTC(QEFFTransformersBase):

_hf_auto_class = AutoModelForCTC
_pytorch_transforms = [CustomOpsTransform, AwqToMatmulNbitsTransform, GPTQToMatmulNbitsTransform]
_onnx_transforms = []
_onnx_transforms = [FP16ClipTransform, SplitTensorsTransform]

def __init__(self, model: nn.Module, **kwargs):
super().__init__(model, **kwargs)
Expand Down
Loading