-
Notifications
You must be signed in to change notification settings - Fork 4
Expand file tree
/
Copy pathcnnTestData.m
More file actions
35 lines (30 loc) · 1.11 KB
/
cnnTestData.m
File metadata and controls
35 lines (30 loc) · 1.11 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
function [acc, e] = cnnTestData(cnn, VX, VY, numImages)
% Validate CNN Accuracy
% VData: validation data, [x-dim, y-dim, channel-num, data-count]
% VLabel: validation label, [1, data-count]
% numImages: number of images that want to validate
cnn.to.test = 1;
if cnn.to.useGPU == 1
images = gpuArray(single(VX(:, :, :, 1:numImages)));
mb_labels = gpuArray(single(VY(:, 1:numImages)));
cnn = cnnFeedForward_GPU(cnn, images);
else
images = VX(:, :, :, 1:numImages);
mb_labels = VY(:, 1:numImages);
cnn = cnnFeedForward(cnn, images);
end
% if to.PCAflag==1
% for iLayer=1:cnn.LNum
% if cnn.Layers{iLayer}.type==9
% fltLayer=iLayer;
% break;
% end
% end
% OptData=cnnFilter(images, cnn.Layers{fltLayer});
% else
% OptData=[];
% end
[~, preds] = max(cnn.OutData{cnn.LNum}, [], 1);
e = (preds == mb_labels);
acc = gather(sum(preds == mb_labels) / numImages);
end