Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions QEfficient/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1446,6 +1446,17 @@ def append_new_layers(self, layer_idx: int) -> None:
while len(self.layers) <= layer_idx:
self.layers.append(QEffGemma4DynamicLayer(is_sliding=self._is_sliding_layer(len(self.layers))))

def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: int,
*args,
**kwargs,
) -> tuple[torch.Tensor, torch.Tensor]:
self.append_new_layers(layer_idx)
return super().update(key_states, value_states, layer_idx, *args, **kwargs)

@classmethod
def from_legacy_cache(
cls,
Expand Down
68 changes: 67 additions & 1 deletion QEfficient/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,31 @@
)
from transformers.models.gemma4.modeling_gemma4 import (
Gemma4ForCausalLM,
Gemma4ForConditionalGeneration,
Gemma4RMSNorm,
Gemma4TextAttention,
Gemma4TextDecoderLayer,
Gemma4TextExperts,
Gemma4TextModel,
Gemma4TextRouter,
)

try:
from transformers.models.gemma4_unified.modeling_gemma4_unified import (
Gemma4UnifiedForCausalLM,
Gemma4UnifiedForConditionalGeneration,
Gemma4UnifiedRMSNorm,
Gemma4UnifiedTextAttention,
Gemma4UnifiedTextDecoderLayer,
Gemma4UnifiedTextModel,
)
except Exception:
Gemma4UnifiedForCausalLM = None
Gemma4UnifiedForConditionalGeneration = None
Gemma4UnifiedRMSNorm = None
Gemma4UnifiedTextAttention = None
Gemma4UnifiedTextDecoderLayer = None
Gemma4UnifiedTextModel = None
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2Block, GPT2LMHeadModel, GPT2Model
from transformers.models.gpt_bigcode.modeling_gpt_bigcode import (
GPTBigCodeAttention,
Expand Down Expand Up @@ -128,12 +146,30 @@
from .models.gemma4.modeling_gemma4 import (
QEffGemma4CustomRMSNormAIC,
QEffGemma4ForCausalLM,
QEffGemma4ForConditionalGeneration,
QEffGemma4TextAttention,
QEffGemma4TextDecoderLayer,
QEffGemma4TextExperts,
QEffGemma4TextModel,
QEffGemma4TextRouter,
)

try:
from .models.gemma4_unified.modeling_gemma4_unified import (
QEffGemma4UnifiedCustomRMSNormAIC,
QEffGemma4UnifiedForCausalLM,
QEffGemma4UnifiedForConditionalGeneration,
QEffGemma4UnifiedTextAttention,
QEffGemma4UnifiedTextDecoderLayer,
QEffGemma4UnifiedTextModel,
)
except Exception:
QEffGemma4UnifiedCustomRMSNormAIC = None
QEffGemma4UnifiedForCausalLM = None
QEffGemma4UnifiedForConditionalGeneration = None
QEffGemma4UnifiedTextAttention = None
QEffGemma4UnifiedTextDecoderLayer = None
QEffGemma4UnifiedTextModel = None
from .models.gpt2.modeling_gpt2 import QEffGPT2Attention, QEffGPT2Block, QEffGPT2LMHeadModel, QEffGPT2Model
from .models.gpt_bigcode.modeling_gpt_bigcode import (
QEffGPTBigCodeAttention,
Expand Down Expand Up @@ -197,6 +233,7 @@
GemmaForCausalLM.__name__,
Gemma2ForCausalLM.__name__,
Gemma4ForCausalLM.__name__,
Gemma4ForConditionalGeneration.__name__,
MistralForCausalLM.__name__,
MixtralForCausalLM.__name__,
Phi3ForCausalLM.__name__,
Expand All @@ -210,9 +247,25 @@
]
)

if Gemma4UnifiedForCausalLM is not None:
qeff_supported_architectures.architectures.extend(
[
Gemma4UnifiedForCausalLM.__name__,
Gemma4UnifiedForConditionalGeneration.__name__,
]
)


# This is for supporting different seq_len for different layers for Sliding window attn, chunked attn etc.
DYNAMIC_SEQ_LEN_SUPPORTED_MODEL_ARCH = {"gemma3", "gemma3_text", "gemma4_text", "llama4", "llama4_text"}
DYNAMIC_SEQ_LEN_SUPPORTED_MODEL_ARCH = {
"gemma3",
"gemma3_text",
"gemma4_text",
"gemma4_unified",
"gemma4_unified_text",
"llama4",
"llama4_text",
}

# This is for supporting different modelling classes specially written for prefill-only model
SPECIALIZED_DISAGG_SERVING_MODEL_ARCH = {"gpt_oss", "qwen3_moe", "glm4_moe", "kimi_k2", "kimi_k25"}
Expand Down Expand Up @@ -275,6 +328,7 @@ def _configure_proxy_for_model(instance: "QEFFBaseModel", enable_proxy: bool) ->
Gemma4TextAttention: QEffGemma4TextAttention,
Gemma4TextModel: QEffGemma4TextModel,
Gemma4ForCausalLM: QEffGemma4ForCausalLM,
Gemma4ForConditionalGeneration: QEffGemma4ForConditionalGeneration,
Gemma4TextDecoderLayer: QEffGemma4TextDecoderLayer,
Gemma4TextExperts: QEffGemma4TextExperts,
Gemma4TextRouter: QEffGemma4TextRouter,
Expand Down Expand Up @@ -340,6 +394,18 @@ def _configure_proxy_for_model(instance: "QEFFBaseModel", enable_proxy: bool) ->
WhisperForConditionalGeneration: QEffWhisperForConditionalGeneration,
}

if Gemma4UnifiedForCausalLM is not None:
TransformersToQEffModulesDict.update(
{
Gemma4UnifiedTextAttention: QEffGemma4UnifiedTextAttention,
Gemma4UnifiedTextModel: QEffGemma4UnifiedTextModel,
Gemma4UnifiedForCausalLM: QEffGemma4UnifiedForCausalLM,
Gemma4UnifiedForConditionalGeneration: QEffGemma4UnifiedForConditionalGeneration,
Gemma4UnifiedTextDecoderLayer: QEffGemma4UnifiedTextDecoderLayer,
Gemma4UnifiedRMSNorm: QEffGemma4UnifiedCustomRMSNormAIC,
}
)


def build_model_class_mapping(auto_model_class, qeff_class_name):
"""
Expand Down
4 changes: 3 additions & 1 deletion QEfficient/transformers/models/gemma3/modeling_gemma3.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@
Gemma3ForConditionalGeneration,
Gemma3TextConfig,
Gemma3TextModel,
logger,
repeat_kv,
rotate_half,
)
from transformers.utils import logging

from QEfficient.customop.rms_norm import CustomRMSNorm
from QEfficient.transformers.cache_utils import QEffSlidingWindowCache
Expand All @@ -33,6 +33,8 @@
from QEfficient.utils._utils import IOInfo
from QEfficient.utils.constants import MIN_MASKED_ATTENTION_VALUE

logger = logging.get_logger(__name__)


class GemmaRMSNormFunc(torch.autograd.Function):
@staticmethod
Expand Down
4 changes: 4 additions & 0 deletions QEfficient/transformers/models/gemma4_unified/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# -----------------------------------------------------------------------------
# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
# SPDX-License-Identifier: BSD-3-Clause
# -----------------------------------------------------------------------------
Loading
Loading