Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@
strict = True
ignore_missing_imports = True
disallow_untyped_calls = False
disable_error_code = no-any-return
disable_error_code = no-any-return
explicit_package_bases = True
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ ciao/
│ │ ├── nodes.py # Node classes for tree/graph search
│ │ └── search_helpers.py # Shared MCTS/MCGS helper functions
│ ├── data/ # Data loading and preprocessing
│ │ ├── loader.py # Image loaders
│ │ ├── loader.py # Path loaders
│ │ ├── preprocessing.py # Image preprocessing utilities
│ │ └── segmentation.py # Segmentation utilities (hex/square grids)
│ ├── evaluation/ # Scoring and evaluation
Expand Down
18 changes: 18 additions & 0 deletions ciao/data/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
"""Data loading utilities for CIAO."""

from ciao.data.loader import iter_image_paths
from ciao.data.preprocessing import load_and_preprocess_image
from ciao.data.replacement import (
calculate_image_mean_color,
get_replacement_image,
)
from ciao.data.segmentation import create_segmentation


__all__ = [
"calculate_image_mean_color",
"create_segmentation",
"get_replacement_image",
"iter_image_paths",
"load_and_preprocess_image",
]
61 changes: 61 additions & 0 deletions ciao/data/loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
"""Simple image path resolution utilities."""

from collections.abc import Iterator
from pathlib import Path

from omegaconf import DictConfig


# Supported image formats
IMAGE_EXTENSIONS = (".jpg", ".jpeg", ".png", ".bmp", ".webp")


def iter_image_paths(config: DictConfig) -> Iterator[Path]:
"""Generate paths to images based on configuration.

Args:
config: Hydra config object containing data.image_path or data.batch_path

Returns:
Iterator of Path objects pointing to valid images

Raises:
ValueError: If config specifies both or neither paths
FileNotFoundError: If single image_path does not exist
NotADirectoryError: If batch_path directory does not exist
"""
image_path_value = config.data.get("image_path")
batch_path_value = config.data.get("batch_path")

if image_path_value and batch_path_value:
raise ValueError("Specify exactly one of image_path or batch_path in config")

if not image_path_value and not batch_path_value:
raise ValueError("Must specify either image_path or batch_path in config")

if image_path_value:
image_path = Path(image_path_value)
if not image_path.is_file():
raise FileNotFoundError(
f"image_path must be a valid file, got: {image_path}. "
"Check for typos or incorrect path configuration."
)
if image_path.suffix.lower() not in IMAGE_EXTENSIONS:
raise ValueError(
f"image_path must use a supported image extension, got: {image_path}"
)

if batch_path_value:
directory = Path(batch_path_value)
if not directory.is_dir():
raise NotADirectoryError(
f"batch_path must be a valid directory, got: {directory}. "
"Check for typos or incorrect path configuration."
)

if image_path_value:
yield image_path
else:
for path in directory.rglob("*"):
if path.is_file() and path.suffix.lower() in IMAGE_EXTENSIONS:
yield path
41 changes: 41 additions & 0 deletions ciao/data/preprocessing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from pathlib import Path
from typing import cast

import torch
import torchvision.transforms as transforms
from PIL import Image


