-
Notifications
You must be signed in to change notification settings - Fork 88
Created ReplicateKVHeadTransform to integrate KV-heads replication module within Qefficient library. #625
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Created ReplicateKVHeadTransform to integrate KV-heads replication module within Qefficient library. #625
Changes from all commits
f64970c
df4ffad
bd370b8
22dc3e4
08032e1
fb21640
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -50,6 +50,7 @@ | |
| PoolingTransform, | ||
| PrefillOnlyChunkedTransform, | ||
| PrefillOnlyTransform, | ||
| ReplicateKVHeadTransform, | ||
| RevertPrefillKeepAttentionTransform, | ||
| RevertPrefillOnlyTransform, | ||
| SamplerTransform, | ||
|
|
@@ -887,6 +888,11 @@ def __init__( | |
|
|
||
| self.vision_model = QEffVisionEncoderForTextImageToTextModel(model, **kwargs) | ||
| self.lang_model = QEffCausalLMForTextImageToTextModel(model, qaic_config=qaic_config, **kwargs) | ||
| self.model, replicate_kv_transformed = ReplicateKVHeadTransform.apply(self.model, **kwargs) | ||
| # Since both modules use the entire config for hash creation, we're updating the params for consistency. | ||
| if replicate_kv_transformed: | ||
| self.lang_model.hash_params["config"] = model.config.to_diff_dict() | ||
| self.vision_model.hash_params["config"] = model.config.to_diff_dict() | ||
| self.continuous_batching = continuous_batching | ||
| self.ccl_enabled = False | ||
| if qaic_config: | ||
|
|
@@ -1623,6 +1629,9 @@ def __init__( | |
| self.model.config.text_config.use_cache = True | ||
| else: | ||
| self.model.config.use_cache = True | ||
| self.model, replicate_kv_transformed = ReplicateKVHeadTransform.apply(self.model, **kwargs) | ||
| if replicate_kv_transformed: | ||
| self.hash_params["config"] = model.config.to_diff_dict() | ||
| self.hash_params["qeff_auto_class"] = self.__class__.__name__ | ||
| self.ccl_enabled = False | ||
| if qaic_config: | ||
|
|
@@ -2257,7 +2266,9 @@ def from_pretrained( | |
| logger.warning("Updating low_cpu_mem_usage=False") | ||
|
|
||
| kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False}) | ||
| num_kv_heads_repeat = kwargs.pop("num_kv_heads_repeat", None) | ||
| model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, **kwargs) | ||
| kwargs.update({"num_kv_heads_repeat": num_kv_heads_repeat}) | ||
| return cls( | ||
| model, | ||
| kv_offload=kv_offload, | ||
|
|
@@ -2384,6 +2395,9 @@ def __init__( | |
|
|
||
| setattr(model.config, "max_seq_len_cached", max_seq_len_cached) | ||
| super().__init__(model, qaic_config=qaic_config, **kwargs) | ||
| self.model, replicate_kv_transformed = ReplicateKVHeadTransform.apply(self.model, **kwargs) | ||
| if replicate_kv_transformed: | ||
| self.hash_params["config"] = model.config.to_diff_dict() | ||
|
Comment on lines
+2398
to
+2400
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. better add it to |
||
| self.num_layers = model.config.num_hidden_layers | ||
| self.continuous_batching = continuous_batching | ||
| self.model.qaic_config = qaic_config | ||
|
|
@@ -2481,7 +2495,10 @@ def from_pretrained( | |
| kv_offload = kwargs.pop("kv_offload", None) | ||
|
|
||
| kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False}) | ||
| # InternVL causes an error if we pass the num_kv_heads_repeat parameter | ||
| num_kv_heads_repeat = kwargs.pop("num_kv_heads_repeat", 1) | ||
|
Comment on lines
+2498
to
+2499
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. not needed |
||
| model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, *args, **kwargs) | ||
| kwargs.update({"num_kv_heads_repeat": num_kv_heads_repeat}) | ||
| if qaic_config is not None: | ||
| qaic_config["pretrained_model_name_or_path"] = pretrained_model_name_or_path | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -10,6 +10,7 @@ | |
| from types import MethodType | ||
| from typing import Callable, Optional, Tuple, Union | ||
|
|
||
| import torch | ||
| from torch import nn | ||
| from transformers.models.codegen.modeling_codegen import ( | ||
| CodeGenAttention, | ||
|
|
@@ -446,8 +447,12 @@ | |
| QEffWhisperPositionalEmbedding, | ||
| ) | ||
| from QEfficient.transformers.post_processing import build_and_attach_mlp, model_type_registry | ||
| from QEfficient.transformers.quantizers.awq import WQLinear_GEMM | ||
| from QEfficient.transformers.quantizers.gptq import QuantLinearGPTQ | ||
| from QEfficient.transformers.quantizers.quantizer_compressed_tensors import FP8DeQuantLinear | ||
| from QEfficient.transformers.sampler.sampler import sampler_forward | ||
| from QEfficient.transformers.spd.spd_transform_forward import tlm_forward | ||
| from QEfficient.utils.logging_utils import logger | ||
|
|
||
| SPD_TARGET = "target" | ||
|
|
||
|
|
@@ -686,6 +691,164 @@ class RevertPrefillOnlyTransform(ModuleMappingTransform): | |
| } | ||
|
|
||
|
|
||
| class ReplicateKVHeadTransform: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Make this inherit |
||
| """ | ||
| Replicates KV heads in attention modules to match the number of KV heads in the target model. | ||
| This transform is used when the source model has fewer KV heads than required in target model. | ||
| """ | ||
|
|
||
| _module_mapping = { | ||
| QEffCodeGenForCausalLM, | ||
| QEffFalconForCausalLM, | ||
| QEffGPT2LMHeadModel, | ||
| QEffGPTJForCausalLM, | ||
| QEffLlamaForCausalLM, | ||
| QEffLlama4ForConditionalGeneration, | ||
| QEffLlavaForConditionalGeneration, | ||
| QEffLlavaNextForConditionalGeneration, | ||
| QEffGemmaForCausalLM, | ||
| QEffGemma2ForCausalLM, | ||
| QEffGemma3ForConditionalGeneration, | ||
| QEffGraniteForCausalLM, | ||
| QEffGraniteMoeForCausalLM, | ||
| QEffMllamaForConditionalGeneration, | ||
| QEffMistralForCausalLM, | ||
| QEffMistral3ForConditionalGeneration, | ||
| QEffMptForCausalLM, | ||
| QEffPhiForCausalLM, | ||
| QEffPhi3ForCausalLM, | ||
| QEffQwen2ForCausalLM, | ||
| QEffQwen3ForCausalLM, | ||
| QEffQwen_2_5_vl_ForConditionalGeneration, | ||
| QEffQwen3MoeForCausalLM, | ||
| QEffStarcoder2ForCausalLM, | ||
| QEffGPTBigCodeForCausalLM, | ||
| QEffOlmo2ForCausalLM, | ||
| } | ||
| _module_string_mapping = { | ||
| "InternVLChatModel", | ||
| } | ||
|
|
||
| def _duplicate_weights_for_linear_layer( | ||
| layer: nn.Module, orig_kv_heads: int, repeat: int, head_dim: int, hidden_size: int | ||
| ): | ||
| new_kv_heads = repeat * orig_kv_heads | ||
| if isinstance(layer, (WQLinear_GEMM, QuantLinearGPTQ)): | ||
| if head_dim % 8 != 0: | ||
| raise ValueError( | ||
| f"the value head_dim={head_dim} is not divisible by 8 which is \ | ||
| according to the assumption that model is 4-bit quantized." | ||
| ) | ||
| if hidden_size % layer.group_size != 0: | ||
| raise ValueError( | ||
| f"The value of hidden_size={hidden_size} is not divisible by \ | ||
| K_proj.group_size={layer.group_size}" | ||
| ) | ||
|
|
||
| # Duplication of quantized weights | ||
| layer.qweight.data = torch.repeat_interleave( | ||
| layer.qweight.data.view(hidden_size, orig_kv_heads, head_dim // 8), repeat, 1 | ||
| ).view(hidden_size, (new_kv_heads * head_dim) // 8) | ||
| # Duplication of quantized zero points | ||
| layer.qzeros.data = torch.repeat_interleave( | ||
| layer.qzeros.data.view(hidden_size // layer.group_size, orig_kv_heads, head_dim // 8), | ||
| repeat, | ||
| 1, | ||
| ).view(hidden_size // layer.group_size, (new_kv_heads * head_dim) // 8) | ||
| # Duplication of quantization scales | ||
| layer.scales.data = torch.repeat_interleave( | ||
| layer.scales.data.view(hidden_size // layer.group_size, orig_kv_heads, head_dim), | ||
| repeat, | ||
| 1, | ||
| ).view(hidden_size // layer.group_size, new_kv_heads * head_dim) | ||
| if layer.bias is not None: | ||
| layer.bias.data = torch.repeat_interleave( | ||
| layer.bias.data.view(orig_kv_heads, head_dim), repeat, 0 | ||
| ).view(new_kv_heads * head_dim) | ||
| layer.out_features = layer.out_features * repeat | ||
|
|
||
| elif isinstance(layer, FP8DeQuantLinear): | ||
| layer.weight.data = torch.repeat_interleave( | ||
| layer.weight.data.view(orig_kv_heads, head_dim, hidden_size), repeat, 0 | ||
| ).view(new_kv_heads * head_dim, hidden_size) | ||
| layer.weight_scale.data = torch.repeat_interleave( | ||
| layer.weight_scale.data.view(orig_kv_heads, head_dim), repeat, 0 | ||
| ).view(new_kv_heads * head_dim, -1) | ||
|
|
||
| else: | ||
| layer.weight.data = torch.repeat_interleave( | ||
| layer.weight.data.view(orig_kv_heads, head_dim, hidden_size), repeat, 0 | ||
| ).view(new_kv_heads * head_dim, hidden_size) | ||
| if layer.bias is not None: | ||
| layer.bias.data = torch.repeat_interleave( | ||
| layer.bias.data.view(orig_kv_heads, head_dim), repeat, 0 | ||
| ).view(new_kv_heads * head_dim) | ||
| if layer.bias is not None: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. lines 782-785 are repeated here, please remove |
||
| layer.bias.data = torch.repeat_interleave( | ||
| layer.bias.data.view(orig_kv_heads, head_dim), repeat, 0 | ||
| ).view(new_kv_heads * head_dim) | ||
|
|
||
| def _get_text_model(model): | ||
| """ | ||
| Determine and return the appropriate text_model from a given model object. | ||
| """ | ||
| # Check for VLMs | ||
| if hasattr(model, "language_model"): | ||
| if hasattr(model.language_model, "model"): | ||
| return model.language_model.model | ||
| else: | ||
| return model.language_model | ||
| # Check for CausalLMs | ||
| if hasattr(model, "model"): | ||
| return model.model | ||
|
|
||
| raise AttributeError("No suitable text model found in the provided model.") | ||
|
|
||
| @classmethod | ||
| def apply(cls, model: nn.Module, **kwargs) -> nn.Module: | ||
| """ | ||
| Replicates KV heads in attention modules based on provided multiplier. | ||
|
|
||
| Args: | ||
| model: The model to apply the transform to. | ||
| kwargs: Additional arguments for the transformation. Includes: | ||
| - num_kv_heads_repeat: The number of times to repeat the KV heads. | ||
| """ | ||
| n_repeat = kwargs.pop("num_kv_heads_repeat", 1) | ||
| transformed = False | ||
| if n_repeat is not None and n_repeat > 1: | ||
| if (model.__class__ in cls._module_mapping) or (model.__class__.__name__ in cls._module_string_mapping): | ||
| text_model = cls._get_text_model(model) | ||
|
|
||
| orig_kv_heads = text_model.config.num_key_value_heads | ||
| new_kv_heads = n_repeat * orig_kv_heads | ||
| text_model.config.orig_kv_heads = orig_kv_heads | ||
| text_model.config.num_key_value_heads = new_kv_heads | ||
|
|
||
| num_attention_heads = text_model.config.num_attention_heads | ||
| hidden_size = text_model.config.hidden_size | ||
|
|
||
| logger.warning(f"Original KV heads: {orig_kv_heads}") | ||
| logger.warning(f"Modified KV heads: {new_kv_heads}") | ||
| transformed = True | ||
| for block in text_model.layers: | ||
| attn = getattr(block, "cross_attn", getattr(block, "self_attn", None)) | ||
| attn.num_key_value_heads = new_kv_heads | ||
| attn.num_key_value_groups = num_attention_heads // new_kv_heads | ||
|
|
||
| cls._duplicate_weights_for_linear_layer( | ||
| attn.k_proj, orig_kv_heads, n_repeat, attn.head_dim, hidden_size | ||
| ) | ||
| cls._duplicate_weights_for_linear_layer( | ||
| attn.v_proj, orig_kv_heads, n_repeat, attn.head_dim, hidden_size | ||
| ) | ||
| else: | ||
| raise NotImplementedError( | ||
| f"Model class {model.__class__.__name__} is not supported for KV head replication." | ||
| ) | ||
| return model, transformed | ||
|
|
||
|
|
||
| class SpDTransform: | ||
| """ | ||
| Apply generic QEffForCausalLM forward pass to extract `num_speculative_tokens+1` hidden states before computing logits during decode phase and extract last predicted token during prefill. | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
don't we already dump config somewhere? in
_generate_export_hash?You can just always add
repeat_kv_headsvalue to self.hash_params which will be 1 if nothing is passed.