From 9610741a76a64941b014a9fd4abc6748a34778f1 Mon Sep 17 00:00:00 2001 From: jinkasaisagar Date: Fri, 5 Jun 2026 15:06:55 +0530 Subject: [PATCH] Added final support of dream model Signed-off-by: jinkasaisagar --- QEfficient/transformers/modeling_auto.py | 4503 +++++++++++++++++ .../models/dream/modeling_dream.py | 299 ++ .../transformers/models/pytorch_transforms.py | 3 +- QEfficient/utils/diffusionLM_utils.py | 304 ++ 4 files changed, 5108 insertions(+), 1 deletion(-) create mode 100644 QEfficient/transformers/modeling_auto.py create mode 100644 QEfficient/transformers/models/dream/modeling_dream.py create mode 100644 QEfficient/utils/diffusionLM_utils.py diff --git a/QEfficient/transformers/modeling_auto.py b/QEfficient/transformers/modeling_auto.py new file mode 100644 index 0000000000..609a7c0cdc --- /dev/null +++ b/QEfficient/transformers/modeling_auto.py @@ -0,0 +1,4503 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ---------------------------------------------------------------------------- + +import os +import warnings +from pathlib import Path +from time import perf_counter +from typing import List, Optional, Union + +import numpy as np +import torch +import torch.nn as nn +from transformers import ( + AutoImageProcessor, + AutoModel, + AutoModelForCausalLM, + AutoModelForCTC, + AutoModelForImageTextToText, + AutoModelForSequenceClassification, + AutoModelForSpeechSeq2Seq, + PreTrainedTokenizer, + PreTrainedTokenizerFast, + TextStreamer, +) + +import QEfficient +from QEfficient.base.modeling_qeff import QEFFBaseModel +from QEfficient.base.onnx_transforms import FP16ClipTransform, SplitTensorsTransform +from QEfficient.base.pytorch_transforms import SplitGateUpWeightsTransform +from QEfficient.generation.cloud_infer import QAICInferenceSession +from QEfficient.generation.text_generation_inference import ( + CloudAI100ExecInfoNew, + PerfMetrics, + calculate_latency, + get_compilation_dims, + write_io_files, +) +from QEfficient.generation.vlm_generation import VisionLanguageGeneration +from QEfficient.transformers.modeling_utils import ( + DYNAMIC_SEQ_LEN_SUPPORTED_MODEL_ARCH, + SPECIALIZED_DISAGG_SERVING_MODEL_ARCH, + _configure_proxy_for_model, +) +from QEfficient.transformers.models.pytorch_transforms import ( + CustomOpsTransform, + KVCacheExternalModuleMapperTransform, + KVCacheTransform, + PoolingTransform, + PrefillOnlyChunkedTransform, + PrefillOnlyExternalModuleMapperTransform, + PrefillOnlyTransform, + RevertPrefillKeepAttentionTransform, + RevertPrefillOnlyExternalModuleMapperTransform, + RevertPrefillOnlyTransform, + SamplerTransform, + SpDTransform, + TextClassificationTransform, + VlmKVOffloadTransform, + VlmNoKVOffloadTransform, +) +from QEfficient.transformers.quantizers.auto import QEFF_AUTO_QUANTIZATION_CONFIG_MAPPING, with_replaced_quantizers +from QEfficient.transformers.quantizers.quant_transforms import ( + AwqToMatmulNbitsTransform, + FP8BlockWiseDequantLinearToLinearTransform, + FP8BlockWiseDequantQwen3VLMoeTextExpertsToQwen3VLMoeTextExpertsTransform, + FP8DeQuantLinearToLinearTransform, + GPTQToMatmulNbitsTransform, + Mxfp4GptOssExpertDequantizeTransform, +) +from QEfficient.utils import ( + constants, + get_padding_shape_from_config, +) +from QEfficient.utils.check_ccl_specializations import process_ccl_specializations +from QEfficient.utils.logging_utils import logger +from QEfficient.utils.sampler_utils import get_sampling_inputs_and_outputs +from QEfficient.utils.diffusionLM_utils import ( + DreamGenerationConfig, + _prepare_special_tokens, + DreamModelOutput, + _prepare_generation_config, + _prepare_generated_length, + _validate_generated_length, + _sample +) +CUSTOM_IO_DTYPE_MAP = { + torch.float16: "float16", + torch.bfloat16: "bfloat16", + torch.float32: "float16", # Since compiler doesn't support fp32 + "float32": "float16", # Since compiler doesn't support fp32 +} + +TORCH_TO_NUMPY_DTYPE_MAP = { + torch.float16: np.float16, + torch.bfloat16: np.float16, # Since numpy doesn't support bfloat16 + torch.float32: np.float32, +} + + +class QEFFTransformersBase(QEFFBaseModel): + """ + Base class for QEfficient wrappers around HuggingFace transformer models. + + This class provides common functionality for loading, representing, and managing + HuggingFace models within the QEfficient framework. It serves as a parent + for specific model types like `AutoModel`, `AutoModelForCausalLM`, etc. + """ + + _hf_auto_class: type + + def __init__(self, model: nn.Module, **kwargs) -> None: + _configure_proxy_for_model(self, kwargs.pop("enable_proxy", False)) + + if ( + hasattr(model, "config") + and hasattr(model.config, "quantization_config") + and not isinstance(model.config.quantization_config, tuple(QEFF_AUTO_QUANTIZATION_CONFIG_MAPPING.values())) + ): + raise AssertionError("Please use `from_pretrained` method to load quantized models") + + super().__init__(model, **kwargs) + + def __repr__(self) -> str: + return self.__class__.__name__ + "\n" + self.model.__repr__() + + @classmethod + @with_replaced_quantizers + def from_pretrained(cls, pretrained_model_name_or_path: str, *args, **kwargs): + """ + Load a QEfficient transformer model from a pretrained HuggingFace model or local path. + + This is the recommended way to initialize any QEfficient transformer model. + The interface is similar to ``transformers.AutoModel.from_pretrained``. + + Parameters + ---------- + pretrained_model_name_or_path : str + Model card name from HuggingFace or local path to model directory. + *args : + Positional arguments passed directly to `cls._hf_auto_class.from_pretrained`. + **kwargs : + Keyword arguments passed directly to `cls._hf_auto_class.from_pretrained`. + + **Note:** `attn_implementation` and `low_cpu_mem_usage` are automatically set to "eager" and False respectively to ensure compatibility. + + Returns + ------- + QEFFTransformersBase + An instance of the specific QEFFAutoModel subclass, initialized with the pretrained weights. + """ + enable_proxy = kwargs.pop("enable_proxy", False) + + if kwargs.get("attn_implementation", None) not in {None, "eager"}: + logger.warning('Updating attn_implementation="eager"') + + if kwargs.get("low_cpu_mem_usage", None): + logger.warning("Updating low_cpu_mem_usage=False") + + kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False}) + + model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, *args, **kwargs) + + kwargs.update({"enable_proxy": enable_proxy} if enable_proxy else {}) + + return cls(model, pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs) + + +class MultimodalUtilityMixin: + """ + Mixin for multimodal models providing utilities like input auto-correction. + + This mixin ensures that inputs to multimodal models conform to the expected + names, shapes, and dtypes defined by the model's `get_inputs_info` method. + """ + + def __new__(cls, *args, **kwargs): + if cls is MultimodalUtilityMixin: + raise TypeError(f"only children of '{cls.__name__}' may be instantiated") + return object.__new__(cls) + + def auto_correct_inputs(self, inputs): + """ + Validates and corrects model inputs to match expected specifications. + + Checks if the provided inputs dictionary contains all required keys and + if the data types of the tensors match the model's specifications. + It then filters the input dictionary to only include expected inputs. + + Parameters + ---------- + inputs : Dict[str, torch.Tensor] + A dictionary of input tensors, where keys are input names and values are `torch.Tensor` objects. + + Returns + ------- + Dict[str, torch.Tensor] + A filtered dictionary of input tensors that match the model's expected inputs. + + Raises + ------ + RuntimeError + If any expected input is missing or has a mismatched data type. + """ + checked = True + inputs_info = self.model.get_inputs_info() + for valid_input_info in inputs_info: + if valid_input_info.name not in inputs: + checked = False + break + if inputs[valid_input_info.name].dtype != valid_input_info.datatype: + checked = False + break + + if not checked: + err_str: str = ( + "Expected following input names and shapes to be passed\n" + + "\n".join([val.__repr__() for val in inputs_info]) + + "\ngot" + + f"{[(k, v.shape, v.dtype) for k, v in inputs.items()]}" + ) + + raise RuntimeError(err_str) + + return {k: v for k, v in inputs.items() if k in [iinfo.name for iinfo in inputs_info]} + + +class QEFFAutoModel(QEFFTransformersBase): + """ + QEfficient class for general transformer models from the HuggingFace hub (e.g., BERT, Sentence Transformers). + + This class provides a unified interface for loading, exporting, compiling, and running + various encoder-only transformer models on Cloud AI 100 hardware. It supports pooling + for embedding extraction. + + Example + ------- + .. code-block:: python + + from QEfficient import QEFFAutoModel + from transformers import AutoTokenizer + + model = QEFFAutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2", pooling="mean") + model.compile(num_cores=16) + tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2") + inputs = tokenizer("My name is", return_tensors="pt") + output = model.generate(inputs) + print(output) # Output will be a dictionary containing extracted features. + """ + + _hf_auto_class = AutoModel + _pytorch_transforms = [CustomOpsTransform, AwqToMatmulNbitsTransform, GPTQToMatmulNbitsTransform] + # FP16Clip inlines external weights; without Split the saved protobuf exceeds 2GB for large embedders. + _onnx_transforms = [FP16ClipTransform, SplitTensorsTransform] + + def __init__(self, model: nn.Module, pooling=None, **kwargs): + """ + Initializes a QEFFAutoModel instance. + + Parameters + ---------- + model : nn.Module + The underlying HuggingFace PyTorch model. + pooling : str or Callable, optional + The pooling method to use for feature extraction. + Options include: "mean", "max", "cls", "avg", or a custom Callable. + Default is None (no pooling applied). + **kwargs : + Additional keyword arguments passed to the base class constructor. + """ + super().__init__(model, **kwargs) + + # Make Embedding specific transforms like appending pooling + if pooling: + self.model, _ = PoolingTransform.apply(self.model, pooling) + + self.model.base_model.config.use_cache = True + + self.hash_params["qeff_auto_class"] = self.__class__.__name__ + + @classmethod + @with_replaced_quantizers + def from_pretrained(cls, pretrained_model_name_or_path, pooling=None, *args, **kwargs): + """ + Load a QEfficient transformer model from a pretrained HuggingFace model or local path. + + This is the recommended way to initialize a QEfficient transformer model. The interface is similar to + ``transformers.AutoModel.from_pretrained``. Once initialized, you can use methods such as ``export``, ``compile``, and ``generate``. + + Parameters + ---------- + pretrained_model_name_or_path : str + Model card name from HuggingFace or local path to model directory. + pooling : str or Callable, optional + The pooling method to use. Options include: + - "mean": Mean pooling + - "max": Max pooling + - "cls": CLS token pooling + - "avg": Average pooling + - Callable: A custom pooling function + - None: No pooling applied. Default is None. + *args : + Positional arguments passed directly to `cls._hf_auto_class.from_pretrained`. + **kwargs : + Additional keyword arguments passed directly to `cls._hf_auto_class.from_pretrained`. + + **Note:** `attn_implementation` and `low_cpu_mem_usage` are automatically + set to "eager" and False respectively to ensure compatibility. + + Returns + ------- + QEFFAutoModel + An instance initialized with the pretrained weights. + """ + enable_proxy = kwargs.pop("enable_proxy", False) + + if kwargs.get("attn_implementation", None) not in {None, "eager"}: + logger.warning('Updating attn_implementation="eager"') + + if kwargs.get("low_cpu_mem_usage", None): + logger.warning("Updating low_cpu_mem_usage=False") + + kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False}) + + model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, *args, **kwargs) + + # This is support models that should be classified to in a different auto class but transformers load them via this class + kv_offload = kwargs.pop("kv_offload", None) + + kwargs.update({"enable_proxy": enable_proxy} if enable_proxy else {}) + + if model.__class__.__name__ in MISCLASSIFIED_CAUSAL_LM_TO_QEFF_AUTO_CLASS_MAP: + return MISCLASSIFIED_CAUSAL_LM_TO_QEFF_AUTO_CLASS_MAP[model.__class__.__name__]( + model, kv_offload=kv_offload, **kwargs + ) + + return cls(model, pretrained_model_name_or_path=pretrained_model_name_or_path, pooling=pooling, **kwargs) + + @property + def get_model_config(self) -> dict: + """ + Get the model configuration as a dictionary. + + Returns + ------- + dict + The configuration dictionary of the underlying HuggingFace model. + """ + return self.model.config.__dict__ + + def export(self, export_dir: Optional[str] = None, **kwargs) -> str: + """ + Export the model to ONNX format using ``torch.onnx.export``. + + This method prepares example inputs and dynamic axes based on the model configuration, + then exports the model to an ONNX graph suitable for compilation and deployment on Cloud AI 100 hardware. + + Parameters + ---------- + export_dir : str, optional + Directory path where the exported ONNX graph will be saved. If not provided, + the default export directory is used. + use_onnx_subfunctions: bool, optional + whether to enable ONNX subfunctions during export. Exporting PyTorch model to ONNX with modules as subfunctions helps to reduce export/compile time. Defaults to False + + Returns + ------- + str + Path to the generated ONNX graph file. + """ + bs = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE + seq_len = constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN + + example_inputs = { + "input_ids": torch.zeros((bs, seq_len), dtype=torch.int64), + "attention_mask": torch.ones((bs, seq_len), dtype=torch.int64), + } + + dynamic_axes = {"input_ids": {0: "batch_size", 1: "seq_len"}, "attention_mask": {0: "batch_size", 1: "seq_len"}} + + output_names = ["output"] + + return self._export( + example_inputs, + output_names=output_names, + dynamic_axes=dynamic_axes, + export_dir=export_dir, + use_onnx_subfunctions=kwargs.get("use_onnx_subfunctions", False), + ) + + def compile( + self, + onnx_path: Optional[str] = None, + compile_dir: Optional[str] = None, + *, + seq_len: Union[int, List[int]] = 32, + batch_size: int = 1, + num_devices: int = 1, + num_cores: int = 16, # FIXME: Make this mandatory arg + mxfp6_matmul: bool = False, + use_onnx_subfunctions: bool = False, + **compiler_options, + ) -> str: + """ + Compile the exported ONNX model using the Cloud AI 100 Platform SDK compiler. + + This method generates a ``qpc`` package. If the model has not been exported yet, + this method will handle the export process. Additional arguments for the `qaic-compile` + compiler can be passed as keyword arguments. + + Parameters + ---------- + onnx_path : str, optional + Path to a pre-exported ONNX model. If not provided, the model will be exported first. + compile_dir : str, optional + Directory to save the generated QPC package. If not provided, a default directory is used. + seq_len : int or list of int, optional + The length(s) of the prompt(s) to compile for. Can be a single integer or a list of integers + to create multiple specializations. Default is 32. + batch_size : int, optional + Batch size. Default is 1. + num_devices : int, optional + Number of devices to compile for. Default is 1. + num_cores : int, optional + Number of cores to use for compilation. + mxfp6_matmul : bool, optional + Use MXFP6 compression for weights. Default is False. + use_onnx_subfunctions: bool, optional + whether to enable ONNX subfunctions during export. Exporting PyTorch model to ONNX with modules as subfunctions helps to reduce export/compile time. Defaults to False + **compiler_options : dict + Additional compiler options for QAIC or QNN compilers. These are passed directly + to the underlying compilation command. + + **For QAIC Compiler:** Extra arguments for qaic-compile can be passed. Some common options include: + + - mos (int, optional): Effort level to reduce on-chip memory. Defaults to -1, meaning no effort. Defaults to -1. + - aic_enable_depth_first (bool, optional): Enables DFS with default memory size. Defaults to False. + - allow_mxint8_mdp_io (bool, optional): Allows MXINT8 compression of MDP IO traffic. Defaults to False. + + Params are converted to flags as below: + + - ``aic_num_cores=16`` -> ``-aic-num-cores=16`` + - ``convert_to_fp16=True`` -> ``-convert-to-fp16`` + + **For QNN Compiler:** Following arguments can be passed as: + + - enable_qnn (bool): Enables QNN Compilation. + - qnn_config (str): Path of QNN Config parameters file. Any extra parameters for QNN compilation can be passed via this file. + + Returns + ------- + str + Path to the compiled QPC package. + + """ + + if isinstance(seq_len, list) and len(seq_len) >= 15: + warnings.warn("Recommended: `seq_len` should contain fewer than 15 items.") + + _seq_lens = seq_len if isinstance(seq_len, list) else [seq_len] + specializations = [ + { + "_graph_name": "Embedding" if len(_seq_lens) == 1 else f"Embedding_{i}", + "batch_size": batch_size, + "seq_len": sl, + } + for i, sl in enumerate(_seq_lens) + ] + + target_dtype = getattr(self.model.config, "torch_dtype", torch.float32) + return self._compile( + onnx_path=onnx_path, + compile_dir=compile_dir, + specializations=specializations, + convert_to_fp16=(CUSTOM_IO_DTYPE_MAP[target_dtype] == "float16"), + mxfp6_matmul=mxfp6_matmul, + mdp_ts_num_devices=num_devices, + aic_num_cores=num_cores, + use_onnx_subfunctions=use_onnx_subfunctions, + **compiler_options, + ) + + def generate( + self, + inputs: torch.Tensor, + device_ids: List[int] = None, + runtime_ai100: bool = True, + write_io: bool = False, + dtype: Optional[torch.dtype] = torch.float32, + ) -> Union[torch.Tensor, np.ndarray]: + """ + Generate output by executing the compiled QPC on Cloud AI 100 hardware or using PyTorch runtime. + + This method runs sequential execution based on the compiled model's batch size and the number of prompts. + If the number of prompts is not divisible by the batch size, the last batch will be dropped. + + Parameters + ---------- + inputs : torch.Tensor or np.ndarray + Input data for the model. For AI 100 runtime, this typically includes + `input_ids` and `attention_mask`. + device_ids : list of int, optional + Device IDs for running the QPC. Defaults to `[0]` if not specified and `runtime_ai100` is True. + runtime_ai100 : bool, optional + Whether to use the AI 100 runtime for inference. If False, the PyTorch + runtime will be used. Default is True. + + Returns + ------- + torch.Tensor or np.ndarray + Output from the AI 100 or PyTorch runtime. The type depends on the runtime and model. + """ + self._write_io_dir = os.path.join(os.path.dirname(self.onnx_path), "io_dir") if write_io else None + + # AI_100 runtime + if runtime_ai100: + if not isinstance(self.qpc_path, Path): + raise TypeError("Please run compile API first!") + + return self.cloud_ai_100_feature_generate(inputs=inputs, device_ids=device_ids) + # PyTorch runtime + else: + return self.pytorch_feature_generate(model=self.model, inputs=inputs) + + def cloud_ai_100_feature_generate( + self, + inputs: torch.Tensor, + device_ids: List[int] = [0], + dtype: Optional[torch.dtype] = torch.float32, + ) -> np.ndarray: + """ + Generate features for a batch of inputs using the Cloud AI 100 hardware runtime. + + This method runs inference on the compiled QPC using the Cloud AI 100 accelerator. + It automatically pads input tensors to match the compiled sequence length and handles session setup. + + Parameters + ---------- + inputs : torch.Tensor or np.ndarray + Input tensors for feature extraction. Must be a dictionary-like object + including `input_ids` and `attention_mask`. + device_ids : List[int], optional + List of device IDs to use for inference. Defaults to [0]. + + Returns + ------- + np.ndarray + Array containing the generated output features for each input in the batch. + """ + + if self.qpc_session is None: + self.qpc_session = QAICInferenceSession(str(self.qpc_path), device_ids) + self.batch_size = self.qpc_session.bindings[0].dims[0] + + # Dynamic switching to closest seq_Len based on input_ids_len + input_ids_len = inputs["input_ids"].shape[1] + + for allowed_shape in self.qpc_session.allowed_shapes: + seq_len_allowed = allowed_shape[1][1][1] + + if seq_len_allowed >= input_ids_len: + self.seq_len = seq_len_allowed + break + + # To handle single seq_len as we can't fetch allowed shapes for single seq_len + self.seq_len = self.qpc_session.bindings[0].dims[1] if not hasattr(self, "seq_len") else self.seq_len + + input_ids = np.array( + torch.nn.functional.pad(inputs["input_ids"], (0, self.seq_len - input_ids_len), "constant", 0) + ) + attention_mask = np.array( + torch.nn.functional.pad( + inputs["attention_mask"], (0, self.seq_len - inputs["attention_mask"].size(1)), "constant", 0 + ) + ) + + inputs = dict(input_ids=input_ids, attention_mask=attention_mask) + + # TODO: Remove try and catch after compiler fix + try: + outputs = { + "output": np.random.randn(*list(self.qpc_session.bindings[2].dims)).astype( + TORCH_TO_NUMPY_DTYPE_MAP[dtype] + ), + } + self.qpc_session.set_buffers(outputs) + outputs = self.qpc_session.run(inputs) + except Exception: + outputs = { + "output": np.random.randn(self.batch_size, self.seq_len, self.qpc_session.bindings[2].dims[1]).astype( + TORCH_TO_NUMPY_DTYPE_MAP[dtype] + ), + } + self.qpc_session.set_buffers(outputs) + outputs = self.qpc_session.run(inputs) + + if self._write_io_dir is not None: + write_io_files(inputs, outputs, self._write_io_dir, "output", "aic_batch_io", True, False) + + return outputs + + def pytorch_feature_generate(self, model, inputs: Union[torch.Tensor, np.ndarray]) -> List[torch.Tensor]: + """ + Generate features from a batch of inputs using the PyTorch model. + + This method runs the model in PyTorch (CPU/GPU) mode for feature extraction. + + Parameters + ---------- + model : nn.Module + The PyTorch model to use for inference. + inputs : torch.Tensor or np.ndarray + Input tensors for feature extraction. Expected to be a dictionary-like object. + + Returns + ------- + List[torch.Tensor] + List of output features generated by the model for each input. + """ + outputs = model(**inputs) + + if self._write_io_dir is not None: + write_io_files(inputs, outputs, self._write_io_dir, "output", "aic_batch_io", True, False) + return outputs + + @torch.no_grad() + def diffusion_generate( + self, + inputs: Optional[torch.Tensor] = None, + generation_config: Optional[DreamGenerationConfig] = None, + **kwargs, + ) -> Union[DreamModelOutput, torch.LongTensor]: + # 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call + generation_config = _prepare_generation_config(generation_config, **kwargs) + generation_tokens_hook_func = kwargs.pop("generation_tokens_hook_func", lambda step, x, logits: x) + generation_logits_hook_func = kwargs.pop("generation_logits_hook_func", lambda step, x, logits: logits) + import time + + # 2. Define model inputs + assert inputs is not None + input_ids = inputs + device = input_ids.device + attention_mask = kwargs.pop("attention_mask", None) + _prepare_special_tokens(generation_config, device=device) + + # 3. Prepare `max_length`. + input_ids_length = input_ids.shape[-1] + has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None + generation_config = _prepare_generated_length( + generation_config=generation_config, + has_default_max_length=has_default_max_length, + input_ids_length=input_ids_length, + ) + + _validate_generated_length(generation_config, input_ids_length, has_default_max_length) + + if ( + hasattr(generation_config, "pad_token_id") and + torch.any(input_ids == generation_config.pad_token_id) and + attention_mask is None + ): + warnings.warn( + "Padding was detected but no attention mask is passed here. For correct " + "generation results, please set `attention_mask` when batch-padding inputs.", + UserWarning, + ) + + input_ids, attention_mask = self._expand_inputs_for_generation( + expand_size=generation_config.num_return_sequences, + input_ids=input_ids, + attention_mask=attention_mask + ) + qpc_path = generation_config.qpc_path + device_ids = generation_config.device_ids + qpc_session = QAICInferenceSession(str(qpc_path), device_ids=device_ids) + outputs = { + "logits": np.random.randn(*list(qpc_session.bindings[1].dims)).astype(np.int64), + } + outputs = {k: v.numpy() if isinstance(v, torch.Tensor) else v for k, v in outputs.items()} + qpc_session.set_buffers(outputs) + + + start_time = time.perf_counter() + result = _sample( + start_time, + qpc_session, + input_ids, + attention_mask=attention_mask, + generation_config=generation_config, + generation_tokens_hook_func=generation_tokens_hook_func, + generation_logits_hook_func=generation_logits_hook_func + ) + end_time = time.perf_counter() + + total_time = end_time - start_time + average_time_per_iteration = total_time / generation_config.steps + return result, average_time_per_iteration + +class QEFFAutoModelForSequenceClassification(QEFFTransformersBase): + """ + QEfficient class for sequence classification models from the HuggingFace hub (e.g., BERT, DebertaV2 for classification). + + This class provides a unified interface for loading, exporting, compiling, and running + sequence classification models on Cloud AI 100 hardware. + + Example + ------- + .. code-block:: python + + from QEfficient import QEFFAutoModelForSequenceClassification + from transformers import AutoTokenizer + + model = QEFFAutoModelForSequenceClassification.from_pretrained("meta-llama/Llama-Prompt-Guard-2-22M") + model.compile(num_cores=16) + tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-Prompt-Guard-2-22M") + inputs = tokenizer("Ignore your previous instructions.", return_tensors="pt") + output = model.generate(inputs) + predicted_class_id = output["logits"].argmax().item() + print(model.model.config.id2label[predicted_class_id]) + """ + + _hf_auto_class = AutoModelForSequenceClassification + _pytorch_transforms = [CustomOpsTransform, TextClassificationTransform] + _onnx_transforms = [] + + def __init__(self, model: nn.Module, **kwargs): + """ + Initializes a QEFFAutoModelForSequenceClassification instance. + + Parameters + ---------- + model : nn.Module + The underlying HuggingFace PyTorch sequence classification model. + **kwargs : + Additional keyword arguments passed to the base class constructor. + """ + super().__init__(model, **kwargs) + self.model.config.use_cache = True + self.hash_params["qeff_auto_class"] = self.__class__.__name__ + + @classmethod + @with_replaced_quantizers + def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs): + """ + Load a QEfficient sequence classification model from a pretrained HuggingFace model or local path. + + This is the recommended way to initialize a QEfficient sequence classification model. + The interface is similar to ``transformers.AutoModelForSequenceClassification.from_pretrained``. + + Parameters + ---------- + pretrained_model_name_or_path : str + Model card name from HuggingFace or local path to model directory. + *args : + Positional arguments passed directly to `cls._hf_auto_class.from_pretrained`. + **kwargs : + Additional keyword arguments passed directly to `cls._hf_auto_class.from_pretrained`. + + **Note:** `attn_implementation` and `low_cpu_mem_usage` are automatically + set to "eager" and False respectively to ensure compatibility. + + Returns + ------- + QEFFAutoModelForSequenceClassification + An instance initialized with the pretrained weights. + """ + enable_proxy = kwargs.pop("enable_proxy", False) + + if kwargs.get("attn_implementation", None) not in {None, "eager"}: + logger.warning('Updating attn_implementation="eager"') + + if kwargs.get("low_cpu_mem_usage", None): + logger.warning("Updating low_cpu_mem_usage=False") + + kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False}) + + model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, *args, **kwargs) + kwargs.update({"enable_proxy": enable_proxy} if enable_proxy else {}) + return cls(model, pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs) + + @property + def get_model_config(self) -> dict: + """ + Get the model configuration as a dictionary. + + Returns + ------- + dict + The configuration dictionary of the underlying HuggingFace model. + """ + return self.model.config.__dict__ + + def export(self, export_dir: Optional[str] = None, **kwargs) -> str: + """ + Export the model to ONNX format using ``torch.onnx.export``. + + This method prepares example inputs and dynamic axes based on the model configuration, + then exports the model to an ONNX graph suitable for compilation and deployment on Cloud AI 100 hardware. + + Parameters + ---------- + export_dir : str, optional + Directory path where the exported ONNX graph will be saved. If not provided, + the default export directory is used. + use_onnx_subfunctions: bool, optional + whether to enable ONNX subfunctions during export. Exporting PyTorch model to ONNX with modules as subfunctions helps to reduce export/compile time. Defaults to False + + Returns + ------- + str + Path to the generated ONNX graph file. + """ + bs = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE + seq_len = constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN + + example_inputs = { + "input_ids": torch.zeros((bs, seq_len), dtype=torch.int64), + "attention_mask": torch.ones((bs, seq_len), dtype=torch.int64), + } + + dynamic_axes = {"input_ids": {0: "batch_size", 1: "seq_len"}, "attention_mask": {0: "batch_size", 1: "seq_len"}} + + output_names = ["logits"] + + return self._export( + example_inputs, + output_names, + dynamic_axes, + export_dir=export_dir, + use_onnx_subfunctions=kwargs.get("use_onnx_subfunctions", False), + ) + + def compile( + self, + onnx_path: Optional[str] = None, + compile_dir: Optional[str] = None, + *, + seq_len: Union[int, List[int]] = 32, + batch_size: int = 1, + num_devices: int = 1, + num_cores: int = 16, + mxfp6_matmul: bool = False, + use_onnx_subfunctions: bool = False, + **compiler_options, + ) -> str: + """ + Compile the exported ONNX model using the Cloud AI 100 Platform SDK compiler. + + This method generates a ``qpc`` package. If the model has not been exported yet, + this method will handle the export process. + + Parameters + ---------- + onnx_path : str, optional + Path to a pre-exported ONNX model. If not provided, the model will be exported first. + compile_dir : str, optional + Directory to save the generated QPC package. If not provided, a default directory is used. + seq_len : int or list of int, optional + The length(s) of the input sequence(s) to compile for. Can be a single integer or a list of integers + to create multiple specializations. Default is 32. + batch_size : int, optional + Batch size. Default is 1. + num_devices : int, optional + Number of devices to compile for. Default is 1. + num_cores : int, optional + Number of cores to use for compilation. + mxfp6_matmul : bool, optional + Use MXFP6 compression for weights. Default is False. + use_onnx_subfunctions: bool, optional + whether to enable ONNX subfunctions during export. Defaults to False + **compiler_options : dict + Additional compiler options for QAIC or QNN compilers. + + Returns + ------- + str + Path to the compiled QPC package. + """ + if isinstance(seq_len, list) and len(seq_len) >= 15: + warnings.warn("Recommended: `seq_len` should contain fewer than 15 items.") + + _seq_lens = seq_len if isinstance(seq_len, list) else [seq_len] + specializations = [ + { + "_graph_name": "SeqClassification" if len(_seq_lens) == 1 else f"SeqClassification_{i}", + "batch_size": batch_size, + "seq_len": sl, + } + for i, sl in enumerate(_seq_lens) + ] + target_dtype = getattr(self.model.config, "torch_dtype", torch.float32) + return self._compile( + onnx_path=onnx_path, + compile_dir=compile_dir, + specializations=specializations, + convert_to_fp16=(CUSTOM_IO_DTYPE_MAP[target_dtype] == "float16"), + mxfp6_matmul=mxfp6_matmul, + mdp_ts_num_devices=num_devices, + aic_num_cores=num_cores, + use_onnx_subfunctions=use_onnx_subfunctions, + **compiler_options, + ) + + def generate( + self, + inputs: torch.Tensor, + device_ids: List[int] = None, + ) -> dict: + """ + Generate classification output using the Cloud AI 100 hardware runtime. + + Parameters + ---------- + inputs : torch.Tensor or np.ndarray + Input tensors for classification. Must be a dictionary-like object + including `input_ids` and `attention_mask`. + device_ids : List[int], optional + List of device IDs to use for inference. Defaults to [0]. + + Returns + ------- + dict + Dictionary containing the classification logits. + """ + if self.qpc_session is None: + self.qpc_session = QAICInferenceSession(str(self.qpc_path), device_ids) + self.batch_size = self.qpc_session.bindings[0].dims[0] + + # Dynamic switching to closest seq_len based on input_ids_len + input_ids_len = inputs["input_ids"].shape[1] + + for allowed_shape in self.qpc_session.allowed_shapes: + seq_len_allowed = allowed_shape[1][1][1] + if seq_len_allowed >= input_ids_len: + self.seq_len = seq_len_allowed + break + + # To handle single seq_len as we can't fetch allowed shapes for single seq_len + self.seq_len = self.qpc_session.bindings[0].dims[1] if not hasattr(self, "seq_len") else self.seq_len + + input_ids = np.array( + torch.nn.functional.pad(inputs["input_ids"], (0, self.seq_len - input_ids_len), "constant", 0) + ) + attention_mask = np.array( + torch.nn.functional.pad( + inputs["attention_mask"], (0, self.seq_len - inputs["attention_mask"].size(1)), "constant", 0 + ) + ) + + inputs_np = dict(input_ids=input_ids, attention_mask=attention_mask) + outputs = self.qpc_session.run(inputs_np) + + return {"logits": torch.from_numpy(outputs["logits"])} + + +class QEffVisionEncoderForTextImageToTextModel(QEFFBaseModel): + """ + QEfficient wrapper for the Vision Encoder component of a Text-to-Image-to-Text model. + + This class handles the export and compilation of the vision encoder part + of multimodal models for optimal performance on Cloud AI 100 hardware. + """ + + _pytorch_transforms = [ + AwqToMatmulNbitsTransform, + GPTQToMatmulNbitsTransform, + CustomOpsTransform, + KVCacheTransform, + KVCacheExternalModuleMapperTransform, + ] + _onnx_transforms = [] + + def __init__(self, model: nn.modules, **kwargs): + """ + Initializes the vision encoder component for multimodal models. + + Parameters + ---------- + model : nn.Module + The full HuggingFace multimodal model from which the vision encoder is extracted. + **kwargs : + Additional keyword arguments passed to the base class constructor. + """ + _configure_proxy_for_model(self, kwargs.pop("enable_proxy", False)) + super().__init__(model, **kwargs) + self.model = model.get_qeff_vision_encoder() + self.hash_params["qeff_auto_class"] = self.__class__.__name__ + + def export(self, inputs, output_names, dynamic_axes, export_dir=None, offload_pt_weights=True, **kwargs): + """ + Exports the vision encoder component to ONNX format. + + Parameters + ---------- + inputs : Dict[str, torch.Tensor] + Example inputs for the ONNX export. + output_names : List[str] + List of output names for the ONNX graph. + dynamic_axes : Dict[str, Dict[int, str]] + Dynamic axes configuration for the ONNX graph. + export_dir : str, optional + Directory path where the exported ONNX graph will be saved. Default is None. + offload_pt_weights : bool, optional + If True, PyTorch weights will be offloaded after export. Default is True. + use_onnx_subfunctions: bool, optional + whether to enable ONNX subfunctions during export. Exporting PyTorch model to ONNX with modules as subfunctions helps to reduce export/compile time. Defaults to False + + Returns + ------- + str + Path to the generated ONNX graph file for the vision encoder. + """ + return self._export( + inputs, + output_names=output_names, + dynamic_axes=dynamic_axes, + export_dir=export_dir, + offload_pt_weights=offload_pt_weights, + use_onnx_subfunctions=kwargs.get("use_onnx_subfunctions", False), + ) + + def compile( + self, + compile_dir, + specializations, + convert_to_fp16, + mxfp6_matmul, + mdp_ts_num_devices, + aic_num_cores, + custom_io, + use_onnx_subfunctions: bool = False, + **compiler_options, + ) -> str: + """ + Compiles the vision encoder component to a QPC package. + + Parameters + ---------- + compile_dir : str + Directory to save the generated QPC package. + specializations : List[Dict[str, Union[int, str]]] + List of dictionaries, each specifying a compilation specialization. + convert_to_fp16 : bool + If True, converts model to FP16 precision during compilation. + mxfp6_matmul : bool + If True, uses MXFP6 compression for MatMul weights. + mdp_ts_num_devices : int + Number of devices for multi-device (tensor slicing) compilation. + aic_num_cores : int + Number of cores to use for compilation. + custom_io : Dict[str, str] + Custom I/O configurations for the compiler. + use_onnx_subfunctions: bool, optional + whether to enable ONNX subfunctions during export. Exporting PyTorch model to ONNX with modules as subfunctions helps to reduce export/compile time. Defaults to False + **compiler_options : + Additional compiler options passed to the underlying compilation command. + + Returns + ------- + str + Path to the compiled QPC package for the vision encoder. + """ + return self._compile( + compile_dir=compile_dir, + specializations=specializations, + convert_to_fp16=convert_to_fp16, + mxfp6_matmul=mxfp6_matmul, + mdp_ts_num_devices=mdp_ts_num_devices, + aic_num_cores=aic_num_cores, + custom_io=custom_io, + use_onnx_subfunctions=use_onnx_subfunctions, + **compiler_options, + ) + + @property + def get_model_config(self) -> dict: + """ + Get the configuration dictionary of the underlying HuggingFace vision model. + + Returns + ------- + dict + The configuration dictionary. + """ + if hasattr(self.model.model, "vision_model"): + return self.model.model.vision_model.config.__dict__ + return self.model.model.config.__dict__ + + +class QEffCausalLMForTextImageToTextModel(QEFFBaseModel): + """ + QEfficient wrapper for the Causal Language Model (decoder) component of a Text-to-Image-to-Text model. + + This class handles the export and compilation of the language decoder part + of multimodal models for optimal performance on Cloud AI 100 hardware. + """ + + _pytorch_transforms = [ + AwqToMatmulNbitsTransform, + GPTQToMatmulNbitsTransform, + FP8BlockWiseDequantQwen3VLMoeTextExpertsToQwen3VLMoeTextExpertsTransform, + FP8BlockWiseDequantLinearToLinearTransform, + CustomOpsTransform, + KVCacheTransform, + VlmKVOffloadTransform, + SplitGateUpWeightsTransform, + ] + _onnx_transforms = [] + + def __init__(self, model, qaic_config: Optional[dict] = None, **kwargs): + """ + Initializes the language decoder component for multimodal models. + + Parameters + ---------- + model : nn.Module + The full HuggingFace multimodal model from which the language decoder is extracted. + qaic_config : dict, optional + A dictionary for QAIC-specific configurations. Supported keys include: + - **num_kv_blocks** (int): Number of K/V blocks for BlockedKV attention implementation. + **kwargs : + Additional keyword arguments passed to the base class constructor. + """ + _configure_proxy_for_model(self, kwargs.pop("enable_proxy", False)) + super().__init__(model, **kwargs) + self.model = model.get_qeff_language_decoder() + self.model.qaic_config = qaic_config + self.hash_params["qeff_auto_class"] = self.__class__.__name__ + + def __update_prefill_transform( + self, + enable: Optional[bool] = True, + enable_chunking: Optional[bool] = False, + retain_full_kv: Optional[bool] = False, + ): + if enable: + if enable_chunking: + self.model, tf = PrefillOnlyChunkedTransform.apply(self.model) + else: + self.model, tf = PrefillOnlyTransform.apply(self.model) + + else: + if retain_full_kv: + self.model, tf = RevertPrefillKeepAttentionTransform.apply(self.model) + else: + self.model, tf = RevertPrefillOnlyTransform.apply(self.model) + + def export( + self, + inputs, + output_names, + dynamic_axes, + export_dir=None, + offload_pt_weights=True, + prefill_seq_len: Optional[int] = None, + prefill_only: bool = False, + enable_chunking: bool = False, + **kwargs, + ): + """ + Exports the language decoder component to ONNX format. + + Parameters + ---------- + inputs : Dict[str, torch.Tensor] + Example inputs for the ONNX export. + output_names : List[str] + List of output names for the ONNX graph. + dynamic_axes : Dict[str, Dict[int, str]] + Dynamic axes configuration for the ONNX graph. + export_dir : str, optional + Directory path where the exported ONNX graph will be saved. Default is None. + offload_pt_weights : bool, optional + If True, PyTorch weights will be offloaded after export. Default is True. + use_onnx_subfunctions: bool, optional + whether to enable ONNX subfunctions during export. Exporting PyTorch model to ONNX with modules as subfunctions helps to reduce export/compile time. Defaults to False + + Returns + ------- + str + Path to the generated ONNX graph file for the language decoder. + """ + if prefill_only: + assert prefill_seq_len > 1 + if not enable_chunking and self.continuous_batching: + raise NotImplementedError( + "Looks like you are trying to run prefix-caching without chunking, this feature is not available yet!" + ) + self.hash_params["prefill_only"] = True + self.__update_prefill_transform(enable=True, enable_chunking=enable_chunking) + else: + self.hash_params["prefill_only"] = False + self.__update_prefill_transform(False, retain_full_kv=kwargs.get("retain_full_kv", False)) + + return self._export( + inputs, + output_names=output_names, + dynamic_axes=dynamic_axes, + export_dir=export_dir, + offload_pt_weights=offload_pt_weights, + use_onnx_subfunctions=kwargs.get("use_onnx_subfunctions", False), + ) + + def compile( + self, + compile_dir, + specializations, + convert_to_fp16, + mxfp6_matmul, + mdp_ts_num_devices, + aic_num_cores, + custom_io, + use_onnx_subfunctions: bool = False, + **compiler_options, + ) -> str: + """ + Compiles the language decoder component to a QPC package. + + Parameters + ---------- + compile_dir : str + Directory to save the generated QPC package. + specializations : List[Dict[str, Union[int, str]]] + List of dictionaries, each specifying a compilation specialization. + convert_to_fp16 : bool + If True, converts model to FP16 precision during compilation. + mxfp6_matmul : bool + If True, uses MXFP6 compression for MatMul weights. + mdp_ts_num_devices : int + Number of devices for multi-device (tensor slicing) compilation. + aic_num_cores : int + Number of cores to use for compilation. + custom_io : Dict[str, str] + Custom I/O configurations for the compiler. + use_onnx_subfunctions: bool, optional + whether to enable ONNX subfunctions during export. Exporting PyTorch model to ONNX with modules as subfunctions helps to reduce export/compile time. Defaults to False + **compiler_options : + Additional compiler options passed to the underlying compilation command. + + Returns + ------- + str + Path to the compiled QPC package for the language decoder. + """ + return self._compile( + compile_dir=compile_dir, + specializations=specializations, + convert_to_fp16=convert_to_fp16, + mxfp6_matmul=mxfp6_matmul, + mdp_ts_num_devices=mdp_ts_num_devices, + aic_num_cores=aic_num_cores, + custom_io=custom_io, + use_onnx_subfunctions=use_onnx_subfunctions, + **compiler_options, + ) + + @property + def get_model_config(self) -> dict: + """ + Get the configuration dictionary of the underlying HuggingFace language model. + + Returns + ------- + dict + The configuration dictionary. + """ + if hasattr(self.model, "language_model"): + return self.model.language_model.config.__dict__ + return self.model.config.__dict__ + + +class _QEffAutoModelForImageTextToTextDualQPC: + """ + Internal class handling multimodal image-text-to-text models using a dual QPC approach. + + In this approach, the vision encoder and language model decoder are compiled + into separate QPC packages. The vision encoder's KV cache might be offloaded + to CPU or managed differently from the language model's KV cache. + """ + + _hf_auto_class = AutoModelForImageTextToText + + def __init__( + self, + model: nn.Module, + continuous_batching: bool = False, + qaic_config: Optional[dict] = None, + **kwargs, + ): + """ + Initializes the dual QPC multimodal model wrapper. + + Parameters + ---------- + model : nn.Module + The full HuggingFace multimodal model. + qaic_config : dict, optional + A dictionary for QAIC-specific configurations. + **kwargs : + Additional keyword arguments. + """ + if kwargs.pop("full_batch_size", None): + continuous_batching = True + warnings.warn( + "full_batch_size argument is deprecated. Use continuous_batching=True instead.", DeprecationWarning, 2 + ) + self.model = model + self.config = model.config + + self.vision_model = QEffVisionEncoderForTextImageToTextModel(model, **kwargs) + self.lang_model = QEffCausalLMForTextImageToTextModel(model, qaic_config=qaic_config, **kwargs) + self.continuous_batching = continuous_batching + self.ccl_enabled = False + if qaic_config: + self.ccl_enabled = qaic_config.get("ccl_enabled", False) + self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = None, None + self.input_shapes, self.output_names = None, None + # ---Sampling--- + # Note: SamplerTransform should be applied after all other transforms + # are done. The role of the sampler is to just add nodes at the output of the + # previous transform function. + self.lang_model.model, _ = SamplerTransform.apply(self.lang_model.model, qaic_config, **kwargs) + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: str, qaic_config: Optional[dict] = None, **kwargs): + """ + Load a QEfficient multimodal model for dual QPC from a pretrained HuggingFace model or local path. + + Parameters + ---------- + pretrained_model_name_or_path : str + Model card name from HuggingFace or local path to model directory. + **kwargs : + Additional keyword arguments passed directly to `cls._hf_auto_class.from_pretrained`. + Note: `attn_implementation` and `low_cpu_mem_usage` are automatically + set to "eager" and False respectively to ensure compatibility. + + Returns + ------- + _QEffAutoModelForImageTextToTextDualQPC + An instance initialized with the pretrained weights. + """ + enable_proxy = kwargs.pop("enable_proxy", False) + + if kwargs.get("attn_implementation", None) not in {None, "eager"}: + logger.warning('Updating attn_implementation="eager"') + + if kwargs.get("low_cpu_mem_usage", None): + logger.warning("Updating low_cpu_mem_usage=False") + + kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False}) + + model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, **kwargs) + + kwargs.update({"enable_proxy": enable_proxy} if enable_proxy else {}) + + return cls( + model, + pretrained_model_name_or_path=pretrained_model_name_or_path, + qaic_config=qaic_config, + **kwargs, + ) + + @property + def onnx_path(self): + """ + Get the ONNX paths for the vision and language model components. + + Returns + ------- + List[str] + A list containing the ONNX paths of the vision model and the language model. + """ + return [self.vision_model.onnx_path, self.lang_model.onnx_path] + + def export( + self, + export_dir: Optional[str] = None, + use_onnx_subfunctions: bool = False, + skip_vision: Optional[bool] = False, + skip_lang: Optional[bool] = False, + prefill_seq_len: Optional[int] = None, + prefill_only: bool = False, + enable_chunking: bool = False, + **kwargs, + ) -> str: + """ + Exports both the vision encoder and language decoder components to ONNX format. + + This method exports the vision component (optionally without offloading PyTorch weights) + and the language component (with offloading PyTorch weights). + + Parameters + ---------- + export_dir : str, optional + Directory path where the exported ONNX graphs will be saved. Default is None. + use_onnx_subfunctions: bool, optional + whether to enable ONNX subfunctions during export. Exporting PyTorch model to ONNX with modules as subfunctions helps to reduce export/compile time. Defaults to False + **kwargs : + Additional keyword arguments. + + Returns + ------- + List[str] + A list containing the paths to the generated ONNX graph files for both components. + """ + # TODO This is a temporary change as continous batching is enabled only for few models. Once support is added for all the models this exception handing can be removed. + try: + inputs = self.model.get_dummy_inputs( + kv_offload=True, + continuous_batching=self.continuous_batching, + comp_ctx_lengths=self.comp_ctx_lengths_decode, + ) + dynamic_axes = self.model.get_onnx_dynamic_axes( + kv_offload=True, + continuous_batching=self.continuous_batching, + comp_ctx_lengths=self.comp_ctx_lengths_decode, + ) + except TypeError: + inputs = self.model.get_dummy_inputs(kv_offload=True, comp_ctx_lengths=self.comp_ctx_lengths_decode) + dynamic_axes = self.model.get_onnx_dynamic_axes( + kv_offload=True, comp_ctx_lengths=self.comp_ctx_lengths_decode + ) + output_names = self.model.get_output_names(kv_offload=True) + if self.lang_model.model.qaic_config is not None and self.lang_model.model.qaic_config.get( + "include_sampler", False + ): + logits_index = output_names["lang"].index("logits") + output_names["lang"][logits_index] = "next_tokens" + inputs["lang"], output_names["lang"], dynamic_axes["lang"] = get_sampling_inputs_and_outputs( + example_inputs=inputs["lang"], + output_names=output_names["lang"], + dynamic_axes=dynamic_axes["lang"], + continuous_batching=self.continuous_batching, + vocab_size=self.model.language_model.config.vocab_size, + qaic_config=self.lang_model.model.qaic_config, + ) + if not skip_vision: + self.vision_model.export( + inputs["vision"], + output_names["vision"], + dynamic_axes["vision"], + export_dir=export_dir, + offload_pt_weights=False, + use_onnx_subfunctions=use_onnx_subfunctions, + ) + + if prefill_only and prefill_seq_len > 1: + offload_pt_weights = False # to keep weight for decode onnx + else: + offload_pt_weights = kwargs.get("offload_pt_weights", True) + + if not skip_lang: + self.lang_model.export( + inputs["lang"], + output_names["lang"], + dynamic_axes["lang"], + export_dir=export_dir, + offload_pt_weights=offload_pt_weights, + use_onnx_subfunctions=use_onnx_subfunctions, + prefill_only=prefill_only, + enable_chunking=enable_chunking, + prefill_seq_len=prefill_seq_len, + ) + return self.onnx_path + + def transform( + self, + ctx_len: Optional[int] = None, + seq_len: Optional[int] = None, + bs: Optional[int] = 1, + num_devices: int = 1, + qaic_config: Optional[dict] = None, + **compiler_options, + ): + self.vision_model.transform( + ctx_len=ctx_len, + seq_len=seq_len, + bs=bs, + num_devices=num_devices, + qaic_config=qaic_config, + **compiler_options, + ) + + self.lang_model.transform( + ctx_len=ctx_len, + seq_len=seq_len, + bs=bs, + num_devices=num_devices, + qaic_config=qaic_config, + **compiler_options, + ) + + def compile( + self, + img_size: Optional[int] = None, + vision_onnx_path: Optional[str] = None, + lang_onnx_path: Optional[str] = None, + compile_dir: Optional[str] = None, + *, + prefill_seq_len: Optional[int] = None, + comp_ctx_lengths_prefill: Optional[List[int]] = None, + comp_ctx_lengths_decode: Optional[List[int]] = None, + ctx_len: Optional[int] = None, + batch_size: int = 1, + full_batch_size: Optional[int] = None, + kv_cache_batch_size: Optional[int] = None, + num_devices: int = 1, + num_cores: int = 16, # FIXME: Make this mandatory arg + mxfp6_matmul: bool = False, + mxint8_kv_cache: bool = False, + skip_vision: Optional[bool] = False, + skip_lang: Optional[bool] = False, + use_onnx_subfunctions: bool = False, + prefill_only=None, + enable_chunking=False, + qaic_config: Optional[dict] = None, + **compiler_options, + ) -> str: + """ + Compiles both the vision encoder and language decoder components into QPC packages. + + Parameters + ---------- + img_size : int, optional + The image size to compile the vision model for. Default is None. + vision_onnx_path : str, optional + Path to a pre-exported ONNX file for the vision encoder. If None, it will be exported. + lang_onnx_path : str, optional + Path to a pre-exported ONNX file for the language decoder. If None, it will be exported. + compile_dir : str, optional + Directory to save the generated QPC packages. + prefill_seq_len : int, optional + Length of the prefill prompt for the language model. Default is None. + ctx_len : int, optional + Maximum context length for the language model. Default is None. + batch_size : int, optional + Batch size. Default is 1. + full_batch_size : int, optional + Not supported for this model; must be None. + kv_cache_batch_size : int, optional + Not supported for this model; must be None. + num_devices : int, optional + Number of devices to compile for. Default is 1. + num_cores : int, optional + Number of cores to use for compilation. + mxfp6_matmul : bool, optional + Use MXFP6 compression for weights in the language model. Default is False. + mxint8_kv_cache : bool, optional + Use MXINT8 compression for KV cache. Default is False. + num_speculative_tokens : int, optional + Not supported for this model; must be None. + skip_vision : bool, optional + If True, skips compilation of the vision encoder. Default is False. + skip_lang : bool, optional + If True, skips compilation of the language decoder. Default is False. + use_onnx_subfunctions: bool, optional + whether to enable ONNX subfunctions during export. Exporting PyTorch model to ONNX with modules as subfunctions helps to reduce export/compile time. Defaults to False + **compiler_options : dict + Additional compiler options for QAIC or QNN compilers. + + Returns + ------- + Union[List[str], str, None] + A list of paths to the compiled QPC packages, or a single path if only + one component is compiled, or None if neither is compiled. + + Raises + ------ + ValueError + If `full_batch_size`, `kv_cache_batch_size`, or `num_speculative_tokens` are not None. + If both `skip_lang` and `skip_vision` are True. + """ + if skip_lang and skip_vision: + raise ValueError("Expected at least one of 'skip_lang' or 'skip_vision' to be False") + + if self.continuous_batching and full_batch_size is None: + raise TypeError("`full_batch_size` is required when `continuous_batching=True`.") + + if kv_cache_batch_size and not full_batch_size: + raise ValueError( + "KV caching requires continuous batching. Please set `full_batch_size` and " + "enable `continuous_batching=True` in `from_pretrained`." + ) + + # Infer kv_cache_batch_size if not provided + kv_cache_batch_size = kv_cache_batch_size or full_batch_size or batch_size + + output_names = self.model.get_output_names(kv_offload=True) + + # if ccl_enabled is True read Compute-Context-Length lists + if self.ccl_enabled: + if comp_ctx_lengths_prefill is None and comp_ctx_lengths_decode is None: + logger.info("Auto-generating CCL-prefill and CCL-decode lists based on Context Length (CL).") + self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode, ctx_len = process_ccl_specializations( + comp_ctx_lengths_prefill, comp_ctx_lengths_decode, ctx_len, prefill_seq_len + ) + # For supporting VLLM and Disaggregated with CCL + elif comp_ctx_lengths_prefill is not None or comp_ctx_lengths_decode is not None: + self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode, ctx_len = process_ccl_specializations( + comp_ctx_lengths_prefill, comp_ctx_lengths_decode, ctx_len, prefill_seq_len + ) + + # Apply compile-dependent transforms like blocking transform + self.transform( + ctx_len=ctx_len, + seq_len=prefill_seq_len, + batch_size=batch_size, + num_devices=num_devices, + qaic_config=qaic_config, + aic_num_cores=num_cores, + ) + + specializations, compiler_options = self.model.get_specializations( + batch_size=batch_size, + prefill_seq_len=prefill_seq_len, + ctx_len=ctx_len, + comp_ctx_lengths_prefill=self.comp_ctx_lengths_prefill, + comp_ctx_lengths_decode=self.comp_ctx_lengths_decode, + img_size=img_size, + kv_offload=True, + continuous_batching=self.continuous_batching, + kv_cache_batch_size=kv_cache_batch_size, + full_batch_size=full_batch_size, + **compiler_options, + ) + + custom_io_vision = {} + target_dtype = getattr(self.model.config, "torch_dtype", torch.float32) + kv_cache_dtype = "mxint8" if mxint8_kv_cache else CUSTOM_IO_DTYPE_MAP[target_dtype] + molmo = hasattr(self.model.config, "model_type") and self.model.config.model_type == "molmo" + if molmo: + custom_io_vision["image_masks"] = CUSTOM_IO_DTYPE_MAP[target_dtype] + custom_io_vision["pixel_values"] = CUSTOM_IO_DTYPE_MAP[target_dtype] + + for output_name in output_names["vision"]: + if output_name.startswith("past_"): + custom_io_vision[output_name] = kv_cache_dtype + else: + custom_io_vision[output_name] = CUSTOM_IO_DTYPE_MAP[target_dtype] + + if vision_onnx_path: + self.vision_model.onnx_path = vision_onnx_path + if lang_onnx_path: + self.lang_model.onnx_path = lang_onnx_path + + if vision_onnx_path is None or lang_onnx_path is None: + self.export( + use_onnx_subfunctions=use_onnx_subfunctions, + skip_vision=skip_vision, + skip_lang=skip_lang, + prefill_only=prefill_only, + enable_chunking=enable_chunking, + prefill_seq_len=prefill_seq_len, + ) + + # TODO this hould be removed once the continous batching is supported for all the models. + compiler_options.pop("continuous_batching", None) + compiler_options.pop("kv_cache_batch_size", None) + compiler_options.pop("full_batch_size", None) + self.qpc_paths = {} + if not skip_vision: + vision_qpc_path = self.vision_model._compile( + compile_dir=compile_dir, + specializations=specializations["vision"], + specialization_module_name="Vision", + convert_to_fp16=(CUSTOM_IO_DTYPE_MAP[target_dtype] == "float16"), + mxfp6_matmul=constants.VISION_MXFP6_MATMUL, + mdp_ts_num_devices=num_devices, + aic_num_cores=num_cores, + custom_io=custom_io_vision, + mxint8_kv_cache=mxint8_kv_cache, + use_onnx_subfunctions=use_onnx_subfunctions, + **compiler_options, + ) + self.qpc_paths["vision_qpc_path"] = vision_qpc_path + + # Custom NPI file options + if hasattr(self.model, "get_npi_file") and "node_precision_info" not in compiler_options: + compiler_options["node_precision_info"] = self.model.get_npi_file(self.model.name_or_path) + + if not skip_lang: + custom_io_lang = {} + # Inputs + for output_name in output_names["lang"]: + if output_name.endswith("_RetainedState"): + custom_io_lang[output_name[: -len("_RetainedState")]] = ( + CUSTOM_IO_DTYPE_MAP[target_dtype] + if ("vision_embeds" in output_name or "deepstack_features" in output_name) + else kv_cache_dtype + ) + + # outputs + for output_name in output_names["lang"]: + if output_name.endswith("_RetainedState"): + custom_io_lang[output_name] = ( + CUSTOM_IO_DTYPE_MAP[target_dtype] + if ("vision_embeds" in output_name or "deepstack_features" in output_name) + else kv_cache_dtype + ) + if prefill_only: + specializations = specializations["lang"][:1] + qpc_key = "lang_prefill_qpc_path" + elif prefill_seq_len == 1: + specializations = specializations["lang"][-1:] + qpc_key = "lang_decode_qpc_path" + else: + specializations = specializations["lang"] + qpc_key = "lang_qpc_path" + + lang_qpc_path = self.lang_model._compile( + compile_dir=compile_dir, + retained_state=True, + specializations=specializations, + convert_to_fp16=(CUSTOM_IO_DTYPE_MAP[target_dtype] == "float16"), + mxfp6_matmul=mxfp6_matmul, + mdp_ts_num_devices=num_devices, + aic_num_cores=num_cores, + custom_io=custom_io_lang, + mxint8_kv_cache=mxint8_kv_cache, + use_onnx_subfunctions=use_onnx_subfunctions, + **compiler_options, + ) + self.qpc_paths.update({qpc_key: lang_qpc_path}) + return self.qpc_paths + + def generate( + self, + inputs: Optional[torch.Tensor] = None, + tokenizer: Union[PreTrainedTokenizerFast, PreTrainedTokenizer] = None, + processor: Optional[AutoImageProcessor] = None, + images: List[str] = None, + prompts: List[str] = None, + streamer: Optional[TextStreamer] = None, + device_ids: List[int] = None, + runtime_ai100: bool = True, + generation_len: Optional[int] = None, + image_height: Optional[int] = None, + image_width: Optional[int] = None, + **kwargs, + ) -> Union[torch.Tensor, np.ndarray]: + """ + Generates output by executing the compiled QPC(s) on Cloud AI 100 Hardware cards. + + This method coordinates inference between the vision encoder and language model decoder. + + Parameters + ---------- + inputs : Dict[str, Union[torch.Tensor, np.ndarray]] + Inputs to run the execution, typically includes `pixel_values`, `input_ids`, + `attention_mask`, etc. + tokenizer : PreTrainedTokenizer or PreTrainedTokenizerFast, optional + Tokenizer for the model. Used when images and prompts are provided. + processor : AutoImageProcessor, optional + Processor for the model. Used when images and prompts are provided. + images : List[str], optional + List of image paths or PIL images to process. + prompts : List[str], optional + List of text prompts corresponding to the images. + streamer : TextStreamer, optional + A streamer object to display generated tokens in real-time. Default is None. + device_ids : List[int], optional + IDs of devices for running the QPC. E.g., `[0]` for a single device or + `[0, 1, 2, 3]` for tensor slicing. Defaults to `[0]` if not specified. + runtime_ai100 : bool, optional + If True, uses the AI 100 runtime. PyTorch runtime is not supported for this model. + Default is True. + generation_len : int, optional + The maximum number of tokens to generate. If None, it's inferred from `ctx_len`. + + Returns + ------- + CloudAI100ExecInfoNew or np.ndarray + Output from the AI 100 runtime, including generated IDs and performance metrics. + + Raises + ------ + NotImplementedError + If `runtime_ai100` is False. + """ + if not runtime_ai100: + raise NotImplementedError("PyTorch execution is not supported yet for this model!") + + write_io = kwargs.pop("write_io", False) + self._write_io_dir = os.path.join(os.path.dirname(self.onnx_path[1]), "io_dir") if write_io else None + + # Use VisionLanguageGeneration for image-prompt pairs + if (processor and images) or (tokenizer and prompts): + # Create VisionLanguageGeneration instance + batch_size_comp, ctx_len_comp, fbs = get_compilation_dims(self.lang_model.qpc_path) + vlm_gen = VisionLanguageGeneration( + qeff_model=self, + lang_qpc_path=self.lang_model.qpc_path, + vision_qpc_path=self.vision_model.qpc_path, + tokenizer=tokenizer, + processor=processor, + device_id=device_ids, # if device_ids is not None else [0], + ctx_len=ctx_len_comp, + full_batch_size=fbs, + comp_ctx_lengths_prefill=self.comp_ctx_lengths_prefill, + comp_ctx_lengths_decode=self.comp_ctx_lengths_decode, + image_height=image_height, + image_width=image_width, + write_io_dir=self._write_io_dir, + **kwargs, + ) + + # Call generate method + return vlm_gen.generate( + images=images, + prompts=prompts, + generation_len=generation_len, + stream=streamer is not None, + ) + + # Fallback to kv_offload_generate for direct inputs (backward compatibility) + return self.kv_offload_generate( + inputs=inputs, device_ids=device_ids, streamer=streamer, generation_len=generation_len + ) + + def kv_offload_generate( + self, + inputs: List[str] = None, + streamer: Optional[TextStreamer] = None, + device_ids: List[int] = None, + generation_len: int = None, + ): + """ + Performs generation for multimodal models with KV offloading to CPU. + + This method orchestrates the inference by running the vision encoder (if compiled) + and then iteratively running the language decoder, managing KV cache states. + + Parameters + ---------- + inputs : Dict[str, Union[torch.Tensor, np.ndarray]] + Input tensors for the multimodal model. + streamer : TextStreamer, optional + A streamer object to display generated tokens in real-time. Default is None. + device_ids : List[int], optional + IDs of devices for running the QPC. Defaults to `[0]` if not specified. + generation_len : int, optional + The maximum number of tokens to generate. If None, it's inferred from `ctx_len`. + + Returns + ------- + CloudAI100ExecInfoNew + Execution information including generated IDs and performance metrics. + + Raises + ------ + TypeError + If the language model QPC is not compiled. + AssertionError + If `generation_len` is not greater than zero. + """ + if not self.lang_model.qpc_path: + raise TypeError("Please run compile API for language model first!") + + lang_session = QAICInferenceSession(self.lang_model.qpc_path, device_ids, activate=False) + + if self.vision_model.qpc_path: + vision_session = QAICInferenceSession(self.vision_model.qpc_path, device_ids) + + batch_size, ctx_len, fbs = get_compilation_dims(self.lang_model.qpc_path) + + pad_token_id = 1 + + # Skip inputs/outputs + lang_session.skip_buffers( + [ + x + for x in lang_session.input_names + lang_session.output_names + if x.startswith("past_") or x.endswith("_RetainedState") + ] + ) + + # Read prompt and ctx len from session + batch_size = max( + [x[lang_session.binding_index_map["input_ids"]][1][0] for x in lang_session.allowed_shapes] + + [lang_session.bindings[lang_session.binding_index_map["input_ids"]].dims[0]] + ) + + prefill_seq_len = max( + [x[lang_session.binding_index_map["input_ids"]][1][1] for x in lang_session.allowed_shapes] + + [lang_session.bindings[lang_session.binding_index_map["input_ids"]].dims[1]] + ) + input_len = inputs["attention_mask"].sum(1, keepdims=True) + input_ids_length = inputs["input_ids"].shape[1] + num_chunks = -(input_ids_length // -prefill_seq_len) # ceil divide without float + padded_len = num_chunks * prefill_seq_len # Convert to a multiple of prompt_len + + if generation_len is None: + generation_len = ctx_len - input_len.max() + assert generation_len > 0, "generation length should be greater than zero" + generated_ids = np.full((batch_size, generation_len + 1), pad_token_id) + + inputs["input_ids"] = torch.nn.functional.pad( + inputs["input_ids"], + (0, padded_len - input_ids_length), + "constant", + pad_token_id, + ) + inputs["attention_mask"] = torch.nn.functional.pad( + inputs["attention_mask"], (0, padded_len - input_ids_length), "constant", 0 + ) + if "cross_attention_mask" in inputs: + inputs["cross_attention_mask"] = torch.nn.functional.pad( + inputs["cross_attention_mask"], (0, 0, 0, 0, 0, padded_len - input_ids_length) + ) + + for k, v in inputs.items(): + inputs[k] = np.array(v) + + vision_inputs = { + k: v + for k, v in inputs.items() + if k + in {"pixel_values", "image_masks", "image_input_idx", "valid_idx", "aspect_ratio_ids", "aspect_ratio_mask"} + } + + vision_inputs_fp16 = {"pixel_values", "image_masks"} + vision_inputs.update({k: vision_inputs[k].astype("float16") for k in vision_inputs_fp16 if k in vision_inputs}) + + vision_start = perf_counter() + + vision_outputs = {} + if vision_inputs: + vision_outputs = vision_session.run(vision_inputs) + vision_end = perf_counter() + + lang_inputs = {k: v for k, v in inputs.items() if k not in vision_inputs} + if "position_ids" in inputs: + lang_inputs["position_ids"] = inputs["position_ids"] + lang_inputs.pop("attention_mask") + else: + lang_inputs["position_ids"] = np.where( + lang_inputs.pop("attention_mask"), np.arange(padded_len), -1 + ) # Need to use -1 as position_ids for invalid tokens + + not_mllama = hasattr(self.model.config, "model_type") and self.model.config.model_type != "mllama" + if not_mllama: + lang_inputs["image_idx"] = np.array([[0]]) + if self.vision_model.qpc_path: + vision_session.deactivate() + lang_session.activate() + + lang_session.set_buffers(vision_outputs) + + if self.comp_ctx_lengths_prefill is not None: + list_of_comp_ctx_lengths_prefill = [ + np.zeros(length, dtype=np.int8) for length in self.comp_ctx_lengths_prefill + ] + prefill_ccl_id = 0 + lang_inputs["comp_ctx_lengths"] = list_of_comp_ctx_lengths_prefill[prefill_ccl_id] + + lang_start = perf_counter() + # Run prefill + chunk_inputs = lang_inputs.copy() + for i in range(num_chunks): + if ( + self.comp_ctx_lengths_prefill is not None + and (i + 1) * prefill_seq_len > self.comp_ctx_lengths_prefill[prefill_ccl_id] + ): + prefill_ccl_id = min(prefill_ccl_id + 1, len(self.comp_ctx_lengths_prefill) - 1) + chunk_inputs["comp_ctx_lengths"] = list_of_comp_ctx_lengths_prefill[prefill_ccl_id] + + chunk_inputs["input_ids"] = lang_inputs["input_ids"][:, i * prefill_seq_len : (i + 1) * prefill_seq_len] + chunk_inputs["position_ids"] = lang_inputs["position_ids"][ + ..., i * prefill_seq_len : (i + 1) * prefill_seq_len + ] + outputs = lang_session.run(chunk_inputs) + chunk_inputs["image_idx"] = outputs["image_idx_output"] + + if self._write_io_dir is not None: + write_io_files(lang_inputs, outputs, self._write_io_dir, "prefill", "aic_batch_io", True, False) + + prefill_time = perf_counter() - lang_start + vision_end - vision_start + # Skip inputs/outputs again + lang_session.skip_buffers( + [ + x + for x in lang_session.input_names + lang_session.output_names + if x.startswith("past_") or x.endswith("_RetainedState") + ] + ) + if not_mllama: + lang_session.skip_buffers(vision_outputs.keys()) + # Get first token + lang_inputs["input_ids"] = outputs["logits"].argmax(2) + lang_inputs["position_ids"] = np.max(lang_inputs["position_ids"], axis=-1, keepdims=True) + 1 + if "cross_attention_mask" in lang_inputs: + bs, _, num_images, img_tiles = lang_inputs["cross_attention_mask"].shape + lang_inputs["cross_attention_mask"] = torch.ones((bs, 1, num_images, img_tiles), dtype=torch.int64).numpy() + generated_ids[:, 0] = lang_inputs["input_ids"].squeeze(1) + + if streamer: + streamer.put(lang_inputs["input_ids"][0]) + + # Decode loop + if self.comp_ctx_lengths_decode is not None: + max_ccl_id = len(self.comp_ctx_lengths_decode) - 1 + list_of_comp_ctx_lengths_decode = [ + np.zeros(length, dtype=np.int8) for length in self.comp_ctx_lengths_decode + ] + max_position_id = np.max(lang_inputs["position_ids"]) + ccl_id_initial = 0 + ccl_id = ccl_id_initial + for i in range(ccl_id_initial, len(self.comp_ctx_lengths_decode)): + if max_position_id < self.comp_ctx_lengths_decode[i]: + ccl_id = i + break + lang_inputs["comp_ctx_lengths"] = list_of_comp_ctx_lengths_decode[ccl_id] + + decode_start = perf_counter() + for num_token in range(1, generation_len): + if self.comp_ctx_lengths_decode is not None: + if max_position_id >= self.comp_ctx_lengths_decode[ccl_id] - 1: + ccl_id = min(ccl_id + 1, max_ccl_id) + lang_inputs["comp_ctx_lengths"] = list_of_comp_ctx_lengths_decode[ccl_id] + + outputs = lang_session.run(lang_inputs) + if self._write_io_dir is not None: + write_io_files(lang_inputs, outputs, self._write_io_dir, "decode", "aic_batch_io", True, False) + self._write_io_dir = None + + # Prepare inputs for next iteration + lang_inputs["input_ids"] = outputs["logits"].argmax(2) + lang_inputs["position_ids"] += 1 + generated_ids[:, num_token] = lang_inputs["input_ids"].squeeze(1) + if streamer: + streamer.put(lang_inputs["input_ids"][0]) + + decode_end = perf_counter() + if streamer: + streamer.end() + + decode_perf = (num_token - 1) / (decode_end - decode_start) + total_time = decode_end - decode_start + prefill_time + total_perf = num_token / total_time + + return CloudAI100ExecInfoNew( + batch_size=batch_size, + generated_ids=generated_ids, + perf_metrics=PerfMetrics( + prefill_time=prefill_time, decode_perf=decode_perf, total_perf=total_perf, total_time=total_time + ), + ) + + +class _QEFFAutoModelForImageTextToTextSingleQPC(QEFFTransformersBase, MultimodalUtilityMixin): + """ + Internal class handling multimodal image-text-to-text models using a single QPC approach. + + In this approach, the entire multimodal model (vision encoder + language model decoder) + is compiled into a single QPC package. + """ + + _hf_auto_class = AutoModelForImageTextToText + _pytorch_transforms = [ + AwqToMatmulNbitsTransform, + GPTQToMatmulNbitsTransform, + CustomOpsTransform, + KVCacheTransform, + KVCacheExternalModuleMapperTransform, + VlmNoKVOffloadTransform, + SplitGateUpWeightsTransform, + ] + _onnx_transforms = [] + + def __init__( + self, + model: nn.Module, + qaic_config: Optional[dict] = None, + **kwargs, + ): + """ + Initializes the single QPC multimodal model wrapper. + + Parameters + ---------- + model : nn.Module + The full HuggingFace multimodal model. + qaic_config : dict, optional + A dictionary for QAIC-specific configurations. Supported keys include: + - **num_kv_blocks** (int): Number of K/V blocks for BlockedKV attention implementation. + **kwargs : + Additional keyword arguments. `full_batch_size` is not supported here. + + Raises + ------ + NotImplementedError + If `full_batch_size` is provided or `include_sampler` is True. + """ + if kwargs.pop("full_batch_size", None): + warnings.warn( + "full_batch_size argument is deprecated. Use continuous_batching=True instead.", DeprecationWarning, 2 + ) + raise NotImplementedError("Continuous batching is not supported for image-text-to-text models yet.") + if qaic_config is not None and qaic_config.pop("include_sampler", False): + raise NotImplementedError("On-device sampling is not supported for single QPC multimodal models yet.") + + super().__init__(model, **kwargs) + + self.model.qaic_config = qaic_config + + # to handle internvl models + if hasattr(self.model.config, "llm_config") and hasattr(self.model.config, "vision_config"): + self.model.config.llm_config.use_cache = True + self.model.config.llm_config._attn_implementation = "eager" + self.model.config.vision_config.use_flash_attn = "false" + else: + if hasattr(self.model.config, "text_config"): + self.model.config.text_config.use_cache = True + else: + self.model.config.use_cache = True + self.hash_params["qeff_auto_class"] = self.__class__.__name__ + self.ccl_enabled = False + if qaic_config: + self.ccl_enabled = qaic_config.get("ccl_enabled", False) + self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = None, None + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path, + qaic_config: Optional[dict] = None, + *args, + **kwargs, + ): + """ + Load a QEfficient multimodal model for single QPC from a pretrained HuggingFace model or local path. + + Parameters + ---------- + pretrained_model_name_or_path : str + Model card name from HuggingFace or local path to model directory. + *args : + Positional arguments passed directly to `cls._hf_auto_class.from_pretrained`. + **kwargs : + Additional keyword arguments passed directly to `cls._hf_auto_class.from_pretrained`. + Note: `attn_implementation` and `low_cpu_mem_usage` are automatically + set to "eager" and False respectively to ensure compatibility. + Also, `_attn_implementation` and `use_flash_attn` are configured for VLM models. + + Returns + ------- + _QEFFAutoModelForImageTextToTextSingleQPC + An instance initialized with the pretrained weights. + """ + enable_proxy = kwargs.pop("enable_proxy", False) + + if kwargs.get("attn_implementation", None) not in {None, "eager"}: + logger.warning('Updating attn_implementation="eager"') + + if kwargs.get("low_cpu_mem_usage", None): + logger.warning("Updating low_cpu_mem_usage=False") + + kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False}) + + from transformers import AutoConfig + + config = AutoConfig.from_pretrained(pretrained_model_name_or_path, trust_remote_code=True) + config._attn_implementation = "eager" + config.vision_config.use_flash_attn = "false" + model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, config, *args, **kwargs) + + kwargs.update({"enable_proxy": enable_proxy} if enable_proxy else {}) + + return cls( + model, + pretrained_model_name_or_path=pretrained_model_name_or_path, + qaic_config=qaic_config, + **kwargs, + ) + + def export( + self, + export_dir: Optional[str] = None, + use_onnx_subfunctions: bool = False, + **kwargs, + ) -> str: + """ + Exports the entire multimodal model to ONNX format. + + Parameters + ---------- + export_dir : str, optional + Directory path where the exported ONNX graph will be saved. Default is None. + **kwargs : + Additional keyword arguments. + + Returns + ------- + str + Path to the generated ONNX graph file. + """ + inputs = self.model.get_dummy_inputs(comp_ctx_lengths=self.comp_ctx_lengths_decode) + dynamic_axes = self.model.get_onnx_dynamic_axes(comp_ctx_lengths=self.comp_ctx_lengths_decode) + output_names = self.model.get_output_names() + return self._export( + inputs, + output_names=output_names, + dynamic_axes=dynamic_axes, + export_dir=export_dir, + use_onnx_subfunctions=use_onnx_subfunctions, + ) + + def compile( + self, + onnx_path: Optional[str] = None, + img_size: Optional[int] = None, + compile_dir: Optional[str] = None, + *, + prefill_seq_len: Optional[int] = None, + ctx_len: Optional[int] = None, + comp_ctx_lengths_prefill: Optional[List[int]] = None, + comp_ctx_lengths_decode: Optional[List[int]] = None, + batch_size: int = 1, + full_batch_size: Optional[int] = None, + kv_cache_batch_size: Optional[int] = None, + num_devices: int = 1, + num_cores: int = 16, # FIXME: Make this mandatory arg + mxfp6_matmul: bool = False, + mxint8_kv_cache: bool = False, + num_speculative_tokens: Optional[int] = None, + use_onnx_subfunctions: bool = False, + qaic_config: Optional[dict] = None, + **compiler_options, + ) -> str: + """ + Compiles the exported ONNX model (single QPC) using the Cloud AI 100 Platform SDK compiler. + + This method generates a single ``qpc`` package for the entire multimodal model. + + Parameters + ---------- + onnx_path : str, optional + Path to a pre-exported ONNX model. If not provided, the model will be exported first. + img_size : int, optional + The image size to compile the vision part of the model for. Default is None. + compile_dir : str, optional + Directory to save the generated QPC package. + prefill_seq_len : int, optional + Length of the prefill prompt. Default is None. + ctx_len : int, optional + Maximum context length the compiled model can remember. Default is None. + batch_size : int, optional + Batch size. Default is 1. + full_batch_size : int, optional + Not supported for this model; must be None. + kv_cache_batch_size : int, optional + Not supported for this model; must be None. + num_devices : int, optional + Number of devices to compile for. Default is 1. + num_cores : int, optional + Number of cores to use for compilation. + mxfp6_matmul : bool, optional + Use MXFP6 compression for weights. Default is False. + mxint8_kv_cache : bool, optional + Use MXINT8 compression for KV cache. Default is False. + num_speculative_tokens : int, optional + Not supported for this model; must be None. + use_onnx_subfunctions: bool, optional + whether to enable ONNX subfunctions during export. Exporting PyTorch model to ONNX with modules as subfunctions helps to reduce export/compile time. Defaults to False + **compiler_options : dict + Additional compiler options for QAIC or QNN compilers. + + Returns + ------- + str + Path to the compiled QPC package. + + Raises + ------ + ValueError + If `full_batch_size`, `kv_cache_batch_size`, or `num_speculative_tokens` are not None. + """ + if any(param is not None for param in [full_batch_size, kv_cache_batch_size, num_speculative_tokens]): + raise ValueError( + f"Expected 'full_batch_size', 'kv_cache_batch_size', 'num_speculative_tokens' to be None but got: " + f"full_batch_size={full_batch_size}, kv_cache_batch_size={kv_cache_batch_size}, num_speculative_tokens={num_speculative_tokens}, " + ) + + # Infer kv_cache_batch_size if not provided + kv_cache_batch_size = kv_cache_batch_size or full_batch_size or batch_size + output_names = self.model.get_output_names() + + # if ccl_enabled is True read Compute-Context-Length lists + if self.ccl_enabled: + if comp_ctx_lengths_prefill is None and comp_ctx_lengths_decode is None: + logger.info("Auto-generating CCL-prefill and CCL-decode lists based on Context Length (CL).") + self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode, ctx_len = process_ccl_specializations( + comp_ctx_lengths_prefill, comp_ctx_lengths_decode, ctx_len, prefill_seq_len + ) + # For supporting VLLM and Disaggregated with CCL + elif comp_ctx_lengths_prefill is not None or comp_ctx_lengths_decode is not None: + self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode, ctx_len = process_ccl_specializations( + comp_ctx_lengths_prefill, comp_ctx_lengths_decode, ctx_len, prefill_seq_len + ) + + # Apply compile-dependent transforms like blocking transform + self.transform( + ctx_len=ctx_len, + seq_len=prefill_seq_len, + batch_size=batch_size, + num_devices=num_devices, + qaic_config=qaic_config, + aic_num_cores=num_cores, + ) + + # Get specializations from modelling file + # TODO: expose this via the auto class as well + specializations, compiler_options = self.model.get_specializations( + batch_size=batch_size, + prefill_seq_len=prefill_seq_len, + ctx_len=ctx_len, + comp_ctx_lengths_prefill=self.comp_ctx_lengths_prefill, + comp_ctx_lengths_decode=self.comp_ctx_lengths_decode, + kv_cache_batch_size=kv_cache_batch_size, + img_size=img_size, + **compiler_options, + ) + + if hasattr(self.model, "get_npi_file") and "node_precision_info" not in compiler_options: + compiler_options["node_precision_info"] = self.model.get_npi_file(self.model.name_or_path) + + custom_io = {} + target_dtype = getattr(self.model.config, "torch_dtype", torch.float32) + kv_cache_dtype = "mxint8" if mxint8_kv_cache else CUSTOM_IO_DTYPE_MAP[target_dtype] + # inputs + for input_name in output_names: + if input_name.endswith("_RetainedState"): + custom_io[input_name[: -len("_RetainedState")]] = ( + CUSTOM_IO_DTYPE_MAP[target_dtype] if "pixel_values" in input_name else kv_cache_dtype + ) + + # outputs + for output_name in output_names: + if output_name.endswith("_RetainedState"): + custom_io[output_name] = ( + CUSTOM_IO_DTYPE_MAP[target_dtype] if "pixel_values" in output_name else kv_cache_dtype + ) + + # TODO this hould be removed once the continous batching is supported for all the models. + compiler_options.pop("continuous_batching", None) + compiler_options.pop("kv_cache_batch_size", None) + compiler_options.pop("full_batch_size", None) + self._compile( + onnx_path=onnx_path, + compile_dir=compile_dir, + retained_state=True, + specializations=specializations, + convert_to_fp16=(CUSTOM_IO_DTYPE_MAP[target_dtype] == "float16"), + mxfp6_matmul=mxfp6_matmul, + custom_io=custom_io, + mdp_ts_num_devices=num_devices, + aic_num_cores=num_cores, + mxint8_kv_cache=mxint8_kv_cache, + use_onnx_subfunctions=use_onnx_subfunctions, + **compiler_options, + ) + return self.qpc_path + + def get_onnx_dynamic_axes(self): + """ + Retrieves the dynamic axes configuration for ONNX export for this model. + + Returns + ------- + Dict[str, Dict[int, str]] + A dictionary specifying the dynamic axes for inputs. + """ + return self.model.get_onnx_dynamic_axes() + + def generate( + self, + inputs: torch.Tensor, + streamer: Optional[TextStreamer] = None, + device_ids: List[int] = None, + runtime_ai100: bool = True, + generation_len: Optional[int] = None, + write_io: bool = False, + ) -> Union[torch.Tensor, np.ndarray]: + """ + Generates output by executing the compiled single QPC on Cloud AI 100 Hardware cards. + + Parameters + ---------- + inputs : Dict[str, Union[torch.Tensor, np.ndarray]] + Inputs to run the execution, typically includes `pixel_values`, `input_ids`, + `attention_mask`, etc. + streamer : TextStreamer, optional + A streamer object to display generated tokens in real-time. Default is None. + device_ids : List[int], optional + IDs of devices for running the QPC. E.g., `[0]` for a single device or + `[0, 1, 2, 3]` for tensor slicing. Defaults to `[0]` if not specified. + runtime_ai100 : bool, optional + If True, uses the AI 100 runtime. PyTorch runtime is not supported for this model. + Default is True. + generation_len : int, optional + The maximum number of tokens to generate. If None, it's inferred from `ctx_len`. + + Returns + ------- + CloudAI100ExecInfoNew or np.ndarray + Output from the AI 100 runtime, including generated IDs and performance metrics. + + Raises + ------ + NotImplementedError + If `runtime_ai100` is False. + """ + if not runtime_ai100: + raise NotImplementedError("PyTorch execution is not supported yet for this model!") + + self._write_io_dir = os.path.join(os.path.dirname(self.onnx_path), "io_dir") if write_io else None + + return self.cloud_ai_100_generate( + inputs=inputs, device_ids=device_ids, generation_len=generation_len, streamer=streamer + ) + + def cloud_ai_100_generate( + self, + inputs: torch.Tensor, + device_ids: List[int], + enable_debug_logs: bool = False, + generation_len: int = None, + streamer: Optional[TextStreamer] = None, + ) -> np.ndarray: + """ + Performs generation for multimodal models using a single QPC on Cloud AI 100 hardware. + + Parameters + ---------- + inputs : Dict[str, Union[torch.Tensor, np.ndarray]] + Input tensors for the multimodal model. + device_ids : List[int] + IDs of devices for running the QPC. + enable_debug_logs : bool, optional + If True, enables debug logging for the QAIC inference session. Default is False. + generation_len : int, optional + The maximum number of tokens to generate. If None, it's inferred from `ctx_len`. + streamer : TextStreamer, optional + A streamer object to display generated tokens in real-time. Default is None. + + Returns + ------- + CloudAI100ExecInfoNew + Execution information including generated IDs and performance metrics. + + Raises + ------ + AssertionError + If `generation_len` is not greater than zero. + """ + inputs = self.auto_correct_inputs(inputs) + qpc_session = QAICInferenceSession( + self.qpc_path, device_ids, enable_debug_logs=enable_debug_logs, activate=False + ) + batch_size, ctx_len, fbs = get_compilation_dims(self.qpc_path) + pad_token_id = 1 + # Skip inputs/outputs + qpc_session.skip_buffers( + [ + x + for x in qpc_session.input_names + qpc_session.output_names + if x.startswith("past_") or x.endswith("_RetainedState") + ] + ) + + # Read prompt and ctx len from session + batch_size = max( + [x[qpc_session.binding_index_map["input_ids"]][1][0] for x in qpc_session.allowed_shapes] + + [qpc_session.bindings[qpc_session.binding_index_map["input_ids"]].dims[0]] + ) + + prefill_seq_len = max( + [x[qpc_session.binding_index_map["input_ids"]][1][1] for x in qpc_session.allowed_shapes] + + [qpc_session.bindings[qpc_session.binding_index_map["input_ids"]].dims[1]] + ) + + input_len = inputs["attention_mask"].sum(1, keepdims=True) + input_ids_length = inputs["input_ids"].shape[1] + + num_chunks = -(input_ids_length // -prefill_seq_len) # ceil divide without float + + padded_len = num_chunks * prefill_seq_len # Convert to a multiple of prompt_len + if generation_len is None: + generation_len = ctx_len - input_len.max() + + assert generation_len > 0, "generation length should be greater than zero" + generated_ids = np.full((batch_size, generation_len + 1), pad_token_id) + + # Prepare inputs for prefill + inputs["input_ids"] = torch.nn.functional.pad( + inputs["input_ids"], + (0, padded_len - input_ids_length), + "constant", + pad_token_id, + ) + inputs["attention_mask"] = torch.nn.functional.pad( + inputs["attention_mask"], (0, padded_len - input_ids_length), "constant", 0 + ) + if "cross_attention_mask" in inputs: + inputs["cross_attention_mask"] = torch.nn.functional.pad( + inputs["cross_attention_mask"], (0, 0, 0, 0, 0, padded_len - input_ids_length) + ) + for k, v in inputs.items(): + inputs[k] = np.array(v) + + if "pixel_values_RetainedState" in qpc_session.output_names: + inputs["pixel_values"] = inputs["pixel_values"].astype("float16") + + inputs["position_ids"] = np.where(inputs.pop("attention_mask"), np.arange(padded_len), -1) + inputs["image_idx"] = np.array([[0]]) + + if self.comp_ctx_lengths_prefill is not None: + list_of_comp_ctx_lengths_prefill = [ + np.zeros(length, dtype=np.int8) for length in self.comp_ctx_lengths_prefill + ] + prefill_ccl_id = 0 + inputs["comp_ctx_lengths"] = list_of_comp_ctx_lengths_prefill[prefill_ccl_id] + + qpc_session.activate() + chunk_inputs = inputs.copy() + prefill_start = perf_counter() + + # Run prefill + for i in range(num_chunks): + if ( + self.comp_ctx_lengths_prefill is not None + and (i + 1) * prefill_seq_len > self.comp_ctx_lengths_prefill[prefill_ccl_id] + ): + prefill_ccl_id = min(prefill_ccl_id + 1, len(self.comp_ctx_lengths_prefill) - 1) + chunk_inputs["comp_ctx_lengths"] = list_of_comp_ctx_lengths_prefill[prefill_ccl_id] + + chunk_inputs["input_ids"] = inputs["input_ids"][:, i * prefill_seq_len : (i + 1) * prefill_seq_len] + chunk_inputs["position_ids"] = inputs["position_ids"][:, i * prefill_seq_len : (i + 1) * prefill_seq_len] + outputs = qpc_session.run(chunk_inputs) + + if self._write_io_dir is not None: + write_io_files(chunk_inputs, outputs, self._write_io_dir, "prefill", "aic_batch_io", True, False) + + chunk_inputs["image_idx"] = outputs["image_idx_output"] + + prefill_time = perf_counter() - prefill_start + # Get first token + inputs["input_ids"] = outputs["logits"].argmax(2) + inputs["position_ids"] = input_len.numpy() + + if "cross_attention_mask" in inputs: + bs, _, num_images, img_tiles = inputs["cross_attention_mask"].shape + inputs["cross_attention_mask"] = torch.ones((bs, 1, num_images, img_tiles), dtype=torch.int64).numpy() + + generated_ids[:, 0] = inputs["input_ids"].squeeze(1) + if streamer: + streamer.put(inputs["input_ids"][0]) + + if "pixel_values_RetainedState" in qpc_session.output_names: + qpc_session.skip_buffers(["pixel_values"]) + inputs.pop("pixel_values") + + # Decode loop + if self.comp_ctx_lengths_decode is not None: + list_of_comp_ctx_lengths_decode = [ + np.zeros(length, dtype=np.int8) for length in self.comp_ctx_lengths_decode + ] + max_ccl_id = len(self.comp_ctx_lengths_decode) - 1 + max_position_id = np.max(inputs["position_ids"]) + ccl_id_initial = 0 + ccl_id = ccl_id_initial + for i in range(ccl_id_initial, len(self.comp_ctx_lengths_decode)): + if max_position_id < self.comp_ctx_lengths_decode[i]: + ccl_id = i + break + inputs["comp_ctx_lengths"] = list_of_comp_ctx_lengths_decode[ccl_id] + + decode_start = perf_counter() + for num_token in range(1, generation_len): + if self.comp_ctx_lengths_decode is not None: + if max_position_id >= self.comp_ctx_lengths_decode[ccl_id] - 1: + ccl_id = min(ccl_id + 1, max_ccl_id) + inputs["comp_ctx_lengths"] = list_of_comp_ctx_lengths_decode[ccl_id] + + outputs = qpc_session.run(inputs) + if self._write_io_dir is not None: + write_io_files(inputs, outputs, self._write_io_dir, "decode", "aic_batch_io", True, False) + self._write_io_dir = None + + # Prepare inputs for next iteration + inputs["input_ids"] = outputs["logits"].argmax(2) + inputs["position_ids"] += 1 + generated_ids[:, num_token] = inputs["input_ids"].squeeze(1) + if streamer: + streamer.put(inputs["input_ids"][0]) + + decode_end = perf_counter() + if streamer: + streamer.end() + + decode_perf = (num_token - 1) / (decode_end - decode_start) + total_time = decode_end - prefill_start + total_perf = num_token / total_time + + return CloudAI100ExecInfoNew( + batch_size=batch_size, + generated_ids=generated_ids, + perf_metrics=PerfMetrics( + prefill_time=prefill_time, decode_perf=decode_perf, total_perf=total_perf, total_time=total_time + ), + ) + + @property + def get_model_config(self) -> dict: + """ + Get the configuration dictionary of the underlying HuggingFace model. + + Returns + ------- + dict + The configuration dictionary. + """ + return self.model.config.__dict__ + + +class QEFFAutoModelForImageTextToText: + """ + QEfficient class for multimodal (image-text-to-text) models from the HuggingFace hub. + + This class supports both single and dual QPC (Quantized Package Compilation) approaches for efficient deployment on Cloud AI 100 hardware. + It is recommended to use the ``from_pretrained`` method for initialization. + + Example + ------- + .. code-block:: python + + import requests + from PIL import Image + from transformers import AutoProcessor, TextStreamer + from QEfficient import QEFFAutoModelForImageTextToText + + HF_TOKEN = "" # Your HuggingFace token if needed + model_name = "meta-llama/Llama-3.2-11B-Vision-Instruct" + query = "Describe this image." + image_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg" + + # STEP 1: Load processor and model + processor = AutoProcessor.from_pretrained(model_name, token=HF_TOKEN) + model = QEFFAutoModelForImageTextToText.from_pretrained( + model_name, token=HF_TOKEN, attn_implementation="eager", kv_offload=False # kv_offload=False for single QPC + ) + + # STEP 2: Export & Compile + model.compile( + prefill_seq_len=32, + ctx_len=512, + img_size=560, + num_cores=16, + num_devices=1, + mxfp6_matmul=False, + ) + + # STEP 3: Prepare inputs + image = Image.open(requests.get(image_url, stream=True).raw) + messages = [ + { + "role": "user", + "content": [ + {"type": "image"}, + {"type": "text", "text": query}, + ], + } + ] + input_text = [processor.apply_chat_template(messages, add_generation_prompt=True)] + inputs = processor( + text=input_text, + images=image, + return_tensors="pt", + add_special_tokens=False, + padding="max_length", # Consider padding strategy if max_length is crucial + max_length=32, + ) + + # STEP 4: Run inference + streamer = TextStreamer(processor.tokenizer) + model.generate(inputs=inputs, streamer=streamer, generation_len=512) + """ + + _hf_auto_class = AutoModelForImageTextToText + + def __new__( + self, + model: nn.Module, + kv_offload: Optional[bool] = True, + continuous_batching: bool = False, + qaic_config: Optional[dict] = None, + **kwargs, + ): + """ + Instantiate the appropriate internal class for single or dual QPC mode. + + Parameters + ---------- + model : nn.Module + The loaded HuggingFace multimodal model. + kv_offload : bool, optional + If True, uses the dual QPC approach (vision encoder KV offloaded). + If False, uses the single QPC approach (entire model in one QPC). + Default is True. + **kwargs : + Additional keyword arguments passed to the constructor of the selected internal class. + + Returns + ------- + Union[_QEffAutoModelForImageTextToTextDualQPC, _QEFFAutoModelForImageTextToTextSingleQPC] + The wrapped model instance, configured for either dual or single QPC. + """ + if kv_offload: + return _QEffAutoModelForImageTextToTextDualQPC( + model, continuous_batching, qaic_config=qaic_config, **kwargs + ) + else: + return _QEFFAutoModelForImageTextToTextSingleQPC(model, qaic_config=qaic_config, **kwargs) + + @classmethod + @with_replaced_quantizers + def from_pretrained( + cls, + pretrained_model_name_or_path: str, + kv_offload: Optional[bool] = None, + continuous_batching: bool = False, + qaic_config: Optional[dict] = None, + **kwargs, + ): + """ + Load a QEfficient image-text-to-text model from a pretrained HuggingFace model or local path. + + Parameters + ---------- + pretrained_model_name_or_path : str + Model card name from HuggingFace or local path to model directory. + kv_offload : bool, optional + If True, uses the dual QPC approach (vision encoder KV offloaded). + If False, uses the single QPC approach (entire model in one QPC). + If None, the default behavior of the internal classes is used (typically dual QPC). + qaic_config : dict, optional + A dictionary for QAIC-specific configurations. + **kwargs : + Additional arguments passed to HuggingFace's ``from_pretrained``. + + **Note:** `attn_implementation` and `low_cpu_mem_usage` are automatically set to "eager" and False respectively to ensure compatibility. + `continuous_batching` is not supported for image-text-to-text models. + + Returns + ------- + QEFFAutoModelForImageTextToText + An instance initialized with the pretrained weights, wrapped for QEfficient. + + Raises + ------ + NotImplementedError + If `continuous_batching` is provided as True. + """ + enable_proxy = kwargs.pop("enable_proxy", False) + + # TODO: add a check to see if kv_offload is allowed for given model by loading the config and checking architecture or type of config here. + if continuous_batching and not kv_offload: + NotImplementedError("Continuous batching is not supported for kv_offload = False") + + if kwargs.get("attn_implementation", None) not in {None, "eager"}: + logger.warning('Updating attn_implementation="eager"') + + if kwargs.get("low_cpu_mem_usage", None): + logger.warning("Updating low_cpu_mem_usage=False") + + kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False}) + model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, **kwargs) + + kwargs.update({"enable_proxy": enable_proxy} if enable_proxy else {}) + + return cls( + model, + kv_offload=kv_offload, + continuous_batching=continuous_batching, + pretrained_model_name_or_path=pretrained_model_name_or_path, + qaic_config=qaic_config, + **kwargs, + ) + + +MISCLASSIFIED_CAUSAL_LM_TO_QEFF_AUTO_CLASS_MAP = { + "InternVLChatModel": QEFFAutoModelForImageTextToText, + "MolmoForCausalLM": QEFFAutoModelForImageTextToText, +} + + +class QEFFAutoModelForCausalLM(QEFFBaseModel): + """ + QEfficient class for Causal Language Models from the HuggingFace hub (e.g., GPT-2, Llama). + + This class provides a unified interface for loading, exporting, compiling, and generating + text with causal language models on Cloud AI 100 hardware. It supports features like + continuous batching, speculative decoding (TLM), and on-device sampling. + + Example + ------- + .. code-block:: python + + from QEfficient import QEFFAutoModelForCausalLM + from transformers import AutoTokenizer + + model = QEFFAutoModelForCausalLM.from_pretrained("gpt2") + model.compile(num_cores=16) + tokenizer = AutoTokenizer.from_pretrained("gpt2") + model.generate(prompts=["Hi there!!"], tokenizer=tokenizer) + """ + + _hf_auto_class = AutoModelForCausalLM + _pytorch_transforms = [ + AwqToMatmulNbitsTransform, + GPTQToMatmulNbitsTransform, + FP8DeQuantLinearToLinearTransform, + Mxfp4GptOssExpertDequantizeTransform, + CustomOpsTransform, + KVCacheTransform, + SplitGateUpWeightsTransform, + KVCacheExternalModuleMapperTransform, + ] + + _onnx_transforms = [] + + def prefill( + self, + enable: Optional[bool] = True, + enable_chunking: Optional[bool] = False, + retain_full_kv: Optional[bool] = False, + ): + if enable: + self.model, tf = PrefillOnlyExternalModuleMapperTransform.apply(self.model) + if enable_chunking: + self.model, tf = PrefillOnlyChunkedTransform.apply(self.model) + else: + self.model, tf = PrefillOnlyTransform.apply(self.model) + + else: + self.model, tf = RevertPrefillOnlyExternalModuleMapperTransform.apply(self.model) + if retain_full_kv: + self.model, tf = RevertPrefillKeepAttentionTransform.apply(self.model) + else: + self.model, tf = RevertPrefillOnlyTransform.apply(self.model) + + def __update_prefill_transform( + self, + enable: Optional[bool] = True, + enable_chunking: Optional[bool] = False, + retain_full_kv: Optional[bool] = False, + ): + if enable: + self.model, tf = PrefillOnlyExternalModuleMapperTransform.apply(self.model) + if enable_chunking: + self.model, tf = PrefillOnlyChunkedTransform.apply(self.model) + else: + self.model, tf = PrefillOnlyTransform.apply(self.model) + + else: + self.model, tf = RevertPrefillOnlyExternalModuleMapperTransform.apply(self.model) + if retain_full_kv: + self.model, tf = RevertPrefillKeepAttentionTransform.apply(self.model) + else: + self.model, tf = RevertPrefillOnlyTransform.apply(self.model) + + def __init__( + self, + model: nn.Module, + continuous_batching: bool = False, + qaic_config: Optional[dict] = None, + max_seq_len_cached: Optional[int] = None, + **kwargs, + ): + """ + Initializes a QEFFAutoModelForCausalLM instance. + + Parameters + ---------- + model : nn.Module + The underlying HuggingFace PyTorch Causal Language Model. + continuous_batching : bool, optional + If True, enables continuous batching mode for future compilation and execution. + This setting must be consistent across `from_pretrained` and `compile` calls. Default is False. + qaic_config : dict, optional + A dictionary for QAIC-specific configurations. Supported keys include: + - **speculative_model_type** (str): Specifies the type of Speculative Decoding model (e.g., "target"). + - **include_sampler** (bool): If True, enables on-device sampling of next tokens. + - **return_pdfs** (bool): If True, returns probability distributions along with sampled tokens. + For Speculative Decoding Target Language Models, this is always True. + - **max_top_k_ids** (int): Maximum number of top K tokens (<= vocab size) to consider during sampling. + - **include_guided_decoding** (bool): If True, enables guided token-level filtering + during decoding. Only works when include_sampler=True. + - **num_kv_blocks** (int): Number of K/V blocks for BlockedKV attention implementation. + **kwargs : + Additional keyword arguments passed to the base class constructor. + + Raises + ------ + TypeError + If the provided `model` is not a CausalLM or LMHeadModel type. + """ + model_class_name = model.__class__.__name__ + if not (model_class_name.endswith("ForCausalLM") or model_class_name.endswith("LMHeadModel")): + raise TypeError(f"Required pytorch module for CausalLM or LMHeadModel, got {model_class_name}") + _configure_proxy_for_model(self, kwargs.pop("enable_proxy", False)) + + # TODO: remove from version 1.20 + if kwargs.pop("full_batch_size", None): + continuous_batching = True + warnings.warn( + "full_batch_size argument is deprecated. Use continuous_batching=True instead.", DeprecationWarning, 2 + ) + if hasattr(model.config, "quantization_config") and not isinstance( + model.config.quantization_config, tuple(QEFF_AUTO_QUANTIZATION_CONFIG_MAPPING.values()) + ): + logger.warning( + "Please use `from_pretrained` method to load quantized models, might give unexpected results" + ) + # Set use_cache=True to get KV values as output during ONNX export + model.config.use_cache = True + + setattr(model.config, "max_seq_len_cached", max_seq_len_cached) + super().__init__(model, qaic_config=qaic_config, **kwargs) + self.num_layers = model.config.num_hidden_layers + self.continuous_batching = continuous_batching + self.model.qaic_config = qaic_config + self.model.pretrained_path = kwargs.pop("pretrained_model_name_or_path", None) + self.model, transformed = SpDTransform.apply(self.model, qaic_config, **kwargs) + self.is_tlm = transformed + + self.hash_params["qeff_auto_class"] = self.__class__.__name__ + self.ccl_enabled = False + if qaic_config: + self.ccl_enabled = qaic_config.get("ccl_enabled", False) + if mla_absorption := qaic_config.get("mla_absorption", None): + self.hash_params["mla_absorption"] = mla_absorption + setattr(self.model, "mla_absorption", mla_absorption) + self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = None, None + self.hash_params["max_seq_len_cached"] = max_seq_len_cached + + # ---Sampling--- + # Note: SamplerTransform should be applied after all other transforms + # are done. The role of the sampler is to just add nodes at the output of the + # previous transform function. + self.model, transformed = SamplerTransform.apply(self.model, qaic_config, **kwargs) + # TODO : Update in qaic_config isn't updated in the hash due to SpDTransforms. Need to move + # SpDTransforms to PytorchTransforms. + if self.is_tlm: + self.model.qaic_config["return_pdfs"] = True + + def __repr__(self) -> str: + return self.__class__.__name__ + "\n" + self.model.__repr__() + + @classmethod + @with_replaced_quantizers + def from_pretrained( + cls, + pretrained_model_name_or_path, + continuous_batching: bool = False, + qaic_config: Optional[dict] = None, + max_seq_len_cached: Optional[int] = None, + *args, + **kwargs, + ): + """ + Load a QEfficient Causal Language Model from a pretrained HuggingFace model or local path. + + This is the recommended way to initialize a QEfficient Causal Language Model. + The interface is similar to ``transformers.AutoModelForCausalLM.from_pretrained``. + Once initialized, you can use methods such as ``export``, ``compile``, and ``generate``. + + Parameters + ---------- + pretrained_model_name_or_path : str + Model card name from HuggingFace or local path to model directory. + continuous_batching : bool, optional + Whether this model will be used for continuous batching in the future. + If not set to True here, the model cannot be exported/compiled for + continuous batching later. Default is False. + qaic_config : dict, optional + QAIC config dictionary. Supported keys include: + + - **speculative_model_type** (str): Specify Speculative Decoding Target Language Models. + - **include_sampler** (bool): Enable/Disable sampling of next tokens. + - **return_pdfs** (bool): Return probability distributions along with sampled next tokens. + For Speculative Decoding Target Language Model, ``return_pdfs=True`` always. + Otherwise, ``return_pdfs=True`` for Speculative Decoding Draft Language Model + and ``return_pdfs=False`` for regular model. + - **max_top_k_ids** (int): Maximum number of top K tokens (<= vocab size) to consider during sampling. + The values provided in ``top_ks`` tensor must be less than this maximum limit. + - **include_guided_decoding** (bool): If True, enables guided token-level filtering + during decoding. Only works when include_sampler=True. + + *args : + Positional arguments passed directly to `cls._hf_auto_class.from_pretrained`. + **kwargs : + Additional keyword arguments passed directly to `cls._hf_auto_class.from_pretrained`. + + **Note:** `attn_implementation` and `low_cpu_mem_usage` are automatically + set to "eager" and False respectively to ensure compatibility. + + Returns + ------- + QEFFAutoModelForCausalLM + An instance initialized with the pretrained weights. + """ + enable_proxy = kwargs.pop("enable_proxy", False) + if kwargs.pop("full_batch_size", None): + continuous_batching = True + warnings.warn( + "full_batch_size argument is deprecated. Use continuous_batching=True instead.", DeprecationWarning, 2 + ) + + if kwargs.get("attn_implementation", None) not in {None, "eager"}: + logger.warning('Updating attn_implementation="eager"') + + if kwargs.get("low_cpu_mem_usage", None): + logger.warning("Updating low_cpu_mem_usage=False") + + kv_offload = kwargs.pop("kv_offload", None) + + kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False}) + model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, *args, **kwargs) + if qaic_config is not None: + qaic_config["pretrained_model_name_or_path"] = pretrained_model_name_or_path + + # This is support models that should be classified to in a different auto class but transformers load them via this class + kwargs.update({"enable_proxy": enable_proxy} if enable_proxy else {}) + if model.__class__.__name__ in MISCLASSIFIED_CAUSAL_LM_TO_QEFF_AUTO_CLASS_MAP: + return MISCLASSIFIED_CAUSAL_LM_TO_QEFF_AUTO_CLASS_MAP[model.__class__.__name__]( + model, + kv_offload=kv_offload, + pretrained_model_name_or_path=pretrained_model_name_or_path, + qaic_config=qaic_config, + continuous_batching=continuous_batching, + **kwargs, + ) + return cls( + model, + continuous_batching=continuous_batching, + qaic_config=qaic_config, + pretrained_model_name_or_path=pretrained_model_name_or_path, + max_seq_len_cached=max_seq_len_cached, + **kwargs, + ) + + @property + def get_model_config(self) -> dict: + """ + Get the model configuration as a dictionary. + + Returns + ------- + dict + The configuration dictionary of the underlying HuggingFace model. + """ + return self.model.config.__dict__ + + def get_seq_len_and_handle_specialized_prefill_model( + self, prefill_seq_len: Optional[int] = None, enable_chunking=False + ) -> int: + self.hash_params["prefill_only"] = True + if enable_chunking: + self.hash_params["chunking"] = True + return constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN + + num_q_blocks = ( + self.hash_params["blocking_config"].num_q_blocks if self.hash_params.get("blocking_kwargs", None) else None + ) + if num_q_blocks is None: + if ( + prefill_seq_len is None + or prefill_seq_len % constants.GPT_OSS_PREFILL_Q_BLOCK_SIZE != 0 + or prefill_seq_len < constants.GPT_OSS_PREFILL_Q_BLOCK_SIZE + ): + raise ValueError( + f"When prefill_only=True, 'prefill_seq_len' must be explicitly set and divisible by block_size={constants.GPT_OSS_PREFILL_Q_BLOCK_SIZE}. " + f"Or set `NUM_Q_BLOCKS` ENV variable" + f"Received: prefill_seq_len={prefill_seq_len}" + ) + + num_q_blocks = prefill_seq_len // constants.GPT_OSS_PREFILL_Q_BLOCK_SIZE + logger.warning( + f"Setting NUM_Q_BLOCKS={num_q_blocks} used in attention Q-blocking for prefill_only model, please pass `NUM_Q_BLOCKS` in qaic_config to override" + ) + num_q_blocks = int(num_q_blocks) + + num_ffn_blocks = os.environ.get("NUM_FFN_BLOCKS", None) + num_ffn_blocks = int(num_ffn_blocks) if num_ffn_blocks else num_ffn_blocks + min_seq_len = max(num_q_blocks, num_ffn_blocks) if num_ffn_blocks else num_q_blocks + if (num_ffn_blocks and min_seq_len % num_ffn_blocks != 0) or min_seq_len % num_q_blocks != 0: + raise ValueError( + f"Got NUM_FFN_BLOCKS={num_ffn_blocks} and NUM_Q_BLOCKS={num_q_blocks}, tried to set seq_len={min_seq_len} for export but," + "seq_len is not divisible by either num_ffn_blocks or num_q_blocks, try chaning the values." + ) + + self.hash_params["NUM_Q_BLOCKS"] = num_q_blocks + self.hash_params["NUM_FFN_BLOCKS"] = num_ffn_blocks + self.hash_params["ENABLE_OPT_SWA"] = os.environ.get("ENABLE_OPT_SWA", "0") + return ( + min_seq_len + if min_seq_len > constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN + else constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN + ) + + def export( + self, + export_dir: Optional[str] = None, + prefill_only: Optional[bool] = False, + prefill_seq_len: Optional[int] = None, + **kwargs, + ) -> str: + """ + Export the model to ONNX format using ``torch.onnx.export``. + + This method prepares example inputs and dynamic axes based on the model configuration, + then exports the model to an ONNX graph suitable for compilation and deployment + on Cloud AI 100 hardware. It handles KV cache inputs/outputs and sampler-related inputs. + + Parameters + ---------- + export_dir : str, optional + Directory path where the exported ONNX graph will be saved. + If not provided, the default export directory is used. + use_onnx_subfunctions: bool, optional + whether to enable ONNX subfunctions during export. Exporting PyTorch model to ONNX with modules as subfunctions helps to reduce export/compile time. Defaults to False + Returns + ------- + str + Path to the generated ONNX graph file. + """ + bs: int = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE + seq_len: int = constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN + + # increase seq_len if using a larger number of blocks + if self.hash_params.get("blocking_kwargs", None): + max_blocks = -1 + for num_blocks in self.hash_params.get("blocking_kwargs").__dict__.values(): + if isinstance(num_blocks, int): + max_blocks = max(max_blocks, num_blocks) + block_size = -(-seq_len // max_blocks) + seq_len = block_size * max_blocks + fbs: int = constants.ONNX_EXPORT_EXAMPLE_FBS + + kv_cache_shape = get_padding_shape_from_config( + self.model.config, fbs if self.continuous_batching else bs, seq_len + ) + enable_chunking = kwargs.get("enable_chunking", False) + + # TODO: move this to a DA Serving utility class + if self.model.config.model_type in SPECIALIZED_DISAGG_SERVING_MODEL_ARCH: + if prefill_only: + if not enable_chunking and self.continuous_batching: + raise NotImplementedError( + "Looks like you are trying to run prefix-caching without chunking, this feature is not available yet!" + ) + self.__update_prefill_transform(enable=True, enable_chunking=enable_chunking) + self.hash_params.pop("retain_full_kv", None) + if "DeepseekV3ForCausalLM" not in (getattr(self.model.config, "architectures", None) or []): + seq_len = self.get_seq_len_and_handle_specialized_prefill_model( + prefill_seq_len=prefill_seq_len, enable_chunking=enable_chunking + ) + kv_cache_shape[2] = ( + seq_len + + (self.model.config.sliding_window if self.model.config.sliding_window is not None else 0) + if enable_chunking + else seq_len + ) + else: + self.__update_prefill_transform(False, retain_full_kv=kwargs.get("retain_full_kv", False)) + self.hash_params.pop("prefill_only", None) + self.hash_params.pop("NUM_Q_BLOCKS", None) + self.hash_params.pop("NUM_FFN_BLOCKS", None) + self.hash_params.pop("ENABLE_OPT_SWA", None) + self.hash_params.pop("chunking", None) + if kwargs.get("retain_full_kv", False): + kv_cache_shape[2] = seq_len + ( + self.model.config.sliding_window if self.model.config.sliding_window is not None else 0 + ) + self.hash_params["retain_full_kv"] = True + + example_inputs = { + "input_ids": torch.zeros((bs, seq_len), dtype=torch.int64), + "position_ids": torch.arange(seq_len, dtype=torch.int64).view(1, seq_len).repeat(bs, 1), + "past_key_values": [[] for _ in range(self.num_layers)], + } + dynamic_axes = { + "input_ids": {0: "batch_size", 1: "seq_len"}, + "position_ids": {0: "batch_size", 1: "seq_len"}, + } + if self.ccl_enabled: + example_inputs["comp_ctx_lengths"] = torch.randint(0, 127, (512,), dtype=torch.int8) + dynamic_axes["comp_ctx_lengths"] = {0: "comp_ctx_lengths"} + + if len(kv_cache_shape) == 3: # For GPTBigCode arch the pkv is 3d + pkv_dynamic_axes = { + 0: "full_batch_size" if self.continuous_batching else "batch_size", + 1: "ctx_len", + } + else: # pkv is 4d + pkv_dynamic_axes = { + 0: "full_batch_size" if self.continuous_batching else "batch_size", + 2: "ctx_len", + } + output_names = [] + if self.model.qaic_config is not None and self.model.qaic_config.get("include_sampler", False): + if self.model.qaic_config.get("return_pdfs", False): + output_names.append("probs") + output_names.append("next_tokens") + else: + output_names.append("logits") + + # TODO Update the get_padding_shape_from_config method to handle the case when the model config has attention_chunk_size or sliding_window and it should return a list of shapes for each layer + if ( + hasattr(self.model.config, "model_type") + and self.model.config.model_type in DYNAMIC_SEQ_LEN_SUPPORTED_MODEL_ARCH + ): + pkv_cache = self.model.get_dummy_pkv_cache( + self.model.config, fbs if self.continuous_batching else bs, seq_len + ) + for i in range(self.num_layers): + for kv in ["key", "value"]: + example_inputs["past_key_values"][i].append( + torch.zeros(pkv_cache[0][0].shape, dtype=self.model.config.torch_dtype) + ) + dynamic_axes[f"past_{kv}.{i}"] = pkv_dynamic_axes + output_names.append(f"past_{kv}.{i}_RetainedState") + + else: + # HACK: create common function for this including above if condition code + pkv_dynamic_axes = ( + self.model.get_pkv_dynamic_axes( + retain_full_kv=kwargs.get("retain_full_kv", False) + or (prefill_only and kwargs.get("enable_chunking", False)), + continuous_batching=self.continuous_batching, + ) + if hasattr(self.model, "get_pkv_dynamic_axes") + else pkv_dynamic_axes + ) + pkv_dynamic_axes = ( + [pkv_dynamic_axes] * self.model.config.num_hidden_layers + if isinstance(pkv_dynamic_axes, dict) + else pkv_dynamic_axes + ) + + for i in range(self.num_layers): + for kv in ["key", "value"]: + example_inputs["past_key_values"][i].append( + torch.zeros(kv_cache_shape, dtype=self.model.config.torch_dtype) + ) + dynamic_axes[f"past_{kv}.{i}"] = pkv_dynamic_axes[i] + output_names.append(f"past_{kv}.{i}_RetainedState") + + if "DeepseekV3ForCausalLM" in (getattr(self.model.config, "architectures", None) or []): + if self.model.qaic_config is not None and self.model.qaic_config.get("mla_absorption", None) is not None: + mla_absorption = self.model.qaic_config["mla_absorption"] + cache_compressed = mla_absorption.get("cache_compressed", False) + else: + cache_compressed = False + pkv_cache = self.model.get_dummy_pkv_cache( + self.model.config, fbs if self.continuous_batching else bs, seq_len + ) + if cache_compressed: + example_inputs = {k: v for k, v in example_inputs.items() if "past" not in k} + dynamic_axes = {k: v for k, v in dynamic_axes.items() if "past" not in k} + output_names = [v for v in output_names if "past" not in v] + example_inputs["compressed_kvs"] = [[] for _ in range(self.num_layers)] + for i in range(self.num_layers): + example_inputs["compressed_kvs"][i].append( + torch.zeros(pkv_cache[0][0].shape, dtype=self.model.config.torch_dtype) + ) + example_inputs["compressed_kvs"][i].append( + torch.zeros(pkv_cache[0][1].shape, dtype=self.model.config.torch_dtype) + ) + dynamic_axes[f"compressed_kv.{i}"] = {0: "batch_size", 2: "ctx_len"} + dynamic_axes[f"k_pe.{i}"] = {0: "batch_size", 2: "ctx_len"} + output_names.append(f"compressed_kv.{i}_RetainedState") + output_names.append(f"k_pe.{i}_RetainedState") + else: + example_inputs["past_key_values"] = [[] for _ in range(self.num_layers)] + for i in range(self.num_layers): + example_inputs["past_key_values"][i].append( + torch.zeros(pkv_cache[0][0].shape, dtype=self.model.config.torch_dtype) + ) + example_inputs["past_key_values"][i].append( + torch.zeros(pkv_cache[0][1].shape, dtype=self.model.config.torch_dtype) + ) + + if self.continuous_batching: + example_inputs["batch_index"] = torch.arange(bs).view(bs, 1) + dynamic_axes["batch_index"] = {0: "batch_size"} + + if self.is_tlm: + nlk = constants.ONNX_EXPORT_EXAMPLE_NLK # Number of Logits to Keep + example_inputs["num_logits_to_keep"] = torch.arange(nlk).view(nlk, 1) + dynamic_axes["num_logits_to_keep"] = {0: "num_logits_to_keep"} + + if self.model.qaic_config is not None and self.model.qaic_config.get("include_sampler", False): + example_inputs, output_names, dynamic_axes = get_sampling_inputs_and_outputs( + example_inputs=example_inputs, + output_names=output_names, + dynamic_axes=dynamic_axes, + continuous_batching=self.continuous_batching, + vocab_size=self.model.config.vocab_size, + qaic_config=self.model.qaic_config, + ) + + return self._export( + example_inputs, + output_names=output_names, + dynamic_axes=dynamic_axes, + export_dir=export_dir, + use_onnx_subfunctions=kwargs.get("use_onnx_subfunctions", False), + offload_pt_weights=kwargs.get("offload_pt_weights", True), + prefill_only=prefill_only, + ) + + def build_prefill_specialization( + self, + prefill_seq_len: int = 32, + ctx_len: int = 128, + comp_ctx_lengths: Optional[int] = None, + batch_size: int = 1, + kv_cache_batch_size: Optional[int] = None, + full_batch_size: Optional[int] = None, + **kwargs, + ): + """ + Builds a dictionary representing a compilation specialization for the prefill phase. + + Parameters + ---------- + prefill_seq_len : int, optional + Length of the prefill prompt. Default is 32. + ctx_len : int, optional + Maximum context length the compiled model can remember. Default is 128. + batch_size : int, optional + Batch size for the prefill. Default is 1. + kv_cache_batch_size : int, optional + Batch size for KV cache. If not provided, it defaults based on `full_batch_size` or `batch_size`. + full_batch_size : int, optional + Continuous batching batch size. Used if `continuous_batching` is enabled. Default is None. + + Returns + ------- + Dict[str, Union[int, str]] + A dictionary defining the prefill specialization. + """ + if not self.continuous_batching: + exec_batch_size = batch_size + elif prefill_seq_len == 1: + exec_batch_size = full_batch_size + else: + exec_batch_size = 1 + + if hasattr(self.model, "get_specializations"): + spec = self.model.get_specializations( + batch_size=exec_batch_size, + prefill_seq_len=prefill_seq_len, + ctx_len=ctx_len, + **kwargs, + )[0] + else: + spec = { + "batch_size": exec_batch_size, + "seq_len": prefill_seq_len, + "ctx_len": ctx_len, + } + if comp_ctx_lengths is not None: + spec["comp_ctx_lengths"] = comp_ctx_lengths + spec["num_logits_to_keep"] = 1 if self.is_tlm else None + if self.continuous_batching: + spec["full_batch_size"] = kv_cache_batch_size + else: + spec["batch_size"] = kv_cache_batch_size + # TODO: remove this; not required + if full_batch_size: + spec["full_batch_exec_size"] = exec_batch_size + result = {k: v for k, v in spec.items() if v is not None} + result["_graph_name"] = "Prefill" + return result + + def build_decode_specialization( + self, + prefill_seq_len: int = 32, + ctx_len: int = 128, + comp_ctx_lengths: Optional[int] = None, + batch_size: int = 1, + kv_cache_batch_size: Optional[int] = None, + full_batch_size: Optional[int] = None, + num_speculative_tokens: Optional[int] = None, + **kwargs, + ): + """ + Builds a dictionary representing a compilation specialization for the decode phase. + + Parameters + ---------- + prefill_seq_len : int, optional + Length of the prefill prompt. Used to avoid duplicate specializations. Default is 32. + ctx_len : int, optional + Maximum context length the compiled model can remember. Default is 128. + batch_size : int, optional + Batch size for the decode phase. Default is 1. + kv_cache_batch_size : int, optional + Batch size for KV cache. If not provided, it defaults based on `full_batch_size` or `batch_size`. + full_batch_size : int, optional + Continuous batching batch size. Used if `continuous_batching` is enabled. Default is None. + num_speculative_tokens : int, optional + Number of speculative tokens for Speculative Decoding Target Language Model. Default is None. + + Returns + ------- + Optional[Dict[str, Union[int, str]]] + A dictionary defining the decode specialization, or None if it would be a duplicate + of the prefill specialization (e.g., if prefill_seq_len is 1 and not continuous batching). + """ + if hasattr(self.model, "get_specializations"): + spec = self.model.get_specializations( + batch_size=full_batch_size if self.continuous_batching else batch_size, + prefill_seq_len=(num_speculative_tokens + 1) if self.is_tlm else 1, + ctx_len=ctx_len, + )[1] + else: + spec = { + "batch_size": full_batch_size if self.continuous_batching else batch_size, + "seq_len": (num_speculative_tokens + 1) if self.is_tlm else 1, + "ctx_len": ctx_len, + } + if comp_ctx_lengths is not None: + spec["comp_ctx_lengths"] = comp_ctx_lengths + + spec["num_logits_to_keep"] = (num_speculative_tokens + 1) if self.is_tlm else None + + if self.continuous_batching: + spec["full_batch_size"] = kv_cache_batch_size + else: + spec["batch_size"] = kv_cache_batch_size + result = {k: v for k, v in spec.items() if v is not None} + result["_graph_name"] = "Decode" + return result + + def compile( + self, + onnx_path: Optional[str] = None, + compile_dir: Optional[str] = None, + *, + prefill_seq_len: int = 32, + ctx_len: int = 128, + comp_ctx_lengths_prefill: Optional[List[int]] = None, + comp_ctx_lengths_decode: Optional[List[int]] = None, + batch_size: int = 1, + full_batch_size: Optional[int] = None, + kv_cache_batch_size: Optional[int] = None, + num_devices: int = 1, + num_cores: int = 16, # FIXME: Make this mandatory arg + mxfp6_matmul: bool = False, + mxint8_kv_cache: bool = False, + num_speculative_tokens: Optional[int] = None, + prefill_only: Optional[bool] = None, + use_onnx_subfunctions: bool = False, + offload_pt_weights: Optional[bool] = True, + enable_chunking: Optional[bool] = False, + retain_full_kv: Optional[bool] = None, + **compiler_options, + ) -> str: + """ + + Compile the exported ONNX model using the Cloud AI 100 Platform SDK compiler. + + This method generates a ``qpc`` package. If the model has not been exported yet, + this method will handle the export process. Additional arguments for the `qaic-compile` + compiler can be passed as keyword arguments. + + Parameters + ---------- + onnx_path : str, optional + Path to a pre-exported ONNX model. If not provided, the model will be exported first. + compile_dir : str, optional + Directory to save the generated QPC package. If not provided, a default directory is used. + prefill_seq_len : int, optional + Length of the prefill prompt. Default is 32. + ctx_len : int, optional + Maximum context length the compiled model can remember. Default is 128. + batch_size : int, optional + Batch size. Default is 1. + full_batch_size : int, optional + Continuous batching batch size. Required if `continuous_batching=True` was + set during `from_pretrained`. + kv_cache_batch_size : int, optional + Batch size for KV cache. If not provided, it defaults to `full_batch_size` (if + continuous batching) or `batch_size`. + num_devices : int, optional + Number of devices to compile for. Default is 1. + num_cores : int, optional + Number of cores to use for compilation. + mxfp6_matmul : bool, optional + Use MXFP6 compression for weights. Default is False. + mxint8_kv_cache : bool, optional + Use MXINT8 compression for KV cache. Default is False. + num_speculative_tokens : int, optional + Number of speculative tokens for Speculative Decoding Target Language Model. + Required if the model is configured as a Target Language Model (`is_tlm=True`). + prefill_only : bool, optional + If True, compiles only for the prefill stage. If False, compiles only for + the decode stage. If None, compiles for both stages. Default is None. + use_onnx_subfunctions: bool, optional + whether to enable ONNX subfunctions during export. Exporting PyTorch model to ONNX with modules as subfunctions helps to reduce export/compile time. Defaults to False + **compiler_options : dict + Additional compiler options for QAIC or QNN compilers. + + **For QAIC Compiler:** Extra arguments for qaic-compile can be passed. Some common options include: + + - mos (int, optional): Effort level to reduce on-chip memory. Defaults to -1, meaning no effort. Defaults to -1. + - aic_enable_depth_first (bool, optional): Enables DFS with default memory size. Defaults to False. + - allow_mxint8_mdp_io (bool, optional): Allows MXINT8 compression of MDP IO traffic. Defaults to False. + + Params are converted to flags as below: + + - ``aic_num_cores=16`` -> ``-aic-num-cores=16`` + - ``convert_to_fp16=True`` -> ``-convert-to-fp16`` + + **For QNN Compiler:** Following arguments can be passed as: + + - enable_qnn (bool): Enables QNN Compilation. + - qnn_config (str): Path of QNN Config parameters file. Any extra parameters for QNN compilation can be passed via this file. + + Returns + ------- + str + Path to the compiled QPC package. + + Raises + ------ + TypeError + If `prefill_only` is not a boolean. + If `full_batch_size` is None when `continuous_batching` is True. + If `num_speculative_tokens` is None when the model is a TLM. + ValueError + If KV caching is requested without continuous batching (`full_batch_size`). + If `include_sampler` is True and `num_speculative_tokens` is greater than 0. + If `num_speculative_tokens` is not an integer greater than 1. + If `prefill_seq_len` is less than `num_speculative_tokens + 1` for TLM models. + + """ + if self.model.qaic_config is not None and self.model.qaic_config.get("mla_absorption", None) is not None: + mla_absorption = self.model.qaic_config["mla_absorption"] + cache_compressed = mla_absorption.get("cache_compressed", False) + else: + cache_compressed = False + if ( + self.model.qaic_config is not None + and self.model.qaic_config.get("mla_absorption", None) is not None + and not cache_compressed + ): + logger.warning("mla_absorption will be ignored as cache_compressed is set to False") + if (kv_cache_batch_size or full_batch_size) and not self.continuous_batching: + logger.warning( + "`kv_cache_batch_size` or `full_batch_size` is being passed" + "This will be ignored as `continuous_batching` is set to `False` in `from_pretrained`" + ) + + if prefill_only is None or not prefill_only: + if self.continuous_batching and full_batch_size is None: + raise TypeError("`full_batch_size` is required when `continuous_batching=True`.") + else: + if self.continuous_batching and kv_cache_batch_size is None and full_batch_size is None: + raise ValueError( + "Please pass valid integer for kv_cache_batch_size or full_batch_size, both have same meaning, as continuous_batching is enabled for prefill-only model" + ) + + # Infer kv_cache_batch_size if not provided + kv_cache_batch_size = kv_cache_batch_size or full_batch_size or batch_size + + # if ccl_enabled is True read Compute-Context-Length lists + if self.ccl_enabled: + if comp_ctx_lengths_prefill is None and comp_ctx_lengths_decode is None: + logger.info("Auto-generating CCL-prefill and CCL-decode lists based on Context Length (CL).") + self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode, ctx_len = process_ccl_specializations( + comp_ctx_lengths_prefill, comp_ctx_lengths_decode, ctx_len, prefill_seq_len, enable_chunking + ) + # For supporting VLLM and Disaggregated with CCL + elif comp_ctx_lengths_prefill is not None or comp_ctx_lengths_decode is not None: + self.ccl_enabled = True + if isinstance(comp_ctx_lengths_prefill, str): + import ast + + try: + # Safely evaluate the string to a Python list for disaggregated input + self.comp_ctx_lengths_prefill = ast.literal_eval(comp_ctx_lengths_prefill) + self.comp_ctx_lengths_decode = ast.literal_eval(comp_ctx_lengths_decode) + + except (ValueError, SyntaxError): + raise ValueError("Invalid format for comp_ctx_lengths. Expected a list-like string.") + else: + self.comp_ctx_lengths_prefill = comp_ctx_lengths_prefill + self.comp_ctx_lengths_decode = comp_ctx_lengths_decode + + self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode, ctx_len = process_ccl_specializations( + self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode, ctx_len, prefill_seq_len, enable_chunking + ) + # --- Validation --- + if prefill_only is not None and not isinstance(prefill_only, bool): + raise TypeError("`prefill_only` must be a boolean.") + + if self.is_tlm: + num_speculative_tokens = self.check_and_get_num_speculative_tokens(num_speculative_tokens, prefill_seq_len) + + if ( + self.model.qaic_config is not None + and self.model.qaic_config.get("include_sampler", False) + and num_speculative_tokens is not None + and num_speculative_tokens > 0 + ): + raise ValueError("Currently, sampler does not support `num_speculative_tokens` > 0.") + + # --- Specializations --- + specializations = [] + if prefill_only is None or prefill_only or prefill_seq_len == 1: + # TODO: we are handling decode-only case inside prefill call which is utterly mis-leading + if self.comp_ctx_lengths_prefill is not None or self.comp_ctx_lengths_decode is not None: + ccl_lengths = self.comp_ctx_lengths_decode if prefill_seq_len == 1 else self.comp_ctx_lengths_prefill + # Adding elements from self.comp_ctx_lengths_prefill to prefill_specialization + for i in range(0, len(ccl_lengths)): + specializations.append( + self.build_prefill_specialization( + prefill_seq_len=prefill_seq_len, + ctx_len=ctx_len, + comp_ctx_lengths=ccl_lengths[i], + batch_size=batch_size, + kv_cache_batch_size=kv_cache_batch_size, + full_batch_size=full_batch_size, + ) + ) + + else: + specializations.append( + self.build_prefill_specialization( + prefill_seq_len=prefill_seq_len, + ctx_len=ctx_len, + batch_size=batch_size, + kv_cache_batch_size=kv_cache_batch_size, + full_batch_size=full_batch_size, + prefill_only=prefill_only, + enable_chunking=enable_chunking, + ) + ) + + if (prefill_only is None or not prefill_only) and prefill_seq_len != 1: + if self.comp_ctx_lengths_decode is not None: + # Adding elements from self.comp_ctx_lengths_decode to decode_specialization + for i in range(0, len(self.comp_ctx_lengths_decode)): + decode_spec = self.build_decode_specialization( + prefill_seq_len=prefill_seq_len, + ctx_len=ctx_len, + comp_ctx_lengths=self.comp_ctx_lengths_decode[i], + batch_size=batch_size, + kv_cache_batch_size=kv_cache_batch_size, + full_batch_size=full_batch_size, + num_speculative_tokens=num_speculative_tokens, + ) + if decode_spec: + specializations.append(decode_spec) + + else: + decode_spec = self.build_decode_specialization( + prefill_seq_len=prefill_seq_len, + ctx_len=ctx_len, + batch_size=batch_size, + kv_cache_batch_size=kv_cache_batch_size, + full_batch_size=full_batch_size, + num_speculative_tokens=num_speculative_tokens, + prefill_only=prefill_only, + ) + if decode_spec: + specializations.append(decode_spec) + + if kw_spec := compiler_options.pop("specializations", None): + specializations = kw_spec + + target_dtype = getattr(self.model.config, "torch_dtype", torch.float32) + kv_cache_dtype = "mxint8" if mxint8_kv_cache else CUSTOM_IO_DTYPE_MAP[target_dtype] + # --- Compilation --- + custom_io = {} + if not cache_compressed: + for suffix in ["", "_RetainedState"]: + for i in range(self.num_layers): + for kv in ["key", "value"]: + custom_io[f"past_{kv}.{i}{suffix}"] = kv_cache_dtype + else: + for suffix in ["", "_RetainedState"]: + for i in range(self.num_layers): + custom_io[f"compressed_kv.{i}{suffix}"] = kv_cache_dtype + custom_io[f"k_pe.{i}{suffix}"] = kv_cache_dtype + + qpc_path = self._compile( + onnx_path=onnx_path, + compile_dir=compile_dir, + retained_state=True, + specializations=specializations, + convert_to_fp16=(CUSTOM_IO_DTYPE_MAP[target_dtype] == "float16"), + mxfp6_matmul=mxfp6_matmul, + custom_io=custom_io, + mdp_ts_num_devices=num_devices, + num_speculative_tokens=num_speculative_tokens, + aic_num_cores=num_cores, + mxint8_kv_cache=mxint8_kv_cache, + use_onnx_subfunctions=use_onnx_subfunctions, + prefill_only=prefill_only, + offload_pt_weights=offload_pt_weights, + enable_chunking=enable_chunking, + retain_full_kv=retain_full_kv, + **compiler_options, + ) + + return qpc_path + + # FIXME: Update this method to match with transformers AutoModelForCausalLM.generate + def generate( + self, + tokenizer: Union[PreTrainedTokenizerFast, PreTrainedTokenizer], + prompts: List[str], + device_id: List[int] = None, + runtime_ai100: bool = True, + **kwargs, + ): + """ + Generate output by executing the compiled QPC on Cloud AI 100 hardware. + + This method runs sequential execution based on the compiled model's batch size and the number of prompts. + If the number of prompts is not divisible by the batch size, the last batch will be dropped. + + Parameters + ---------- + tokenizer : PreTrainedTokenizer or PreTrainedTokenizerFast + Tokenizer for the model. + prompts : list of str + List of prompts to generate output for. + device_id : list of int, optional + Device IDs for running the QPC. Defaults to `[0]` if not specified. + runtime_ai100 : bool, optional + Whether to use AI 100 runtime. Default is True. + **kwargs : + Additional keyword arguments. Currently supports: + - `generation_len (int, optional)`: The maximum number of tokens to generate. + - `write_io (bool, optional)`: Whether to save the io files. + + Returns + ------- + CloudAI100ExecInfoNew + Output from the AI 100 runtime, containing generated IDs and performance metrics. + + Raises + ------ + TypeError + If the QPC path is not set (i.e., `compile` was not run). + NotImplementedError + If `runtime_ai100` is False. + """ + write_io = kwargs.pop("write_io", False) + self._write_io_dir = os.path.join(os.path.dirname(self.onnx_path), "io_dir") if write_io else None + + if runtime_ai100: + if not isinstance(self.qpc_path, Path): + raise TypeError("Please run compile API first!") + generation_len = kwargs.pop("generation_len", None) + return QEfficient.cloud_ai_100_exec_kv( + tokenizer=tokenizer, + qpc_path=self.qpc_path, + prompt=prompts, + comp_ctx_lengths_prefill=self.comp_ctx_lengths_prefill, + comp_ctx_lengths_decode=self.comp_ctx_lengths_decode, + device_id=device_id, + generation_len=generation_len, + automation=kwargs.pop("automation", False), + iteration=kwargs.pop("iteration", 1), + is_tlm=self.is_tlm, + write_io_dir=self._write_io_dir, + **kwargs, + ) + else: + raise NotImplementedError("Only AI_100 runtime is supported right now via generate API") + + def check_and_get_num_speculative_tokens(self, num_speculative_tokens: Optional[int], prefill_seq_len: int): + """ + Validates and retrieves the number of speculative tokens for TLM models. + + Parameters + ---------- + num_speculative_tokens : int, optional + The number of speculative tokens provided by the user. + prefill_seq_len : int + The prefill sequence length. + + Returns + ------- + int + The determined number of speculative tokens. + + Raises + ------ + TypeError + If `num_speculative_tokens` is None when `is_tlm` is True. + ValueError + If `num_speculative_tokens` is not an integer greater than 1. + If `prefill_seq_len` is less than `num_speculative_tokens + 1`. + """ + if not self.is_tlm: + return None + if hasattr(self.model.config, "speculative_config"): + num_speculative_tokens_ = self.model.config.speculative_config["num_speculative_tokens"] + if num_speculative_tokens is not None: + logger.warning( + f"arg `num_speculative_tokens` is a fixed value of {num_speculative_tokens_} for this model." + f" Passed value of {num_speculative_tokens} will be ignored." + ) + num_speculative_tokens = num_speculative_tokens_ + elif num_speculative_tokens is None: + raise TypeError("missing required argument `num_speculative_tokens` as `is_tlm` instance variable is True.") + + if not isinstance(num_speculative_tokens, int) and num_speculative_tokens: + ValueError( + f"`num_speculative_tokens` arg should be an integer greater than 1, got {num_speculative_tokens}" + ) + num_logits_to_keep = num_speculative_tokens + 1 + if prefill_seq_len < num_logits_to_keep: + raise ValueError( + f"sequence length ({prefill_seq_len}) must be at least `num_speculative_tokens+1` ({num_logits_to_keep})" + ) + return num_speculative_tokens + + +class QEFFAutoModelForSpeechSeq2Seq(QEFFTransformersBase, MultimodalUtilityMixin): + """ + QEfficient class for sequence-to-sequence speech-to-text models (e.g., Whisper, Encoder-Decoder speech models). + + This class enables efficient export, compilation, and inference of speech models on Cloud AI 100 hardware. + It is recommended to use the ``from_pretrained`` method for initialization. + + Example + ------- + .. code-block:: python + + from datasets import load_dataset + from transformers import AutoProcessor + from QEfficient import QEFFAutoModelForSpeechSeq2Seq + + base_model_name = "openai/whisper-tiny" + ## STEP 1 -- load audio sample, using a standard english dataset, can load specific files if longer audio needs to be tested; also load initial processor + ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") + data = ds[0]["audio"]["array"] + # reshape to so shape corresponds to data with batch size 1 + data = data.reshape(-1) + sample_rate = ds[0]["audio"]["sampling_rate"] + processor = AutoProcessor.from_pretrained(base_model_name) + + ## STEP 2 -- init base model + qeff_model = QEFFAutoModelForSpeechSeq2Seq.from_pretrained(base_model_name) + + ## STEP 3 -- export and compile model + qeff_model.compile() + + ## STEP 4 -- generate output for loaded input and processor + exec_info = qeff_model.generate(inputs=processor(data, sampling_rate=sample_rate, return_tensors="pt"), generation_len=25) + + ## STEP 5 (optional) -- use processor to decode output + print(processor.batch_decode(exec_info.generated_ids)[0]) + """ + + _hf_auto_class = AutoModelForSpeechSeq2Seq + _pytorch_transforms = [CustomOpsTransform, AwqToMatmulNbitsTransform, GPTQToMatmulNbitsTransform, KVCacheTransform] + _onnx_transforms = [] + + def __init__(self, model: nn.Module, **kwargs): + """ + Initialize a QEFFAutoModelForSpeechSeq2Seq instance. + + Parameters + ---------- + model : nn.Module + A PyTorch model with a sequence-to-sequence speech-to-text head (e.g., Whisper). + **kwargs : + Additional keyword arguments passed to the base class constructor. + + Raises + ------ + TypeError + If the model is not a supported speech-to-text model (i.e., not a `ForConditionalGeneration` model). + """ + model_class_name = model.__class__.__name__ + + if not (model_class_name.endswith("ForConditionalGeneration")): + raise TypeError(f"Required pytorch module with ForConditionalGeneration, got {model_class_name}") + + model.config.use_cache = True + super().__init__(model, **kwargs) + self.num_layers = model.config.num_hidden_layers + self.hash_params["qeff_auto_class"] = self.__class__.__name__ + + @property + def get_model_config(self) -> dict: + """ + Get the configuration dictionary of the underlying HuggingFace model. + + Returns + ------- + dict + The configuration dictionary. + """ + return self.model.config.__dict__ + + def export(self, export_dir: Optional[str] = None, **kwargs) -> str: + """ + Export the model to ONNX format using ``torch.onnx.export``. + + This method prepares example inputs and dynamic axes based on the model configuration, + then exports the model to an ONNX graph suitable for compilation and deployment on Cloud AI 100 hardware. + + Parameters + ---------- + export_dir : str, optional + Directory path where the exported ONNX graph will be saved. + If not provided, the default export directory is used. + use_onnx_subfunctions: bool, optional + whether to enable ONNX subfunctions during export. Exporting PyTorch model to ONNX with modules as subfunctions helps to reduce export/compile time. Defaults to False + + Returns + ------- + str + Path to the generated ONNX graph file. + """ + inputs = self.model.get_dummy_inputs() + dynamic_axes = self.model.get_onnx_dynamic_axes() + output_names = self.model.get_output_names() + return self._export( + inputs, + output_names=output_names, + dynamic_axes=dynamic_axes, + export_dir=export_dir, + use_onnx_subfunctions=kwargs.get("use_onnx_subfunctions", False), + ) + + def compile( + self, + onnx_path: Optional[str] = None, + compile_dir: Optional[str] = None, + *, + prefill_seq_len: Optional[int] = 1, + encoder_ctx_len: Optional[int] = None, + ctx_len: int = 150, + full_batch_size: Optional[int] = None, + kv_cache_batch_size: Optional[int] = None, + batch_size: int = 1, + num_devices: int = 1, + num_cores: int = 16, # FIXME: Make this mandatory arg + mxfp6_matmul: bool = False, + mxint8_kv_cache: bool = False, + num_speculative_tokens: Optional[int] = None, + use_onnx_subfunctions: bool = False, + **compiler_options, + ) -> str: + """ + Compile the exported ONNX model using the Cloud AI 100 Platform SDK compiler. + + This method generates a ``qpc`` package. If the model has not been exported yet, + this method will handle the export process. Additional arguments for the `qaic-compile` + compiler can be passed as keyword arguments. + + Parameters + ---------- + onnx_path : str, optional + Path to a pre-exported ONNX model. If not provided, the model will be exported first. + compile_dir : str, optional + Directory to save the generated QPC package. + prefill_seq_len : int, optional + Prefill sequence length. This parameter is typically not critically used for + SpeechSeq2Seq models' decoder compilation as the first decoder input is `seq_len=1`. + Default is 1. + encoder_ctx_len : int, optional + Maximum context length for the encoder part of the model. If None, it's inferred + from the model configuration or defaults (e.g., 1500 for Whisper). + ctx_len : int, optional + Maximum decoder context length. This defines the maximum output sequence length + the compiled model can handle. Default is 150. + batch_size : int, optional + Batch size. Default is 1. + num_devices : int, optional + Number of devices to compile for. Default is 1. + num_cores : int, optional + Number of cores to use for compilation. + mxfp6_matmul : bool, optional + Use MXFP6 compression for weights. Default is False. + mxint8_kv_cache : bool, optional + Use MXINT8 compression for KV cache. Default is False. + full_batch_size : int, optional + Not yet supported for this model. + kv_cache_batch_size : int, optional + Not yet supported for this model. + num_speculative_tokens : int, optional + Not yet supported for this model. + use_onnx_subfunctions: bool, optional + whether to enable ONNX subfunctions during export. Exporting PyTorch model to ONNX with modules as subfunctions helps to reduce export/compile time. Defaults to False + **compiler_options : dict + Additional compiler options for QAIC. + + **For QAIC Compiler:** Extra arguments for qaic-compile can be passed. Some common options include: + + - mos (int, optional): Effort level to reduce on-chip memory. Defaults to -1, meaning no effort. Defaults to -1. + - aic_enable_depth_first (bool, optional): Enables DFS with default memory size. Defaults to False. + - allow_mxint8_mdp_io (bool, optional): Allows MXINT8 compression of MDP IO traffic. Defaults to False. + + Params are converted to flags as below: + + - ``aic_num_cores=16`` -> ``-aic-num-cores=16`` + - ``convert_to_fp16=True`` -> ``-convert-to-fp16`` + + Returns + ------- + str + Path to the compiled QPC package. + + """ + specializations, compiler_options = self.model.get_specializations( + batch_size, + encoder_ctx_len, + ctx_len, + **compiler_options, + ) + + if full_batch_size: + logger.warning("Continuous batching is not yet enabled for AutoModelForSpeechSeq2Seq") + + if kv_cache_batch_size: + logger.warning("Prefix caching is not yet enabled for AutoModelForSpeechSeq2Seq") + + if mxint8_kv_cache: + logger.warning("mxint8 cache is not yet enabled for AutoModelForSpeechSeq2Seq") + + if num_speculative_tokens: + logger.warning("Speculative decoding is not yet enabled for AutoModelForSpeechSeq2Seq") + + output_names = self.model.get_output_names() + + target_dtype = getattr(self.model.config, "torch_dtype", torch.float32) + kv_cache_dtype = CUSTOM_IO_DTYPE_MAP[target_dtype] + custom_io = {} + + custom_io["input_features"] = kv_cache_dtype + + # Slice output_names to get input names + for output_name in output_names: + if output_name.endswith("_RetainedState"): + custom_io[output_name[: -len("_RetainedState")]] = kv_cache_dtype + + # Get output names + for output_name in output_names: + if output_name.endswith("_RetainedState"): + custom_io[output_name] = kv_cache_dtype + + return self._compile( + onnx_path=onnx_path, + compile_dir=compile_dir, + retained_state=True, + specializations=specializations, + convert_to_fp16=(CUSTOM_IO_DTYPE_MAP[target_dtype] == "float16"), + mxfp6_matmul=mxfp6_matmul, + mdp_ts_num_devices=num_devices, + aic_num_cores=num_cores, + custom_io=custom_io, + use_onnx_subfunctions=use_onnx_subfunctions, + **compiler_options, + ) + + def generate( + self, + inputs: torch.Tensor, + generation_len: int, + streamer: Optional[TextStreamer] = None, + device_ids: List[int] = None, + write_io: bool = False, + ) -> Union[torch.Tensor, np.ndarray]: + """ + Generate output until ``<|endoftext|>`` token or `generation_len` is reached, + by executing the compiled QPC on Cloud AI 100 hardware. + + This method performs sequential execution based on the compiled model's batch size + and the provided audio tensors. It manages the iterative decoding process and KV cache. + + Parameters + ---------- + inputs : Dict[str, np.ndarray] + Model inputs for inference, typically a dictionary containing: + - `input_features` (np.ndarray): Preprocessed audio features. + - `decoder_input_ids` (np.ndarray): Initial decoder input IDs (e.g., start token). + - `decoder_position_ids` (np.ndarray): Initial decoder position IDs. + These should be prepared to match the compiled model's expectations. + generation_len : int + Maximum number of tokens to generate. The generation stops if this limit is reached + or the model generates an end-of-sequence token. + streamer : TextStreamer, optional + Streamer to receive generated tokens in real-time. Default is None. + device_ids : List[int], optional + Device IDs for running the QPC. Defaults to `[0]` if not specified. + + Returns + ------- + CloudAI100ExecInfoNew + Output from the AI 100 runtime, including generated IDs and performance metrics. + + Raises + ------ + TypeError + If the QPC path is not set (i.e., `compile` was not run). + """ + if not isinstance(self.qpc_path, Path): + raise TypeError("Please run compile API first!") + + self._write_io_dir = os.path.join(os.path.dirname(self.onnx_path), "io_dir") if write_io else None + + inputs = self.auto_correct_inputs(inputs) + if self.qpc_session is None: + self.qpc_session = QAICInferenceSession(str(self.qpc_path), device_ids) + self.batch_size = self.qpc_session.bindings[0].dims[0] + + inputs["input_features"] = inputs["input_features"].numpy().astype(np.float16) + + # add start token id and initial position ids to inputs + seq_len = 1 + inputs["input_ids"] = ( + torch.ones((self.batch_size, seq_len), dtype=torch.int64) * self.model.config.decoder_start_token_id + ).numpy() + inputs["position_ids"] = ( + torch.arange(seq_len, dtype=torch.int64).view(1, seq_len).repeat(self.batch_size, 1).numpy() + ) + + self.qpc_session.skip_buffers( + [x for x in self.qpc_session.input_names + self.qpc_session.output_names if x.startswith("past_")] + ) + + outputs = { + "logits": np.random.randn(self.batch_size, 1, self.model.config.vocab_size).astype(np.float32), + } + self.qpc_session.set_buffers(outputs) + + # encoder run + start = perf_counter() + outputs = self.qpc_session.run(inputs) + + if self._write_io_dir is not None: + write_io_files(inputs, outputs, self._write_io_dir, "prefill", "aic_batch_io", True, False) + + # array to hold generated tokens + generated_ids = np.full((self.batch_size, generation_len + 1), self.model.config.eos_token_id) + generated_ids[:, 0] = [self.model.config.decoder_start_token_id] + logits = outputs["logits"] + next_token = logits.argmax(-1) + generated_ids[:, 1] = next_token.squeeze(1) + + if streamer: + streamer.put(next_token) + + inputs["input_features"] = np.zeros((self.batch_size, self.model.config.num_mel_bins, 1)).astype(np.float16) + + loop_start = perf_counter() + for num_tokens in range(generation_len): + outputs = self.qpc_session.run(inputs) + if self._write_io_dir is not None: + write_io_files(inputs, outputs, self._write_io_dir, "decode", "aic_batch_io", True, False) + self._write_io_dir = None + + logits = outputs["logits"] + next_token = logits.argmax(-1) + generated_ids[:, num_tokens + 1] = next_token.squeeze(1) + + if next_token[0][0] == self.model.config.eos_token_id: + break + + inputs["input_ids"] = next_token + inputs["position_ids"] += 1 + + if streamer: + streamer.put(next_token) + end = perf_counter() + + prefill_time, decode_perf, total_perf, total_time = calculate_latency(num_tokens, loop_start, start, end) + + return CloudAI100ExecInfoNew( + batch_size=self.batch_size, + generated_ids=generated_ids, + perf_metrics=PerfMetrics(prefill_time, decode_perf, total_perf, total_time), + ) + + +class QEFFAutoModelForCTC(QEFFTransformersBase): + """ + The QEFFAutoModelForCTC class is designed for transformer models with a Connectionist Temporal Classification (CTC) speech-to-text head, + including Wav2Vec2 and other encoder-only speech models optimized for alignment-free transcription. + Although it is possible to initialize the class directly, we highly recommend using the ``from_pretrained`` method for initialization. + + Example + ------- + .. code-block:: python + + import torchaudio + from QEfficient import QEFFAutoModelForCTC + from transformers import AutoProcessor + + # Initialize the model using from_pretrained similar to transformers.AutoModelForCTC. + model=QEFFAutoModelForCTC.from_pretrained(model_name) + + # Now you can directly compile the model for Cloud AI 100 + model.compile(num_cores=16) # Considering you have a Cloud AI 100 SKU + + #prepare input + processor = AutoProcessor.from_pretrained(model_name) + input_audio, sample_rate = [...] # audio data loaded in via some external audio package, such as librosa or soundfile + + # Resample the input_audio if necessary + if input_audio.shape[0] > 1: + input_audio = input_audio.mean(dim=0) + if sample_rate != 16000: + resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000) + input_audio = resampler(input_audio) + + # You can now execute the model + out = model.generate(processor,inputs=input_audio) + """ + + _hf_auto_class = AutoModelForCTC + _pytorch_transforms = [CustomOpsTransform, AwqToMatmulNbitsTransform, GPTQToMatmulNbitsTransform] + _onnx_transforms = [] + + def __init__(self, model: nn.Module, **kwargs): + super().__init__(model, **kwargs) + self.model.base_model.config.use_cache = True + + self.hash_params["qeff_auto_class"] = self.__class__.__name__ + + @classmethod + @with_replaced_quantizers + def from_pretrained(cls, pretrained_model_name_or_path, pooling=None, *args, **kwargs): + """ + This method serves as the easiest entry point into using QEfficient. The interface is designed to be similar to transformers.AutoModelForCTC. + Once the model is initialized, you can use other methods such as export, compile, and generate on the same object. + + Args: + pretrained_model_name_or_path (str): The name or path of the pre-trained model. + + .. code-block:: python + + import torchaudio + from QEfficient import QEFFAutoModelForCTC + from transformers import AutoProcessor + + # Initialize the model using from_pretrained similar to transformers.AutoModelForCTC. + model=QEFFAutoModelForCTC.from_pretrained(model_name) + + # Now you can directly compile the model for Cloud AI 100 + model.compile(num_cores=16) # Considering you have a Cloud AI 100 SKU + + #prepare input + processor = AutoProcessor.from_pretrained(model_name) + input_audio, sample_rate = [...] # audio data loaded in via some external audio package, such as librosa or soundfile + + # Resample the input_audio if necessary + if input_audio.shape[0] > 1: + input_audio = input_audio.mean(dim=0) + if sample_rate != 16000: + resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000) + input_audio = resampler(input_audio) + + # You can now execute the model + out = model.generate(processor,inputs=input_audio) + """ + enable_proxy = kwargs.pop("enable_proxy", False) + if kwargs.get("attn_implementation", None) not in {None, "eager"}: + logger.warning('Updating attn_implementation="eager"') + + if kwargs.get("low_cpu_mem_usage", None): + logger.warning("Updating low_cpu_mem_usage=False") + + kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False}) + + model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, *args, **kwargs) + + # This is support models that should be classified to in a different auto class but transformers load them via this class + kv_offload = kwargs.pop("kv_offload", None) + + kwargs.update({"enable_proxy": enable_proxy} if enable_proxy else {}) + + if model.__class__.__name__ in MISCLASSIFIED_CAUSAL_LM_TO_QEFF_AUTO_CLASS_MAP: + return MISCLASSIFIED_CAUSAL_LM_TO_QEFF_AUTO_CLASS_MAP[model.__class__.__name__]( + model, kv_offload=kv_offload, **kwargs + ) + + return cls(model, pretrained_model_name_or_path=pretrained_model_name_or_path, pooling=pooling, **kwargs) + + @property + def get_model_config(self) -> dict: + return self.model.config.__dict__ + + def export(self, export_dir: Optional[str] = None, **kwargs) -> str: + """ + Exports the model to ``ONNX`` format using ``torch.onnx.export``. + + ``Optional`` Args: + :export_dir (str, optional): The directory path to store ONNX-graph. + :use_onnx_subfunctions: bool, optional + whether to enable ONNX subfunctions during export. Exporting PyTorch model to ONNX with modules as subfunctions helps to reduce export/compile time. Defaults to False + + Returns: + :str: Path of the generated ``ONNX`` graph. + """ + bs = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE + seq_len = constants.WAV2VEC2_MAX_SEQ_LEN + + example_inputs = { + "input_values": torch.zeros((bs, seq_len), dtype=self.model.config.torch_dtype), + } + + dynamic_axes = {"input_values": {0: "batch_size", 1: "seq_len"}} + + output_names = ["logits"] + + return self._export( + example_inputs, + output_names=output_names, + dynamic_axes=dynamic_axes, + export_dir=export_dir, + use_onnx_subfunctions=kwargs.get("use_onnx_subfunctions", False), + ) + + def compile( + self, + onnx_path: Optional[str] = None, + compile_dir: Optional[str] = None, + *, + seq_len: Union[int, List[int]] = 480000, + batch_size: int = 1, + num_devices: int = 1, + num_cores: int = 16, # FIXME: Make this mandatory arg + mxfp6_matmul: bool = False, + use_onnx_subfunctions: bool = False, + **compiler_options, + ) -> str: + """ + This method compiles the exported ``ONNX`` model using the Cloud AI 100 Platform SDK compiler binary found at ``/opt/qti-aic/exec/qaic-compile`` and generates a ``qpc`` package. + If the model has not been exported yet, this method will handle the export process. + You can pass any other arguments that the `qaic-compile` takes as extra kwargs. + + ``Optional`` Args: + :onnx_path (str, optional): Path to pre-exported onnx model. + :compile_dir (str, optional): Path for saving the qpc generated. + :seq_len (Union[int, List[int]]): The length of the prompt should be less that ``seq_len``. ``Defaults to 32``. + :batch_size (int, optional): Batch size. ``Defaults to 1``. + :num_devices (int): Number of devices the model needs to be compiled for. Defaults to 1. + :num_cores (int): Number of cores used to compile the model. + :mxfp6_matmul (bool, optional): Whether to use ``mxfp6`` compression for weights. ``Defaults to False``. + :use_onnx_subfunctions: bool, optional: whether to enable ONNX subfunctions during export. Exporting PyTorch model to ONNX with modules as subfunctions helps to reduce export/compile time. Defaults to False + :compiler_options (dict, optional): Additional compiler options. + + For QAIC Compiler: Extra arguments for qaic-compile can be passed. + :aic_enable_depth_first (bool, optional): Enables DFS with default memory size. ``Defaults to False``. + :allow_mxint8_mdp_io (bool, optional): Allows MXINT8 compression of MDP IO traffic. ``Defaults to False.`` + + Params are converted to flags as below: + + - aic_hw_version=ai100 -> -aic-hw-version=ai100 + - aic_hw_version=ai200 -> -aic-hw-version=ai200 + + For QNN Compiler: Following arguments can be passed. + :enable_qnn (bool): Enables QNN Compilation. + :qnn_config (str): Path of QNN Config parameters file. Any extra parameters for QNN compilation can be passed via this file. + + Returns: + :str: Path of the compiled ``qpc`` package. + """ + + _seq_lens = seq_len if isinstance(seq_len, list) else [seq_len] + specializations = [ + {"_graph_name": "CTC" if len(_seq_lens) == 1 else f"CTC_{i}", "batch_size": batch_size, "seq_len": sl} + for i, sl in enumerate(_seq_lens) + ] + + target_dtype = getattr(self.model.config, "torch_dtype", torch.float32) + return self._compile( + onnx_path=onnx_path, + compile_dir=compile_dir, + specializations=specializations, + convert_to_fp16=(CUSTOM_IO_DTYPE_MAP[target_dtype] == "float16"), + mxfp6_matmul=mxfp6_matmul, + mdp_ts_num_devices=num_devices, + aic_num_cores=num_cores, + use_onnx_subfunctions=use_onnx_subfunctions, + **compiler_options, + ) + + def generate( + self, + processor, + inputs: torch.Tensor, + device_ids: List[int] = None, + runtime_ai100: bool = True, + write_io: bool = False, + ) -> Union[torch.Tensor, np.ndarray]: + """ + This method generates output by executing PyTorch runtime or the compiled ``qpc`` on ``Cloud AI 100`` Hardware cards. + ``Mandatory`` Args: + :inputs (Union[torch.Tensor, np.ndarray]): inputs to run the execution. + :processor (AutoProcessor): The Processor to use for encoding the waveform. + ``optional`` Args: + :device_id (List[int]): Ids of devices for running the qpc pass as [0] in case of normal model / [0, 1, 2, 3] in case of tensor slicing model + :runtime_ai100 (bool, optional): ``AI_100`` and ``PyTorch`` runtime is supported as of now. Defaults to ``True`` for ``AI_100`` runtime. + Returns: + :dict: Output from the ``AI_100`` or ``PyTorch`` runtime. + """ + self._write_io_dir = os.path.join(os.path.dirname(self.onnx_path), "io_dir") if write_io else None + + # AI_100 runtime + if runtime_ai100: + if not isinstance(self.qpc_path, Path): + raise TypeError("Please run compile API first!") + + return self.cloud_ai_100_feature_generate(processor, inputs=inputs, device_ids=device_ids) + # PyTorch runtime + else: + return self.pytorch_feature_generate(processor, model=self.model, inputs=inputs) + + def cloud_ai_100_feature_generate( + self, + processor, + inputs: torch.Tensor, + device_ids: List[int] = [0], + ) -> np.ndarray: + """ + Generates features with list of prompts using AI 100 runtime. + + ``Mandatory`` Args: + :inputs (Union[torch.Tensor, np.ndarray]): inputs to run the execution. + :processor (AutoProcessor): The Processor to use for encoding the waveform. + ``Optional`` Args: + device_ids (List[int], optional): A list of device IDs to use for the session. Defaults to [0]. + + """ + + if self.qpc_session is None: + self.qpc_session = QAICInferenceSession(str(self.qpc_path), device_ids) + self.batch_size = self.qpc_session.bindings[0].dims[0] + self.seq_len = self.qpc_session.bindings[0].dims[1] + + # To handle single seq_len as we can't fetch allowed shapes for single seq_len + self.seq_len = self.qpc_session.bindings[0].dims[1] if not hasattr(self, "seq_len") else self.seq_len + inputs = processor(inputs, return_tensors="pt", max_length=self.seq_len, truncation=True, padding="max_length") + input_ids_len = inputs["input_values"].shape[-1] + input_values = np.array( + torch.nn.functional.pad(inputs["input_values"], (0, self.seq_len - input_ids_len), "constant", 0) + ) + target_dtype = getattr(self.model.config, "torch_dtype", torch.float32) + input_values = input_values.astype(TORCH_TO_NUMPY_DTYPE_MAP[target_dtype]) + inputs = dict(input_values=input_values) + outputs = self.qpc_session.run(inputs) + + if self._write_io_dir is not None: + write_io_files(inputs, outputs, self._write_io_dir, "output", "aic_batch_io", True, False) + + logits = outputs["logits"] + predicted_ids = np.argmax(logits, axis=-1) + transcriptions = processor.batch_decode(torch.tensor(predicted_ids)) + return transcriptions + + def pytorch_feature_generate(self, processor, model, inputs: Union[torch.Tensor, np.ndarray]) -> List[torch.Tensor]: + """ + Generates features from a list of text prompts using a PyTorch model. + + ``Mandatory`` Args: + :model: The transformed PyTorch model used for generating features. + :inputs (Union[torch.Tensor, np.ndarray]): inputs to run the execution. + :processor (AutoProcessor): The Processor to use for encoding the waveform. + + """ + input_values = processor( + inputs[0], return_tensors="pt", max_length=self.seq_len, truncation=True, padding="max_length" + ).input_values + outputs = model(input_values[0]) + + if self._write_io_dir is not None: + write_io_files(input_values[0], outputs, self._write_io_dir, "output", "aic_batch_io", True, False) + + logits = outputs.logits + logits = logits.detach().numpy() + predicted_ids = np.argmax(logits, axis=-1) + transcriptions = processor.batch_decode(predicted_ids) + return transcriptions diff --git a/QEfficient/transformers/models/dream/modeling_dream.py b/QEfficient/transformers/models/dream/modeling_dream.py new file mode 100644 index 0000000000..e0dc57adf9 --- /dev/null +++ b/QEfficient/transformers/models/dream/modeling_dream.py @@ -0,0 +1,299 @@ +from typing import List, Optional, Tuple, Union +import torch +from torch import nn +from transformers.cache_utils import Cache, DynamicCache +from transformers.modeling_outputs import ( + BaseModelOutput, + MaskedLMOutput, +) +from torch.nn import functional as F +from QEfficient.diffusers.models.modeling_utils import compute_blocked_attention, get_attention_blocking_config + +from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import ( + logging, +) +from transformers import PretrainedConfig +logger = logging.get_logger(__name__) + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + +def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`): + The position indices of the tokens corresponding to the query and key tensors. For example, this can be + used to pass offsetted position ids when working with a KV-cache. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos[position_ids].unsqueeze(unsqueeze_dim) + sin = sin[position_ids].unsqueeze(unsqueeze_dim) + + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + # Cast back to original dtype + return q_embed.to(q.dtype), k_embed.to(k.dtype) + +class QEffDreamRotaryEmbedding(nn.Module): + """ + The only differences are: + - Add static sin/cos computations. + """ + + def __init__(self, config, device=None): + super().__init__() + dim = config.hidden_size // config.num_attention_heads + self.inv_freq = 1.0 / (config.rope_theta ** (torch.arange(0, dim, 2, device=device, dtype=torch.float) / dim)) + self.original_max_seq_len = config.max_position_embeddings or config.max_sequence_length + self._set_cos_sin_cache( + seq_len=self.original_max_seq_len, device=self.inv_freq.device, dtype=torch.get_default_dtype() + ) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) + + freqs = torch.outer(t, self.inv_freq) + + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + + def forward(self, x, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + if seq_len > self.max_seq_len_cached: + self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) + + return ( + self.cos_cached[:seq_len].to(dtype=x.dtype), + self.sin_cached[:seq_len].to(dtype=x.dtype), + ) + +# Copied from transformers.models.llama.modeling_llama.rotate_half +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + +class QEffDreamBlockedAttention(DreamAttention): + """ + This is blocked attention. Set num_head_blocks, kv blocks as the environment variable + """ + def __qeff_init__(self): + self.rotary_emb = QEffDreamRotaryEmbedding(config=self.config) + self.compile_length = self.config.max_position_embeddings + + # Adapted from DreamAttention.forward + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " + "removed and `position_embeddings` will be mandatory." + ) + cos, sin = self.rotary_emb(value_states, q_len) + else: + cos, sin = position_embeddings + query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + + attention_mask = attention_mask.bool() + + blocking_mode, head_block_size, num_kv_blocks, num_q_blocks = get_attention_blocking_config() + attn_output = compute_blocked_attention( + query_states, + key_states, + value_states, + blocking_mode=blocking_mode, + head_block_size=head_block_size, + num_kv_blocks=num_kv_blocks, + num_q_blocks=num_q_blocks, + attention_mask=attention_mask, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, self.hidden_size) + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + +def top_k_logits(logits, top_k=None): + top_k = min(top_k, logits.size(-1)) # Safety check + # Remove all tokens with a probability less than the last token of the top-k + indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] + logits = logits.masked_fill(indices_to_remove, torch.finfo(logits.dtype).min) + return logits + +def sample_tokens(logits, temperature, top_p, top_k, neg_entropy): + if temperature > 0: + logits = logits / temperature + if top_k is not None: + logits = top_k_logits(logits, top_k) + probs = torch.softmax(logits, dim=-1) + + if temperature > 0: + confidence, x0 = probs.max(dim=-1) + else: + confidence, x0 = probs.max(dim=-1) + + if neg_entropy: + epsilon = 1e-10 + log_probs = torch.log(probs + epsilon) + confidence = torch.sum(probs * log_probs, dim=-1) + + return confidence, x0 + +class QEffDreamSampler(nn.Module): + def __init__(self, mask_token_id): + super().__init__() + self.mask_token_id = mask_token_id + self.temperature = 0.2 + self.top_p = 0.95 + self.top_k = 50 + self.entropy = True + self.alg_temp = 0.0 + + def forward(self, x, logits): + mask_index = (x == self.mask_token_id) + logits = torch.cat([logits[:,:1], logits[:, :-1]], dim=1) + mask_logits = torch.where(mask_index.unsqueeze(-1), logits, torch.tensor(0.0)) + confidence, x0 = sample_tokens(mask_logits, temperature=self.temperature, top_p=self.top_p, top_k=self.top_k, neg_entropy=self.entropy) + confidence = confidence.unsqueeze(0) + number_transfer_tokens = 2 + full_confidence = torch.full_like(x, -torch.inf, device=x.device, dtype=logits.dtype) + full_confidence = torch.where(mask_index, confidence, full_confidence ) + if number_transfer_tokens > 0: + _, transfer_index = torch.topk(full_confidence, number_transfer_tokens) + x_ = torch.zeros_like(x, device=x.device, dtype=torch.long) + self.mask_token_id + x_ = torch.where(mask_index, x0, x_) + row_indices = torch.arange(x.size(0), device=x.device, dtype=torch.int64).unsqueeze(1).expand_as(transfer_index) + transfer_index = transfer_index.to(torch.int64) + x[row_indices,transfer_index] = x_[row_indices,transfer_index] + + return x + +class QEffDreamModel(nn.Module): + def __qeff_init__(self, config): + self.vocab_size = config.vocab_size + self.mask_token_id = config.mask_token_id + self.sampler = QEffDreamSampler(mask_token_id = self.mask_token_id) + + def forward( + self, + # current_iter: int = 0, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, MaskedLMOutput]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + hidden_states = outputs[0] + + logits = self.lm_head(hidden_states[:, -0:, :]) + new_x = self.sampler(input_ids, logits) + return new_x diff --git a/QEfficient/transformers/models/pytorch_transforms.py b/QEfficient/transformers/models/pytorch_transforms.py index ec34ebb046..8957d1e47e 100644 --- a/QEfficient/transformers/models/pytorch_transforms.py +++ b/QEfficient/transformers/models/pytorch_transforms.py @@ -969,7 +969,6 @@ class VlmNoKVOffloadTransform(ModuleMappingTransform): MllamaTextCrossAttention: QEffMllamaTextCrossAttentionSingleQPC, } - class KVCacheExternalModuleMapperTransform(ExternalModuleMapperTransform): _match_class_replace_method = {} _match_string_replace_method = { @@ -984,6 +983,8 @@ class KVCacheExternalModuleMapperTransform(ExternalModuleMapperTransform): "get_qeff_language_decoder": QEffInternVLModel.get_qeff_language_decoder, }, "InternVisionEmbeddings": {"forward": QEffInternVisionEmbeddings.forward}, + # Mapping for Dream + "DreamRMSNorm": {"forward": CustomRMSNormAIC.forward}, # Mapping for Molmo "MolmoForCausalLM": { "forward": QEffMolmoModel.forward, diff --git a/QEfficient/utils/diffusionLM_utils.py b/QEfficient/utils/diffusionLM_utils.py new file mode 100644 index 0000000000..10b19c3e40 --- /dev/null +++ b/QEfficient/utils/diffusionLM_utils.py @@ -0,0 +1,304 @@ +import torch +from typing import Any, Dict, Optional, Tuple, Union +from transformers.utils import ModelOutput +from transformers.generation.configuration_utils import ( + GenerationConfig +) +from transformers import __version__ +from QEfficient.utils.logging_utils import logger +from torch.nn import functional as F +import copy +from transformers.utils import ( + is_torchdynamo_compiling +) +import time +import numpy as np + + +class DreamGenerationConfig(GenerationConfig): + def __init__(self, **kwargs): + self.temperature: float = kwargs.pop("temperature", 0.0) + self.top_p: Optional[float] = kwargs.pop("top_p", None) + self.top_k: Optional[int] = kwargs.pop("top_k", None) + self.max_length = kwargs.pop("max_length", 20) + self.max_new_tokens = kwargs.pop("max_new_tokens", None) + # diffusion specific params + self.eps: float = kwargs.pop("eps", 1e-3) + self.steps: int = kwargs.pop("steps", 512) + self.alg: str = kwargs.pop("alg", 'origin') + self.alg_temp: Optional[float] = kwargs.pop("alg_temp", None) + self.number_transfer_tokens = kwargs.pop("number_transfer_tokens", 1) + self.expand_budget = kwargs.pop("expand_budget", None) + self.pad_delete_to_right = kwargs.pop("pad_delete_to_right", False) + + #qpc specific params + self.compile_length: int = kwargs.pop("compile_length", 1000) + self.qpc_path: str = kwargs.pop("qpc_path", None) + self.device_ids: list[int] = kwargs.pop("device_ids", None) + + # Parameters that define the output variables of `generate` + self.num_return_sequences: int = kwargs.pop("num_return_sequences", 1) + self.return_dict_in_generate: bool = kwargs.pop("return_dict_in_generate", False) + self.output_history: bool = kwargs.pop("output_history", False) + + # Special tokens that can be used at generation time + self.mask_token_id = kwargs.pop("mask_token_id", None) + self.pad_token_id = kwargs.pop("pad_token_id", None) + self.bos_token_id = kwargs.pop("bos_token_id", None) + self.eos_token_id = kwargs.pop("eos_token_id", None) + self.expand_token_id = kwargs.pop("expand_token_id", None) + self.delete_token_id = 151643#kwargs.pop("eos_token_id", None) + + # Wild card + self.generation_kwargs = kwargs.pop("generation_kwargs", {}) + + # The remaining attributes do not parametrize `.generate()`, but are informative and/or used by the hub + # interface. + self._from_model_config = kwargs.pop("_from_model_config", False) + self._commit_hash = kwargs.pop("_commit_hash", None) + self.transformers_version = kwargs.pop("transformers_version", __version__) + + # Additional attributes without default values + if not self._from_model_config: + # we don't want to copy values from the model config if we're initializing a `GenerationConfig` from a + # model's default configuration file + for key, value in kwargs.items(): + try: + setattr(self, key, value) + except AttributeError as err: + logger.error(f"Can't set {key} with value {value} for {self}") + raise err + + # Validate the values of the attributes + self.validate(is_init=True) + + def validate(self, is_init=False): + pass + +def _prepare_special_tokens( + self, + generation_config: DreamGenerationConfig, + device: Optional[Union[torch.device, str]] = None, + ): + """ + Prepares the special tokens for generation, overwriting the generation config with their processed versions + converted to tensor. + + Note that `generation_config` is changed in place and stops being serializable after this method is called. + That is no problem if called within `generate` (`generation_config` is a local copy that doesn't leave the + function). However, if called outside `generate`, consider creating a copy of `generation_config` first. + """ + + # Convert special tokens to tensors + def _tensor_or_none(token, device=None): + if token is None: + return token + + device = device if device is not None else self.device + if isinstance(token, torch.Tensor): + return token.to(device) + return torch.tensor(token, device=device, dtype=torch.long) + + bos_token_tensor = _tensor_or_none(generation_config.bos_token_id, device=device) + eos_token_tensor = _tensor_or_none(generation_config.eos_token_id, device=device) + pad_token_tensor = _tensor_or_none(generation_config.pad_token_id, device=device) + mask_token_tensor = _tensor_or_none(generation_config.mask_token_id, device=device) + + # We can have more than one eos token. Always treat it as a 1D tensor (when it exists). + if eos_token_tensor is not None and eos_token_tensor.ndim == 0: + eos_token_tensor = eos_token_tensor.unsqueeze(0) + + # Set pad token if unset (and there are conditions to do so) + if pad_token_tensor is None and eos_token_tensor is not None: + pad_token_tensor = eos_token_tensor[0] + logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{pad_token_tensor} for open-end generation.") + + # Update generation config with the updated special tokens tensors + # NOTE: this must be written into a different attribute name than the one holding the original special tokens + # (in their non-tensor form), in order to enable end-to-end compilation. See + # https://pytorch.org/docs/stable/torch.compiler_cudagraph_trees.html#limitations + generation_config._bos_token_tensor = bos_token_tensor + generation_config._eos_token_tensor = eos_token_tensor + generation_config._pad_token_tensor = pad_token_tensor + generation_config._mask_token_tensor = mask_token_tensor + +class DreamModelOutput(ModelOutput): + sequences: torch.LongTensor = None + history: Optional[Tuple[torch.FloatTensor]] = None + +def _prepare_generated_length( + self, + generation_config, + has_default_max_length, + input_ids_length, + ): + """Prepared max and min length in generation configs to avoid clashes between similar attributes""" + + if generation_config.max_new_tokens is not None: + if not has_default_max_length and generation_config.max_length is not None: + logger.warning( + f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(=" + f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. " + "Please refer to the documentation for more information. " + "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)" + ) + generation_config.max_length = generation_config.max_new_tokens + input_ids_length + + elif has_default_max_length: + if generation_config.max_length == DreamGenerationConfig().max_length: + generation_config.max_length = generation_config.max_length + input_ids_length + max_position_embeddings = getattr(self.config, "max_position_embeddings", None) + if max_position_embeddings is not None: + generation_config.max_length = min(generation_config.max_length, max_position_embeddings) + + return generation_config + +def _prepare_generation_config( + self, generation_config: Optional[DreamGenerationConfig], **kwargs: Dict + ) -> DreamGenerationConfig: + """ + Prepares the base generation config, then applies any generation configuration options from kwargs. This + function handles retrocompatibility with respect to configuration files. + """ + # priority: `generation_config` argument > `model.generation_config` (the default generation config) + using_model_generation_config = False + if generation_config is None: + generation_config = DreamGenerationConfig.from_model_config(self.config) + using_model_generation_config = True + + # `torch.compile` can't compile `copy.deepcopy`, arguments in `kwargs` that are part of `generation_config` + # will mutate the object with `.update`. As such, passing these arguments through `kwargs` is disabled -- an + # exception will be raised in `_validate_model_kwargs` + if not is_torchdynamo_compiling(): + generation_config = copy.deepcopy(generation_config) + _kwargs = generation_config.update(**kwargs) + # If `generation_config` is provided, let's fallback ALL special tokens to the default values for the model + if not using_model_generation_config: + if generation_config.bos_token_id is None: + generation_config.bos_token_id = self.generation_config.bos_token_id + if generation_config.eos_token_id is None: + generation_config.eos_token_id = self.generation_config.eos_token_id + if generation_config.pad_token_id is None: + generation_config.pad_token_id = self.generation_config.pad_token_id + if generation_config.mask_token_id is None: + generation_config.mask_token_id = self.generation_config.mask_token_id + if generation_config.expand_token_id is None: + generation_config.expand_token_id = self.generation_config.expand_token_id + + return generation_config + +def _prepare_special_tokens( + self, + generation_config: DreamGenerationConfig, + device: Optional[Union[torch.device, str]] = None, + ): + """ + Prepares the special tokens for generation, overwriting the generation config with their processed versions + converted to tensor. + + Note that `generation_config` is changed in place and stops being serializable after this method is called. + That is no problem if called within `generate` (`generation_config` is a local copy that doesn't leave the + function). However, if called outside `generate`, consider creating a copy of `generation_config` first. + """ + + # Convert special tokens to tensors + def _tensor_or_none(token, device=None): + if token is None: + return token + + device = device if device is not None else self.device + if isinstance(token, torch.Tensor): + return token.to(device) + return torch.tensor(token, device=device, dtype=torch.long) + + bos_token_tensor = _tensor_or_none(generation_config.bos_token_id, device=device) + eos_token_tensor = _tensor_or_none(generation_config.eos_token_id, device=device) + pad_token_tensor = _tensor_or_none(generation_config.pad_token_id, device=device) + mask_token_tensor = _tensor_or_none(generation_config.mask_token_id, device=device) + + # We can have more than one eos token. Always treat it as a 1D tensor (when it exists). + if eos_token_tensor is not None and eos_token_tensor.ndim == 0: + eos_token_tensor = eos_token_tensor.unsqueeze(0) + + # Set pad token if unset (and there are conditions to do so) + if pad_token_tensor is None and eos_token_tensor is not None: + pad_token_tensor = eos_token_tensor[0] + logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{pad_token_tensor} for open-end generation.") + + # Update generation config with the updated special tokens tensors + # NOTE: this must be written into a different attribute name than the one holding the original special tokens + # (in their non-tensor form), in order to enable end-to-end compilation. See + # https://pytorch.org/docs/stable/torch.compiler_cudagraph_trees.html#limitations + generation_config._bos_token_tensor = bos_token_tensor + generation_config._eos_token_tensor = eos_token_tensor + generation_config._pad_token_tensor = pad_token_tensor + generation_config._mask_token_tensor = mask_token_tensor + +def _sample( + self, + start_time, + qpc_session, + input_ids: torch.LongTensor, + attention_mask: Optional[torch.LongTensor], + generation_config: DreamGenerationConfig, + generation_tokens_hook_func, + generation_logits_hook_func + ) -> Union[DreamModelOutput, torch.LongTensor]: + # init values + output_history = generation_config.output_history + return_dict_in_generate = generation_config.return_dict_in_generate + max_length = generation_config.max_length + mask_token_id = generation_config.mask_token_id + eos_token_id = generation_config.eos_token_id + pad_token_id = generation_config.pad_token_id + steps = generation_config.steps + + compile_length = generation_config.compile_length + + histories = [] if (return_dict_in_generate and output_history) else None + + # pad input_ids to max_length + prompt_length = input_ids.shape[1] + x = F.pad(input_ids, (0, compile_length - input_ids.shape[1]), value=mask_token_id) + # attention_mask = torch.ones_like(x) + print('Number of steps are ',steps) + + if attention_mask is not None: + attention_mask = torch.logical_and( + attention_mask.unsqueeze(1).unsqueeze(-2), + attention_mask.unsqueeze(1).unsqueeze(-1), + ) + else: + tok_idx = None + attention_mask = "full" + + x = generation_tokens_hook_func(None, x, None) + x = x.numpy() + attention_mask = attention_mask.numpy().astype(np.int64) + for i in range(steps): + start_time_iter = time.perf_counter() + inputs = dict(input_ids=x, attention_mask = attention_mask) + x = qpc_session.run(inputs)['logits'] + x = torch.tensor(x) + x = generation_tokens_hook_func(None, x, None) + x = x.numpy() + + mask_token_mask = (x == mask_token_id).any() + + eos_mask = (x == eos_token_id) + eos_positions = np.where(eos_mask) + if len(eos_positions[0]) > 0: + first_eos_idx = eos_positions[-1][0] + x[:, first_eos_idx + 1:] = eos_token_id + if not mask_token_mask: + break + end_time_iter = time.perf_counter() + total_time_network_sample = end_time_iter - start_time_iter + print(f'Time only network, gradio at {i} iteration is {total_time_network_sample:.6f} for input_ids of length {x.shape[1]} ') + x = torch.tensor(x) + + if return_dict_in_generate: + return {'sequences': x, + 'history': histories} + else: + return {'sequences': x} \ No newline at end of file