Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/heretic/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,9 @@ def get_score(self) -> tuple[tuple[float, float], float, int]:
kl_divergence_scale = self.settings.kl_divergence_scale
kl_divergence_target = self.settings.kl_divergence_target

refusals_score = refusals / self.base_refusals
refusals_score = (
refusals / self.base_refusals if self.base_refusals > 0 else 0.0
)

if kl_divergence >= kl_divergence_target:
kld_score = kl_divergence / kl_divergence_scale
Expand Down
16 changes: 12 additions & 4 deletions src/heretic/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,22 @@
)


def obtain_merge_strategy(settings: Settings) -> str | None:
def obtain_merge_strategy(settings: Settings, model: Model) -> str | None:
"""
Prompts the user for how to proceed with saving the model.
Provides info to the user if the model is quantized on memory use.
Returns "merge", "adapter", or None (if cancelled/invalid).
"""

if settings.quantization == QuantizationMethod.BNB_4BIT:
# Also detect pre-quantized models (FP8, MXFP4, etc.) via their built-in
# quantization_config, which HuggingFace stores in the model's config.json.
pre_quantized = (
getattr(model.model.config, "quantization_config", None) is not None
and settings.quantization == QuantizationMethod.NONE
)
is_quantized = settings.quantization == QuantizationMethod.BNB_4BIT or pre_quantized

if is_quantized:
print()
print(
"Model was loaded with quantization. Merging requires reloading the base model."
Expand Down Expand Up @@ -753,7 +761,7 @@ def count_completed_trials() -> int:
if not save_directory:
continue

strategy = obtain_merge_strategy(settings)
strategy = obtain_merge_strategy(settings, model)
if strategy is None:
continue

Expand Down Expand Up @@ -802,7 +810,7 @@ def count_completed_trials() -> int:
)
private = visibility == "Private"

strategy = obtain_merge_strategy(settings)
strategy = obtain_merge_strategy(settings, model)
if strategy is None:
continue

Expand Down
Loading