-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathImageData.py
More file actions
73 lines (48 loc) · 1.76 KB
/
ImageData.py
File metadata and controls
73 lines (48 loc) · 1.76 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import functools
import random
import tensorflow as tf
import tensorflow_addons as tfa
from config import get_config_from_json
def decode_png(img):
img = tf.image.decode_png(img, channels=3)
img = tf.image.convert_image_dtype(img, tf.float32)
img = img / 127.5 - 1
return img
def process_path(img_path, label):
img = tf.io.read_file("faces_images/" + img_path)
img = decode_png(img)
return img, label
def process_only_path(img_path):
img = tf.io.read_file("faces_images/" + img_path)
img = decode_png(img)
return img
def resize_and_crop(img: tf.Tensor):
seed = random.randint(0, 2 ** 31 - 1)
image = tf.image.resize(img, size=(196, 196))
image = tf.image.random_crop(image, (128, 128, 3), seed=seed)
return image
def flip_left_right(img: tf.Tensor):
seed = random.randint(0, 2 ** 31 - 1)
img = tf.image.random_flip_left_right(img, seed=seed)
return img
def random_rotation(img: tf.Tensor):
angle = random.randrange(-30, 30)
img = tfa.image.rotate(img, angle)
return img
def random_shear(img: tf.Tensor):
img = tf.keras.preprocessing.image.random_shear(img, intensity=15)
return img
def random_brihtness(img: tf.Tensor):
img = tf.keras.preprocessing.image.random_brightness(img, brightness_range=(0.8, 1))
return img
def get_augmentation_list():
return [resize_and_crop, flip_left_right, random_rotation]
@tf.function
def augmentation(img: tf.Tensor, label: tf.Tensor, augment=None):
img = tf.cond(tf.random.uniform([], 0, 1) > 0.75, lambda: augment(img), lambda: img)
return img, label
if __name__ == "__main__":
config = get_config_from_json("config.json")
img = tf.ones((128, 128, 3))
label = tf.zeros((6,))
augmented_img, label = augmentation(img, label)