Skip to content
45 changes: 36 additions & 9 deletions QEfficient/base/modeling_qeff.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
require_value,
to_named_specializations,
)
from QEfficient.utils.config_utils import calculate_num_replicate_kv_heads
from QEfficient.utils.export_utils import export_wrapper

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -655,23 +656,48 @@ def transform(
**compiler_options,
):
# Apply the transformations that are dependent on compilation parameters
def _transform_tracking_root(module: torch.nn.Module) -> torch.nn.Module:
"""
Use the shared wrapped model as transform-tracking root when available.
This lets encoder/decoder wrappers coordinate one-time transforms.
"""
wrapped = getattr(module, "model", None)
return wrapped if isinstance(wrapped, torch.nn.Module) else module

qaic_config = qaic_config if qaic_config else getattr(self.model, "qaic_config", None)

model_config = getattr(self.model, "config", None) or getattr(self.model.model, "config", None)
model_config = getattr(self.model, "config", None) or getattr(
getattr(self.model, "model", None), "config", None
)
num_replicate_kv_heads = 1
if model_config is not None:
num_replicate_kv_heads = calculate_num_replicate_kv_heads(
num_devices=num_devices,
text_model_config=model_config,
)

if model_config:
if "DeepseekV3ForCausalLM" in (getattr(model_config, "architectures", None) or []):
if qaic_config:
if qaic_config.get("blocking_mode", None) == "h":
qaic_config["head_block_size"] = qaic_config.get("head_block_size", num_devices)
num_kv_heads_repeat = qaic_config.get("num_kv_heads_repeat", 1)
if qaic_config is not None:
num_replicate_kv_heads = qaic_config.get("num_replicate_kv_heads", num_replicate_kv_heads)
qaic_config["num_replicate_kv_heads"] = num_replicate_kv_heads
transform_root = _transform_tracking_root(self.model)
applied_transforms = getattr(transform_root, "_qeff_runtime_transforms_applied", set())
should_apply_repeat_kv = num_replicate_kv_heads is not None and num_replicate_kv_heads > 1
if not should_apply_repeat_kv:
replicate_kv_transformed = False
elif ReplicateKVHeadTransform.__name__ in applied_transforms:
replicate_kv_transformed = False
logger.warning("Skipping RepeatKVTransform: already applied on this model instance.")
else:
self.model, replicate_kv_transformed = ReplicateKVHeadTransform.apply(
self.model, num_kv_heads_repeat
self.model,
num_replicate_kv_heads,
)
if replicate_kv_transformed:
self.hash_params["config"] = self.model.config.to_diff_dict()

