Skip to content

Refactor QKV Fusion Utilities to be LoRA-Aware#14047

Open
dg845 wants to merge 7 commits into
mainfrom
refactor/lora-aware-qkv-fusion
Open

Refactor QKV Fusion Utilities to be LoRA-Aware#14047
dg845 wants to merge 7 commits into
mainfrom
refactor/lora-aware-qkv-fusion

Conversation

@dg845

@dg845 dg845 commented Jun 22, 2026

Copy link
Copy Markdown
Collaborator

What does this PR do?

This PR refactors the attention QKV fusion utilities in AttentionMixin and AttentionModuleMixin to be more LoRA-aware. In particular, this PR adds guards when attempting to fuse/unfuse with a LoRA attached (because LoRAs cannot be easily transferred over when fusing/unfusing) and an inplace option to fuse without keeping copies of the split Q,K,V projections.

Changelist

  1. Adds guards when attempting to fuse/unfuse Q,K,V with a LoRA attached (will raise an error).
  2. Adds an inplace argument; if inplace=True, the module is modified to have only the fused QKV projection (e.g. to_qkv) with the split Q,K,V projections (e.g. to_q/to_k/to_v) removed. (inplace=False, the default, retains the current behavior).
  3. Supports fusion in the case where add_k_proj and add_v_proj are present without add_q_proj also present (e.g. Wan models).
  4. Adds get_qkv and get_added_qkv helper methods in AttentionModuleMixin which handles getting the Q, K, V (and added Q,K,V, for second stream projections in MM-DiT-style models like Flux) in both the fused and split case. This is intended to make it easier for attention processors to support both fused and split QKV.
  5. Adds an experimental restore_checkpoint_fusion_state method to AttentionMixin to put models back in the fusion state of the original model checkpoint. A new _native_fused_projections attribute on AttentionModuleMixin is added to allow this state to be described. (The motivation is to make it easier to support PEFT adapters which target the original checkpoint structure.)

Partially addresses #14003.

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@sayakpaul
@DN6

dg845 and others added 5 commits June 21, 2026 07:46
…e to original fusion state, module-level helpers to get Q,K,V in both fused and split cases
…duleMixin

Tests are in tests/models/test_attention_mixins.py and cover the four minimal
concrete AttentionModuleMixin fixtures (_MinimalSelfAttn, _MinimalCrossAttn,
_MinimalAddedKVAttn, _MinimalAddedQKVAttn):

TestAttentionModuleMixin (53 tests):
- Idempotency of fuse_projections/unfuse_projections
- Module attribute invariants for non-inplace and inplace paths
- Weight/bias correctness: fused weight equals concatenation of split weights
- Inplace round-trip weight preservation and storage-sharing (no copy on unfuse)
- Cross-attention to_kv path, added-KV to_added_kv path (Wan-style), and
  added-QKV to_added_qkv path (Flux-style)
- get_qkv and get_added_qkv numerical correctness in split and fused cases
- LoRA guard: fuse_projections/unfuse_projections raise ValueError when PEFT-style
  lora_A/lora_B submodules are detected on split or fused projections

TestAttentionMixin (6 tests):
- fuse_qkv_projections/unfuse_qkv_projections propagate to all eligible blocks
- restore_checkpoint_fusion_state respects _native_fused_projections=None/True/False
  per block, including mixed-state models

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
@github-actions github-actions Bot added models tests size/L PR with diff > 200 LOC labels Jun 22, 2026
@HuggingFaceDocBuilderDev

Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Comment on lines +326 to +327
delattr(self, "to_k")
delattr(self, "to_v")

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it free the memory? If not, what is the purpose of deleting these attributes? Also, from what I understand this and some of the other refactors introduced in this PR aren't particularly for LoRA-awareness?

Comment on lines +348 to +350
delattr(self, "to_q")
delattr(self, "to_k")
delattr(self, "to_v")

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above.

@staticmethod
def _has_active_lora(module: nn.Module) -> bool:
"""Checks for the presence of PEFT-style LoRA modules without needing to import `peft`."""
return any("lora_A" in name or "lora_B" in name for name, _ in module.named_modules())

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think a better way to detect it is:

for name, mod in module.named_modules():
    if isinstance(mod, BaseTunerLayer): ...

This way, we can cater towards any non-LoRA adapters in the future.

if isinstance(module, AttentionModuleMixin) and module._supports_qkv_fusion:
module.unfuse_projections()

def restore_checkpoint_fusion_state(self, inplace: bool = False):

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you provide an example of where / how this is used?

My mental model says:

  • Load the pipeline
  • Call the fuse_qkv_projections() method on pipe.transformer.
  • Then, before loading we call this method.
  • And then load the LoRA weights?

@@ -0,0 +1,624 @@
import pytest

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it not make sense to make these a common mixin that can be added to models?

return self.base(x) + self.lora_B(self.lora_A(x))


class TestAttentionModuleMixin:

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The tests are missing:

  • If outputs with and without fusion stay the same (they should).
  • If expected errors are raised when LoRA is attached.

Additionally, are we checking if the projection attributes like to_k, to_q, etc. are actually getting deleted for in-place fusions and are successfully restored when the user wants them?

assert not hasattr(model.block2, "to_kv")

# -------------------------------------------------------------------------
# restore_checkpoint_fusion_state

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should the corresponding tests also check if we have the fused QKV projections rather than just checking the fused_projections attribute?

def test_restore_checkpoint_noop_for_none(self, model):
# Default _native_fused_projections is None — state should be unchanged.
model.restore_checkpoint_fusion_state()
assert model.block1.fused_projections is False

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Prefer

Suggested change
assert model.block1.fused_projections is False
assert not model.block1.fused_projections

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

models size/L PR with diff > 200 LOC tests

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants