Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions QEfficient/transformers/models/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
PoolingTransform,
PrefillOnlyChunkedTransform,
PrefillOnlyTransform,
ReplicateKVHeadTransform,
RevertPrefillKeepAttentionTransform,
RevertPrefillOnlyTransform,
SamplerTransform,
Expand Down Expand Up @@ -887,6 +888,11 @@ def __init__(

self.vision_model = QEffVisionEncoderForTextImageToTextModel(model, **kwargs)
self.lang_model = QEffCausalLMForTextImageToTextModel(model, qaic_config=qaic_config, **kwargs)
self.model, replicate_kv_transformed = ReplicateKVHeadTransform.apply(self.model, **kwargs)
# Since both modules use the entire config for hash creation, we're updating the params for consistency.
if replicate_kv_transformed:
self.lang_model.hash_params["config"] = model.config.to_diff_dict()
self.vision_model.hash_params["config"] = model.config.to_diff_dict()
Comment on lines +893 to +895
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't we already dump config somewhere? in _generate_export_hash?
You can just always add repeat_kv_heads value to self.hash_params which will be 1 if nothing is passed.

self.continuous_batching = continuous_batching
self.ccl_enabled = False
if qaic_config:
Expand Down Expand Up @@ -1623,6 +1629,9 @@ def __init__(
self.model.config.text_config.use_cache = True
else:
self.model.config.use_cache = True
self.model, replicate_kv_transformed = ReplicateKVHeadTransform.apply(self.model, **kwargs)
if replicate_kv_transformed:
self.hash_params["config"] = model.config.to_diff_dict()
self.hash_params["qeff_auto_class"] = self.__class__.__name__
self.ccl_enabled = False
if qaic_config:
Expand Down Expand Up @@ -2257,7 +2266,9 @@ def from_pretrained(
logger.warning("Updating low_cpu_mem_usage=False")

kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False})
num_kv_heads_repeat = kwargs.pop("num_kv_heads_repeat", None)
model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, **kwargs)
kwargs.update({"num_kv_heads_repeat": num_kv_heads_repeat})
return cls(
model,
kv_offload=kv_offload,
Expand Down Expand Up @@ -2384,6 +2395,9 @@ def __init__(

setattr(model.config, "max_seq_len_cached", max_seq_len_cached)
super().__init__(model, qaic_config=qaic_config, **kwargs)
self.model, replicate_kv_transformed = ReplicateKVHeadTransform.apply(self.model, **kwargs)
if replicate_kv_transformed:
self.hash_params["config"] = model.config.to_diff_dict()
Comment on lines +2398 to +2400
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

better add it to _pytorch_transforms if we are always going to call it.

self.num_layers = model.config.num_hidden_layers
self.continuous_batching = continuous_batching
self.model.qaic_config = qaic_config
Expand Down Expand Up @@ -2481,7 +2495,10 @@ def from_pretrained(
kv_offload = kwargs.pop("kv_offload", None)

kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False})
# InternVL causes an error if we pass the num_kv_heads_repeat parameter
num_kv_heads_repeat = kwargs.pop("num_kv_heads_repeat", 1)
Comment on lines +2498 to +2499
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not needed

model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
kwargs.update({"num_kv_heads_repeat": num_kv_heads_repeat})
if qaic_config is not None:
qaic_config["pretrained_model_name_or_path"] = pretrained_model_name_or_path

Expand Down
163 changes: 163 additions & 0 deletions QEfficient/transformers/models/pytorch_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from types import MethodType
from typing import Callable, Optional, Tuple, Union

import torch
from torch import nn
from transformers.models.codegen.modeling_codegen import (
CodeGenAttention,
Expand Down Expand Up @@ -446,8 +447,12 @@
QEffWhisperPositionalEmbedding,
)
from QEfficient.transformers.post_processing import build_and_attach_mlp, model_type_registry
from QEfficient.transformers.quantizers.awq import WQLinear_GEMM
from QEfficient.transformers.quantizers.gptq import QuantLinearGPTQ
from QEfficient.transformers.quantizers.quantizer_compressed_tensors import FP8DeQuantLinear
from QEfficient.transformers.sampler.sampler import sampler_forward
from QEfficient.transformers.spd.spd_transform_forward import tlm_forward
from QEfficient.utils.logging_utils import logger

SPD_TARGET = "target"

Expand Down Expand Up @@ -686,6 +691,164 @@ class RevertPrefillOnlyTransform(ModuleMappingTransform):
}


