From 01c2bd0b8137ebab9727795749ce7f9cfb5acc7b Mon Sep 17 00:00:00 2001 From: dhalmazna Date: Wed, 4 Mar 2026 16:30:03 +0100 Subject: [PATCH 01/13] feat: add data loading utilities for images --- .mypy.ini | 3 ++- ciao/data/__init__.py | 6 ++++++ ciao/data/loader.py | 41 +++++++++++++++++++++++++++++++++++++++++ 3 files changed, 49 insertions(+), 1 deletion(-) create mode 100644 ciao/data/__init__.py create mode 100644 ciao/data/loader.py diff --git a/.mypy.ini b/.mypy.ini index 1c6f44e..8d15f3c 100644 --- a/.mypy.ini +++ b/.mypy.ini @@ -2,4 +2,5 @@ strict = True ignore_missing_imports = True disallow_untyped_calls = False -disable_error_code = no-any-return \ No newline at end of file +disable_error_code = no-any-return +explicit_package_bases = True \ No newline at end of file diff --git a/ciao/data/__init__.py b/ciao/data/__init__.py new file mode 100644 index 0000000..0771aa1 --- /dev/null +++ b/ciao/data/__init__.py @@ -0,0 +1,6 @@ +"""Data loading utilities for CIAO.""" + +from ciao.data.loader import get_image_loader + + +__all__ = ["get_image_loader"] diff --git a/ciao/data/loader.py b/ciao/data/loader.py new file mode 100644 index 0000000..ec2d331 --- /dev/null +++ b/ciao/data/loader.py @@ -0,0 +1,41 @@ +"""Simple image path loading utilities.""" + +from collections.abc import Iterator +from pathlib import Path +from typing import Any + + +# Supported image formats +IMAGE_EXTENSIONS = (".jpg", ".jpeg", ".png", ".bmp", ".webp") + + +def get_image_loader(config: Any) -> Iterator[Path]: + """Create image loader based on configuration. + + Args: + config: Hydra config object + + Returns: + Iterator of Path objects + + Raises: + ValueError: If neither image_path nor batch_path is specified + """ + if config.data.get("image_path"): + # Single image mode + yield Path(config.data.image_path) + + elif config.data.get("batch_path"): + # Directory mode - find all images with supported extensions + directory = Path(config.data.batch_path) + if not directory.is_dir(): + raise ValueError( + f"batch_path must be a valid directory, got: {directory}. " + "Check for typos or incorrect path configuration." + ) + + for ext in IMAGE_EXTENSIONS: + yield from directory.glob(f"**/*{ext}") + + else: + raise ValueError("Must specify either image_path or batch_path in config") From 2643f12fe126d6c533496088516e832ef10b0280 Mon Sep 17 00:00:00 2001 From: dhalmazna Date: Wed, 4 Mar 2026 16:35:38 +0100 Subject: [PATCH 02/13] feat: add image preprocessing function --- ciao/data/__init__.py | 3 ++- ciao/data/preprocessing.py | 37 +++++++++++++++++++++++++++++++++++++ 2 files changed, 39 insertions(+), 1 deletion(-) create mode 100644 ciao/data/preprocessing.py diff --git a/ciao/data/__init__.py b/ciao/data/__init__.py index 0771aa1..2a2ab3c 100644 --- a/ciao/data/__init__.py +++ b/ciao/data/__init__.py @@ -1,6 +1,7 @@ """Data loading utilities for CIAO.""" from ciao.data.loader import get_image_loader +from ciao.data.preprocessing import load_and_preprocess_image -__all__ = ["get_image_loader"] +__all__ = ["get_image_loader", "load_and_preprocess_image"] diff --git a/ciao/data/preprocessing.py b/ciao/data/preprocessing.py new file mode 100644 index 0000000..997e76d --- /dev/null +++ b/ciao/data/preprocessing.py @@ -0,0 +1,37 @@ +from pathlib import Path + +import torch +import torchvision.transforms as transforms +from PIL import Image + + +# ImageNet preprocessing transforms +preprocess = transforms.Compose( + [ + transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ] +) + + +def load_and_preprocess_image( + image_path: str | Path, device: torch.device | None = None +) -> torch.Tensor: + """Load and preprocess an image for the model. + + Args: + image_path: Path to image file + device: Device to place tensor on (defaults to cuda if available, else cpu) + + Returns: + Preprocessed image tensor [3, 224, 224] on specified device + """ + if device is None: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + image = Image.open(image_path).convert("RGB") + input_tensor = preprocess(image).to(device) # (3, 224, 224) - on correct device + + return input_tensor From 00fd45292c24d6c9d46b4b5c7ee0b37d60827e96 Mon Sep 17 00:00:00 2001 From: dhalmazna Date: Wed, 4 Mar 2026 17:22:19 +0100 Subject: [PATCH 03/13] feat: add segmentation functionality for image processing --- ciao/data/__init__.py | 3 +- ciao/data/segmentation.py | 284 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 286 insertions(+), 1 deletion(-) create mode 100644 ciao/data/segmentation.py diff --git a/ciao/data/__init__.py b/ciao/data/__init__.py index 2a2ab3c..77cc547 100644 --- a/ciao/data/__init__.py +++ b/ciao/data/__init__.py @@ -2,6 +2,7 @@ from ciao.data.loader import get_image_loader from ciao.data.preprocessing import load_and_preprocess_image +from ciao.data.segmentation import create_segmentation -__all__ = ["get_image_loader", "load_and_preprocess_image"] +__all__ = ["create_segmentation", "get_image_loader", "load_and_preprocess_image"] diff --git a/ciao/data/segmentation.py b/ciao/data/segmentation.py new file mode 100644 index 0000000..0570c1d --- /dev/null +++ b/ciao/data/segmentation.py @@ -0,0 +1,284 @@ +import math + +import numpy as np +import torch + + +def _hex_round_vectorized( + q: np.ndarray, r: np.ndarray +) -> tuple[np.ndarray, np.ndarray]: + """Vectorized hex rounding for entire arrays of axial coordinates. + + Args: + q: Array of fractional axial q coordinates + r: Array of fractional axial r coordinates + + Returns: + (q_rounded, r_rounded): Integer axial coordinates of nearest hex + """ + x = q + z = r + y = -x - z + + rx = np.round(x) + ry = np.round(y) + rz = np.round(z) + + dx = np.abs(rx - x) + dy = np.abs(ry - y) + dz = np.abs(rz - z) + + # Vectorized conditional logic + cond1 = (dx > dy) & (dx > dz) + cond2 = dy > dz + + rx = np.where(cond1, -ry - rz, rx) + ry = np.where(~cond1 & cond2, -rx - rz, ry) + rz = np.where(~cond1 & ~cond2, -rx - ry, rz) + + return rx.astype(np.int32), rz.astype(np.int32) + + +def bitmasks_to_adjacency_list( + adj_masks: tuple[int, ...], +) -> tuple[tuple[int, ...], ...]: + """Convert adjacency bitmasks back to adjacency list format. + + Args: + adj_masks: Tuple of integer bitmasks where bit i indicates neighbor i + + Returns: + Adjacency list as tuple of tuples of neighbor IDs + """ + adjacency_list = [] + for mask in adj_masks: + neighbors = [] + for i in range(len(adj_masks)): + if mask & (1 << i): + neighbors.append(i) + adjacency_list.append(tuple(neighbors)) + return tuple(adjacency_list) + + +def _build_square_adjacency_list( + segments: np.ndarray, neighborhood: int = 8 +) -> tuple[tuple[int, ...], ...]: + """Build adjacency list from segment array (for square grids) using vectorized operations. + + Args: + segments: 2D array mapping pixels to segment IDs + neighborhood: 4 or 8 connectivity + + Returns: + Adjacency list as tuple of tuples + """ + num_segments = segments.max() + 1 + adjacency_sets: list[set[int]] = [set() for _ in range(num_segments)] + + # Vectorized horizontal adjacency + left = segments[:, :-1].ravel() + right = segments[:, 1:].ravel() + mask_h = left != right + edges_h = np.column_stack([left[mask_h], right[mask_h]]) + + for seg1, seg2 in edges_h: + adjacency_sets[seg1].add(seg2) + adjacency_sets[seg2].add(seg1) + + # Vectorized vertical adjacency + top = segments[:-1, :].ravel() + bottom = segments[1:, :].ravel() + mask_v = top != bottom + edges_v = np.column_stack([top[mask_v], bottom[mask_v]]) + + for seg1, seg2 in edges_v: + adjacency_sets[seg1].add(seg2) + adjacency_sets[seg2].add(seg1) + + if neighborhood == 8: + # Vectorized diagonal adjacency (down-right) + top_left = segments[:-1, :-1].ravel() + bottom_right = segments[1:, 1:].ravel() + mask_dr = top_left != bottom_right + edges_dr = np.column_stack([top_left[mask_dr], bottom_right[mask_dr]]) + + for seg1, seg2 in edges_dr: + adjacency_sets[seg1].add(seg2) + adjacency_sets[seg2].add(seg1) + + # Vectorized diagonal adjacency (down-left) + top_right = segments[:-1, 1:].ravel() + bottom_left = segments[1:, :-1].ravel() + mask_dl = top_right != bottom_left + edges_dl = np.column_stack([top_right[mask_dl], bottom_left[mask_dl]]) + + for seg1, seg2 in edges_dl: + adjacency_sets[seg1].add(seg2) + adjacency_sets[seg2].add(seg1) + + # Convert to tuple of tuples + return tuple(tuple(sorted(neighbors)) for neighbors in adjacency_sets) + + +def _build_fast_adjacency_list( + hex_to_id: dict[tuple[int, int], int], max_id: int +) -> tuple[tuple[int, ...], ...]: + """Create a static adjacency list optimized for fast reading. + + Args: + hex_to_id: Dict mapping (q, r) -> int_id (0 to N-1) + max_id: Total number of segments (N) + + Returns: + adj_list: Tuple of Tuples. + adj_list[5] returns e.g. (4, 6, 12) - neighbors of segment 5. + """ + # Initialize empty lists for each ID + # Use list of lists for construction + temp_adj: list[list[int]] = [[] for _ in range(max_id)] + + # Offsets for neighbors (axial coords) + hex_neighbors = [(+1, 0), (+1, -1), (0, -1), (-1, 0), (-1, +1), (0, +1)] + + for (q, r), seg_id in hex_to_id.items(): + for dq, dr in hex_neighbors: + neighbor_key = (q + dq, r + dr) + + # If neighbor exists (is within the image) + if neighbor_key in hex_to_id: + neighbor_id = hex_to_id[neighbor_key] + temp_adj[seg_id].append(neighbor_id) + + # Convert to tuple of tuples for maximum read speed and memory efficiency + # Sort neighbors (optional, but good for determinism) + final_adj = tuple(tuple(sorted(neighbors)) for neighbors in temp_adj) + + return final_adj + + +def _build_adjacency_bitmasks(adj_list: tuple[tuple[int, ...], ...]) -> tuple[int, ...]: + """Convert adjacency list to a list of integers. + + adj_masks[5] will be an integer with bits set at positions of hex 5's neighbors. + """ + adj_masks = [] + for neighbors in adj_list: + mask = 0 + for n in neighbors: + mask |= 1 << n + adj_masks.append(mask) + return tuple(adj_masks) + + +def _create_square_grid( + input_tensor: torch.Tensor, square_size: int = 14, neighborhood: int = 8 +) -> tuple[np.ndarray, tuple[tuple[int, ...], ...]]: + """Create a grid of squares with adjacency list representing spatial relationships.""" + _channels, height, width = input_tensor.shape + segments = np.zeros((height, width), dtype=np.int32) + + segment_id = 0 + + # Create square grid + for row in range(0, height, square_size): + for col in range(0, width, square_size): + # Define square boundaries + row_end = min(row + square_size, height) + col_end = min(col + square_size, width) + + # Assign segment ID to all pixels in this square + segments[row:row_end, col:col_end] = segment_id + segment_id += 1 + + # Build adjacency list + adjacency_list = _build_square_adjacency_list(segments, neighborhood=neighborhood) + + return segments, adjacency_list + + +def _create_hexagonal_grid( + input_tensor: torch.Tensor, hex_radius: int = 14 +) -> tuple[np.ndarray, tuple[tuple[int, ...], ...]]: + """Create a grid of hexagons with adjacency list using vectorized operations. + + Uses axial coordinate system for precise hexagonal tiling (flat-top orientation). + Each hexagon has exactly 6 neighbors (neighborhood parameter ignored). + + Args: + input_tensor: Input image tensor [C, H, W] + hex_radius: Hex size parameter (distance from center to flat edge, default: 14) + + Returns: + segments: 2D array mapping pixels to segment IDs + adjacency_list: Tuple of tuples representing segment relationships + """ + _channels, height, width = input_tensor.shape + + # Create coordinate grids using meshgrid + x_coords, y_coords = np.meshgrid(np.arange(width), np.arange(height)) + + # Vectorized pixel to hex conversion + sqrt3 = math.sqrt(3) + q_float = (sqrt3 / 3 * x_coords - 1 / 3 * y_coords) / hex_radius + r_float = (2 / 3 * y_coords) / hex_radius + + # Vectorized hex rounding + q_int, r_int = _hex_round_vectorized(q_float, r_float) + + # Stack q and r to create unique keys + qr_stacked = np.stack([q_int.ravel(), r_int.ravel()], axis=1) + + # Use np.unique to assign segment IDs efficiently + _, segments_flat = np.unique(qr_stacked, axis=0, return_inverse=True) + segments = segments_flat.reshape((height, width)).astype(np.int32) + + # Build hex_to_id mapping for adjacency construction + unique_qr = np.unique(qr_stacked, axis=0) + hex_to_id = {(int(q), int(r)): idx for idx, (q, r) in enumerate(unique_qr)} + + # Build adjacency list using axial coordinate neighbors + adjacency_list = _build_fast_adjacency_list(hex_to_id, len(hex_to_id)) + + return segments, adjacency_list + + +def create_segmentation( + input_tensor: torch.Tensor, + segmentation_type: str = "hexagonal", + segment_size: int = 14, + neighborhood: int = 8, +) -> tuple[np.ndarray, tuple[int, ...]]: + """Create image segmentation with specified type. + + Args: + input_tensor: Input image tensor [C, H, W] + segmentation_type: "square" or "hexagonal" + segment_size: Size parameter (square_size or hex_radius) + neighborhood: Neighborhood connectivity for squares (4, or 8) + + Returns: + segments: 2D array mapping pixels to segment IDs + adj_masks: Tuple of integer bitmasks representing adjacency relationships + """ + if segment_size <= 0: + raise ValueError( + f"segment_size must be positive, got {segment_size}. " + "Non-positive values cause division by zero or invalid range operations." + ) + + if segmentation_type == "square": + segments, adjacency_list = _create_square_grid( + input_tensor, square_size=segment_size, neighborhood=neighborhood + ) + elif segmentation_type == "hexagonal": + segments, adjacency_list = _create_hexagonal_grid( + input_tensor, hex_radius=segment_size + ) + else: + raise ValueError( + f"Unknown segmentation_type: {segmentation_type}. Use 'square' or 'hexagonal'." + ) + + # Convert adjacency list to bitmasks + adj_masks = _build_adjacency_bitmasks(adjacency_list) + return segments, adj_masks From 45e4f1d4df1a9d70ab649ca762eac50d33171b49 Mon Sep 17 00:00:00 2001 From: dhalmazna Date: Wed, 4 Mar 2026 17:29:21 +0100 Subject: [PATCH 04/13] chore: remove the unused networkx dependency --- pyproject.toml | 1 - uv.lock | 2 -- 2 files changed, 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 066f668..3d9d8d3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,7 +27,6 @@ dependencies = [ # Scientific computing "numpy>=1.21.0", - "networkx>=2.6.0", # Others "tqdm>=4.0.0", diff --git a/uv.lock b/uv.lock index 1122328..a6ebabe 100644 --- a/uv.lock +++ b/uv.lock @@ -2407,7 +2407,6 @@ dependencies = [ { name = "ipywidgets" }, { name = "matplotlib" }, { name = "mlflow" }, - { name = "networkx" }, { name = "numpy" }, { name = "omegaconf" }, { name = "pillow" }, @@ -2434,7 +2433,6 @@ requires-dist = [ { name = "ipywidgets", specifier = ">=7.0.0" }, { name = "matplotlib", specifier = ">=3.5.0" }, { name = "mlflow", specifier = ">=3.0" }, - { name = "networkx", specifier = ">=2.6.0" }, { name = "numpy", specifier = ">=1.21.0" }, { name = "omegaconf", specifier = ">=2.3.0" }, { name = "pillow", specifier = ">=9.0.0" }, From ce834bcf511059f88d28b3aa364ed360504a0e8b Mon Sep 17 00:00:00 2001 From: dhalmazna Date: Wed, 4 Mar 2026 18:02:50 +0100 Subject: [PATCH 05/13] feat: apply agents' suggestions --- ciao/data/loader.py | 22 ++++++++++++++++------ ciao/data/preprocessing.py | 8 ++++++-- ciao/data/segmentation.py | 32 +++++++++++++++----------------- 3 files changed, 37 insertions(+), 25 deletions(-) diff --git a/ciao/data/loader.py b/ciao/data/loader.py index ec2d331..5ef110e 100644 --- a/ciao/data/loader.py +++ b/ciao/data/loader.py @@ -2,14 +2,15 @@ from collections.abc import Iterator from pathlib import Path -from typing import Any + +from omegaconf import DictConfig # Supported image formats IMAGE_EXTENSIONS = (".jpg", ".jpeg", ".png", ".bmp", ".webp") -def get_image_loader(config: Any) -> Iterator[Path]: +def get_image_loader(config: DictConfig) -> Iterator[Path]: """Create image loader based on configuration. Args: @@ -20,10 +21,17 @@ def get_image_loader(config: Any) -> Iterator[Path]: Raises: ValueError: If neither image_path nor batch_path is specified + FileNotFoundError: If single image_path does not exist """ if config.data.get("image_path"): - # Single image mode - yield Path(config.data.image_path) + # Single image mode - validate file exists + image_path = Path(config.data.image_path) + if not image_path.is_file(): + raise FileNotFoundError( + f"image_path must be a valid file, got: {image_path}. " + "Check for typos or incorrect path configuration." + ) + yield image_path elif config.data.get("batch_path"): # Directory mode - find all images with supported extensions @@ -34,8 +42,10 @@ def get_image_loader(config: Any) -> Iterator[Path]: "Check for typos or incorrect path configuration." ) - for ext in IMAGE_EXTENSIONS: - yield from directory.glob(f"**/*{ext}") + # Single rglob pass with suffix filtering + for path in directory.rglob("*"): + if path.suffix.lower() in IMAGE_EXTENSIONS: + yield path else: raise ValueError("Must specify either image_path or batch_path in config") diff --git a/ciao/data/preprocessing.py b/ciao/data/preprocessing.py index 997e76d..bc85cbf 100644 --- a/ciao/data/preprocessing.py +++ b/ciao/data/preprocessing.py @@ -1,4 +1,5 @@ from pathlib import Path +from typing import cast import torch import torchvision.transforms as transforms @@ -31,7 +32,10 @@ def load_and_preprocess_image( if device is None: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - image = Image.open(image_path).convert("RGB") - input_tensor = preprocess(image).to(device) # (3, 224, 224) - on correct device + # Use context manager to prevent file descriptor leaks + with Image.open(image_path) as img: + image = img.convert("RGB") + tensor = cast(torch.Tensor, preprocess(image)) # (3, 224, 224) + input_tensor = tensor.to(device) return input_tensor diff --git a/ciao/data/segmentation.py b/ciao/data/segmentation.py index 0570c1d..2c009c8 100644 --- a/ciao/data/segmentation.py +++ b/ciao/data/segmentation.py @@ -81,19 +81,14 @@ def _build_square_adjacency_list( mask_h = left != right edges_h = np.column_stack([left[mask_h], right[mask_h]]) - for seg1, seg2 in edges_h: - adjacency_sets[seg1].add(seg2) - adjacency_sets[seg2].add(seg1) - # Vectorized vertical adjacency top = segments[:-1, :].ravel() bottom = segments[1:, :].ravel() mask_v = top != bottom edges_v = np.column_stack([top[mask_v], bottom[mask_v]]) - for seg1, seg2 in edges_v: - adjacency_sets[seg1].add(seg2) - adjacency_sets[seg2].add(seg1) + # Collect all edges + edge_arrays = [edges_h, edges_v] if neighborhood == 8: # Vectorized diagonal adjacency (down-right) @@ -102,19 +97,19 @@ def _build_square_adjacency_list( mask_dr = top_left != bottom_right edges_dr = np.column_stack([top_left[mask_dr], bottom_right[mask_dr]]) - for seg1, seg2 in edges_dr: - adjacency_sets[seg1].add(seg2) - adjacency_sets[seg2].add(seg1) - # Vectorized diagonal adjacency (down-left) top_right = segments[:-1, 1:].ravel() bottom_left = segments[1:, :-1].ravel() mask_dl = top_right != bottom_left edges_dl = np.column_stack([top_right[mask_dl], bottom_left[mask_dl]]) - for seg1, seg2 in edges_dl: - adjacency_sets[seg1].add(seg2) - adjacency_sets[seg2].add(seg1) + edge_arrays.extend([edges_dr, edges_dl]) + + # Stack all edges together and populate adjacency sets in a single loop + all_edges = np.vstack(edge_arrays) + for seg1, seg2 in all_edges: + adjacency_sets[seg1].add(seg2) + adjacency_sets[seg2].add(seg1) # Convert to tuple of tuples return tuple(tuple(sorted(neighbors)) for neighbors in adjacency_sets) @@ -228,12 +223,11 @@ def _create_hexagonal_grid( # Stack q and r to create unique keys qr_stacked = np.stack([q_int.ravel(), r_int.ravel()], axis=1) - # Use np.unique to assign segment IDs efficiently - _, segments_flat = np.unique(qr_stacked, axis=0, return_inverse=True) + # Use np.unique to assign segment IDs efficiently (compute only once) + unique_qr, segments_flat = np.unique(qr_stacked, axis=0, return_inverse=True) segments = segments_flat.reshape((height, width)).astype(np.int32) # Build hex_to_id mapping for adjacency construction - unique_qr = np.unique(qr_stacked, axis=0) hex_to_id = {(int(q), int(r)): idx for idx, (q, r) in enumerate(unique_qr)} # Build adjacency list using axial coordinate neighbors @@ -267,6 +261,10 @@ def create_segmentation( ) if segmentation_type == "square": + if neighborhood not in (4, 8): + raise ValueError( + f"For square segmentation, neighborhood must be 4 or 8, got {neighborhood}." + ) segments, adjacency_list = _create_square_grid( input_tensor, square_size=segment_size, neighborhood=neighborhood ) From 677f408b53486ae7991470a4cee4461ddfe4b016 Mon Sep 17 00:00:00 2001 From: dhalmazna Date: Wed, 4 Mar 2026 18:19:15 +0100 Subject: [PATCH 06/13] feat: remove unused bitmask to adjacency list conversion function --- ciao/data/segmentation.py | 21 --------------------- 1 file changed, 21 deletions(-) diff --git a/ciao/data/segmentation.py b/ciao/data/segmentation.py index 2c009c8..cef3165 100644 --- a/ciao/data/segmentation.py +++ b/ciao/data/segmentation.py @@ -39,27 +39,6 @@ def _hex_round_vectorized( return rx.astype(np.int32), rz.astype(np.int32) -def bitmasks_to_adjacency_list( - adj_masks: tuple[int, ...], -) -> tuple[tuple[int, ...], ...]: - """Convert adjacency bitmasks back to adjacency list format. - - Args: - adj_masks: Tuple of integer bitmasks where bit i indicates neighbor i - - Returns: - Adjacency list as tuple of tuples of neighbor IDs - """ - adjacency_list = [] - for mask in adj_masks: - neighbors = [] - for i in range(len(adj_masks)): - if mask & (1 << i): - neighbors.append(i) - adjacency_list.append(tuple(neighbors)) - return tuple(adjacency_list) - - def _build_square_adjacency_list( segments: np.ndarray, neighborhood: int = 8 ) -> tuple[tuple[int, ...], ...]: From 76b0bc636ee38d308a11780216491e0ace1cbe21 Mon Sep 17 00:00:00 2001 From: dhalmazna Date: Thu, 5 Mar 2026 11:09:04 +0100 Subject: [PATCH 07/13] feat: implement image replacement strategies --- ciao/data/replacement.py | 119 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 119 insertions(+) create mode 100644 ciao/data/replacement.py diff --git a/ciao/data/replacement.py b/ciao/data/replacement.py new file mode 100644 index 0000000..23e8fd9 --- /dev/null +++ b/ciao/data/replacement.py @@ -0,0 +1,119 @@ +"""Image replacement strategies for masking operations.""" + +import torch +import torchvision.transforms.functional as TF +from matplotlib import pyplot as plt + + +# ImageNet normalization constants +IMAGENET_MEAN = torch.tensor([0.485, 0.456, 0.406]) +IMAGENET_STD = torch.tensor([0.229, 0.224, 0.225]) + + +def calculate_image_mean_color(input_tensor: torch.Tensor) -> torch.Tensor: + """Calculate image mean color using ImageNet normalization constants. + + Args: + input_tensor: Input tensor [3, H, W] or [1, 3, H, W] (ImageNet normalized) + + Returns: + Mean color tensor [3, 1, 1] (ImageNet normalized) + """ + device = input_tensor.device + + # Add batch dimension if needed + if input_tensor.dim() == 3: + input_tensor = input_tensor.unsqueeze(0) + + # Move normalization constants to same device + imagenet_mean = IMAGENET_MEAN.view(1, 3, 1, 1).to(device) + imagenet_std = IMAGENET_STD.view(1, 3, 1, 1).to(device) + + # Unnormalize, calculate mean, then re-normalize + unnormalized = (input_tensor * imagenet_std) + imagenet_mean + mean_color = unnormalized.mean(dim=(2, 3), keepdim=True) + normalized_mean = (mean_color - imagenet_mean) / imagenet_std + + return normalized_mean.squeeze(0) # Remove batch dimension + + +def get_replacement_image( + input_tensor: torch.Tensor, + replacement: str = "mean_color", + color: tuple[int, int, int] = (0, 0, 0), +) -> torch.Tensor: + """Generate replacement image for masking operations. + + Args: + input_tensor: Input tensor [3, H, W] (ImageNet normalized) + replacement: Strategy - "mean_color", "interlacing", "blur", or "solid_color" + color: For solid_color mode, RGB tuple (0-255). Defaults to black (0, 0, 0) + + Returns: + replacement_image: torch tensor [3, H, W] on same device + """ + device = input_tensor.device + + # Extract spatial dimensions from input tensor + _, height, width = input_tensor.shape + + if replacement == "mean_color": + # Fill entire image with mean color + mean_color = calculate_image_mean_color(input_tensor) # [3, 1, 1] + replacement_image = mean_color.expand(-1, height, width) # [3, H, W] + + elif replacement == "interlacing": + # Create interlaced pattern: even columns flipped vertically, then even rows flipped horizontally + replacement_image = input_tensor.clone() + even_row_indices = torch.arange(0, height, 2) # Even row indices + even_col_indices = torch.arange(0, width, 2) # Even column indices + + # Step 1: Flip even columns vertically (upside down) + replacement_image[:, :, even_col_indices] = torch.flip( + replacement_image[:, :, even_col_indices], dims=[1] + ) + + # Step 2: Flip even rows horizontally (left-right) + replacement_image[:, even_row_indices, :] = torch.flip( + replacement_image[:, even_row_indices, :], dims=[2] + ) + + elif replacement == "blur": + # Apply Gaussian blur using torchvision functional API + input_batch = input_tensor.unsqueeze(0) # [1, 3, H, W] + replacement_image = TF.gaussian_blur( + input_batch, kernel_size=[7, 7], sigma=[1.5, 1.5] + ).squeeze(0) # [3, H, W] + + elif replacement == "solid_color": + # Fill with specified solid color (expects RGB values in 0-255 range) + # Convert color to torch tensor + color_tensor = torch.tensor(color, dtype=torch.float32, device=device) + + # Convert from 0-255 range to 0-1 range + color_tensor = color_tensor / 255.0 + + # Apply ImageNet normalization + mean = IMAGENET_MEAN.view(3, 1, 1).to(device) + std = IMAGENET_STD.view(3, 1, 1).to(device) + normalized_color = (color_tensor.view(3, 1, 1) - mean) / std + replacement_image = normalized_color.expand(-1, height, width) # [3, H, W] + + else: + raise ValueError(f"Unknown replacement strategy: {replacement}") + + return replacement_image + + +def plot_image_mean_color(input_tensor: torch.Tensor) -> None: + """Display the mean color of the image. + + Args: + input_tensor: Input tensor [3, H, W] (ImageNet normalized) + + Note: + The visualization shows the normalized tensor (ImageNet normalization). + """ + normalized_mean = calculate_image_mean_color(input_tensor).unsqueeze(0) + plt.imshow(normalized_mean[0].permute(1, 2, 0)) + plt.show() From de00a02b3198c9c15394573e60f4c6d29ac2d4b3 Mon Sep 17 00:00:00 2001 From: dhalmazna Date: Thu, 5 Mar 2026 11:18:35 +0100 Subject: [PATCH 08/13] chore: update __init__.py --- ciao/data/__init__.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/ciao/data/__init__.py b/ciao/data/__init__.py index 77cc547..e00a039 100644 --- a/ciao/data/__init__.py +++ b/ciao/data/__init__.py @@ -2,7 +2,19 @@ from ciao.data.loader import get_image_loader from ciao.data.preprocessing import load_and_preprocess_image +from ciao.data.replacement import ( + calculate_image_mean_color, + get_replacement_image, + plot_image_mean_color, +) from ciao.data.segmentation import create_segmentation -__all__ = ["create_segmentation", "get_image_loader", "load_and_preprocess_image"] +__all__ = [ + "calculate_image_mean_color", + "create_segmentation", + "get_image_loader", + "get_replacement_image", + "load_and_preprocess_image", + "plot_image_mean_color", +] From 60deed345af24c304425e0f0d079a3980f4b100d Mon Sep 17 00:00:00 2001 From: dhalmazna Date: Tue, 10 Mar 2026 13:41:53 +0100 Subject: [PATCH 09/13] chore: ruff checks --- ciao/data/preprocessing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ciao/data/preprocessing.py b/ciao/data/preprocessing.py index bc85cbf..c0f482c 100644 --- a/ciao/data/preprocessing.py +++ b/ciao/data/preprocessing.py @@ -35,7 +35,7 @@ def load_and_preprocess_image( # Use context manager to prevent file descriptor leaks with Image.open(image_path) as img: image = img.convert("RGB") - tensor = cast(torch.Tensor, preprocess(image)) # (3, 224, 224) + tensor = cast("torch.Tensor", preprocess(image)) # (3, 224, 224) input_tensor = tensor.to(device) return input_tensor From 60a25b32a969fdbe4c821f040ea17bc2e53fe8e9 Mon Sep 17 00:00:00 2001 From: dhalmazna Date: Tue, 10 Mar 2026 14:17:22 +0100 Subject: [PATCH 10/13] chore: apply agents' suggestions --- ciao/data/loader.py | 14 ++++++++++---- ciao/data/replacement.py | 11 +++++++++-- 2 files changed, 19 insertions(+), 6 deletions(-) diff --git a/ciao/data/loader.py b/ciao/data/loader.py index 5ef110e..52eaa30 100644 --- a/ciao/data/loader.py +++ b/ciao/data/loader.py @@ -23,9 +23,15 @@ def get_image_loader(config: DictConfig) -> Iterator[Path]: ValueError: If neither image_path nor batch_path is specified FileNotFoundError: If single image_path does not exist """ - if config.data.get("image_path"): + image_path_value = config.data.get("image_path") + batch_path_value = config.data.get("batch_path") + + if image_path_value and batch_path_value: + raise ValueError("Specify exactly one of image_path or batch_path in config") + + if image_path_value: # Single image mode - validate file exists - image_path = Path(config.data.image_path) + image_path = Path(image_path_value) if not image_path.is_file(): raise FileNotFoundError( f"image_path must be a valid file, got: {image_path}. " @@ -33,9 +39,9 @@ def get_image_loader(config: DictConfig) -> Iterator[Path]: ) yield image_path - elif config.data.get("batch_path"): + elif batch_path_value: # Directory mode - find all images with supported extensions - directory = Path(config.data.batch_path) + directory = Path(batch_path_value) if not directory.is_dir(): raise ValueError( f"batch_path must be a valid directory, got: {directory}. " diff --git a/ciao/data/replacement.py b/ciao/data/replacement.py index 23e8fd9..a362ab0 100644 --- a/ciao/data/replacement.py +++ b/ciao/data/replacement.py @@ -114,6 +114,13 @@ def plot_image_mean_color(input_tensor: torch.Tensor) -> None: Note: The visualization shows the normalized tensor (ImageNet normalization). """ - normalized_mean = calculate_image_mean_color(input_tensor).unsqueeze(0) - plt.imshow(normalized_mean[0].permute(1, 2, 0)) + mean = IMAGENET_MEAN.view(3, 1, 1).to( + device=input_tensor.device, dtype=input_tensor.dtype + ) + std = IMAGENET_STD.view(3, 1, 1).to( + device=input_tensor.device, dtype=input_tensor.dtype + ) + normalized_mean = calculate_image_mean_color(input_tensor) + display_mean = ((normalized_mean * std) + mean).clamp(0, 1) + plt.imshow(display_mean.permute(1, 2, 0).detach().cpu()) plt.show() From 6b1886a2dff602a23ece712cd53f534a9e1e8de5 Mon Sep 17 00:00:00 2001 From: dhalmazna Date: Wed, 11 Mar 2026 09:04:15 +0100 Subject: [PATCH 11/13] refactor: improve image path validation and fail-fast logic --- README.md | 2 +- ciao/data/__init__.py | 4 ++-- ciao/data/loader.py | 32 ++++++++++++++++---------------- 3 files changed, 19 insertions(+), 19 deletions(-) diff --git a/README.md b/README.md index 4c4ff3f..17bee0b 100644 --- a/README.md +++ b/README.md @@ -98,7 +98,7 @@ ciao/ │ │ ├── nodes.py # Node classes for tree/graph search │ │ └── search_helpers.py # Shared MCTS/MCGS helper functions │ ├── data/ # Data loading and preprocessing -│ │ ├── loader.py # Image loaders +│ │ ├── loader.py # Path loaders │ │ ├── preprocessing.py # Image preprocessing utilities │ │ └── segmentation.py # Segmentation utilities (hex/square grids) │ ├── evaluation/ # Scoring and evaluation diff --git a/ciao/data/__init__.py b/ciao/data/__init__.py index e00a039..ee39bbc 100644 --- a/ciao/data/__init__.py +++ b/ciao/data/__init__.py @@ -1,6 +1,6 @@ """Data loading utilities for CIAO.""" -from ciao.data.loader import get_image_loader +from ciao.data.loader import iter_image_paths from ciao.data.preprocessing import load_and_preprocess_image from ciao.data.replacement import ( calculate_image_mean_color, @@ -13,8 +13,8 @@ __all__ = [ "calculate_image_mean_color", "create_segmentation", - "get_image_loader", "get_replacement_image", + "iter_image_paths", "load_and_preprocess_image", "plot_image_mean_color", ] diff --git a/ciao/data/loader.py b/ciao/data/loader.py index 52eaa30..e4a5d0f 100644 --- a/ciao/data/loader.py +++ b/ciao/data/loader.py @@ -1,4 +1,4 @@ -"""Simple image path loading utilities.""" +"""Simple image path resolution utilities.""" from collections.abc import Iterator from pathlib import Path @@ -10,18 +10,19 @@ IMAGE_EXTENSIONS = (".jpg", ".jpeg", ".png", ".bmp", ".webp") -def get_image_loader(config: DictConfig) -> Iterator[Path]: - """Create image loader based on configuration. +def iter_image_paths(config: DictConfig) -> Iterator[Path]: + """Generate paths to images based on configuration. Args: - config: Hydra config object + config: Hydra config object containing data.image_path or data.batch_path Returns: - Iterator of Path objects + Iterator of Path objects pointing to valid images Raises: - ValueError: If neither image_path nor batch_path is specified + ValueError: If config specifies both or neither paths FileNotFoundError: If single image_path does not exist + NotADirectoryError: If batch_path directory does not exist """ image_path_value = config.data.get("image_path") batch_path_value = config.data.get("batch_path") @@ -29,29 +30,28 @@ def get_image_loader(config: DictConfig) -> Iterator[Path]: if image_path_value and batch_path_value: raise ValueError("Specify exactly one of image_path or batch_path in config") + if not image_path_value and not batch_path_value: + raise ValueError("Must specify either image_path or batch_path in config") + if image_path_value: - # Single image mode - validate file exists image_path = Path(image_path_value) if not image_path.is_file(): raise FileNotFoundError( f"image_path must be a valid file, got: {image_path}. " "Check for typos or incorrect path configuration." ) - yield image_path - elif batch_path_value: - # Directory mode - find all images with supported extensions + if batch_path_value: directory = Path(batch_path_value) if not directory.is_dir(): - raise ValueError( + raise NotADirectoryError( f"batch_path must be a valid directory, got: {directory}. " "Check for typos or incorrect path configuration." ) - # Single rglob pass with suffix filtering - for path in directory.rglob("*"): + if image_path_value: + yield Path(image_path_value) + else: + for path in Path(batch_path_value).rglob("*"): if path.suffix.lower() in IMAGE_EXTENSIONS: yield path - - else: - raise ValueError("Must specify either image_path or batch_path in config") From 6ac5c402ea64ae7c461cd436414815e5f70ce5f7 Mon Sep 17 00:00:00 2001 From: dhalmazna Date: Wed, 11 Mar 2026 09:10:24 +0100 Subject: [PATCH 12/13] refactor: improve fail-fast logic in segmentation; convert strings to literals --- ciao/data/__init__.py | 2 -- ciao/data/replacement.py | 28 +++++----------------------- ciao/data/segmentation.py | 23 +++++++++++++---------- 3 files changed, 18 insertions(+), 35 deletions(-) diff --git a/ciao/data/__init__.py b/ciao/data/__init__.py index ee39bbc..1a44cfa 100644 --- a/ciao/data/__init__.py +++ b/ciao/data/__init__.py @@ -5,7 +5,6 @@ from ciao.data.replacement import ( calculate_image_mean_color, get_replacement_image, - plot_image_mean_color, ) from ciao.data.segmentation import create_segmentation @@ -16,5 +15,4 @@ "get_replacement_image", "iter_image_paths", "load_and_preprocess_image", - "plot_image_mean_color", ] diff --git a/ciao/data/replacement.py b/ciao/data/replacement.py index a362ab0..86b83b1 100644 --- a/ciao/data/replacement.py +++ b/ciao/data/replacement.py @@ -1,8 +1,9 @@ """Image replacement strategies for masking operations.""" +from typing import Literal + import torch import torchvision.transforms.functional as TF -from matplotlib import pyplot as plt # ImageNet normalization constants @@ -39,7 +40,9 @@ def calculate_image_mean_color(input_tensor: torch.Tensor) -> torch.Tensor: def get_replacement_image( input_tensor: torch.Tensor, - replacement: str = "mean_color", + replacement: Literal[ + "mean_color", "interlacing", "blur", "solid_color" + ] = "mean_color", color: tuple[int, int, int] = (0, 0, 0), ) -> torch.Tensor: """Generate replacement image for masking operations. @@ -103,24 +106,3 @@ def get_replacement_image( raise ValueError(f"Unknown replacement strategy: {replacement}") return replacement_image - - -def plot_image_mean_color(input_tensor: torch.Tensor) -> None: - """Display the mean color of the image. - - Args: - input_tensor: Input tensor [3, H, W] (ImageNet normalized) - - Note: - The visualization shows the normalized tensor (ImageNet normalization). - """ - mean = IMAGENET_MEAN.view(3, 1, 1).to( - device=input_tensor.device, dtype=input_tensor.dtype - ) - std = IMAGENET_STD.view(3, 1, 1).to( - device=input_tensor.device, dtype=input_tensor.dtype - ) - normalized_mean = calculate_image_mean_color(input_tensor) - display_mean = ((normalized_mean * std) + mean).clamp(0, 1) - plt.imshow(display_mean.permute(1, 2, 0).detach().cpu()) - plt.show() diff --git a/ciao/data/segmentation.py b/ciao/data/segmentation.py index cef3165..e13c349 100644 --- a/ciao/data/segmentation.py +++ b/ciao/data/segmentation.py @@ -1,4 +1,5 @@ import math +from typing import Literal import numpy as np import torch @@ -217,7 +218,7 @@ def _create_hexagonal_grid( def create_segmentation( input_tensor: torch.Tensor, - segmentation_type: str = "hexagonal", + segmentation_type: Literal["square", "hexagonal"] = "hexagonal", segment_size: int = 14, neighborhood: int = 8, ) -> tuple[np.ndarray, tuple[int, ...]]: @@ -239,22 +240,24 @@ def create_segmentation( "Non-positive values cause division by zero or invalid range operations." ) + if segmentation_type not in ("square", "hexagonal"): + raise ValueError( + f"Unknown segmentation_type: {segmentation_type}. Use 'square' or 'hexagonal'." + ) + + if segmentation_type == "square" and neighborhood not in (4, 8): + raise ValueError( + f"For square segmentation, neighborhood must be 4 or 8, got {neighborhood}." + ) + if segmentation_type == "square": - if neighborhood not in (4, 8): - raise ValueError( - f"For square segmentation, neighborhood must be 4 or 8, got {neighborhood}." - ) segments, adjacency_list = _create_square_grid( input_tensor, square_size=segment_size, neighborhood=neighborhood ) - elif segmentation_type == "hexagonal": + else: segments, adjacency_list = _create_hexagonal_grid( input_tensor, hex_radius=segment_size ) - else: - raise ValueError( - f"Unknown segmentation_type: {segmentation_type}. Use 'square' or 'hexagonal'." - ) # Convert adjacency list to bitmasks adj_masks = _build_adjacency_bitmasks(adjacency_list) From 452935c4049d6ebfd057de8aeaa9a2cee75c29c0 Mon Sep 17 00:00:00 2001 From: dhalmazna Date: Wed, 11 Mar 2026 09:54:36 +0100 Subject: [PATCH 13/13] refactor: enhance image path validation --- ciao/data/loader.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/ciao/data/loader.py b/ciao/data/loader.py index e4a5d0f..dc9a9ac 100644 --- a/ciao/data/loader.py +++ b/ciao/data/loader.py @@ -40,6 +40,10 @@ def iter_image_paths(config: DictConfig) -> Iterator[Path]: f"image_path must be a valid file, got: {image_path}. " "Check for typos or incorrect path configuration." ) + if image_path.suffix.lower() not in IMAGE_EXTENSIONS: + raise ValueError( + f"image_path must use a supported image extension, got: {image_path}" + ) if batch_path_value: directory = Path(batch_path_value) @@ -50,8 +54,8 @@ def iter_image_paths(config: DictConfig) -> Iterator[Path]: ) if image_path_value: - yield Path(image_path_value) + yield image_path else: - for path in Path(batch_path_value).rglob("*"): - if path.suffix.lower() in IMAGE_EXTENSIONS: + for path in directory.rglob("*"): + if path.is_file() and path.suffix.lower() in IMAGE_EXTENSIONS: yield path