diff --git a/QEfficient/transformers/cache_utils.py b/QEfficient/transformers/cache_utils.py index 6ebccdfbf8..508d9050d4 100644 --- a/QEfficient/transformers/cache_utils.py +++ b/QEfficient/transformers/cache_utils.py @@ -734,10 +734,16 @@ def from_legacy_cache( ) -> "HybridCache": """Converts a cache in the legacy cache format into an equivalent `DynamicCache`. Used for backward compatibility.""" + + # Get the sliding_window_pattern from config + sliding_window_pattern = getattr( + config, "_sliding_window_pattern", getattr(config, "sliding_window_pattern", None) + ) + cache = cls( config, batch_size=past_key_values[0][0].shape[0], - max_cache_len=past_key_values[config.sliding_window_pattern - 1][0].shape[2], + max_cache_len=past_key_values[sliding_window_pattern - 1][0].shape[2], sliding_window_len=past_key_values[0][0].shape[2], ) if past_key_values is not None: diff --git a/QEfficient/transformers/models/gemma3/modeling_gemma3.py b/QEfficient/transformers/models/gemma3/modeling_gemma3.py index 3d5a19bf96..8fb8cdbdda 100644 --- a/QEfficient/transformers/models/gemma3/modeling_gemma3.py +++ b/QEfficient/transformers/models/gemma3/modeling_gemma3.py @@ -198,7 +198,7 @@ def __qeff_init__(self): config = copy.deepcopy(self.config) config.rope_theta = config.rope_local_base_freq config.rope_scaling = {"rope_type": "default", "factor": 1.0} - self.is_local = _is_local(self.layer_idx, self.config.sliding_window_pattern) + self.is_local = _is_local(self.layer_idx, self.config._sliding_window_pattern) self.window = self.config.sliding_window if self.is_local else None self.rotary_emb_local = QEffGemma3RotaryEmbedding( @@ -253,7 +253,7 @@ def forward( "batch_index": batch_index, "position_ids": position_ids, "is_sliding": self.is_sliding, - "sliding_window_pattern": self.config.sliding_window_pattern, + "sliding_window_pattern": self.config._sliding_window_pattern, "sliding_window": past_key_values.sliding_window_len, } if comp_ctx_lengths is not None: @@ -322,7 +322,7 @@ def forward( else: attention_mask = _create_causal_mask( position_ids=position_ids, - target_length=past_key_value.key_cache[self.config.sliding_window_pattern - 1].shape[-2], + target_length=past_key_value.key_cache[self.config._sliding_window_pattern - 1].shape[-2], ) hidden_states, self_attn_weights = self.self_attn( @@ -569,7 +569,9 @@ def forward( def get_dummy_pkv_cache(self, config, batch_size, seq_len): n_heads = config.num_key_value_heads d_head = config.head_dim - layer_switch = config.sliding_window_pattern if hasattr(config, "sliding_window_pattern") else 2 # 2 is for BC + layer_switch = ( + config._sliding_window_pattern if hasattr(config, "_sliding_window_pattern") else 2 + ) # 2 is for BC is_sliding = torch.tensor( [bool((i + 1) % layer_switch) for i in range(config.num_hidden_layers)], dtype=torch.bool ) @@ -835,15 +837,15 @@ def get_onnx_dynamic_axes( pkv_dynamic_axes = {0: "full_batch_size" if continuous_batching else "batch_size", 2: "ctx_len"} pkv_dynamic_sliding_axes = {0: "full_batch_size" if continuous_batching else "batch_size", 2: "sliding_window"} layer_switch = ( - self.language_model.config.sliding_window_pattern - if hasattr(self.language_model.config, "sliding_window_pattern") + self.language_model.config._sliding_window_pattern + if hasattr(self.language_model.config, "_sliding_window_pattern") else 2 ) for i in range(self.language_model.config.num_hidden_layers): for kv in ["key", "value"]: apply_dynamic_axes = ( pkv_dynamic_sliding_axes - if ((i + 1) % layer_switch and hasattr(self.language_model.config, "sliding_window_pattern")) + if ((i + 1) % layer_switch and hasattr(self.language_model.config, "_sliding_window_pattern")) else pkv_dynamic_axes ) lang_dynamic_axes[f"past_{kv}.{i}"] = apply_dynamic_axes @@ -881,7 +883,9 @@ def get_output_names(self, kv_offload: bool = False): def get_dummy_pkv_cache(self, config, batch_size, seq_len): n_heads = config.num_key_value_heads d_head = config.head_dim - layer_switch = config.sliding_window_pattern if hasattr(config, "sliding_window_pattern") else 2 # 2 is for BC + layer_switch = ( + config._sliding_window_pattern if hasattr(config, "_sliding_window_pattern") else 2 + ) # 2 is for BC is_sliding = torch.tensor( [bool((i + 1) % layer_switch) for i in range(config.num_hidden_layers)], dtype=torch.bool )