class ReplicateKVHeadTransform:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make this inherit ModuleMutatorTransform
You may need to implement mutate method which is similar to apply here

"""
Replicates KV heads in attention modules to match the number of KV heads in the target model.
This transform is used when the source model has fewer KV heads than required in target model.
"""

_module_mapping = {
QEffCodeGenForCausalLM,
QEffFalconForCausalLM,
QEffGPT2LMHeadModel,
QEffGPTJForCausalLM,
QEffLlamaForCausalLM,
QEffLlama4ForConditionalGeneration,
QEffLlavaForConditionalGeneration,
QEffLlavaNextForConditionalGeneration,
QEffGemmaForCausalLM,
QEffGemma2ForCausalLM,
QEffGemma3ForConditionalGeneration,
QEffGraniteForCausalLM,
QEffGraniteMoeForCausalLM,
QEffMllamaForConditionalGeneration,
QEffMistralForCausalLM,
QEffMistral3ForConditionalGeneration,
QEffMptForCausalLM,
QEffPhiForCausalLM,
QEffPhi3ForCausalLM,
QEffQwen2ForCausalLM,
QEffQwen3ForCausalLM,
QEffQwen_2_5_vl_ForConditionalGeneration,
QEffQwen3MoeForCausalLM,
QEffStarcoder2ForCausalLM,
QEffGPTBigCodeForCausalLM,
QEffOlmo2ForCausalLM,
}
_module_string_mapping = {
"InternVLChatModel",
}

def _duplicate_weights_for_linear_layer(
layer: nn.Module, orig_kv_heads: int, repeat: int, head_dim: int, hidden_size: int
):
new_kv_heads = repeat * orig_kv_heads
if isinstance(layer, (WQLinear_GEMM, QuantLinearGPTQ)):
if head_dim % 8 != 0:
raise ValueError(
f"the value head_dim={head_dim} is not divisible by 8 which is \
according to the assumption that model is 4-bit quantized."
)
if hidden_size % layer.group_size != 0:
raise ValueError(
f"The value of hidden_size={hidden_size} is not divisible by \
K_proj.group_size={layer.group_size}"
)

