Skip to content

Commit 8a2ffcf

Browse files
noise option added
1 parent 140e47e commit 8a2ffcf

2 files changed

Lines changed: 32 additions & 7 deletions

File tree

interface.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,9 @@ def __init__(
8989
load_model_path: str = None,
9090
custom_augment_figuratif=None,
9191
custom_augment_abstrait=None,
92-
n_transforms_augmented = 2
92+
n_transforms_augmented = 2,
93+
noise: bool = False,
94+
noise_std: float = 0.1
9395
):
9496

9597
# Profiler
@@ -151,6 +153,8 @@ def __init__(
151153
self.augmentation = augmentation
152154
self.input_size = input_size
153155
self.padding = padding
156+
self.noise = noise
157+
self.noise_std = noise_std
154158

155159
# Initialize the dataset and dataloader
156160

@@ -161,23 +165,29 @@ def __init__(
161165
image_input_size=self.input_size,
162166
custom_augment_abstrait=custom_augment_abstrait,
163167
custom_augment_figuratif=custom_augment_figuratif,
164-
n_transforms_augmented=n_transforms_augmented)
165-
168+
n_transforms_augmented=n_transforms_augmented,
169+
noise=self.noise,
170+
noise_std=self.noise_std)
171+
166172
self.dataset_val = PaintingsDataset(self.data_path+'val/',
167173
augment=False,
168174
transform=self.transform,
169175
padding=self.padding,
170176
image_input_size=self.input_size,
171177
custom_augment_abstrait=None,
172-
custom_augment_figuratif=None,)
173-
178+
custom_augment_figuratif=None,
179+
noise=self.noise,
180+
noise_std=self.noise_std)
181+
174182
self.dataset_test = PaintingsDataset(self.data_path+'test/',
175183
augment=False,
176184
transform=self.transform,
177185
padding=self.padding,
178186
image_input_size=self.input_size,
179187
custom_augment_abstrait=None,
180-
custom_augment_figuratif=None,)
188+
custom_augment_figuratif=None,
189+
noise=self.noise,
190+
noise_std=self.noise_std)
181191

182192
n_figurative = self.dataset_train.len_figurative #only for training
183193
n_abstract = self.dataset_train.len_abstract

paintings_dataset.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ class PaintingsDataset(Dataset):
7373
In case of `transform` being True, the dataset will apply transformations to the images or
7474
a custom transformation can be passed as a parameter named `custom_transform`(if None we will use custom).
7575
"""
76-
def __init__(self, data_path, augment= False, transform=False, custom_augment_figuratif=None,custom_augment_abstrait=None, padding: PaddingOptions = PaddingOptions.ZERO, image_input_size: int = 224,n_transforms_augmented=2):
76+
def __init__(self, data_path, augment= False, transform=False, custom_augment_figuratif=None,custom_augment_abstrait=None, padding: PaddingOptions = PaddingOptions.ZERO, image_input_size: int = 224,n_transforms_augmented=2, noise: bool = False, noise_std: float = 0.1):
7777

7878
# Path to the data directory
7979
self.data_path = data_path
@@ -106,6 +106,10 @@ def __init__(self, data_path, augment= False, transform=False, custom_augment_fi
106106
self.padding = padding
107107
self.image_input_size = image_input_size
108108

109+
# Noise configuration
110+
self.noise = noise
111+
self.noise_std = noise_std
112+
109113

110114
def __len__(self):
111115
return self.total_length
@@ -176,6 +180,17 @@ def __getitem__(self, idx: int)->Dict[str, Union[torch.Tensor, int]]:
176180
size=(max(imh, int(self.image_input_size / 2 + 1)), max(imw, int(self.image_input_size/2 + 1))),
177181
mode='nearest').squeeze(0).clone()
178182
output['image'] = output['image'].float() / 255.0 # Normalize the image to [0, 1]
183+
184+
# Noise configuration
185+
if self.noise:
186+
noise = torch.randn_like(output['image']) * self.noise_std
187+
output['image'] = torch.clamp(output['image'] + noise, 0.0, 1.0)
188+
189+
# Apply noise to transformed image (if it exists and noise enabled)
190+
if self.noise and self.transform:
191+
noise_transformed = torch.randn_like(output['transformed_image']) * self.noise_std
192+
output['transformed_image'] = torch.clamp(output['transformed_image'] + noise_transformed, 0.0, 1.0)
193+
179194
return output
180195

181196
def change_padding(self, padding: Literal['zero', 'mirror', 'replicate'], image_input_size: int = 224):

0 commit comments

Comments
 (0)