Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,8 @@ ciao/
│ │ ├── loader.py # Path loaders
│ │ ├── preprocessing.py # Image preprocessing utilities
│ │ └── segmentation.py # Segmentation utilities (hex/square grids)
│ ├── evaluation/ # Scoring and evaluation
│ │ ├── surrogate.py # Surrogate dataset creation and segment scoring
│ ├── scoring/ # Scoring
│ │ ├── segments.py # Surrogate dataset creation and segment scoring
│ │ └── hyperpixel.py # Hyperpixel evaluation and selection
│ ├── explainer/ # Core explainer implementation
│ │ └── ciao_explainer.py # Main CIAO explainer class
Expand Down
9 changes: 9 additions & 0 deletions ciao/algorithm/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
"""CIAO algorithm implementations."""

from ciao.algorithm.bitmask_graph import get_frontier, iter_bits


__all__ = [
"get_frontier",
"iter_bits",
]
55 changes: 55 additions & 0 deletions ciao/algorithm/bitmask_graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
"""Bitmask-based graph utilities for efficient segment manipulation.

This module will provide low-level primitives for working with graph structures
represented as integer bitmasks, where each bit represents a node/segment.
"""

from collections.abc import Iterator


def iter_bits(mask: int) -> Iterator[int]:
"""Iterate over set bits in a mask using low-bit isolation.

Yields node IDs in arbitrary order (depends on bit positions).
Performance: O(k) where k is the number of set bits.

Example:
mask = 0b10110 # bits 1, 2, 4 are set
list(iter_bits(mask)) # [1, 2, 4]
"""
temp = mask
while temp:
low_bit = temp & -temp
node_id = low_bit.bit_length() - 1
yield node_id
temp ^= low_bit


def get_frontier(mask: int, adj_masks: tuple[int, ...], used_mask: int) -> int:
"""Compute the expansion frontier (valid neighbors) for graph traversal.

The frontier is the set of segments adjacent to the current structure
that can be added in the next step.

A segment is in the frontier if:
- It is adjacent to at least one segment in the current mask
- It is NOT already in the current mask
- It is NOT in the used_mask (respects global exclusion constraints)

Args:
mask: Bitmask of current structure/connected component
adj_masks: Tuple of adjacency bitmasks (adj_masks[i] = neighbors of segment i)
used_mask: Bitmask of globally excluded segments

Returns:
Bitmask of valid frontier segments
"""
frontier = 0

for node_id in iter_bits(mask):
frontier |= adj_masks[node_id]

frontier &= ~mask
frontier &= ~used_mask

return frontier
6 changes: 6 additions & 0 deletions ciao/model/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
"""Model prediction utilities for CIAO."""

from ciao.model.predictor import ModelPredictor


__all__ = ["ModelPredictor"]
37 changes: 37 additions & 0 deletions ciao/model/predictor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import torch


class ModelPredictor:
"""Handles model predictions and class information for the CIAO explainer."""

def __init__(self, model: torch.nn.Module, class_names: list[str]) -> None:
self.model = model
self.class_names = class_names

# Ensure deterministic inference by disabling Dropout and freezing BatchNorm
self.model.eval()

# Robustly determine the device (fall back to CPU if model has no parameters)
try:
self.device = next(model.parameters()).device
except StopIteration:
self.device = torch.device("cpu")

def get_predictions(self, input_batch: torch.Tensor) -> torch.Tensor:
"""Get model predictions (returns probabilities)."""
input_batch = input_batch.to(self.device)

with torch.no_grad():
outputs = self.model(input_batch)
probabilities = torch.nn.functional.softmax(outputs, dim=1)
return probabilities

def get_class_logit_batch(
self, input_batch: torch.Tensor, target_class_idx: int
) -> torch.Tensor:
"""Get raw logits for a specific target class across a batch of images."""
input_batch = input_batch.to(self.device)

with torch.no_grad():
outputs = self.model(input_batch)
return outputs[:, target_class_idx]
18 changes: 18 additions & 0 deletions ciao/scoring/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
"""Evaluation and scoring utilities for segments and hyperpixels."""

from ciao.scoring.hyperpixel import (
calculate_hyperpixel_deltas,
select_top_hyperpixels,
)
from ciao.scoring.segments import (
calculate_segment_scores,
create_surrogate_dataset,
)


