Skip to content
Open
19 changes: 17 additions & 2 deletions src/heretic/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@ def obtain_merge_strategy(settings: Settings) -> str | None:
"[yellow]WARNING: CPU merging requires dequantizing the entire model to system RAM.[/]"
)
print("[yellow]This can lead to system freezes if you run out of memory.[/]")
print(
"[yellow] However, you can choose to save the model shard-by-shard. It is slightly slower, but requires <10 GB of RAM for most models.[/]"
)
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need this extra complexity. Always using sharded merging should be fine.


try:
# Estimate memory requirements by loading the model structure on the "meta" device.
Expand All @@ -88,13 +91,13 @@ def obtain_merge_strategy(settings: Settings) -> str | None:
footprint_bytes = meta_model.get_memory_footprint()
footprint_gb = footprint_bytes / (1024**3)
print(
f"[yellow]Estimated RAM required (excluding overhead): [bold]~{footprint_gb:.2f} GB[/][/]"
f"[yellow]Estimated RAM in non-sharded mode required (excluding overhead): [bold]~{footprint_gb:.2f} GB[/][/]"
)
except Exception:
# Fallback if meta loading fails (e.g. owing to custom model code
# or bitsandbytes quantization config issues on the meta device).
print(
"[yellow]Rule of thumb: You need approximately 3x the parameter count in GB RAM.[/]"
"[yellow]Rule of thumb: In non-sharded mode, you need approximately 3x the parameter count in GB RAM.[/]"
)
print(
"[yellow]Example: A 27B model requires ~80GB RAM. A 70B model requires ~200GB RAM.[/]"
Expand All @@ -113,6 +116,15 @@ def obtain_merge_strategy(settings: Settings) -> str | None:
),
value="merge",
),
Choice(
title="Merge LoRA into full model (sharded)"
+ (
""
if settings.quantization == QuantizationMethod.NONE
else " (very low RAM usage)"
),
value="merge_sharded",
),
Choice(
title="Cancel",
value="cancel",
Expand Down Expand Up @@ -754,6 +766,9 @@ def count_completed_trials() -> int:
if strategy == "adapter":
print("Saving LoRA adapter...")
model.model.save_pretrained(save_directory)
elif strategy == "merge_sharded":
print("Saving merged model file by file...")
model.save_sharded(settings.model, save_directory)
else:
print("Saving merged model...")
merged_model = model.get_merged_model()
Expand Down
95 changes: 95 additions & 0 deletions src/heretic/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,11 @@
from contextlib import suppress
from dataclasses import dataclass
from typing import Any, Type, cast
import os
import glob
import json
import shutil
from safetensors import safe_open, save_file

import bitsandbytes as bnb
import torch
Expand Down Expand Up @@ -265,6 +270,96 @@ def get_merged_model(self) -> PreTrainedModel:
self.needs_reload = True
return merged_model


def save_sharded(self, model_path: str, save_directory: str) -> None:

assert isinstance(self.model, PeftModel)
Comment thread
kabachuha marked this conversation as resolved.

lora_state_dict = self.model.state_dict()

# 1. resolve the local path
if not os.path.exists(model_path):
print("Model path not found locally, attempting to locate in HF cache.")
hf_hub_location = os.environ.get("HF_HUB_CACHE", os.path.join(os.environ.get("HOME", ""), ".cache/huggingface/hub"))
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could be dicey. I think that if Xet is used (which HF is aggressively migrating towards), caches work quite differently, and aren't simply folders containing the model repository. It's an object storage similar to Git.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@p-e-w In this case heretic can do "from_pretrained" on the model and then do "save_pretrained" into a temp folder, on which the sharded ablation will be performed. Sharded modification is just too good not to have it

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Of course, a warning similar to the current one ram should be displayed, if the model wasn't saved to a local folder beforehand (and add a hint in the documentation what local saving is preferred and maybe enable it by default)


model_name_cached = "models--" + model_path.replace("/", "--")
model_path = os.path.join(hf_hub_location, model_name_cached)

if not os.path.exists(model_path):
raise ValueError(f"Could not find model at {model_path}")

# Parse LoRA keys to find matching base layers
lora_pairs = {}
for key in lora_state_dict.keys():
if "lora_A" in key:
base_key = key.replace(".lora_A.weight", "").replace("base_model.model.", "")
lora_pairs[base_key] = {
'A': lora_state_dict[key],
'B': lora_state_dict.get(key.replace("lora_A", "lora_B"))
}

# Get all safetensors shards
model_files = sorted(glob.glob(os.path.join(model_path, "*.safetensors")))
if not model_files:
raise ValueError(f"No safetensors found in {model_path}")

# Process each shard
print("Merging shards")
for shard_file in model_files:
shard_name = os.path.basename(shard_file)
merged_tensors = {}

with safe_open(shard_file, framework="pt", device='cpu') as f:
for key in f.keys():
tensor = f.get_tensor(key)

# Check if tensor is quantized
if tensor.dtype not in [torch.float32, torch.float16, torch.bfloat16]:
raise ValueError(
f"Tensor {key} has dtype {tensor.dtype} - "
"quantized base weights not supported for merging"
)

# Check if this layer has LoRA weights
if key in lora_pairs:
lora_item = lora_pairs.pop(key)
lora_A = lora_item['A']
lora_B = lora_item['B']
Comment thread
kabachuha marked this conversation as resolved.

# Verify shapes
assert lora_A.shape[0] == tensor.shape[0], \
f"Shape mismatch for {key}: tensor {tensor.shape}, lora_A {lora_A.shape}"
Comment on lines +330 to +331
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The shape verification for the LoRA weights is incorrect. The assertion lora_A.shape[0] == tensor.shape[0] compares the LoRA rank (r) with the base weight's output features, which will likely fail.

The correct checks should verify that the output dimensions of lora_B and input dimensions of lora_A match the base tensor's dimensions.

Suggested change
assert lora_A.shape[0] == tensor.shape[0], \
f"Shape mismatch for {key}: tensor {tensor.shape}, lora_A {lora_A.shape}"
# Verify shapes. W is (out, in), A is (r, in), B is (out, r).
assert lora_B.shape[0] == tensor.shape[0] and lora_A.shape[1] == tensor.shape[1], \
f"Shape mismatch for {key}: W={tensor.shape}, A={lora_A.shape}, B={lora_B.shape}"

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the lora is computed fine, this should be fine as well


# Merge: W_merged = W + (B @ A)
lora_weight = lora_B @ lora_A

# Ensure same dtype for merging
lora_weight = lora_weight.to(tensor.dtype)

# Add to original weight
tensor = tensor + lora_weight

merged_tensors[key] = tensor

# Save merged shard
output_shard = os.path.join(save_directory, shard_name)
save_file(merged_tensors, output_shard, metadata={"format": "pt"})
del merged_tensors

# Copy non-safetensors files
for file in os.listdir(model_path):
if not file.endswith('.safetensors'):
src = os.path.join(model_path, file)
dst = os.path.join(save_directory, file)
if os.path.isfile(src):
shutil.copy2(src, dst)

assert len(lora_pairs) == 0, "Not all LoRA keys have been injected"

print(f"Merged model saved to {save_directory}")

# Because no writing to the model is made, we can continue without reloading

def reset_model(self):
"""
Resets the model to a clean state for the next trial or evaluation.
Expand Down