From df64316be48d0f2e49fb0a6f967111f704cee8d8 Mon Sep 17 00:00:00 2001 From: Faried Abu Zaid Date: Mon, 18 Aug 2025 23:47:49 +0200 Subject: [PATCH 1/3] store checkpoint at correct location --- src/usflows/explib/hyperopt.py | 98 ++++++++++++++++++---------------- 1 file changed, 51 insertions(+), 47 deletions(-) diff --git a/src/usflows/explib/hyperopt.py b/src/usflows/explib/hyperopt.py index 6fb6dac..5fba98b 100644 --- a/src/usflows/explib/hyperopt.py +++ b/src/usflows/explib/hyperopt.py @@ -130,9 +130,12 @@ def _trial(cls, config: T.Dict[str, T.Any], device: torch.device = None) -> Dict best_loss = val_loss # Create checkpoint - - torch.save(flow.state_dict(), f"./checkpoint.pt") - + wdr = os.getcwd() + wdr_split = wdr.split("/") + expdir = [d for d in wdr_split if d.startswith("_trial_")][0] + trialdir = wdr_split[-1] + torch.save(flow.state_dict(), f"{config['storage_path']}/{expdir}/{trialdir}/checkpoint.pt") + # Advanced logging try: cfg_log = config["logging"] @@ -184,55 +187,56 @@ def conduct(self, report_dir: os.PathLike, storage_path: os.PathLike = None): if self.skip: return + + #ray.init() + if storage_path is None: - storage_path = os.path.expanduser("~") + storage_path = os.path.expanduser("~/ray_results") + + runcfg = RunConfig(storage_path=storage_path) + runcfg.local_dir = f"{storage_path}/local/" + tuner_config = {"run_config": runcfg} self.temp_dir = os.path.join(storage_path, "temp") ray.init(_temp_dir=f"{storage_path}/temp/") - #ray.init() - - if storage_path is not None: - runcfg = RunConfig(storage_path=storage_path) - runcfg.local_dir = f"{storage_path}/local/" - tuner_config = {"run_config": runcfg} - else: - storage_path = os.path.expanduser("~/ray_results") - tuner_config = {} - - exptime = str(datetime.now()) - tuner = tune.Tuner( - tune.with_resources( - tune.with_parameters(HyperoptExperiment._trial), - resources={"cpu": self.cpus_per_trial, "gpu": self.gpus_per_trial}, - ), - tune_config=tune.TuneConfig( - scheduler=self.scheduler, - #search_alg=search_alg, - num_samples=self.num_hyperopt_samples, - **(self.tuner_params), - ), - param_space=self.trial_config, - **(tuner_config), - ) - results = tuner.fit() - - # TODO: hacky way to determine the last experiment - exppath = ( - storage_path - + [ - "/" + f - for f in sorted(os.listdir(storage_path)) - if f.startswith("_trial") - ][-1] - ) - report_file = os.path.join( - report_dir, f"report_{self.name}_" + exptime + ".csv" - ) - results = self._build_report(exppath, report_file=report_file, config_prefix="param_") - best_result = results.iloc[results["val_loss_best"].argmin()].copy() + self.trial_config["storage_path"] = storage_path + + try: + exptime = str(datetime.now()) + tuner = tune.Tuner( + tune.with_resources( + tune.with_parameters(HyperoptExperiment._trial), + resources={"cpu": self.cpus_per_trial, "gpu": self.gpus_per_trial}, + ), + tune_config=tune.TuneConfig( + scheduler=self.scheduler, + #search_alg=search_alg, + num_samples=self.num_hyperopt_samples, + **(self.tuner_params), + ), + param_space=self.trial_config, + **(tuner_config), + ) + results = tuner.fit() + + # TODO: hacky way to determine the last experiment + exppath = ( + storage_path + + [ + "/" + f + for f in sorted(os.listdir(storage_path)) + if f.startswith("_trial") + ][-1] + ) + report_file = os.path.join( + report_dir, f"report_{self.name}_" + exptime + ".csv" + ) + results = self._build_report(exppath, report_file=report_file, config_prefix="param_") + best_result = results.iloc[results["val_loss_best"].argmin()].copy() - self._test_best_model(best_result, exppath, report_dir, device=self.device, exp_id=exptime) - ray.shutdown() + self._test_best_model(best_result, exppath, report_dir, device=self.device, exp_id=exptime) + finally: + ray.shutdown() def _test_best_model(self, best_result: pd.Series, expdir: str, report_dir: str, device: torch.device = "cpu", exp_id: str = "foo" ) -> pd.Series: trial_id = best_result.trial_id From 1bc37798a8786103ae337aff9b7dfb9a4cc685c7 Mon Sep 17 00:00:00 2001 From: Faried Abu Zaid Date: Mon, 25 Aug 2025 13:33:58 +0200 Subject: [PATCH 2/3] add MVTec dataset --- src/usflows/explib/datasets.py | 264 ++++++++++++++++++++++++++++++++- src/usflows/explib/eval.py | 16 ++ 2 files changed, 279 insertions(+), 1 deletion(-) diff --git a/src/usflows/explib/datasets.py b/src/usflows/explib/datasets.py index 3da1440..53f2e59 100644 --- a/src/usflows/explib/datasets.py +++ b/src/usflows/explib/datasets.py @@ -10,6 +10,7 @@ from sklearn.datasets import make_blobs, make_checkerboard, make_circles, make_moons from torch import Tensor from torchvision.datasets import MNIST, FashionMNIST, CIFAR10 +from PIL import Image # Base dataset classes @@ -596,4 +597,265 @@ def __init__( train = DistributionDataset(distribution, num_train, device) val = DistributionDataset(distribution, num_val, device) test = DistributionDataset(distribution, num_test, device) - super().__init__(train, test, val) \ No newline at end of file + super().__init__(train, test, val) + +# MVTec AD Dataset +class MVTecADDequantized(DequantizedDataset): + def __init__( + self, + dataloc: os.PathLike = None, + train: bool = True, + category: str = "bottle", + is_anomaly: bool = False, + device: torch.device = None, + space_to_depth_factor: int = 1, + download: bool = True, + *args, + **kwargs, + ): + """ + MVTec AD dataset for anomaly detection. + + Args: + dataloc: Path to the dataset directory + train: Whether to load training or test data + category: Category/class of objects (e.g., 'bottle', 'cable', 'capsule') + is_anomaly: Whether to load anomalous samples (only applicable for test set) + device: Device to store data on + space_to_depth_factor: Factor for space-to-depth transformation + download: Whether to download the dataset if not found + """ + if dataloc is None: + dataloc = os.path.join(os.getcwd(), "data", "mvtec_ad") + + # Create directory if it doesn't exist + os.makedirs(dataloc, exist_ok=True) + + # Check if dataset exists, download if needed + category_path = os.path.join(dataloc, category) + if not os.path.exists(category_path) and download: + self._download_mvtec_ad(dataloc, category) + + # Define paths + split = "train" if train else "test" + base_path = os.path.join(dataloc, category, split) + + if not os.path.exists(base_path): + raise RuntimeError(f"MVTec AD dataset not found at {base_path}. " + "Set download=True to download it automatically.") + + if train or not is_anomaly: + # For training or normal test samples + img_dir = os.path.join(base_path, "good") + + # Check if the good directory exists + if not os.path.exists(img_dir): + raise RuntimeError(f"Good images directory not found at {img_dir}. " + "The dataset structure might be incorrect.") + + img_paths = [os.path.join(img_dir, f) + for f in os.listdir(img_dir) + if f.endswith(('.png', '.jpg', '.jpeg'))] + else: + # For anomalous test samples, we need to handle multiple anomaly types + anomaly_dirs = [d for d in os.listdir(base_path) + if os.path.isdir(os.path.join(base_path, d)) and d != "good"] + + # Check if there are any anomaly directories + if not anomaly_dirs: + raise RuntimeError(f"No anomaly directories found in {base_path}.") + + img_paths = [] + for anomaly_type in anomaly_dirs: + anomaly_dir = os.path.join(base_path, anomaly_type) + img_paths.extend([os.path.join(anomaly_dir, f) + for f in os.listdir(anomaly_dir) + if f.endswith(('.png', '.jpg', '.jpeg'))]) + + # Check if we found any images + if not img_paths: + raise RuntimeError(f"No images found in {base_path}.") + + # Read and preprocess images + images = [] + transform_to_tensor = transforms.ToTensor() + + for img_path in img_paths: + img = Image.open(img_path).convert('RGB') + img_tensor = transform_to_tensor(img) + images.append(img_tensor) + + dataset = torch.stack(images) * 255 # Scale to [0, 255] + dataset = dataset.to(torch.uint8) + + # Create labels: 0 for normal, 1 for anomalous + labels = torch.zeros(len(dataset)) if not is_anomaly else torch.ones(len(dataset)) + + super().__init__( + dataset, + num_bits=8, + space_to_depth_factor=space_to_depth_factor, + device=device, + *args, + **kwargs + ) + self.labels = labels + + def _download_mvtec_ad(self, dataloc: os.PathLike, category: str): + """ + Download the MVTec AD dataset if it doesn't exist, handling XZ compression. + """ + import requests + import tarfile + import lzma + from tqdm import tqdm + import shutil + + # Official MVTec AD download URL (this might need to be updated) + base_url = "https://www.mydrive.ch/shares/38536/3830184030e49fe74747669442f0f282/download/420938113-1629952094" + url = f"{base_url}/{category}.xz" # Note the .xz extension + + # Download to a temporary file + temp_path = os.path.join(dataloc, f"{category}.tar.xz") + + print(f"Downloading MVTec AD {category} dataset (XZ compressed)...") + + try: + # Stream the download with timeout and headers + headers = { + 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36' + } + + response = requests.get(url, stream=True, headers=headers, timeout=30) + response.raise_for_status() # Check for HTTP errors + + total_size = int(response.headers.get('content-length', 0)) + + with open(temp_path, 'wb') as f, tqdm( + desc=f"Downloading {category}", + total=total_size, + unit='iB', + unit_scale=True, + unit_divisor=1024, + ) as pbar: + for data in response.iter_content(chunk_size=8192): + size = f.write(data) + pbar.update(size) + + # Extract the XZ-compressed tar archive + print(f"Extracting {category} dataset...") + + # Method 1: Using tarfile with xz compression (Python 3.3+) + try: + with tarfile.open(temp_path, 'r:xz') as tar: + tar.extractall(dataloc) + except Exception as e: + print(f"Tar extraction failed: {e}, trying manual xz extraction...") + + # Method 2: Manual extraction as fallback + # First decompress xz, then extract tar + decompressed_path = temp_path.replace('.xz', '') + + # Decompress xz + with lzma.open(temp_path) as compressed: + with open(decompressed_path, 'wb') as decompressed: + shutil.copyfileobj(compressed, decompressed) + + # Extract tar + with tarfile.open(decompressed_path, 'r:') as tar: + tar.extractall(dataloc) + + # Clean up intermediate file + os.remove(decompressed_path) + + print(f"Download and extraction of {category} complete.") + + except requests.exceptions.RequestException as e: + print(f"Download failed: {e}") + raise RuntimeError(f"Failed to download MVTec AD dataset. Please download manually from https://www.mvtec.com/company/research/datasets/mvtec-ad") + except Exception as e: + print(f"Failed to process MVTec AD dataset: {e}") + raise + finally: + # Clean up the temporary file if it exists + if os.path.exists(temp_path): + os.remove(temp_path) + + def __getitem__(self, index: int): + x = self.transform(self.dataset[index]) + y = self.labels[index] + return x, y + +class MVTecADSplit(DataSplit): + def __init__( + self, + dataloc: os.PathLike = None, + category: str = "bottle", + val_split: float = 0.1, + space_to_depth_factor: int = 1, + device: torch.device = None, + download: bool = True, + ): + """ + Data split for MVTec AD dataset. + + Args: + dataloc: Path to the dataset directory + category: Category/class of objects + val_split: Fraction of training data to use for validation + space_to_depth_factor: Factor for space-to-depth transformation + device: Device to store data on + download: Whether to download the dataset if not found + """ + if dataloc is None: + dataloc = os.path.join(os.getcwd(), "data", "mvtec_ad") + + # Training data (only normal samples) + self.train = MVTecADDequantized( + dataloc=dataloc, + train=True, + category=category, + is_anomaly=False, + space_to_depth_factor=space_to_depth_factor, + device=device, + download=download + ) + + # Split training data into train and validation + shuffle = torch.randperm(len(self.train)) + val_size = int(len(self.train) * val_split) + self.val = torch.utils.data.Subset(self.train, shuffle[:val_size]) + self.train = torch.utils.data.Subset(self.train, shuffle[val_size:]) + + # Test data (both normal and anomalous samples) + test_normal = MVTecADDequantized( + dataloc=dataloc, + train=False, + category=category, + is_anomaly=False, + space_to_depth_factor=space_to_depth_factor, + device=device, + download=download + ) + + test_anomaly = MVTecADDequantized( + dataloc=dataloc, + train=False, + category=category, + is_anomaly=True, + space_to_depth_factor=space_to_depth_factor, + device=device, + download=download + ) + + # Combine normal and anomalous test samples + self.test = torch.utils.data.ConcatDataset([test_normal, test_anomaly]) + + def get_train(self) -> torch.utils.data.Dataset: + return self.train + + def get_test(self) -> torch.utils.data.Dataset: + return self.test + + def get_val(self) -> torch.utils.data.Dataset: + return self.val \ No newline at end of file diff --git a/src/usflows/explib/eval.py b/src/usflows/explib/eval.py index 03af51f..9e73829 100644 --- a/src/usflows/explib/eval.py +++ b/src/usflows/explib/eval.py @@ -557,12 +557,22 @@ def nll_norm_scatter_plot(self, ref_distribution, ax=None, n_samples=10000): nlls = -ref_distribution.log_prob(base_samples).cpu().numpy() latent_norms = (self.flow.backward(base_samples) - self.loc).norm(p=self.p, dim=1).cpu().numpy() + # Compute Pearson correlation + pearson_r, _ = stats.pearsonr(nlls, latent_norms) + spearman_rho, _ = stats.spearmanr(nlls, latent_norms) + kendall_tau, _ = stats.kendalltau(nlls, latent_norms) + + + # Scatter plot ax.scatter(nlls, latent_norms, alpha=0.5) ax.set_xlabel("Negative Log-Likelihood") ax.set_ylabel("Latent Norm") ax.set_title("Negative Log-Likelihood vs Latent Norm") + ax.text(0.6, 0.05, f"Pearson R: {pearson_r:.2f}\nSpearman Rho: {spearman_rho:.2f}\nKendall Tau: {kendall_tau:.2f}", + transform=ax.transAxes, bbox=dict(facecolor='white', alpha=0.5)) + return ax def logprob_reference_scatter_plot(self, ref_distribution, ax=None, n_samples=10000): @@ -596,10 +606,16 @@ def logprob_reference_scatter_plot(self, ref_distribution, ax=None, n_samples=10 ref_log_probs = ref_distribution.log_prob(base_samples).cpu().numpy() learned_log_probs = self.flow.log_prob(base_samples).cpu().numpy() + pearson_r, _ = stats.pearsonr(ref_log_probs, learned_log_probs) + spearman_rho, _ = stats.spearmanr(ref_log_probs, learned_log_probs) + kendall_tau, _ = stats.kendalltau(ref_log_probs, learned_log_probs) + # Scatter plot ax.scatter(ref_log_probs, learned_log_probs, alpha=0.5) ax.set_xlabel("Reference Log-Probability") ax.set_ylabel("Estimated Log-Probability") ax.set_title("Log-Probability Comparison") + ax.text(0.6, 0.05, f"Pearson R: {pearson_r:.2f}\nSpearman Rho: {spearman_rho:.2f}\nKendall Tau: {kendall_tau:.2f}", + transform=ax.transAxes, bbox=dict(facecolor='white', alpha=0.5)) return ax \ No newline at end of file From 9cb7dfd573d6ab8af311508b8f7a298735c9a92e Mon Sep 17 00:00:00 2001 From: Faried Abu Zaid Date: Wed, 27 Aug 2025 15:59:26 +0200 Subject: [PATCH 3/3] Load MVTec into memory optional --- src/usflows/explib/config_parser.py | 5 +- src/usflows/explib/datasets.py | 172 +++++++++++++++++++--------- 2 files changed, 122 insertions(+), 55 deletions(-) diff --git a/src/usflows/explib/config_parser.py b/src/usflows/explib/config_parser.py index 68132e3..e36196c 100644 --- a/src/usflows/explib/config_parser.py +++ b/src/usflows/explib/config_parser.py @@ -193,7 +193,10 @@ def parse_raw_config(d: dict) -> Any: C = getattr(import_module(module), cls) d.pop("__object__") d = parse_raw_config(d) - return C(**d) + try: + return C(**d) + except TypeError as e: + raise ValueError(f"Error while instantiating {C} with {d}: {e}") elif "__eval__" in d: return eval(d["__eval__"]) elif "__class__" in d: diff --git a/src/usflows/explib/datasets.py b/src/usflows/explib/datasets.py index 53f2e59..0dcad81 100644 --- a/src/usflows/explib/datasets.py +++ b/src/usflows/explib/datasets.py @@ -260,6 +260,7 @@ def __init__( train: bool = True, label: T.Optional[int] = None, scale: bool = False, + three_channel: bool = False, # NEW *args, **kwargs, ): @@ -275,6 +276,10 @@ def __init__( dataset = idx2numpy.convert_from_file(path) if scale: dataset = dataset[:, ::3, ::3] + # If requested, convert grayscale -> 3-channel by repeating channel + if three_channel: + # dataset shape: (N, H, W) -> (N, 1, H, W) then repeat to 3 channels + dataset = np.repeat(dataset[:, None, :, :], 3, axis=1) #dataset = dataset.reshape(dataset.shape[0], -1) if label is not None: rel_path = ( @@ -303,11 +308,18 @@ def __init__( val_split: float = 0.1, label: T.Optional[int] = None, space_to_depth_factor: int = 1, + three_channel: bool = False, # NEW ): if dataloc is None: dataloc = os.path.join(os.getcwd(), "data") self.dataloc = dataloc - self.train = FashionMnistDequantized(self.dataloc, train=True, label=label, space_to_depth_factor=space_to_depth_factor) + self.train = FashionMnistDequantized( + self.dataloc, + train=True, + label=label, + space_to_depth_factor=space_to_depth_factor, + three_channel=three_channel, # PASS THROUGH + ) shuffle = torch.randperm(len(self.train)) self.val = torch.utils.data.Subset( self.train, shuffle[: int(len(self.train) * val_split)] @@ -315,7 +327,7 @@ def __init__( self.train = torch.utils.data.Subset( self.train, shuffle[int(len(self.train) * val_split) :] ) - self.test = FashionMnistDequantized(self.dataloc, train=False, label=label, space_to_depth_factor=space_to_depth_factor) + self.test = FashionMnistDequantized(self.dataloc, train=False, label=label, space_to_depth_factor=space_to_depth_factor, three_channel=three_channel) def get_train(self) -> torch.utils.data.Dataset: return self.train @@ -337,7 +349,8 @@ def __init__( flatten=False, scale: bool = False, device: torch.device = None, - space_to_depth_factor: int = 1 + space_to_depth_factor: int = 1, + three_channel: bool = False, # NEW ): if train: rel_path = "MNIST/raw/train-images-idx3-ubyte" @@ -352,6 +365,9 @@ def __init__( dataset = dataset[:, ::3, ::3] if flatten: dataset = dataset.reshape(dataset.shape[0], -1) + # Convert to 3-channel if requested + if three_channel and dataset.ndim == 3: + dataset = np.repeat(dataset[:, None, :, :], 3, axis=1) if digit is not None: if train: rel_path = "MNIST/raw/train-labels-idx1-ubyte" @@ -384,7 +400,8 @@ def __init__( digit: T.Optional[int] = None, scale: bool = False, device: torch.device = None, - space_to_depth_factor: int = 1 + space_to_depth_factor: int = 1, + three_channel: bool = False, # NEW ): if dataloc is None: dataloc = os.path.join(os.getcwd(), "data") @@ -395,7 +412,8 @@ def __init__( digit=digit, scale=scale, space_to_depth_factor=space_to_depth_factor, - device=device + device=device, + three_channel=three_channel, # PASS THROUGH ) shuffle = torch.randperm(len(self.train)) self.val = torch.utils.data.Subset( @@ -410,7 +428,8 @@ def __init__( digit=digit, scale=scale, space_to_depth_factor=space_to_depth_factor, - device=device + device=device, + three_channel=three_channel, # PASS THROUGH ) def get_train(self) -> torch.utils.data.Dataset: @@ -610,6 +629,7 @@ def __init__( device: torch.device = None, space_to_depth_factor: int = 1, download: bool = True, + load_into_memory: bool = False, *args, **kwargs, ): @@ -627,79 +647,84 @@ def __init__( """ if dataloc is None: dataloc = os.path.join(os.getcwd(), "data", "mvtec_ad") - + # Create directory if it doesn't exist os.makedirs(dataloc, exist_ok=True) - + # Check if dataset exists, download if needed category_path = os.path.join(dataloc, category) if not os.path.exists(category_path) and download: self._download_mvtec_ad(dataloc, category) - + # Define paths split = "train" if train else "test" base_path = os.path.join(dataloc, category, split) - + if not os.path.exists(base_path): raise RuntimeError(f"MVTec AD dataset not found at {base_path}. " "Set download=True to download it automatically.") - + + # Gather image paths depending on mode if train or not is_anomaly: # For training or normal test samples img_dir = os.path.join(base_path, "good") - - # Check if the good directory exists if not os.path.exists(img_dir): raise RuntimeError(f"Good images directory not found at {img_dir}. " - "The dataset structure might be incorrect.") - - img_paths = [os.path.join(img_dir, f) - for f in os.listdir(img_dir) - if f.endswith(('.png', '.jpg', '.jpeg'))] + "The dataset structure might be incorrect.") + img_paths = [os.path.join(img_dir, f) + for f in sorted(os.listdir(img_dir)) + if f.lower().endswith(('.png', '.jpg', '.jpeg'))] + labels = [0] * len(img_paths) else: - # For anomalous test samples, we need to handle multiple anomaly types - anomaly_dirs = [d for d in os.listdir(base_path) - if os.path.isdir(os.path.join(base_path, d)) and d != "good"] - - # Check if there are any anomaly directories + # For anomalous test samples, handle multiple anomaly types + anomaly_dirs = [d for d in sorted(os.listdir(base_path)) + if os.path.isdir(os.path.join(base_path, d)) and d != "good"] if not anomaly_dirs: raise RuntimeError(f"No anomaly directories found in {base_path}.") - img_paths = [] + labels = [] for anomaly_type in anomaly_dirs: anomaly_dir = os.path.join(base_path, anomaly_type) - img_paths.extend([os.path.join(anomaly_dir, f) - for f in os.listdir(anomaly_dir) - if f.endswith(('.png', '.jpg', '.jpeg'))]) - - # Check if we found any images + files = [os.path.join(anomaly_dir, f) + for f in sorted(os.listdir(anomaly_dir)) + if f.lower().endswith(('.png', '.jpg', '.jpeg'))] + img_paths.extend(files) + labels.extend([1] * len(files)) + + # Ensure we found images if not img_paths: raise RuntimeError(f"No images found in {base_path}.") - - # Read and preprocess images - images = [] + + # If requested, load everything into memory (backward compatible) transform_to_tensor = transforms.ToTensor() - - for img_path in img_paths: - img = Image.open(img_path).convert('RGB') - img_tensor = transform_to_tensor(img) - images.append(img_tensor) - - dataset = torch.stack(images) * 255 # Scale to [0, 255] - dataset = dataset.to(torch.uint8) - - # Create labels: 0 for normal, 1 for anomalous - labels = torch.zeros(len(dataset)) if not is_anomaly else torch.ones(len(dataset)) - - super().__init__( - dataset, - num_bits=8, - space_to_depth_factor=space_to_depth_factor, - device=device, - *args, - **kwargs + self.device = device + self.space_to_depth_factor = space_to_depth_factor + self.num_bits = 8 + self.num_levels = 2 ** self.num_bits + self.transform = transforms.Compose( + [ + transforms.Lambda(lambda x: x / self.num_levels), + transforms.Lambda(lambda x: x + torch.rand_like(x) / self.num_levels), + ] ) - self.labels = labels + + self._load_into_memory = bool(load_into_memory) + + if self._load_into_memory: + images = [] + for img_path in img_paths: + img = Image.open(img_path).convert('RGB') + img_tensor = transform_to_tensor(img) * 255.0 + images.append(img_tensor) + dataset = torch.stack(images) + # keep dtype float for safe operations; DequantizedDataset handled uint8 but transform works on float + super().__init__(dataset, num_bits=self.num_bits, space_to_depth_factor=space_to_depth_factor, device=device, *args, **kwargs) + self.labels = torch.tensor(labels, dtype=torch.long) + else: + # Store paths and labels and avoid loading images until requested + self.img_paths = img_paths + self.labels = torch.tensor(labels, dtype=torch.long) + # don't call parent constructor to avoid converting full dataset to tensor def _download_mvtec_ad(self, dataloc: os.PathLike, category: str): """ @@ -782,10 +807,47 @@ def _download_mvtec_ad(self, dataloc: os.PathLike, category: str): os.remove(temp_path) def __getitem__(self, index: int): - x = self.transform(self.dataset[index]) + # Support both loaded-into-memory and on-demand modes + if getattr(self, "_load_into_memory", False): + x = self.dataset[index] + else: + # Load image on demand + img_path = self.img_paths[index] + img = Image.open(img_path).convert('RGB') + img_tensor = transforms.ToTensor()(img) * 255.0 + x = img_tensor + + # Apply space-to-depth if needed + if getattr(self, "space_to_depth_factor", 1) > 1: + x = self._apply_space_to_depth(x) + + # Move to device if set + if getattr(self, "device", None) is not None: + x = x.to(self.device) + + # Apply dequantization transform (returns float) + x = self.transform(x) y = self.labels[index] return x, y + def _apply_space_to_depth(self, x: Tensor) -> Tensor: + """Apply space-to-depth on a single image tensor of shape (C, H, W).""" + f = int(self.space_to_depth_factor) + if f <= 1: + return x + # Ensure dimensions are compatible + c, h, w = x.shape + if h % f != 0 or w % f != 0: + raise RuntimeError(f"Image size ({h},{w}) not divisible by space_to_depth_factor {f}") + # Reshape and permute: (C, H, W) -> (C, H//f, f, W//f, f) -> (C, f, f, H//f, W//f) -> (C*f*f, H//f, W//f) + x = x.reshape(c, h // f, f, w // f, f).permute(0, 2, 4, 1, 3).reshape(c * f * f, h // f, w // f) + return x + + def __len__(self): + if getattr(self, "_load_into_memory", False): + return len(self.dataset) + return len(self.img_paths) + class MVTecADSplit(DataSplit): def __init__( self, @@ -795,6 +857,7 @@ def __init__( space_to_depth_factor: int = 1, device: torch.device = None, download: bool = True, + load_into_memory: bool = False, ): """ Data split for MVTec AD dataset. @@ -818,7 +881,8 @@ def __init__( is_anomaly=False, space_to_depth_factor=space_to_depth_factor, device=device, - download=download + download=download, + load_into_memory=load_into_memory ) # Split training data into train and validation