-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathgpt.py
More file actions
46 lines (38 loc) · 1.58 KB
/
gpt.py
File metadata and controls
46 lines (38 loc) · 1.58 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import math
import torch
from torch import nn
from torch.utils.data import Dataset
from .model import GPT as GPTBody
from .base import TextGenerator
class TextDataset(Dataset):
def __init__(self, tokens, seq_size):
self.tokens = tokens
self.seq_size = seq_size
def __getitem__(self, idx):
context = self.tokens[idx:idx+self.seq_size]
target = self.tokens[idx+1:idx+1+self.seq_size]
return torch.tensor(context), torch.tensor(target)
def __len__(self):
return len(self.tokens) - self.seq_size - 1
class GPT(GPTBody, TextGenerator):
def __init__(self, vocab_size, embed_dim, hidden_size, num_heads, num_layers, seq_size, lr=3e-4, test_size=0.8, device='cpu'):
GPTBody.__init__(self, embed_dim, hidden_size, num_heads, num_layers, vocab_size)
self.device = device
self.test_size = test_size
self.ws = seq_size
self.vocab_size = vocab_size
self.W1 = nn.Embedding(vocab_size, embed_dim)
self.W1_weight = self.W1.weight
self.optim = torch.optim.Adam(self.parameters(), lr=lr)
self.loss = nn.CrossEntropyLoss(reduction='sum')
self.metric = lambda x: math.exp(x)
self.dataset = TextDataset
self.to(self.device)
def forward(self, idxs):
self.embed = self.W1(idxs)
output = GPTBody.forward(self, self.embed).view(-1, self.vocab_size)
return output
def predict(self, idxs, return_hidden=False):
self.embed = self.W1_weight[idxs]
output = GPTBody.predict(self, self.embed, return_hidden)
return output