applied_transforms.add(ReplicateKVHeadTransform.__name__)
setattr(transform_root, "_qeff_runtime_transforms_applied", applied_transforms)
if replicate_kv_transformed:
self.hash_params["config"] = self.model.config.to_diff_dict()
blocking_config = build_transformer_blocking_config_for_transform(
model_config,
ctx_len=ctx_len,
Expand All @@ -688,6 +714,7 @@ def transform(
if blocking_config is not None:
self.model, _ = BlockingAttentionTransform.apply(self.model, attn_blocking_config=blocking_config)
self.hash_params["blocking_kwargs"] = blocking_config
self.hash_params["num_replicate_kv_heads"] = num_replicate_kv_heads

@dump_qconfig
def _compile(
Expand Down
1 change: 1 addition & 0 deletions QEfficient/transformers/models/gemma3/modeling_gemma3.py
Original file line number Diff line number Diff line change
Expand Up @@ -626,6 +626,7 @@ class QEffGemma3EncoderWrapper(nn.Module):
def __init__(self, model):
super().__init__()
self.model = model.model
self.config = self.model.config
self.model.vision_model = self.model.vision_tower

def get_submodules_for_export(self) -> Type[nn.Module]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class QEffInternEncoderWrapper(nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
self.config = self.model.config

def get_submodules_for_export(self) -> Type[nn.Module]:
"""
Expand Down
1 change: 1 addition & 0 deletions QEfficient/transformers/models/llama4/modeling_llama4.py
Original file line number Diff line number Diff line change
Expand Up @@ -831,6 +831,7 @@ class QEffLlama4EncoderWrapper(nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
self.config = self.model.config

def get_submodules_for_export(self) -> Type[nn.Module]:
"""
Expand Down
1 change: 1 addition & 0 deletions QEfficient/transformers/models/llava/modeling_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def __init__(self, model):
super().__init__()
self.model = model
self.model.vision_model = self.model.model.vision_tower
self.config = self.model.config

def get_submodules_for_export(self) -> Type[nn.Module]:
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def __init__(self, model):
super().__init__()
self.model = model
self.model.vision_model = self.model.model.vision_tower
self.config = self.model.config

def get_submodules_for_export(self) -> Type[nn.Module]:
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ class QEFFMistral3EncoderWrapper(nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
self.config = self.model.config
self.model.model.vision_model = self.model.model.vision_tower

def get_submodules_for_export(self) -> Type[nn.Module]:
Expand Down
16 changes: 15 additions & 1 deletion QEfficient/transformers/models/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -1326,6 +1326,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: str, qaic_config: Option
kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False})

_resolve_torch_dtype(kwargs)
num_replicate_kv_heads = kwargs.pop("num_replicate_kv_heads", 1)
model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, **kwargs)

kwargs.update({"enable_proxy": enable_proxy} if enable_proxy else {})
Expand All @@ -1334,6 +1335,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: str, qaic_config: Option
model,
pretrained_model_name_or_path=pretrained_model_name_or_path,
qaic_config=qaic_config,
num_replicate_kv_heads=num_replicate_kv_heads,
**kwargs,
)

Expand Down Expand Up @@ -1433,7 +1435,12 @@ def export(
if prefill_only and prefill_seq_len > 1:
offload_pt_weights = False # to keep weight for decode onnx
else:
offload_pt_weights = kwargs.get("offload_pt_weights", True)
num_replicate_kv_heads = (
(self.lang_model.model.qaic_config or {}).get("num_replicate_kv_heads", 1)
if hasattr(self.lang_model.model, "qaic_config")
else 1
)
offload_pt_weights = kwargs.get("offload_pt_weights", num_replicate_kv_heads <= 1)

if not skip_lang and self.lang_model.onnx_path is None:
self.lang_model.export(
Expand Down Expand Up @@ -2228,6 +2235,7 @@ def from_pretrained(
config._attn_implementation = "eager"
config.vision_config.use_flash_attn = "false"
_resolve_torch_dtype(kwargs)
num_replicate_kv_heads = kwargs.pop("num_replicate_kv_heads", 1)
model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, config, *args, **kwargs)

kwargs.update({"enable_proxy": enable_proxy} if enable_proxy else {})
Expand All @@ -2236,6 +2244,7 @@ def from_pretrained(
model,
pretrained_model_name_or_path=pretrained_model_name_or_path,
qaic_config=qaic_config,
num_replicate_kv_heads=num_replicate_kv_heads,
**kwargs,
)

Expand Down Expand Up @@ -2875,6 +2884,7 @@ def from_pretrained(
kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False})

_resolve_torch_dtype(kwargs)
num_replicate_kv_heads = kwargs.pop("num_replicate_kv_heads", 1)
model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, **kwargs)

kwargs.update({"enable_proxy": enable_proxy} if enable_proxy else {})
Expand All @@ -2885,6 +2895,7 @@ def from_pretrained(
continuous_batching=continuous_batching,
pretrained_model_name_or_path=pretrained_model_name_or_path,
qaic_config=qaic_config,
num_replicate_kv_heads=num_replicate_kv_heads,
**kwargs,
)

Expand Down Expand Up @@ -3129,6 +3140,7 @@ def from_pretrained(
kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False})

_resolve_torch_dtype(kwargs)
num_replicate_kv_heads = kwargs.pop("num_replicate_kv_heads", 1)
model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
if qaic_config is not None:
qaic_config["pretrained_model_name_or_path"] = pretrained_model_name_or_path
Expand All @@ -3142,6 +3154,7 @@ def from_pretrained(
pretrained_model_name_or_path=pretrained_model_name_or_path,
qaic_config=qaic_config,
continuous_batching=continuous_batching,
num_replicate_kv_heads=num_replicate_kv_heads,
**kwargs,
)
return cls(
Expand All @@ -3150,6 +3163,7 @@ def from_pretrained(
qaic_config=qaic_config,
pretrained_model_name_or_path=pretrained_model_name_or_path,
max_seq_len_cached=max_seq_len_cached,
num_replicate_kv_heads=num_replicate_kv_heads,
**kwargs,
)

Expand Down
1 change: 1 addition & 0 deletions QEfficient/transformers/models/molmo/modeling_molmo.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,6 +565,7 @@ class QEffMolmoEncoderWrapper(nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
self.config = self.model.config

def get_submodules_for_export(self) -> Type[nn.Module]:
"""
Expand Down
Loading
Loading