-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathdata_loader.py
More file actions
40 lines (34 loc) · 1.65 KB
/
data_loader.py
File metadata and controls
40 lines (34 loc) · 1.65 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import numpy as np
from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets
# DataLoader class: need to customize according to your dataset
class DataLoader(object):
""" Example data loader class to load and process dataset.
Note: In this tutorial, we load the whole dataset to memory.
You need to customize this class based on your dataset.
"""
def __init__(self, data_dir, width, height, channel):
# variable to hold the whole dataset
self.dataset = read_data_sets(data_dir, one_hot=False)
# basic stats of the dataset
self.num = self.dataset.train.images.shape[0]
self.test_num = self.dataset.test.images.shape[0]
self.h = height
self.w = width
self.c = channel
# counter that indicates which image to load
self._idx = 0
def next_batch(self, batch_size):
""" Load next batch of training data """
images_batch = np.zeros((batch_size, self.h, self.w, self.c))
labels_batch = np.zeros(batch_size)
for i in range(batch_size):
# when your dataset is huge, you may need to load images on the fly
images_batch[i, ...] = self.dataset.train.images[self._idx].reshape((self.h, self.w, self.c))
labels_batch[i, ...] = self.dataset.train.labels[self._idx]
self._idx += 1
if self._idx == self.num:
self._idx = 0
return images_batch, labels_batch
def load_test(self):
""" Load testing data """
return self.dataset.test.images.reshape((self.test_num, self.h, self.w, self.c)), self.dataset.test.labels