__all__ = [
"calculate_hyperpixel_deltas",
"calculate_segment_scores",
"create_surrogate_dataset",
"select_top_hyperpixels",
]
115 changes: 115 additions & 0 deletions ciao/scoring/hyperpixel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
import numpy as np
import torch

from ciao.model.predictor import ModelPredictor


def calculate_hyperpixel_deltas(
predictor: ModelPredictor,
input_batch: torch.Tensor,
segments: np.ndarray,
hyperpixel_segment_ids_list: list[list[int]],
replacement_image: torch.Tensor,
target_class_idx: int,
batch_size: int = 64,
) -> list[float]:
"""Calculate masking deltas for hyperpixel candidates using batched inference.

Handles internal batching to prevent memory overflow with large path counts.

Args:
predictor: ModelPredictor instance
input_batch: Input tensor batch [1, C, H, W]
segments: Pixel-to-segment mapping array [H, W]
hyperpixel_segment_ids_list: List of segment ID lists, e.g. [[1,2,3], [4,5,6]]
replacement_image: Replacement tensor [C, H, W]
target_class_idx: Target class index
batch_size: Batch size

Returns:
List[float]: Delta scores for each candidate
"""
if not hyperpixel_segment_ids_list:
return []

# Validate all segment lists are non-empty
for i, segment_ids in enumerate(hyperpixel_segment_ids_list):
if not segment_ids:
raise ValueError(f"Empty segment list at index {i}")

if input_batch.dim() != 4 or input_batch.shape[0] != 1:
raise ValueError(
f"input_batch must have shape [1, C, H, W], got {tuple(input_batch.shape)}"
)

expected_shape = input_batch.shape[1:]
if tuple(replacement_image.shape) != tuple(expected_shape):
raise ValueError(
"replacement_image must have shape [C, H, W] matching input_batch, "
f"got {tuple(replacement_image.shape)} vs expected {tuple(expected_shape)}"
)

# Move tensors to the predictor's device once to avoid repeated transfers.
# Align replacement_image dtype with input_batch to prevent torch.where errors.
input_batch = input_batch.to(predictor.device)
replacement_image = replacement_image.to(
device=predictor.device, dtype=input_batch.dtype
)

with torch.no_grad():
original_logit = predictor.get_class_logit_batch(input_batch, target_class_idx)[
0
].item()

gpu_segments = torch.from_numpy(segments).to(predictor.device)

all_deltas = []
num_masks = len(hyperpixel_segment_ids_list)

for batch_start in range(0, num_masks, batch_size):
batch_end = min(batch_start + batch_size, num_masks)
current_batch_size = batch_end - batch_start

# Clone on GPU directly
batch_inputs = input_batch.repeat(current_batch_size, 1, 1, 1)

# Fully vectorized mask creation
mask_list = []
for segment_ids in hyperpixel_segment_ids_list[batch_start:batch_end]:
target_ids = torch.tensor(
segment_ids, dtype=gpu_segments.dtype, device=predictor.device
)
mask_list.append(torch.isin(gpu_segments, target_ids))

# mask_tensor shape: [batch_size, H, W]
mask_tensor = torch.stack(mask_list)

# Apply masks using a single broadcasted operation
batch_inputs = torch.where(
mask_tensor.unsqueeze(1), # [batch_size, 1, H, W]
replacement_image.unsqueeze(0), # [1, C, H, W]
batch_inputs, # [batch_size, C, H, W]
)

masked_logits = predictor.get_class_logit_batch(
batch_inputs, target_class_idx
)
batch_deltas = [
original_logit - masked_logit.item() for masked_logit in masked_logits
]
all_deltas.extend(batch_deltas)

del batch_inputs, masked_logits, mask_tensor

return all_deltas


def select_top_hyperpixels(
hyperpixels: list[dict[str, object]], max_hyperpixels: int = 10
) -> list[dict[str, object]]:
"""Select top hyperpixels by their primary algorithm-specific score."""
return sorted(
hyperpixels,
key=lambda hp: abs(hp["hyperpixel_score"]), # type: ignore[arg-type]
reverse=True,
)[:max_hyperpixels]
136 changes: 136 additions & 0 deletions ciao/scoring/segments.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
import logging

