forked from shekkizh/FCN.tensorflow
-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathTFReader.py
More file actions
140 lines (124 loc) · 6.18 KB
/
TFReader.py
File metadata and controls
140 lines (124 loc) · 6.18 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import tensorflow as tf
from tensorflow.data import Dataset, Iterator
class DatasetReader:
filenames = []
tf_filenames = tf.convert_to_tensor([])
label_filenames = []
image_options = {}
def __init__(self, records_list, image_options={}, batch_size=1):
"""
Intialize a generic file reader with batching for list of files
:param records_list: list of file records to read -
sample record: {'image': f, 'annotation': annotation_file, 'filename': filename}
:param image_options: A dictionary of options for modifying the output image
Available options:
resize = True/ False
resize_size = #size of output image - does bilinear resize
color=True/False
image_augmentation=True/False
:param predict_dataset: boolean stating whether dataset is for predictions (does not include annotations)
True/False (default False)
"""
self.batch_size = batch_size
self.image_options = image_options
self.records = {}
self.records["image"] = [record['image'] for record in records_list]
self.records["filename"] = [record['filename'] for record in records_list]
if not self.image_options.get("predict_dataset", False):
self.records["annotation"] = [record['annotation'] for record in records_list]
#tf_records_placeholder = tf.placeholder(self.records)
if 'annotation' in self.records:
self.dataset = Dataset.from_tensor_slices((self.records['image'], self.records['filename'],
self.records['annotation']))
else:
self.dataset = Dataset.from_tensor_slices((self.records['image'], self.records['filename']))
self.dataset = self.dataset.map(self._input_parser)
self.dataset = self.dataset.batch(batch_size)
self.dataset = self.dataset.repeat()
def _input_parser(self, image_filename, name, annotation_filename=None):
#Based on https://github.com/tensorflow/tensorflow/issues/9356, decode_jpeg and decode_png both decode both formats
#This is a workaround because decode_image does not return a static size, which breaks resize_images
image = tf.image.decode_png(tf.read_file(image_filename))
if self.image_options.get("resize", False):
image = tf.image.resize_images(image[..., :3], (self.image_options["resize_height"], self.image_options["resize_width"]))
annotation = None
if annotation_filename is not None:
annotation = tf.image.decode_png(tf.read_file(annotation_filename))
if self.image_options.get("resize", False):
annotation = tf.image.resize_images(annotation,
(self.image_options["resize_height"], self.image_options["resize_width"]))
if self.image_options.get("image_augmentation", False):
#Return image, annotation if existing, and filename
return self._augment_image(image, annotation) + (name, )
elif annotation_filename is None:
return image, name
else:
return image, annotation, name
def _augment_image(self, image, annotation_file=None):
if annotation_file is not None:
combined_image_label = tf.concat((image, annotation_file), axis=2)
else:
combined_image_label = image
combined_image_label = tf.image.random_flip_left_right(combined_image_label)
combined_image_label = tf.image.random_flip_up_down(combined_image_label)
if annotation_file is not None:
distorted_image = combined_image_label[:, :, :3]
#Add extra dimension to image to make it NxMx1 rather than NxM image
distorted_annotation = tf.expand_dims(combined_image_label[:, :, 3], -1)
else:
distorted_image = combined_image_label
distorted_image = tf.image.random_brightness(distorted_image, max_delta=63)
distorted_image = tf.image.random_contrast(distorted_image, lower=0.2, upper=1.8)
if annotation_file is not None:
# IDE may not think so, but distorted_annotation is always created before returned
return distorted_image, distorted_annotation
else:
return distorted_image
class TrainVal:
def __init__(self):
self.train = None
self.validation = None
pass
@classmethod
def from_DatasetReaders(cls, train_reader, val_reader):
train_val = cls()
train_val.train = train_reader
train_val.validation = val_reader
train_val._create_iterators()
#train_val._create_ops()
return train_val
@classmethod
def from_records(cls, train_records, val_records, train_image_options, val_image_options, train_batch_size=1, val_batch_size=1):
train_reader = DatasetReader(train_records, train_image_options, train_batch_size)
val_reader = DatasetReader(val_records, val_image_options, val_batch_size)
return cls.from_DatasetReaders(train_reader, val_reader)
def _create_iterators(self):
if self.train and self.validation:
self.train_iterator = self.train.dataset.make_one_shot_iterator()
self.validation_iterator = self.validation.dataset.make_one_shot_iterator()
def get_iterators(self):
if not self.train_iterator or not self.validation_iterator:
self._create_iterators()
return self.train_iterator, self.validation_iterator
class SingleDataset:
def __init__(self):
self.reader = None
self.iterator = None
pass
@classmethod
def from_DatasetReaders(cls, reader):
dataset = cls()
dataset.reader = reader
dataset._create_iterator()
return dataset
@classmethod
def from_records(cls, records, image_options, batch_size=1):
reader = DatasetReader(records, image_options, batch_size)
return cls.from_DatasetReaders(reader)
def _create_iterator(self):
if self.reader:
self.iterator = self.reader.dataset.make_one_shot_iterator()
def get_iterator(self):
if not self.iterator:
self._create_iterator()
return self.iterator