forked from DavidBert/ModIA_TP1
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdata_utils.py
More file actions
44 lines (38 loc) · 1.67 KB
/
data_utils.py
File metadata and controls
44 lines (38 loc) · 1.67 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
from torchvision.datasets.folder import ImageFolder, default_loader, IMG_EXTENSIONS
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
class ImageFolderGrayColor(ImageFolder):
def __init__(
self,
root,
transform=None,
target_transform=None,
):
super(ImageFolder, self).__init__(root=root,
loader=default_loader,
transform=transform,
extensions=IMG_EXTENSIONS,
target_transform=target_transform)
#TODO à modifier
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (sample, target) where target is class_index of the target class.
"""
path, _ = self.samples[index]
sample = self.loader(path)
if self.target_transform is not None:
target = self.target_transform(sample)
if self.transform is not None:
sample = self.transform(sample)
return sample, target
def get_colorized_dataset_loader(path, **kwargs):
source_process = transforms.Compose(
[transforms.Resize((224, 224)), transforms.Grayscale(num_output_channels=1), transforms.ToTensor(),
transforms.Normalize(mean=[0.5], std=[0.5])])
target_process = transforms.Compose(
[transforms.Resize((224, 224)), transforms.ToTensor()])
dataset = ImageFolderGrayColor(path, source_process, target_process)
return DataLoader(dataset, **kwargs)