Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion QEfficient/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -734,10 +734,16 @@ def from_legacy_cache(
) -> "HybridCache":
"""Converts a cache in the legacy cache format into an equivalent `DynamicCache`. Used for
backward compatibility."""

# Get the sliding_window_pattern from config
sliding_window_pattern = getattr(
config, "_sliding_window_pattern", getattr(config, "sliding_window_pattern", None)
)

cache = cls(
config,
batch_size=past_key_values[0][0].shape[0],
max_cache_len=past_key_values[config.sliding_window_pattern - 1][0].shape[2],
max_cache_len=past_key_values[sliding_window_pattern - 1][0].shape[2],
sliding_window_len=past_key_values[0][0].shape[2],
)
if past_key_values is not None:
Expand Down
20 changes: 12 additions & 8 deletions QEfficient/transformers/models/gemma3/modeling_gemma3.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def __qeff_init__(self):
config = copy.deepcopy(self.config)
config.rope_theta = config.rope_local_base_freq
config.rope_scaling = {"rope_type": "default", "factor": 1.0}
self.is_local = _is_local(self.layer_idx, self.config.sliding_window_pattern)
self.is_local = _is_local(self.layer_idx, self.config._sliding_window_pattern)
self.window = self.config.sliding_window if self.is_local else None

self.rotary_emb_local = QEffGemma3RotaryEmbedding(
Expand Down Expand Up @@ -253,7 +253,7 @@ def forward(
"batch_index": batch_index,
"position_ids": position_ids,
"is_sliding": self.is_sliding,
"sliding_window_pattern": self.config.sliding_window_pattern,
"sliding_window_pattern": self.config._sliding_window_pattern,
"sliding_window": past_key_values.sliding_window_len,
}
if comp_ctx_lengths is not None:
Expand Down Expand Up @@ -322,7 +322,7 @@ def forward(
else:
attention_mask = _create_causal_mask(
position_ids=position_ids,
target_length=past_key_value.key_cache[self.config.sliding_window_pattern - 1].shape[-2],
target_length=past_key_value.key_cache[self.config._sliding_window_pattern - 1].shape[-2],
)

hidden_states, self_attn_weights = self.self_attn(
Expand Down Expand Up @@ -569,7 +569,9 @@ def forward(
def get_dummy_pkv_cache(self, config, batch_size, seq_len):
n_heads = config.num_key_value_heads
d_head = config.head_dim
layer_switch = config.sliding_window_pattern if hasattr(config, "sliding_window_pattern") else 2 # 2 is for BC
layer_switch = (
config._sliding_window_pattern if hasattr(config, "_sliding_window_pattern") else 2
) # 2 is for BC
is_sliding = torch.tensor(
[bool((i + 1) % layer_switch) for i in range(config.num_hidden_layers)], dtype=torch.bool
)
Expand Down Expand Up @@ -835,15 +837,15 @@ def get_onnx_dynamic_axes(
pkv_dynamic_axes = {0: "full_batch_size" if continuous_batching else "batch_size", 2: "ctx_len"}
pkv_dynamic_sliding_axes = {0: "full_batch_size" if continuous_batching else "batch_size", 2: "sliding_window"}
layer_switch = (
self.language_model.config.sliding_window_pattern
if hasattr(self.language_model.config, "sliding_window_pattern")
self.language_model.config._sliding_window_pattern
if hasattr(self.language_model.config, "_sliding_window_pattern")
else 2
)
for i in range(self.language_model.config.num_hidden_layers):
for kv in ["key", "value"]:
apply_dynamic_axes = (
pkv_dynamic_sliding_axes
if ((i + 1) % layer_switch and hasattr(self.language_model.config, "sliding_window_pattern"))
if ((i + 1) % layer_switch and hasattr(self.language_model.config, "_sliding_window_pattern"))
else pkv_dynamic_axes
)
lang_dynamic_axes[f"past_{kv}.{i}"] = apply_dynamic_axes
Expand Down Expand Up @@ -881,7 +883,9 @@ def get_output_names(self, kv_offload: bool = False):
def get_dummy_pkv_cache(self, config, batch_size, seq_len):
n_heads = config.num_key_value_heads
d_head = config.head_dim
layer_switch = config.sliding_window_pattern if hasattr(config, "sliding_window_pattern") else 2 # 2 is for BC
layer_switch = (
config._sliding_window_pattern if hasattr(config, "_sliding_window_pattern") else 2
) # 2 is for BC
is_sliding = torch.tensor(
[bool((i + 1) % layer_switch) for i in range(config.num_hidden_layers)], dtype=torch.bool
)
Expand Down
Loading