Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
249 changes: 16 additions & 233 deletions Transformer.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,205 +2,26 @@
"cells": [
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 127,
"metadata": {},
"outputs": [],
"source": [
"from torch import nn\n",
"import torch\n",
"import torch.nn.functional as F\n",
"from math import sqrt"
"from math import sqrt\n",
"\n",
"from encoder import EncoderLayer\n",
"from decoder import DecoderLayer\n",
"from multi_head_attention import MultiHeadAttention\n",
"from positional_encoding import PositionalEncoding\n",
"from scaled_dot_attention import attention\n",
"from embedding import WordEmbedding"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"class MultiHeadAttention(nn.Module):\n",
" def __init__(self, num_heads, embed_dim, input_dim, dropout=0.1):\n",
" super().__init__()\n",
"\n",
" self.embed_dim = embed_dim\n",
" self.num_heads = num_heads\n",
" self.dim_heads = embed_dim // num_heads # dim_heads aka d_k\n",
"\n",
" self.q_lin = nn.Linear(input_dim, embed_dim)\n",
" self.k_lin = nn.Linear(input_dim, embed_dim)\n",
" self.v_lin = nn.Linear(input_dim, embed_dim)\n",
" self.out_proj = nn.Linear(embed_dim, embed_dim)\n",
" self.dropout = nn.Dropout(dropout)\n",
"\n",
" def forward(self, q, k, v, mask=None):\n",
" batch_size = q.size(0)\n",
" num_heads, dim_heads = self.num_heads, self.dim_heads\n",
"\n",
" q = self.q_lin(q).reshape(batch_size, -1, num_heads, dim_heads).transpose(1, 2)\n",
" k = self.k_lin(k).reshape(batch_size, -1, num_heads, dim_heads).transpose(1, 2)\n",
" v = self.v_lin(v).reshape(batch_size, -1, num_heads, dim_heads).transpose(1, 2)\n",
"\n",
" scores = attention(q, k, v, dim_heads, mask=mask, dropout=self.dropout)\n",
" \n",
" scores = scores.transpose(1, 2).contiguous().reshape(batch_size, -1, self.embed_dim)\n",
"\n",
" output = self.out_proj(scores)\n",
"\n",
" return output"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"def attention(q, k, v, d_k, mask=None, dropout=None):\n",
" scaled_dot = torch.matmul(q, k.transpose(-2, -1)) / sqrt(d_k)\n",
" if mask is not None:\n",
" scaled_dot = scaled_dot.masked_fill(mask == 0, -1e9)\n",
" scaled_dot = F.softmax(scaled_dot, dim=-1)\n",
" if dropout is not None:\n",
" scaled_dot = dropout(scaled_dot)\n",
" output = torch.matmul(scaled_dot, v)\n",
" return output"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"class PositionwiseFeedForward(nn.Module):\n",
"\n",
" def __init__(self, embed_dim, input_dim, dropout_rate=0.1):\n",
" \"\"\"\n",
" embed_dim: num of expected features in input (same as d_model)\n",
" input_dim: length of sequence\n",
" \"\"\"\n",
" super(PositionwiseFeedForward, self).__init__()\n",
" self.embed_dim = embed_dim\n",
" self.input_dim = input_dim\n",
" self.dropout_rate = dropout_rate\n",
" self.w_1 = nn.Linear(embed_dim, input_dim)\n",
" self.w_2 = nn.Linear(input_dim, embed_dim)\n",
" self.dropout = nn.Dropout(dropout_rate)\n",
" \n",
" def forward(self, x):\n",
" # x = (batch_size, input_dim, embed_dim)\n",
" x = self.dropout(F.relu(self.w_1(x))) \n",
" x = self.w_2(x) \n",
" return x"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"class WordEmbedding(nn.Module):\n",
" def __init__(self, vocab_size, embed_dim=512):\n",
" super().__init__()\n",
" # embed_dim: embedding dimension (usually 1024 or 512)\n",
" self.embed_dim = embed_dim\n",
" self.embed_matrix = torch.empty([vocab_size, embed_dim])\n",
"\n",
" nn.init.xavier_normal_(self.embed_matrix)\n",
" self.embed_matrix = nn.Parameter(self.embed_matrix)\n",
" self.embed_matrix = self.embed_matrix.to(torch.float)\n",
" # seq len x vocab_size, vocab_size x embed_dim\n",
" # embedding matrix dimensions: number of words in vocab x embed_dim (usually 1024 or 512)\n",
"\n",
" def forward(self, x):\n",
" # x: embedding tensor (batch_size by seq_len by vocab_size)\n",
" return torch.matmul(x, self.embed_matrix)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"class PositionalEncoding(nn.Module):\n",
"\n",
" def __init__(self, embed_dim, input_dim):\n",
" \"\"\"\n",
" embed_dim: num of expected features in input (same as d_model)\n",
" input_dim: length of sequence\n",
" \"\"\"\n",
" super().__init__()\n",
"\n",
" encod = torch.zeros(input_dim, embed_dim)\n",
"\n",
" position = torch.arange(0, input_dim, dtype=torch.float).unsqueeze(1) # numerator\n",
"\n",
" i = torch.arange(0, embed_dim, 2, dtype=torch.float)\n",
"\n",
" denom = torch.exp(log(10000.0) * i / embed_dim)\n",
"\n",
" encod[ : , 0::2] = torch.sin(position / denom)\n",
" encod[ : , 1::2] = torch.cos(position / denom)\n",
" encod.unsqueeze(0)\n",
"\n",
" self.pe = encod\n",
"\n",
"\n",
" def forward(self, x):\n",
" x = x + self.pe[:, : x.size(1)]\n",
" return x"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"class DecoderLayer(nn.Module):\n",
" def __init__(self, embed_dim, input_dim, num_heads):\n",
" \"\"\"\n",
" embed_dim: num of expected features in input (same as d_model)\n",
" input_dim: length of sequence\n",
" num_heads: num of heads\n",
" \"\"\"\n",
" super().__init__()\n",
"\n",
" self.attention1 = MultiHeadAttention(num_heads=num_heads, embed_dim=embed_dim, input_dim=input_dim, dropout=0.1)\n",
" self.attention2 = MultiHeadAttention(num_heads=num_heads, embed_dim=embed_dim, input_dim=input_dim, dropout=0.1)\n",
" self.feedforward = PositionwiseFeedForward(embed_dim=embed_dim, input_dim=input_dim)\n",
"\n",
" self.norm1 = nn.LayerNorm(input_dim)\n",
" self.norm2 = nn.LayerNorm(input_dim)\n",
" self.norm3 = nn.LayerNorm(input_dim)\n",
" self.dropout1 = nn.Dropout(0.1)\n",
" self.dropout2 = nn.Dropout(0.1)\n",
" self.dropout3 = nn.Dropout(0.1)\n",
"\n",
" def forward(self, x, encod_out, mask=None):\n",
" # masked attention output\n",
" attn_1_out = self.attention1(q=x, k=x, v=x, mask=mask)\n",
" x = x + self.dropout1(attn_1_out)\n",
" x = self.norm1(x)\n",
"\n",
" # unmasked attention output with encoder input\n",
" attn_2_out = self.attention2(q=x, k=encod_out, v=encod_out, mask=None)\n",
" x = x + self.dropout2(attn_2_out)\n",
" x = self.norm2(x)\n",
"\n",
" # feedforward output\n",
" ff_out = self.feedforward(x) \n",
" x = x + self.dropout3(ff_out)\n",
" x = self.norm3(x)\n",
"\n",
" return x"
]
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 128,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -218,52 +39,14 @@
"\n",
" def forward(self, x, encod_out, mask=None):\n",
" for layer in self.decoder_layers:\n",
" x = layer(x, encod_out, mask)\n",
" x = layer(x, mask)\n",
" \n",
" return x"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"class EncoderLayer(nn.Module):\n",
" def __init__(self, embed_dim, input_dim, num_heads, dropout=0.1):\n",
" \"\"\"\n",
" Single Encoder layer\n",
" embed_dim: num of expected features in input (same as d_model)\n",
" input_dim: length of sequence\n",
" num_heads: num of heads\n",
" \"\"\"\n",
" super().__init__()\n",
"\n",
" self.attention = MultiHeadAttention(num_heads=num_heads, embed_dim=embed_dim, input_dim=input_dim, dropout=0.1)\n",
" self.feedforward = PositionwiseFeedForward(embed_dim, input_dim)\n",
"\n",
" self.norm1 = nn.LayerNorm(input_dim)\n",
" self.norm2 = nn.LayerNorm(input_dim)\n",
" self.dropout1 = nn.Dropout(dropout)\n",
" self.dropout2 = nn.Dropout(dropout)\n",
"\n",
" def forward(self, x, mask=None):\n",
" # attention output\n",
" attn_out = self.attention(q=x, k=x, v=x, mask=mask)\n",
" x = x + self.dropout1(attn_out)\n",
" x = self.norm1(x)\n",
"\n",
" # feedforward output\n",
" ff_out = self.feedforward(x)\n",
" x = x + self.dropout2(ff_out)\n",
" x = self.norm2(x)\n",
"\n",
" return x"
]
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 129,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -277,7 +60,7 @@
" \"\"\"\n",
" super().__init__()\n",
"\n",
" self.encoder_layers = nn.ModuleList( [ EncoderLayer(embed_dim, input_dim, num_heads, dropout) for x in range(num_layers) ] )\n",
" self.encoder_layers = nn.ModuleList( [ EncoderLayer(embed_dim, input_dim, num_heads) for x in range(num_layers) ] )\n",
"\n",
" def forward(self, x, mask=None):\n",
" for layer in self.encoder_layers:\n",
Expand All @@ -288,13 +71,13 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 130,
"metadata": {},
"outputs": [],
"source": [
"class Transformer(nn.Module):\n",
" def __init__(self, vocab_size, embed_dim, input_dim, num_heads, num_layers_encod = 6, num_layers_decod = 6, dropout = 0.1):\n",
" super.__init__()\n",
" super().__init__()\n",
"\n",
" self.embedding1 = WordEmbedding(vocab_size, embed_dim)\n",
" self.embedding2 = WordEmbedding(vocab_size, embed_dim)\n",
Expand Down Expand Up @@ -326,7 +109,7 @@
" out = self.linear(decod_out)\n",
" out = self.soft(out)\n",
"\n",
" return out\n"
" return out"
]
}
],
Expand Down
17 changes: 15 additions & 2 deletions encoder.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from torch import nn
import torch

