-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathTestModel.lua
More file actions
66 lines (64 loc) · 1.74 KB
/
TestModel.lua
File metadata and controls
66 lines (64 loc) · 1.74 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
--[Input] to [Output]
function TestModelOne(model, dataIn, dataOut)
err = 0
-- forward pass
local output = model:forward(dataIn)
if (output ~= dataOut) then
err = 1
end
return err
end
--[X][Input] to [X][OutputSize]
function TestModelRaw(model, dataIn, dataOut)
local err = 0
local all = 0
-- forward pass
local output = model:forward(dataIn)
--search 1 max val
for k = 1, dataOut:size(1) do
local _, mx = output[k]:max(1)
local _, mxt = dataOut[k]:max(1)
if (mx:squeeze() ~= mxt:squeeze()) then
err = err + 1
--err_all[targets[k]] = err_all[targets[k]] + 1
end
all = all + 1
end
--for k = 1, dataOut:size(1) do
-- local _, mx = output[k]:max(1)
-- if (mx:squeeze() ~= dataOut[k]) then
-- err = err + 1
-- --err_all[targets[k]] = err_all[targets[k]] + 1
-- end
-- all = all + 1
--end
return all,err
end
--[X][Input] to [X][OutputSize]
function TestModelRawX(model, dataIn, dataOut)
local err = torch.Tensor(Settings.OutputSize,Settings.OutputSize):fill(0)
local all = torch.Tensor(Settings.OutputSize):fill(0)
-- forward pass
local output = model:forward(dataIn)
--search 1 max val
for k = 1, dataOut:size(1) do
local _, mx = output[k]:max(1)
local _, mxt = dataOut[k]:max(1)
local mxs = mx:squeeze()
local mxts = mxt:squeeze()
if (mxs ~= mxts) then
err[mxts][mxs] = err[mxts][mxs]+1
--err_all[targets[k]] = err_all[targets[k]] + 1
end
all[mxts] = all[mxts] + 1
end
--for k = 1, dataOut:size(1) do
-- local _, mx = output[k]:max(1)
-- if (mx:squeeze() ~= dataOut[k]) then
-- err = err + 1
-- --err_all[targets[k]] = err_all[targets[k]] + 1
-- end
-- all = all + 1
--end
return all,err
end