-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathlora.py
More file actions
149 lines (131 loc) · 5.67 KB
/
lora.py
File metadata and controls
149 lines (131 loc) · 5.67 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
import torch
from torch import nn
from torch.nn import functional as F
import math
def find_module(root_module: nn.Module, key: str):
"""
Find a module with a specific name in a Transformer model
From OpenDelta https://github.com/thunlp/OpenDelta
"""
sub_keys = key.split(".")
parent_module = root_module
for sub_key in sub_keys[:-1]:
parent_module = getattr(parent_module, sub_key)
module = getattr(parent_module, sub_keys[-1])
return parent_module, sub_keys[-1], module
class LoRALinear(nn.Linear):
"""
LoRA implemented in a dense layer
From https://github.com/microsoft/LoRA/blob/main/loralib/layers.py
"""
def __init__(
self,
in_features: int,
out_features: int,
r: int = 0,
lora_alpha: int = 1,
lora_dropout: float = 0.,
fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
merge_weights: bool = False, # Not sure if this will affect saving/loading models so just set it to be False
**kwargs
):
nn.Linear.__init__(self, in_features, out_features, **kwargs)
self.r = r
self.lora_alpha = lora_alpha
# Optional dropout
if lora_dropout > 0.:
self.lora_dropout = nn.Dropout(p=lora_dropout)
else:
self.lora_dropout = lambda x: x
# Mark the weight as unmerged
self.merged = False
self.merge_weights = merge_weights
self.fan_in_fan_out = fan_in_fan_out
# Actual trainable parameters
if r > 0:
self.lora_A = nn.Parameter(self.weight.new_zeros((r, in_features)))
self.lora_B = nn.Parameter(self.weight.new_zeros((out_features, r)))
self.scaling = self.lora_alpha / self.r
# Freezing the pre-trained weight matrix
self.weight.requires_grad = False
self.reset_parameters()
if fan_in_fan_out:
self.weight.data = self.weight.data.transpose(0, 1)
def reset_parameters(self):
nn.Linear.reset_parameters(self)
if hasattr(self, 'lora_A'):
# initialize A the same way as the default for nn.Linear and B to zero
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
nn.init.zeros_(self.lora_B)
def train(self, mode: bool = True):
def T(w):
return w.transpose(0, 1) if self.fan_in_fan_out else w
nn.Linear.train(self, mode)
if mode:
if self.merge_weights and self.merged:
# Make sure that the weights are not merged
if self.r > 0:
self.weight.data -= T(self.lora_B @ self.lora_A) * self.scaling
self.merged = False
else:
if self.merge_weights and not self.merged:
# Merge the weights and mark it
if self.r > 0:
self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling
self.merged = True
def forward(self, x: torch.Tensor):
def T(w):
return w.transpose(0, 1) if self.fan_in_fan_out else w
if self.r > 0 and not self.merged:
result = F.linear(x, T(self.weight), bias=self.bias)
if self.r > 0:
result += (self.lora_dropout(x) @ self.lora_A.transpose(0, 1) @ self.lora_B.transpose(0, 1)) * self.scaling
return result
else:
return F.linear(x, T(self.weight), bias=self.bias)
class LoRA:
def __init__(self, model, r, alpha, float16):
"""
Input:
r, alpha: LoRA hyperparameters
float16: Whether the model parameters are float16 or not
"""
self.model = model
self.hidden_dim = model.config.hidden_size
self.float16 = float16
if model.config.model_type == "opt":
attention_name = "attn"
elif model.config.model_type == "roberta":
attention_name = "attention"
else:
attention_name = "attn"
# raise NotImplementedError
# Insert LoRA
for key, _ in model.named_modules():
if key[-len(attention_name):] == attention_name:
logger.info(f"Inject lora to: {key}")
_, _, attn = find_module(model, key)
if model.config.model_type == "opt" or model.config.model_type == "llama":
original_q_weight = attn.q_proj.weight.data
# original_q_bias = attn.q_proj.bias.data
original_v_weight= attn.v_proj.weight.data
# original_v_bias = attn.v_proj.bias.data
attn.q_proj = LoRALinear(model.config.hidden_size, model.config.hidden_size, r=r, lora_alpha=alpha).to(original_q_weight.device)
attn.v_proj = LoRALinear(model.config.hidden_size, model.config.hidden_size, r=r, lora_alpha=alpha).to(original_v_weight.device)
if float16:
attn.q_proj.half()
attn.v_proj.half()
attn.q_proj.weight.data = original_q_weight
# attn.q_proj.bias.data = original_q_bias
attn.v_proj.weight.data = original_v_weight
# attn.v_proj.bias.data = original_v_bias
else:
raise NotImplementedError
# Freeze non-LoRA parameters
for n, p in model.named_parameters():
if "lora" not in n:
p.requires_grad = False