import numpy as np
import torch

from ciao.algorithm.bitmask_graph import get_frontier, iter_bits
from ciao.model.predictor import ModelPredictor
from ciao.scoring.hyperpixel import calculate_hyperpixel_deltas


logger = logging.getLogger(__name__)


def create_surrogate_dataset(
predictor: ModelPredictor,
input_batch: torch.Tensor,
replacement_image: torch.Tensor,
segments: np.ndarray,
adj_masks: tuple[int, ...],
target_class_idx: int,
neighborhood_distance: int = 1,
batch_size: int = 16,
) -> tuple[np.ndarray, np.ndarray]:
"""Create surrogate dataset for interpretability.

Each row represents one masking operation:
- Features (X): Binary indicator vector [num_segments] - 1 if segment was masked, 0 otherwise
- Target (y): Delta score (original_logit - masked_logit)

This dataset can be used for:
- Computing segment importance scores
- Training interpretable models (like LIME does)
- Analyzing masking effects

Args:
predictor: ModelPredictor instance
input_batch: Input tensor batch
replacement_image: Replacement tensor [C, H, W]
segments: Pixel-to-segment mapping array [H, W]
adj_masks: Tuple of adjacency bitmasks where bit i indicates neighbor i
target_class_idx: Target class index
neighborhood_distance: Distance for neighborhood masking
batch_size: Batch size for processing segments

Returns:
X: Binary indicator matrix [num_samples, num_segments]
y: Delta scores array [num_samples]
"""
# BFS algorithm using low-level bitmask graph operations
local_groups = []
num_segments = len(adj_masks)

for segment_id in range(num_segments):
visited_mask = 1 << segment_id
current_layer_mask = visited_mask

for _ in range(neighborhood_distance):
next_layer_mask = get_frontier(
mask=current_layer_mask, adj_masks=adj_masks, used_mask=visited_mask
)

# Early exit if we reached the boundary of the isolated graph component
if not next_layer_mask:
break

visited_mask |= next_layer_mask
current_layer_mask = next_layer_mask

local_groups.append(list(iter_bits(visited_mask)))

# Calculate deltas for all local groups
deltas = calculate_hyperpixel_deltas(
predictor,
input_batch,
segments,
local_groups,
replacement_image,
target_class_idx,
batch_size=batch_size,
)

# Create surrogate dataset
num_samples = len(local_groups)
X = np.zeros((num_samples, num_segments), dtype=np.float32)
y = np.array(deltas, dtype=np.float32)

# Fast vectorized indicator matrix filling
for i, masked_segments in enumerate(local_groups):
X[i, masked_segments] = 1.0

logger.info(f"Created surrogate dataset: X shape {X.shape}, y shape {y.shape}")
logger.info(f"Average delta: {y.mean():.4f}, std: {y.std():.4f}")

return X, y


def calculate_segment_scores(X: np.ndarray, y: np.ndarray) -> dict[int, float]: # noqa: N803
"""Calculate neighborhood-smoothed segment importance scores from sampled deltas.

This function computes the mean delta score for each segment across all
surrogate samples where that segment was masked. It acts as a fast
approximation of the segment's marginal contribution to the prediction.

Args:
X: Binary indicator matrix of shape [num_samples, num_segments].
X[i, j] == 1 if segment j is masked in sample i, else 0.
y: Delta scores array of shape [num_samples].

Returns:
Dict mapping segment_id -> averaged importance score.

Raises:
ValueError: If any segment was never masked in the surrogate dataset.
"""
# Vectorized count of how many times each segment was masked
counts = X.sum(axis=0)

# Fail-fast validation for unmasked segments
unmasked_indices = np.where(counts == 0)[0]
if unmasked_indices.size > 0:
raise ValueError(
f"Segment(s) {unmasked_indices.tolist()} never appear in any local group. "
"This suggests a bug in group generation or segment ID mapping."
)

# Vectorized mean calculation using matrix multiplication
segment_means = (y @ X) / counts

# Convert the numpy array results to the expected dictionary format
scores = {int(i): float(score) for i, score in enumerate(segment_means)}

if scores:
score_values = list(scores.values())
logger.info(f"Score range: [{min(score_values):.4f}, {max(score_values):.4f}]")

return scores