diff --git a/QEfficient/base/modeling_qeff.py b/QEfficient/base/modeling_qeff.py index 6f22e867ef..b07411999f 100644 --- a/QEfficient/base/modeling_qeff.py +++ b/QEfficient/base/modeling_qeff.py @@ -20,9 +20,7 @@ from QEfficient.base.onnx_transforms import ( BaseOnnxTransform, - FP16ClipTransform, OnnxTransformPipeline, - SplitTensorsTransform, ) from QEfficient.base.pytorch_transforms import PytorchTransform from QEfficient.compile.qnn_compiler import compile as qnn_compile @@ -54,8 +52,9 @@ class QEFFBaseModel(ABC): _pytorch_transforms: List[PytorchTransform] _onnx_transforms = [BaseOnnxTransform] - def _transform_names(self) -> List[str]: - return [x.__name__ for x in self._pytorch_transforms + self._onnx_transforms] + @classmethod + def _transform_names(cls) -> List[str]: + return [x.__name__ for x in cls._pytorch_transforms + cls._onnx_transforms] def __init__(self, model: torch.nn.Module, **kwargs) -> None: super().__init__() @@ -243,6 +242,7 @@ def _export( self.onnx_path = onnx_path return onnx_path + tmp_onnx_path = onnx_path # check if the model is in meta state or weights are offloaded self._model_offloaded_check() @@ -276,7 +276,7 @@ def _export( torch.onnx.export( self.model, (example_inputs,), - str(onnx_path), + str(tmp_onnx_path), input_names=input_names, output_names=output_names, dynamic_axes=dynamic_axes, @@ -285,13 +285,10 @@ def _export( ) logger.info("PyTorch export successful") _ = self._offload_model_weights(offload_pt_weights) - model = onnx.load(onnx_path, load_external_data=False) + model = onnx.load(tmp_onnx_path, load_external_data=False) - needs_external_tensor_data = any( - transform in self._onnx_transforms for transform in (FP16ClipTransform, SplitTensorsTransform) - ) transform_kwargs = { - "onnx_base_dir": str(export_dir) if needs_external_tensor_data else None, + "onnx_base_dir": str(export_dir), "model_name": self.model_name, } if onnx_transform_kwargs is not None: @@ -306,9 +303,7 @@ def _export( ) logger.info("ONNX transforms applied") - onnx_path_tmp = onnx_path.with_suffix(onnx_path.suffix + ".tmp") - onnx.save(model, onnx_path_tmp) - onnx_path_tmp.replace(onnx_path) + onnx.save(model, onnx_path) del model gc.collect() logger.info("Transformed ONNX saved") diff --git a/QEfficient/base/onnx_transforms.py b/QEfficient/base/onnx_transforms.py index 2ba53829a4..3d3567f2e5 100644 --- a/QEfficient/base/onnx_transforms.py +++ b/QEfficient/base/onnx_transforms.py @@ -236,8 +236,7 @@ def apply( do_split = SplitTensorsTransform in requested fp16_min, fp16_max = np.finfo(np.float16).min, np.finfo(np.float16).max file_num_tracker = {"num": 0, "size": 0} - if onnx_base_dir is not None: - external_data_helper.load_external_data_for_model(model, onnx_base_dir) + external_data_helper.load_external_data_for_model(model, onnx_base_dir) if do_fp16 or do_split: for tensor in external_data_helper._get_all_tensors(model): diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 2caaa345de..02bb573343 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -29,7 +29,7 @@ import QEfficient from QEfficient.base.modeling_qeff import QEFFBaseModel -from QEfficient.base.onnx_transforms import FP16ClipTransform +from QEfficient.base.onnx_transforms import FP16ClipTransform, SplitTensorsTransform from QEfficient.base.pytorch_transforms import SplitGateUpWeightsTransform from QEfficient.generation.cloud_infer import QAICInferenceSession from QEfficient.generation.text_generation_inference import ( @@ -231,7 +231,7 @@ class QEFFAutoModel(QEFFTransformersBase): _hf_auto_class = AutoModel _pytorch_transforms = [CustomOpsTransform, AwqToMatmulNbitsTransform, GPTQToMatmulNbitsTransform] - _onnx_transforms = [FP16ClipTransform] + _onnx_transforms = [FP16ClipTransform, SplitTensorsTransform] def __init__(self, model: nn.Module, pooling=None, **kwargs): """ @@ -619,7 +619,7 @@ class QEFFAutoModelForSequenceClassification(QEFFTransformersBase): _hf_auto_class = AutoModelForSequenceClassification _pytorch_transforms = [CustomOpsTransform, TextClassificationTransform] - _onnx_transforms = [] + _onnx_transforms = [FP16ClipTransform, SplitTensorsTransform] def __init__(self, model: nn.Module, **kwargs): """ @@ -861,7 +861,7 @@ class QEffVisionEncoderForTextImageToTextModel(QEFFBaseModel): KVCacheTransform, KVCacheExternalModuleMapperTransform, ] - _onnx_transforms = [] + _onnx_transforms = [FP16ClipTransform, SplitTensorsTransform] def __init__(self, model: nn.modules, **kwargs): """ @@ -1002,7 +1002,7 @@ class QEffCausalLMForTextImageToTextModel(QEFFBaseModel): VlmKVOffloadTransform, SplitGateUpWeightsTransform, ] - _onnx_transforms = [] + _onnx_transforms = [FP16ClipTransform, SplitTensorsTransform] def __init__(self, model, qaic_config: Optional[dict] = None, **kwargs): """ @@ -1932,7 +1932,7 @@ class _QEFFAutoModelForImageTextToTextSingleQPC(QEFFTransformersBase, Multimodal VlmNoKVOffloadTransform, SplitGateUpWeightsTransform, ] - _onnx_transforms = [] + _onnx_transforms = [FP16ClipTransform, SplitTensorsTransform] def __init__( self, @@ -2684,7 +2684,7 @@ class QEFFAutoModelForCausalLM(QEFFBaseModel): KVCacheExternalModuleMapperTransform, ] - _onnx_transforms = [] + _onnx_transforms = [FP16ClipTransform, SplitTensorsTransform] def prefill( self, @@ -3649,7 +3649,7 @@ class QEFFAutoModelForSpeechSeq2Seq(QEFFTransformersBase, MultimodalUtilityMixin _hf_auto_class = AutoModelForSpeechSeq2Seq _pytorch_transforms = [CustomOpsTransform, AwqToMatmulNbitsTransform, GPTQToMatmulNbitsTransform, KVCacheTransform] - _onnx_transforms = [] + _onnx_transforms = [FP16ClipTransform, SplitTensorsTransform] def __init__(self, model: nn.Module, **kwargs): """ @@ -4008,7 +4008,7 @@ class QEFFAutoModelForCTC(QEFFTransformersBase): _hf_auto_class = AutoModelForCTC _pytorch_transforms = [CustomOpsTransform, AwqToMatmulNbitsTransform, GPTQToMatmulNbitsTransform] - _onnx_transforms = [] + _onnx_transforms = [FP16ClipTransform, SplitTensorsTransform] def __init__(self, model: nn.Module, **kwargs): super().__init__(model, **kwargs)