-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathmemory.lua
More file actions
112 lines (93 loc) · 4.24 KB
/
memory.lua
File metadata and controls
112 lines (93 loc) · 4.24 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
require 'nn'
require 'nngraph'
require 'layers/CumulativeSum'
require 'layers/ScalarAddTable'
local Memory = {}
function Memory.updateStrength(prev_strength, pop, is_stack)
local neg_cumsum = nn.MulConstant(-1)(nn.CSubTable()(
{ nn.CumulativeSum(is_stack)(prev_strength),
prev_strength }))
local inner_max = nn.ReLU()(nn.ScalarAddTable()({neg_cumsum, pop}))
local outer_max = nn.ReLU()(nn.CSubTable()({prev_strength, inner_max}))
return outer_max
end
function Memory.computeRead(strength, memory_vectors, is_stack, opt)
local neg_cumsum = nn.MulConstant(-1)(nn.CSubTable()(
{ nn.CumulativeSum(is_stack)(strength),
strength }))
local inner_max = nn.ReLU()(nn.AddConstant(1)(neg_cumsum))
local coeff = nn.Min(1)(nn.JoinTable(1)(
{ nn.View(1, opt.batch_size, -1)(strength),
nn.View(1, opt.batch_size, -1)(inner_max) }))
local read = nn.MixtureTable(2)({coeff, memory_vectors})
return read
end
function Memory.oneSidedMemory(prev_memory_vectors, prev_strength,
new_memory, pop, push, is_stack, opt)
local new_memory_vectors = nn.JoinTable(2)(
{ prev_memory_vectors,
nn.Reshape(1, opt.memory_size, true)(new_memory) })
local updated_strength = Memory.updateStrength(prev_strength, pop, is_stack)
local new_strength = nn.JoinTable(2)({updated_strength, push})
local read = Memory.computeRead(new_strength, new_memory_vectors, is_stack, opt)
return {new_memory_vectors, new_strength, read}
end
function Memory.Stack(prev_memory_vectors, prev_strength,
new_memory, pop, push, opt)
return Memory.oneSidedMemory(prev_memory_vectors, prev_strength,
new_memory, pop, push, true, opt)
end
function Memory.Queue(prev_memory_vectors, prev_strength,
new_memory, pop, push, opt)
return Memory.oneSidedMemory(prev_memory_vectors, prev_strength,
new_memory, pop, push, false, opt)
end
function Memory.DeQue(prev_memory_vectors, prev_strength,
memory_top, memory_bot, pop_top, pop_bot,
push_top, push_bot, opt)
local new_memory_vectors = nn.JoinTable(2)(
{ nn.Reshape(1, opt.memory_size, true)(memory_bot),
prev_memory_vectors,
nn.Reshape(1, opt.memory_size, true)(memory_top) })
local strength_top = Memory.updateStrength(prev_strength, pop_top, true)
local strength_both = Memory.updateStrength(strength_top, pop_bot, false)
local new_strength = nn.JoinTable(2)({push_bot, strength_both, push_top})
local read_top = Memory.computeRead(new_strength, new_memory_vectors, true, opt)
local read_bot = Memory.computeRead(new_strength, new_memory_vectors, false, opt)
return {new_memory_vectors, new_strength, read_top, read_bot}
end
function Memory.oneSidedMemoryModule(MemoryType, opt)
local prev_memory_vectors = nn.Identity()()
local prev_strength = nn.Identity()()
local new_memory = nn.Identity()()
local pop = nn.Identity()()
local push = nn.Identity()()
local outputs = MemoryType(prev_memory_vectors, prev_strength,
new_memory, pop, push, opt)
return nn.gModule(
{prev_memory_vectors, prev_strength, new_memory, pop, push},
outputs)
end
function Memory.StackModule(opt)
return Memory.oneSidedMemoryModule(Memory.Stack, opt)
end
function Memory.QueueModule(opt)
return Memory.oneSidedMemoryModule(Memory.Queue, opt)
end
function Memory.DeQueModule(opt)
local prev_memory_vectors = nn.Identity()()
local prev_strength = nn.Identity()()
local memory_top = nn.Identity()()
local memory_bot = nn.Identity()()
local pop_top = nn.Identity()()
local pop_bot = nn.Identity()()
local push_top = nn.Identity()()
local push_bot = nn.Identity()()
local outputs = Memory.DeQue(prev_memory_vectors, prev_strength, memory_top,
memory_bot, pop_top, pop_bot, push_top, push_bot, opt)
return nn.gModule(
{prev_memory_vectors, prev_strength, memory_top, memory_bot,
pop_top, pop_bot, push_top, push_bot},
outputs)
end
return Memory