From 340c35e26958e65097f7bf427a31b9567e2e35ff Mon Sep 17 00:00:00 2001 From: Dhiraj Kumar Sah Date: Wed, 15 Apr 2026 11:24:13 +0530 Subject: [PATCH 1/8] Added support for repeat_kv_heads as a pytorch transform for LLM and VLMs. Based on PR #625. Addressed most of the comments made on the previous PR. Repeat check is done on a subset of models during CI, primarily due to difference in configs of such models. Signed-off-by: Dhiraj Kumar Sah --- QEfficient/base/modeling_qeff.py | 1 + .../transformers/models/modeling_auto.py | 13 + .../transformers/models/pytorch_transforms.py | 180 +++- QEfficient/utils/test_utils.py | 31 + tests/configs/causal_model_configs.json | 19 +- .../test_image_text_to_text_models.py | 55 ++ .../models/test_causal_lm_models.py | 810 ++++++++++++++++++ 7 files changed, 1107 insertions(+), 2 deletions(-) create mode 100644 tests/transformers/models/test_causal_lm_models.py diff --git a/QEfficient/base/modeling_qeff.py b/QEfficient/base/modeling_qeff.py index 17b87afd1..c61248ffd 100755 --- a/QEfficient/base/modeling_qeff.py +++ b/QEfficient/base/modeling_qeff.py @@ -76,6 +76,7 @@ def __init__(self, model: torch.nn.Module, **kwargs) -> None: self.model = model self.config = model.config self.hash_params = create_model_params(self, **kwargs) + self.hash_params["num_kv_heads_repeat"] = kwargs.get("num_kv_heads_repeat", 1) self.onnx_path: Optional[str] = None self.qpc_path: Optional[str] = None self.qpc_session: Optional[QAICInferenceSession] = None diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 57689ede6..5cf45df1a 100755 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -55,6 +55,7 @@ PrefillOnlyChunkedTransform, PrefillOnlyExternalModuleMapperTransform, PrefillOnlyTransform, + ReplicateKVHeadTransform, RevertPrefillKeepAttentionTransform, RevertPrefillOnlyExternalModuleMapperTransform, RevertPrefillOnlyTransform, @@ -1290,6 +1291,7 @@ def __init__( self.ccl_enabled = qaic_config.get("ccl_enabled", False) self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = None, None self.input_shapes, self.output_names = None, None + self.model, replicate_kv_transformed = ReplicateKVHeadTransform.apply(self.model, **kwargs) # ---Sampling--- # Note: SamplerTransform should be applied after all other transforms # are done. The role of the sampler is to just add nodes at the output of the @@ -1326,6 +1328,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: str, qaic_config: Option kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False}) _resolve_torch_dtype(kwargs) + num_kv_heads_repeat = kwargs.pop("num_kv_heads_repeat", 1) model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, **kwargs) kwargs.update({"enable_proxy": enable_proxy} if enable_proxy else {}) @@ -1334,6 +1337,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: str, qaic_config: Option model, pretrained_model_name_or_path=pretrained_model_name_or_path, qaic_config=qaic_config, + num_kv_heads_repeat=num_kv_heads_repeat, **kwargs, ) @@ -2178,6 +2182,7 @@ def __init__( self.model.config.text_config.use_cache = True else: self.model.config.use_cache = True + self.model, replicate_kv_transformed = ReplicateKVHeadTransform.apply(self.model, **kwargs) self.hash_params["qeff_auto_class"] = self.__class__.__name__ self.ccl_enabled = False if qaic_config: @@ -2228,6 +2233,7 @@ def from_pretrained( config._attn_implementation = "eager" config.vision_config.use_flash_attn = "false" _resolve_torch_dtype(kwargs) + num_kv_heads_repeat = kwargs.pop("num_kv_heads_repeat", 1) model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, config, *args, **kwargs) kwargs.update({"enable_proxy": enable_proxy} if enable_proxy else {}) @@ -2236,6 +2242,7 @@ def from_pretrained( model, pretrained_model_name_or_path=pretrained_model_name_or_path, qaic_config=qaic_config, + num_kv_heads_repeat=num_kv_heads_repeat, **kwargs, ) @@ -2875,6 +2882,7 @@ def from_pretrained( kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False}) _resolve_torch_dtype(kwargs) + num_kv_heads_repeat = kwargs.pop("num_kv_heads_repeat", 1) model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, **kwargs) kwargs.update({"enable_proxy": enable_proxy} if enable_proxy else {}) @@ -2885,6 +2893,7 @@ def from_pretrained( continuous_batching=continuous_batching, pretrained_model_name_or_path=pretrained_model_name_or_path, qaic_config=qaic_config, + num_kv_heads_repeat=num_kv_heads_repeat, **kwargs, ) @@ -3044,6 +3053,7 @@ def __init__( setattr(self.model, "mla_absorption", mla_absorption) self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = None, None self.hash_params["max_seq_len_cached"] = max_seq_len_cached + self.model, replicate_kv_transformed = ReplicateKVHeadTransform.apply(self.model, **kwargs) # ---Sampling--- # Note: SamplerTransform should be applied after all other transforms @@ -3129,6 +3139,7 @@ def from_pretrained( kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False}) _resolve_torch_dtype(kwargs) + num_kv_heads_repeat = kwargs.pop("num_kv_heads_repeat", 1) model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, *args, **kwargs) if qaic_config is not None: qaic_config["pretrained_model_name_or_path"] = pretrained_model_name_or_path @@ -3142,6 +3153,7 @@ def from_pretrained( pretrained_model_name_or_path=pretrained_model_name_or_path, qaic_config=qaic_config, continuous_batching=continuous_batching, + num_kv_heads_repeat=num_kv_heads_repeat, **kwargs, ) return cls( @@ -3150,6 +3162,7 @@ def from_pretrained( qaic_config=qaic_config, pretrained_model_name_or_path=pretrained_model_name_or_path, max_seq_len_cached=max_seq_len_cached, + num_kv_heads_repeat=num_kv_heads_repeat, **kwargs, ) diff --git a/QEfficient/transformers/models/pytorch_transforms.py b/QEfficient/transformers/models/pytorch_transforms.py index b2b447a78..712d28307 100755 --- a/QEfficient/transformers/models/pytorch_transforms.py +++ b/QEfficient/transformers/models/pytorch_transforms.py @@ -297,8 +297,13 @@ ) from transformers.models.xlm_roberta.modeling_xlm_roberta import XLMRobertaModel -from QEfficient.base.pytorch_transforms import ExternalModuleMapperTransform, ModuleMappingTransform +from QEfficient.base.pytorch_transforms import ( + ExternalModuleMapperTransform, + ModuleMappingTransform, + ModuleMutatorTransform, +) from QEfficient.customop import CustomRMSNormAIC, GemmaCustomRMSNormAIC +from QEfficient.customop.matmulnbits import QuantLinearORT from QEfficient.transformers.embeddings.embedding_utils import POOLING_MAP, PooledModel, validate_user_pooling_function from QEfficient.transformers.models.bert.modeling_bert import ( QEffBertModel, @@ -627,6 +632,9 @@ QEffWhisperPositionalEmbedding, ) from QEfficient.transformers.post_processing import build_and_attach_mlp, model_type_registry +from QEfficient.transformers.quantizers.awq import WQLinear_GEMM +from QEfficient.transformers.quantizers.gptq import QuantLinearGPTQ +from QEfficient.transformers.quantizers.quantizer_compressed_tensors import FP8DeQuantLinear from QEfficient.transformers.sampler.sampler import sampler_forward from QEfficient.transformers.spd.spd_transform_forward import tlm_forward from QEfficient.utils.logging_utils import logger @@ -962,6 +970,176 @@ class RevertPrefillOnlyTransform(ModuleMappingTransform): **{v: k for k, v in PrefillOnlyChunkedTransform._module_mapping.items()}, } +class ReplicateKVHeadTransform(ModuleMutatorTransform): + """ + Replicates KV heads in attention modules to match the number of KV heads in the target model. + This transform is used when the source model has fewer KV heads than required in target model. + """ + + _module_mapping = { + QEffCodeGenForCausalLM, + QEffFalconForCausalLM, + QEffGPT2LMHeadModel, + QEffGPTJForCausalLM, + QEffLlamaForCausalLM, + QEffLlama4ForConditionalGeneration, + QEffLlavaForConditionalGeneration, + QEffLlavaNextForConditionalGeneration, + QEffGemmaForCausalLM, + QEffGemma2ForCausalLM, + QEffGemma3ForConditionalGeneration, + QEffGraniteForCausalLM, + QEffGraniteMoeForCausalLM, + QEffMllamaForConditionalGeneration, + QEffMistralForCausalLM, + QEffMistral3ForConditionalGeneration, + QEffMptForCausalLM, + QEffPhiForCausalLM, + QEffPhi3ForCausalLM, + QEffQwen2ForCausalLM, + QEffQwen3ForCausalLM, + QEffQwen_2_5_vl_ForConditionalGeneration, + QEffQwen3MoeForCausalLM, + QEffQwen3VLForConditionalGeneration, + QEffQwen3VLMoeForConditionalGeneration, + QEffStarcoder2ForCausalLM, + QEffGPTBigCodeForCausalLM, + QEffOlmo2ForCausalLM, + } + _module_string_mapping = { + "InternVLChatModel", + "MolmoForCausalLM," + } + + def _duplicate_weights_for_linear_layer( + layer: nn.Module, orig_kv_heads: int, repeat: int, head_dim: int, hidden_size: int + ): + new_kv_heads = repeat * orig_kv_heads + if isinstance(layer, (WQLinear_GEMM, QuantLinearGPTQ, QuantLinearORT)): + if head_dim % 8 != 0: + raise ValueError( + f"the value head_dim={head_dim} is not divisible by 8 which is according to the assumption that model is 4-bit quantized." + ) + if hidden_size % layer.group_size != 0: + raise ValueError( + f"The value of hidden_size={hidden_size} is not divisible by k_proj.group_size={layer.group_size}" + ) + + # Duplication of quantized weights + layer.qweight.data = torch.repeat_interleave( + layer.qweight.data.view(hidden_size, orig_kv_heads, head_dim // 8), repeat, 1 + ).view(hidden_size, (new_kv_heads * head_dim) // 8) + # Duplication of quantized zero points + layer.qzeros.data = torch.repeat_interleave( + layer.qzeros.data.view(hidden_size // layer.group_size, orig_kv_heads, head_dim // 8), + repeat, + 1, + ).view(hidden_size // layer.group_size, (new_kv_heads * head_dim) // 8) + # Duplication of quantization scales + layer.scales.data = torch.repeat_interleave( + layer.scales.data.view(hidden_size // layer.group_size, orig_kv_heads, head_dim), + repeat, + 1, + ).view(hidden_size // layer.group_size, new_kv_heads * head_dim) + layer.out_features = layer.out_features * repeat + + elif isinstance(layer, FP8DeQuantLinear): + layer.weight.data = torch.repeat_interleave( + layer.weight.data.view(orig_kv_heads, head_dim, hidden_size), repeat, 0 + ).view(new_kv_heads * head_dim, hidden_size) + layer.weight_scale.data = torch.repeat_interleave( + layer.weight_scale.data.view(orig_kv_heads, head_dim), repeat, 0 + ).view(new_kv_heads * head_dim, -1) + + else: + layer.weight.data = torch.repeat_interleave( + layer.weight.data.view(orig_kv_heads, head_dim, hidden_size), repeat, 0 + ).view(new_kv_heads * head_dim, hidden_size) + if layer.bias is not None: + layer.bias.data = torch.repeat_interleave( + layer.bias.data.view(orig_kv_heads, head_dim), repeat, 0 + ).view(new_kv_heads * head_dim) + + def _get_text_model(model): + """ + Determine and return the appropriate text_model from a given model object. + """ + # Check for VLMs + if hasattr(model, "language_model"): + if hasattr(model.language_model, "model"): + return model.language_model.model + else: + return model.language_model + if hasattr(model, "model"): + return model.model + if hasattr(model, "transformer"): + return model.transformer + if hasattr(model, "llm"): + return model.llm + if hasattr(model, "backbone"): + return model.backbone + + raise AttributeError("No suitable text model found in the provided model.") + + @classmethod + def mutate(cls, original_module: nn.Module, parent_module: nn.Module, n_repeat: int) -> nn.Module: + """ + Mutates the matched top-level model module in-place by replicating its KV heads. + + Args: + original_module: The matched top-level model module to mutate. + parent_module: The parent module (unused, present for interface compatibility). + n_repeat: The number of times to repeat the KV heads. + + Returns: + The mutated module (same object, modified in-place). + """ + text_model = cls._get_text_model(original_module) + orig_kv_heads = text_model.config.num_key_value_heads + new_kv_heads = n_repeat * orig_kv_heads + text_model.config.orig_kv_heads = orig_kv_heads + text_model.config.num_key_value_heads = new_kv_heads + + num_attention_heads = text_model.config.num_attention_heads + hidden_size = text_model.config.hidden_size + + logger.warning(f"Original KV heads: {orig_kv_heads}") + logger.warning(f"Modified KV heads: {new_kv_heads}") + for block in text_model.layers: + attn = getattr(block, "cross_attn", getattr(block, "self_attn", None)) + attn.num_key_value_heads = new_kv_heads + attn.num_key_value_groups = num_attention_heads // new_kv_heads + + cls._duplicate_weights_for_linear_layer( + attn.k_proj, orig_kv_heads, n_repeat, attn.head_dim, hidden_size + ) + cls._duplicate_weights_for_linear_layer( + attn.v_proj, orig_kv_heads, n_repeat, attn.head_dim, hidden_size + ) + + return original_module + + @classmethod + def apply(cls, model: nn.Module, **kwargs) -> Tuple[nn.Module, bool]: + """ + Replicates KV heads in attention modules based on provided multiplier. + + Args: + model: The model to apply the transform to. + kwargs: Additional arguments for the transformation. Includes: + - num_kv_heads_repeat: The number of times to repeat the KV heads. + """ + n_repeat = kwargs.pop("num_kv_heads_repeat", 1) + transformed = False + if n_repeat is not None and n_repeat > 1: + if (model.__class__ in cls._module_mapping) or (model.__class__.__name__ in cls._module_string_mapping): + cls.mutate(model, None, n_repeat) + transformed = True + else: + raise NotImplementedError( + f"Model class {model.__class__.__name__} is not supported for KV head replication." + ) + return model, transformed class ReplicateKVHeadTransform: """ diff --git a/QEfficient/utils/test_utils.py b/QEfficient/utils/test_utils.py index 131ff59e2..77cfe7178 100644 --- a/QEfficient/utils/test_utils.py +++ b/QEfficient/utils/test_utils.py @@ -288,6 +288,12 @@ def load_qeff_model_with_sampler( return qeff_model +def get_text_config(config): + if hasattr(config, "text_config"): + return config.text_config + elif hasattr(config, "llm_config"): + return config.llm_config + return config # Processor class for InternVL models class InternProcessor: @@ -492,6 +498,31 @@ class ModelConfig: "Qwen/Qwen3.6-35B-A3B", } + REPEAT_KV_TEST_MODELS = { + "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + "ibm-granite/granite-3.1-1b-a400m-base", + "Qwen/Qwen2-0.5B", + "bigcode/starcoder2-3b", + # "mistralai/Mixtral-8x7B-Instruct-v0.1", + "meta-llama/Llama-3.2-1B", + # "unsloth/gemma-2b", + # "unsloth/gemma-2-2b", + # "TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ", + "TheBloke/Llama-2-7B-GPTQ", + "neuralmagic/Llama-3.2-3B-Instruct-FP8", + "ibm-granite/granite-3.1-2b-instruct", + "llava-hf/llava-1.5-7b-hf", + "meta-llama/Llama-4-Scout-17B-16E-Instruct", + # "google/gemma-3-4b-it", + "allenai/Molmo-7B-D-0924", + "mistralai/Mistral-Small-3.1-24B-Instruct-2503", + "Qwen/Qwen2.5-VL-3B-Instruct", + "Qwen/Qwen3-VL-2B-Instruct", + "Qwen/Qwen3-VL-30B-A3B-Instruct", + "allenai/Molmo-7B-D-0924", + "OpenGVLab/InternVL2_5-1B", + } + EXTERNAL_MODELS = { "hpcai-tech/grok-1": { "pytorch_hf_tokens_custom_case": [ diff --git a/tests/configs/causal_model_configs.json b/tests/configs/causal_model_configs.json index 93f4e7ae2..9c7160522 100644 --- a/tests/configs/causal_model_configs.json +++ b/tests/configs/causal_model_configs.json @@ -325,6 +325,19 @@ "num_key_value_heads": 1 } }, + { + "model_name": "hpcai-tech/grok-1", + "model_type": null, + "additional_params": { + "max_position_embeddings": 128, + "num_hidden_layers": 1, + "num_attention_heads": 2, + "hidden_size": 64, + "intermediate_size": 256, + "vocab_size": 131072, + "num_key_value_heads": 1 + } + }, { "model_name": "Snowflake/Llama-3.1-SwiftKV-8B-Instruct", "model_type": null, @@ -342,6 +355,10 @@ "high_freq_factor": 4.0, "low_freq_factor": 1.0, "original_max_position_embeddings": 8192, + "rope_type": "llama3""factor": 8.0, + "high_freq_factor": 4.0, + "low_freq_factor": 1.0, + "original_max_position_embeddings": 8192, "rope_type": "llama3" } } @@ -720,4 +737,4 @@ } } ] -} +} \ No newline at end of file diff --git a/tests/transformers/models/image_text_to_text/test_image_text_to_text_models.py b/tests/transformers/models/image_text_to_text/test_image_text_to_text_models.py index 9b9e662e5..0bd26efe8 100644 --- a/tests/transformers/models/image_text_to_text/test_image_text_to_text_models.py +++ b/tests/transformers/models/image_text_to_text/test_image_text_to_text_models.py @@ -30,6 +30,7 @@ from QEfficient.utils.test_utils import ( InternProcessor, ModelConfig, + get_text_config, load_vlm_model, load_vlm_model_from_config, set_num_layers_vlm, @@ -56,6 +57,8 @@ def check_image_text_to_text_pytorch_vs_kv_vs_ort_vs_ai100( enable_qnn: Optional[bool] = False, qnn_config: Optional[str] = None, config: Optional[AutoConfig] = None, + num_kv_heads_repeat: Optional[int] = 1, + test_kv_replicate: Optional[bool] = None, torch_dtype: Optional[torch.dtype] = torch.float32, compare_results: Optional[bool] = False, ): @@ -75,6 +78,9 @@ def check_image_text_to_text_pytorch_vs_kv_vs_ort_vs_ai100( model_name, trust_remote_code=True, padding=model_name not in ModelConfig.MOLMO_MODELS ) config = set_num_layers_vlm(config, n_layer=n_layer) + if test_kv_replicate: + text_config = get_text_config(config) + num_kv_heads_repeat = text_config.num_attention_heads // text_config.num_key_value_heads if hasattr(config, "model_type") and config.model_type in ["gemma3"]: config.text_config._sliding_window_pattern = 2 config.text_config.layer_types = ["sliding_attention", "full_attention"] @@ -93,6 +99,7 @@ def check_image_text_to_text_pytorch_vs_kv_vs_ort_vs_ai100( kv_offload=kv_offload, config=config, torch_dtype=torch_dtype, + num_kv_heads_repeat=num_kv_heads_repeat, ) else: model_hf = load_vlm_model(config) @@ -101,14 +108,19 @@ def check_image_text_to_text_pytorch_vs_kv_vs_ort_vs_ai100( kv_offload=kv_offload, config=config, torch_dtype=torch_dtype, + num_kv_heads_repeat=num_kv_heads_repeat, ) else: + if test_kv_replicate: + text_config = get_text_config(config) + num_kv_heads_repeat = text_config.num_attention_heads // text_config.num_key_value_heads model_hf = load_vlm_model_from_config(config) qeff_model = QEFFAutoModelForImageTextToText( copy.deepcopy(model_hf), kv_offload=kv_offload, config=model_hf.config, torch_dtype=torch_dtype, + num_kv_heads_repeat=num_kv_heads_repeat, ) compile_kwargs = { "num_devices": num_devices, @@ -336,6 +348,49 @@ def test_dummy_image_text_to_text_pytorch_vs_kv_vs_ort_vs_ai100(model_name, kv_o manual_cleanup=manual_cleanup, ) +@pytest.mark.on_qaic +@pytest.mark.multimodal +@pytest.mark.regular +@pytest.mark.parametrize("model_name", test_mm_models) +@pytest.mark.parametrize("kv_offload", [True, False]) +def test_custom_replicate_kv_pytorch_vs_ai100( + model_name, kv_offload +): + """ + Test function to validate the PyTorch model, the PyTorch model after KV changes, the ONNX model, and the Cloud AI 100 model, without continuous batching. + ``Mandatory`` Args: + :model_name (str): Hugging Face Model Card name, Example: ``gpt2`` + """ + torch.manual_seed(42) + if model_name in ModelConfig.SKIPPED_MODELS: + pytest.skip("Test skipped for this model due to some issues.") + if model_name in ModelConfig.DUAL_QPC_MODELS and not kv_offload: + pytest.skip("These models require kv_offload=True for testing.") + + img_size = model_config_dict[model_name].get("img_size") + + hf_config = None + model_type = model_config_dict[model_name].get("model_type", None) + if model_name in ModelConfig.STANDARD_VLM_MODELS and model_type is not None: + custom_config = model_config_dict[model_name].get("additional_params", {}) + hf_config = AutoConfig.for_model(model_type, trust_remote_code=True, **custom_config) + hf_config.name_or_path = model_name + if model_name in ModelConfig.REPEAT_KV_TEST_MODELS: + check_image_text_to_text_pytorch_vs_kv_vs_ort_vs_ai100( + model_name=model_name, + prompt_len=model_config_dict[model_name]["prompt_len"], + ctx_len=model_config_dict[model_name]["ctx_len"], + max_gen_len=NEW_GENERATION_TOKENS, + img_size=img_size, + img_url=model_config_dict[model_name]["img_url"], + query=model_config_dict[model_name]["text_prompt"], + n_layer=model_config_dict[model_name]["num_layers"], + batch_size=model_config_dict[model_name]["batch_size"], + kv_offload=kv_offload, + test_kv_replicate=True, + ) + else: + pytest.skip(f"Skipping replicate KV test for {model_name} as it's not in REPEAT_KV_TEST_MODELS") ################################ QNN Tests ################################ diff --git a/tests/transformers/models/test_causal_lm_models.py b/tests/transformers/models/test_causal_lm_models.py new file mode 100644 index 000000000..0fd406125 --- /dev/null +++ b/tests/transformers/models/test_causal_lm_models.py @@ -0,0 +1,810 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import copy +import json +import os +from typing import Optional + +import numpy as np +import pytest +import torch +from transformers import AutoConfig, AutoModelForCausalLM + +from QEfficient.exporter.export_hf_to_cloud_ai_100 import qualcomm_efficient_converter +from QEfficient.transformers.models.modeling_auto import QEFFAutoModelForCausalLM +from QEfficient.transformers.quantizers.auto import replace_transformers_quantizers +from QEfficient.utils import hf_download +from QEfficient.utils._utils import create_json, load_hf_tokenizer +from QEfficient.utils.constants import Constants, QnnConstants +from QEfficient.utils.device_utils import get_available_device_id +from QEfficient.utils.run_utils import ApiRunner +from QEfficient.utils.test_utils import ModelConfig + +CONFIG_PATH = "tests/configs/causal_model_configs.json" + +with open(CONFIG_PATH, "r") as f: + config_data = json.load(f) + causal_lm_models = config_data["causal_lm_models"] + causal_lm_fp16_models = config_data["causal_lm_fp16_test_models"] + spd_models = config_data["spd_causal_lm_models"] + qnn_models = config_data["qnn_causal_lm_models"] + blockedKV_models = config_data["blockedKV_causal_lm_models"] + + +# Create a list of model names for parameterization +test_models_causal = [model["model_name"] for model in causal_lm_models] +test_fp16_causal_models = [model["model_name"] for model in causal_lm_fp16_models] +test_models_spd = [model["model_name"] for model in spd_models] +test_models_qnn = [model["model_name"] for model in qnn_models] +test_models_blockedKV = [model["model_name"] for model in blockedKV_models] + +all_models = causal_lm_models + causal_lm_fp16_models + +# Create a dictionary mapping model names to their configs +model_config_dict = {model["model_name"]: model for model in all_models} + + +def get_hf_config_from_custom_config(model_name): + """ + Function to get HF config from custom config file + -------- + :model_name: str + + :return config + """ + custom_config = model_config_dict[model_name] + + hf_config = AutoConfig.from_pretrained( + model_name, + trust_remote_code=model_name in ModelConfig.EXTERNAL_MODELS, + **custom_config.get("additional_params", {}), + ) + return hf_config + + +def get_custom_n_layers(model_name): + """ + Function to set number layers of the variuos types of models such as swiftkv models and others + -------- + + :model_name: str + + :return n_layer + """ + if model_name in {"microsoft/Phi-3-mini-4k-instruct", "neuralmagic/Qwen2-0.5B-Instruct-FP8", "openai/gpt-oss-20b"}: + return 2 + elif model_name in ModelConfig.SWIFTKV_MODELS: + return None + return 1 + + +def load_causal_lm_model(model_name, n_layer=1, config=None, dtype=torch.float32): + """ + Function to load model from huggingface and transform to KV model + -------- + + :model_name: str + :n_layer: int + :config: Autoconfig + + :return model_hf, params + """ + torch.manual_seed(42) + model_path = hf_download( + repo_id=model_name, + ignore_patterns=["*.onnx", "*.ot", "*.md", "*.tflite", "*.pdf", "*.h5", "*.msgpack"], + ) + if config is None: # If custom config is not provided, load the model config from Hugging Face + if n_layer is not None: + model_hf = AutoModelForCausalLM.from_pretrained( + model_path, + use_cache=True, + num_hidden_layers=n_layer, + attn_implementation="eager", + low_cpu_mem_usage=False, + torch_dtype=dtype, + trust_remote_code=model_name in ModelConfig.EXTERNAL_MODELS, + ) + else: + # If n_layer is not specified, load the model without specifying the number of layers + model_hf = AutoModelForCausalLM.from_pretrained( + model_path, + use_cache=True, + attn_implementation="eager", + low_cpu_mem_usage=False, + torch_dtype=dtype, + trust_remote_code=model_name in ModelConfig.EXTERNAL_MODELS, + ) + else: # If custom config is provided, load the model using the config + model_hf = AutoModelForCausalLM.from_config( + config, + attn_implementation="eager", + torch_dtype=dtype, + trust_remote_code=model_name in ModelConfig.EXTERNAL_MODELS, + ) + # Convert to intended dtype + try: + model_hf = model_hf.to(dtype) + model_hf.config.torch_dtype = dtype + except ValueError: + pass # fully ignore + params = sum(p.numel() for p in model_hf.parameters()) + model_hf.eval() + return model_hf, params + + +def check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( + model_name: str, + prompt_len: int = Constants.PROMPT_LEN, + ctx_len: int = Constants.CTX_LEN, + n_layer: int = 1, + num_speculative_tokens: Optional[int] = None, + prefill_only: Optional[bool] = None, + enable_qnn: Optional[bool] = False, + qnn_config: Optional[str] = None, + config: Optional[AutoConfig] = None, + pytorch_hf_tokens: Optional[list] = None, + qaic_config: Optional[dict] = None, + retain_full_kv: Optional[bool] = None, +): + """ + Validate the PyTorch model, the PyTorch model after KV changes, the ONNX model, and the Cloud AI 100 model, both with and without continuous batching. + ``Mandatory`` Args: + :model_name (str): Hugging Face Model Card name, Example: ``gpt2`` + :prompt_len (int): Prompt length for the model to compile. + :ctx_len (int): Maximum context length to compile the model. + :n_layers (int): Number of layers for the Model. + """ + replace_transformers_quantizers() + if config is None: + n_layer = get_custom_n_layers(model_name) + model_hf, _ = load_causal_lm_model(model_name, n_layer=n_layer) + else: + model_hf, _ = load_causal_lm_model(model_name, config=config) + + tokenizer = load_hf_tokenizer(pretrained_model_name_or_path=model_name) + config = model_hf.config + batch_size = len(Constants.INPUT_STR) + api_runner = ApiRunner( + batch_size, + tokenizer, + config, + Constants.INPUT_STR, + Constants.PROMPT_LEN, + Constants.CTX_LEN, + ) + if model_name not in ModelConfig.SWIFTKV_MODELS and model_name not in ModelConfig.EXTERNAL_MODELS: + pytorch_hf_tokens = api_runner.run_hf_model_on_pytorch(model_hf) + + is_tlm = False if num_speculative_tokens is None else True + qeff_model = QEFFAutoModelForCausalLM( + copy.deepcopy(model_hf), is_tlm=is_tlm, pretrained_model_name_or_path=model_name, qaic_config=qaic_config + ) + qeff_model.transform( + ctx_len=ctx_len, seq_len=prompt_len, batch_size=batch_size, num_devices=1, qaic_config=qaic_config + ) + pytorch_kv_tokens = api_runner.run_kv_model_on_pytorch(qeff_model.model) + + if model_name not in ModelConfig.SWIFTKV_MODELS and model_name not in ModelConfig.EXTERNAL_MODELS: + assert (pytorch_hf_tokens == pytorch_kv_tokens).all(), ( + "Tokens don't match for HF PyTorch model output and KV PyTorch model output" + ) + onnx_model_path = qeff_model.export() + ort_tokens = api_runner.run_kv_model_on_ort(onnx_model_path, is_tlm=is_tlm) + gen_len = ort_tokens.shape[-1] + + assert (pytorch_kv_tokens == ort_tokens).all(), "Tokens don't match for ONNXRT output and PyTorch output." + + qpc_path = qeff_model.compile( + prefill_seq_len=prompt_len, + ctx_len=ctx_len, + num_cores=14, + mxfp6=False, + aic_enable_depth_first=False, + num_speculative_tokens=num_speculative_tokens, + prefill_only=prefill_only, + enable_qnn=enable_qnn, + qnn_config=qnn_config, + retain_full_kv=retain_full_kv, + ) + exec_info = qeff_model.generate(tokenizer, prompts=Constants.INPUT_STR) + cloud_ai_100_tokens = exec_info.generated_ids[0][ + :, :gen_len + ] # Because we always run for single input and single batch size + if prefill_only: + assert (ort_tokens[0][0] == cloud_ai_100_tokens[0][0]).all(), ( + "prefill run output tokens don't match for ONNXRT output and Cloud AI 100 output." + ) + else: + assert (ort_tokens == cloud_ai_100_tokens).all(), ( + "Tokens don't match for ONNXRT output and Cloud AI 100 output." + ) + assert os.path.isfile(os.path.join(os.path.dirname(qpc_path), "qconfig.json")) + if prefill_only is not None: + return + + # return # skip CB tests for now + + # testing for CB models + full_batch_size = 4 + fbs_prompts = Constants.INPUT_STR * 4 + api_runner = ApiRunner( + batch_size, + tokenizer, + config, + fbs_prompts, + Constants.PROMPT_LEN, + Constants.CTX_LEN, + full_batch_size, + ) + if model_name not in ModelConfig.SWIFTKV_MODELS and model_name not in ModelConfig.EXTERNAL_MODELS: + pytorch_hf_tokens = api_runner.run_hf_model_on_pytorch_CB(model_hf) + pytorch_hf_tokens = np.vstack(pytorch_hf_tokens) + + qeff_model = QEFFAutoModelForCausalLM( + copy.deepcopy(model_hf), + continuous_batching=True, + is_tlm=is_tlm, + pretrained_model_name_or_path=model_name, + qaic_config=qaic_config, + ) + qeff_model.transform(ctx_len=ctx_len, seq_len=prompt_len, batch_size=full_batch_size, num_devices=1) + onnx_model_path = qeff_model.export() + + if not get_available_device_id(): + pytest.skip("No available devices to run model on Cloud AI 100") + + compiler_options = {} + if prompt_len == 1: + prefill_spec = { + "batch_size": batch_size, + "seq_len": 1, + "ctx_len": ctx_len, + "full_batch_size": full_batch_size, + "sliding_window": 128, + } + decode_spec = { + "batch_size": full_batch_size, + "seq_len": 1, + "ctx_len": ctx_len, + "full_batch_size": full_batch_size, + "sliding_window": 128, + } + compiler_options = {"specializations": [prefill_spec, decode_spec]} + + # TODO: add prefill_only tests + qpc_path = qeff_model.compile( + prefill_seq_len=prompt_len, + ctx_len=ctx_len, + num_cores=14, + mxfp6=False, + aic_enable_depth_first=False, + batch_size=batch_size, + full_batch_size=full_batch_size, + num_speculative_tokens=num_speculative_tokens, + enable_qnn=enable_qnn, + qnn_config=qnn_config, + retain_full_kv=retain_full_kv, + **compiler_options, + ) + exec_info_fbs = qeff_model.generate(tokenizer, prompts=fbs_prompts) + if model_name in ModelConfig.SWIFTKV_MODELS or model_name in ModelConfig.EXTERNAL_MODELS: + assert all( + [ + all(ort_token[:24] == cloud_token[:24]) + for ort_token, cloud_token in zip(ort_tokens, exec_info_fbs.generated_ids) + ] + ), "Tokens don't match for HF PyTorch model output and Cloud AI 100 output." + else: + assert all( + [ + all(pt_token[:24] == cloud_token[:24]) + for pt_token, cloud_token in zip(pytorch_hf_tokens, exec_info_fbs.generated_ids) + ] + ), "Tokens don't match for HF PyTorch model output and Cloud AI 100 output." + + assert os.path.isfile(os.path.join(os.path.dirname(qpc_path), "qconfig.json")) + +def check_kv_repeat_causal_lm_pytorch_vs_ai100( + model_name: str, + prompt_len: int = Constants.PROMPT_LEN, + ctx_len: int = Constants.CTX_LEN, + n_layer: int = 1, + num_kv_heads_repeat: int = 1, + config: Optional[AutoConfig] = None, + pytorch_hf_tokens: Optional[list] = None, +): + """ + Validate the PyTorch model and the Cloud AI 100 model with repeating original KV heads. + ``Mandatory`` Args: + :model_name (str): Hugging Face Model Card name, Example: ``gpt2`` + :prompt_len (int): Prompt length for the model to compile. + :ctx_len (int): Maximum context length to compile the model. + :n_layers (int): Number of layers for the Model. + :num_kv_heads_repeat (int): Number of times to repeat KV heads. + """ + replace_transformers_quantizers() + if config is None: + n_layer = get_custom_n_layers(model_name) + model_hf, _ = load_causal_lm_model(model_name, n_layer=n_layer) + else: + model_hf, _ = load_causal_lm_model(model_name, config=config) + tokenizer = load_hf_tokenizer(pretrained_model_name_or_path=model_name) + config = model_hf.config + batch_size = len(Constants.INPUT_STR) + api_runner = ApiRunner( + batch_size, + tokenizer, + config, + Constants.INPUT_STR, + Constants.PROMPT_LEN, + Constants.CTX_LEN, + ) + if model_name not in ModelConfig.SWIFTKV_MODELS and model_name not in ModelConfig.EXTERNAL_MODELS: + pytorch_hf_tokens = api_runner.run_hf_model_on_pytorch(model_hf) + + # Generate num_kv_heads_repeat from config so that divisibility error doesn't occur. + num_kv_heads_repeat = getattr(config, "num_attention_heads", getattr(config, "n_head", 1)) // getattr(config, "num_key_value_heads", getattr(config, "n_head", 1)) + breakpoint() + qeff_model = QEFFAutoModelForCausalLM( + copy.deepcopy(model_hf), + pretrained_model_name_or_path=model_name, + num_kv_heads_repeat=num_kv_heads_repeat, + ) + if not get_available_device_id(): + pytest.skip("No available devices to run model on Cloud AI 100") + + qpc_path = qeff_model.compile( + prefill_seq_len=prompt_len, + ctx_len=ctx_len, + num_cores=14, + mxfp6=False, + aic_enable_depth_first=False, + ) + exec_info = qeff_model.generate(tokenizer, prompts=Constants.INPUT_STR) + gen_len = len(pytorch_hf_tokens) + cloud_ai_100_tokens = exec_info.generated_ids[0][:, :gen_len] + assert (pytorch_hf_tokens == cloud_ai_100_tokens).all(), ( + "Tokens don't match for Pytorch HF output and Cloud AI 100 output." + ) + assert os.path.isfile(os.path.join(os.path.dirname(qpc_path), "qconfig.json")) + +def check_causal_lm_pytorch_vs_kv_vs_ai100( + model_name: str, + prompt_len: int = Constants.PROMPT_LEN, + ctx_len: int = Constants.CTX_LEN, + n_layer: int = 1, + num_speculative_tokens: Optional[int] = None, + prefill_only: Optional[bool] = None, + enable_qnn: Optional[bool] = False, + qnn_config: Optional[str] = None, + config: Optional[AutoConfig] = None, + pytorch_hf_tokens: Optional[list] = None, + qaic_config: Optional[dict] = None, + retain_full_kv: Optional[bool] = None, + dtype: Optional[torch.dtype] = torch.float32, +): + """ + Validate the PyTorch model, the PyTorch model after KV changes, the ONNX model, and the Cloud AI 100 model, both with and without continuous batching. + ``Mandatory`` Args: + :model_name (str): Hugging Face Model Card name, Example: ``gpt2`` + :prompt_len (int): Prompt length for the model to compile. + :ctx_len (int): Maximum context length to compile the model. + :n_layers (int): Number of layers for the Model. + """ + replace_transformers_quantizers() + if config is None: + n_layer = get_custom_n_layers(model_name) + model_hf, _ = load_causal_lm_model(model_name, n_layer=n_layer, dtype=dtype) + else: + model_hf, _ = load_causal_lm_model(model_name, config=config, dtype=dtype) + + tokenizer = load_hf_tokenizer(pretrained_model_name_or_path=model_name) + config = model_hf.config + batch_size = len(Constants.INPUT_STR) + api_runner = ApiRunner( + batch_size, + tokenizer, + config, + Constants.INPUT_STR, + Constants.PROMPT_LEN, + Constants.CTX_LEN, + dtype=dtype, + ) + + if model_name not in ModelConfig.SWIFTKV_MODELS and model_name not in ModelConfig.EXTERNAL_MODELS: + pytorch_hf_tokens = api_runner.run_hf_model_on_pytorch(model_hf) + + is_tlm = False if num_speculative_tokens is None else True + qeff_model = QEFFAutoModelForCausalLM( + copy.deepcopy(model_hf), is_tlm=is_tlm, pretrained_model_name_or_path=model_name, qaic_config=qaic_config + ) + + pytorch_kv_tokens = api_runner.run_kv_model_on_pytorch(qeff_model.model) + + if model_name not in ModelConfig.SWIFTKV_MODELS and model_name not in ModelConfig.EXTERNAL_MODELS: + assert (pytorch_hf_tokens == pytorch_kv_tokens).all(), ( + "Tokens don't match for HF PyTorch model output and KV PyTorch model output" + ) + qeff_model.export() + qpc_path = qeff_model.compile( + prefill_seq_len=prompt_len, + ctx_len=ctx_len, + num_cores=16, + mxfp6=False, + aic_hw_version="ai100", + aic_enable_depth_first=False, + num_speculative_tokens=num_speculative_tokens, + prefill_only=prefill_only, + enable_qnn=enable_qnn, + qnn_config=qnn_config, + ) + exec_info = qeff_model.generate(tokenizer, prompts=Constants.INPUT_STR) + gen_len = pytorch_kv_tokens.shape[-1] + cloud_ai_100_tokens = exec_info.generated_ids[0][ + :, :gen_len + ] # Because we always run for single input and single batch size + if prefill_only: + assert (pytorch_hf_tokens[0][0] == cloud_ai_100_tokens[0][0]).all(), ( + "prefill run output tokens don't match for ONNXRT output and Cloud AI 100 output." + ) + else: + assert (pytorch_hf_tokens == cloud_ai_100_tokens).all(), ( + "Tokens don't match for ONNXRT output and Cloud AI 100 output." + ) + assert os.path.isfile(os.path.join(os.path.dirname(qpc_path), "qconfig.json")) + if prefill_only is not None: + return + + assert os.path.isfile(os.path.join(os.path.dirname(qpc_path), "qconfig.json")) + + +# FIXME: there should be a CB test here +@pytest.mark.parametrize("model_name", ["gpt2"], ids=lambda x: x) +def test_causal_lm_export_with_deprecated_api(model_name): + model, _ = load_causal_lm_model(model_name, n_layer=1) + tokenizer = load_hf_tokenizer(pretrained_model_name_or_path=model_name) + qeff_model = QEFFAutoModelForCausalLM(model, model_name=model_name, pretrained_model_name_or_path=model_name) + new_api_onnx_model_path = qeff_model.export() + + # Again loading model since the export moves model to meta device + model, _ = load_causal_lm_model(model_name, n_layer=1) + qeff_model = QEFFAutoModelForCausalLM(model, model_name=model_name, pretrained_model_name_or_path=model_name) + _, old_api_onnx_model_path = qualcomm_efficient_converter( + model_name=model_name, model_kv=qeff_model, tokenizer=tokenizer + ) + + api_runner = ApiRunner( + batch_size=1, + tokenizer=tokenizer, + config=model.config, + prompt=Constants.INPUT_STR, + prompt_len=Constants.PROMPT_LEN, + ctx_len=Constants.CTX_LEN, + ) + + new_api_ort_tokens = api_runner.run_kv_model_on_ort(new_api_onnx_model_path) + old_api_ort_tokens = api_runner.run_kv_model_on_ort(old_api_onnx_model_path) + + assert (new_api_ort_tokens == old_api_ort_tokens).all(), ( + "New API output does not match old API output for ONNX export function" + ) + + +@pytest.mark.on_qaic +@pytest.mark.regular +@pytest.mark.llm_model +@pytest.mark.parametrize("model_name", test_models_causal) +def test_custom_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(model_name): + """ + Test function to validate the dummy PyTorch model, the PyTorch model after KV changes, the ONNX model, and the Cloud AI 100 model, both with and without continuous batching. + ``Mandatory`` Args: + :model_name (str): Hugging Face Model Card name, Example: ``gpt2`` + """ + + hf_config = get_hf_config_from_custom_config(model_name) + if model_name in ModelConfig.QUANTIZED_MODELS: + n_layer = get_custom_n_layers(model_name) + check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(model_name, n_layer=n_layer) + else: + check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(model_name, config=hf_config) + + +@pytest.mark.on_qaic +@pytest.mark.regular +@pytest.mark.llm_model +@pytest.mark.parametrize("model_name", test_fp16_causal_models) +def test_custom_causal_lm_pytorch_vs_kv_vs_ai100(model_name): + """ + Test function to validate the dummy PyTorch model, the PyTorch model after KV changes, and the Cloud AI 100 model, without continuous batching for custom dtype. + ``Mandatory`` Args: + :model_name (str): Hugging Face Model Card name, Example: ``gpt2`` + """ + + hf_config = get_hf_config_from_custom_config(model_name) + if model_name in ModelConfig.QUANTIZED_MODELS: + n_layer = get_custom_n_layers(model_name) + check_causal_lm_pytorch_vs_kv_vs_ai100(model_name, n_layer=n_layer, dtype=torch.float16) + else: + check_causal_lm_pytorch_vs_kv_vs_ai100(model_name, config=hf_config, dtype=torch.float16) + + +@pytest.mark.nightly +@pytest.mark.on_qaic +@pytest.mark.llm_model +@pytest.mark.parametrize("model_name", test_models_causal) +def test_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(model_name): + """ + Test function to validate the PyTorch model, the PyTorch model after KV changes, the ONNX model, and the Cloud AI 100 model, both with and without continuous batching. + ``Mandatory`` Args: + :model_name (str): Hugging Face Model Card name, Example: ``gpt2`` + """ + n_layer = get_custom_n_layers(model_name) + + check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(model_name=model_name, n_layer=n_layer) + +@pytest.mark.nightly +@pytest.mark.on_qaic +@pytest.mark.parametrize("model_name", test_models_causal) +def test_check_kv_repeat_custom_causal_lm_pytorch_vs_ai100(model_name): + """ + Test function to validate the PyTorch model and the Cloud AI 100 model with repeating original KV heads. + ``Mandatory`` Args: + :model_name (str): Hugging Face Model Card name, Example: ``gpt2`` + """ + hf_config = get_hf_config_from_custom_config(model_name) + if model_name in ModelConfig.REPEAT_KV_TEST_MODELS: + if model_name in ModelConfig.QUANTIZED_MODELS: + n_layer = get_custom_n_layers(model_name) + check_kv_repeat_causal_lm_pytorch_vs_ai100(model_name, n_layer=n_layer) + else: + check_kv_repeat_causal_lm_pytorch_vs_ai100(model_name, config=hf_config) + else: + pytest.skip(f"Skipping {model_name} as it is not in REPEAT_KV_TEST_MODELS") + +@pytest.mark.nightly +@pytest.mark.on_qaic +@pytest.mark.llm_model +@pytest.mark.parametrize("model_name", test_fp16_causal_models) +def test_causal_lm_pytorch_vs_kv_vs_ai100(model_name): + """ + Test function to validate the PyTorch model, the PyTorch model after KV changes, the ONNX model, and the Cloud AI 100 model, both with and without continuous batching. + ``Mandatory`` Args: + :model_name (str): Hugging Face Model Card name, Example: ``gpt2`` + """ + n_layer = get_custom_n_layers(model_name) + + check_causal_lm_pytorch_vs_kv_vs_ai100(model_name=model_name, n_layer=n_layer, dtype=torch.float16) + + +@pytest.mark.nightly +@pytest.mark.on_qaic +@pytest.mark.parametrize("retain_full_kv", [True, False]) +def test_causal_lm_gpt_oss_pytorch_vs_kv_vs_ort_vs_ai100_pl1(retain_full_kv): + """ + Test function to validate the PyTorch model, the PyTorch model after KV changes, the ONNX model, and the Cloud AI 100 model, both with and without continuous batching. + ``Mandatory`` Args: + :model_name (str): Hugging Face Model Card name, Example: ``gpt2`` + """ + model_name = "openai/gpt-oss-20b" + n_layer = get_custom_n_layers(model_name) + prompt_len = 1 + + check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( + model_name=model_name, n_layer=n_layer, prompt_len=prompt_len, retain_full_kv=retain_full_kv + ) + + +@pytest.mark.on_qaic +@pytest.mark.regular +@pytest.mark.qnn +@pytest.mark.llm_model +@pytest.mark.parametrize("model_name", test_models_qnn) +def test_custom_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100_qnn(model_name): + """ + QNN Setup + Test function to validate the dummy PyTorch model, the PyTorch model after KV changes, the ONNX model, and the Cloud AI 100 model, both with and without continuous batching. + ``Mandatory`` Args: + :model_name (str): Hugging Face Model Card name, Example: ``gpt2`` + """ + hf_config = get_hf_config_from_custom_config(model_name) + qnn_config_json_path = os.path.join(os.getcwd(), "qnn_config.json") + create_json(qnn_config_json_path, QnnConstants.QNN_SAMPLE_CONFIG) + + check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( + model_name, enable_qnn=True, qnn_config=qnn_config_json_path, config=hf_config + ) + + +@pytest.mark.nightly +@pytest.mark.on_qaic +@pytest.mark.qnn +@pytest.mark.llm_model +@pytest.mark.parametrize("model_name", test_models_qnn) +def test_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100_qnn(model_name): + """ + QNN Setup + Test function to validate the PyTorch model, the PyTorch model after KV changes, the ONNX model, and the Cloud AI 100 model, both with and without continuous batching. + ``Mandatory`` Args: + :model_name (str): Hugging Face Model Card name, Example: ``gpt2`` + """ + qnn_config_json_path = os.path.join(os.getcwd(), "qnn_config.json") + create_json(qnn_config_json_path, QnnConstants.QNN_SAMPLE_CONFIG) + n_layer = get_custom_n_layers(model_name) + + check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( + model_name=model_name, n_layer=n_layer, enable_qnn=True, qnn_config=qnn_config_json_path + ) + + +@pytest.mark.regular +@pytest.mark.on_qaic +@pytest.mark.qnn +@pytest.mark.llm_model +@pytest.mark.parametrize("model_name", test_models_spd) +def test_custom_causal_tlm_pytorch_vs_kv_vs_ort_vs_ai100(model_name): + """ + Test function to validate the dummy PyTorch model for speculative decoding, the PyTorch model after KV changes, the ONNX model, and the Cloud AI 100 model, both with and without continuous batching. + ``Mandatory`` Args: + :model_name (str): Hugging Face Model Card name, Example: ``gpt2`` + """ + hf_config = get_hf_config_from_custom_config(model_name) + + check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( + model_name=model_name, + num_speculative_tokens=Constants.NUM_SPECULATIVE_TOKENS, + config=hf_config, + ) + + +@pytest.mark.nightly +@pytest.mark.on_qaic +@pytest.mark.llm_model +@pytest.mark.parametrize("model_name", test_models_spd) +def test_causal_tlm_pytorch_vs_kv_vs_ort_vs_ai100(model_name): + """ + Test function to validate the PyTorch model for speculative decoding, the PyTorch model after KV changes, the ONNX model, and the Cloud AI 100 model, both with and without continuous batching. + ``Mandatory`` Args: + :model_name (str): Hugging Face Model Card name, Example: ``gpt2`` + """ + n_layer = get_custom_n_layers(model_name) + + check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( + model_name=model_name, n_layer=n_layer, num_speculative_tokens=Constants.NUM_SPECULATIVE_TOKENS + ) + + +@pytest.mark.on_qaic +@pytest.mark.llm_model +def test_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100_pl1(): + """ + Test function to validate the PyTorch model, the PyTorch model after KV changes, the ONNX model, and the Cloud AI 100 model for a prompt length of 1, both with and without continuous batching. + """ + model_name = "gpt2" + prompt_len = 1 + + check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(model_name=model_name, prompt_len=prompt_len) + + +@pytest.mark.on_qaic +@pytest.mark.qnn +@pytest.mark.llm_model +def test_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100_pl1_qnn(): + """ + Test function to validate the PyTorch model, the PyTorch model after KV changes, the ONNX model, and the Cloud AI 100 model for a prompt length of 1, both with and without continuous batching. + """ + model_name = "gpt2" + prompt_len = 1 + + qnn_config_json_path = os.path.join(os.getcwd(), "qnn_config.json") + create_json(qnn_config_json_path, QnnConstants.QNN_SAMPLE_CONFIG) + + check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( + model_name=model_name, prompt_len=prompt_len, enable_qnn=True, qnn_config=qnn_config_json_path + ) + + +@pytest.mark.on_qaic +@pytest.mark.llm_model +def test_prefiill_only_pytorch_vs_kv_vs_ort_vs_ai100(): + model_name = "gpt2" + n_layer = 1 + check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(model_name, n_layer=n_layer, prefill_only=True) + + check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(model_name, n_layer=n_layer, prefill_only=False) + + +@pytest.mark.on_qaic +@pytest.mark.qnn +@pytest.mark.llm_model +def test_prefiill_only_pytorch_vs_kv_vs_ort_vs_ai100_qnn(): + model_name = "gpt2" + n_layer = 1 + + qnn_config_json_path = os.path.join(os.getcwd(), "qnn_config.json") + create_json(qnn_config_json_path, QnnConstants.QNN_SAMPLE_CONFIG) + + check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( + model_name, n_layer=n_layer, prefill_only=True, enable_qnn=True, qnn_config=qnn_config_json_path + ) + + check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( + model_name, n_layer=n_layer, prefill_only=False, enable_qnn=True, qnn_config=qnn_config_json_path + ) + + +@pytest.mark.on_qaic +@pytest.mark.llm_model +@pytest.mark.regular +@pytest.mark.parametrize("model_name", test_models_blockedKV) +def test_custom_causal_blockedKV_pytorch_vs_kv_vs_ort_vs_ai100(model_name): + """ + Test function to validate the PyTorch model for KV blocking, the PyTorch model after KV changes, the ONNX model, and the Cloud AI 100 model, both with and without continuous batching. + ``Mandatory`` Args: + :model_name (str): Hugging Face Model Card name, Example: ``gpt2`` + """ + hf_config = get_hf_config_from_custom_config(model_name) + + NUM_KV_BLOCKS = 2 + + qaic_config = dict(enable_blocking=True, num_kv_blocks=NUM_KV_BLOCKS) + check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(model_name=model_name, config=hf_config, qaic_config=qaic_config) + + +@pytest.mark.on_qaic +@pytest.mark.parametrize("model_name", test_models_blockedKV) +def test_causal_all_blocking_pytorch_vs_kv_vs_ort_vs_ai100(model_name): + """ + Test function to validate the PyTorch model for HQKV blocking, the PyTorch model after KV changes, the ONNX model, and the Cloud AI 100 model, both with and without continuous batching. + ``Mandatory`` Args: + :model_name (str): Hugging Face Model Card name, Example: ``gpt2`` + """ + n_layer = get_custom_n_layers(model_name) + + HEAD_BLOCK_SIZE = 8 + NUM_KV_BLOCKS = 2 + NUM_Q_BLOCKS = 2 + + # head blocking only + qaic_config = dict(enable_blocking=True, head_block_size=HEAD_BLOCK_SIZE) + check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(model_name=model_name, n_layer=n_layer, qaic_config=qaic_config) + + # kv blocking only + qaic_config = dict(enable_blocking=True, num_kv_blocks=NUM_KV_BLOCKS) + check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(model_name=model_name, n_layer=n_layer, qaic_config=qaic_config) + + # q block only + qaic_config = dict(enable_blocking=True, num_q_blocks=NUM_Q_BLOCKS) + check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(model_name=model_name, n_layer=n_layer, qaic_config=qaic_config) + + # qkv blocking + qaic_config = dict(enable_blocking=True, num_kv_blocks=NUM_KV_BLOCKS, num_q_blocks=NUM_Q_BLOCKS) + check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(model_name=model_name, n_layer=n_layer, qaic_config=qaic_config) + + # head qkv blocking + qaic_config = dict( + enable_blocking=True, + head_block_size=HEAD_BLOCK_SIZE, + num_kv_blocks=NUM_KV_BLOCKS, + num_q_blocks=NUM_Q_BLOCKS, + ) + check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(model_name=model_name, n_layer=n_layer, qaic_config=qaic_config) + + +@pytest.mark.on_qaic +@pytest.mark.llm_model +@pytest.mark.parametrize("model_name", test_models_blockedKV) +def test_causal_nonBlockedKV_pytorch_vs_kv_vs_ort_vs_ai100(model_name): + """ + Test function to validate the PyTorch model for KV blocking, the PyTorch model after KV changes, the ONNX model, and the Cloud AI 100 model, both with and without continuous batching. + ``Mandatory`` Args: + :model_name (str): Hugging Face Model Card name, Example: ``gpt2`` + """ + n_layer = get_custom_n_layers(model_name) + + check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(model_name=model_name, n_layer=n_layer) From 913828d0fa8dc00c96518a0e8a6f1ce1e6ee3895 Mon Sep 17 00:00:00 2001 From: Dhiraj Kumar Sah Date: Tue, 19 May 2026 15:25:59 +0530 Subject: [PATCH 2/8] Changes are made based on PR #625 and addressing the comments along with changes made for the new transforms. TODO: Check for the ONNX directory path name being different. Check if the list of classes for mapping covers all the models that we support. Signed-off-by: Dhiraj Kumar Sah --- .../transformers/models/modeling_auto.py | 7 +- .../transformers/models/pytorch_transforms.py | 106 +++++------------- QEfficient/utils/test_utils.py | 2 + .../test_image_text_to_text_models.py | 17 +-- .../models/test_causal_lm_models.py | 8 +- 5 files changed, 45 insertions(+), 95 deletions(-) diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 5cf45df1a..aff0d9f07 100755 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -55,7 +55,6 @@ PrefillOnlyChunkedTransform, PrefillOnlyExternalModuleMapperTransform, PrefillOnlyTransform, - ReplicateKVHeadTransform, RevertPrefillKeepAttentionTransform, RevertPrefillOnlyExternalModuleMapperTransform, RevertPrefillOnlyTransform, @@ -1291,7 +1290,7 @@ def __init__( self.ccl_enabled = qaic_config.get("ccl_enabled", False) self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = None, None self.input_shapes, self.output_names = None, None - self.model, replicate_kv_transformed = ReplicateKVHeadTransform.apply(self.model, **kwargs) + # self.model, replicate_kv_transformed = ReplicateKVHeadTransform.apply(self.model, **kwargs) # ---Sampling--- # Note: SamplerTransform should be applied after all other transforms # are done. The role of the sampler is to just add nodes at the output of the @@ -2182,7 +2181,7 @@ def __init__( self.model.config.text_config.use_cache = True else: self.model.config.use_cache = True - self.model, replicate_kv_transformed = ReplicateKVHeadTransform.apply(self.model, **kwargs) + # self.model, replicate_kv_transformed = ReplicateKVHeadTransform.apply(self.model, **kwargs) self.hash_params["qeff_auto_class"] = self.__class__.__name__ self.ccl_enabled = False if qaic_config: @@ -3053,7 +3052,7 @@ def __init__( setattr(self.model, "mla_absorption", mla_absorption) self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = None, None self.hash_params["max_seq_len_cached"] = max_seq_len_cached - self.model, replicate_kv_transformed = ReplicateKVHeadTransform.apply(self.model, **kwargs) + # self.model, replicate_kv_transformed = ReplicateKVHeadTransform.apply(self.model, **kwargs) # ---Sampling--- # Note: SamplerTransform should be applied after all other transforms diff --git a/QEfficient/transformers/models/pytorch_transforms.py b/QEfficient/transformers/models/pytorch_transforms.py index 712d28307..58a5ee7e4 100755 --- a/QEfficient/transformers/models/pytorch_transforms.py +++ b/QEfficient/transformers/models/pytorch_transforms.py @@ -970,6 +970,7 @@ class RevertPrefillOnlyTransform(ModuleMappingTransform): **{v: k for k, v in PrefillOnlyChunkedTransform._module_mapping.items()}, } + class ReplicateKVHeadTransform(ModuleMutatorTransform): """ Replicates KV heads in attention modules to match the number of KV heads in the target model. @@ -1008,7 +1009,27 @@ class ReplicateKVHeadTransform(ModuleMutatorTransform): } _module_string_mapping = { "InternVLChatModel", - "MolmoForCausalLM," + "MolmoForCausalLM,", + "QEffGemma3DecoderWrapper", + "QEffGemma3EncoderWrapper", + "QEffInternDecoderWrapper", + "QEffInternEncoderWrapper", + "QEffLlama4DecoderWrapper", + "QEffLlama4EncoderWrapper", + "QEFFLlavaDecoderWrapper", + "QEFFLlavaEncoderWrapper", + "QEffLlavaNextDecoderWrapper", + "QEffLlavaNextEncoderWrapper", + "QEFFMistral3DecoderWrapper", + "QEFFMistral3EncoderWrapper", + "QEffMolmoDecoderWrapper", + "QEffMolmoEncoderWrapper", + "QEffQwen_2_5_vl_DecoderWrapper", + "QEffQwen_2_5_vl_EncoderWrapper", + "QEffQwen3VLDecoderWrapper", + "QEffQwen3VLEncoderWrapper", + "QEffQwen3VLDecoderWrapper", + "QEffQwen3VLEncoderWrapper", } def _duplicate_weights_for_linear_layer( @@ -1056,9 +1077,9 @@ def _duplicate_weights_for_linear_layer( layer.weight.data.view(orig_kv_heads, head_dim, hidden_size), repeat, 0 ).view(new_kv_heads * head_dim, hidden_size) if layer.bias is not None: - layer.bias.data = torch.repeat_interleave( - layer.bias.data.view(orig_kv_heads, head_dim), repeat, 0 - ).view(new_kv_heads * head_dim) + layer.bias.data = torch.repeat_interleave(layer.bias.data.view(orig_kv_heads, head_dim), repeat, 0).view( + new_kv_heads * head_dim + ) def _get_text_model(model): """ @@ -1110,12 +1131,8 @@ def mutate(cls, original_module: nn.Module, parent_module: nn.Module, n_repeat: attn.num_key_value_heads = new_kv_heads attn.num_key_value_groups = num_attention_heads // new_kv_heads - cls._duplicate_weights_for_linear_layer( - attn.k_proj, orig_kv_heads, n_repeat, attn.head_dim, hidden_size - ) - cls._duplicate_weights_for_linear_layer( - attn.v_proj, orig_kv_heads, n_repeat, attn.head_dim, hidden_size - ) + cls._duplicate_weights_for_linear_layer(attn.k_proj, orig_kv_heads, n_repeat, attn.head_dim, hidden_size) + cls._duplicate_weights_for_linear_layer(attn.v_proj, orig_kv_heads, n_repeat, attn.head_dim, hidden_size) return original_module @@ -1141,75 +1158,6 @@ def apply(cls, model: nn.Module, **kwargs) -> Tuple[nn.Module, bool]: ) return model, transformed -class ReplicateKVHeadTransform: - """ - Replicates KV heads in attention modules to match the number of KV heads in the target model. - This transform is used when the source model has fewer KV heads than required in target model. - """ - - def _duplicate_weights_for_linear_layer( - layer: nn.Module, orig_kv_heads: int, repeat: int, dim: int, hidden_size: int - ): - new_kv_heads = repeat # for mla - - layer.weight.data = torch.repeat_interleave( - layer.weight.data.view(orig_kv_heads, dim, hidden_size), repeat, 0 - ).view(new_kv_heads * dim, hidden_size) - - if layer.bias is not None: - layer.bias.data = torch.repeat_interleave(layer.bias.data.view(orig_kv_heads, dim), repeat, 0).view( - new_kv_heads * dim - ) - - def _get_text_model(model): - """ - Determine and return the appropriate text_model from a given model object. - """ - # Check for VLMs - if hasattr(model, "language_model"): - if hasattr(model.language_model, "model"): - return model.language_model.model - else: - return model.language_model - # Check for CausalLMs - if hasattr(model, "model"): - return model.model - - raise AttributeError("No suitable text model found in the provided model.") - - @classmethod - def apply(cls, model: nn.Module, num_kv_heads_repeat: int = 1) -> nn.Module: - """ - Replicates KV heads in attention modules based on provided multiplier. - - Args: - model: The model to apply the transform to. - num_kv_heads_repeat: The number of times to repeat the KV heads. - """ - transformed = False - if num_kv_heads_repeat is not None and num_kv_heads_repeat > 1: - text_model = cls._get_text_model(model) - - orig_kv_heads = 1 # for mla #text_model.config.num_key_value_heads - new_kv_heads = num_kv_heads_repeat * orig_kv_heads - text_model.config.orig_kv_heads = orig_kv_heads - text_model.config.num_key_value_heads = new_kv_heads - - hidden_size = text_model.config.hidden_size - - logger.warning(f"Original KV heads: {orig_kv_heads}") - logger.warning(f"Modified KV heads: {new_kv_heads}") - transformed = True - for block in text_model.layers: - attn = getattr(block, "cross_attn", getattr(block, "self_attn", None)) - attn.num_key_value_heads = new_kv_heads - head_dim = attn.kv_lora_rank + attn.qk_rope_head_dim - - cls._duplicate_weights_for_linear_layer( - attn.kv_a_proj_with_mqa, orig_kv_heads, num_kv_heads_repeat, head_dim, hidden_size - ) - return model, transformed - class SpDTransform: """ diff --git a/QEfficient/utils/test_utils.py b/QEfficient/utils/test_utils.py index 77cfe7178..371202111 100644 --- a/QEfficient/utils/test_utils.py +++ b/QEfficient/utils/test_utils.py @@ -288,6 +288,7 @@ def load_qeff_model_with_sampler( return qeff_model + def get_text_config(config): if hasattr(config, "text_config"): return config.text_config @@ -295,6 +296,7 @@ def get_text_config(config): return config.llm_config return config + # Processor class for InternVL models class InternProcessor: """ diff --git a/tests/transformers/models/image_text_to_text/test_image_text_to_text_models.py b/tests/transformers/models/image_text_to_text/test_image_text_to_text_models.py index 0bd26efe8..d8e820e7a 100644 --- a/tests/transformers/models/image_text_to_text/test_image_text_to_text_models.py +++ b/tests/transformers/models/image_text_to_text/test_image_text_to_text_models.py @@ -348,13 +348,16 @@ def test_dummy_image_text_to_text_pytorch_vs_kv_vs_ort_vs_ai100(model_name, kv_o manual_cleanup=manual_cleanup, ) + @pytest.mark.on_qaic @pytest.mark.multimodal @pytest.mark.regular @pytest.mark.parametrize("model_name", test_mm_models) @pytest.mark.parametrize("kv_offload", [True, False]) def test_custom_replicate_kv_pytorch_vs_ai100( - model_name, kv_offload + model_name, + kv_offload, + manual_cleanup, ): """ Test function to validate the PyTorch model, the PyTorch model after KV changes, the ONNX model, and the Cloud AI 100 model, without continuous batching. @@ -367,8 +370,6 @@ def test_custom_replicate_kv_pytorch_vs_ai100( if model_name in ModelConfig.DUAL_QPC_MODELS and not kv_offload: pytest.skip("These models require kv_offload=True for testing.") - img_size = model_config_dict[model_name].get("img_size") - hf_config = None model_type = model_config_dict[model_name].get("model_type", None) if model_name in ModelConfig.STANDARD_VLM_MODELS and model_type is not None: @@ -378,20 +379,14 @@ def test_custom_replicate_kv_pytorch_vs_ai100( if model_name in ModelConfig.REPEAT_KV_TEST_MODELS: check_image_text_to_text_pytorch_vs_kv_vs_ort_vs_ai100( model_name=model_name, - prompt_len=model_config_dict[model_name]["prompt_len"], - ctx_len=model_config_dict[model_name]["ctx_len"], - max_gen_len=NEW_GENERATION_TOKENS, - img_size=img_size, - img_url=model_config_dict[model_name]["img_url"], - query=model_config_dict[model_name]["text_prompt"], - n_layer=model_config_dict[model_name]["num_layers"], - batch_size=model_config_dict[model_name]["batch_size"], kv_offload=kv_offload, test_kv_replicate=True, + manual_cleanup=manual_cleanup, ) else: pytest.skip(f"Skipping replicate KV test for {model_name} as it's not in REPEAT_KV_TEST_MODELS") + ################################ QNN Tests ################################ diff --git a/tests/transformers/models/test_causal_lm_models.py b/tests/transformers/models/test_causal_lm_models.py index 0fd406125..643778c65 100644 --- a/tests/transformers/models/test_causal_lm_models.py +++ b/tests/transformers/models/test_causal_lm_models.py @@ -310,6 +310,7 @@ def check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( assert os.path.isfile(os.path.join(os.path.dirname(qpc_path), "qconfig.json")) + def check_kv_repeat_causal_lm_pytorch_vs_ai100( model_name: str, prompt_len: int = Constants.PROMPT_LEN, @@ -349,7 +350,9 @@ def check_kv_repeat_causal_lm_pytorch_vs_ai100( pytorch_hf_tokens = api_runner.run_hf_model_on_pytorch(model_hf) # Generate num_kv_heads_repeat from config so that divisibility error doesn't occur. - num_kv_heads_repeat = getattr(config, "num_attention_heads", getattr(config, "n_head", 1)) // getattr(config, "num_key_value_heads", getattr(config, "n_head", 1)) + num_kv_heads_repeat = getattr(config, "num_attention_heads", getattr(config, "n_head", 1)) // getattr( + config, "num_key_value_heads", getattr(config, "n_head", 1) + ) breakpoint() qeff_model = QEFFAutoModelForCausalLM( copy.deepcopy(model_hf), @@ -374,6 +377,7 @@ def check_kv_repeat_causal_lm_pytorch_vs_ai100( ) assert os.path.isfile(os.path.join(os.path.dirname(qpc_path), "qconfig.json")) + def check_causal_lm_pytorch_vs_kv_vs_ai100( model_name: str, prompt_len: int = Constants.PROMPT_LEN, @@ -548,6 +552,7 @@ def test_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(model_name): check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(model_name=model_name, n_layer=n_layer) + @pytest.mark.nightly @pytest.mark.on_qaic @pytest.mark.parametrize("model_name", test_models_causal) @@ -567,6 +572,7 @@ def test_check_kv_repeat_custom_causal_lm_pytorch_vs_ai100(model_name): else: pytest.skip(f"Skipping {model_name} as it is not in REPEAT_KV_TEST_MODELS") + @pytest.mark.nightly @pytest.mark.on_qaic @pytest.mark.llm_model From bdeb038a483ae8e5aa958b582f5d08675f2e0d94 Mon Sep 17 00:00:00 2001 From: Dhiraj Kumar Sah Date: Fri, 22 May 2026 15:12:50 +0530 Subject: [PATCH 3/8] Updated script to enable repeatkv export of VLMs as well. Encoder/Decoder Wrappers were added to string mapping list to enable dummy model export for CI. Changes were made to prevent multiple application of ReplicateKVTransform if done in either Encoder or Decoder Wrapper already. Modeling files updated to access config in EncoderWrapper as well. Infra added for causalLM and VLM checks for repeatKV setup CI tests. CausalLM script APIRunner instantiation moved to allow updated input shapes to be made. Similarly commented export in VLM script since compile will call it with updated changes already. TODO: Confirm the changes that were made for DeepSeekV3 model for RepeatKV, currently they were removed for a generic approach. Signed-off-by: Dhiraj Kumar Sah --- QEfficient/base/modeling_qeff.py | 36 +- .../models/gemma3/modeling_gemma3.py | 1 + .../models/internvl/modeling_internvl.py | 1 + .../models/llama4/modeling_llama4.py | 1 + .../models/llava/modeling_llava.py | 1 + .../models/llava_next/modeling_llava_next.py | 1 + .../models/mistral3/modeling_mistral3.py | 1 + .../transformers/models/modeling_auto.py | 7 +- .../models/molmo/modeling_molmo.py | 1 + .../transformers/models/pytorch_transforms.py | 81 +- .../models/qwen2_5_vl/modeling_qwen2_5_vl.py | 1 + .../models/qwen3_vl/modeling_qwen3_vl.py | 1 + .../qwen3_vl_moe/modeling_qwen3_vl_moe.py | 1 + tests/configs/causal_model_configs.json | 4 - .../causal_lm_models/check_causal_models.py | 52 +- .../causal_lm_models/test_causal_lm_models.py | 27 + .../test_image_text_to_text_models.py | 51 +- .../models/test_causal_lm_models.py | 816 ------------------ 18 files changed, 213 insertions(+), 871 deletions(-) delete mode 100644 tests/transformers/models/test_causal_lm_models.py diff --git a/QEfficient/base/modeling_qeff.py b/QEfficient/base/modeling_qeff.py index c61248ffd..a5721a80b 100755 --- a/QEfficient/base/modeling_qeff.py +++ b/QEfficient/base/modeling_qeff.py @@ -656,23 +656,41 @@ def transform( **compiler_options, ): # Apply the transformations that are dependent on compilation parameters + def _transform_tracking_root(module: torch.nn.Module) -> torch.nn.Module: + """ + Use the shared wrapped model as transform-tracking root when available. + This lets encoder/decoder wrappers coordinate one-time transforms. + """ + wrapped = getattr(module, "model", None) + return wrapped if isinstance(wrapped, torch.nn.Module) else module qaic_config = qaic_config if qaic_config else getattr(self.model, "qaic_config", None) - model_config = getattr(self.model, "config", None) or getattr(self.model.model, "config", None) + model_config = getattr(self.model, "config", None) or getattr( + getattr(self.model, "model", None), "config", None + ) if model_config: - if "DeepseekV3ForCausalLM" in (getattr(model_config, "architectures", None) or []): - if qaic_config: - if qaic_config.get("blocking_mode", None) == "h": - qaic_config["head_block_size"] = qaic_config.get("head_block_size", num_devices) - num_kv_heads_repeat = qaic_config.get("num_kv_heads_repeat", 1) + # if "DeepseekV3ForCausalLM" in (getattr(model_config, "architectures", None) or []): + if qaic_config: + if qaic_config.get("blocking_mode", None) == "h": + qaic_config["head_block_size"] = qaic_config.get("head_block_size", num_devices) + num_kv_heads_repeat = qaic_config.get("num_kv_heads_repeat", 1) + transform_root = _transform_tracking_root(self.model) + applied_transforms = getattr(transform_root, "_qeff_runtime_transforms_applied", set()) + if ReplicateKVHeadTransform.__name__ in applied_transforms: + replicate_kv_transformed = False + logger.warning("Skipping RepeatKVTransform: already applied on this model instance.") + else: self.model, replicate_kv_transformed = ReplicateKVHeadTransform.apply( - self.model, num_kv_heads_repeat + self.model, + num_kv_heads_repeat=num_kv_heads_repeat, ) if replicate_kv_transformed: - self.hash_params["config"] = self.model.config.to_diff_dict() - + applied_transforms.add(ReplicateKVHeadTransform.__name__) + setattr(transform_root, "_qeff_runtime_transforms_applied", applied_transforms) + if replicate_kv_transformed: + self.hash_params["config"] = self.model.config.to_diff_dict() blocking_config = build_transformer_blocking_config_for_transform( model_config, ctx_len=ctx_len, diff --git a/QEfficient/transformers/models/gemma3/modeling_gemma3.py b/QEfficient/transformers/models/gemma3/modeling_gemma3.py index a3e9257a7..524a22081 100644 --- a/QEfficient/transformers/models/gemma3/modeling_gemma3.py +++ b/QEfficient/transformers/models/gemma3/modeling_gemma3.py @@ -626,6 +626,7 @@ class QEffGemma3EncoderWrapper(nn.Module): def __init__(self, model): super().__init__() self.model = model.model + self.config = self.model.config self.model.vision_model = self.model.vision_tower def get_submodules_for_export(self) -> Type[nn.Module]: diff --git a/QEfficient/transformers/models/internvl/modeling_internvl.py b/QEfficient/transformers/models/internvl/modeling_internvl.py index 563c42e25..7a0b7d524 100644 --- a/QEfficient/transformers/models/internvl/modeling_internvl.py +++ b/QEfficient/transformers/models/internvl/modeling_internvl.py @@ -20,6 +20,7 @@ class QEffInternEncoderWrapper(nn.Module): def __init__(self, model): super().__init__() self.model = model + self.config = self.model.config def get_submodules_for_export(self) -> Type[nn.Module]: """ diff --git a/QEfficient/transformers/models/llama4/modeling_llama4.py b/QEfficient/transformers/models/llama4/modeling_llama4.py index 2cf5dbb2e..c2c4b8ad7 100644 --- a/QEfficient/transformers/models/llama4/modeling_llama4.py +++ b/QEfficient/transformers/models/llama4/modeling_llama4.py @@ -831,6 +831,7 @@ class QEffLlama4EncoderWrapper(nn.Module): def __init__(self, model): super().__init__() self.model = model + self.config = self.model.config def get_submodules_for_export(self) -> Type[nn.Module]: """ diff --git a/QEfficient/transformers/models/llava/modeling_llava.py b/QEfficient/transformers/models/llava/modeling_llava.py index 88bb5e102..a4005497b 100644 --- a/QEfficient/transformers/models/llava/modeling_llava.py +++ b/QEfficient/transformers/models/llava/modeling_llava.py @@ -29,6 +29,7 @@ def __init__(self, model): super().__init__() self.model = model self.model.vision_model = self.model.model.vision_tower + self.config = self.model.config def get_submodules_for_export(self) -> Type[nn.Module]: """ diff --git a/QEfficient/transformers/models/llava_next/modeling_llava_next.py b/QEfficient/transformers/models/llava_next/modeling_llava_next.py index 342269ce5..43adfe7c5 100755 --- a/QEfficient/transformers/models/llava_next/modeling_llava_next.py +++ b/QEfficient/transformers/models/llava_next/modeling_llava_next.py @@ -29,6 +29,7 @@ def __init__(self, model): super().__init__() self.model = model self.model.vision_model = self.model.model.vision_tower + self.config = self.model.config def get_submodules_for_export(self) -> Type[nn.Module]: """ diff --git a/QEfficient/transformers/models/mistral3/modeling_mistral3.py b/QEfficient/transformers/models/mistral3/modeling_mistral3.py index 628d1dee2..3406791b7 100644 --- a/QEfficient/transformers/models/mistral3/modeling_mistral3.py +++ b/QEfficient/transformers/models/mistral3/modeling_mistral3.py @@ -183,6 +183,7 @@ class QEFFMistral3EncoderWrapper(nn.Module): def __init__(self, model): super().__init__() self.model = model + self.config = self.model.config self.model.model.vision_model = self.model.model.vision_tower def get_submodules_for_export(self) -> Type[nn.Module]: diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index aff0d9f07..12381c513 100755 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -1436,7 +1436,12 @@ def export( if prefill_only and prefill_seq_len > 1: offload_pt_weights = False # to keep weight for decode onnx else: - offload_pt_weights = kwargs.get("offload_pt_weights", True) + num_kv_heads_repeat = ( + (self.lang_model.model.qaic_config or {}).get("num_kv_heads_repeat", 1) + if hasattr(self.lang_model.model, "qaic_config") + else 1 + ) + offload_pt_weights = kwargs.get("offload_pt_weights", num_kv_heads_repeat <= 1) if not skip_lang and self.lang_model.onnx_path is None: self.lang_model.export( diff --git a/QEfficient/transformers/models/molmo/modeling_molmo.py b/QEfficient/transformers/models/molmo/modeling_molmo.py index d59ca4e01..b673d9e06 100644 --- a/QEfficient/transformers/models/molmo/modeling_molmo.py +++ b/QEfficient/transformers/models/molmo/modeling_molmo.py @@ -565,6 +565,7 @@ class QEffMolmoEncoderWrapper(nn.Module): def __init__(self, model): super().__init__() self.model = model + self.config = self.model.config def get_submodules_for_export(self) -> Type[nn.Module]: """ diff --git a/QEfficient/transformers/models/pytorch_transforms.py b/QEfficient/transformers/models/pytorch_transforms.py index 58a5ee7e4..74e7c583f 100755 --- a/QEfficient/transformers/models/pytorch_transforms.py +++ b/QEfficient/transformers/models/pytorch_transforms.py @@ -1028,8 +1028,6 @@ class ReplicateKVHeadTransform(ModuleMutatorTransform): "QEffQwen_2_5_vl_EncoderWrapper", "QEffQwen3VLDecoderWrapper", "QEffQwen3VLEncoderWrapper", - "QEffQwen3VLDecoderWrapper", - "QEffQwen3VLEncoderWrapper", } def _duplicate_weights_for_linear_layer( @@ -1081,26 +1079,67 @@ def _duplicate_weights_for_linear_layer( new_kv_heads * head_dim ) + def _is_valid_text_model(candidate: nn.Module) -> bool: + """ + Validate whether a candidate object looks like a text stack suitable for KV replication. + """ + if candidate is None: + return False + cfg = getattr(candidate, "config", None) + layers = getattr(candidate, "layers", None) + return ( + cfg is not None + and layers is not None + and hasattr(cfg, "num_key_value_heads") + and hasattr(cfg, "num_attention_heads") + and hasattr(cfg, "hidden_size") + ) + def _get_text_model(model): """ Determine and return the appropriate text_model from a given model object. + + Some VLM wrappers expose multiple nested text attributes (e.g. `language_model`, + `language_model.model`, `model.language_model`). We pick the first valid module + that has both `config` and `layers` required for KV head replication. """ - # Check for VLMs - if hasattr(model, "language_model"): - if hasattr(model.language_model, "model"): - return model.language_model.model - else: - return model.language_model - if hasattr(model, "model"): - return model.model - if hasattr(model, "transformer"): - return model.transformer - if hasattr(model, "llm"): - return model.llm - if hasattr(model, "backbone"): - return model.backbone - - raise AttributeError("No suitable text model found in the provided model.") + candidate_paths = ( + ("language_model",), + ("language_model", "model"), + ("model", "language_model"), + ("model", "language_model", "model"), + ("model",), + ("model", "model"), + ("transformer",), + ("transformer", "model"), + ("llm",), + ("llm", "model"), + ("backbone",), + ) + + for path in candidate_paths: + candidate = model + valid_path = True + for attr in path: + if not hasattr(candidate, attr): + valid_path = False + break + candidate = getattr(candidate, attr) + if valid_path and ReplicateKVHeadTransform._is_valid_text_model(candidate): + return candidate + + raise AttributeError( + f"No suitable text model found in the provided model ({model.__class__.__name__}). " + "Expected a module with `layers` and text `config` attributes." + ) + + def _get_replication_root(model: nn.Module) -> nn.Module: + """ + Return a shared root module for wrapper and non-wrapper models so KV replication + can be applied once across encoder/decoder components of the same model. + """ + candidate = getattr(model, "model", None) + return candidate if isinstance(candidate, nn.Module) else model @classmethod def mutate(cls, original_module: nn.Module, parent_module: nn.Module, n_repeat: int) -> nn.Module: @@ -1115,6 +1154,11 @@ def mutate(cls, original_module: nn.Module, parent_module: nn.Module, n_repeat: Returns: The mutated module (same object, modified in-place). """ + replication_root = cls._get_replication_root(original_module) + if getattr(replication_root, "_qeff_kv_replication_applied", False): + logger.warning("KV head replication already applied for this model instance; skipping.") + return original_module + text_model = cls._get_text_model(original_module) orig_kv_heads = text_model.config.num_key_value_heads new_kv_heads = n_repeat * orig_kv_heads @@ -1134,6 +1178,7 @@ def mutate(cls, original_module: nn.Module, parent_module: nn.Module, n_repeat: cls._duplicate_weights_for_linear_layer(attn.k_proj, orig_kv_heads, n_repeat, attn.head_dim, hidden_size) cls._duplicate_weights_for_linear_layer(attn.v_proj, orig_kv_heads, n_repeat, attn.head_dim, hidden_size) + setattr(replication_root, "_qeff_kv_replication_applied", True) return original_module @classmethod diff --git a/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py b/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py index 357c4af16..f970ba54b 100644 --- a/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +++ b/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py @@ -746,6 +746,7 @@ def __init__(self, model): super().__init__() self.model = model.model self.model.vision_model = self.model.visual + self.config = self.model.config def get_submodules_for_export(self) -> Type[nn.Module]: """ diff --git a/QEfficient/transformers/models/qwen3_vl/modeling_qwen3_vl.py b/QEfficient/transformers/models/qwen3_vl/modeling_qwen3_vl.py index 0f6ab210d..45a8a8fa5 100644 --- a/QEfficient/transformers/models/qwen3_vl/modeling_qwen3_vl.py +++ b/QEfficient/transformers/models/qwen3_vl/modeling_qwen3_vl.py @@ -652,6 +652,7 @@ def __init__(self, model): super().__init__() self.model = model.model self.model.vision_model = self.model.visual + self.config = self.model.config def get_submodules_for_export(self) -> Type[nn.Module]: """ diff --git a/QEfficient/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py b/QEfficient/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py index db30350f8..17ff828b4 100644 --- a/QEfficient/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +++ b/QEfficient/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py @@ -745,6 +745,7 @@ def __init__(self, model): super().__init__() self.model = model.model self.model.vision_model = self.model.visual + self.config = self.model.config def get_submodules_for_export(self) -> Type[nn.Module]: """ diff --git a/tests/configs/causal_model_configs.json b/tests/configs/causal_model_configs.json index 9c7160522..2c092ed9e 100644 --- a/tests/configs/causal_model_configs.json +++ b/tests/configs/causal_model_configs.json @@ -355,10 +355,6 @@ "high_freq_factor": 4.0, "low_freq_factor": 1.0, "original_max_position_embeddings": 8192, - "rope_type": "llama3""factor": 8.0, - "high_freq_factor": 4.0, - "low_freq_factor": 1.0, - "original_max_position_embeddings": 8192, "rope_type": "llama3" } } diff --git a/tests/transformers/models/causal_lm_models/check_causal_models.py b/tests/transformers/models/causal_lm_models/check_causal_models.py index f878acbe7..7b1b78abc 100644 --- a/tests/transformers/models/causal_lm_models/check_causal_models.py +++ b/tests/transformers/models/causal_lm_models/check_causal_models.py @@ -39,6 +39,40 @@ def get_custom_n_layers(model_name): return 1 +def check_kv_repeat_causal_lm_pytorch_vs_ai100( + model_name: str, + manual_cleanup: callable, + prompt_len: int = Constants.PROMPT_LEN, + ctx_len: int = Constants.CTX_LEN, + n_layer: int = -1, + config: Optional[AutoConfig] = None, +): + """ + Validate causal LM flow with repeated KV heads configuration. + """ + if config is None: + model_config = AutoConfig.from_pretrained( + model_name, + trust_remote_code=model_name in ModelConfig.EXTERNAL_MODELS, + ) + else: + model_config = config + + num_attention_heads = getattr(model_config, "num_attention_heads", getattr(model_config, "n_head", 1)) + num_key_value_heads = getattr(model_config, "num_key_value_heads", num_attention_heads) + num_kv_heads_repeat = max(1, num_attention_heads // max(1, num_key_value_heads)) + + check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( + model_name=model_name, + manual_cleanup=manual_cleanup, + prompt_len=prompt_len, + ctx_len=ctx_len, + n_layer=n_layer, + config=config, + qaic_config={"num_kv_heads_repeat": num_kv_heads_repeat}, + ) + + def check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( model_name: str, manual_cleanup: callable, @@ -71,15 +105,6 @@ def check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( pytorch_kv_tokens = None ort_tokens = None - api_runner = ApiRunner( - batch_size, - tokenizer, - config, - prompts, - Constants.PROMPT_LEN, - Constants.CTX_LEN, - full_batch_size if continuous_batching else None, - ) qeff_model = QEFFAutoModelForCausalLM( copy.deepcopy(model_hf), is_tlm=is_tlm, @@ -94,6 +119,15 @@ def check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( num_devices=num_devices, qaic_config=qaic_config, ) + api_runner = ApiRunner( + batch_size, + tokenizer, + qeff_model.config, + prompts, + Constants.PROMPT_LEN, + Constants.CTX_LEN, + full_batch_size if continuous_batching else None, + ) if continuous_batching is False: pytorch_kv_tokens = api_runner.run_kv_model_on_pytorch(qeff_model.model) diff --git a/tests/transformers/models/causal_lm_models/test_causal_lm_models.py b/tests/transformers/models/causal_lm_models/test_causal_lm_models.py index 5011a670a..1b0b07be6 100644 --- a/tests/transformers/models/causal_lm_models/test_causal_lm_models.py +++ b/tests/transformers/models/causal_lm_models/test_causal_lm_models.py @@ -17,6 +17,7 @@ from .check_causal_models import ( check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100, + check_kv_repeat_causal_lm_pytorch_vs_ai100, get_custom_n_layers, ) @@ -73,6 +74,32 @@ def test_dummy_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(model_name, manual_cleanu check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(model_name, config=hf_config, manual_cleanup=manual_cleanup) +@pytest.mark.dummy_layers +@pytest.mark.on_qaic +@pytest.mark.llm_model +@pytest.mark.parametrize("model_name", test_models_causal) +def test_check_kv_repeat_custom_causal_lm_pytorch_vs_ai100(model_name, manual_cleanup): + """ + Test function to validate the PyTorch model and the Cloud AI 100 model with repeating original KV heads. + ``Mandatory`` Args: + :model_name (str): Hugging Face Model Card name, Example: ``gpt2`` + """ + custom_config = model_config_dict[model_name] + hf_config = AutoConfig.from_pretrained( + model_name, + trust_remote_code=model_name in ModelConfig.EXTERNAL_MODELS, + **custom_config.get("additional_params", {}), + ) + if model_name in ModelConfig.REPEAT_KV_TEST_MODELS: + if model_name in ModelConfig.QUANTIZED_MODELS: + n_layer = get_custom_n_layers(model_name) + check_kv_repeat_causal_lm_pytorch_vs_ai100(model_name, manual_cleanup=manual_cleanup, n_layer=n_layer) + else: + check_kv_repeat_causal_lm_pytorch_vs_ai100(model_name, manual_cleanup=manual_cleanup, config=hf_config) + else: + pytest.skip(f"Skipping {model_name} as it is not in REPEAT_KV_TEST_MODELS") + + @pytest.mark.full_layers @pytest.mark.on_qaic @pytest.mark.llm_model diff --git a/tests/transformers/models/image_text_to_text/test_image_text_to_text_models.py b/tests/transformers/models/image_text_to_text/test_image_text_to_text_models.py index d8e820e7a..1495ffb0b 100644 --- a/tests/transformers/models/image_text_to_text/test_image_text_to_text_models.py +++ b/tests/transformers/models/image_text_to_text/test_image_text_to_text_models.py @@ -57,6 +57,7 @@ def check_image_text_to_text_pytorch_vs_kv_vs_ort_vs_ai100( enable_qnn: Optional[bool] = False, qnn_config: Optional[str] = None, config: Optional[AutoConfig] = None, + qaic_config: Optional[dict] = None, num_kv_heads_repeat: Optional[int] = 1, test_kv_replicate: Optional[bool] = None, torch_dtype: Optional[torch.dtype] = torch.float32, @@ -73,6 +74,7 @@ def check_image_text_to_text_pytorch_vs_kv_vs_ort_vs_ai100( pytorch_kv_tokens = None ort_tokens = None n_layer = num_hidden_layers + qaic_config = copy.deepcopy(qaic_config) if qaic_config is not None else None if config is None: config = AutoConfig.from_pretrained( model_name, trust_remote_code=True, padding=model_name not in ModelConfig.MOLMO_MODELS @@ -81,6 +83,8 @@ def check_image_text_to_text_pytorch_vs_kv_vs_ort_vs_ai100( if test_kv_replicate: text_config = get_text_config(config) num_kv_heads_repeat = text_config.num_attention_heads // text_config.num_key_value_heads + qaic_config = qaic_config or {} + qaic_config["num_kv_heads_repeat"] = num_kv_heads_repeat if hasattr(config, "model_type") and config.model_type in ["gemma3"]: config.text_config._sliding_window_pattern = 2 config.text_config.layer_types = ["sliding_attention", "full_attention"] @@ -98,6 +102,7 @@ def check_image_text_to_text_pytorch_vs_kv_vs_ort_vs_ai100( model_name, kv_offload=kv_offload, config=config, + qaic_config=qaic_config, torch_dtype=torch_dtype, num_kv_heads_repeat=num_kv_heads_repeat, ) @@ -107,6 +112,7 @@ def check_image_text_to_text_pytorch_vs_kv_vs_ort_vs_ai100( model_name, kv_offload=kv_offload, config=config, + qaic_config=qaic_config, torch_dtype=torch_dtype, num_kv_heads_repeat=num_kv_heads_repeat, ) @@ -114,11 +120,14 @@ def check_image_text_to_text_pytorch_vs_kv_vs_ort_vs_ai100( if test_kv_replicate: text_config = get_text_config(config) num_kv_heads_repeat = text_config.num_attention_heads // text_config.num_key_value_heads + qaic_config = qaic_config or {} + qaic_config["num_kv_heads_repeat"] = num_kv_heads_repeat model_hf = load_vlm_model_from_config(config) qeff_model = QEFFAutoModelForImageTextToText( copy.deepcopy(model_hf), kv_offload=kv_offload, config=model_hf.config, + qaic_config=qaic_config, torch_dtype=torch_dtype, num_kv_heads_repeat=num_kv_heads_repeat, ) @@ -129,6 +138,7 @@ def check_image_text_to_text_pytorch_vs_kv_vs_ort_vs_ai100( "mxfp6": False, "enable_qnn": enable_qnn, "qnn_config": qnn_config, + "qaic_config": qaic_config, } if model_name in ModelConfig.INTERNVL_MODELS: @@ -251,7 +261,7 @@ def check_image_text_to_text_pytorch_vs_kv_vs_ort_vs_ai100( # "Tokens don't match for pytorch HF output and pytorch KV output" # ) - _ = qeff_model.export() + # _ = qeff_model.export() # ort_tokens = api_runner.run_vlm_kv_model_on_ort(onnx_model_path) # assert (pytorch_hf_tokens == ort_tokens).all(), "Tokens don't match for pytorch HF output and ORT output" @@ -351,7 +361,7 @@ def test_dummy_image_text_to_text_pytorch_vs_kv_vs_ort_vs_ai100(model_name, kv_o @pytest.mark.on_qaic @pytest.mark.multimodal -@pytest.mark.regular +@pytest.mark.dummy_layers @pytest.mark.parametrize("model_name", test_mm_models) @pytest.mark.parametrize("kv_offload", [True, False]) def test_custom_replicate_kv_pytorch_vs_ai100( @@ -370,19 +380,32 @@ def test_custom_replicate_kv_pytorch_vs_ai100( if model_name in ModelConfig.DUAL_QPC_MODELS and not kv_offload: pytest.skip("These models require kv_offload=True for testing.") - hf_config = None - model_type = model_config_dict[model_name].get("model_type", None) - if model_name in ModelConfig.STANDARD_VLM_MODELS and model_type is not None: - custom_config = model_config_dict[model_name].get("additional_params", {}) - hf_config = AutoConfig.for_model(model_type, trust_remote_code=True, **custom_config) - hf_config.name_or_path = model_name if model_name in ModelConfig.REPEAT_KV_TEST_MODELS: - check_image_text_to_text_pytorch_vs_kv_vs_ort_vs_ai100( - model_name=model_name, - kv_offload=kv_offload, - test_kv_replicate=True, - manual_cleanup=manual_cleanup, - ) + hf_config = None + if model_name in ModelConfig.STANDARD_VLM_MODELS: + model_type = model_config_dict[model_name].get("model_type") + custom_config = model_config_dict[model_name].get("additional_params", {}) + hf_config = AutoConfig.for_model(model_type, trust_remote_code=True, **custom_config) + hf_config.name_or_path = model_name + + if hf_config is not None: + check_image_text_to_text_pytorch_vs_kv_vs_ort_vs_ai100( + model_name=model_name, + kv_offload=kv_offload, + config=hf_config, + qaic_config={}, + test_kv_replicate=True, + manual_cleanup=manual_cleanup, + ) + else: + check_image_text_to_text_pytorch_vs_kv_vs_ort_vs_ai100( + model_name=model_name, + num_hidden_layers=model_config_dict[model_name]["num_layers"], + kv_offload=kv_offload, + qaic_config={}, + test_kv_replicate=True, + manual_cleanup=manual_cleanup, + ) else: pytest.skip(f"Skipping replicate KV test for {model_name} as it's not in REPEAT_KV_TEST_MODELS") diff --git a/tests/transformers/models/test_causal_lm_models.py b/tests/transformers/models/test_causal_lm_models.py deleted file mode 100644 index 643778c65..000000000 --- a/tests/transformers/models/test_causal_lm_models.py +++ /dev/null @@ -1,816 +0,0 @@ -# ----------------------------------------------------------------------------- -# -# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. -# SPDX-License-Identifier: BSD-3-Clause -# -# ----------------------------------------------------------------------------- - -import copy -import json -import os -from typing import Optional - -import numpy as np -import pytest -import torch -from transformers import AutoConfig, AutoModelForCausalLM - -from QEfficient.exporter.export_hf_to_cloud_ai_100 import qualcomm_efficient_converter -from QEfficient.transformers.models.modeling_auto import QEFFAutoModelForCausalLM -from QEfficient.transformers.quantizers.auto import replace_transformers_quantizers -from QEfficient.utils import hf_download -from QEfficient.utils._utils import create_json, load_hf_tokenizer -from QEfficient.utils.constants import Constants, QnnConstants -from QEfficient.utils.device_utils import get_available_device_id -from QEfficient.utils.run_utils import ApiRunner -from QEfficient.utils.test_utils import ModelConfig - -CONFIG_PATH = "tests/configs/causal_model_configs.json" - -with open(CONFIG_PATH, "r") as f: - config_data = json.load(f) - causal_lm_models = config_data["causal_lm_models"] - causal_lm_fp16_models = config_data["causal_lm_fp16_test_models"] - spd_models = config_data["spd_causal_lm_models"] - qnn_models = config_data["qnn_causal_lm_models"] - blockedKV_models = config_data["blockedKV_causal_lm_models"] - - -# Create a list of model names for parameterization -test_models_causal = [model["model_name"] for model in causal_lm_models] -test_fp16_causal_models = [model["model_name"] for model in causal_lm_fp16_models] -test_models_spd = [model["model_name"] for model in spd_models] -test_models_qnn = [model["model_name"] for model in qnn_models] -test_models_blockedKV = [model["model_name"] for model in blockedKV_models] - -all_models = causal_lm_models + causal_lm_fp16_models - -# Create a dictionary mapping model names to their configs -model_config_dict = {model["model_name"]: model for model in all_models} - - -def get_hf_config_from_custom_config(model_name): - """ - Function to get HF config from custom config file - -------- - :model_name: str - - :return config - """ - custom_config = model_config_dict[model_name] - - hf_config = AutoConfig.from_pretrained( - model_name, - trust_remote_code=model_name in ModelConfig.EXTERNAL_MODELS, - **custom_config.get("additional_params", {}), - ) - return hf_config - - -def get_custom_n_layers(model_name): - """ - Function to set number layers of the variuos types of models such as swiftkv models and others - -------- - - :model_name: str - - :return n_layer - """ - if model_name in {"microsoft/Phi-3-mini-4k-instruct", "neuralmagic/Qwen2-0.5B-Instruct-FP8", "openai/gpt-oss-20b"}: - return 2 - elif model_name in ModelConfig.SWIFTKV_MODELS: - return None - return 1 - - -def load_causal_lm_model(model_name, n_layer=1, config=None, dtype=torch.float32): - """ - Function to load model from huggingface and transform to KV model - -------- - - :model_name: str - :n_layer: int - :config: Autoconfig - - :return model_hf, params - """ - torch.manual_seed(42) - model_path = hf_download( - repo_id=model_name, - ignore_patterns=["*.onnx", "*.ot", "*.md", "*.tflite", "*.pdf", "*.h5", "*.msgpack"], - ) - if config is None: # If custom config is not provided, load the model config from Hugging Face - if n_layer is not None: - model_hf = AutoModelForCausalLM.from_pretrained( - model_path, - use_cache=True, - num_hidden_layers=n_layer, - attn_implementation="eager", - low_cpu_mem_usage=False, - torch_dtype=dtype, - trust_remote_code=model_name in ModelConfig.EXTERNAL_MODELS, - ) - else: - # If n_layer is not specified, load the model without specifying the number of layers - model_hf = AutoModelForCausalLM.from_pretrained( - model_path, - use_cache=True, - attn_implementation="eager", - low_cpu_mem_usage=False, - torch_dtype=dtype, - trust_remote_code=model_name in ModelConfig.EXTERNAL_MODELS, - ) - else: # If custom config is provided, load the model using the config - model_hf = AutoModelForCausalLM.from_config( - config, - attn_implementation="eager", - torch_dtype=dtype, - trust_remote_code=model_name in ModelConfig.EXTERNAL_MODELS, - ) - # Convert to intended dtype - try: - model_hf = model_hf.to(dtype) - model_hf.config.torch_dtype = dtype - except ValueError: - pass # fully ignore - params = sum(p.numel() for p in model_hf.parameters()) - model_hf.eval() - return model_hf, params - - -def check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( - model_name: str, - prompt_len: int = Constants.PROMPT_LEN, - ctx_len: int = Constants.CTX_LEN, - n_layer: int = 1, - num_speculative_tokens: Optional[int] = None, - prefill_only: Optional[bool] = None, - enable_qnn: Optional[bool] = False, - qnn_config: Optional[str] = None, - config: Optional[AutoConfig] = None, - pytorch_hf_tokens: Optional[list] = None, - qaic_config: Optional[dict] = None, - retain_full_kv: Optional[bool] = None, -): - """ - Validate the PyTorch model, the PyTorch model after KV changes, the ONNX model, and the Cloud AI 100 model, both with and without continuous batching. - ``Mandatory`` Args: - :model_name (str): Hugging Face Model Card name, Example: ``gpt2`` - :prompt_len (int): Prompt length for the model to compile. - :ctx_len (int): Maximum context length to compile the model. - :n_layers (int): Number of layers for the Model. - """ - replace_transformers_quantizers() - if config is None: - n_layer = get_custom_n_layers(model_name) - model_hf, _ = load_causal_lm_model(model_name, n_layer=n_layer) - else: - model_hf, _ = load_causal_lm_model(model_name, config=config) - - tokenizer = load_hf_tokenizer(pretrained_model_name_or_path=model_name) - config = model_hf.config - batch_size = len(Constants.INPUT_STR) - api_runner = ApiRunner( - batch_size, - tokenizer, - config, - Constants.INPUT_STR, - Constants.PROMPT_LEN, - Constants.CTX_LEN, - ) - if model_name not in ModelConfig.SWIFTKV_MODELS and model_name not in ModelConfig.EXTERNAL_MODELS: - pytorch_hf_tokens = api_runner.run_hf_model_on_pytorch(model_hf) - - is_tlm = False if num_speculative_tokens is None else True - qeff_model = QEFFAutoModelForCausalLM( - copy.deepcopy(model_hf), is_tlm=is_tlm, pretrained_model_name_or_path=model_name, qaic_config=qaic_config - ) - qeff_model.transform( - ctx_len=ctx_len, seq_len=prompt_len, batch_size=batch_size, num_devices=1, qaic_config=qaic_config - ) - pytorch_kv_tokens = api_runner.run_kv_model_on_pytorch(qeff_model.model) - - if model_name not in ModelConfig.SWIFTKV_MODELS and model_name not in ModelConfig.EXTERNAL_MODELS: - assert (pytorch_hf_tokens == pytorch_kv_tokens).all(), ( - "Tokens don't match for HF PyTorch model output and KV PyTorch model output" - ) - onnx_model_path = qeff_model.export() - ort_tokens = api_runner.run_kv_model_on_ort(onnx_model_path, is_tlm=is_tlm) - gen_len = ort_tokens.shape[-1] - - assert (pytorch_kv_tokens == ort_tokens).all(), "Tokens don't match for ONNXRT output and PyTorch output." - - qpc_path = qeff_model.compile( - prefill_seq_len=prompt_len, - ctx_len=ctx_len, - num_cores=14, - mxfp6=False, - aic_enable_depth_first=False, - num_speculative_tokens=num_speculative_tokens, - prefill_only=prefill_only, - enable_qnn=enable_qnn, - qnn_config=qnn_config, - retain_full_kv=retain_full_kv, - ) - exec_info = qeff_model.generate(tokenizer, prompts=Constants.INPUT_STR) - cloud_ai_100_tokens = exec_info.generated_ids[0][ - :, :gen_len - ] # Because we always run for single input and single batch size - if prefill_only: - assert (ort_tokens[0][0] == cloud_ai_100_tokens[0][0]).all(), ( - "prefill run output tokens don't match for ONNXRT output and Cloud AI 100 output." - ) - else: - assert (ort_tokens == cloud_ai_100_tokens).all(), ( - "Tokens don't match for ONNXRT output and Cloud AI 100 output." - ) - assert os.path.isfile(os.path.join(os.path.dirname(qpc_path), "qconfig.json")) - if prefill_only is not None: - return - - # return # skip CB tests for now - - # testing for CB models - full_batch_size = 4 - fbs_prompts = Constants.INPUT_STR * 4 - api_runner = ApiRunner( - batch_size, - tokenizer, - config, - fbs_prompts, - Constants.PROMPT_LEN, - Constants.CTX_LEN, - full_batch_size, - ) - if model_name not in ModelConfig.SWIFTKV_MODELS and model_name not in ModelConfig.EXTERNAL_MODELS: - pytorch_hf_tokens = api_runner.run_hf_model_on_pytorch_CB(model_hf) - pytorch_hf_tokens = np.vstack(pytorch_hf_tokens) - - qeff_model = QEFFAutoModelForCausalLM( - copy.deepcopy(model_hf), - continuous_batching=True, - is_tlm=is_tlm, - pretrained_model_name_or_path=model_name, - qaic_config=qaic_config, - ) - qeff_model.transform(ctx_len=ctx_len, seq_len=prompt_len, batch_size=full_batch_size, num_devices=1) - onnx_model_path = qeff_model.export() - - if not get_available_device_id(): - pytest.skip("No available devices to run model on Cloud AI 100") - - compiler_options = {} - if prompt_len == 1: - prefill_spec = { - "batch_size": batch_size, - "seq_len": 1, - "ctx_len": ctx_len, - "full_batch_size": full_batch_size, - "sliding_window": 128, - } - decode_spec = { - "batch_size": full_batch_size, - "seq_len": 1, - "ctx_len": ctx_len, - "full_batch_size": full_batch_size, - "sliding_window": 128, - } - compiler_options = {"specializations": [prefill_spec, decode_spec]} - - # TODO: add prefill_only tests - qpc_path = qeff_model.compile( - prefill_seq_len=prompt_len, - ctx_len=ctx_len, - num_cores=14, - mxfp6=False, - aic_enable_depth_first=False, - batch_size=batch_size, - full_batch_size=full_batch_size, - num_speculative_tokens=num_speculative_tokens, - enable_qnn=enable_qnn, - qnn_config=qnn_config, - retain_full_kv=retain_full_kv, - **compiler_options, - ) - exec_info_fbs = qeff_model.generate(tokenizer, prompts=fbs_prompts) - if model_name in ModelConfig.SWIFTKV_MODELS or model_name in ModelConfig.EXTERNAL_MODELS: - assert all( - [ - all(ort_token[:24] == cloud_token[:24]) - for ort_token, cloud_token in zip(ort_tokens, exec_info_fbs.generated_ids) - ] - ), "Tokens don't match for HF PyTorch model output and Cloud AI 100 output." - else: - assert all( - [ - all(pt_token[:24] == cloud_token[:24]) - for pt_token, cloud_token in zip(pytorch_hf_tokens, exec_info_fbs.generated_ids) - ] - ), "Tokens don't match for HF PyTorch model output and Cloud AI 100 output." - - assert os.path.isfile(os.path.join(os.path.dirname(qpc_path), "qconfig.json")) - - -def check_kv_repeat_causal_lm_pytorch_vs_ai100( - model_name: str, - prompt_len: int = Constants.PROMPT_LEN, - ctx_len: int = Constants.CTX_LEN, - n_layer: int = 1, - num_kv_heads_repeat: int = 1, - config: Optional[AutoConfig] = None, - pytorch_hf_tokens: Optional[list] = None, -): - """ - Validate the PyTorch model and the Cloud AI 100 model with repeating original KV heads. - ``Mandatory`` Args: - :model_name (str): Hugging Face Model Card name, Example: ``gpt2`` - :prompt_len (int): Prompt length for the model to compile. - :ctx_len (int): Maximum context length to compile the model. - :n_layers (int): Number of layers for the Model. - :num_kv_heads_repeat (int): Number of times to repeat KV heads. - """ - replace_transformers_quantizers() - if config is None: - n_layer = get_custom_n_layers(model_name) - model_hf, _ = load_causal_lm_model(model_name, n_layer=n_layer) - else: - model_hf, _ = load_causal_lm_model(model_name, config=config) - tokenizer = load_hf_tokenizer(pretrained_model_name_or_path=model_name) - config = model_hf.config - batch_size = len(Constants.INPUT_STR) - api_runner = ApiRunner( - batch_size, - tokenizer, - config, - Constants.INPUT_STR, - Constants.PROMPT_LEN, - Constants.CTX_LEN, - ) - if model_name not in ModelConfig.SWIFTKV_MODELS and model_name not in ModelConfig.EXTERNAL_MODELS: - pytorch_hf_tokens = api_runner.run_hf_model_on_pytorch(model_hf) - - # Generate num_kv_heads_repeat from config so that divisibility error doesn't occur. - num_kv_heads_repeat = getattr(config, "num_attention_heads", getattr(config, "n_head", 1)) // getattr( - config, "num_key_value_heads", getattr(config, "n_head", 1) - ) - breakpoint() - qeff_model = QEFFAutoModelForCausalLM( - copy.deepcopy(model_hf), - pretrained_model_name_or_path=model_name, - num_kv_heads_repeat=num_kv_heads_repeat, - ) - if not get_available_device_id(): - pytest.skip("No available devices to run model on Cloud AI 100") - - qpc_path = qeff_model.compile( - prefill_seq_len=prompt_len, - ctx_len=ctx_len, - num_cores=14, - mxfp6=False, - aic_enable_depth_first=False, - ) - exec_info = qeff_model.generate(tokenizer, prompts=Constants.INPUT_STR) - gen_len = len(pytorch_hf_tokens) - cloud_ai_100_tokens = exec_info.generated_ids[0][:, :gen_len] - assert (pytorch_hf_tokens == cloud_ai_100_tokens).all(), ( - "Tokens don't match for Pytorch HF output and Cloud AI 100 output." - ) - assert os.path.isfile(os.path.join(os.path.dirname(qpc_path), "qconfig.json")) - - -def check_causal_lm_pytorch_vs_kv_vs_ai100( - model_name: str, - prompt_len: int = Constants.PROMPT_LEN, - ctx_len: int = Constants.CTX_LEN, - n_layer: int = 1, - num_speculative_tokens: Optional[int] = None, - prefill_only: Optional[bool] = None, - enable_qnn: Optional[bool] = False, - qnn_config: Optional[str] = None, - config: Optional[AutoConfig] = None, - pytorch_hf_tokens: Optional[list] = None, - qaic_config: Optional[dict] = None, - retain_full_kv: Optional[bool] = None, - dtype: Optional[torch.dtype] = torch.float32, -): - """ - Validate the PyTorch model, the PyTorch model after KV changes, the ONNX model, and the Cloud AI 100 model, both with and without continuous batching. - ``Mandatory`` Args: - :model_name (str): Hugging Face Model Card name, Example: ``gpt2`` - :prompt_len (int): Prompt length for the model to compile. - :ctx_len (int): Maximum context length to compile the model. - :n_layers (int): Number of layers for the Model. - """ - replace_transformers_quantizers() - if config is None: - n_layer = get_custom_n_layers(model_name) - model_hf, _ = load_causal_lm_model(model_name, n_layer=n_layer, dtype=dtype) - else: - model_hf, _ = load_causal_lm_model(model_name, config=config, dtype=dtype) - - tokenizer = load_hf_tokenizer(pretrained_model_name_or_path=model_name) - config = model_hf.config - batch_size = len(Constants.INPUT_STR) - api_runner = ApiRunner( - batch_size, - tokenizer, - config, - Constants.INPUT_STR, - Constants.PROMPT_LEN, - Constants.CTX_LEN, - dtype=dtype, - ) - - if model_name not in ModelConfig.SWIFTKV_MODELS and model_name not in ModelConfig.EXTERNAL_MODELS: - pytorch_hf_tokens = api_runner.run_hf_model_on_pytorch(model_hf) - - is_tlm = False if num_speculative_tokens is None else True - qeff_model = QEFFAutoModelForCausalLM( - copy.deepcopy(model_hf), is_tlm=is_tlm, pretrained_model_name_or_path=model_name, qaic_config=qaic_config - ) - - pytorch_kv_tokens = api_runner.run_kv_model_on_pytorch(qeff_model.model) - - if model_name not in ModelConfig.SWIFTKV_MODELS and model_name not in ModelConfig.EXTERNAL_MODELS: - assert (pytorch_hf_tokens == pytorch_kv_tokens).all(), ( - "Tokens don't match for HF PyTorch model output and KV PyTorch model output" - ) - qeff_model.export() - qpc_path = qeff_model.compile( - prefill_seq_len=prompt_len, - ctx_len=ctx_len, - num_cores=16, - mxfp6=False, - aic_hw_version="ai100", - aic_enable_depth_first=False, - num_speculative_tokens=num_speculative_tokens, - prefill_only=prefill_only, - enable_qnn=enable_qnn, - qnn_config=qnn_config, - ) - exec_info = qeff_model.generate(tokenizer, prompts=Constants.INPUT_STR) - gen_len = pytorch_kv_tokens.shape[-1] - cloud_ai_100_tokens = exec_info.generated_ids[0][ - :, :gen_len - ] # Because we always run for single input and single batch size - if prefill_only: - assert (pytorch_hf_tokens[0][0] == cloud_ai_100_tokens[0][0]).all(), ( - "prefill run output tokens don't match for ONNXRT output and Cloud AI 100 output." - ) - else: - assert (pytorch_hf_tokens == cloud_ai_100_tokens).all(), ( - "Tokens don't match for ONNXRT output and Cloud AI 100 output." - ) - assert os.path.isfile(os.path.join(os.path.dirname(qpc_path), "qconfig.json")) - if prefill_only is not None: - return - - assert os.path.isfile(os.path.join(os.path.dirname(qpc_path), "qconfig.json")) - - -# FIXME: there should be a CB test here -@pytest.mark.parametrize("model_name", ["gpt2"], ids=lambda x: x) -def test_causal_lm_export_with_deprecated_api(model_name): - model, _ = load_causal_lm_model(model_name, n_layer=1) - tokenizer = load_hf_tokenizer(pretrained_model_name_or_path=model_name) - qeff_model = QEFFAutoModelForCausalLM(model, model_name=model_name, pretrained_model_name_or_path=model_name) - new_api_onnx_model_path = qeff_model.export() - - # Again loading model since the export moves model to meta device - model, _ = load_causal_lm_model(model_name, n_layer=1) - qeff_model = QEFFAutoModelForCausalLM(model, model_name=model_name, pretrained_model_name_or_path=model_name) - _, old_api_onnx_model_path = qualcomm_efficient_converter( - model_name=model_name, model_kv=qeff_model, tokenizer=tokenizer - ) - - api_runner = ApiRunner( - batch_size=1, - tokenizer=tokenizer, - config=model.config, - prompt=Constants.INPUT_STR, - prompt_len=Constants.PROMPT_LEN, - ctx_len=Constants.CTX_LEN, - ) - - new_api_ort_tokens = api_runner.run_kv_model_on_ort(new_api_onnx_model_path) - old_api_ort_tokens = api_runner.run_kv_model_on_ort(old_api_onnx_model_path) - - assert (new_api_ort_tokens == old_api_ort_tokens).all(), ( - "New API output does not match old API output for ONNX export function" - ) - - -@pytest.mark.on_qaic -@pytest.mark.regular -@pytest.mark.llm_model -@pytest.mark.parametrize("model_name", test_models_causal) -def test_custom_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(model_name): - """ - Test function to validate the dummy PyTorch model, the PyTorch model after KV changes, the ONNX model, and the Cloud AI 100 model, both with and without continuous batching. - ``Mandatory`` Args: - :model_name (str): Hugging Face Model Card name, Example: ``gpt2`` - """ - - hf_config = get_hf_config_from_custom_config(model_name) - if model_name in ModelConfig.QUANTIZED_MODELS: - n_layer = get_custom_n_layers(model_name) - check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(model_name, n_layer=n_layer) - else: - check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(model_name, config=hf_config) - - -@pytest.mark.on_qaic -@pytest.mark.regular -@pytest.mark.llm_model -@pytest.mark.parametrize("model_name", test_fp16_causal_models) -def test_custom_causal_lm_pytorch_vs_kv_vs_ai100(model_name): - """ - Test function to validate the dummy PyTorch model, the PyTorch model after KV changes, and the Cloud AI 100 model, without continuous batching for custom dtype. - ``Mandatory`` Args: - :model_name (str): Hugging Face Model Card name, Example: ``gpt2`` - """ - - hf_config = get_hf_config_from_custom_config(model_name) - if model_name in ModelConfig.QUANTIZED_MODELS: - n_layer = get_custom_n_layers(model_name) - check_causal_lm_pytorch_vs_kv_vs_ai100(model_name, n_layer=n_layer, dtype=torch.float16) - else: - check_causal_lm_pytorch_vs_kv_vs_ai100(model_name, config=hf_config, dtype=torch.float16) - - -@pytest.mark.nightly -@pytest.mark.on_qaic -@pytest.mark.llm_model -@pytest.mark.parametrize("model_name", test_models_causal) -def test_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(model_name): - """ - Test function to validate the PyTorch model, the PyTorch model after KV changes, the ONNX model, and the Cloud AI 100 model, both with and without continuous batching. - ``Mandatory`` Args: - :model_name (str): Hugging Face Model Card name, Example: ``gpt2`` - """ - n_layer = get_custom_n_layers(model_name) - - check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(model_name=model_name, n_layer=n_layer) - - -@pytest.mark.nightly -@pytest.mark.on_qaic -@pytest.mark.parametrize("model_name", test_models_causal) -def test_check_kv_repeat_custom_causal_lm_pytorch_vs_ai100(model_name): - """ - Test function to validate the PyTorch model and the Cloud AI 100 model with repeating original KV heads. - ``Mandatory`` Args: - :model_name (str): Hugging Face Model Card name, Example: ``gpt2`` - """ - hf_config = get_hf_config_from_custom_config(model_name) - if model_name in ModelConfig.REPEAT_KV_TEST_MODELS: - if model_name in ModelConfig.QUANTIZED_MODELS: - n_layer = get_custom_n_layers(model_name) - check_kv_repeat_causal_lm_pytorch_vs_ai100(model_name, n_layer=n_layer) - else: - check_kv_repeat_causal_lm_pytorch_vs_ai100(model_name, config=hf_config) - else: - pytest.skip(f"Skipping {model_name} as it is not in REPEAT_KV_TEST_MODELS") - - -@pytest.mark.nightly -@pytest.mark.on_qaic -@pytest.mark.llm_model -@pytest.mark.parametrize("model_name", test_fp16_causal_models) -def test_causal_lm_pytorch_vs_kv_vs_ai100(model_name): - """ - Test function to validate the PyTorch model, the PyTorch model after KV changes, the ONNX model, and the Cloud AI 100 model, both with and without continuous batching. - ``Mandatory`` Args: - :model_name (str): Hugging Face Model Card name, Example: ``gpt2`` - """ - n_layer = get_custom_n_layers(model_name) - - check_causal_lm_pytorch_vs_kv_vs_ai100(model_name=model_name, n_layer=n_layer, dtype=torch.float16) - - -@pytest.mark.nightly -@pytest.mark.on_qaic -@pytest.mark.parametrize("retain_full_kv", [True, False]) -def test_causal_lm_gpt_oss_pytorch_vs_kv_vs_ort_vs_ai100_pl1(retain_full_kv): - """ - Test function to validate the PyTorch model, the PyTorch model after KV changes, the ONNX model, and the Cloud AI 100 model, both with and without continuous batching. - ``Mandatory`` Args: - :model_name (str): Hugging Face Model Card name, Example: ``gpt2`` - """ - model_name = "openai/gpt-oss-20b" - n_layer = get_custom_n_layers(model_name) - prompt_len = 1 - - check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( - model_name=model_name, n_layer=n_layer, prompt_len=prompt_len, retain_full_kv=retain_full_kv - ) - - -@pytest.mark.on_qaic -@pytest.mark.regular -@pytest.mark.qnn -@pytest.mark.llm_model -@pytest.mark.parametrize("model_name", test_models_qnn) -def test_custom_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100_qnn(model_name): - """ - QNN Setup - Test function to validate the dummy PyTorch model, the PyTorch model after KV changes, the ONNX model, and the Cloud AI 100 model, both with and without continuous batching. - ``Mandatory`` Args: - :model_name (str): Hugging Face Model Card name, Example: ``gpt2`` - """ - hf_config = get_hf_config_from_custom_config(model_name) - qnn_config_json_path = os.path.join(os.getcwd(), "qnn_config.json") - create_json(qnn_config_json_path, QnnConstants.QNN_SAMPLE_CONFIG) - - check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( - model_name, enable_qnn=True, qnn_config=qnn_config_json_path, config=hf_config - ) - - -@pytest.mark.nightly -@pytest.mark.on_qaic -@pytest.mark.qnn -@pytest.mark.llm_model -@pytest.mark.parametrize("model_name", test_models_qnn) -def test_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100_qnn(model_name): - """ - QNN Setup - Test function to validate the PyTorch model, the PyTorch model after KV changes, the ONNX model, and the Cloud AI 100 model, both with and without continuous batching. - ``Mandatory`` Args: - :model_name (str): Hugging Face Model Card name, Example: ``gpt2`` - """ - qnn_config_json_path = os.path.join(os.getcwd(), "qnn_config.json") - create_json(qnn_config_json_path, QnnConstants.QNN_SAMPLE_CONFIG) - n_layer = get_custom_n_layers(model_name) - - check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( - model_name=model_name, n_layer=n_layer, enable_qnn=True, qnn_config=qnn_config_json_path - ) - - -@pytest.mark.regular -@pytest.mark.on_qaic -@pytest.mark.qnn -@pytest.mark.llm_model -@pytest.mark.parametrize("model_name", test_models_spd) -def test_custom_causal_tlm_pytorch_vs_kv_vs_ort_vs_ai100(model_name): - """ - Test function to validate the dummy PyTorch model for speculative decoding, the PyTorch model after KV changes, the ONNX model, and the Cloud AI 100 model, both with and without continuous batching. - ``Mandatory`` Args: - :model_name (str): Hugging Face Model Card name, Example: ``gpt2`` - """ - hf_config = get_hf_config_from_custom_config(model_name) - - check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( - model_name=model_name, - num_speculative_tokens=Constants.NUM_SPECULATIVE_TOKENS, - config=hf_config, - ) - - -@pytest.mark.nightly -@pytest.mark.on_qaic -@pytest.mark.llm_model -@pytest.mark.parametrize("model_name", test_models_spd) -def test_causal_tlm_pytorch_vs_kv_vs_ort_vs_ai100(model_name): - """ - Test function to validate the PyTorch model for speculative decoding, the PyTorch model after KV changes, the ONNX model, and the Cloud AI 100 model, both with and without continuous batching. - ``Mandatory`` Args: - :model_name (str): Hugging Face Model Card name, Example: ``gpt2`` - """ - n_layer = get_custom_n_layers(model_name) - - check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( - model_name=model_name, n_layer=n_layer, num_speculative_tokens=Constants.NUM_SPECULATIVE_TOKENS - ) - - -@pytest.mark.on_qaic -@pytest.mark.llm_model -def test_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100_pl1(): - """ - Test function to validate the PyTorch model, the PyTorch model after KV changes, the ONNX model, and the Cloud AI 100 model for a prompt length of 1, both with and without continuous batching. - """ - model_name = "gpt2" - prompt_len = 1 - - check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(model_name=model_name, prompt_len=prompt_len) - - -@pytest.mark.on_qaic -@pytest.mark.qnn -@pytest.mark.llm_model -def test_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100_pl1_qnn(): - """ - Test function to validate the PyTorch model, the PyTorch model after KV changes, the ONNX model, and the Cloud AI 100 model for a prompt length of 1, both with and without continuous batching. - """ - model_name = "gpt2" - prompt_len = 1 - - qnn_config_json_path = os.path.join(os.getcwd(), "qnn_config.json") - create_json(qnn_config_json_path, QnnConstants.QNN_SAMPLE_CONFIG) - - check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( - model_name=model_name, prompt_len=prompt_len, enable_qnn=True, qnn_config=qnn_config_json_path - ) - - -@pytest.mark.on_qaic -@pytest.mark.llm_model -def test_prefiill_only_pytorch_vs_kv_vs_ort_vs_ai100(): - model_name = "gpt2" - n_layer = 1 - check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(model_name, n_layer=n_layer, prefill_only=True) - - check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(model_name, n_layer=n_layer, prefill_only=False) - - -@pytest.mark.on_qaic -@pytest.mark.qnn -@pytest.mark.llm_model -def test_prefiill_only_pytorch_vs_kv_vs_ort_vs_ai100_qnn(): - model_name = "gpt2" - n_layer = 1 - - qnn_config_json_path = os.path.join(os.getcwd(), "qnn_config.json") - create_json(qnn_config_json_path, QnnConstants.QNN_SAMPLE_CONFIG) - - check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( - model_name, n_layer=n_layer, prefill_only=True, enable_qnn=True, qnn_config=qnn_config_json_path - ) - - check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( - model_name, n_layer=n_layer, prefill_only=False, enable_qnn=True, qnn_config=qnn_config_json_path - ) - - -@pytest.mark.on_qaic -@pytest.mark.llm_model -@pytest.mark.regular -@pytest.mark.parametrize("model_name", test_models_blockedKV) -def test_custom_causal_blockedKV_pytorch_vs_kv_vs_ort_vs_ai100(model_name): - """ - Test function to validate the PyTorch model for KV blocking, the PyTorch model after KV changes, the ONNX model, and the Cloud AI 100 model, both with and without continuous batching. - ``Mandatory`` Args: - :model_name (str): Hugging Face Model Card name, Example: ``gpt2`` - """ - hf_config = get_hf_config_from_custom_config(model_name) - - NUM_KV_BLOCKS = 2 - - qaic_config = dict(enable_blocking=True, num_kv_blocks=NUM_KV_BLOCKS) - check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(model_name=model_name, config=hf_config, qaic_config=qaic_config) - - -@pytest.mark.on_qaic -@pytest.mark.parametrize("model_name", test_models_blockedKV) -def test_causal_all_blocking_pytorch_vs_kv_vs_ort_vs_ai100(model_name): - """ - Test function to validate the PyTorch model for HQKV blocking, the PyTorch model after KV changes, the ONNX model, and the Cloud AI 100 model, both with and without continuous batching. - ``Mandatory`` Args: - :model_name (str): Hugging Face Model Card name, Example: ``gpt2`` - """ - n_layer = get_custom_n_layers(model_name) - - HEAD_BLOCK_SIZE = 8 - NUM_KV_BLOCKS = 2 - NUM_Q_BLOCKS = 2 - - # head blocking only - qaic_config = dict(enable_blocking=True, head_block_size=HEAD_BLOCK_SIZE) - check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(model_name=model_name, n_layer=n_layer, qaic_config=qaic_config) - - # kv blocking only - qaic_config = dict(enable_blocking=True, num_kv_blocks=NUM_KV_BLOCKS) - check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(model_name=model_name, n_layer=n_layer, qaic_config=qaic_config) - - # q block only - qaic_config = dict(enable_blocking=True, num_q_blocks=NUM_Q_BLOCKS) - check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(model_name=model_name, n_layer=n_layer, qaic_config=qaic_config) - - # qkv blocking - qaic_config = dict(enable_blocking=True, num_kv_blocks=NUM_KV_BLOCKS, num_q_blocks=NUM_Q_BLOCKS) - check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(model_name=model_name, n_layer=n_layer, qaic_config=qaic_config) - - # head qkv blocking - qaic_config = dict( - enable_blocking=True, - head_block_size=HEAD_BLOCK_SIZE, - num_kv_blocks=NUM_KV_BLOCKS, - num_q_blocks=NUM_Q_BLOCKS, - ) - check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(model_name=model_name, n_layer=n_layer, qaic_config=qaic_config) - - -@pytest.mark.on_qaic -@pytest.mark.llm_model -@pytest.mark.parametrize("model_name", test_models_blockedKV) -def test_causal_nonBlockedKV_pytorch_vs_kv_vs_ort_vs_ai100(model_name): - """ - Test function to validate the PyTorch model for KV blocking, the PyTorch model after KV changes, the ONNX model, and the Cloud AI 100 model, both with and without continuous batching. - ``Mandatory`` Args: - :model_name (str): Hugging Face Model Card name, Example: ``gpt2`` - """ - n_layer = get_custom_n_layers(model_name) - - check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(model_name=model_name, n_layer=n_layer) From 3d14efd382566d259bdf76f50287376cf78d8dbb Mon Sep 17 00:00:00 2001 From: Dhiraj Kumar Sah Date: Mon, 25 May 2026 14:57:54 +0530 Subject: [PATCH 4/8] Added RepeatKVTransform operations needed for DeepseekV3ForCausalLM. Made changes to allow generic name based transformation of heads (num_attention_heads, n_heads, n_head etc). Minor edits and utils created for this task. Signed-off-by: Dhiraj Kumar Sah --- QEfficient/base/modeling_qeff.py | 10 +- .../transformers/models/pytorch_transforms.py | 139 ++++++++++++++++-- QEfficient/utils/config_utils.py | 40 +++++ QEfficient/utils/constants.py | 5 + .../causal_lm_models/check_causal_models.py | 21 ++- 5 files changed, 192 insertions(+), 23 deletions(-) create mode 100644 QEfficient/utils/config_utils.py diff --git a/QEfficient/base/modeling_qeff.py b/QEfficient/base/modeling_qeff.py index a5721a80b..8645e494d 100755 --- a/QEfficient/base/modeling_qeff.py +++ b/QEfficient/base/modeling_qeff.py @@ -671,20 +671,24 @@ def _transform_tracking_root(module: torch.nn.Module) -> torch.nn.Module: ) if model_config: - # if "DeepseekV3ForCausalLM" in (getattr(model_config, "architectures", None) or []): + architectures = getattr(model_config, "architectures", None) or [] + is_deepseek_v3 = "DeepseekV3ForCausalLM" in architectures if qaic_config: if qaic_config.get("blocking_mode", None) == "h": qaic_config["head_block_size"] = qaic_config.get("head_block_size", num_devices) num_kv_heads_repeat = qaic_config.get("num_kv_heads_repeat", 1) transform_root = _transform_tracking_root(self.model) applied_transforms = getattr(transform_root, "_qeff_runtime_transforms_applied", set()) - if ReplicateKVHeadTransform.__name__ in applied_transforms: + should_apply_repeat_kv = is_deepseek_v3 or (num_kv_heads_repeat is not None and num_kv_heads_repeat > 1) + if not should_apply_repeat_kv: + replicate_kv_transformed = False + elif ReplicateKVHeadTransform.__name__ in applied_transforms: replicate_kv_transformed = False logger.warning("Skipping RepeatKVTransform: already applied on this model instance.") else: self.model, replicate_kv_transformed = ReplicateKVHeadTransform.apply( self.model, - num_kv_heads_repeat=num_kv_heads_repeat, + num_kv_heads_repeat, ) if replicate_kv_transformed: applied_transforms.add(ReplicateKVHeadTransform.__name__) diff --git a/QEfficient/transformers/models/pytorch_transforms.py b/QEfficient/transformers/models/pytorch_transforms.py index 74e7c583f..2a1f8ef2d 100755 --- a/QEfficient/transformers/models/pytorch_transforms.py +++ b/QEfficient/transformers/models/pytorch_transforms.py @@ -637,6 +637,13 @@ from QEfficient.transformers.quantizers.quantizer_compressed_tensors import FP8DeQuantLinear from QEfficient.transformers.sampler.sampler import sampler_forward from QEfficient.transformers.spd.spd_transform_forward import tlm_forward +from QEfficient.utils.config_utils import ( + resolve_attention_heads, + resolve_hidden_size, + resolve_kv_heads, + set_kv_head_aliases, +) +from QEfficient.utils.constants import ATTENTION_HEAD_CONFIG_KEYS, HIDDEN_SIZE_CONFIG_KEYS, KV_HEAD_CONFIG_KEYS from QEfficient.utils.logging_utils import logger SPD_TARGET = "target" @@ -1008,6 +1015,7 @@ class ReplicateKVHeadTransform(ModuleMutatorTransform): QEffOlmo2ForCausalLM, } _module_string_mapping = { + "DeepseekV3ForCausalLM", "InternVLChatModel", "MolmoForCausalLM,", "QEffGemma3DecoderWrapper", @@ -1030,6 +1038,51 @@ class ReplicateKVHeadTransform(ModuleMutatorTransform): "QEffQwen3VLEncoderWrapper", } + @classmethod + def _get_attention_module(cls, block: nn.Module) -> nn.Module: + for attr in ("cross_attn", "self_attn", "attention", "attn"): + attn = getattr(block, attr, None) + if attn is not None: + return attn + raise AttributeError(f"No attention module found in block type {block.__class__.__name__}") + + @staticmethod + def _get_projection_layer(attn: nn.Module, names: tuple) -> nn.Module: + for name in names: + layer = getattr(attn, name, None) + if layer is not None: + return layer + raise AttributeError(f"Missing projection layer in {attn.__class__.__name__}; expected one of {names}") + + @staticmethod + def _is_mla_attention(attn: nn.Module) -> bool: + return ( + hasattr(attn, "kv_a_proj_with_mqa") and hasattr(attn, "kv_lora_rank") and hasattr(attn, "qk_rope_head_dim") + ) + + @classmethod + def _is_mla_model(cls, text_model: nn.Module) -> bool: + for block in getattr(text_model, "layers", []): + try: + attn = cls._get_attention_module(block) + except AttributeError: + continue + if cls._is_mla_attention(attn): + return True + return False + + @staticmethod + def _duplicate_weights_for_mla_layer(layer: nn.Module, orig_kv_heads: int, repeat: int, dim: int, hidden_size: int): + new_kv_heads = repeat * orig_kv_heads + layer.weight.data = torch.repeat_interleave( + layer.weight.data.view(orig_kv_heads, dim, hidden_size), repeat, 0 + ).view(new_kv_heads * dim, hidden_size) + + if layer.bias is not None: + layer.bias.data = torch.repeat_interleave(layer.bias.data.view(orig_kv_heads, dim), repeat, 0).view( + new_kv_heads * dim + ) + def _duplicate_weights_for_linear_layer( layer: nn.Module, orig_kv_heads: int, repeat: int, head_dim: int, hidden_size: int ): @@ -1087,12 +1140,15 @@ def _is_valid_text_model(candidate: nn.Module) -> bool: return False cfg = getattr(candidate, "config", None) layers = getattr(candidate, "layers", None) + attn_heads = resolve_attention_heads(cfg) if cfg is not None else None + kv_heads = resolve_kv_heads(cfg) if cfg is not None else None + hidden_size = resolve_hidden_size(cfg) if cfg is not None else None return ( cfg is not None and layers is not None - and hasattr(cfg, "num_key_value_heads") - and hasattr(cfg, "num_attention_heads") - and hasattr(cfg, "hidden_size") + and attn_heads is not None + and kv_heads is not None + and hidden_size is not None ) def _get_text_model(model): @@ -1160,29 +1216,76 @@ def mutate(cls, original_module: nn.Module, parent_module: nn.Module, n_repeat: return original_module text_model = cls._get_text_model(original_module) - orig_kv_heads = text_model.config.num_key_value_heads + cfg = text_model.config + orig_kv_heads = resolve_kv_heads(cfg) + num_attention_heads = resolve_attention_heads(cfg) + hidden_size = resolve_hidden_size(cfg) + is_mla_model = cls._is_mla_model(text_model) + + if orig_kv_heads is None or num_attention_heads is None or hidden_size is None: + raise ValueError( + "Unable to resolve attention/KV heads or hidden size from config for RepeatKV transform. " + f"Supported attention keys={ATTENTION_HEAD_CONFIG_KEYS}, kv keys={KV_HEAD_CONFIG_KEYS}, " + f"hidden size keys={HIDDEN_SIZE_CONFIG_KEYS}." + ) + if orig_kv_heads < 1 or num_attention_heads < 1: + raise ValueError( + f"Invalid head values for RepeatKV transform: " + f"num_attention_heads={num_attention_heads}, num_key_value_heads={orig_kv_heads}" + ) + if is_mla_model: + # Legacy MLA path treats compressed-KV projection as single KV head. + orig_kv_heads = 1 + new_kv_heads = n_repeat * orig_kv_heads - text_model.config.orig_kv_heads = orig_kv_heads - text_model.config.num_key_value_heads = new_kv_heads + if (not is_mla_model) and (new_kv_heads > num_attention_heads or (num_attention_heads % new_kv_heads) != 0): + raise ValueError( + f"Invalid RepeatKV configuration: num_attention_heads={num_attention_heads}, " + f"orig_kv_heads={orig_kv_heads}, num_kv_heads_repeat={n_repeat}, new_kv_heads={new_kv_heads}. " + "Expected new_kv_heads <= num_attention_heads and divisibility." + ) - num_attention_heads = text_model.config.num_attention_heads - hidden_size = text_model.config.hidden_size + cfg.orig_kv_heads = orig_kv_heads + set_kv_head_aliases(cfg, new_kv_heads) logger.warning(f"Original KV heads: {orig_kv_heads}") logger.warning(f"Modified KV heads: {new_kv_heads}") for block in text_model.layers: - attn = getattr(block, "cross_attn", getattr(block, "self_attn", None)) - attn.num_key_value_heads = new_kv_heads - attn.num_key_value_groups = num_attention_heads // new_kv_heads - - cls._duplicate_weights_for_linear_layer(attn.k_proj, orig_kv_heads, n_repeat, attn.head_dim, hidden_size) - cls._duplicate_weights_for_linear_layer(attn.v_proj, orig_kv_heads, n_repeat, attn.head_dim, hidden_size) + attn = cls._get_attention_module(block) + if hasattr(attn, "num_key_value_heads"): + attn.num_key_value_heads = new_kv_heads + if hasattr(attn, "n_kv_heads"): + attn.n_kv_heads = new_kv_heads + + if cls._is_mla_attention(attn): + # Legacy MLA support: KV compression projection is organized as + # [kv_heads, kv_lora_rank + qk_rope_head_dim, hidden_size]. + mla_orig_kv_heads = 1 + mla_head_dim = int(attn.kv_lora_rank + attn.qk_rope_head_dim) + cls._duplicate_weights_for_mla_layer( + attn.kv_a_proj_with_mqa, + mla_orig_kv_heads, + n_repeat, + mla_head_dim, + hidden_size, + ) + else: + n_kv_groups = num_attention_heads // new_kv_heads + if hasattr(attn, "num_key_value_groups"): + attn.num_key_value_groups = n_kv_groups + if hasattr(attn, "n_kv_groups"): + attn.n_kv_groups = n_kv_groups + head_dim = getattr(attn, "head_dim", hidden_size // num_attention_heads) + k_proj = cls._get_projection_layer(attn, ("k_proj", "key_proj")) + v_proj = cls._get_projection_layer(attn, ("v_proj", "value_proj")) + cls._duplicate_weights_for_linear_layer(k_proj, orig_kv_heads, n_repeat, head_dim, hidden_size) + cls._duplicate_weights_for_linear_layer(v_proj, orig_kv_heads, n_repeat, head_dim, hidden_size) setattr(replication_root, "_qeff_kv_replication_applied", True) return original_module @classmethod - def apply(cls, model: nn.Module, **kwargs) -> Tuple[nn.Module, bool]: + def apply(cls, model: nn.Module, num_kv_heads_repeat: Optional[int] = None, **kwargs) -> Tuple[nn.Module, bool]: """ Replicates KV heads in attention modules based on provided multiplier. @@ -1191,7 +1294,11 @@ def apply(cls, model: nn.Module, **kwargs) -> Tuple[nn.Module, bool]: kwargs: Additional arguments for the transformation. Includes: - num_kv_heads_repeat: The number of times to repeat the KV heads. """ - n_repeat = kwargs.pop("num_kv_heads_repeat", 1) + if num_kv_heads_repeat is None: + n_repeat = kwargs.pop("num_kv_heads_repeat", 1) + else: + kwargs.pop("num_kv_heads_repeat", None) + n_repeat = num_kv_heads_repeat transformed = False if n_repeat is not None and n_repeat > 1: if (model.__class__ in cls._module_mapping) or (model.__class__.__name__ in cls._module_string_mapping): diff --git a/QEfficient/utils/config_utils.py b/QEfficient/utils/config_utils.py new file mode 100644 index 000000000..1e6125d7e --- /dev/null +++ b/QEfficient/utils/config_utils.py @@ -0,0 +1,40 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +from typing import Iterable, Optional + +from QEfficient.utils.constants import ATTENTION_HEAD_CONFIG_KEYS, HIDDEN_SIZE_CONFIG_KEYS, KV_HEAD_CONFIG_KEYS + + +def get_first_config_value(config, names: Iterable[str], default=None, cast_int: bool = False): + for name in names: + value = getattr(config, name, None) + if value is not None: + return int(value) if cast_int else value + return default + + +def resolve_attention_heads(config) -> Optional[int]: + return get_first_config_value(config, ATTENTION_HEAD_CONFIG_KEYS, cast_int=True) + + +def resolve_kv_heads(config) -> Optional[int]: + value = get_first_config_value(config, KV_HEAD_CONFIG_KEYS, cast_int=True) + if value is None: + value = resolve_attention_heads(config) + return value + + +def resolve_hidden_size(config) -> Optional[int]: + return get_first_config_value(config, HIDDEN_SIZE_CONFIG_KEYS, cast_int=True) + + +def set_kv_head_aliases(config, value: int): + setattr(config, "num_key_value_heads", value) + for key in KV_HEAD_CONFIG_KEYS: + if hasattr(config, key): + setattr(config, key, value) diff --git a/QEfficient/utils/constants.py b/QEfficient/utils/constants.py index 3a03f6b1c..0cc1c4ff6 100644 --- a/QEfficient/utils/constants.py +++ b/QEfficient/utils/constants.py @@ -140,6 +140,11 @@ def get_default_aic_hw_version() -> str: DEFAULT_AIC_HW_VERSION = get_default_aic_hw_version() ONNX_TRANSFORM_MEMORY_CLEANUP_INTERVAL = 100 +# Generic config key aliases used across model families. +ATTENTION_HEAD_CONFIG_KEYS = ("num_attention_heads", "n_head", "n_heads", "num_heads") +KV_HEAD_CONFIG_KEYS = ("num_key_value_heads", "n_kv_heads", "num_kv_heads", "effective_n_kv_heads") +HIDDEN_SIZE_CONFIG_KEYS = ("hidden_size", "n_embd", "d_model") + # InternVL constants # Fixing the feature size with reference to OpenGVLab/InternVL2_5-1B, OpenGVLab/InternVL2_5-38B and OpenGVLab/InternVL2_5-78B INTERN_FEATURE_SIZE = 256 diff --git a/tests/transformers/models/causal_lm_models/check_causal_models.py b/tests/transformers/models/causal_lm_models/check_causal_models.py index 7b1b78abc..e604cb72f 100644 --- a/tests/transformers/models/causal_lm_models/check_causal_models.py +++ b/tests/transformers/models/causal_lm_models/check_causal_models.py @@ -16,7 +16,8 @@ from QEfficient.transformers.models.modeling_auto import QEFFAutoModelForCausalLM from QEfficient.transformers.quantizers.auto import replace_transformers_quantizers from QEfficient.utils._utils import load_hf_tokenizer -from QEfficient.utils.constants import Constants +from QEfficient.utils.config_utils import get_first_config_value +from QEfficient.utils.constants import ATTENTION_HEAD_CONFIG_KEYS, KV_HEAD_CONFIG_KEYS, Constants from QEfficient.utils.run_utils import ApiRunner from QEfficient.utils.test_utils import ModelConfig, load_hf_causal_lm_model @@ -58,9 +59,21 @@ def check_kv_repeat_causal_lm_pytorch_vs_ai100( else: model_config = config - num_attention_heads = getattr(model_config, "num_attention_heads", getattr(model_config, "n_head", 1)) - num_key_value_heads = getattr(model_config, "num_key_value_heads", num_attention_heads) - num_kv_heads_repeat = max(1, num_attention_heads // max(1, num_key_value_heads)) + num_attention_heads = get_first_config_value(model_config, ATTENTION_HEAD_CONFIG_KEYS, default=1, cast_int=True) + num_key_value_heads = get_first_config_value(model_config, KV_HEAD_CONFIG_KEYS, default=None, cast_int=True) + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + if num_attention_heads < 1 or num_key_value_heads < 1: + raise ValueError( + f"Invalid heads in config for RepeatKV: " + f"num_attention_heads={num_attention_heads}, num_key_value_heads={num_key_value_heads}" + ) + if num_attention_heads % num_key_value_heads != 0: + raise ValueError( + f"Invalid heads in config for RepeatKV: num_attention_heads ({num_attention_heads}) " + f"is not divisible by num_key_value_heads ({num_key_value_heads})." + ) + num_kv_heads_repeat = num_attention_heads // num_key_value_heads check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( model_name=model_name, From 7b60678c94cd0e1ac4b31a5b48b4c45a9123a135 Mon Sep 17 00:00:00 2001 From: Dhiraj Kumar Sah Date: Tue, 26 May 2026 14:17:26 +0530 Subject: [PATCH 5/8] Addressed Internal Code Review comments. Edited the changes as suggested by quic-mamta. Signed-off-by: Dhiraj Kumar Sah --- QEfficient/base/modeling_qeff.py | 10 ++++------ QEfficient/transformers/models/pytorch_transforms.py | 7 +++++++ 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/QEfficient/base/modeling_qeff.py b/QEfficient/base/modeling_qeff.py index 8645e494d..216f56cc7 100755 --- a/QEfficient/base/modeling_qeff.py +++ b/QEfficient/base/modeling_qeff.py @@ -674,21 +674,19 @@ def _transform_tracking_root(module: torch.nn.Module) -> torch.nn.Module: architectures = getattr(model_config, "architectures", None) or [] is_deepseek_v3 = "DeepseekV3ForCausalLM" in architectures if qaic_config: - if qaic_config.get("blocking_mode", None) == "h": + if is_deepseek_v3 and (qaic_config.get("blocking_mode", None) == "h"): qaic_config["head_block_size"] = qaic_config.get("head_block_size", num_devices) num_kv_heads_repeat = qaic_config.get("num_kv_heads_repeat", 1) transform_root = _transform_tracking_root(self.model) applied_transforms = getattr(transform_root, "_qeff_runtime_transforms_applied", set()) - should_apply_repeat_kv = is_deepseek_v3 or (num_kv_heads_repeat is not None and num_kv_heads_repeat > 1) - if not should_apply_repeat_kv: - replicate_kv_transformed = False - elif ReplicateKVHeadTransform.__name__ in applied_transforms: + + if ReplicateKVHeadTransform.__name__ in applied_transforms: replicate_kv_transformed = False logger.warning("Skipping RepeatKVTransform: already applied on this model instance.") else: self.model, replicate_kv_transformed = ReplicateKVHeadTransform.apply( self.model, - num_kv_heads_repeat, + num_kv_heads_repeat=num_kv_heads_repeat, ) if replicate_kv_transformed: applied_transforms.add(ReplicateKVHeadTransform.__name__) diff --git a/QEfficient/transformers/models/pytorch_transforms.py b/QEfficient/transformers/models/pytorch_transforms.py index 2a1f8ef2d..53512683a 100755 --- a/QEfficient/transformers/models/pytorch_transforms.py +++ b/QEfficient/transformers/models/pytorch_transforms.py @@ -1083,6 +1083,7 @@ def _duplicate_weights_for_mla_layer(layer: nn.Module, orig_kv_heads: int, repea new_kv_heads * dim ) + @staticmethod def _duplicate_weights_for_linear_layer( layer: nn.Module, orig_kv_heads: int, repeat: int, head_dim: int, hidden_size: int ): @@ -1299,6 +1300,12 @@ def apply(cls, model: nn.Module, num_kv_heads_repeat: Optional[int] = None, **kw else: kwargs.pop("num_kv_heads_repeat", None) n_repeat = num_kv_heads_repeat + # Validate n_repeat is a positive integer + if not isinstance(n_repeat, int) or n_repeat < 1: + raise ValueError( + f"num_kv_heads_repeat must be a positive integer, got: {n_repeat} (type: {type(n_repeat).__name__})" + ) + transformed = False if n_repeat is not None and n_repeat > 1: if (model.__class__ in cls._module_mapping) or (model.__class__.__name__ in cls._module_string_mapping): From 76ff96c3e21d1d6727f08a913f7ba176a744b16d Mon Sep 17 00:00:00 2001 From: Dhiraj Kumar Sah Date: Fri, 29 May 2026 14:47:04 +0530 Subject: [PATCH 6/8] Revert "Addressed Internal Code Review comments." This reverts commit b40a34d561274feed16696418c51c7c6f4d1a06a. --- QEfficient/base/modeling_qeff.py | 10 ++++++---- QEfficient/transformers/models/pytorch_transforms.py | 7 ------- 2 files changed, 6 insertions(+), 11 deletions(-) diff --git a/QEfficient/base/modeling_qeff.py b/QEfficient/base/modeling_qeff.py index 216f56cc7..8645e494d 100755 --- a/QEfficient/base/modeling_qeff.py +++ b/QEfficient/base/modeling_qeff.py @@ -674,19 +674,21 @@ def _transform_tracking_root(module: torch.nn.Module) -> torch.nn.Module: architectures = getattr(model_config, "architectures", None) or [] is_deepseek_v3 = "DeepseekV3ForCausalLM" in architectures if qaic_config: - if is_deepseek_v3 and (qaic_config.get("blocking_mode", None) == "h"): + if qaic_config.get("blocking_mode", None) == "h": qaic_config["head_block_size"] = qaic_config.get("head_block_size", num_devices) num_kv_heads_repeat = qaic_config.get("num_kv_heads_repeat", 1) transform_root = _transform_tracking_root(self.model) applied_transforms = getattr(transform_root, "_qeff_runtime_transforms_applied", set()) - - if ReplicateKVHeadTransform.__name__ in applied_transforms: + should_apply_repeat_kv = is_deepseek_v3 or (num_kv_heads_repeat is not None and num_kv_heads_repeat > 1) + if not should_apply_repeat_kv: + replicate_kv_transformed = False + elif ReplicateKVHeadTransform.__name__ in applied_transforms: replicate_kv_transformed = False logger.warning("Skipping RepeatKVTransform: already applied on this model instance.") else: self.model, replicate_kv_transformed = ReplicateKVHeadTransform.apply( self.model, - num_kv_heads_repeat=num_kv_heads_repeat, + num_kv_heads_repeat, ) if replicate_kv_transformed: applied_transforms.add(ReplicateKVHeadTransform.__name__) diff --git a/QEfficient/transformers/models/pytorch_transforms.py b/QEfficient/transformers/models/pytorch_transforms.py index 53512683a..2a1f8ef2d 100755 --- a/QEfficient/transformers/models/pytorch_transforms.py +++ b/QEfficient/transformers/models/pytorch_transforms.py @@ -1083,7 +1083,6 @@ def _duplicate_weights_for_mla_layer(layer: nn.Module, orig_kv_heads: int, repea new_kv_heads * dim ) - @staticmethod def _duplicate_weights_for_linear_layer( layer: nn.Module, orig_kv_heads: int, repeat: int, head_dim: int, hidden_size: int ): @@ -1300,12 +1299,6 @@ def apply(cls, model: nn.Module, num_kv_heads_repeat: Optional[int] = None, **kw else: kwargs.pop("num_kv_heads_repeat", None) n_repeat = num_kv_heads_repeat - # Validate n_repeat is a positive integer - if not isinstance(n_repeat, int) or n_repeat < 1: - raise ValueError( - f"num_kv_heads_repeat must be a positive integer, got: {n_repeat} (type: {type(n_repeat).__name__})" - ) - transformed = False if n_repeat is not None and n_repeat > 1: if (model.__class__ in cls._module_mapping) or (model.__class__.__name__ in cls._module_string_mapping): From 12ca65ed359b85f1fd3642deb59495bb0b717fa7 Mon Sep 17 00:00:00 2001 From: Dhiraj Kumar Sah Date: Wed, 3 Jun 2026 15:07:04 +0530 Subject: [PATCH 7/8] Addressed comments. Enabled method to calculate best possible repeat_kv count based on model and num devices. Added repeat_kv method for AWQ quantized models. Signed-off-by: Dhiraj Kumar Sah --- QEfficient/base/modeling_qeff.py | 20 +- .../transformers/models/modeling_auto.py | 3 - .../transformers/models/pytorch_transforms.py | 192 ++++++++++++------ QEfficient/utils/config_utils.py | 28 +++ QEfficient/utils/test_utils.py | 2 +- 5 files changed, 172 insertions(+), 73 deletions(-) diff --git a/QEfficient/base/modeling_qeff.py b/QEfficient/base/modeling_qeff.py index 8645e494d..63d0959fb 100755 --- a/QEfficient/base/modeling_qeff.py +++ b/QEfficient/base/modeling_qeff.py @@ -47,6 +47,7 @@ require_value, to_named_specializations, ) +from QEfficient.utils.config_utils import calculate_num_kv_heads_repeat from QEfficient.utils.export_utils import export_wrapper logger = logging.getLogger(__name__) @@ -76,7 +77,6 @@ def __init__(self, model: torch.nn.Module, **kwargs) -> None: self.model = model self.config = model.config self.hash_params = create_model_params(self, **kwargs) - self.hash_params["num_kv_heads_repeat"] = kwargs.get("num_kv_heads_repeat", 1) self.onnx_path: Optional[str] = None self.qpc_path: Optional[str] = None self.qpc_session: Optional[QAICInferenceSession] = None @@ -669,17 +669,20 @@ def _transform_tracking_root(module: torch.nn.Module) -> torch.nn.Module: model_config = getattr(self.model, "config", None) or getattr( getattr(self.model, "model", None), "config", None ) + num_kv_heads_repeat = 1 + if model_config is not None: + num_kv_heads_repeat = calculate_num_kv_heads_repeat( + num_devices=num_devices, + text_model_config=model_config, + ) if model_config: - architectures = getattr(model_config, "architectures", None) or [] - is_deepseek_v3 = "DeepseekV3ForCausalLM" in architectures - if qaic_config: - if qaic_config.get("blocking_mode", None) == "h": - qaic_config["head_block_size"] = qaic_config.get("head_block_size", num_devices) - num_kv_heads_repeat = qaic_config.get("num_kv_heads_repeat", 1) + if qaic_config is not None: + num_kv_heads_repeat = qaic_config.get("num_kv_heads_repeat", num_kv_heads_repeat) + qaic_config["num_kv_heads_repeat"] = num_kv_heads_repeat transform_root = _transform_tracking_root(self.model) applied_transforms = getattr(transform_root, "_qeff_runtime_transforms_applied", set()) - should_apply_repeat_kv = is_deepseek_v3 or (num_kv_heads_repeat is not None and num_kv_heads_repeat > 1) + should_apply_repeat_kv = num_kv_heads_repeat is not None and num_kv_heads_repeat > 1 if not should_apply_repeat_kv: replicate_kv_transformed = False elif ReplicateKVHeadTransform.__name__ in applied_transforms: @@ -711,6 +714,7 @@ def _transform_tracking_root(module: torch.nn.Module) -> torch.nn.Module: if blocking_config is not None: self.model, _ = BlockingAttentionTransform.apply(self.model, attn_blocking_config=blocking_config) self.hash_params["blocking_kwargs"] = blocking_config + self.hash_params["num_kv_heads_repeat"] = num_kv_heads_repeat @dump_qconfig def _compile( diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 12381c513..87f029263 100755 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -1290,7 +1290,6 @@ def __init__( self.ccl_enabled = qaic_config.get("ccl_enabled", False) self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = None, None self.input_shapes, self.output_names = None, None - # self.model, replicate_kv_transformed = ReplicateKVHeadTransform.apply(self.model, **kwargs) # ---Sampling--- # Note: SamplerTransform should be applied after all other transforms # are done. The role of the sampler is to just add nodes at the output of the @@ -2186,7 +2185,6 @@ def __init__( self.model.config.text_config.use_cache = True else: self.model.config.use_cache = True - # self.model, replicate_kv_transformed = ReplicateKVHeadTransform.apply(self.model, **kwargs) self.hash_params["qeff_auto_class"] = self.__class__.__name__ self.ccl_enabled = False if qaic_config: @@ -3057,7 +3055,6 @@ def __init__( setattr(self.model, "mla_absorption", mla_absorption) self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = None, None self.hash_params["max_seq_len_cached"] = max_seq_len_cached - # self.model, replicate_kv_transformed = ReplicateKVHeadTransform.apply(self.model, **kwargs) # ---Sampling--- # Note: SamplerTransform should be applied after all other transforms diff --git a/QEfficient/transformers/models/pytorch_transforms.py b/QEfficient/transformers/models/pytorch_transforms.py index 2a1f8ef2d..84b0ed690 100755 --- a/QEfficient/transformers/models/pytorch_transforms.py +++ b/QEfficient/transformers/models/pytorch_transforms.py @@ -303,7 +303,7 @@ ModuleMutatorTransform, ) from QEfficient.customop import CustomRMSNormAIC, GemmaCustomRMSNormAIC -from QEfficient.customop.matmulnbits import QuantLinearORT +from QEfficient.customop.matmulnbits import QuantLinearORT, dequantize_blockwise_bits from QEfficient.transformers.embeddings.embedding_utils import POOLING_MAP, PooledModel, validate_user_pooling_function from QEfficient.transformers.models.bert.modeling_bert import ( QEffBertModel, @@ -1001,6 +1001,7 @@ class ReplicateKVHeadTransform(ModuleMutatorTransform): QEffMllamaForConditionalGeneration, QEffMistralForCausalLM, QEffMistral3ForConditionalGeneration, + QEffMixtralForCausalLM, QEffMptForCausalLM, QEffPhiForCausalLM, QEffPhi3ForCausalLM, @@ -1071,50 +1072,129 @@ def _is_mla_model(cls, text_model: nn.Module) -> bool: return True return False - @staticmethod - def _duplicate_weights_for_mla_layer(layer: nn.Module, orig_kv_heads: int, repeat: int, dim: int, hidden_size: int): - new_kv_heads = repeat * orig_kv_heads - layer.weight.data = torch.repeat_interleave( - layer.weight.data.view(orig_kv_heads, dim, hidden_size), repeat, 0 - ).view(new_kv_heads * dim, hidden_size) - - if layer.bias is not None: - layer.bias.data = torch.repeat_interleave(layer.bias.data.view(orig_kv_heads, dim), repeat, 0).view( - new_kv_heads * dim - ) - def _duplicate_weights_for_linear_layer( layer: nn.Module, orig_kv_heads: int, repeat: int, head_dim: int, hidden_size: int ): new_kv_heads = repeat * orig_kv_heads - if isinstance(layer, (WQLinear_GEMM, QuantLinearGPTQ, QuantLinearORT)): - if head_dim % 8 != 0: + if isinstance(layer, WQLinear_GEMM): + # AWQ layout: + # qweight: [in_features, out_features/pack] + # qzeros: [in_features/group_size, out_features/pack] + # scales: [in_features/group_size, out_features] + if layer.qweight.shape[1] % orig_kv_heads != 0: raise ValueError( - f"the value head_dim={head_dim} is not divisible by 8 which is according to the assumption that model is 4-bit quantized." + f"Invalid AWQ qweight shape for RepeatKV: qweight.shape={tuple(layer.qweight.shape)}, " + f"orig_kv_heads={orig_kv_heads}" ) - if hidden_size % layer.group_size != 0: + if layer.qzeros.shape[1] % orig_kv_heads != 0 or layer.scales.shape[1] % orig_kv_heads != 0: raise ValueError( - f"The value of hidden_size={hidden_size} is not divisible by k_proj.group_size={layer.group_size}" + f"Invalid AWQ qzeros/scales shape for RepeatKV: qzeros.shape={tuple(layer.qzeros.shape)}, " + f"scales.shape={tuple(layer.scales.shape)}, orig_kv_heads={orig_kv_heads}" ) - # Duplication of quantized weights layer.qweight.data = torch.repeat_interleave( - layer.qweight.data.view(hidden_size, orig_kv_heads, head_dim // 8), repeat, 1 - ).view(hidden_size, (new_kv_heads * head_dim) // 8) - # Duplication of quantized zero points + layer.qweight.data.view(layer.qweight.shape[0], orig_kv_heads, -1), repeat, 1 + ).view(layer.qweight.shape[0], -1) layer.qzeros.data = torch.repeat_interleave( - layer.qzeros.data.view(hidden_size // layer.group_size, orig_kv_heads, head_dim // 8), - repeat, - 1, - ).view(hidden_size // layer.group_size, (new_kv_heads * head_dim) // 8) - # Duplication of quantization scales + layer.qzeros.data.view(layer.qzeros.shape[0], orig_kv_heads, -1), repeat, 1 + ).view(layer.qzeros.shape[0], -1) layer.scales.data = torch.repeat_interleave( - layer.scales.data.view(hidden_size // layer.group_size, orig_kv_heads, head_dim), - repeat, - 1, - ).view(hidden_size // layer.group_size, new_kv_heads * head_dim) + layer.scales.data.view(layer.scales.shape[0], orig_kv_heads, -1), repeat, 1 + ).view(layer.scales.shape[0], -1) + layer.out_features = layer.out_features * repeat + + elif isinstance(layer, QuantLinearGPTQ): + # GPTQ layout: + # qweight: [in_features/pack, out_features] + # qzeros: [in_features/group_size, out_features/pack] + # scales: [in_features/group_size, out_features] + if layer.qweight.shape[1] % orig_kv_heads != 0: + raise ValueError( + f"Invalid GPTQ qweight shape for RepeatKV: qweight.shape={tuple(layer.qweight.shape)}, " + f"orig_kv_heads={orig_kv_heads}" + ) + if layer.qzeros.shape[1] % orig_kv_heads != 0 or layer.scales.shape[1] % orig_kv_heads != 0: + raise ValueError( + f"Invalid GPTQ qzeros/scales shape for RepeatKV: qzeros.shape={tuple(layer.qzeros.shape)}, " + f"scales.shape={tuple(layer.scales.shape)}, orig_kv_heads={orig_kv_heads}" + ) + + layer.qweight.data = torch.repeat_interleave( + layer.qweight.data.view(layer.qweight.shape[0], orig_kv_heads, -1), repeat, 1 + ).view(layer.qweight.shape[0], -1) + layer.qzeros.data = torch.repeat_interleave( + layer.qzeros.data.view(layer.qzeros.shape[0], orig_kv_heads, -1), repeat, 1 + ).view(layer.qzeros.shape[0], -1) + layer.scales.data = torch.repeat_interleave( + layer.scales.data.view(layer.scales.shape[0], orig_kv_heads, -1), repeat, 1 + ).view(layer.scales.shape[0], -1) layer.out_features = layer.out_features * repeat + elif isinstance(layer, QuantLinearORT): + # QuantLinearORT stores blockwise packed buffers. Dequantize, replicate per-KV-head, + # then re-pack using existing QuantLinearORT.pack path. + float_weight, zeros_per_group, scales_per_group = dequantize_blockwise_bits( + layer.qweight, + layer.scales, + layer.qzeros, + layer.bits, + layer.group_size, + layer.g_idx, + layer.in_features, + layer.out_features, + ) + # float_weight: [out_features, in_features] + if float_weight.shape[0] % orig_kv_heads != 0: + raise ValueError( + f"Invalid QuantLinearORT weight shape for RepeatKV: " + f"weight.shape={tuple(float_weight.shape)}, orig_kv_heads={orig_kv_heads}" + ) + + duplicated_weight = torch.repeat_interleave( + float_weight.view(orig_kv_heads, -1, float_weight.shape[1]), + repeat, + dim=0, + ).view(new_kv_heads * (float_weight.shape[0] // orig_kv_heads), float_weight.shape[1]) + + duplicated_zeros = torch.repeat_interleave( + zeros_per_group.view(orig_kv_heads, -1, zeros_per_group.shape[1]), + repeat, + dim=0, + ).view(new_kv_heads * (zeros_per_group.shape[0] // orig_kv_heads), zeros_per_group.shape[1]) + duplicated_scales = torch.repeat_interleave( + scales_per_group.view(orig_kv_heads, -1, scales_per_group.shape[1]), + repeat, + dim=0, + ).view(new_kv_heads * (scales_per_group.shape[0] // orig_kv_heads), scales_per_group.shape[1]) + + original_out_features = layer.out_features + layer.out_features = original_out_features * repeat + q_rows = layer.in_features // layer.group_size + layer.qweight = torch.zeros( + (layer.out_features, q_rows, layer.group_size // (8 // layer.bits)), + dtype=layer.qweight.dtype, + device=layer.qweight.device, + ) + layer.qzeros = torch.zeros( + (q_rows + (q_rows & 1)) * (layer.out_features // 8 * layer.bits), + dtype=layer.qzeros.dtype, + device=layer.qzeros.device, + ) + layer.scales = torch.zeros( + (q_rows * layer.out_features), + dtype=layer.scales.dtype, + device=layer.scales.device, + ) + + linear = nn.Linear(layer.in_features, layer.out_features, bias=False, dtype=duplicated_weight.dtype) + linear.weight.data = duplicated_weight.to(linear.weight.dtype) + layer.pack( + linear, + duplicated_scales.contiguous().to(layer.scales.dtype), + duplicated_zeros.contiguous().to(torch.int32), + layer.g_idx, + ) + elif isinstance(layer, FP8DeQuantLinear): layer.weight.data = torch.repeat_interleave( layer.weight.data.view(orig_kv_heads, head_dim, hidden_size), repeat, 0 @@ -1210,6 +1290,7 @@ def mutate(cls, original_module: nn.Module, parent_module: nn.Module, n_repeat: Returns: The mutated module (same object, modified in-place). """ + # breakpoint() replication_root = cls._get_replication_root(original_module) if getattr(replication_root, "_qeff_kv_replication_applied", False): logger.warning("KV head replication already applied for this model instance; skipping.") @@ -1217,10 +1298,13 @@ def mutate(cls, original_module: nn.Module, parent_module: nn.Module, n_repeat: text_model = cls._get_text_model(original_module) cfg = text_model.config + if cls._is_mla_model(text_model): + logger.warning("Skipping RepeatKVTransform: MLA models don't apply replicate KV changes.") + return original_module + orig_kv_heads = resolve_kv_heads(cfg) num_attention_heads = resolve_attention_heads(cfg) hidden_size = resolve_hidden_size(cfg) - is_mla_model = cls._is_mla_model(text_model) if orig_kv_heads is None or num_attention_heads is None or hidden_size is None: raise ValueError( @@ -1233,12 +1317,8 @@ def mutate(cls, original_module: nn.Module, parent_module: nn.Module, n_repeat: f"Invalid head values for RepeatKV transform: " f"num_attention_heads={num_attention_heads}, num_key_value_heads={orig_kv_heads}" ) - if is_mla_model: - # Legacy MLA path treats compressed-KV projection as single KV head. - orig_kv_heads = 1 - new_kv_heads = n_repeat * orig_kv_heads - if (not is_mla_model) and (new_kv_heads > num_attention_heads or (num_attention_heads % new_kv_heads) != 0): + if new_kv_heads > num_attention_heads or (num_attention_heads % new_kv_heads) != 0: raise ValueError( f"Invalid RepeatKV configuration: num_attention_heads={num_attention_heads}, " f"orig_kv_heads={orig_kv_heads}, num_kv_heads_repeat={n_repeat}, new_kv_heads={new_kv_heads}. " @@ -1257,29 +1337,16 @@ def mutate(cls, original_module: nn.Module, parent_module: nn.Module, n_repeat: if hasattr(attn, "n_kv_heads"): attn.n_kv_heads = new_kv_heads - if cls._is_mla_attention(attn): - # Legacy MLA support: KV compression projection is organized as - # [kv_heads, kv_lora_rank + qk_rope_head_dim, hidden_size]. - mla_orig_kv_heads = 1 - mla_head_dim = int(attn.kv_lora_rank + attn.qk_rope_head_dim) - cls._duplicate_weights_for_mla_layer( - attn.kv_a_proj_with_mqa, - mla_orig_kv_heads, - n_repeat, - mla_head_dim, - hidden_size, - ) - else: - n_kv_groups = num_attention_heads // new_kv_heads - if hasattr(attn, "num_key_value_groups"): - attn.num_key_value_groups = n_kv_groups - if hasattr(attn, "n_kv_groups"): - attn.n_kv_groups = n_kv_groups - head_dim = getattr(attn, "head_dim", hidden_size // num_attention_heads) - k_proj = cls._get_projection_layer(attn, ("k_proj", "key_proj")) - v_proj = cls._get_projection_layer(attn, ("v_proj", "value_proj")) - cls._duplicate_weights_for_linear_layer(k_proj, orig_kv_heads, n_repeat, head_dim, hidden_size) - cls._duplicate_weights_for_linear_layer(v_proj, orig_kv_heads, n_repeat, head_dim, hidden_size) + n_kv_groups = num_attention_heads // new_kv_heads + if hasattr(attn, "num_key_value_groups"): + attn.num_key_value_groups = n_kv_groups + if hasattr(attn, "n_kv_groups"): + attn.n_kv_groups = n_kv_groups + head_dim = getattr(attn, "head_dim", hidden_size // num_attention_heads) + k_proj = cls._get_projection_layer(attn, ("k_proj", "key_proj")) + v_proj = cls._get_projection_layer(attn, ("v_proj", "value_proj")) + cls._duplicate_weights_for_linear_layer(k_proj, orig_kv_heads, n_repeat, head_dim, hidden_size) + cls._duplicate_weights_for_linear_layer(v_proj, orig_kv_heads, n_repeat, head_dim, hidden_size) setattr(replication_root, "_qeff_kv_replication_applied", True) return original_module @@ -1302,8 +1369,11 @@ def apply(cls, model: nn.Module, num_kv_heads_repeat: Optional[int] = None, **kw transformed = False if n_repeat is not None and n_repeat > 1: if (model.__class__ in cls._module_mapping) or (model.__class__.__name__ in cls._module_string_mapping): + transform_root = cls._get_replication_root(model) + was_applied = getattr(transform_root, "_qeff_kv_replication_applied", False) cls.mutate(model, None, n_repeat) - transformed = True + is_applied = getattr(transform_root, "_qeff_kv_replication_applied", False) + transformed = (not was_applied) and is_applied else: raise NotImplementedError( f"Model class {model.__class__.__name__} is not supported for KV head replication." diff --git a/QEfficient/utils/config_utils.py b/QEfficient/utils/config_utils.py index 1e6125d7e..215e92b59 100644 --- a/QEfficient/utils/config_utils.py +++ b/QEfficient/utils/config_utils.py @@ -38,3 +38,31 @@ def set_kv_head_aliases(config, value: int): for key in KV_HEAD_CONFIG_KEYS: if hasattr(config, key): setattr(config, key, value) + + +def calculate_num_kv_heads_repeat(num_devices: int, text_model_config) -> int: + """ + Choose a KV-repeat value from model config and device count. + + Primary criteria: + 1. num_kv_heads * repeat is divisible by num_devices + 2. num_attention_heads is divisible by (num_kv_heads * repeat) + + Fallback: + repeat = num_attention_heads / num_kv_heads (integer-truncated if needed). + """ + num_attention_heads = resolve_attention_heads(text_model_config) + num_kv_heads = resolve_kv_heads(text_model_config) + + if num_attention_heads is None or num_kv_heads is None or num_attention_heads < 1 or num_kv_heads < 1: + return 1 + + num_devices = max(1, int(num_devices)) + max_repeat = max(1, int(num_attention_heads / num_kv_heads)) + + for repeat in range(max_repeat, 0, -1): + repeated_kv_heads = num_kv_heads * repeat + if (repeated_kv_heads % num_devices == 0) and (num_attention_heads % repeated_kv_heads == 0): + return repeat + + return 1 diff --git a/QEfficient/utils/test_utils.py b/QEfficient/utils/test_utils.py index 371202111..34bc474e5 100644 --- a/QEfficient/utils/test_utils.py +++ b/QEfficient/utils/test_utils.py @@ -509,7 +509,7 @@ class ModelConfig: "meta-llama/Llama-3.2-1B", # "unsloth/gemma-2b", # "unsloth/gemma-2-2b", - # "TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ", + "TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ", "TheBloke/Llama-2-7B-GPTQ", "neuralmagic/Llama-3.2-3B-Instruct-FP8", "ibm-granite/granite-3.1-2b-instruct", From 008adc73ee2585d4af9b7ad165f5b009f8b1087d Mon Sep 17 00:00:00 2001 From: Dhiraj Kumar Sah Date: Fri, 5 Jun 2026 13:47:34 +0530 Subject: [PATCH 8/8] Renamed num_kv_heads_repeat to num_replicate_kv_heads as suggested. Signed-off-by: Dhiraj Kumar Sah --- QEfficient/base/modeling_qeff.py | 16 ++++++------- .../transformers/models/modeling_auto.py | 24 +++++++++---------- .../transformers/models/pytorch_transforms.py | 14 +++++------ QEfficient/utils/config_utils.py | 2 +- examples/kimi_k2/README.md | 4 ++-- examples/kimi_k2/export_kimik2.py | 8 +++---- examples/text_generation/run_kimik2.py | 8 +++---- .../causal_lm_models/check_causal_models.py | 4 ++-- .../test_image_text_to_text_models.py | 16 ++++++------- 9 files changed, 48 insertions(+), 48 deletions(-) diff --git a/QEfficient/base/modeling_qeff.py b/QEfficient/base/modeling_qeff.py index 63d0959fb..defb44956 100755 --- a/QEfficient/base/modeling_qeff.py +++ b/QEfficient/base/modeling_qeff.py @@ -47,7 +47,7 @@ require_value, to_named_specializations, ) -from QEfficient.utils.config_utils import calculate_num_kv_heads_repeat +from QEfficient.utils.config_utils import calculate_num_replicate_kv_heads from QEfficient.utils.export_utils import export_wrapper logger = logging.getLogger(__name__) @@ -669,20 +669,20 @@ def _transform_tracking_root(module: torch.nn.Module) -> torch.nn.Module: model_config = getattr(self.model, "config", None) or getattr( getattr(self.model, "model", None), "config", None ) - num_kv_heads_repeat = 1 + num_replicate_kv_heads = 1 if model_config is not None: - num_kv_heads_repeat = calculate_num_kv_heads_repeat( + num_replicate_kv_heads = calculate_num_replicate_kv_heads( num_devices=num_devices, text_model_config=model_config, ) if model_config: if qaic_config is not None: - num_kv_heads_repeat = qaic_config.get("num_kv_heads_repeat", num_kv_heads_repeat) - qaic_config["num_kv_heads_repeat"] = num_kv_heads_repeat + num_replicate_kv_heads = qaic_config.get("num_replicate_kv_heads", num_replicate_kv_heads) + qaic_config["num_replicate_kv_heads"] = num_replicate_kv_heads transform_root = _transform_tracking_root(self.model) applied_transforms = getattr(transform_root, "_qeff_runtime_transforms_applied", set()) - should_apply_repeat_kv = num_kv_heads_repeat is not None and num_kv_heads_repeat > 1 + should_apply_repeat_kv = num_replicate_kv_heads is not None and num_replicate_kv_heads > 1 if not should_apply_repeat_kv: replicate_kv_transformed = False elif ReplicateKVHeadTransform.__name__ in applied_transforms: @@ -691,7 +691,7 @@ def _transform_tracking_root(module: torch.nn.Module) -> torch.nn.Module: else: self.model, replicate_kv_transformed = ReplicateKVHeadTransform.apply( self.model, - num_kv_heads_repeat, + num_replicate_kv_heads, ) if replicate_kv_transformed: applied_transforms.add(ReplicateKVHeadTransform.__name__) @@ -714,7 +714,7 @@ def _transform_tracking_root(module: torch.nn.Module) -> torch.nn.Module: if blocking_config is not None: self.model, _ = BlockingAttentionTransform.apply(self.model, attn_blocking_config=blocking_config) self.hash_params["blocking_kwargs"] = blocking_config - self.hash_params["num_kv_heads_repeat"] = num_kv_heads_repeat + self.hash_params["num_replicate_kv_heads"] = num_replicate_kv_heads @dump_qconfig def _compile( diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 87f029263..e2cccdf96 100755 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -1326,7 +1326,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: str, qaic_config: Option kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False}) _resolve_torch_dtype(kwargs) - num_kv_heads_repeat = kwargs.pop("num_kv_heads_repeat", 1) + num_replicate_kv_heads = kwargs.pop("num_replicate_kv_heads", 1) model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, **kwargs) kwargs.update({"enable_proxy": enable_proxy} if enable_proxy else {}) @@ -1335,7 +1335,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: str, qaic_config: Option model, pretrained_model_name_or_path=pretrained_model_name_or_path, qaic_config=qaic_config, - num_kv_heads_repeat=num_kv_heads_repeat, + num_replicate_kv_heads=num_replicate_kv_heads, **kwargs, ) @@ -1435,12 +1435,12 @@ def export( if prefill_only and prefill_seq_len > 1: offload_pt_weights = False # to keep weight for decode onnx else: - num_kv_heads_repeat = ( - (self.lang_model.model.qaic_config or {}).get("num_kv_heads_repeat", 1) + num_replicate_kv_heads = ( + (self.lang_model.model.qaic_config or {}).get("num_replicate_kv_heads", 1) if hasattr(self.lang_model.model, "qaic_config") else 1 ) - offload_pt_weights = kwargs.get("offload_pt_weights", num_kv_heads_repeat <= 1) + offload_pt_weights = kwargs.get("offload_pt_weights", num_replicate_kv_heads <= 1) if not skip_lang and self.lang_model.onnx_path is None: self.lang_model.export( @@ -2235,7 +2235,7 @@ def from_pretrained( config._attn_implementation = "eager" config.vision_config.use_flash_attn = "false" _resolve_torch_dtype(kwargs) - num_kv_heads_repeat = kwargs.pop("num_kv_heads_repeat", 1) + num_replicate_kv_heads = kwargs.pop("num_replicate_kv_heads", 1) model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, config, *args, **kwargs) kwargs.update({"enable_proxy": enable_proxy} if enable_proxy else {}) @@ -2244,7 +2244,7 @@ def from_pretrained( model, pretrained_model_name_or_path=pretrained_model_name_or_path, qaic_config=qaic_config, - num_kv_heads_repeat=num_kv_heads_repeat, + num_replicate_kv_heads=num_replicate_kv_heads, **kwargs, ) @@ -2884,7 +2884,7 @@ def from_pretrained( kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False}) _resolve_torch_dtype(kwargs) - num_kv_heads_repeat = kwargs.pop("num_kv_heads_repeat", 1) + num_replicate_kv_heads = kwargs.pop("num_replicate_kv_heads", 1) model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, **kwargs) kwargs.update({"enable_proxy": enable_proxy} if enable_proxy else {}) @@ -2895,7 +2895,7 @@ def from_pretrained( continuous_batching=continuous_batching, pretrained_model_name_or_path=pretrained_model_name_or_path, qaic_config=qaic_config, - num_kv_heads_repeat=num_kv_heads_repeat, + num_replicate_kv_heads=num_replicate_kv_heads, **kwargs, ) @@ -3140,7 +3140,7 @@ def from_pretrained( kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False}) _resolve_torch_dtype(kwargs) - num_kv_heads_repeat = kwargs.pop("num_kv_heads_repeat", 1) + num_replicate_kv_heads = kwargs.pop("num_replicate_kv_heads", 1) model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, *args, **kwargs) if qaic_config is not None: qaic_config["pretrained_model_name_or_path"] = pretrained_model_name_or_path @@ -3154,7 +3154,7 @@ def from_pretrained( pretrained_model_name_or_path=pretrained_model_name_or_path, qaic_config=qaic_config, continuous_batching=continuous_batching, - num_kv_heads_repeat=num_kv_heads_repeat, + num_replicate_kv_heads=num_replicate_kv_heads, **kwargs, ) return cls( @@ -3163,7 +3163,7 @@ def from_pretrained( qaic_config=qaic_config, pretrained_model_name_or_path=pretrained_model_name_or_path, max_seq_len_cached=max_seq_len_cached, - num_kv_heads_repeat=num_kv_heads_repeat, + num_replicate_kv_heads=num_replicate_kv_heads, **kwargs, ) diff --git a/QEfficient/transformers/models/pytorch_transforms.py b/QEfficient/transformers/models/pytorch_transforms.py index 84b0ed690..fe70fb055 100755 --- a/QEfficient/transformers/models/pytorch_transforms.py +++ b/QEfficient/transformers/models/pytorch_transforms.py @@ -1321,7 +1321,7 @@ def mutate(cls, original_module: nn.Module, parent_module: nn.Module, n_repeat: if new_kv_heads > num_attention_heads or (num_attention_heads % new_kv_heads) != 0: raise ValueError( f"Invalid RepeatKV configuration: num_attention_heads={num_attention_heads}, " - f"orig_kv_heads={orig_kv_heads}, num_kv_heads_repeat={n_repeat}, new_kv_heads={new_kv_heads}. " + f"orig_kv_heads={orig_kv_heads}, num_replicate_kv_heads={n_repeat}, new_kv_heads={new_kv_heads}. " "Expected new_kv_heads <= num_attention_heads and divisibility." ) @@ -1352,20 +1352,20 @@ def mutate(cls, original_module: nn.Module, parent_module: nn.Module, n_repeat: return original_module @classmethod - def apply(cls, model: nn.Module, num_kv_heads_repeat: Optional[int] = None, **kwargs) -> Tuple[nn.Module, bool]: + def apply(cls, model: nn.Module, num_replicate_kv_heads: Optional[int] = None, **kwargs) -> Tuple[nn.Module, bool]: """ Replicates KV heads in attention modules based on provided multiplier. Args: model: The model to apply the transform to. kwargs: Additional arguments for the transformation. Includes: - - num_kv_heads_repeat: The number of times to repeat the KV heads. + - num_replicate_kv_heads: The number of times to repeat the KV heads. """ - if num_kv_heads_repeat is None: - n_repeat = kwargs.pop("num_kv_heads_repeat", 1) + if num_replicate_kv_heads is None: + n_repeat = kwargs.pop("num_replicate_kv_heads", 1) else: - kwargs.pop("num_kv_heads_repeat", None) - n_repeat = num_kv_heads_repeat + kwargs.pop("num_replicate_kv_heads", None) + n_repeat = num_replicate_kv_heads transformed = False if n_repeat is not None and n_repeat > 1: if (model.__class__ in cls._module_mapping) or (model.__class__.__name__ in cls._module_string_mapping): diff --git a/QEfficient/utils/config_utils.py b/QEfficient/utils/config_utils.py index 215e92b59..4b28d5488 100644 --- a/QEfficient/utils/config_utils.py +++ b/QEfficient/utils/config_utils.py @@ -40,7 +40,7 @@ def set_kv_head_aliases(config, value: int): setattr(config, key, value) -def calculate_num_kv_heads_repeat(num_devices: int, text_model_config) -> int: +def calculate_num_replicate_kv_heads(num_devices: int, text_model_config) -> int: """ Choose a KV-repeat value from model config and device count. diff --git a/examples/kimi_k2/README.md b/examples/kimi_k2/README.md index 230127ebb..4fae4a8cf 100644 --- a/examples/kimi_k2/README.md +++ b/examples/kimi_k2/README.md @@ -20,9 +20,9 @@ mla_absorption has 3 keys: # Blocking We have also implemented KV head replication, HEAD Blocking and KV Blocking which can be enable like this : - For No Blocking : qaic_config = {"mla_absorption" : mla_absorption} -- For No blocking with kv head replication : qaic_config = {"mla_absorption" : mla_absorption, "num_kv_heads_repeat": TS} +- For No blocking with kv head replication : qaic_config = {"mla_absorption" : mla_absorption, "num_replicate_kv_heads": TS} - For KV blocking : qaic_config = {"mla_absorption" : mla_absorption, "enable_blocking": True, "blocking_mode": "kv"} # for KV blocking -- For Head Blocking : qaic_config = {"mla_absorption" : mla_absorption, "enable_blocking": True, "blocking_mode": "h", "num_kv_heads_repeat": TS} for h blocking, it internally sets head_block_size equal to num_devices/num_kv_heads_repeat +- For Head Blocking : qaic_config = {"mla_absorption" : mla_absorption, "enable_blocking": True, "blocking_mode": "h", "num_replicate_kv_heads": TS} for h blocking, it internally sets head_block_size equal to num_devices/num_replicate_kv_heads - Currently Decode-Only model is giving best perf with Head Blocking and compressed cache. - Contnuous batching is not enabled yet. \ No newline at end of file diff --git a/examples/kimi_k2/export_kimik2.py b/examples/kimi_k2/export_kimik2.py index 1e7035216..ba6b26c06 100644 --- a/examples/kimi_k2/export_kimik2.py +++ b/examples/kimi_k2/export_kimik2.py @@ -18,16 +18,16 @@ # qaic_config = None # Full PKV Cache # qaic_config = {"enable_blocking": True, "blocking_mode": "h"} # Full PKV Cache with Head Blocking # qaic_config = {"mla_absorption": mla_absorption} # for No Blocking -# qaic_config = {"mla_absorption": mla_absorption, "num_kv_heads_repeat": TS} # No blocking with kv head replication +# qaic_config = {"mla_absorption": mla_absorption, "num_replicate_kv_heads": TS} # No blocking with kv head replication # qaic_config = {"mla_absorption": mla_absorption, "enable_blocking": True, "blocking_mode": "kv"} # for KV blocking -# qaic_config = {"mla_absorption": mla_absorption, "enable_blocking": True, "blocking_mode": "kv", "num_kv_heads_repeat":TS} # for KV blocking with kv head replication +# qaic_config = {"mla_absorption": mla_absorption, "enable_blocking": True, "blocking_mode": "kv", "num_replicate_kv_heads":TS} # for KV blocking with kv head replication qaic_config = { "mla_absorption": mla_absorption, "enable_blocking": True, "blocking_mode": "h", - "num_kv_heads_repeat": TS, + "num_replicate_kv_heads": TS, } -# for h blocking, it internally sets head_block_size equal to num_devices/num_kv_heads_repeat +# for h blocking, it internally sets head_block_size equal to num_devices/num_replicate_kv_heads model_name = "moonshotai/Kimi-K2-Thinking" model = AutoModelForCausalLM.from_pretrained( diff --git a/examples/text_generation/run_kimik2.py b/examples/text_generation/run_kimik2.py index 81767308a..e85c57242 100644 --- a/examples/text_generation/run_kimik2.py +++ b/examples/text_generation/run_kimik2.py @@ -19,16 +19,16 @@ # qaic_config = None # Full PKV Cache # qaic_config = {"enable_blocking": True, "blocking_mode": "h"} # Full PKV Cache with Head Blocking # qaic_config = {"mla_absorption": mla_absorption} # for No Blocking -# qaic_config = {"mla_absorption": mla_absorption, "num_kv_heads_repeat": TS} # No blocking with kv head replication +# qaic_config = {"mla_absorption": mla_absorption, "num_replicate_kv_heads": TS} # No blocking with kv head replication # qaic_config = {"mla_absorption": mla_absorption, "enable_blocking": True, "blocking_mode": "kv"} # for KV blocking -# qaic_config = {"mla_absorption": mla_absorption, "enable_blocking": True, "blocking_mode": "kv", "num_kv_heads_repeat":TS} # for KV blocking with kv head replication +# qaic_config = {"mla_absorption": mla_absorption, "enable_blocking": True, "blocking_mode": "kv", "num_replicate_kv_heads":TS} # for KV blocking with kv head replication qaic_config = { "mla_absorption": mla_absorption, "enable_blocking": True, "blocking_mode": "h", - "num_kv_heads_repeat": TS, + "num_replicate_kv_heads": TS, } -# for h blocking, it internally sets head_block_size equal to num_devices/num_kv_heads_repeat +# for h blocking, it internally sets head_block_size equal to num_devices/num_replicate_kv_heads model_name = "moonshotai/Kimi-K2-Thinking" model = AutoModelForCausalLM.from_pretrained( diff --git a/tests/transformers/models/causal_lm_models/check_causal_models.py b/tests/transformers/models/causal_lm_models/check_causal_models.py index e604cb72f..78ff74cbf 100644 --- a/tests/transformers/models/causal_lm_models/check_causal_models.py +++ b/tests/transformers/models/causal_lm_models/check_causal_models.py @@ -73,7 +73,7 @@ def check_kv_repeat_causal_lm_pytorch_vs_ai100( f"Invalid heads in config for RepeatKV: num_attention_heads ({num_attention_heads}) " f"is not divisible by num_key_value_heads ({num_key_value_heads})." ) - num_kv_heads_repeat = num_attention_heads // num_key_value_heads + num_replicate_kv_heads = num_attention_heads // num_key_value_heads check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( model_name=model_name, @@ -82,7 +82,7 @@ def check_kv_repeat_causal_lm_pytorch_vs_ai100( ctx_len=ctx_len, n_layer=n_layer, config=config, - qaic_config={"num_kv_heads_repeat": num_kv_heads_repeat}, + qaic_config={"num_replicate_kv_heads": num_replicate_kv_heads}, ) diff --git a/tests/transformers/models/image_text_to_text/test_image_text_to_text_models.py b/tests/transformers/models/image_text_to_text/test_image_text_to_text_models.py index 1495ffb0b..df9c3b9e8 100644 --- a/tests/transformers/models/image_text_to_text/test_image_text_to_text_models.py +++ b/tests/transformers/models/image_text_to_text/test_image_text_to_text_models.py @@ -58,7 +58,7 @@ def check_image_text_to_text_pytorch_vs_kv_vs_ort_vs_ai100( qnn_config: Optional[str] = None, config: Optional[AutoConfig] = None, qaic_config: Optional[dict] = None, - num_kv_heads_repeat: Optional[int] = 1, + num_replicate_kv_heads: Optional[int] = 1, test_kv_replicate: Optional[bool] = None, torch_dtype: Optional[torch.dtype] = torch.float32, compare_results: Optional[bool] = False, @@ -82,9 +82,9 @@ def check_image_text_to_text_pytorch_vs_kv_vs_ort_vs_ai100( config = set_num_layers_vlm(config, n_layer=n_layer) if test_kv_replicate: text_config = get_text_config(config) - num_kv_heads_repeat = text_config.num_attention_heads // text_config.num_key_value_heads + num_replicate_kv_heads = text_config.num_attention_heads // text_config.num_key_value_heads qaic_config = qaic_config or {} - qaic_config["num_kv_heads_repeat"] = num_kv_heads_repeat + qaic_config["num_replicate_kv_heads"] = num_replicate_kv_heads if hasattr(config, "model_type") and config.model_type in ["gemma3"]: config.text_config._sliding_window_pattern = 2 config.text_config.layer_types = ["sliding_attention", "full_attention"] @@ -104,7 +104,7 @@ def check_image_text_to_text_pytorch_vs_kv_vs_ort_vs_ai100( config=config, qaic_config=qaic_config, torch_dtype=torch_dtype, - num_kv_heads_repeat=num_kv_heads_repeat, + num_replicate_kv_heads=num_replicate_kv_heads, ) else: model_hf = load_vlm_model(config) @@ -114,14 +114,14 @@ def check_image_text_to_text_pytorch_vs_kv_vs_ort_vs_ai100( config=config, qaic_config=qaic_config, torch_dtype=torch_dtype, - num_kv_heads_repeat=num_kv_heads_repeat, + num_replicate_kv_heads=num_replicate_kv_heads, ) else: if test_kv_replicate: text_config = get_text_config(config) - num_kv_heads_repeat = text_config.num_attention_heads // text_config.num_key_value_heads + num_replicate_kv_heads = text_config.num_attention_heads // text_config.num_key_value_heads qaic_config = qaic_config or {} - qaic_config["num_kv_heads_repeat"] = num_kv_heads_repeat + qaic_config["num_replicate_kv_heads"] = num_replicate_kv_heads model_hf = load_vlm_model_from_config(config) qeff_model = QEFFAutoModelForImageTextToText( copy.deepcopy(model_hf), @@ -129,7 +129,7 @@ def check_image_text_to_text_pytorch_vs_kv_vs_ort_vs_ai100( config=model_hf.config, qaic_config=qaic_config, torch_dtype=torch_dtype, - num_kv_heads_repeat=num_kv_heads_repeat, + num_replicate_kv_heads=num_replicate_kv_heads, ) compile_kwargs = { "num_devices": num_devices,