Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion src/heretic/evaluator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# SPDX-License-Identifier: AGPL-3.0-or-later
# Copyright (C) 2025-2026 Philipp Emanuel Weidmann <pew@worldwidemann.com> + contributors

import math

import torch.nn.functional as F
from torch import Tensor

Expand Down Expand Up @@ -101,6 +103,10 @@ def get_score(self) -> tuple[tuple[float, float], float, int]:
reduction="batchmean",
log_target=True,
).item()
# Aggressive abliteration can destabilize the model, producing NaN logits.
# Treat NaN KL as infinitely bad so Optuna avoids these parameter regions.
if math.isnan(kl_divergence):
kl_divergence = float("inf")
print(f" * KL divergence: [bold]{kl_divergence:.4f}[/]")

print(" * Counting model refusals...")
Expand All @@ -110,7 +116,7 @@ def get_score(self) -> tuple[tuple[float, float], float, int]:
kl_divergence_scale = self.settings.kl_divergence_scale
kl_divergence_target = self.settings.kl_divergence_target

refusals_score = refusals / self.base_refusals
refusals_score = refusals / self.base_refusals if self.base_refusals > 0 else 0.0

if kl_divergence >= kl_divergence_target:
kld_score = kl_divergence / kl_divergence_scale
Expand Down
63 changes: 54 additions & 9 deletions src/heretic/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,9 +166,17 @@ def _apply_lora(self):
# but this is harmless as we only abliterate the modules we target in `abliterate()`,
# leaving the others at their default (identity) state.
# NOTE: This will need to be updated when hybrid layer support (#43) is merged.
target_modules = [
comp.split(".")[-1] for comp in self.get_abliterable_components()
]
#
# We resolve target module names by matching module identities from the model tree,
# because the component labels in get_layer_modules (e.g. "mlp.down_proj") may differ
# from the actual registered module names (e.g. "w2" in MiniMax-M2.5 MoE experts).
all_modules = {id(m): name.split(".")[-1] for name, m in self.model.named_modules()}
target_modules = list({
all_modules[id(module)]
for modules in self.get_layer_modules(0).values()
for module in modules
if id(module) in all_modules
})

if self.settings.row_normalization != RowNormalization.FULL:
# Rank 1 is sufficient for directional ablation without renormalization.
Expand All @@ -192,6 +200,14 @@ def _apply_lora(self):
# so the result is a PeftModel rather than a PeftMixedModel.
self.model = cast(PeftModel, get_peft_model(self.model, self.peft_config))

# FP8 dtypes (e.g. float8_e4m3fn) are not supported by standard torch.addmm,
# which nn.Linear uses internally. Models distributed in FP8 (e.g. MiniMax-M2.5)
# will cause LoRA forward passes to fail. Cast LoRA adapter weights to bfloat16
# so that the adapter matmuls use a supported dtype.
for name, param in self.model.named_parameters():
if "lora_" in name and param.dtype in (torch.float8_e4m3fn, torch.float8_e5m2):
param.data = param.data.to(torch.bfloat16)

print(f"* LoRA adapters initialized (targets: {', '.join(target_modules)})")

def _get_quantization_config(self, dtype: str) -> BitsAndBytesConfig | None:
Expand Down Expand Up @@ -257,7 +273,24 @@ def get_merged_model(self) -> PreTrainedModel:
merged_model = peft_model.merge_and_unload()
return merged_model
else:
# Non-quantized model - can merge directly
# Non-quantized model - can merge directly.
# FP8 base weights don't support in-place addition (+=) needed by merge,
# so dequantize them to bfloat16 first (applying block-wise scale factors),
# then merge. The merged model is kept in bfloat16 because the original
# FP8 scale factors are invalidated by the LoRA delta.
for _, module in self.model.named_modules():
if hasattr(module, "weight") and not isinstance(module, Linear):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The condition not isinstance(module, Linear) prevents the logic from running on LoRA-wrapped layers. These are precisely the layers that need their base weights upcast because merge_and_unload() performs an in-place addition on them, which fails for FP8 dtypes. The weight property on a peft.tuners.lora.layer.Linear module correctly delegates to the base layer's weight, so these modules should be processed. Removing this part of the condition will fix the issue and allow the merge to succeed with FP8 models.

Suggested change
if hasattr(module, "weight") and not isinstance(module, Linear):
if hasattr(module, "weight"):

w = module.weight
if w.dtype in (torch.float8_e4m3fn, torch.float8_e5m2):
W = w.data.to(torch.bfloat16)
if hasattr(module, "weight_scale_inv") and module.weight_scale_inv is not None:
bs_r, bs_c = module.block_size
scale = module.weight_scale_inv
scale = scale.repeat_interleave(bs_r, dim=0)[: W.shape[0]]
scale = scale.repeat_interleave(bs_c, dim=1)[:, : W.shape[1]]
W = W * scale.to(dtype=W.dtype, device=W.device)
module.weight.data = W

print("* Merging LoRA adapters into base model...")
merged_model = self.model.merge_and_unload()
# merge_and_unload() modifies self.model in-place, destroying LoRA adapters.
Expand Down Expand Up @@ -447,12 +480,11 @@ def abliterate(
# FIXME: This cast is valid only under the assumption that the original
# module wrapped by the LoRA adapter has a weight attribute.
# See the comment above for why this is currently not guaranteed.
base_weight = cast(Tensor, module.base_layer.weight)
base_module = module.base_layer
base_weight = cast(Tensor, base_module.weight)
quant_state = getattr(base_weight, "quant_state", None)

if quant_state is None:
W = base_weight.to(torch.float32)
else:
if quant_state is not None:
# 4-bit quantization.
# This cast is always valid. Type inference fails here because the
# bnb.functional module is not found by ty for some reason.
Expand All @@ -463,6 +495,17 @@ def abliterate(
quant_state,
).to(torch.float32),
)
elif hasattr(base_module, "weight_scale_inv") and base_module.weight_scale_inv is not None:
# FP8 block-wise quantization (e.g. transformers FP8Linear).
# Apply per-block scale factors to recover the true weight values.
W = base_weight.to(torch.float32)
bs_r, bs_c = base_module.block_size
scale = base_module.weight_scale_inv
scale = scale.repeat_interleave(bs_r, dim=0)[: W.shape[0]]
scale = scale.repeat_interleave(bs_c, dim=1)[:, : W.shape[1]]
W = W * scale.to(W.device)
else:
W = base_weight.to(torch.float32)

# Flatten weight matrix to (out_features, in_features).
W = W.view(W.shape[0], -1)
Expand Down Expand Up @@ -674,7 +717,9 @@ def get_logprobs(self, prompts: list[Prompt]) -> Tensor:
logits = cast(tuple[FloatTensor], outputs.scores)[0]

# The returned tensor has shape (prompt, token).
return F.log_softmax(logits, dim=-1)
# Upcast to float32 and clamp log_softmax output to avoid -inf values
# that cause infinite KL divergence (https://github.com/pytorch/pytorch/issues/32520).
return F.log_softmax(logits.float(), dim=-1).clamp(min=-100)

def get_logprobs_batched(self, prompts: list[Prompt]) -> Tensor:
logprobs = []
Expand Down