forked from wojzaremba/lstm
-
Notifications
You must be signed in to change notification settings - Fork 38
Expand file tree
/
Copy pathdata.lua
More file actions
79 lines (69 loc) · 2.15 KB
/
data.lua
File metadata and controls
79 lines (69 loc) · 2.15 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
--
---- Copyright (c) 2014, Facebook, Inc.
---- All rights reserved.
----
---- This source code is licensed under the Apache 2 license found in the
---- LICENSE file in the root directory of this source tree.
----
local stringx = require('pl.stringx')
local file = require('pl.file')
local ptb_path = "./data/"
local trainfn = ptb_path .. "ptb.train.txt"
local testfn = ptb_path .. "ptb.test.txt"
local validfn = ptb_path .. "ptb.valid.txt"
--[[
local trainfn = ptb_path .. "ptb.char.train.txt"
local validfn = ptb_path .. "ptb.char.valid.txt"
--]]
local vocab_idx = 0
local vocab_map = {}
-- Stacks replicated, shifted versions of x_inp
-- into a single matrix of size x_inp:size(1) x batch_size.
local function replicate(x_inp, batch_size)
local s = x_inp:size(1)
local x = torch.zeros(torch.floor(s / batch_size), batch_size)
for i = 1, batch_size do
local start = torch.round((i - 1) * s / batch_size) + 1
local finish = start + x:size(1) - 1
x:sub(1, x:size(1), i, i):copy(x_inp:sub(start, finish))
end
return x
end
local function load_data(fname)
local data = file.read(fname)
data = stringx.replace(data, '\n', '<eos>')
data = stringx.split(data)
--print(string.format("Loading %s, size of data = %d", fname, #data))
local x = torch.zeros(#data)
for i = 1, #data do
if vocab_map[data[i]] == nil then
vocab_idx = vocab_idx + 1
vocab_map[data[i]] = vocab_idx
end
x[i] = vocab_map[data[i]]
end
return x
end
local function traindataset(batch_size, char)
local x = load_data(trainfn)
x = replicate(x, batch_size)
return x
end
-- Intentionally we repeat dimensions without offseting.
-- Pass over this batch corresponds to the fully sequential processing.
local function testdataset(batch_size)
if testfn then
local x = load_data(testfn)
x = x:resize(x:size(1), 1):expand(x:size(1), batch_size)
return x
end
end
local function validdataset(batch_size)
local x = load_data(validfn)
x = replicate(x, batch_size)
return x
end
return {traindataset=traindataset,
testdataset=testdataset,
validdataset=validdataset,
vocab_map=vocab_map}