forked from Element-Research/dpnn
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathBigrams.lua
More file actions
135 lines (111 loc) · 3.42 KB
/
Bigrams.lua
File metadata and controls
135 lines (111 loc) · 3.42 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
local Bigrams, parent = torch.class("nn.Bigrams", "nn.Module")
--Function taken by torchx Aliasmultinomial.lua
function Bigrams:setup(probs)
assert(probs:dim() == 1)
local K = probs:nElement()
local q = probs.new(K):zero()
local J = torch.LongTensor(K):zero()
-- Sort the data into the outcomes with probabilities
-- that are larger and smaller than 1/K.
local smaller, larger = {}, {}
local maxk, maxp = 0, -1
for kk = 1,K do
local prob = probs[kk]
q[kk] = K*prob
if q[kk] < 1 then
table.insert(smaller, kk)
else
table.insert(larger, kk)
end
if maxk > maxp then
end
end
-- Loop through and create little binary mixtures that
-- appropriately allocate the larger outcomes over the
-- overall uniform mixture.
while #smaller > 0 and #larger > 0 do
local small = table.remove(smaller)
local large = table.remove(larger)
J[small] = large
q[large] = q[large] - (1.0 - q[small])
if q[large] < 1.0 then
table.insert(smaller,large)
else
table.insert(larger,large)
end
end
assert(q:min() >= 0)
if q:max() > 1 then
q:div(q:max())
end
assert(q:max() <= 1)
if J:min() <= 0 then
-- sometimes an large index isn't added to J.
-- fix it by making the probability 1 so that J isn't indexed.
local i = 0
J:apply(function(x)
i = i + 1
if x <= 0 then
q[i] = 1
end
end)
end
return J, q
end
function Bigrams:batchdraw(output, J, q)
assert(torch.type(output) == 'torch.LongTensor')
assert(output:nElement() > 0)
local K = J:nElement()
local _kk = output.new()
_kk:resizeAs(output):random(1,K)
local _q = q.new()
_q:index(q, 1, _kk:view(-1))
local _mask = torch.LongTensor()
_mask:resize(_q:size()):bernoulli(_q)
local __kk = output.new()
__kk:resize(_kk:size()):copy(_kk)
__kk:cmul(_mask)
-- if mask == 0 then output[i] = J[kk[i]] else output[i] = 0
_mask:add(-1):mul(-1) -- (1,0) - > (0,1)
output:view(-1):index(J, 1, _kk:view(-1))
output:cmul(_mask)
-- elseif mask == 1 then output[i] = kk[i]
output:add(__kk)
return output
end
function Bigrams:__init(bigrams, nsample)
self.nsample = nsample
self.bigrams = bigrams
self.q = {}
self.J = {}
for uniI, map in pairs(bigrams) do
local J, q = self.setup(self, map.prob)
self.J[uniI] = J
self.q[uniI] = q
end
end
function Bigrams:updateOutput(input)
assert(torch.type(input) == 'torch.LongTensor')
local batchSize = input:size(1)
self.output = torch.type(self.output) == 'torch.LongTensor' and self.output or torch.LongTensor()
self.output:resize(batchSize, self.nsample)
for i = 1, batchSize do
self.batchdraw(self, self.output[i], self.J[input[i]], self.q[input[i]])
self.output[i]:apply(function(x) return self.bigrams[input[i]]['index'][x] end)
end
return self.output
end
function Bigrams:updateGradInput(input, gradOutput)
self.gradInput = torch.type(self.gradInput) == 'torch.LongTensor' or torch.LongTensor()
self.gradInput:resizeAs(input):fill(0)
return self.gradInput
end
function Bigrams:statistics()
local sum, count = 0, 0
for uniI, map in pairs(self.bigrams) do
sum = sum + map.prob:nElement()
count = count + 1
end
local meansize = sum/count
return meansize
end