-
Notifications
You must be signed in to change notification settings - Fork 2k
Sharded LoRA merge #162
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Sharded LoRA merge #162
Changes from all commits
8d433f7
a831ec0
521bc94
43e2998
25862c3
75b1d2a
cd7087b
e9fb28f
9be07c4
b6a094a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -5,6 +5,11 @@ | |||||||||||
| from contextlib import suppress | ||||||||||||
| from dataclasses import dataclass | ||||||||||||
| from typing import Any, Type, cast | ||||||||||||
| import os | ||||||||||||
| import glob | ||||||||||||
| import json | ||||||||||||
| import shutil | ||||||||||||
| from safetensors import safe_open, save_file | ||||||||||||
|
|
||||||||||||
| import bitsandbytes as bnb | ||||||||||||
| import torch | ||||||||||||
|
|
@@ -265,6 +270,96 @@ def get_merged_model(self) -> PreTrainedModel: | |||||||||||
| self.needs_reload = True | ||||||||||||
| return merged_model | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| def save_sharded(self, model_path: str, save_directory: str) -> None: | ||||||||||||
|
|
||||||||||||
| assert isinstance(self.model, PeftModel) | ||||||||||||
|
kabachuha marked this conversation as resolved.
|
||||||||||||
|
|
||||||||||||
| lora_state_dict = self.model.state_dict() | ||||||||||||
|
|
||||||||||||
| # 1. resolve the local path | ||||||||||||
| if not os.path.exists(model_path): | ||||||||||||
| print("Model path not found locally, attempting to locate in HF cache.") | ||||||||||||
| hf_hub_location = os.environ.get("HF_HUB_CACHE", os.path.join(os.environ.get("HOME", ""), ".cache/huggingface/hub")) | ||||||||||||
|
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This could be dicey. I think that if Xet is used (which HF is aggressively migrating towards), caches work quite differently, and aren't simply folders containing the model repository. It's an object storage similar to Git.
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @p-e-w In this case heretic can do "from_pretrained" on the model and then do "save_pretrained" into a temp folder, on which the sharded ablation will be performed. Sharded modification is just too good not to have it
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Of course, a warning similar to the current one ram should be displayed, if the model wasn't saved to a local folder beforehand (and add a hint in the documentation what local saving is preferred and maybe enable it by default) |
||||||||||||
|
|
||||||||||||
| model_name_cached = "models--" + model_path.replace("/", "--") | ||||||||||||
| model_path = os.path.join(hf_hub_location, model_name_cached) | ||||||||||||
|
|
||||||||||||
| if not os.path.exists(model_path): | ||||||||||||
| raise ValueError(f"Could not find model at {model_path}") | ||||||||||||
|
|
||||||||||||
| # Parse LoRA keys to find matching base layers | ||||||||||||
| lora_pairs = {} | ||||||||||||
| for key in lora_state_dict.keys(): | ||||||||||||
| if "lora_A" in key: | ||||||||||||
| base_key = key.replace(".lora_A.weight", "").replace("base_model.model.", "") | ||||||||||||
| lora_pairs[base_key] = { | ||||||||||||
| 'A': lora_state_dict[key], | ||||||||||||
| 'B': lora_state_dict.get(key.replace("lora_A", "lora_B")) | ||||||||||||
| } | ||||||||||||
|
|
||||||||||||
| # Get all safetensors shards | ||||||||||||
| model_files = sorted(glob.glob(os.path.join(model_path, "*.safetensors"))) | ||||||||||||
| if not model_files: | ||||||||||||
| raise ValueError(f"No safetensors found in {model_path}") | ||||||||||||
|
|
||||||||||||
| # Process each shard | ||||||||||||
| print("Merging shards") | ||||||||||||
| for shard_file in model_files: | ||||||||||||
| shard_name = os.path.basename(shard_file) | ||||||||||||
| merged_tensors = {} | ||||||||||||
|
|
||||||||||||
| with safe_open(shard_file, framework="pt", device='cpu') as f: | ||||||||||||
| for key in f.keys(): | ||||||||||||
| tensor = f.get_tensor(key) | ||||||||||||
|
|
||||||||||||
| # Check if tensor is quantized | ||||||||||||
| if tensor.dtype not in [torch.float32, torch.float16, torch.bfloat16]: | ||||||||||||
| raise ValueError( | ||||||||||||
| f"Tensor {key} has dtype {tensor.dtype} - " | ||||||||||||
| "quantized base weights not supported for merging" | ||||||||||||
| ) | ||||||||||||
|
|
||||||||||||
| # Check if this layer has LoRA weights | ||||||||||||
| if key in lora_pairs: | ||||||||||||
| lora_item = lora_pairs.pop(key) | ||||||||||||
| lora_A = lora_item['A'] | ||||||||||||
| lora_B = lora_item['B'] | ||||||||||||
|
kabachuha marked this conversation as resolved.
|
||||||||||||
|
|
||||||||||||
| # Verify shapes | ||||||||||||
| assert lora_A.shape[0] == tensor.shape[0], \ | ||||||||||||
| f"Shape mismatch for {key}: tensor {tensor.shape}, lora_A {lora_A.shape}" | ||||||||||||
|
Comment on lines
+330
to
+331
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The shape verification for the LoRA weights is incorrect. The assertion The correct checks should verify that the output dimensions of
Suggested change
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If the lora is computed fine, this should be fine as well |
||||||||||||
|
|
||||||||||||
| # Merge: W_merged = W + (B @ A) | ||||||||||||
| lora_weight = lora_B @ lora_A | ||||||||||||
|
|
||||||||||||
| # Ensure same dtype for merging | ||||||||||||
| lora_weight = lora_weight.to(tensor.dtype) | ||||||||||||
|
|
||||||||||||
| # Add to original weight | ||||||||||||
| tensor = tensor + lora_weight | ||||||||||||
|
|
||||||||||||
| merged_tensors[key] = tensor | ||||||||||||
|
|
||||||||||||
| # Save merged shard | ||||||||||||
| output_shard = os.path.join(save_directory, shard_name) | ||||||||||||
| save_file(merged_tensors, output_shard, metadata={"format": "pt"}) | ||||||||||||
| del merged_tensors | ||||||||||||
|
|
||||||||||||
| # Copy non-safetensors files | ||||||||||||
| for file in os.listdir(model_path): | ||||||||||||
| if not file.endswith('.safetensors'): | ||||||||||||
| src = os.path.join(model_path, file) | ||||||||||||
| dst = os.path.join(save_directory, file) | ||||||||||||
| if os.path.isfile(src): | ||||||||||||
| shutil.copy2(src, dst) | ||||||||||||
|
|
||||||||||||
| assert len(lora_pairs) == 0, "Not all LoRA keys have been injected" | ||||||||||||
|
|
||||||||||||
| print(f"Merged model saved to {save_directory}") | ||||||||||||
|
|
||||||||||||
| # Because no writing to the model is made, we can continue without reloading | ||||||||||||
|
|
||||||||||||
| def reset_model(self): | ||||||||||||
| """ | ||||||||||||
| Resets the model to a clean state for the next trial or evaluation. | ||||||||||||
|
|
||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we need this extra complexity. Always using sharded merging should be fine.