-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathsaliencyDatasetClass.py
More file actions
executable file
·81 lines (58 loc) · 2.12 KB
/
saliencyDatasetClass.py
File metadata and controls
executable file
·81 lines (58 loc) · 2.12 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import os
import glob
from sklearn.utils import shuffle
import numpy as np
def load_train(trainSamples, image_size, labels):
images = []
print('Going to read training images')
for sample in trainSamples:
image = sample
images.append(image)
return images, labels
class DataSet(object):
def __init__(self, images, labels):
self._num_examples = images.shape[0]
self._images = images
self._labels = labels
self._epochs_done = 0
self._index_in_epoch = 0
@property
def images(self):
return self._images
@property
def labels(self):
return self._labels
@property
def num_examples(self):
return self._num_examples
@property
def epochs_done(self):
return self._epochs_done
def next_batch(self, batch_size):
"""Return the next `batch_size` examples from this data set."""
start = self._index_in_epoch
self._index_in_epoch += batch_size
if self._index_in_epoch > self._num_examples:
# After each epoch we update this
self._epochs_done += 1
start = 0
self._index_in_epoch = batch_size
assert batch_size <= self._num_examples
end = self._index_in_epoch
return self._images[start:end], self._labels[start:end]
def read_train_sets(trainSamples, image_size, labels, validation_size):
class DataSets(object):
pass
data_sets = DataSets()
# images, labels = load_train(trainSamples, image_size, labels)
images, labels = shuffle(trainSamples, labels)
# images=trainSamples
if isinstance(validation_size, float):
validation_size = int(validation_size * images.shape[0])
validation_images = images[:validation_size]
validation_labels = labels[:validation_size]
train_images = images[validation_size:]
train_labels = labels[validation_size:]
data_sets.train = DataSet(train_images, train_labels)
data_sets.valid = DataSet(validation_images, validation_labels)
return data_sets