-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathexpand.py
More file actions
112 lines (94 loc) · 3.6 KB
/
expand.py
File metadata and controls
112 lines (94 loc) · 3.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import os
import tarfile
from h5py import File
import numpy
import six
from six.moves import range, cPickle
from fuel.converters.base import fill_hdf5_file, check_exists
DISTRIBUTION_FILE = 'cifar-10-python.tar.gz'
@check_exists(required_files=[DISTRIBUTION_FILE])
def convert_cifar10(directory, output_directory,
output_filename='cifar10.hdf5'):
"""Converts the CIFAR-10 dataset to HDF5.
Converts the CIFAR-10 dataset to an HDF5 dataset compatible with
:class:`fuel.datasets.CIFAR10`. The converted dataset is saved as
'cifar10.hdf5'.
It assumes the existence of the following file:
* `cifar-10-python.tar.gz`
Parameters
----------
directory : str
Directory in which input files reside.
output_directory : str
Directory in which to save the converted dataset.
output_filename : str, optional
Name of the saved dataset. Defaults to 'cifar10.hdf5'.
Returns
-------
output_paths : tuple of str
Single-element tuple containing the path to the converted dataset.
"""
output_path = os.path.join(output_directory, output_filename)
h5file = File(output_path, mode='w')
input_file = os.path.join(directory, DISTRIBUTION_FILE)
tar_file = tarfile.open(input_file, 'r:gz')
train_batches = []
for batch in range(1, 6):
file = tar_file.extractfile(
'cifar-10-batches-py/data_batch_%d' % batch)
try:
if six.PY3:
array = cPickle.load(file, encoding='latin1')
else:
array = cPickle.load(file)
train_batches.append(array)
finally:
file.close()
train_features = numpy.concatenate(
[batch['data'].reshape(batch['data'].shape[0], 3, 32, 32)
for batch in train_batches])
train_labels = numpy.concatenate(
[numpy.array(batch['labels'], dtype=numpy.uint8)
for batch in train_batches])
train_labels = numpy.expand_dims(train_labels, 1)
print train_features.shape
print train_labels.shape
flipped_train_features = train_features[:,:,:,::-1]
train_features = numpy.array([val for pair in zip(train_features, flipped_train_features) for val in pair])
train_labels = numpy.repeat(train_labels, 2, axis=0)
print train_features.shape
print train_labels.shape
file = tar_file.extractfile('cifar-10-batches-py/test_batch')
try:
if six.PY3:
test = cPickle.load(file, encoding='latin1')
else:
test = cPickle.load(file)
finally:
file.close()
test_features = test['data'].reshape(test['data'].shape[0],
3, 32, 32)
test_labels = numpy.array(test['labels'], dtype=numpy.uint8)
test_labels = numpy.expand_dims(test_labels, 1)
data = (('train', 'features', train_features),
('train', 'targets', train_labels),
('test', 'features', test_features),
('test', 'targets', test_labels))
fill_hdf5_file(h5file, data)
h5file['features'].dims[0].label = 'batch'
h5file['features'].dims[1].label = 'channel'
h5file['features'].dims[2].label = 'height'
h5file['features'].dims[3].label = 'width'
h5file['targets'].dims[0].label = 'batch'
h5file['targets'].dims[1].label = 'index'
h5file.flush()
h5file.close()
return (output_path,)
def fill_subparser(subparser):
"""Sets up a subparser to convert the CIFAR10 dataset files.
Parameters
----------
subparser : :class:`argparse.ArgumentParser`
Subparser handling the `cifar10` command.
"""
return convert_cifar10