diff --git a/.mypy.ini b/.mypy.ini index 1c6f44e..8d15f3c 100644 --- a/.mypy.ini +++ b/.mypy.ini @@ -2,4 +2,5 @@ strict = True ignore_missing_imports = True disallow_untyped_calls = False -disable_error_code = no-any-return \ No newline at end of file +disable_error_code = no-any-return +explicit_package_bases = True \ No newline at end of file diff --git a/README.md b/README.md index 4c4ff3f..17bee0b 100644 --- a/README.md +++ b/README.md @@ -98,7 +98,7 @@ ciao/ │ │ ├── nodes.py # Node classes for tree/graph search │ │ └── search_helpers.py # Shared MCTS/MCGS helper functions │ ├── data/ # Data loading and preprocessing -│ │ ├── loader.py # Image loaders +│ │ ├── loader.py # Path loaders │ │ ├── preprocessing.py # Image preprocessing utilities │ │ └── segmentation.py # Segmentation utilities (hex/square grids) │ ├── evaluation/ # Scoring and evaluation diff --git a/ciao/data/__init__.py b/ciao/data/__init__.py new file mode 100644 index 0000000..1a44cfa --- /dev/null +++ b/ciao/data/__init__.py @@ -0,0 +1,18 @@ +"""Data loading utilities for CIAO.""" + +from ciao.data.loader import iter_image_paths +from ciao.data.preprocessing import load_and_preprocess_image +from ciao.data.replacement import ( + calculate_image_mean_color, + get_replacement_image, +) +from ciao.data.segmentation import create_segmentation + + +__all__ = [ + "calculate_image_mean_color", + "create_segmentation", + "get_replacement_image", + "iter_image_paths", + "load_and_preprocess_image", +] diff --git a/ciao/data/loader.py b/ciao/data/loader.py new file mode 100644 index 0000000..dc9a9ac --- /dev/null +++ b/ciao/data/loader.py @@ -0,0 +1,61 @@ +"""Simple image path resolution utilities.""" + +from collections.abc import Iterator +from pathlib import Path + +from omegaconf import DictConfig + + +# Supported image formats +IMAGE_EXTENSIONS = (".jpg", ".jpeg", ".png", ".bmp", ".webp") + + +def iter_image_paths(config: DictConfig) -> Iterator[Path]: + """Generate paths to images based on configuration. + + Args: + config: Hydra config object containing data.image_path or data.batch_path + + Returns: + Iterator of Path objects pointing to valid images + + Raises: + ValueError: If config specifies both or neither paths + FileNotFoundError: If single image_path does not exist + NotADirectoryError: If batch_path directory does not exist + """ + image_path_value = config.data.get("image_path") + batch_path_value = config.data.get("batch_path") + + if image_path_value and batch_path_value: + raise ValueError("Specify exactly one of image_path or batch_path in config") + + if not image_path_value and not batch_path_value: + raise ValueError("Must specify either image_path or batch_path in config") + + if image_path_value: + image_path = Path(image_path_value) + if not image_path.is_file(): + raise FileNotFoundError( + f"image_path must be a valid file, got: {image_path}. " + "Check for typos or incorrect path configuration." + ) + if image_path.suffix.lower() not in IMAGE_EXTENSIONS: + raise ValueError( + f"image_path must use a supported image extension, got: {image_path}" + ) + + if batch_path_value: + directory = Path(batch_path_value) + if not directory.is_dir(): + raise NotADirectoryError( + f"batch_path must be a valid directory, got: {directory}. " + "Check for typos or incorrect path configuration." + ) + + if image_path_value: + yield image_path + else: + for path in directory.rglob("*"): + if path.is_file() and path.suffix.lower() in IMAGE_EXTENSIONS: + yield path diff --git a/ciao/data/preprocessing.py b/ciao/data/preprocessing.py new file mode 100644 index 0000000..c0f482c --- /dev/null +++ b/ciao/data/preprocessing.py @@ -0,0 +1,41 @@ +from pathlib import Path +from typing import cast + +import torch +import torchvision.transforms as transforms +from PIL import Image + + +# ImageNet preprocessing transforms +preprocess = transforms.Compose( + [ + transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ] +) + + +def load_and_preprocess_image( + image_path: str | Path, device: torch.device | None = None +) -> torch.Tensor: + """Load and preprocess an image for the model. + + Args: + image_path: Path to image file + device: Device to place tensor on (defaults to cuda if available, else cpu) + + Returns: + Preprocessed image tensor [3, 224, 224] on specified device + """ + if device is None: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + # Use context manager to prevent file descriptor leaks + with Image.open(image_path) as img: + image = img.convert("RGB") + tensor = cast("torch.Tensor", preprocess(image)) # (3, 224, 224) + input_tensor = tensor.to(device) + + return input_tensor diff --git a/ciao/data/replacement.py b/ciao/data/replacement.py new file mode 100644 index 0000000..86b83b1 --- /dev/null +++ b/ciao/data/replacement.py @@ -0,0 +1,108 @@ +"""Image replacement strategies for masking operations.""" + +from typing import Literal + +import torch +import torchvision.transforms.functional as TF + + +# ImageNet normalization constants +IMAGENET_MEAN = torch.tensor([0.485, 0.456, 0.406]) +IMAGENET_STD = torch.tensor([0.229, 0.224, 0.225]) + + +def calculate_image_mean_color(input_tensor: torch.Tensor) -> torch.Tensor: + """Calculate image mean color using ImageNet normalization constants. + + Args: + input_tensor: Input tensor [3, H, W] or [1, 3, H, W] (ImageNet normalized) + + Returns: + Mean color tensor [3, 1, 1] (ImageNet normalized) + """ + device = input_tensor.device + + # Add batch dimension if needed + if input_tensor.dim() == 3: + input_tensor = input_tensor.unsqueeze(0) + + # Move normalization constants to same device + imagenet_mean = IMAGENET_MEAN.view(1, 3, 1, 1).to(device) + imagenet_std = IMAGENET_STD.view(1, 3, 1, 1).to(device) + + # Unnormalize, calculate mean, then re-normalize + unnormalized = (input_tensor * imagenet_std) + imagenet_mean + mean_color = unnormalized.mean(dim=(2, 3), keepdim=True) + normalized_mean = (mean_color - imagenet_mean) / imagenet_std + + return normalized_mean.squeeze(0) # Remove batch dimension + + +def get_replacement_image( + input_tensor: torch.Tensor, + replacement: Literal[ + "mean_color", "interlacing", "blur", "solid_color" + ] = "mean_color", + color: tuple[int, int, int] = (0, 0, 0), +) -> torch.Tensor: + """Generate replacement image for masking operations. + + Args: + input_tensor: Input tensor [3, H, W] (ImageNet normalized) + replacement: Strategy - "mean_color", "interlacing", "blur", or "solid_color" + color: For solid_color mode, RGB tuple (0-255). Defaults to black (0, 0, 0) + + Returns: + replacement_image: torch tensor [3, H, W] on same device + """ + device = input_tensor.device + + # Extract spatial dimensions from input tensor + _, height, width = input_tensor.shape + + if replacement == "mean_color": + # Fill entire image with mean color + mean_color = calculate_image_mean_color(input_tensor) # [3, 1, 1] + replacement_image = mean_color.expand(-1, height, width) # [3, H, W] + + elif replacement == "interlacing": + # Create interlaced pattern: even columns flipped vertically, then even rows flipped horizontally + replacement_image = input_tensor.clone() + even_row_indices = torch.arange(0, height, 2) # Even row indices + even_col_indices = torch.arange(0, width, 2) # Even column indices + + # Step 1: Flip even columns vertically (upside down) + replacement_image[:, :, even_col_indices] = torch.flip( + replacement_image[:, :, even_col_indices], dims=[1] + ) + + # Step 2: Flip even rows horizontally (left-right) + replacement_image[:, even_row_indices, :] = torch.flip( + replacement_image[:, even_row_indices, :], dims=[2] + ) + + elif replacement == "blur": + # Apply Gaussian blur using torchvision functional API + input_batch = input_tensor.unsqueeze(0) # [1, 3, H, W] + replacement_image = TF.gaussian_blur( + input_batch, kernel_size=[7, 7], sigma=[1.5, 1.5] + ).squeeze(0) # [3, H, W] + + elif replacement == "solid_color": + # Fill with specified solid color (expects RGB values in 0-255 range) + # Convert color to torch tensor + color_tensor = torch.tensor(color, dtype=torch.float32, device=device) + + # Convert from 0-255 range to 0-1 range + color_tensor = color_tensor / 255.0 + + # Apply ImageNet normalization + mean = IMAGENET_MEAN.view(3, 1, 1).to(device) + std = IMAGENET_STD.view(3, 1, 1).to(device) + normalized_color = (color_tensor.view(3, 1, 1) - mean) / std + replacement_image = normalized_color.expand(-1, height, width) # [3, H, W] + + else: + raise ValueError(f"Unknown replacement strategy: {replacement}") + + return replacement_image diff --git a/ciao/data/segmentation.py b/ciao/data/segmentation.py new file mode 100644 index 0000000..e13c349 --- /dev/null +++ b/ciao/data/segmentation.py @@ -0,0 +1,264 @@ +import math +from typing import Literal + +import numpy as np +import torch + + +def _hex_round_vectorized( + q: np.ndarray, r: np.ndarray +) -> tuple[np.ndarray, np.ndarray]: + """Vectorized hex rounding for entire arrays of axial coordinates. + + Args: + q: Array of fractional axial q coordinates + r: Array of fractional axial r coordinates + + Returns: + (q_rounded, r_rounded): Integer axial coordinates of nearest hex + """ + x = q + z = r + y = -x - z + + rx = np.round(x) + ry = np.round(y) + rz = np.round(z) + + dx = np.abs(rx - x) + dy = np.abs(ry - y) + dz = np.abs(rz - z) + + # Vectorized conditional logic + cond1 = (dx > dy) & (dx > dz) + cond2 = dy > dz + + rx = np.where(cond1, -ry - rz, rx) + ry = np.where(~cond1 & cond2, -rx - rz, ry) + rz = np.where(~cond1 & ~cond2, -rx - ry, rz) + + return rx.astype(np.int32), rz.astype(np.int32) + + +def _build_square_adjacency_list( + segments: np.ndarray, neighborhood: int = 8 +) -> tuple[tuple[int, ...], ...]: + """Build adjacency list from segment array (for square grids) using vectorized operations. + + Args: + segments: 2D array mapping pixels to segment IDs + neighborhood: 4 or 8 connectivity + + Returns: + Adjacency list as tuple of tuples + """ + num_segments = segments.max() + 1 + adjacency_sets: list[set[int]] = [set() for _ in range(num_segments)] + + # Vectorized horizontal adjacency + left = segments[:, :-1].ravel() + right = segments[:, 1:].ravel() + mask_h = left != right + edges_h = np.column_stack([left[mask_h], right[mask_h]]) + + # Vectorized vertical adjacency + top = segments[:-1, :].ravel() + bottom = segments[1:, :].ravel() + mask_v = top != bottom + edges_v = np.column_stack([top[mask_v], bottom[mask_v]]) + + # Collect all edges + edge_arrays = [edges_h, edges_v] + + if neighborhood == 8: + # Vectorized diagonal adjacency (down-right) + top_left = segments[:-1, :-1].ravel() + bottom_right = segments[1:, 1:].ravel() + mask_dr = top_left != bottom_right + edges_dr = np.column_stack([top_left[mask_dr], bottom_right[mask_dr]]) + + # Vectorized diagonal adjacency (down-left) + top_right = segments[:-1, 1:].ravel() + bottom_left = segments[1:, :-1].ravel() + mask_dl = top_right != bottom_left + edges_dl = np.column_stack([top_right[mask_dl], bottom_left[mask_dl]]) + + edge_arrays.extend([edges_dr, edges_dl]) + + # Stack all edges together and populate adjacency sets in a single loop + all_edges = np.vstack(edge_arrays) + for seg1, seg2 in all_edges: + adjacency_sets[seg1].add(seg2) + adjacency_sets[seg2].add(seg1) + + # Convert to tuple of tuples + return tuple(tuple(sorted(neighbors)) for neighbors in adjacency_sets) + + +def _build_fast_adjacency_list( + hex_to_id: dict[tuple[int, int], int], max_id: int +) -> tuple[tuple[int, ...], ...]: + """Create a static adjacency list optimized for fast reading. + + Args: + hex_to_id: Dict mapping (q, r) -> int_id (0 to N-1) + max_id: Total number of segments (N) + + Returns: + adj_list: Tuple of Tuples. + adj_list[5] returns e.g. (4, 6, 12) - neighbors of segment 5. + """ + # Initialize empty lists for each ID + # Use list of lists for construction + temp_adj: list[list[int]] = [[] for _ in range(max_id)] + + # Offsets for neighbors (axial coords) + hex_neighbors = [(+1, 0), (+1, -1), (0, -1), (-1, 0), (-1, +1), (0, +1)] + + for (q, r), seg_id in hex_to_id.items(): + for dq, dr in hex_neighbors: + neighbor_key = (q + dq, r + dr) + + # If neighbor exists (is within the image) + if neighbor_key in hex_to_id: + neighbor_id = hex_to_id[neighbor_key] + temp_adj[seg_id].append(neighbor_id) + + # Convert to tuple of tuples for maximum read speed and memory efficiency + # Sort neighbors (optional, but good for determinism) + final_adj = tuple(tuple(sorted(neighbors)) for neighbors in temp_adj) + + return final_adj + + +def _build_adjacency_bitmasks(adj_list: tuple[tuple[int, ...], ...]) -> tuple[int, ...]: + """Convert adjacency list to a list of integers. + + adj_masks[5] will be an integer with bits set at positions of hex 5's neighbors. + """ + adj_masks = [] + for neighbors in adj_list: + mask = 0 + for n in neighbors: + mask |= 1 << n + adj_masks.append(mask) + return tuple(adj_masks) + + +def _create_square_grid( + input_tensor: torch.Tensor, square_size: int = 14, neighborhood: int = 8 +) -> tuple[np.ndarray, tuple[tuple[int, ...], ...]]: + """Create a grid of squares with adjacency list representing spatial relationships.""" + _channels, height, width = input_tensor.shape + segments = np.zeros((height, width), dtype=np.int32) + + segment_id = 0 + + # Create square grid + for row in range(0, height, square_size): + for col in range(0, width, square_size): + # Define square boundaries + row_end = min(row + square_size, height) + col_end = min(col + square_size, width) + + # Assign segment ID to all pixels in this square + segments[row:row_end, col:col_end] = segment_id + segment_id += 1 + + # Build adjacency list + adjacency_list = _build_square_adjacency_list(segments, neighborhood=neighborhood) + + return segments, adjacency_list + + +def _create_hexagonal_grid( + input_tensor: torch.Tensor, hex_radius: int = 14 +) -> tuple[np.ndarray, tuple[tuple[int, ...], ...]]: + """Create a grid of hexagons with adjacency list using vectorized operations. + + Uses axial coordinate system for precise hexagonal tiling (flat-top orientation). + Each hexagon has exactly 6 neighbors (neighborhood parameter ignored). + + Args: + input_tensor: Input image tensor [C, H, W] + hex_radius: Hex size parameter (distance from center to flat edge, default: 14) + + Returns: + segments: 2D array mapping pixels to segment IDs + adjacency_list: Tuple of tuples representing segment relationships + """ + _channels, height, width = input_tensor.shape + + # Create coordinate grids using meshgrid + x_coords, y_coords = np.meshgrid(np.arange(width), np.arange(height)) + + # Vectorized pixel to hex conversion + sqrt3 = math.sqrt(3) + q_float = (sqrt3 / 3 * x_coords - 1 / 3 * y_coords) / hex_radius + r_float = (2 / 3 * y_coords) / hex_radius + + # Vectorized hex rounding + q_int, r_int = _hex_round_vectorized(q_float, r_float) + + # Stack q and r to create unique keys + qr_stacked = np.stack([q_int.ravel(), r_int.ravel()], axis=1) + + # Use np.unique to assign segment IDs efficiently (compute only once) + unique_qr, segments_flat = np.unique(qr_stacked, axis=0, return_inverse=True) + segments = segments_flat.reshape((height, width)).astype(np.int32) + + # Build hex_to_id mapping for adjacency construction + hex_to_id = {(int(q), int(r)): idx for idx, (q, r) in enumerate(unique_qr)} + + # Build adjacency list using axial coordinate neighbors + adjacency_list = _build_fast_adjacency_list(hex_to_id, len(hex_to_id)) + + return segments, adjacency_list + + +def create_segmentation( + input_tensor: torch.Tensor, + segmentation_type: Literal["square", "hexagonal"] = "hexagonal", + segment_size: int = 14, + neighborhood: int = 8, +) -> tuple[np.ndarray, tuple[int, ...]]: + """Create image segmentation with specified type. + + Args: + input_tensor: Input image tensor [C, H, W] + segmentation_type: "square" or "hexagonal" + segment_size: Size parameter (square_size or hex_radius) + neighborhood: Neighborhood connectivity for squares (4, or 8) + + Returns: + segments: 2D array mapping pixels to segment IDs + adj_masks: Tuple of integer bitmasks representing adjacency relationships + """ + if segment_size <= 0: + raise ValueError( + f"segment_size must be positive, got {segment_size}. " + "Non-positive values cause division by zero or invalid range operations." + ) + + if segmentation_type not in ("square", "hexagonal"): + raise ValueError( + f"Unknown segmentation_type: {segmentation_type}. Use 'square' or 'hexagonal'." + ) + + if segmentation_type == "square" and neighborhood not in (4, 8): + raise ValueError( + f"For square segmentation, neighborhood must be 4 or 8, got {neighborhood}." + ) + + if segmentation_type == "square": + segments, adjacency_list = _create_square_grid( + input_tensor, square_size=segment_size, neighborhood=neighborhood + ) + else: + segments, adjacency_list = _create_hexagonal_grid( + input_tensor, hex_radius=segment_size + ) + + # Convert adjacency list to bitmasks + adj_masks = _build_adjacency_bitmasks(adjacency_list) + return segments, adj_masks diff --git a/pyproject.toml b/pyproject.toml index 066f668..3d9d8d3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,7 +27,6 @@ dependencies = [ # Scientific computing "numpy>=1.21.0", - "networkx>=2.6.0", # Others "tqdm>=4.0.0", diff --git a/uv.lock b/uv.lock index 1122328..a6ebabe 100644 --- a/uv.lock +++ b/uv.lock @@ -2407,7 +2407,6 @@ dependencies = [ { name = "ipywidgets" }, { name = "matplotlib" }, { name = "mlflow" }, - { name = "networkx" }, { name = "numpy" }, { name = "omegaconf" }, { name = "pillow" }, @@ -2434,7 +2433,6 @@ requires-dist = [ { name = "ipywidgets", specifier = ">=7.0.0" }, { name = "matplotlib", specifier = ">=3.5.0" }, { name = "mlflow", specifier = ">=3.0" }, - { name = "networkx", specifier = ">=2.6.0" }, { name = "numpy", specifier = ">=1.21.0" }, { name = "omegaconf", specifier = ">=2.3.0" }, { name = "pillow", specifier = ">=9.0.0" },