diff --git a/lxmls/transformers/model.py b/lxmls/transformers/model.py index d29a8eb2..cd3a34ae 100644 --- a/lxmls/transformers/model.py +++ b/lxmls/transformers/model.py @@ -16,7 +16,6 @@ from lxmls.transformers.utils import CfgNode as CN from lxmls.transformers.bpe import BPETokenizer -from lxmls.transformers.pretrained_attention import PretrainedCausalSelfAttention # ----------------------------------------------------------------------------- @@ -97,7 +96,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # Apply softmax activation to get attention weights # Check the correct axis for the softmax function! What should be the shape of the weights? weights = F.softmax(scores, dim=-1) - # Apply dropout to the attention weights weights = self.attn_dropout(weights) @@ -120,10 +118,7 @@ class Block(nn.Module): def __init__(self, config): super().__init__() self.ln_1 = nn.LayerNorm(config.n_embd) - if config.pretrained: - self.attn = PretrainedCausalSelfAttention(config) - else: - self.attn = CausalSelfAttention(config) + self.attn = CausalSelfAttention(config) self.ln_2 = nn.LayerNorm(config.n_embd) self.mlp = nn.ModuleDict( dict( @@ -160,7 +155,6 @@ def get_default_config(): C.embd_pdrop = 0.1 C.resid_pdrop = 0.1 C.attn_pdrop = 0.1 - C.pretrained = False return C def __init__(self, config): @@ -252,7 +246,6 @@ def from_pretrained(cls, model_type): config.model_type = model_type config.vocab_size = 50257 # openai's model vocabulary config.block_size = 1024 # openai's model block_size - config.pretrained = True model = GPT(config) sd = model.state_dict() @@ -261,6 +254,41 @@ def from_pretrained(cls, model_type): sd_hf = model_hf.state_dict() # copy while ensuring all of the parameters are aligned and match in names and shapes + def transfer_projection(sd): + keys_to_remove = [] + keys_to_add = [] + + for name, param in sd.items(): + if "c_attn" in name: + num_splits = 3 + if len(param.shape) > 1: + param = param.T + num_rows = param.shape[0] + if num_rows % num_splits == 0: + q, k, v = param.split(num_rows // num_splits, dim=0) + keys_to_remove.append(name) + keys_to_add.append( + (name.replace("c_attn.", "query_proj."), q)) + keys_to_add.append( + (name.replace("c_attn.", "key_proj."), k)) + keys_to_add.append( + (name.replace("c_attn.", "value_proj."), v)) + elif "attn.c_proj" in name: + keys_to_remove.append(name) + keys_to_add.append((name.replace("c_proj.", + "output_proj."), param)) + + # remove the keys from the OrderedDict + for key in keys_to_remove: + del sd[key] + + # add the new keys to the OrderedDict + for key, value in keys_to_add: + sd[key] = value + return sd + + sd_hf = transfer_projection(sd_hf) + keys = [k for k in sd_hf if not k.endswith('attn.masked_bias')] # ignore these keys = [ @@ -272,14 +300,16 @@ def from_pretrained(cls, model_type): ] # ignore these transposed = [ - 'attn.c_attn.weight', 'attn.c_proj.weight', 'mlp.c_fc.weight', - 'mlp.c_proj.weight' + 'attn.output_proj.weight', 'mlp.c_fc.weight', 'mlp.c_proj.weight' ] + # basically the openai checkpoints use a "Conv1D" module, but we only want to use a vanilla nn.Linear. # this means that we have to transpose these weights when we import them # This assert might fail for some transformers library versions. Please comment out if that is the case + assert len(keys) == len(sd_keys) + #sd = transfer_weights(sd_hf, sd) for k in keys: if any(k.endswith(w) for w in transposed): @@ -359,7 +389,6 @@ def forward(self, idx, targets=None): assert t <= self.block_size, f"Cannot forward sequence of length {t}, block size is only {self.block_size}" pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0) # shape (1, t) - # forward the GPT model itself tok_emb = self.transformer.wte( idx) # token embeddings of shape (b, t, n_embd) diff --git a/lxmls/transformers/pretrained_attention.py b/lxmls/transformers/pretrained_attention.py deleted file mode 100644 index 01205558..00000000 --- a/lxmls/transformers/pretrained_attention.py +++ /dev/null @@ -1,56 +0,0 @@ -import torch -import torch.nn as nn -import math -from torch.nn import functional as F - - -class PretrainedCausalSelfAttention(nn.Module): - """ - A vanilla multi-head masked self-attention layer with a projection at the end. - It is possible to use torch.nn.MultiheadAttention here but I am including an - explicit implementation here to show that there is nothing too scary here. - """ - - def __init__(self, config): - super().__init__() - assert config.n_embd % config.n_head == 0 - # key, query, value projections for all heads, but in a batch - self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd) - # output projection - self.c_proj = nn.Linear(config.n_embd, config.n_embd) - # regularization - self.attn_dropout = nn.Dropout(config.attn_pdrop) - self.resid_dropout = nn.Dropout(config.resid_pdrop) - # causal mask to ensure that attention is only applied to the left in the input sequence - self.register_buffer( - "bias", - torch.tril(torch.ones(config.block_size, config.block_size)).view( - 1, 1, config.block_size, config.block_size)) - self.n_head = config.n_head - self.n_embd = config.n_embd - - def forward(self, x): - B, T, C = x.size( - ) # batch size, sequence length, embedding dimensionality (n_embd) - - # calculate query, key, values for all heads in batch and move head forward to be the batch dim - q, k, v = self.c_attn(x).split(self.n_embd, dim=2) - k = k.view(B, T, self.n_head, - C // self.n_head).transpose(1, 2) # (B, nh, T, hs) - q = q.view(B, T, self.n_head, - C // self.n_head).transpose(1, 2) # (B, nh, T, hs) - v = v.view(B, T, self.n_head, - C // self.n_head).transpose(1, 2) # (B, nh, T, hs) - - # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) - att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) - att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float('-inf')) - att = F.softmax(att, dim=-1) - att = self.attn_dropout(att) - y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) - y = y.transpose(1, 2).contiguous().view( - B, T, C) # re-assemble all head outputs side by side - - # output projection - y = self.resid_dropout(self.c_proj(y)) - return y \ No newline at end of file