from multi_head_attention import MultiHeadAttention
from feedforward import PositionwiseFeedForward
Expand Down Expand Up @@ -27,8 +28,20 @@ def forward(self, x, mask=None):
x = self.norm1(x)

# feedforward output
ff_out = self.feedforward(x) # TODO: needs to be implemented
ff_out = self.feedforward(x)
x = x + self.dropout2(ff_out)
x = self.norm2(x)

return x
return x

if __name__ == "__main__":
# TESTING
embed_dim = 3
num_heads = 1

x = torch.tensor([[0, 10, 0]], dtype=torch.float32)
input_dim = 3

encoder = EncoderLayer(embed_dim, input_dim, num_heads)
output = encoder.forward(x)
print(output)
5 changes: 2 additions & 3 deletions feedforward.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,8 @@ def __init__(self, embed_dim, input_dim, dropout_rate=0.1):
super(PositionwiseFeedForward, self).__init__()
self.embed_dim = embed_dim
self.input_dim = input_dim
self.dropout_rate = dropout_rate
self.w_1 = nn.Linear(embed_dim, input_dim)
self.w_2 = nn.Linear(input_dim, embed_dim)
self.w_1 = nn.Linear(embed_dim, 4*embed_dim)
self.w_2 = nn.Linear(4*embed_dim, embed_dim)
self.dropout = nn.Dropout(dropout_rate)

def forward(self, x):
Expand Down