# Duplication of quantized weights
layer.qweight.data = torch.repeat_interleave(
layer.qweight.data.view(hidden_size, orig_kv_heads, head_dim // 8), repeat, 1
).view(hidden_size, (new_kv_heads * head_dim) // 8)
# Duplication of quantized zero points
layer.qzeros.data = torch.repeat_interleave(
layer.qzeros.data.view(hidden_size // layer.group_size, orig_kv_heads, head_dim // 8),
repeat,
1,
).view(hidden_size // layer.group_size, (new_kv_heads * head_dim) // 8)
# Duplication of quantization scales
layer.scales.data = torch.repeat_interleave(
layer.scales.data.view(hidden_size // layer.group_size, orig_kv_heads, head_dim),
repeat,
1,
).view(hidden_size // layer.group_size, new_kv_heads * head_dim)
if layer.bias is not None:
layer.bias.data = torch.repeat_interleave(
layer.bias.data.view(orig_kv_heads, head_dim), repeat, 0
).view(new_kv_heads * head_dim)
layer.out_features = layer.out_features * repeat

elif isinstance(layer, FP8DeQuantLinear):
layer.weight.data = torch.repeat_interleave(
layer.weight.data.view(orig_kv_heads, head_dim, hidden_size), repeat, 0
).view(new_kv_heads * head_dim, hidden_size)
layer.weight_scale.data = torch.repeat_interleave(
layer.weight_scale.data.view(orig_kv_heads, head_dim), repeat, 0
).view(new_kv_heads * head_dim, -1)

else:
layer.weight.data = torch.repeat_interleave(
layer.weight.data.view(orig_kv_heads, head_dim, hidden_size), repeat, 0
).view(new_kv_heads * head_dim, hidden_size)
if layer.bias is not None:
layer.bias.data = torch.repeat_interleave(
layer.bias.data.view(orig_kv_heads, head_dim), repeat, 0
).view(new_kv_heads * head_dim)
if layer.bias is not None:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lines 782-785 are repeated here, please remove

layer.bias.data = torch.repeat_interleave(
layer.bias.data.view(orig_kv_heads, head_dim), repeat, 0
).view(new_kv_heads * head_dim)

def _get_text_model(model):
"""
Determine and return the appropriate text_model from a given model object.
"""
# Check for VLMs
if hasattr(model, "language_model"):
if hasattr(model.language_model, "model"):
return model.language_model.model
else:
return model.language_model
# Check for CausalLMs
if hasattr(model, "model"):
return model.model

raise AttributeError("No suitable text model found in the provided model.")

@classmethod
def apply(cls, model: nn.Module, **kwargs) -> nn.Module:
"""
Replicates KV heads in attention modules based on provided multiplier.

Args:
model: The model to apply the transform to.
kwargs: Additional arguments for the transformation. Includes:
- num_kv_heads_repeat: The number of times to repeat the KV heads.
"""
n_repeat = kwargs.pop("num_kv_heads_repeat", 1)
transformed = False
if n_repeat is not None and n_repeat > 1:
if (model.__class__ in cls._module_mapping) or (model.__class__.__name__ in cls._module_string_mapping):
text_model = cls._get_text_model(model)

orig_kv_heads = text_model.config.num_key_value_heads
new_kv_heads = n_repeat * orig_kv_heads
text_model.config.orig_kv_heads = orig_kv_heads
text_model.config.num_key_value_heads = new_kv_heads

num_attention_heads = text_model.config.num_attention_heads
hidden_size = text_model.config.hidden_size

logger.warning(f"Original KV heads: {orig_kv_heads}")
logger.warning(f"Modified KV heads: {new_kv_heads}")
transformed = True
for block in text_model.layers:
attn = getattr(block, "cross_attn", getattr(block, "self_attn", None))
attn.num_key_value_heads = new_kv_heads
attn.num_key_value_groups = num_attention_heads // new_kv_heads

cls._duplicate_weights_for_linear_layer(
attn.k_proj, orig_kv_heads, n_repeat, attn.head_dim, hidden_size
)
cls._duplicate_weights_for_linear_layer(
attn.v_proj, orig_kv_heads, n_repeat, attn.head_dim, hidden_size
)
else:
raise NotImplementedError(
f"Model class {model.__class__.__name__} is not supported for KV head replication."
)
return model, transformed


class SpDTransform:
"""
Apply generic QEffForCausalLM forward pass to extract `num_speculative_tokens+1` hidden states before computing logits during decode phase and extract last predicted token during prefill.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@
# "https://image.slidesharecdn.com/azureintroduction-191206101932/75/Introduction-to-Microsoft-Azure-Cloud-1-2048.jpg",
# "Please describe the image in detail.",
# 2,
# ), # commented becuase QNN Convertor is not supported for this model yet.
# ),
]

molmo_model_config = [
Expand Down Expand Up @@ -249,6 +249,14 @@ def set_num_layers(config, n_layer=1):
return config


def get_text_config(config):
if hasattr(config, "text_config"):
return config.text_config
elif hasattr(config, "llm_config"):
return config.llm_config
return config


def check_image_text_to_text_pytorch_vs_kv_vs_ort_vs_ai100(
model_name: str,
img_size: int,
Expand All @@ -263,6 +271,8 @@ def check_image_text_to_text_pytorch_vs_kv_vs_ort_vs_ai100(
num_devices: int = 1,
enable_qnn: Optional[bool] = False,
qnn_config: Optional[str] = None,
num_kv_heads_repeat: Optional[int] = None,
test_kv_replicate: Optional[bool] = None,
):
model_config = {"model_name": model_name}
model_config["img_size"] = img_size
Expand Down Expand Up @@ -304,10 +314,15 @@ def check_image_text_to_text_pytorch_vs_kv_vs_ort_vs_ai100(
inputs["pixel_values"] = inputs["pixel_values"].to(torch.float32)
streamer = TextStreamer(processor.tokenizer)
pytorch_hf_tokens = api_runner.run_vlm_hf_model_on_pytorch(model_hf, inputs)
if test_kv_replicate:
text_config = get_text_config(config)
num_kv_heads_repeat = text_config.num_attention_heads // text_config.num_key_value_heads

qeff_model = QEFFAutoModelForImageTextToText.from_pretrained(
model_config["model_name"],
kv_offload=kv_offload,
config=config,
num_kv_heads_repeat=num_kv_heads_repeat,
)

# pytorch_kv_tokens = api_runner.run_vlm_kv_model_on_pytorch(qeff_model.model)
Expand Down Expand Up @@ -428,6 +443,8 @@ def check_intern_image_text_to_text_pytorch_vs_kv_vs_ort_vs_ai100(
num_devices: int = 1,
enable_qnn: Optional[bool] = False,
qnn_config: Optional[str] = None,
num_kv_heads_repeat: Optional[int] = None,
test_kv_replicate: Optional[bool] = None,
):
model_config = {"model_name": model_name}

Expand Down Expand Up @@ -490,10 +507,15 @@ def check_intern_image_text_to_text_pytorch_vs_kv_vs_ort_vs_ai100(
)
pytorch_hf_tokens = api_runner.run_vlm_hf_model_on_pytorch(model_hf, inputs, generation_config)

if test_kv_replicate:
text_config = get_text_config(config)
num_kv_heads_repeat = text_config.num_attention_heads // text_config.num_key_value_heads

qeff_model = QEFFAutoModelForCausalLM.from_pretrained(
model_config["model_name"],
kv_offload=kv_offload,
config=config,
num_kv_heads_repeat=num_kv_heads_repeat,
)
# pytorch_kv_tokens = api_runner.run_vlm_kv_model_on_pytorch(qeff_model.model)
# assert (pytorch_hf_tokens == pytorch_kv_tokens).all(), (
Expand Down Expand Up @@ -551,6 +573,34 @@ def test_image_text_to_text_pytorch_vs_kv_vs_ort_vs_ai100(
)


@pytest.mark.on_qaic
@pytest.mark.multimodal
@pytest.mark.parametrize(
"model_name, kv_offload, batch_size, prompt_len, ctx_len, img_size, img_url, query, n_layer", test_models_config
)
def test_replicate_kv_pytorch_vs_ai100(
model_name, kv_offload, batch_size, prompt_len, ctx_len, img_size, img_url, query, n_layer
):
"""
Test function to validate the PyTorch model, the PyTorch model after KV changes, the ONNX model, and the Cloud AI 100 model, without continuous batching.
``Mandatory`` Args:
:model_name (str): Hugging Face Model Card name, Example: ``gpt2``
"""
check_image_text_to_text_pytorch_vs_kv_vs_ort_vs_ai100(
model_name=model_name,
prompt_len=prompt_len,
ctx_len=ctx_len,
max_gen_len=NEW_GENERATION_TOKENS,
img_size=img_size,
img_url=img_url,
query=query,
n_layer=n_layer,
batch_size=batch_size,
kv_offload=kv_offload,
test_kv_replicate=True,
)


@pytest.mark.on_qaic
@pytest.mark.qnn
@pytest.mark.multimodal
Expand Down Expand Up @@ -608,6 +658,28 @@ def test_image_text_to_text_molmo_pytorch_vs_kv_vs_ort_vs_ai100(
)


@pytest.mark.on_qaic
@pytest.mark.multimodal
@pytest.mark.parametrize(
"model_name, kv_offload, batch_size, prompt_len, ctx_len, img_url, query, n_layer", intern_model_config
)
def test_replicate_kv_intern_pytorch_vs_ai100(
model_name, kv_offload, batch_size, prompt_len, ctx_len, img_url, query, n_layer
):
check_intern_image_text_to_text_pytorch_vs_kv_vs_ort_vs_ai100(
model_name=model_name,
prompt_len=prompt_len,
ctx_len=ctx_len,
max_gen_len=NEW_GENERATION_TOKENS,
img_url=img_url,
query=query,
n_layer=n_layer,
batch_size=batch_size,
kv_offload=kv_offload,
test_kv_replicate=True,
)


@pytest.mark.on_qaic
@pytest.mark.multimodal
@pytest.mark.parametrize(
Expand Down
Loading
Loading