diff --git a/QEfficient/base/modeling_qeff.py b/QEfficient/base/modeling_qeff.py index 17b87afd14..defb449561 100755 --- a/QEfficient/base/modeling_qeff.py +++ b/QEfficient/base/modeling_qeff.py @@ -47,6 +47,7 @@ require_value, to_named_specializations, ) +from QEfficient.utils.config_utils import calculate_num_replicate_kv_heads from QEfficient.utils.export_utils import export_wrapper logger = logging.getLogger(__name__) @@ -655,23 +656,48 @@ def transform( **compiler_options, ): # Apply the transformations that are dependent on compilation parameters + def _transform_tracking_root(module: torch.nn.Module) -> torch.nn.Module: + """ + Use the shared wrapped model as transform-tracking root when available. + This lets encoder/decoder wrappers coordinate one-time transforms. + """ + wrapped = getattr(module, "model", None) + return wrapped if isinstance(wrapped, torch.nn.Module) else module qaic_config = qaic_config if qaic_config else getattr(self.model, "qaic_config", None) - model_config = getattr(self.model, "config", None) or getattr(self.model.model, "config", None) + model_config = getattr(self.model, "config", None) or getattr( + getattr(self.model, "model", None), "config", None + ) + num_replicate_kv_heads = 1 + if model_config is not None: + num_replicate_kv_heads = calculate_num_replicate_kv_heads( + num_devices=num_devices, + text_model_config=model_config, + ) if model_config: - if "DeepseekV3ForCausalLM" in (getattr(model_config, "architectures", None) or []): - if qaic_config: - if qaic_config.get("blocking_mode", None) == "h": - qaic_config["head_block_size"] = qaic_config.get("head_block_size", num_devices) - num_kv_heads_repeat = qaic_config.get("num_kv_heads_repeat", 1) + if qaic_config is not None: + num_replicate_kv_heads = qaic_config.get("num_replicate_kv_heads", num_replicate_kv_heads) + qaic_config["num_replicate_kv_heads"] = num_replicate_kv_heads + transform_root = _transform_tracking_root(self.model) + applied_transforms = getattr(transform_root, "_qeff_runtime_transforms_applied", set()) + should_apply_repeat_kv = num_replicate_kv_heads is not None and num_replicate_kv_heads > 1 + if not should_apply_repeat_kv: + replicate_kv_transformed = False + elif ReplicateKVHeadTransform.__name__ in applied_transforms: + replicate_kv_transformed = False + logger.warning("Skipping RepeatKVTransform: already applied on this model instance.") + else: self.model, replicate_kv_transformed = ReplicateKVHeadTransform.apply( - self.model, num_kv_heads_repeat + self.model, + num_replicate_kv_heads, ) if replicate_kv_transformed: - self.hash_params["config"] = self.model.config.to_diff_dict() - + applied_transforms.add(ReplicateKVHeadTransform.__name__) + setattr(transform_root, "_qeff_runtime_transforms_applied", applied_transforms) + if replicate_kv_transformed: + self.hash_params["config"] = self.model.config.to_diff_dict() blocking_config = build_transformer_blocking_config_for_transform( model_config, ctx_len=ctx_len, @@ -688,6 +714,7 @@ def transform( if blocking_config is not None: self.model, _ = BlockingAttentionTransform.apply(self.model, attn_blocking_config=blocking_config) self.hash_params["blocking_kwargs"] = blocking_config + self.hash_params["num_replicate_kv_heads"] = num_replicate_kv_heads @dump_qconfig def _compile( diff --git a/QEfficient/transformers/models/gemma3/modeling_gemma3.py b/QEfficient/transformers/models/gemma3/modeling_gemma3.py index a3e9257a73..524a220811 100644 --- a/QEfficient/transformers/models/gemma3/modeling_gemma3.py +++ b/QEfficient/transformers/models/gemma3/modeling_gemma3.py @@ -626,6 +626,7 @@ class QEffGemma3EncoderWrapper(nn.Module): def __init__(self, model): super().__init__() self.model = model.model + self.config = self.model.config self.model.vision_model = self.model.vision_tower def get_submodules_for_export(self) -> Type[nn.Module]: diff --git a/QEfficient/transformers/models/internvl/modeling_internvl.py b/QEfficient/transformers/models/internvl/modeling_internvl.py index 563c42e256..7a0b7d524d 100644 --- a/QEfficient/transformers/models/internvl/modeling_internvl.py +++ b/QEfficient/transformers/models/internvl/modeling_internvl.py @@ -20,6 +20,7 @@ class QEffInternEncoderWrapper(nn.Module): def __init__(self, model): super().__init__() self.model = model + self.config = self.model.config def get_submodules_for_export(self) -> Type[nn.Module]: """ diff --git a/QEfficient/transformers/models/llama4/modeling_llama4.py b/QEfficient/transformers/models/llama4/modeling_llama4.py index 2cf5dbb2e9..c2c4b8ad7e 100644 --- a/QEfficient/transformers/models/llama4/modeling_llama4.py +++ b/QEfficient/transformers/models/llama4/modeling_llama4.py @@ -831,6 +831,7 @@ class QEffLlama4EncoderWrapper(nn.Module): def __init__(self, model): super().__init__() self.model = model + self.config = self.model.config def get_submodules_for_export(self) -> Type[nn.Module]: """ diff --git a/QEfficient/transformers/models/llava/modeling_llava.py b/QEfficient/transformers/models/llava/modeling_llava.py index 88bb5e1027..a4005497bd 100644 --- a/QEfficient/transformers/models/llava/modeling_llava.py +++ b/QEfficient/transformers/models/llava/modeling_llava.py @@ -29,6 +29,7 @@ def __init__(self, model): super().__init__() self.model = model self.model.vision_model = self.model.model.vision_tower + self.config = self.model.config def get_submodules_for_export(self) -> Type[nn.Module]: """ diff --git a/QEfficient/transformers/models/llava_next/modeling_llava_next.py b/QEfficient/transformers/models/llava_next/modeling_llava_next.py index 342269ce50..43adfe7c5b 100755 --- a/QEfficient/transformers/models/llava_next/modeling_llava_next.py +++ b/QEfficient/transformers/models/llava_next/modeling_llava_next.py @@ -29,6 +29,7 @@ def __init__(self, model): super().__init__() self.model = model self.model.vision_model = self.model.model.vision_tower + self.config = self.model.config def get_submodules_for_export(self) -> Type[nn.Module]: """ diff --git a/QEfficient/transformers/models/mistral3/modeling_mistral3.py b/QEfficient/transformers/models/mistral3/modeling_mistral3.py index 628d1dee2c..3406791b70 100644 --- a/QEfficient/transformers/models/mistral3/modeling_mistral3.py +++ b/QEfficient/transformers/models/mistral3/modeling_mistral3.py @@ -183,6 +183,7 @@ class QEFFMistral3EncoderWrapper(nn.Module): def __init__(self, model): super().__init__() self.model = model + self.config = self.model.config self.model.model.vision_model = self.model.model.vision_tower def get_submodules_for_export(self) -> Type[nn.Module]: diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 57689ede68..e2cccdf967 100755 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -1326,6 +1326,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: str, qaic_config: Option kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False}) _resolve_torch_dtype(kwargs) + num_replicate_kv_heads = kwargs.pop("num_replicate_kv_heads", 1) model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, **kwargs) kwargs.update({"enable_proxy": enable_proxy} if enable_proxy else {}) @@ -1334,6 +1335,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: str, qaic_config: Option model, pretrained_model_name_or_path=pretrained_model_name_or_path, qaic_config=qaic_config, + num_replicate_kv_heads=num_replicate_kv_heads, **kwargs, ) @@ -1433,7 +1435,12 @@ def export( if prefill_only and prefill_seq_len > 1: offload_pt_weights = False # to keep weight for decode onnx else: - offload_pt_weights = kwargs.get("offload_pt_weights", True) + num_replicate_kv_heads = ( + (self.lang_model.model.qaic_config or {}).get("num_replicate_kv_heads", 1) + if hasattr(self.lang_model.model, "qaic_config") + else 1 + ) + offload_pt_weights = kwargs.get("offload_pt_weights", num_replicate_kv_heads <= 1) if not skip_lang and self.lang_model.onnx_path is None: self.lang_model.export( @@ -2228,6 +2235,7 @@ def from_pretrained( config._attn_implementation = "eager" config.vision_config.use_flash_attn = "false" _resolve_torch_dtype(kwargs) + num_replicate_kv_heads = kwargs.pop("num_replicate_kv_heads", 1) model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, config, *args, **kwargs) kwargs.update({"enable_proxy": enable_proxy} if enable_proxy else {}) @@ -2236,6 +2244,7 @@ def from_pretrained( model, pretrained_model_name_or_path=pretrained_model_name_or_path, qaic_config=qaic_config, + num_replicate_kv_heads=num_replicate_kv_heads, **kwargs, ) @@ -2875,6 +2884,7 @@ def from_pretrained( kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False}) _resolve_torch_dtype(kwargs) + num_replicate_kv_heads = kwargs.pop("num_replicate_kv_heads", 1) model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, **kwargs) kwargs.update({"enable_proxy": enable_proxy} if enable_proxy else {}) @@ -2885,6 +2895,7 @@ def from_pretrained( continuous_batching=continuous_batching, pretrained_model_name_or_path=pretrained_model_name_or_path, qaic_config=qaic_config, + num_replicate_kv_heads=num_replicate_kv_heads, **kwargs, ) @@ -3129,6 +3140,7 @@ def from_pretrained( kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False}) _resolve_torch_dtype(kwargs) + num_replicate_kv_heads = kwargs.pop("num_replicate_kv_heads", 1) model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, *args, **kwargs) if qaic_config is not None: qaic_config["pretrained_model_name_or_path"] = pretrained_model_name_or_path @@ -3142,6 +3154,7 @@ def from_pretrained( pretrained_model_name_or_path=pretrained_model_name_or_path, qaic_config=qaic_config, continuous_batching=continuous_batching, + num_replicate_kv_heads=num_replicate_kv_heads, **kwargs, ) return cls( @@ -3150,6 +3163,7 @@ def from_pretrained( qaic_config=qaic_config, pretrained_model_name_or_path=pretrained_model_name_or_path, max_seq_len_cached=max_seq_len_cached, + num_replicate_kv_heads=num_replicate_kv_heads, **kwargs, ) diff --git a/QEfficient/transformers/models/molmo/modeling_molmo.py b/QEfficient/transformers/models/molmo/modeling_molmo.py index d59ca4e017..b673d9e060 100644 --- a/QEfficient/transformers/models/molmo/modeling_molmo.py +++ b/QEfficient/transformers/models/molmo/modeling_molmo.py @@ -565,6 +565,7 @@ class QEffMolmoEncoderWrapper(nn.Module): def __init__(self, model): super().__init__() self.model = model + self.config = self.model.config def get_submodules_for_export(self) -> Type[nn.Module]: """ diff --git a/QEfficient/transformers/models/pytorch_transforms.py b/QEfficient/transformers/models/pytorch_transforms.py index b2b447a780..fe70fb0551 100755 --- a/QEfficient/transformers/models/pytorch_transforms.py +++ b/QEfficient/transformers/models/pytorch_transforms.py @@ -297,8 +297,13 @@ ) from transformers.models.xlm_roberta.modeling_xlm_roberta import XLMRobertaModel -from QEfficient.base.pytorch_transforms import ExternalModuleMapperTransform, ModuleMappingTransform +from QEfficient.base.pytorch_transforms import ( + ExternalModuleMapperTransform, + ModuleMappingTransform, + ModuleMutatorTransform, +) from QEfficient.customop import CustomRMSNormAIC, GemmaCustomRMSNormAIC +from QEfficient.customop.matmulnbits import QuantLinearORT, dequantize_blockwise_bits from QEfficient.transformers.embeddings.embedding_utils import POOLING_MAP, PooledModel, validate_user_pooling_function from QEfficient.transformers.models.bert.modeling_bert import ( QEffBertModel, @@ -627,8 +632,18 @@ QEffWhisperPositionalEmbedding, ) from QEfficient.transformers.post_processing import build_and_attach_mlp, model_type_registry +from QEfficient.transformers.quantizers.awq import WQLinear_GEMM +from QEfficient.transformers.quantizers.gptq import QuantLinearGPTQ +from QEfficient.transformers.quantizers.quantizer_compressed_tensors import FP8DeQuantLinear from QEfficient.transformers.sampler.sampler import sampler_forward from QEfficient.transformers.spd.spd_transform_forward import tlm_forward +from QEfficient.utils.config_utils import ( + resolve_attention_heads, + resolve_hidden_size, + resolve_kv_heads, + set_kv_head_aliases, +) +from QEfficient.utils.constants import ATTENTION_HEAD_CONFIG_KEYS, HIDDEN_SIZE_CONFIG_KEYS, KV_HEAD_CONFIG_KEYS from QEfficient.utils.logging_utils import logger SPD_TARGET = "target" @@ -963,72 +978,405 @@ class RevertPrefillOnlyTransform(ModuleMappingTransform): } -class ReplicateKVHeadTransform: +class ReplicateKVHeadTransform(ModuleMutatorTransform): """ Replicates KV heads in attention modules to match the number of KV heads in the target model. This transform is used when the source model has fewer KV heads than required in target model. """ + _module_mapping = { + QEffCodeGenForCausalLM, + QEffFalconForCausalLM, + QEffGPT2LMHeadModel, + QEffGPTJForCausalLM, + QEffLlamaForCausalLM, + QEffLlama4ForConditionalGeneration, + QEffLlavaForConditionalGeneration, + QEffLlavaNextForConditionalGeneration, + QEffGemmaForCausalLM, + QEffGemma2ForCausalLM, + QEffGemma3ForConditionalGeneration, + QEffGraniteForCausalLM, + QEffGraniteMoeForCausalLM, + QEffMllamaForConditionalGeneration, + QEffMistralForCausalLM, + QEffMistral3ForConditionalGeneration, + QEffMixtralForCausalLM, + QEffMptForCausalLM, + QEffPhiForCausalLM, + QEffPhi3ForCausalLM, + QEffQwen2ForCausalLM, + QEffQwen3ForCausalLM, + QEffQwen_2_5_vl_ForConditionalGeneration, + QEffQwen3MoeForCausalLM, + QEffQwen3VLForConditionalGeneration, + QEffQwen3VLMoeForConditionalGeneration, + QEffStarcoder2ForCausalLM, + QEffGPTBigCodeForCausalLM, + QEffOlmo2ForCausalLM, + } + _module_string_mapping = { + "DeepseekV3ForCausalLM", + "InternVLChatModel", + "MolmoForCausalLM,", + "QEffGemma3DecoderWrapper", + "QEffGemma3EncoderWrapper", + "QEffInternDecoderWrapper", + "QEffInternEncoderWrapper", + "QEffLlama4DecoderWrapper", + "QEffLlama4EncoderWrapper", + "QEFFLlavaDecoderWrapper", + "QEFFLlavaEncoderWrapper", + "QEffLlavaNextDecoderWrapper", + "QEffLlavaNextEncoderWrapper", + "QEFFMistral3DecoderWrapper", + "QEFFMistral3EncoderWrapper", + "QEffMolmoDecoderWrapper", + "QEffMolmoEncoderWrapper", + "QEffQwen_2_5_vl_DecoderWrapper", + "QEffQwen_2_5_vl_EncoderWrapper", + "QEffQwen3VLDecoderWrapper", + "QEffQwen3VLEncoderWrapper", + } + + @classmethod + def _get_attention_module(cls, block: nn.Module) -> nn.Module: + for attr in ("cross_attn", "self_attn", "attention", "attn"): + attn = getattr(block, attr, None) + if attn is not None: + return attn + raise AttributeError(f"No attention module found in block type {block.__class__.__name__}") + + @staticmethod + def _get_projection_layer(attn: nn.Module, names: tuple) -> nn.Module: + for name in names: + layer = getattr(attn, name, None) + if layer is not None: + return layer + raise AttributeError(f"Missing projection layer in {attn.__class__.__name__}; expected one of {names}") + + @staticmethod + def _is_mla_attention(attn: nn.Module) -> bool: + return ( + hasattr(attn, "kv_a_proj_with_mqa") and hasattr(attn, "kv_lora_rank") and hasattr(attn, "qk_rope_head_dim") + ) + + @classmethod + def _is_mla_model(cls, text_model: nn.Module) -> bool: + for block in getattr(text_model, "layers", []): + try: + attn = cls._get_attention_module(block) + except AttributeError: + continue + if cls._is_mla_attention(attn): + return True + return False + def _duplicate_weights_for_linear_layer( - layer: nn.Module, orig_kv_heads: int, repeat: int, dim: int, hidden_size: int + layer: nn.Module, orig_kv_heads: int, repeat: int, head_dim: int, hidden_size: int ): - new_kv_heads = repeat # for mla + new_kv_heads = repeat * orig_kv_heads + if isinstance(layer, WQLinear_GEMM): + # AWQ layout: + # qweight: [in_features, out_features/pack] + # qzeros: [in_features/group_size, out_features/pack] + # scales: [in_features/group_size, out_features] + if layer.qweight.shape[1] % orig_kv_heads != 0: + raise ValueError( + f"Invalid AWQ qweight shape for RepeatKV: qweight.shape={tuple(layer.qweight.shape)}, " + f"orig_kv_heads={orig_kv_heads}" + ) + if layer.qzeros.shape[1] % orig_kv_heads != 0 or layer.scales.shape[1] % orig_kv_heads != 0: + raise ValueError( + f"Invalid AWQ qzeros/scales shape for RepeatKV: qzeros.shape={tuple(layer.qzeros.shape)}, " + f"scales.shape={tuple(layer.scales.shape)}, orig_kv_heads={orig_kv_heads}" + ) + + layer.qweight.data = torch.repeat_interleave( + layer.qweight.data.view(layer.qweight.shape[0], orig_kv_heads, -1), repeat, 1 + ).view(layer.qweight.shape[0], -1) + layer.qzeros.data = torch.repeat_interleave( + layer.qzeros.data.view(layer.qzeros.shape[0], orig_kv_heads, -1), repeat, 1 + ).view(layer.qzeros.shape[0], -1) + layer.scales.data = torch.repeat_interleave( + layer.scales.data.view(layer.scales.shape[0], orig_kv_heads, -1), repeat, 1 + ).view(layer.scales.shape[0], -1) + layer.out_features = layer.out_features * repeat + + elif isinstance(layer, QuantLinearGPTQ): + # GPTQ layout: + # qweight: [in_features/pack, out_features] + # qzeros: [in_features/group_size, out_features/pack] + # scales: [in_features/group_size, out_features] + if layer.qweight.shape[1] % orig_kv_heads != 0: + raise ValueError( + f"Invalid GPTQ qweight shape for RepeatKV: qweight.shape={tuple(layer.qweight.shape)}, " + f"orig_kv_heads={orig_kv_heads}" + ) + if layer.qzeros.shape[1] % orig_kv_heads != 0 or layer.scales.shape[1] % orig_kv_heads != 0: + raise ValueError( + f"Invalid GPTQ qzeros/scales shape for RepeatKV: qzeros.shape={tuple(layer.qzeros.shape)}, " + f"scales.shape={tuple(layer.scales.shape)}, orig_kv_heads={orig_kv_heads}" + ) - layer.weight.data = torch.repeat_interleave( - layer.weight.data.view(orig_kv_heads, dim, hidden_size), repeat, 0 - ).view(new_kv_heads * dim, hidden_size) + layer.qweight.data = torch.repeat_interleave( + layer.qweight.data.view(layer.qweight.shape[0], orig_kv_heads, -1), repeat, 1 + ).view(layer.qweight.shape[0], -1) + layer.qzeros.data = torch.repeat_interleave( + layer.qzeros.data.view(layer.qzeros.shape[0], orig_kv_heads, -1), repeat, 1 + ).view(layer.qzeros.shape[0], -1) + layer.scales.data = torch.repeat_interleave( + layer.scales.data.view(layer.scales.shape[0], orig_kv_heads, -1), repeat, 1 + ).view(layer.scales.shape[0], -1) + layer.out_features = layer.out_features * repeat + + elif isinstance(layer, QuantLinearORT): + # QuantLinearORT stores blockwise packed buffers. Dequantize, replicate per-KV-head, + # then re-pack using existing QuantLinearORT.pack path. + float_weight, zeros_per_group, scales_per_group = dequantize_blockwise_bits( + layer.qweight, + layer.scales, + layer.qzeros, + layer.bits, + layer.group_size, + layer.g_idx, + layer.in_features, + layer.out_features, + ) + # float_weight: [out_features, in_features] + if float_weight.shape[0] % orig_kv_heads != 0: + raise ValueError( + f"Invalid QuantLinearORT weight shape for RepeatKV: " + f"weight.shape={tuple(float_weight.shape)}, orig_kv_heads={orig_kv_heads}" + ) + + duplicated_weight = torch.repeat_interleave( + float_weight.view(orig_kv_heads, -1, float_weight.shape[1]), + repeat, + dim=0, + ).view(new_kv_heads * (float_weight.shape[0] // orig_kv_heads), float_weight.shape[1]) + + duplicated_zeros = torch.repeat_interleave( + zeros_per_group.view(orig_kv_heads, -1, zeros_per_group.shape[1]), + repeat, + dim=0, + ).view(new_kv_heads * (zeros_per_group.shape[0] // orig_kv_heads), zeros_per_group.shape[1]) + duplicated_scales = torch.repeat_interleave( + scales_per_group.view(orig_kv_heads, -1, scales_per_group.shape[1]), + repeat, + dim=0, + ).view(new_kv_heads * (scales_per_group.shape[0] // orig_kv_heads), scales_per_group.shape[1]) + + original_out_features = layer.out_features + layer.out_features = original_out_features * repeat + q_rows = layer.in_features // layer.group_size + layer.qweight = torch.zeros( + (layer.out_features, q_rows, layer.group_size // (8 // layer.bits)), + dtype=layer.qweight.dtype, + device=layer.qweight.device, + ) + layer.qzeros = torch.zeros( + (q_rows + (q_rows & 1)) * (layer.out_features // 8 * layer.bits), + dtype=layer.qzeros.dtype, + device=layer.qzeros.device, + ) + layer.scales = torch.zeros( + (q_rows * layer.out_features), + dtype=layer.scales.dtype, + device=layer.scales.device, + ) + + linear = nn.Linear(layer.in_features, layer.out_features, bias=False, dtype=duplicated_weight.dtype) + linear.weight.data = duplicated_weight.to(linear.weight.dtype) + layer.pack( + linear, + duplicated_scales.contiguous().to(layer.scales.dtype), + duplicated_zeros.contiguous().to(torch.int32), + layer.g_idx, + ) + elif isinstance(layer, FP8DeQuantLinear): + layer.weight.data = torch.repeat_interleave( + layer.weight.data.view(orig_kv_heads, head_dim, hidden_size), repeat, 0 + ).view(new_kv_heads * head_dim, hidden_size) + layer.weight_scale.data = torch.repeat_interleave( + layer.weight_scale.data.view(orig_kv_heads, head_dim), repeat, 0 + ).view(new_kv_heads * head_dim, -1) + + else: + layer.weight.data = torch.repeat_interleave( + layer.weight.data.view(orig_kv_heads, head_dim, hidden_size), repeat, 0 + ).view(new_kv_heads * head_dim, hidden_size) if layer.bias is not None: - layer.bias.data = torch.repeat_interleave(layer.bias.data.view(orig_kv_heads, dim), repeat, 0).view( - new_kv_heads * dim + layer.bias.data = torch.repeat_interleave(layer.bias.data.view(orig_kv_heads, head_dim), repeat, 0).view( + new_kv_heads * head_dim ) + def _is_valid_text_model(candidate: nn.Module) -> bool: + """ + Validate whether a candidate object looks like a text stack suitable for KV replication. + """ + if candidate is None: + return False + cfg = getattr(candidate, "config", None) + layers = getattr(candidate, "layers", None) + attn_heads = resolve_attention_heads(cfg) if cfg is not None else None + kv_heads = resolve_kv_heads(cfg) if cfg is not None else None + hidden_size = resolve_hidden_size(cfg) if cfg is not None else None + return ( + cfg is not None + and layers is not None + and attn_heads is not None + and kv_heads is not None + and hidden_size is not None + ) + def _get_text_model(model): """ Determine and return the appropriate text_model from a given model object. + + Some VLM wrappers expose multiple nested text attributes (e.g. `language_model`, + `language_model.model`, `model.language_model`). We pick the first valid module + that has both `config` and `layers` required for KV head replication. """ - # Check for VLMs - if hasattr(model, "language_model"): - if hasattr(model.language_model, "model"): - return model.language_model.model - else: - return model.language_model - # Check for CausalLMs - if hasattr(model, "model"): - return model.model + candidate_paths = ( + ("language_model",), + ("language_model", "model"), + ("model", "language_model"), + ("model", "language_model", "model"), + ("model",), + ("model", "model"), + ("transformer",), + ("transformer", "model"), + ("llm",), + ("llm", "model"), + ("backbone",), + ) - raise AttributeError("No suitable text model found in the provided model.") + for path in candidate_paths: + candidate = model + valid_path = True + for attr in path: + if not hasattr(candidate, attr): + valid_path = False + break + candidate = getattr(candidate, attr) + if valid_path and ReplicateKVHeadTransform._is_valid_text_model(candidate): + return candidate + + raise AttributeError( + f"No suitable text model found in the provided model ({model.__class__.__name__}). " + "Expected a module with `layers` and text `config` attributes." + ) + + def _get_replication_root(model: nn.Module) -> nn.Module: + """ + Return a shared root module for wrapper and non-wrapper models so KV replication + can be applied once across encoder/decoder components of the same model. + """ + candidate = getattr(model, "model", None) + return candidate if isinstance(candidate, nn.Module) else model @classmethod - def apply(cls, model: nn.Module, num_kv_heads_repeat: int = 1) -> nn.Module: + def mutate(cls, original_module: nn.Module, parent_module: nn.Module, n_repeat: int) -> nn.Module: """ - Replicates KV heads in attention modules based on provided multiplier. + Mutates the matched top-level model module in-place by replicating its KV heads. Args: - model: The model to apply the transform to. - num_kv_heads_repeat: The number of times to repeat the KV heads. - """ - transformed = False - if num_kv_heads_repeat is not None and num_kv_heads_repeat > 1: - text_model = cls._get_text_model(model) + original_module: The matched top-level model module to mutate. + parent_module: The parent module (unused, present for interface compatibility). + n_repeat: The number of times to repeat the KV heads. - orig_kv_heads = 1 # for mla #text_model.config.num_key_value_heads - new_kv_heads = num_kv_heads_repeat * orig_kv_heads - text_model.config.orig_kv_heads = orig_kv_heads - text_model.config.num_key_value_heads = new_kv_heads + Returns: + The mutated module (same object, modified in-place). + """ + # breakpoint() + replication_root = cls._get_replication_root(original_module) + if getattr(replication_root, "_qeff_kv_replication_applied", False): + logger.warning("KV head replication already applied for this model instance; skipping.") + return original_module + + text_model = cls._get_text_model(original_module) + cfg = text_model.config + if cls._is_mla_model(text_model): + logger.warning("Skipping RepeatKVTransform: MLA models don't apply replicate KV changes.") + return original_module + + orig_kv_heads = resolve_kv_heads(cfg) + num_attention_heads = resolve_attention_heads(cfg) + hidden_size = resolve_hidden_size(cfg) + + if orig_kv_heads is None or num_attention_heads is None or hidden_size is None: + raise ValueError( + "Unable to resolve attention/KV heads or hidden size from config for RepeatKV transform. " + f"Supported attention keys={ATTENTION_HEAD_CONFIG_KEYS}, kv keys={KV_HEAD_CONFIG_KEYS}, " + f"hidden size keys={HIDDEN_SIZE_CONFIG_KEYS}." + ) + if orig_kv_heads < 1 or num_attention_heads < 1: + raise ValueError( + f"Invalid head values for RepeatKV transform: " + f"num_attention_heads={num_attention_heads}, num_key_value_heads={orig_kv_heads}" + ) + new_kv_heads = n_repeat * orig_kv_heads + if new_kv_heads > num_attention_heads or (num_attention_heads % new_kv_heads) != 0: + raise ValueError( + f"Invalid RepeatKV configuration: num_attention_heads={num_attention_heads}, " + f"orig_kv_heads={orig_kv_heads}, num_replicate_kv_heads={n_repeat}, new_kv_heads={new_kv_heads}. " + "Expected new_kv_heads <= num_attention_heads and divisibility." + ) - hidden_size = text_model.config.hidden_size + cfg.orig_kv_heads = orig_kv_heads + set_kv_head_aliases(cfg, new_kv_heads) - logger.warning(f"Original KV heads: {orig_kv_heads}") - logger.warning(f"Modified KV heads: {new_kv_heads}") - transformed = True - for block in text_model.layers: - attn = getattr(block, "cross_attn", getattr(block, "self_attn", None)) + logger.warning(f"Original KV heads: {orig_kv_heads}") + logger.warning(f"Modified KV heads: {new_kv_heads}") + for block in text_model.layers: + attn = cls._get_attention_module(block) + if hasattr(attn, "num_key_value_heads"): attn.num_key_value_heads = new_kv_heads - head_dim = attn.kv_lora_rank + attn.qk_rope_head_dim + if hasattr(attn, "n_kv_heads"): + attn.n_kv_heads = new_kv_heads + + n_kv_groups = num_attention_heads // new_kv_heads + if hasattr(attn, "num_key_value_groups"): + attn.num_key_value_groups = n_kv_groups + if hasattr(attn, "n_kv_groups"): + attn.n_kv_groups = n_kv_groups + head_dim = getattr(attn, "head_dim", hidden_size // num_attention_heads) + k_proj = cls._get_projection_layer(attn, ("k_proj", "key_proj")) + v_proj = cls._get_projection_layer(attn, ("v_proj", "value_proj")) + cls._duplicate_weights_for_linear_layer(k_proj, orig_kv_heads, n_repeat, head_dim, hidden_size) + cls._duplicate_weights_for_linear_layer(v_proj, orig_kv_heads, n_repeat, head_dim, hidden_size) + + setattr(replication_root, "_qeff_kv_replication_applied", True) + return original_module - cls._duplicate_weights_for_linear_layer( - attn.kv_a_proj_with_mqa, orig_kv_heads, num_kv_heads_repeat, head_dim, hidden_size + @classmethod + def apply(cls, model: nn.Module, num_replicate_kv_heads: Optional[int] = None, **kwargs) -> Tuple[nn.Module, bool]: + """ + Replicates KV heads in attention modules based on provided multiplier. + + Args: + model: The model to apply the transform to. + kwargs: Additional arguments for the transformation. Includes: + - num_replicate_kv_heads: The number of times to repeat the KV heads. + """ + if num_replicate_kv_heads is None: + n_repeat = kwargs.pop("num_replicate_kv_heads", 1) + else: + kwargs.pop("num_replicate_kv_heads", None) + n_repeat = num_replicate_kv_heads + transformed = False + if n_repeat is not None and n_repeat > 1: + if (model.__class__ in cls._module_mapping) or (model.__class__.__name__ in cls._module_string_mapping): + transform_root = cls._get_replication_root(model) + was_applied = getattr(transform_root, "_qeff_kv_replication_applied", False) + cls.mutate(model, None, n_repeat) + is_applied = getattr(transform_root, "_qeff_kv_replication_applied", False) + transformed = (not was_applied) and is_applied + else: + raise NotImplementedError( + f"Model class {model.__class__.__name__} is not supported for KV head replication." ) return model, transformed diff --git a/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py b/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py index 357c4af16e..f970ba54b6 100644 --- a/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +++ b/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py @@ -746,6 +746,7 @@ def __init__(self, model): super().__init__() self.model = model.model self.model.vision_model = self.model.visual + self.config = self.model.config def get_submodules_for_export(self) -> Type[nn.Module]: """ diff --git a/QEfficient/transformers/models/qwen3_vl/modeling_qwen3_vl.py b/QEfficient/transformers/models/qwen3_vl/modeling_qwen3_vl.py index 0f6ab210de..45a8a8fa5a 100644 --- a/QEfficient/transformers/models/qwen3_vl/modeling_qwen3_vl.py +++ b/QEfficient/transformers/models/qwen3_vl/modeling_qwen3_vl.py @@ -652,6 +652,7 @@ def __init__(self, model): super().__init__() self.model = model.model self.model.vision_model = self.model.visual + self.config = self.model.config def get_submodules_for_export(self) -> Type[nn.Module]: """ diff --git a/QEfficient/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py b/QEfficient/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py index db30350f8b..17ff828b42 100644 --- a/QEfficient/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +++ b/QEfficient/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py @@ -745,6 +745,7 @@ def __init__(self, model): super().__init__() self.model = model.model self.model.vision_model = self.model.visual + self.config = self.model.config def get_submodules_for_export(self) -> Type[nn.Module]: """ diff --git a/QEfficient/utils/config_utils.py b/QEfficient/utils/config_utils.py new file mode 100644 index 0000000000..4b28d54880 --- /dev/null +++ b/QEfficient/utils/config_utils.py @@ -0,0 +1,68 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +from typing import Iterable, Optional + +from QEfficient.utils.constants import ATTENTION_HEAD_CONFIG_KEYS, HIDDEN_SIZE_CONFIG_KEYS, KV_HEAD_CONFIG_KEYS + + +def get_first_config_value(config, names: Iterable[str], default=None, cast_int: bool = False): + for name in names: + value = getattr(config, name, None) + if value is not None: + return int(value) if cast_int else value + return default + + +def resolve_attention_heads(config) -> Optional[int]: + return get_first_config_value(config, ATTENTION_HEAD_CONFIG_KEYS, cast_int=True) + + +def resolve_kv_heads(config) -> Optional[int]: + value = get_first_config_value(config, KV_HEAD_CONFIG_KEYS, cast_int=True) + if value is None: + value = resolve_attention_heads(config) + return value + + +def resolve_hidden_size(config) -> Optional[int]: + return get_first_config_value(config, HIDDEN_SIZE_CONFIG_KEYS, cast_int=True) + + +def set_kv_head_aliases(config, value: int): + setattr(config, "num_key_value_heads", value) + for key in KV_HEAD_CONFIG_KEYS: + if hasattr(config, key): + setattr(config, key, value) + + +def calculate_num_replicate_kv_heads(num_devices: int, text_model_config) -> int: + """ + Choose a KV-repeat value from model config and device count. + + Primary criteria: + 1. num_kv_heads * repeat is divisible by num_devices + 2. num_attention_heads is divisible by (num_kv_heads * repeat) + + Fallback: + repeat = num_attention_heads / num_kv_heads (integer-truncated if needed). + """ + num_attention_heads = resolve_attention_heads(text_model_config) + num_kv_heads = resolve_kv_heads(text_model_config) + + if num_attention_heads is None or num_kv_heads is None or num_attention_heads < 1 or num_kv_heads < 1: + return 1 + + num_devices = max(1, int(num_devices)) + max_repeat = max(1, int(num_attention_heads / num_kv_heads)) + + for repeat in range(max_repeat, 0, -1): + repeated_kv_heads = num_kv_heads * repeat + if (repeated_kv_heads % num_devices == 0) and (num_attention_heads % repeated_kv_heads == 0): + return repeat + + return 1 diff --git a/QEfficient/utils/constants.py b/QEfficient/utils/constants.py index 3a03f6b1c9..0cc1c4ff61 100644 --- a/QEfficient/utils/constants.py +++ b/QEfficient/utils/constants.py @@ -140,6 +140,11 @@ def get_default_aic_hw_version() -> str: DEFAULT_AIC_HW_VERSION = get_default_aic_hw_version() ONNX_TRANSFORM_MEMORY_CLEANUP_INTERVAL = 100 +# Generic config key aliases used across model families. +ATTENTION_HEAD_CONFIG_KEYS = ("num_attention_heads", "n_head", "n_heads", "num_heads") +KV_HEAD_CONFIG_KEYS = ("num_key_value_heads", "n_kv_heads", "num_kv_heads", "effective_n_kv_heads") +HIDDEN_SIZE_CONFIG_KEYS = ("hidden_size", "n_embd", "d_model") + # InternVL constants # Fixing the feature size with reference to OpenGVLab/InternVL2_5-1B, OpenGVLab/InternVL2_5-38B and OpenGVLab/InternVL2_5-78B INTERN_FEATURE_SIZE = 256 diff --git a/QEfficient/utils/test_utils.py b/QEfficient/utils/test_utils.py index 131ff59e26..34bc474e5b 100644 --- a/QEfficient/utils/test_utils.py +++ b/QEfficient/utils/test_utils.py @@ -289,6 +289,14 @@ def load_qeff_model_with_sampler( return qeff_model +def get_text_config(config): + if hasattr(config, "text_config"): + return config.text_config + elif hasattr(config, "llm_config"): + return config.llm_config + return config + + # Processor class for InternVL models class InternProcessor: """ @@ -492,6 +500,31 @@ class ModelConfig: "Qwen/Qwen3.6-35B-A3B", } + REPEAT_KV_TEST_MODELS = { + "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + "ibm-granite/granite-3.1-1b-a400m-base", + "Qwen/Qwen2-0.5B", + "bigcode/starcoder2-3b", + # "mistralai/Mixtral-8x7B-Instruct-v0.1", + "meta-llama/Llama-3.2-1B", + # "unsloth/gemma-2b", + # "unsloth/gemma-2-2b", + "TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ", + "TheBloke/Llama-2-7B-GPTQ", + "neuralmagic/Llama-3.2-3B-Instruct-FP8", + "ibm-granite/granite-3.1-2b-instruct", + "llava-hf/llava-1.5-7b-hf", + "meta-llama/Llama-4-Scout-17B-16E-Instruct", + # "google/gemma-3-4b-it", + "allenai/Molmo-7B-D-0924", + "mistralai/Mistral-Small-3.1-24B-Instruct-2503", + "Qwen/Qwen2.5-VL-3B-Instruct", + "Qwen/Qwen3-VL-2B-Instruct", + "Qwen/Qwen3-VL-30B-A3B-Instruct", + "allenai/Molmo-7B-D-0924", + "OpenGVLab/InternVL2_5-1B", + } + EXTERNAL_MODELS = { "hpcai-tech/grok-1": { "pytorch_hf_tokens_custom_case": [ diff --git a/examples/kimi_k2/README.md b/examples/kimi_k2/README.md index 230127ebbe..4fae4a8cfb 100644 --- a/examples/kimi_k2/README.md +++ b/examples/kimi_k2/README.md @@ -20,9 +20,9 @@ mla_absorption has 3 keys: # Blocking We have also implemented KV head replication, HEAD Blocking and KV Blocking which can be enable like this : - For No Blocking : qaic_config = {"mla_absorption" : mla_absorption} -- For No blocking with kv head replication : qaic_config = {"mla_absorption" : mla_absorption, "num_kv_heads_repeat": TS} +- For No blocking with kv head replication : qaic_config = {"mla_absorption" : mla_absorption, "num_replicate_kv_heads": TS} - For KV blocking : qaic_config = {"mla_absorption" : mla_absorption, "enable_blocking": True, "blocking_mode": "kv"} # for KV blocking -- For Head Blocking : qaic_config = {"mla_absorption" : mla_absorption, "enable_blocking": True, "blocking_mode": "h", "num_kv_heads_repeat": TS} for h blocking, it internally sets head_block_size equal to num_devices/num_kv_heads_repeat +- For Head Blocking : qaic_config = {"mla_absorption" : mla_absorption, "enable_blocking": True, "blocking_mode": "h", "num_replicate_kv_heads": TS} for h blocking, it internally sets head_block_size equal to num_devices/num_replicate_kv_heads - Currently Decode-Only model is giving best perf with Head Blocking and compressed cache. - Contnuous batching is not enabled yet. \ No newline at end of file diff --git a/examples/kimi_k2/export_kimik2.py b/examples/kimi_k2/export_kimik2.py index 1e70352165..ba6b26c064 100644 --- a/examples/kimi_k2/export_kimik2.py +++ b/examples/kimi_k2/export_kimik2.py @@ -18,16 +18,16 @@ # qaic_config = None # Full PKV Cache # qaic_config = {"enable_blocking": True, "blocking_mode": "h"} # Full PKV Cache with Head Blocking # qaic_config = {"mla_absorption": mla_absorption} # for No Blocking -# qaic_config = {"mla_absorption": mla_absorption, "num_kv_heads_repeat": TS} # No blocking with kv head replication +# qaic_config = {"mla_absorption": mla_absorption, "num_replicate_kv_heads": TS} # No blocking with kv head replication # qaic_config = {"mla_absorption": mla_absorption, "enable_blocking": True, "blocking_mode": "kv"} # for KV blocking -# qaic_config = {"mla_absorption": mla_absorption, "enable_blocking": True, "blocking_mode": "kv", "num_kv_heads_repeat":TS} # for KV blocking with kv head replication +# qaic_config = {"mla_absorption": mla_absorption, "enable_blocking": True, "blocking_mode": "kv", "num_replicate_kv_heads":TS} # for KV blocking with kv head replication qaic_config = { "mla_absorption": mla_absorption, "enable_blocking": True, "blocking_mode": "h", - "num_kv_heads_repeat": TS, + "num_replicate_kv_heads": TS, } -# for h blocking, it internally sets head_block_size equal to num_devices/num_kv_heads_repeat +# for h blocking, it internally sets head_block_size equal to num_devices/num_replicate_kv_heads model_name = "moonshotai/Kimi-K2-Thinking" model = AutoModelForCausalLM.from_pretrained( diff --git a/examples/text_generation/run_kimik2.py b/examples/text_generation/run_kimik2.py index 81767308ad..e85c572420 100644 --- a/examples/text_generation/run_kimik2.py +++ b/examples/text_generation/run_kimik2.py @@ -19,16 +19,16 @@ # qaic_config = None # Full PKV Cache # qaic_config = {"enable_blocking": True, "blocking_mode": "h"} # Full PKV Cache with Head Blocking # qaic_config = {"mla_absorption": mla_absorption} # for No Blocking -# qaic_config = {"mla_absorption": mla_absorption, "num_kv_heads_repeat": TS} # No blocking with kv head replication +# qaic_config = {"mla_absorption": mla_absorption, "num_replicate_kv_heads": TS} # No blocking with kv head replication # qaic_config = {"mla_absorption": mla_absorption, "enable_blocking": True, "blocking_mode": "kv"} # for KV blocking -# qaic_config = {"mla_absorption": mla_absorption, "enable_blocking": True, "blocking_mode": "kv", "num_kv_heads_repeat":TS} # for KV blocking with kv head replication +# qaic_config = {"mla_absorption": mla_absorption, "enable_blocking": True, "blocking_mode": "kv", "num_replicate_kv_heads":TS} # for KV blocking with kv head replication qaic_config = { "mla_absorption": mla_absorption, "enable_blocking": True, "blocking_mode": "h", - "num_kv_heads_repeat": TS, + "num_replicate_kv_heads": TS, } -# for h blocking, it internally sets head_block_size equal to num_devices/num_kv_heads_repeat +# for h blocking, it internally sets head_block_size equal to num_devices/num_replicate_kv_heads model_name = "moonshotai/Kimi-K2-Thinking" model = AutoModelForCausalLM.from_pretrained( diff --git a/tests/configs/causal_model_configs.json b/tests/configs/causal_model_configs.json index 93f4e7ae2f..2c092ed9ee 100644 --- a/tests/configs/causal_model_configs.json +++ b/tests/configs/causal_model_configs.json @@ -325,6 +325,19 @@ "num_key_value_heads": 1 } }, + { + "model_name": "hpcai-tech/grok-1", + "model_type": null, + "additional_params": { + "max_position_embeddings": 128, + "num_hidden_layers": 1, + "num_attention_heads": 2, + "hidden_size": 64, + "intermediate_size": 256, + "vocab_size": 131072, + "num_key_value_heads": 1 + } + }, { "model_name": "Snowflake/Llama-3.1-SwiftKV-8B-Instruct", "model_type": null, @@ -720,4 +733,4 @@ } } ] -} +} \ No newline at end of file diff --git a/tests/transformers/models/causal_lm_models/check_causal_models.py b/tests/transformers/models/causal_lm_models/check_causal_models.py index f878acbe73..78ff74cbfd 100644 --- a/tests/transformers/models/causal_lm_models/check_causal_models.py +++ b/tests/transformers/models/causal_lm_models/check_causal_models.py @@ -16,7 +16,8 @@ from QEfficient.transformers.models.modeling_auto import QEFFAutoModelForCausalLM from QEfficient.transformers.quantizers.auto import replace_transformers_quantizers from QEfficient.utils._utils import load_hf_tokenizer -from QEfficient.utils.constants import Constants +from QEfficient.utils.config_utils import get_first_config_value +from QEfficient.utils.constants import ATTENTION_HEAD_CONFIG_KEYS, KV_HEAD_CONFIG_KEYS, Constants from QEfficient.utils.run_utils import ApiRunner from QEfficient.utils.test_utils import ModelConfig, load_hf_causal_lm_model @@ -39,6 +40,52 @@ def get_custom_n_layers(model_name): return 1 +def check_kv_repeat_causal_lm_pytorch_vs_ai100( + model_name: str, + manual_cleanup: callable, + prompt_len: int = Constants.PROMPT_LEN, + ctx_len: int = Constants.CTX_LEN, + n_layer: int = -1, + config: Optional[AutoConfig] = None, +): + """ + Validate causal LM flow with repeated KV heads configuration. + """ + if config is None: + model_config = AutoConfig.from_pretrained( + model_name, + trust_remote_code=model_name in ModelConfig.EXTERNAL_MODELS, + ) + else: + model_config = config + + num_attention_heads = get_first_config_value(model_config, ATTENTION_HEAD_CONFIG_KEYS, default=1, cast_int=True) + num_key_value_heads = get_first_config_value(model_config, KV_HEAD_CONFIG_KEYS, default=None, cast_int=True) + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + if num_attention_heads < 1 or num_key_value_heads < 1: + raise ValueError( + f"Invalid heads in config for RepeatKV: " + f"num_attention_heads={num_attention_heads}, num_key_value_heads={num_key_value_heads}" + ) + if num_attention_heads % num_key_value_heads != 0: + raise ValueError( + f"Invalid heads in config for RepeatKV: num_attention_heads ({num_attention_heads}) " + f"is not divisible by num_key_value_heads ({num_key_value_heads})." + ) + num_replicate_kv_heads = num_attention_heads // num_key_value_heads + + check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( + model_name=model_name, + manual_cleanup=manual_cleanup, + prompt_len=prompt_len, + ctx_len=ctx_len, + n_layer=n_layer, + config=config, + qaic_config={"num_replicate_kv_heads": num_replicate_kv_heads}, + ) + + def check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( model_name: str, manual_cleanup: callable, @@ -71,15 +118,6 @@ def check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( pytorch_kv_tokens = None ort_tokens = None - api_runner = ApiRunner( - batch_size, - tokenizer, - config, - prompts, - Constants.PROMPT_LEN, - Constants.CTX_LEN, - full_batch_size if continuous_batching else None, - ) qeff_model = QEFFAutoModelForCausalLM( copy.deepcopy(model_hf), is_tlm=is_tlm, @@ -94,6 +132,15 @@ def check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( num_devices=num_devices, qaic_config=qaic_config, ) + api_runner = ApiRunner( + batch_size, + tokenizer, + qeff_model.config, + prompts, + Constants.PROMPT_LEN, + Constants.CTX_LEN, + full_batch_size if continuous_batching else None, + ) if continuous_batching is False: pytorch_kv_tokens = api_runner.run_kv_model_on_pytorch(qeff_model.model) diff --git a/tests/transformers/models/causal_lm_models/test_causal_lm_models.py b/tests/transformers/models/causal_lm_models/test_causal_lm_models.py index 5011a670a6..1b0b07be6b 100644 --- a/tests/transformers/models/causal_lm_models/test_causal_lm_models.py +++ b/tests/transformers/models/causal_lm_models/test_causal_lm_models.py @@ -17,6 +17,7 @@ from .check_causal_models import ( check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100, + check_kv_repeat_causal_lm_pytorch_vs_ai100, get_custom_n_layers, ) @@ -73,6 +74,32 @@ def test_dummy_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(model_name, manual_cleanu check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(model_name, config=hf_config, manual_cleanup=manual_cleanup) +@pytest.mark.dummy_layers +@pytest.mark.on_qaic +@pytest.mark.llm_model +@pytest.mark.parametrize("model_name", test_models_causal) +def test_check_kv_repeat_custom_causal_lm_pytorch_vs_ai100(model_name, manual_cleanup): + """ + Test function to validate the PyTorch model and the Cloud AI 100 model with repeating original KV heads. + ``Mandatory`` Args: + :model_name (str): Hugging Face Model Card name, Example: ``gpt2`` + """ + custom_config = model_config_dict[model_name] + hf_config = AutoConfig.from_pretrained( + model_name, + trust_remote_code=model_name in ModelConfig.EXTERNAL_MODELS, + **custom_config.get("additional_params", {}), + ) + if model_name in ModelConfig.REPEAT_KV_TEST_MODELS: + if model_name in ModelConfig.QUANTIZED_MODELS: + n_layer = get_custom_n_layers(model_name) + check_kv_repeat_causal_lm_pytorch_vs_ai100(model_name, manual_cleanup=manual_cleanup, n_layer=n_layer) + else: + check_kv_repeat_causal_lm_pytorch_vs_ai100(model_name, manual_cleanup=manual_cleanup, config=hf_config) + else: + pytest.skip(f"Skipping {model_name} as it is not in REPEAT_KV_TEST_MODELS") + + @pytest.mark.full_layers @pytest.mark.on_qaic @pytest.mark.llm_model diff --git a/tests/transformers/models/image_text_to_text/test_image_text_to_text_models.py b/tests/transformers/models/image_text_to_text/test_image_text_to_text_models.py index 9b9e662e52..df9c3b9e8d 100644 --- a/tests/transformers/models/image_text_to_text/test_image_text_to_text_models.py +++ b/tests/transformers/models/image_text_to_text/test_image_text_to_text_models.py @@ -30,6 +30,7 @@ from QEfficient.utils.test_utils import ( InternProcessor, ModelConfig, + get_text_config, load_vlm_model, load_vlm_model_from_config, set_num_layers_vlm, @@ -56,6 +57,9 @@ def check_image_text_to_text_pytorch_vs_kv_vs_ort_vs_ai100( enable_qnn: Optional[bool] = False, qnn_config: Optional[str] = None, config: Optional[AutoConfig] = None, + qaic_config: Optional[dict] = None, + num_replicate_kv_heads: Optional[int] = 1, + test_kv_replicate: Optional[bool] = None, torch_dtype: Optional[torch.dtype] = torch.float32, compare_results: Optional[bool] = False, ): @@ -70,11 +74,17 @@ def check_image_text_to_text_pytorch_vs_kv_vs_ort_vs_ai100( pytorch_kv_tokens = None ort_tokens = None n_layer = num_hidden_layers + qaic_config = copy.deepcopy(qaic_config) if qaic_config is not None else None if config is None: config = AutoConfig.from_pretrained( model_name, trust_remote_code=True, padding=model_name not in ModelConfig.MOLMO_MODELS ) config = set_num_layers_vlm(config, n_layer=n_layer) + if test_kv_replicate: + text_config = get_text_config(config) + num_replicate_kv_heads = text_config.num_attention_heads // text_config.num_key_value_heads + qaic_config = qaic_config or {} + qaic_config["num_replicate_kv_heads"] = num_replicate_kv_heads if hasattr(config, "model_type") and config.model_type in ["gemma3"]: config.text_config._sliding_window_pattern = 2 config.text_config.layer_types = ["sliding_attention", "full_attention"] @@ -92,7 +102,9 @@ def check_image_text_to_text_pytorch_vs_kv_vs_ort_vs_ai100( model_name, kv_offload=kv_offload, config=config, + qaic_config=qaic_config, torch_dtype=torch_dtype, + num_replicate_kv_heads=num_replicate_kv_heads, ) else: model_hf = load_vlm_model(config) @@ -100,15 +112,24 @@ def check_image_text_to_text_pytorch_vs_kv_vs_ort_vs_ai100( model_name, kv_offload=kv_offload, config=config, + qaic_config=qaic_config, torch_dtype=torch_dtype, + num_replicate_kv_heads=num_replicate_kv_heads, ) else: + if test_kv_replicate: + text_config = get_text_config(config) + num_replicate_kv_heads = text_config.num_attention_heads // text_config.num_key_value_heads + qaic_config = qaic_config or {} + qaic_config["num_replicate_kv_heads"] = num_replicate_kv_heads model_hf = load_vlm_model_from_config(config) qeff_model = QEFFAutoModelForImageTextToText( copy.deepcopy(model_hf), kv_offload=kv_offload, config=model_hf.config, + qaic_config=qaic_config, torch_dtype=torch_dtype, + num_replicate_kv_heads=num_replicate_kv_heads, ) compile_kwargs = { "num_devices": num_devices, @@ -117,6 +138,7 @@ def check_image_text_to_text_pytorch_vs_kv_vs_ort_vs_ai100( "mxfp6": False, "enable_qnn": enable_qnn, "qnn_config": qnn_config, + "qaic_config": qaic_config, } if model_name in ModelConfig.INTERNVL_MODELS: @@ -239,7 +261,7 @@ def check_image_text_to_text_pytorch_vs_kv_vs_ort_vs_ai100( # "Tokens don't match for pytorch HF output and pytorch KV output" # ) - _ = qeff_model.export() + # _ = qeff_model.export() # ort_tokens = api_runner.run_vlm_kv_model_on_ort(onnx_model_path) # assert (pytorch_hf_tokens == ort_tokens).all(), "Tokens don't match for pytorch HF output and ORT output" @@ -337,6 +359,57 @@ def test_dummy_image_text_to_text_pytorch_vs_kv_vs_ort_vs_ai100(model_name, kv_o ) +@pytest.mark.on_qaic +@pytest.mark.multimodal +@pytest.mark.dummy_layers +@pytest.mark.parametrize("model_name", test_mm_models) +@pytest.mark.parametrize("kv_offload", [True, False]) +def test_custom_replicate_kv_pytorch_vs_ai100( + model_name, + kv_offload, + manual_cleanup, +): + """ + Test function to validate the PyTorch model, the PyTorch model after KV changes, the ONNX model, and the Cloud AI 100 model, without continuous batching. + ``Mandatory`` Args: + :model_name (str): Hugging Face Model Card name, Example: ``gpt2`` + """ + torch.manual_seed(42) + if model_name in ModelConfig.SKIPPED_MODELS: + pytest.skip("Test skipped for this model due to some issues.") + if model_name in ModelConfig.DUAL_QPC_MODELS and not kv_offload: + pytest.skip("These models require kv_offload=True for testing.") + + if model_name in ModelConfig.REPEAT_KV_TEST_MODELS: + hf_config = None + if model_name in ModelConfig.STANDARD_VLM_MODELS: + model_type = model_config_dict[model_name].get("model_type") + custom_config = model_config_dict[model_name].get("additional_params", {}) + hf_config = AutoConfig.for_model(model_type, trust_remote_code=True, **custom_config) + hf_config.name_or_path = model_name + + if hf_config is not None: + check_image_text_to_text_pytorch_vs_kv_vs_ort_vs_ai100( + model_name=model_name, + kv_offload=kv_offload, + config=hf_config, + qaic_config={}, + test_kv_replicate=True, + manual_cleanup=manual_cleanup, + ) + else: + check_image_text_to_text_pytorch_vs_kv_vs_ort_vs_ai100( + model_name=model_name, + num_hidden_layers=model_config_dict[model_name]["num_layers"], + kv_offload=kv_offload, + qaic_config={}, + test_kv_replicate=True, + manual_cleanup=manual_cleanup, + ) + else: + pytest.skip(f"Skipping replicate KV test for {model_name} as it's not in REPEAT_KV_TEST_MODELS") + + ################################ QNN Tests ################################