Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 51 additions & 1 deletion QEfficient/base/modeling_qeff.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def _transform_names(self) -> List[str]:
def __init__(self, model: torch.nn.Module, **kwargs) -> None:
super().__init__()
self.model = model
self.config = model.config
self.hash_params = create_model_params(self, **kwargs)
self.onnx_path: Optional[str] = None
self.qpc_path: Optional[str] = None
Expand All @@ -77,11 +78,51 @@ def __init__(self, model: torch.nn.Module, **kwargs) -> None:
self.model, transformed = transform.apply(self.model)
any_transformed = any_transformed or transformed

self._normalize_torch_dtype()

if not any_transformed:
warnings.warn(f"No transforms applied to model: {self.model_name}. It may be an unsupported model!")
else:
logger.info(f"Pytorch transforms applied to model: {self.model_name}")

if self.config.torch_dtype == torch.bfloat16:
logger.warning("BFloat16 dtype is not yet supported; converting to float16 precision!")

def _normalize_torch_dtype(self):
"""
Normalizes torch_dtype across all nested configs to match the top-level config.

This method ensures consistency by propagating the top-level torch_dtype
to all nested configs (llm_config, vision_config, etc.) that may exist in
multimodal models.
"""
top_level_dtype = getattr(self.config, "torch_dtype", torch.float32)

if top_level_dtype is None:
top_level_dtype = torch.float32
elif isinstance(top_level_dtype, str):
top_level_dtype = getattr(torch, top_level_dtype, torch.float32)

self.config.torch_dtype = top_level_dtype

# Normalize llm_config if it exists
if hasattr(self.config, "llm_config"):
self.config.llm_config.torch_dtype = top_level_dtype
if hasattr(self.config.llm_config, "use_bfloat16"):
self.config.llm_config.use_bfloat16 = top_level_dtype == torch.bfloat16

# Normalize vision_config if it exists
if hasattr(self.config, "vision_config"):
self.config.vision_config.torch_dtype = top_level_dtype
if hasattr(self.config.vision_config, "use_bfloat16"):
self.config.vision_config.use_bfloat16 = top_level_dtype == torch.bfloat16

# Normalize text_config if it exists (for models like Qwen2.5-VL)
if hasattr(self.config, "text_config"):
self.config.text_config.torch_dtype = top_level_dtype

logger.info(f"Normalized all config torch_dtype to: {top_level_dtype}")

