forked from LynnHo/CycleGAN-Tensorflow-2
-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathdata.py
More file actions
124 lines (108 loc) · 3.15 KB
/
data.py
File metadata and controls
124 lines (108 loc) · 3.15 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import numpy as np
import tensorflow as tf
import tf2lib as tl
def make_dataset(
img_paths,
batch_size,
load_size,
crop_size,
channels,
training,
drop_remainder=True,
shuffle=True,
repeat=1,
):
if training:
@tf.function
def _map_fn(img): # preprocessing
img = tf.image.random_flip_left_right(img)
img = tf.image.resize(img, [load_size, load_size])
img = tf.image.random_crop(img, [crop_size, crop_size, tf.shape(img)[-1]])
img = tf.clip_by_value(img, 0, 255) / 255.0 # or img = tl.minmax_norm(img)
img = img * 2 - 1
return img
else:
@tf.function
def _map_fn(img): # preprocessing
img = tf.image.resize(
img, [crop_size, crop_size]
) # or img = tf.image.resize(img, [load_size, load_size]); img = tl.center_crop(img, crop_size)
img = tf.clip_by_value(img, 0, 255) / 255.0 # or img = tl.minmax_norm(img)
img = img * 2 - 1
return img
return tl.disk_image_batch_dataset(
img_paths,
channels,
batch_size,
drop_remainder=drop_remainder,
map_fn=_map_fn,
shuffle=shuffle,
repeat=repeat,
)
def make_zip_dataset(
A_img_paths,
B_img_paths,
batch_size,
load_size,
crop_size,
channels,
training,
shuffle=True,
repeat=False,
):
# zip two datasets aligned by the longer one
if repeat:
A_repeat = B_repeat = None # cycle both
else:
if len(A_img_paths) >= len(B_img_paths):
A_repeat = 1
B_repeat = None # cycle the shorter one
else:
A_repeat = None # cycle the shorter one
B_repeat = 1
A_dataset = make_dataset(
A_img_paths,
batch_size,
load_size,
crop_size,
channels,
training,
drop_remainder=True,
shuffle=shuffle,
repeat=A_repeat,
)
B_dataset = make_dataset(
B_img_paths,
batch_size,
load_size,
crop_size,
channels,
training,
drop_remainder=True,
shuffle=shuffle,
repeat=B_repeat,
)
A_B_dataset = tf.data.Dataset.zip((A_dataset, B_dataset))
len_dataset = max(len(A_img_paths), len(B_img_paths)) // batch_size
return A_B_dataset, len_dataset
class ItemPool:
def __init__(self, pool_size=50):
self.pool_size = pool_size
self.items = []
def __call__(self, in_items):
# `in_items` should be a batch tensor
if self.pool_size == 0:
return in_items
out_items = []
for in_item in in_items:
if len(self.items) < self.pool_size:
self.items.append(in_item)
out_items.append(in_item)
else:
if np.random.rand() > 0.5:
idx = np.random.randint(0, len(self.items))
out_item, self.items[idx] = self.items[idx], in_item
out_items.append(out_item)
else:
out_items.append(in_item)
return tf.stack(out_items, axis=0)