fix: handle FP8 model weights in LoRA adapters and merge#182
fix: handle FP8 model weights in LoRA adapters and merge#182shifusen329 wants to merge 7 commits intop-e-w:masterfrom
Conversation
Models distributed in FP8 (e.g. MiniMax-M2.5) cause failures because torch.addmm has no FP8 kernel. Cast LoRA adapter weights to bfloat16 after initialization, and upcast FP8 base weights before merge to avoid unsupported in-place addition. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Summary of ChangesHello @shifusen329, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces crucial compatibility fixes for working with FP8 quantized models, such as MiniMax-M2.5, when using LoRA adapters. It addresses issues where FP8 models would crash during inference due to unsupported Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request aims to fix issues with FP8 models by casting weights to bfloat16 during LoRA forward passes and model merging. The change in _apply_lora to handle adapter weights seems correct. However, in get_merged_model, there's a critical issue in the logic for upcasting base model weights before merging. The current implementation incorrectly excludes the LoRA-wrapped layers, which prevents the fix from working. I've provided a specific comment and suggestion to address this.
| # so upcast them to bfloat16 first, merge, then cast back. | ||
| fp8_params = {} | ||
| for name, module in self.model.named_modules(): | ||
| if hasattr(module, "weight") and not isinstance(module, Linear): |
There was a problem hiding this comment.
The condition not isinstance(module, Linear) prevents the logic from running on LoRA-wrapped layers. These are precisely the layers that need their base weights upcast because merge_and_unload() performs an in-place addition on them, which fails for FP8 dtypes. The weight property on a peft.tuners.lora.layer.Linear module correctly delegates to the base layer's weight, so these modules should be processed. Removing this part of the condition will fix the issue and allow the merge to succeed with FP8 models.
| if hasattr(module, "weight") and not isinstance(module, Linear): | |
| if hasattr(module, "weight"): |
- _apply_lora: resolve target module names from the model tree by matching module identities instead of parsing component labels. Fixes MoE models where registered names differ from heretic's labels (e.g. "w2" vs "down_proj" in MiniMax-M2.5). - abliterate: dequantize FP8 block-wise quantized weights by applying weight_scale_inv per block, so abliteration computes correct refusal direction projections. - get_merged_model: same FP8 dequantization before merge. The merged model is kept in bfloat16 since the original scale factors are invalidated. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Models that don't refuse any baseline prompts (e.g. MiniMax-M2.5) cause a ZeroDivisionError when computing refusals_score. Return 0.0 when base_refusals is zero since there are no refusals to remove. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Aggressive abliteration can destabilize the model, producing NaN logits that propagate into the KL divergence. NaN silently bypasses the kl_divergence >= target comparison (always False), producing a misleadingly finite score. Replace NaN with inf so Optuna correctly identifies these trials as maximally bad. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
All trials were producing inf KL divergence because the 0.8-1.5 range for max_weight was too aggressive for the model. Lowered to 0.1-0.8. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Low-precision dtypes (bfloat16/float16) cause log_softmax to produce -inf for low-probability tokens. When abliteration shifts which tokens underflow, kl_div returns NaN/inf. Upcasting to float32 matches the existing pattern used for residual vectors. Also reverts the max_weight range change from 2e6cfc6 since the search ranges were not the actual problem. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
log_softmax produces -inf for near-zero probability tokens, which causes kl_div to return inf regardless of actual distribution similarity. Clamping to -100 keeps values finite while preserving effectively-zero probabilities (exp(-100) ≈ 3.7e-44). Ref: pytorch/pytorch#32520 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
|
This seems to be quite similar to #151, which is also concerned with FP8 support. |
Summary
torch.addmmhas no FP8 (Float8_e4m3fn) kernel, causingNotImplementedErrorin PEFT LoRA forward passes_apply_lora()so adapter matmuls use a supported dtypemerge_and_unload()inget_merged_model(), then cast back, to avoid unsupported in-place addition during mergeTest plan
MiniMaxAI/MiniMax-M2.5) and verify inference completes withoutNotImplementedError🤖 Generated with Claude Code