def _offload_model_weights(self, offload_pt_weights: bool) -> bool:
"""Clear PyTorch model weights to reduce memory usage after ONNX export."""
if offload_pt_weights and not self._is_weights_offloaded:
Expand Down Expand Up @@ -506,12 +547,21 @@ def _compile(
command.append(f"-network-specialization-config={specializations_json}")

# Write custom_io.yaml file
model_in_bfloat16 = hasattr(self, "config") and (self.config.torch_dtype == torch.bfloat16)
pkv_in_bfloat16 = (custom_io is not None) and any(
"past_" in key and "bfloat16" in value for key, value in custom_io.items()
)
if custom_io is not None:
custom_io_yaml = compile_dir / "custom_io.yaml"
with open(custom_io_yaml, "w") as fp:
for io_name, dtype in custom_io.items():
fp.write(f" - IOName: {io_name}\n Precision: {dtype}\n\n")
command.append(f"-custom-IO-list-file={custom_io_yaml}")
if model_in_bfloat16 and pkv_in_bfloat16:
logger.warning(
"Model and Past KV types are both bfloat16. Custom IO list file will be ignored during compile."
)
else:
command.append(f"-custom-IO-list-file={custom_io_yaml}")

command.append(f"-aic-binary-dir={qpc_path}")
logger.info(f"Running compiler: {' '.join(command)}")
Expand Down
1 change: 1 addition & 0 deletions QEfficient/generation/cloud_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def __init__(

# Build dtype mapping once (depends on aicapi constants)
self.aic_to_np_dtype_mapping = {
getattr(aicapi, "BFLOAT16_TYPE", 11): np.dtype(np.float16),
aicapi.FLOAT_TYPE: np.dtype(np.float32),
aicapi.FLOAT_16_TYPE: np.dtype(np.float16),
aicapi.INT8_Q_TYPE: np.dtype(np.int8),
Expand Down
7 changes: 3 additions & 4 deletions QEfficient/transformers/models/codegen/modeling_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ def _attn(
head_mask=None,
):
# Keep the attention weights computation in fp32 to avoid overflow issues
query = query.to(torch.float32)
key = key.to(torch.float32)
query = query.to(value.dtype)
key = key.to(value.dtype)

attn_weights = torch.matmul(query, key.transpose(-1, -2))

Expand Down Expand Up @@ -349,8 +349,7 @@ def forward(
# Cast to INT32 to avoid issue while running in ONNXRT
logit_index = position_ids.to(torch.int32).argmax(1, keepdim=True)
hidden_states = transformer_outputs[0][torch.arange(position_ids.shape[0]).view(-1, 1), logit_index]
lm_logits = self.lm_head(hidden_states)

lm_logits = self.lm_head(hidden_states).float()
return CausalLMOutputWithPast(
loss=None,
logits=lm_logits,
Expand Down
8 changes: 5 additions & 3 deletions QEfficient/transformers/models/falcon/modeling_falcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,9 +142,11 @@ def forward(
attention_scores = query_layer @ key_layer.transpose(-1, -2)
attention_scores /= math.sqrt(self.head_dim)
attention_scores = torch.where(
attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attention_scores
attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=self.config.torch_dtype), attention_scores
)
attention_scores = F.softmax(attention_scores + attention_mask, dim=-1, dtype=torch.float32).to(
query_layer.dtype
)
attention_scores = F.softmax(attention_scores + attention_mask, dim=-1, dtype=hidden_states.dtype)
# It is unclear why neither dropout nor head_mask is applied here (while it is with alibi).
attn_output = attention_scores @ value_layer

Expand Down Expand Up @@ -401,7 +403,7 @@ def forward(
# Cast to INT32 to avoid issue while running in ONNXRT
logit_index = position_ids.to(torch.int32).argmax(1, keepdim=True)
hidden_states = transformer_outputs[0][torch.arange(position_ids.shape[0]).view(-1, 1), logit_index]
lm_logits = self.lm_head(hidden_states)
lm_logits = self.lm_head(hidden_states).float()

return CausalLMOutputWithCrossAttentions(
loss=None,
Expand Down
2 changes: 1 addition & 1 deletion QEfficient/transformers/models/gemma/modeling_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def eager_attention_forward(
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
if attention_mask is not None:
attn_weights = torch.where(
attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights
attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=module.config.torch_dtype), attn_weights
)

attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
Expand Down
4 changes: 2 additions & 2 deletions QEfficient/transformers/models/gemma2/modeling_gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def eager_attention_forward(
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
if attention_mask is not None:
attn_weights = torch.where(
attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights
attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=module.config.torch_dtype), attn_weights
)

attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
Expand Down Expand Up @@ -448,7 +448,7 @@ def forward(
logits = logits / self.config.final_logit_softcapping
logits = torch.tanh(logits)
logits = logits * self.config.final_logit_softcapping

logits = logits.float()
return CausalLMOutputWithPast(
loss=None,
logits=logits,
Expand Down
30 changes: 17 additions & 13 deletions QEfficient/transformers/models/gemma3/modeling_gemma3.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,7 @@
class GemmaRMSNormFunc(torch.autograd.Function):
@staticmethod
def forward(hidden_states: torch.Tensor, weight: torch.Tensor, epsilon: float):
hidden_states = hidden_states.to(torch.float32)
div_first = hidden_states * torch.rsqrt(torch.tensor(hidden_states.shape[-1], dtype=torch.float32))
div_first = hidden_states * torch.rsqrt(torch.tensor(hidden_states.shape[-1], dtype=hidden_states.dtype))
variance = div_first.pow(2).sum(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + epsilon)
return weight * hidden_states
Expand All @@ -61,7 +60,7 @@ class QEffGemma3CustomRMSNormAIC(nn.Module):
def forward(self, hidden_states):
return GemmaRMSNormFunc.apply(
hidden_states,
self.weight.float() + 1.0,
(self.weight).to(hidden_states.dtype) + 1.0,
self.variance_epsilon if hasattr(self, "variance_epsilon") else self.eps,
)

Expand Down Expand Up @@ -164,7 +163,7 @@ def eager_attention_forward(

if attention_mask is not None:
attn_weights = torch.where(
attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights
attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=module.config.torch_dtype), attn_weights
)

attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
Expand Down Expand Up @@ -272,7 +271,9 @@ def forward(

if attention_mask is not None: # no matter the length, we just slice it
attn_weights = torch.where(
attention_mask.bool(), torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights
attention_mask.bool(),
torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=self.config.torch_dtype),
attn_weights,
)

# upcast attention to fp32
Expand Down Expand Up @@ -534,6 +535,9 @@ def forward(
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

if self.config.torch_dtype == torch.float16:
logger.warning("Accuracy might drop with float16 as torch_dtype")

outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
Expand All @@ -551,7 +555,7 @@ def forward(
)
logit_index = position_ids.to(torch.int32).argmax(1, keepdim=True)
hidden_states = outputs[0][torch.arange(position_ids.shape[0]).view(-1, 1), logit_index]
logits = self.lm_head(hidden_states)
logits = self.lm_head(hidden_states).float()

if self.config.final_logit_softcapping is not None:
logits = logits / self.config.final_logit_softcapping
Expand Down Expand Up @@ -583,8 +587,8 @@ def get_dummy_pkv_cache(self, config, batch_size, seq_len):
for i in range(config.num_hidden_layers):
if hasattr(config, "sliding_window"):
cache_shape = global_cache_shape if not is_sliding[i] else sliding_cache_shape
new_layer_key_cache = torch.zeros(cache_shape, dtype=torch.float32)
new_layer_value_cache = torch.zeros(cache_shape, dtype=torch.float32)
new_layer_key_cache = torch.zeros(cache_shape, dtype=self.config.torch_dtype)
new_layer_value_cache = torch.zeros(cache_shape, dtype=self.config.torch_dtype)
pkv = (new_layer_key_cache, new_layer_value_cache)
past_key_values.append(pkv)
return past_key_values
Expand Down Expand Up @@ -897,8 +901,8 @@ def get_dummy_pkv_cache(self, config, batch_size, seq_len):
for i in range(config.num_hidden_layers):
if hasattr(config, "sliding_window"):
cache_shape = global_cache_shape if not is_sliding[i] else sliding_cache_shape
new_layer_key_cache = torch.zeros(cache_shape, dtype=torch.float32)
new_layer_value_cache = torch.zeros(cache_shape, dtype=torch.float32)
new_layer_key_cache = torch.zeros(cache_shape, dtype=self.config.torch_dtype)
new_layer_value_cache = torch.zeros(cache_shape, dtype=self.config.torch_dtype)
pkv = (new_layer_key_cache, new_layer_value_cache)
past_key_values.append(pkv)
return past_key_values
Expand Down Expand Up @@ -935,9 +939,9 @@ def get_dummy_inputs(
# Define inputs
vision_inputs = {}
lang_inputs = {}
vision_inputs["pixel_values"] = torch.zeros((inputs_shapes["pixel_values"]), dtype=torch.float32)
vision_inputs["pixel_values"] = torch.zeros((inputs_shapes["pixel_values"]), dtype=self.config.torch_dtype)
lang_inputs["input_ids"] = torch.zeros((inputs_shapes["input_ids"]), dtype=torch.int64)
lang_inputs["vision_embeds"] = torch.zeros((inputs_shapes["vision_embeds"]), dtype=torch.float32)
lang_inputs["vision_embeds"] = torch.zeros((inputs_shapes["vision_embeds"]), dtype=self.config.torch_dtype)
lang_inputs["position_ids"] = (
torch.arange(constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN, dtype=torch.int64)
.view(1, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN)
Expand Down Expand Up @@ -976,7 +980,7 @@ def get_inputs_info(self):
IOInfo(name="attention_mask", datatype=torch.int64, shape=("batch_size", "seq_len")),
IOInfo(
name="pixel_values",
datatype=torch.float32,
datatype=self.config.torch_dtype,
shape=("batch_size", 3, "img_size", "img_size"),
),
]
4 changes: 2 additions & 2 deletions QEfficient/transformers/models/gpt2/modeling_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,10 @@ def eager_attention_forward(module, query, key, value, attention_mask, head_mask
if attention_mask is not None:
# Apply the attention mask
attn_weights = torch.where(
attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights
attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=module.config.torch_dtype), attn_weights
)

attn_weights = nn.functional.softmax(attn_weights, dim=-1)
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32)

# Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op otherwise
attn_weights = attn_weights.type(value.dtype)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def eager_attention_forward(

if attention_mask is not None:
attn_weights = torch.where(
attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights
attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=module.config.torch_dtype), attn_weights
)
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_output = torch.matmul(attn_weights, value_states)
Expand Down Expand Up @@ -439,7 +439,7 @@ def forward(
# Cast to INT32 to avoid issue while running in ONNXRT
logit_index = position_ids.to(torch.int32).argmax(1, keepdim=True)
hidden_states = transformer_outputs[0][torch.arange(position_ids.shape[0]).view(-1, 1), logit_index]
lm_logits = self.lm_head(hidden_states)
lm_logits = self.lm_head(hidden_states).float()

return CausalLMOutputWithCrossAttentions(
loss=None,
Expand Down
6 changes: 3 additions & 3 deletions QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py
Original file line number Diff line number Diff line change
Expand Up @@ -593,7 +593,7 @@ def eager_attention_forward(
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
if attention_mask is not None:
attn_weights = torch.where(
attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights
attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=module.config.torch_dtype), attn_weights
)

sinks = module.sinks.reshape(1, -1, 1, 1).expand(query.shape[0], -1, query.shape[-2], -1)
Expand Down Expand Up @@ -644,7 +644,7 @@ def eager_attention_forward_blocked(
scores = torch.matmul(q_block, key_states.transpose(2, 3)) * scaling
attn_mask_block = attention_mask[:, :, qi : qi + real_q_len, :]
curr_attn_weights = torch.where(
attn_mask_block, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), scores
attn_mask_block, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=module.config.torch_dtype), scores
)
sinks = module.sinks.reshape(1, -1, 1, 1).expand(
curr_attn_weights.shape[0], -1, curr_attn_weights.shape[-2], -1
Expand Down Expand Up @@ -707,7 +707,7 @@ def opt_eager_attention_forward_blocked(

scores = torch.matmul(q_block, k_block.transpose(2, 3)) * scaling
curr_attn_weights = torch.where(
attn_mask_block, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), scores
attn_mask_block, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=module.config.torch_dtype), scores
)
sinks = module.sinks.reshape(1, -1, 1, 1).expand(
curr_attn_weights.shape[0], -1, curr_attn_weights.shape[-2], -1
Expand Down
9 changes: 5 additions & 4 deletions QEfficient/transformers/models/gptj/modeling_gptj.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@


def apply_rotary_pos_emb(tensor: torch.Tensor, sin: torch.Tensor, cos: torch.Tensor) -> torch.Tensor:
# Sin Cos are also fixated on fp32 here
sin = torch.repeat_interleave(sin[:, :, None, :], 2, 3)
cos = torch.repeat_interleave(cos[:, :, None, :], 2, 3)
return (tensor * cos) + (rotate_every_two(tensor) * sin)
Expand All @@ -55,16 +56,16 @@ def _attn(
head_mask=None,
):
# Keep the attention weights computation in fp32 to avoid overflow issues
query = query.to(torch.float32)
key = key.to(torch.float32)
query = query.to(value.dtype)
key = key.to(value.dtype)

attn_weights = torch.matmul(query, key.transpose(-1, -2))
attn_weights = attn_weights / self.scale_attn

if attention_mask is not None:
# Apply the attention mask
attn_weights = torch.where(
attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights
attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=value.dtype), attn_weights
)

attn_weights = nn.functional.softmax(attn_weights, dim=-1)
Expand Down Expand Up @@ -109,7 +110,7 @@ def forward(
embed_positions = get_embed_positions(self.embed_positions, position_ids)
else:
embed_positions = self._get_embed_positions(position_ids)

embed_positions = embed_positions.to(value.dtype)
repeated_position_ids = position_ids.unsqueeze(-1).repeat(1, 1, embed_positions.shape[-1])
repeated_position_ids = torch.where(repeated_position_ids == -1, 0, repeated_position_ids)
sincos = torch.gather(embed_positions, 1, repeated_position_ids)
Expand Down
5 changes: 2 additions & 3 deletions QEfficient/transformers/models/granite/modeling_granite.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def eager_attention_forward(
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
if attention_mask is not None:
attn_weights = torch.where(
attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights
attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=module.config.torch_dtype), attn_weights
)

attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
Expand Down Expand Up @@ -431,8 +431,7 @@ def forward(
logit_index = position_ids.to(torch.int32).argmax(1, keepdim=True)
hidden_states = outputs[0][torch.arange(position_ids.shape[0]).view(-1, 1), logit_index]

logits = self.lm_head(hidden_states)
logits = logits.float()
logits = self.lm_head(hidden_states).float()

return CausalLMOutputWithPast(
loss=None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def eager_attention_forward(
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
if attention_mask is not None:
attn_weights = torch.where(
attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights
attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=module.config.torch_dtype), attn_weights
)

# upcast attention to fp32
Expand Down
Loading
Loading