-
Notifications
You must be signed in to change notification settings - Fork 7.1k
Refactor QKV Fusion Utilities to be LoRA-Aware #14047
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
4e18884
658f071
37c0f70
b761f4a
a7512db
e41dee6
799b8cf
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -95,7 +95,7 @@ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): | |
| for name, module in self.named_children(): | ||
| fn_recursive_attn_processor(name, module, processor) | ||
|
|
||
| def fuse_qkv_projections(self): | ||
| def fuse_qkv_projections(self, inplace: bool = False): | ||
| """ | ||
| Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) | ||
| are fused. For cross-attention modules, key and value projection matrices are fused. | ||
|
|
@@ -106,7 +106,7 @@ def fuse_qkv_projections(self): | |
|
|
||
| for module in self.modules(): | ||
| if isinstance(module, AttentionModuleMixin) and module._supports_qkv_fusion: | ||
| module.fuse_projections() | ||
| module.fuse_projections(inplace=inplace) | ||
|
|
||
| def unfuse_qkv_projections(self): | ||
| """Disables the fused QKV projection if enabled. | ||
|
|
@@ -117,11 +117,28 @@ def unfuse_qkv_projections(self): | |
| if isinstance(module, AttentionModuleMixin) and module._supports_qkv_fusion: | ||
| module.unfuse_projections() | ||
|
|
||
| def restore_checkpoint_fusion_state(self, inplace: bool = False): | ||
| """ | ||
| Restores the QKV fusion state back to that of the original model checkpoint (unlike `fuse_qkv_projections`, | ||
| which will fuse all eligible projections). This can be undone by `unfuse_qkv_projections`. The original | ||
| checkpoint fusion info is held on each `AttentionModuleMixin` module in the _native_fused_projections | ||
| attribute. | ||
|
|
||
| > [!WARNING] > This API is 🧪 experimental. | ||
| """ | ||
| for module in self.modules(): | ||
| if isinstance(module, AttentionModuleMixin) and module._supports_qkv_fusion: | ||
| if module._native_fused_projections is True: | ||
| module.fuse_projections(inplace=inplace) | ||
| elif module._native_fused_projections is False: | ||
| module.unfuse_projections() | ||
|
|
||
|
|
||
| class AttentionModuleMixin: | ||
| _default_processor_cls = None | ||
| _available_processors = [] | ||
| _supports_qkv_fusion = True | ||
| _native_fused_projections = None | ||
| fused_projections = False | ||
|
|
||
| def set_processor(self, processor: AttentionProcessor) -> None: | ||
|
|
@@ -244,11 +261,34 @@ def set_use_memory_efficient_attention_xformers( | |
|
|
||
| self.set_attention_backend("xformers") | ||
|
|
||
| @staticmethod | ||
| def _has_active_lora(module: nn.Module) -> bool: | ||
| """Checks for the presence of PEFT-style LoRA modules without needing to import `peft`.""" | ||
| return any("lora_A" in name or "lora_B" in name for name, _ in module.named_modules()) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think a better way to detect it is: for name, mod in module.named_modules():
if isinstance(mod, BaseTunerLayer): ...This way, we can cater towards any non-LoRA adapters in the future. |
||
|
|
||
| @torch.no_grad() | ||
| def fuse_projections(self): | ||
| def fuse_projections(self, inplace: bool = False): | ||
| """ | ||
| Fuse the query, key, and value projections into a single projection for efficiency. | ||
| """ | ||
| # Do not fuse if LoRA adapters are active on the Q,K,V projections. | ||
| possible_qkv_modules = [ | ||
| ("to_q", getattr(self, "to_q", None)), | ||
| ("to_k", getattr(self, "to_k", None)), | ||
| ("to_v", getattr(self, "to_v", None)), | ||
| ("add_q_proj", getattr(self, "add_q_proj", None)), | ||
| ("add_k_proj", getattr(self, "add_k_proj", None)), | ||
| ("add_v_proj", getattr(self, "add_v_proj", None)), | ||
| ] | ||
| active_lora_modules = [ | ||
| name for name, mod in possible_qkv_modules if mod is not None and self._has_active_lora(mod) | ||
| ] | ||
| if active_lora_modules: | ||
| raise ValueError( | ||
| f"Cannot fuse QKV projections: LoRA adapters are active on {active_lora_modules}. " | ||
| "Please detach the LoRA or call `merge_and_unload()` to merge LoRA weights first." | ||
| ) | ||
|
|
||
| # Skip if the AttentionModuleMixin subclass does not support fusion (for example, the QKV projections in Flux2 | ||
| # single stream blocks are always fused) | ||
| if not self._supports_qkv_fusion: | ||
|
|
@@ -275,6 +315,16 @@ def fuse_projections(self): | |
| if hasattr(self, "use_bias") and self.use_bias: | ||
| concatenated_bias = torch.cat([self.to_k.bias.data, self.to_v.bias.data]) | ||
| self.to_kv.bias.copy_(concatenated_bias) | ||
|
|
||
| if inplace: | ||
| # Keep the necessary K,V dims so that the individual projections can be reconstructed. | ||
| self._qkv_split_dims = ( | ||
| self.to_k.weight.shape[0], | ||
| self.to_v.weight.shape[0], | ||
| self.to_k.weight.shape[1], | ||
| ) | ||
| delattr(self, "to_k") | ||
| delattr(self, "to_v") | ||
|
Comment on lines
+326
to
+327
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does it free the memory? If not, what is the purpose of deleting these attributes? Also, from what I understand this and some of the other refactors introduced in this PR aren't particularly for LoRA-awareness? |
||
| else: | ||
| # Fuse self-attention projections | ||
| concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data]) | ||
|
|
@@ -287,27 +337,68 @@ def fuse_projections(self): | |
| concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data]) | ||
| self.to_qkv.bias.copy_(concatenated_bias) | ||
|
|
||
| if inplace: | ||
| # Keep the necessary Q,K,V dims so that the individual projections can be reconstructed. | ||
| self._qkv_split_dims = ( | ||
| self.to_q.weight.shape[0], | ||
| self.to_k.weight.shape[0], | ||
| self.to_v.weight.shape[0], | ||
| self.to_q.weight.shape[1], | ||
| ) | ||
| delattr(self, "to_q") | ||
| delattr(self, "to_k") | ||
| delattr(self, "to_v") | ||
|
Comment on lines
+348
to
+350
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same as above. |
||
|
|
||
| # Handle added projections for models like SD3, Flux, etc. | ||
| if ( | ||
| getattr(self, "add_q_proj", None) is not None | ||
| and getattr(self, "add_k_proj", None) is not None | ||
| and getattr(self, "add_v_proj", None) is not None | ||
| ): | ||
| concatenated_weights = torch.cat( | ||
| [self.add_q_proj.weight.data, self.add_k_proj.weight.data, self.add_v_proj.weight.data] | ||
| ) | ||
| in_features = concatenated_weights.shape[1] | ||
| out_features = concatenated_weights.shape[0] | ||
| if getattr(self, "add_k_proj", None) is not None and getattr(self, "add_v_proj", None) is not None: | ||
| if getattr(self, "add_q_proj", None) is not None: | ||
| # Added Self Attention (e.g. Flux) | ||
| concatenated_weights = torch.cat( | ||
| [self.add_q_proj.weight.data, self.add_k_proj.weight.data, self.add_v_proj.weight.data] | ||
| ) | ||
| in_features = concatenated_weights.shape[1] | ||
| out_features = concatenated_weights.shape[0] | ||
|
|
||
| self.to_added_qkv = nn.Linear( | ||
| in_features, out_features, bias=self.added_proj_bias, device=device, dtype=dtype | ||
| ) | ||
| self.to_added_qkv.weight.copy_(concatenated_weights) | ||
| if self.added_proj_bias: | ||
| concatenated_bias = torch.cat( | ||
| [self.add_q_proj.bias.data, self.add_k_proj.bias.data, self.add_v_proj.bias.data] | ||
| self.to_added_qkv = nn.Linear( | ||
| in_features, out_features, bias=self.added_proj_bias, device=device, dtype=dtype | ||
| ) | ||
| self.to_added_qkv.bias.copy_(concatenated_bias) | ||
| self.to_added_qkv.weight.copy_(concatenated_weights) | ||
| if self.added_proj_bias: | ||
| concatenated_bias = torch.cat( | ||
| [self.add_q_proj.bias.data, self.add_k_proj.bias.data, self.add_v_proj.bias.data] | ||
| ) | ||
| self.to_added_qkv.bias.copy_(concatenated_bias) | ||
|
|
||
| if inplace: | ||
| self._added_qkv_split_dims = ( | ||
| self.add_q_proj.weight.shape[0], | ||
| self.add_k_proj.weight.shape[0], | ||
| self.add_v_proj.weight.shape[0], | ||
| self.add_q_proj.weight.shape[1], | ||
| ) | ||
| delattr(self, "add_q_proj") | ||
| delattr(self, "add_k_proj") | ||
| delattr(self, "add_v_proj") | ||
| else: | ||
| # Added Cross Attention (e.g. Wan) | ||
| concatenated_weights = torch.cat([self.add_k_proj.weight.data, self.add_v_proj.weight.data]) | ||
| in_features = concatenated_weights.shape[1] | ||
| out_features = concatenated_weights.shape[0] | ||
|
|
||
| self.to_added_kv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype) | ||
| self.to_added_kv.weight.copy_(concatenated_weights) | ||
| if hasattr(self, "use_bias") and self.use_bias: | ||
| concatenated_bias = torch.cat([self.add_k_proj.bias.data, self.add_v_proj.bias.data]) | ||
| self.to_added_kv.bias.copy_(concatenated_bias) | ||
|
|
||
| if inplace: | ||
| self._added_qkv_split_dims = ( | ||
| self.add_k_proj.weight.shape[0], | ||
| self.add_v_proj.weight.shape[0], | ||
| self.add_k_proj.weight.shape[1], | ||
| ) | ||
| delattr(self, "add_k_proj") | ||
| delattr(self, "add_v_proj") | ||
|
|
||
| self.fused_projections = True | ||
|
|
||
|
|
@@ -316,6 +407,22 @@ def unfuse_projections(self): | |
| """ | ||
| Unfuse the query, key, and value projections back to separate projections. | ||
| """ | ||
| # Do not unfuse if LoRA adapters are active on the Q,K,V projections. | ||
| possible_fused_modules = [ | ||
| ("to_qkv", getattr(self, "to_qkv", None)), | ||
| ("to_kv", getattr(self, "to_kv", None)), | ||
| ("to_added_qkv", getattr(self, "to_added_qkv", None)), | ||
| ("to_added_kv", getattr(self, "to_added_kv", None)), | ||
| ] | ||
| active_lora_modules = [ | ||
| name for name, mod in possible_fused_modules if mod is not None and self._has_active_lora(mod) | ||
| ] | ||
| if active_lora_modules: | ||
| raise ValueError( | ||
| f"Cannot unfuse QKV projections: LoRA adapters are active on {active_lora_modules}. " | ||
| "Please detach the LoRA or call `merge_and_unload()` to merge LoRA weights first." | ||
| ) | ||
|
|
||
| # Skip if the AttentionModuleMixin subclass does not support fusion (for example, the QKV projections in Flux2 | ||
| # single stream blocks are always fused) | ||
| if not self._supports_qkv_fusion: | ||
|
|
@@ -327,16 +434,122 @@ def unfuse_projections(self): | |
|
|
||
| # Remove fused projection layers | ||
| if hasattr(self, "to_qkv"): | ||
| if not hasattr(self, "to_q"): | ||
| # QKV fused in-place, need to reconstruct the individual Q,K,V projections | ||
| has_bias = self.to_qkv.bias is not None | ||
| d_q, d_k, d_v, d_in = self._qkv_split_dims | ||
| self.to_q = nn.Linear(d_in, d_q, bias=has_bias) | ||
| self.to_k = nn.Linear(d_in, d_k, bias=has_bias) | ||
| self.to_v = nn.Linear(d_in, d_v, bias=has_bias) | ||
| # Avoid copying by using a view which shares storage with the fused projection | ||
| self.to_q.weight = nn.Parameter(self.to_qkv.weight[:d_q]) | ||
| self.to_k.weight = nn.Parameter(self.to_qkv.weight[d_q : d_q + d_k]) | ||
| self.to_v.weight = nn.Parameter(self.to_qkv.weight[d_q + d_k :]) | ||
| if has_bias: | ||
| self.to_q.bias = nn.Parameter(self.to_qkv.bias[:d_q]) | ||
| self.to_k.bias = nn.Parameter(self.to_qkv.bias[d_q : d_q + d_k]) | ||
| self.to_v.bias = nn.Parameter(self.to_qkv.bias[d_q + d_k :]) | ||
| delattr(self, "to_qkv") | ||
|
|
||
| if hasattr(self, "to_kv"): | ||
| if not hasattr(self, "to_k"): | ||
| has_bias = self.to_kv.bias is not None | ||
| d_k, d_v, d_in = self._qkv_split_dims | ||
| self.to_k = nn.Linear(d_in, d_k, bias=has_bias) | ||
| self.to_v = nn.Linear(d_in, d_v, bias=has_bias) | ||
| self.to_k.weight = nn.Parameter(self.to_kv.weight[:d_k]) | ||
| self.to_v.weight = nn.Parameter(self.to_kv.weight[d_k:]) | ||
| if has_bias: | ||
| self.to_k.bias = nn.Parameter(self.to_kv.bias[:d_k]) | ||
| self.to_v.bias = nn.Parameter(self.to_kv.bias[d_k:]) | ||
| delattr(self, "to_kv") | ||
|
|
||
| if hasattr(self, "to_added_qkv"): | ||
| if not hasattr(self, "add_q_proj"): | ||
| has_bias = self.to_added_qkv.bias is not None | ||
| d_q, d_k, d_v, d_in = self._added_qkv_split_dims | ||
| self.add_q_proj = nn.Linear(d_in, d_q, bias=has_bias) | ||
| self.add_k_proj = nn.Linear(d_in, d_k, bias=has_bias) | ||
| self.add_v_proj = nn.Linear(d_in, d_v, bias=has_bias) | ||
| # Avoid copying by using a view which shares storage with the fused projection | ||
| self.add_q_proj.weight = nn.Parameter(self.to_added_qkv.weight[:d_q]) | ||
| self.add_k_proj.weight = nn.Parameter(self.to_added_qkv.weight[d_q : d_q + d_k]) | ||
| self.add_v_proj.weight = nn.Parameter(self.to_added_qkv.weight[d_q + d_k :]) | ||
| if has_bias: | ||
| self.add_q_proj.bias = nn.Parameter(self.to_added_qkv.bias[:d_q]) | ||
| self.add_k_proj.bias = nn.Parameter(self.to_added_qkv.bias[d_q : d_q + d_k]) | ||
| self.add_v_proj.bias = nn.Parameter(self.to_added_qkv.bias[d_q + d_k :]) | ||
| delattr(self, "to_added_qkv") | ||
|
|
||
| if hasattr(self, "to_added_kv"): | ||
| if not hasattr(self, "add_k_proj"): | ||
| has_bias = self.to_added_kv.bias is not None | ||
| d_k, d_v, d_in = self._added_qkv_split_dims | ||
| self.add_k_proj = nn.Linear(d_in, d_k, bias=has_bias) | ||
| self.add_v_proj = nn.Linear(d_in, d_v, bias=has_bias) | ||
| self.add_k_proj.weight = nn.Parameter(self.to_added_kv.weight[:d_k]) | ||
| self.add_v_proj.weight = nn.Parameter(self.to_added_kv.weight[d_k:]) | ||
| if has_bias: | ||
| self.add_k_proj.bias = nn.Parameter(self.to_added_kv.bias[:d_k]) | ||
| self.add_v_proj.bias = nn.Parameter(self.to_added_kv.bias[d_k:]) | ||
| delattr(self, "to_added_kv") | ||
|
|
||
| if hasattr(self, "_qkv_split_dims"): | ||
| delattr(self, "_qkv_split_dims") | ||
| if hasattr(self, "_added_qkv_split_dims"): | ||
| delattr(self, "_added_qkv_split_dims") | ||
| self.fused_projections = False | ||
|
|
||
| def get_qkv( | ||
| self, | ||
| hidden_states: torch.Tensor, | ||
| encoder_hidden_states: torch.Tensor | None = None, | ||
| ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | ||
| """ | ||
| Get the query, key, and value from the Q,K,V projections, handling both the split and fused cases. | ||
| """ | ||
| if self.fused_projections: | ||
| if hasattr(self, "to_kv"): | ||
| query = self.to_q(hidden_states) | ||
| key, value = self.to_kv(encoder_hidden_states).chunk(2, dim=-1) | ||
| elif hasattr(self, "to_qkv"): | ||
| query, key, value = self.to_qkv(hidden_states).chunk(3, dim=-1) | ||
| else: | ||
| raise RuntimeError("Cannot find fused self-attn proj `to_qkv` or cross-attn proj `to_kv`.") | ||
| else: | ||
| query = self.to_q(hidden_states) | ||
| kv_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states | ||
| key = self.to_k(kv_states) | ||
| value = self.to_v(kv_states) | ||
| return query, key, value | ||
|
|
||
| def get_added_qkv( | ||
| self, | ||
| hidden_states: torch.Tensor, | ||
| encoder_hidden_states: torch.Tensor | None = None, | ||
| ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | ||
| """ | ||
| Get the added query, key, and value from added Q,K,V projections (for example, second stream projections in a | ||
| MM-DiT-style model like Flux). Note that for models with only `add_k_proj`/`add_v_proj` such as Wan, Q comes | ||
| from the normal `to_q` projection. | ||
| """ | ||
| if self.fused_projections: | ||
| if hasattr(self, "to_added_kv"): | ||
| query = self.to_q(hidden_states) | ||
| key, value = self.to_added_kv(encoder_hidden_states).chunk(2, dim=-1) | ||
| elif hasattr(self, "to_added_qkv"): | ||
| query, key, value = self.to_added_qkv(hidden_states).chunk(3, dim=-1) | ||
| else: | ||
| raise RuntimeError( | ||
| "Cannot find added fused self-attn proj `to_added_qkv` or cross-attn proj `to_added_kv`." | ||
| ) | ||
| else: | ||
| query = self.add_q_proj(hidden_states) if hasattr(self, "add_q_proj") else self.to_q(hidden_states) | ||
| kv_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states | ||
| key = self.add_k_proj(kv_states) | ||
| value = self.add_v_proj(kv_states) | ||
| return query, key, value | ||
|
|
||
| def set_attention_slice(self, slice_size: int) -> None: | ||
| """ | ||
| Set the slice size for attention computation. | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you provide an example of where / how this is used?
My mental model says:
fuse_qkv_projections()method onpipe.transformer.