-
Notifications
You must be signed in to change notification settings - Fork 31
Expand file tree
/
Copy pathdataset.py
More file actions
132 lines (112 loc) · 4.58 KB
/
dataset.py
File metadata and controls
132 lines (112 loc) · 4.58 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import glob
import random
import numpy as np
from PIL import Image
from skimage.color import rgb2lab
import torch
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image, ImageChops
from utils import load_default_configs, split_lab_channels
def is_greyscale(im):
"""
Check if image is monochrome (1 channel or 3 identical channels)
You can use this to filter your dataset of black and white images
"""
if isinstance(im, str):
im = Image.open(im).convert("RGB")
if im.mode not in ("L", "RGB"):
raise ValueError("Unsuported image mode")
if im.mode == "RGB":
rgb = im.split()
if ImageChops.difference(rgb[0], rgb[1]).getextrema()[1] != 0:
return False
if ImageChops.difference(rgb[0], rgb[2]).getextrema()[1] != 0:
return False
return True
class ColorizationDataset(Dataset):
def __init__(self, paths, split='train', config=None):
size = config["img_size"]
self.resize = transforms.Resize((size, size), Image.BICUBIC)
if split == 'train':
self.transforms = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(brightness=0.3,
contrast=0.1,
saturation=(1., 2.),
hue=0.05),
self.resize
])
elif split == 'val':
self.transforms = self.resize
self.paths = paths
def tensor_to_lab(self, base_img_tens):
base_img = np.array(base_img_tens)
img_lab = rgb2lab(base_img).astype(
"float32") # Converting RGB to L*a*b
img_lab = transforms.ToTensor()(img_lab)
L = img_lab[[0], ...] / 50. - 1. # Between -1 and 1
ab = img_lab[[1, 2], ...] / 110. # Between -1 and 1
return torch.cat((L, ab), dim=0)
def get_lab_from_path(self, path):
img = Image.open(path).convert("RGB")
img = self.transforms(img)
return self.tensor_to_lab(img)
def get_rgb(self, idx=0):
img = Image.open(self.paths[idx]).convert("RGB")
img = self.transforms(img)
img = np.array(img)
return (img)
def get_grayscale(self, idx=0):
img = Image.open(self.paths[idx]).convert("L")
img = self.resize(img)
img = np.array(img)
return (img)
def get_lab_grayscale(self, idx=0):
img = self.get_lab_from_path(self.paths[idx])
l, _ = split_lab_channels(img.unsqueeze(0))
return torch.cat((l, *[torch.zeros_like(l)] * 2), dim=1)
def __getitem__(self, idx):
return self.get_lab_from_path(self.paths[idx])
def __len__(self):
return len(self.paths)
class PickleColorizationDataset(ColorizationDataset):
def __getitem__(self, idx):
return (torch.load(self.paths[idx]))
def make_datasets(path, config, limit=None):
img_paths = glob.glob(path + "/*")
if limit:
img_paths = random.sample(img_paths, limit)
n_imgs = len(img_paths)
train_split = img_paths[:int(n_imgs * .9)]
val_split = img_paths[int(n_imgs * .9):]
train_dataset = ColorizationDataset(
train_split, split="train", config=config)
val_dataset = ColorizationDataset(val_split, split="val", config=config)
print(f"Train size: {len(train_split)}")
print(f"Val size: {len(val_split)}")
return train_dataset, val_dataset
def make_dataloaders(path, config, num_workers=2, shuffle=True, limit=None):
train_dataset, val_dataset = make_datasets(path, config, limit=limit)
train_dl = DataLoader(train_dataset,
batch_size=config["batch_size"],
num_workers=num_workers,
pin_memory=config["pin_memory"],
persistent_workers=True,
shuffle=shuffle)
val_dl = DataLoader(val_dataset,
batch_size=config["batch_size"],
num_workers=num_workers,
pin_memory=config["pin_memory"],
persistent_workers=True,
shuffle=shuffle)
return train_dl, val_dl
if __name__ == "__main__":
enc_config, unet_config, colordiff_config = load_default_configs()
train_dl, val_dl = make_dataloaders("./fairface",
colordiff_config,
num_workers=4)
x = next(iter(train_dl))
y = next(iter(val_dl))
print(f"y.shape = {y.shape}")
print(f"x.shape = {x.shape}")