-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathencoder.py
More file actions
72 lines (57 loc) · 2.41 KB
/
encoder.py
File metadata and controls
72 lines (57 loc) · 2.41 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import pytorch_lightning as pl
import torch
import torch.nn.functional as F
import dataset
class TermPredictor(pl.LightningModule):
# current version: Z^10 -> F_2^10
# represented as floats
def __init__(self, input_length = 10, output_length = 10) -> None:
super().__init__()
self.input_length = input_length
self.output_length = output_length
self.a1 = torch.nn.MultiheadAttention(3*self.input_length, 1)
self.a2 = torch.nn.MultiheadAttention(3*self.input_length, 1)
self.a3 = torch.nn.MultiheadAttention(3*self.input_length, 1)
self.fc = torch.nn.Linear(3*self.input_length, output_length)
def safe_log(self, x, eps=1e-7):
# so that log doesn't go to 0 when applied twice
x = F.relu(x)
x = torch.log(x + eps)
return x
def forward(self, x):
# Concatenate x with log(x) and log(log(x))
# TODO: fix this
augmented_tensor = x.repeat((1, 3)) # x is of shape [batch_size, input_length]
augmented_tensor[:, 0 : self.input_length] = x
augmented_tensor[:, self.input_length : self.input_length * 2] = self.safe_log(x)
augmented_tensor[:, self.input_length * 2: self.input_length * 3] = self.safe_log(self.safe_log(x))
augmented_tensor = self.a1(augmented_tensor, augmented_tensor, augmented_tensor)[0]
augmented_tensor = self.a2(augmented_tensor, augmented_tensor, augmented_tensor)[0]
augmented_tensor = self.a3(augmented_tensor, augmented_tensor, augmented_tensor)[0]
augmented_tensor = self.fc(augmented_tensor)
return augmented_tensor
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
self.log("train_loss", loss, on_epoch=True, prog_bar=True, logger=True)
return loss
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
self.log("val_loss", loss)
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=0.02)
if __name__ == '__main__':
# tp = TermPredictor()
# fd = dataset.FunctionDataset(size=1)
# i = fd[0]
# x, y = i
# x = torch.stack((x, x))
# o = tp.forward(x)
# print(o)
model = TermPredictor()
dm = dataset.FunctionDataModule()
trainer = pl.Trainer()
trainer.fit(model, datamodule=dm)