# ImageNet preprocessing transforms
preprocess = transforms.Compose(
[
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
)


def load_and_preprocess_image(
image_path: str | Path, device: torch.device | None = None
) -> torch.Tensor:
"""Load and preprocess an image for the model.

Args:
image_path: Path to image file
device: Device to place tensor on (defaults to cuda if available, else cpu)

Returns:
Preprocessed image tensor [3, 224, 224] on specified device
"""
if device is None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Use context manager to prevent file descriptor leaks
with Image.open(image_path) as img:
image = img.convert("RGB")
tensor = cast("torch.Tensor", preprocess(image)) # (3, 224, 224)
input_tensor = tensor.to(device)

return input_tensor
108 changes: 108 additions & 0 deletions ciao/data/replacement.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
"""Image replacement strategies for masking operations."""

from typing import Literal

import torch
import torchvision.transforms.functional as TF


# ImageNet normalization constants
IMAGENET_MEAN = torch.tensor([0.485, 0.456, 0.406])
IMAGENET_STD = torch.tensor([0.229, 0.224, 0.225])


def calculate_image_mean_color(input_tensor: torch.Tensor) -> torch.Tensor:
"""Calculate image mean color using ImageNet normalization constants.

Args:
input_tensor: Input tensor [3, H, W] or [1, 3, H, W] (ImageNet normalized)

Returns:
Mean color tensor [3, 1, 1] (ImageNet normalized)
"""
device = input_tensor.device

# Add batch dimension if needed
if input_tensor.dim() == 3:
input_tensor = input_tensor.unsqueeze(0)

# Move normalization constants to same device
imagenet_mean = IMAGENET_MEAN.view(1, 3, 1, 1).to(device)
imagenet_std = IMAGENET_STD.view(1, 3, 1, 1).to(device)

# Unnormalize, calculate mean, then re-normalize
unnormalized = (input_tensor * imagenet_std) + imagenet_mean
mean_color = unnormalized.mean(dim=(2, 3), keepdim=True)
normalized_mean = (mean_color - imagenet_mean) / imagenet_std

return normalized_mean.squeeze(0) # Remove batch dimension


def get_replacement_image(
input_tensor: torch.Tensor,
replacement: Literal[
"mean_color", "interlacing", "blur", "solid_color"
] = "mean_color",
color: tuple[int, int, int] = (0, 0, 0),
) -> torch.Tensor:
"""Generate replacement image for masking operations.

Args:
input_tensor: Input tensor [3, H, W] (ImageNet normalized)
replacement: Strategy - "mean_color", "interlacing", "blur", or "solid_color"
color: For solid_color mode, RGB tuple (0-255). Defaults to black (0, 0, 0)

Returns:
replacement_image: torch tensor [3, H, W] on same device
"""
device = input_tensor.device

# Extract spatial dimensions from input tensor
_, height, width = input_tensor.shape

if replacement == "mean_color":
# Fill entire image with mean color
mean_color = calculate_image_mean_color(input_tensor) # [3, 1, 1]
replacement_image = mean_color.expand(-1, height, width) # [3, H, W]

elif replacement == "interlacing":
# Create interlaced pattern: even columns flipped vertically, then even rows flipped horizontally
replacement_image = input_tensor.clone()
even_row_indices = torch.arange(0, height, 2) # Even row indices
even_col_indices = torch.arange(0, width, 2) # Even column indices

# Step 1: Flip even columns vertically (upside down)
replacement_image[:, :, even_col_indices] = torch.flip(
replacement_image[:, :, even_col_indices], dims=[1]
)

# Step 2: Flip even rows horizontally (left-right)
replacement_image[:, even_row_indices, :] = torch.flip(
replacement_image[:, even_row_indices, :], dims=[2]
)

elif replacement == "blur":
# Apply Gaussian blur using torchvision functional API
input_batch = input_tensor.unsqueeze(0) # [1, 3, H, W]
replacement_image = TF.gaussian_blur(
input_batch, kernel_size=[7, 7], sigma=[1.5, 1.5]
).squeeze(0) # [3, H, W]

elif replacement == "solid_color":
# Fill with specified solid color (expects RGB values in 0-255 range)
# Convert color to torch tensor
color_tensor = torch.tensor(color, dtype=torch.float32, device=device)

# Convert from 0-255 range to 0-1 range
color_tensor = color_tensor / 255.0

# Apply ImageNet normalization
mean = IMAGENET_MEAN.view(3, 1, 1).to(device)
std = IMAGENET_STD.view(3, 1, 1).to(device)
normalized_color = (color_tensor.view(3, 1, 1) - mean) / std
replacement_image = normalized_color.expand(-1, height, width) # [3, H, W]

else:
raise ValueError(f"Unknown replacement strategy: {replacement}")

return replacement_image
Loading