diff --git a/README.md b/README.md index 17bee0b..e75e36f 100644 --- a/README.md +++ b/README.md @@ -101,8 +101,8 @@ ciao/ │ │ ├── loader.py # Path loaders │ │ ├── preprocessing.py # Image preprocessing utilities │ │ └── segmentation.py # Segmentation utilities (hex/square grids) -│ ├── evaluation/ # Scoring and evaluation -│ │ ├── surrogate.py # Surrogate dataset creation and segment scoring +│ ├── scoring/ # Scoring +│ │ ├── segments.py # Surrogate dataset creation and segment scoring │ │ └── hyperpixel.py # Hyperpixel evaluation and selection │ ├── explainer/ # Core explainer implementation │ │ └── ciao_explainer.py # Main CIAO explainer class diff --git a/ciao/algorithm/__init__.py b/ciao/algorithm/__init__.py new file mode 100644 index 0000000..83cc50d --- /dev/null +++ b/ciao/algorithm/__init__.py @@ -0,0 +1,9 @@ +"""CIAO algorithm implementations.""" + +from ciao.algorithm.bitmask_graph import get_frontier, iter_bits + + +__all__ = [ + "get_frontier", + "iter_bits", +] diff --git a/ciao/algorithm/bitmask_graph.py b/ciao/algorithm/bitmask_graph.py new file mode 100644 index 0000000..f270f8b --- /dev/null +++ b/ciao/algorithm/bitmask_graph.py @@ -0,0 +1,55 @@ +"""Bitmask-based graph utilities for efficient segment manipulation. + +This module will provide low-level primitives for working with graph structures +represented as integer bitmasks, where each bit represents a node/segment. +""" + +from collections.abc import Iterator + + +def iter_bits(mask: int) -> Iterator[int]: + """Iterate over set bits in a mask using low-bit isolation. + + Yields node IDs in arbitrary order (depends on bit positions). + Performance: O(k) where k is the number of set bits. + + Example: + mask = 0b10110 # bits 1, 2, 4 are set + list(iter_bits(mask)) # [1, 2, 4] + """ + temp = mask + while temp: + low_bit = temp & -temp + node_id = low_bit.bit_length() - 1 + yield node_id + temp ^= low_bit + + +def get_frontier(mask: int, adj_masks: tuple[int, ...], used_mask: int) -> int: + """Compute the expansion frontier (valid neighbors) for graph traversal. + + The frontier is the set of segments adjacent to the current structure + that can be added in the next step. + + A segment is in the frontier if: + - It is adjacent to at least one segment in the current mask + - It is NOT already in the current mask + - It is NOT in the used_mask (respects global exclusion constraints) + + Args: + mask: Bitmask of current structure/connected component + adj_masks: Tuple of adjacency bitmasks (adj_masks[i] = neighbors of segment i) + used_mask: Bitmask of globally excluded segments + + Returns: + Bitmask of valid frontier segments + """ + frontier = 0 + + for node_id in iter_bits(mask): + frontier |= adj_masks[node_id] + + frontier &= ~mask + frontier &= ~used_mask + + return frontier diff --git a/ciao/model/__init__.py b/ciao/model/__init__.py new file mode 100644 index 0000000..dd425b4 --- /dev/null +++ b/ciao/model/__init__.py @@ -0,0 +1,6 @@ +"""Model prediction utilities for CIAO.""" + +from ciao.model.predictor import ModelPredictor + + +__all__ = ["ModelPredictor"] diff --git a/ciao/model/predictor.py b/ciao/model/predictor.py new file mode 100644 index 0000000..5efd700 --- /dev/null +++ b/ciao/model/predictor.py @@ -0,0 +1,37 @@ +import torch + + +class ModelPredictor: + """Handles model predictions and class information for the CIAO explainer.""" + + def __init__(self, model: torch.nn.Module, class_names: list[str]) -> None: + self.model = model + self.class_names = class_names + + # Ensure deterministic inference by disabling Dropout and freezing BatchNorm + self.model.eval() + + # Robustly determine the device (fall back to CPU if model has no parameters) + try: + self.device = next(model.parameters()).device + except StopIteration: + self.device = torch.device("cpu") + + def get_predictions(self, input_batch: torch.Tensor) -> torch.Tensor: + """Get model predictions (returns probabilities).""" + input_batch = input_batch.to(self.device) + + with torch.no_grad(): + outputs = self.model(input_batch) + probabilities = torch.nn.functional.softmax(outputs, dim=1) + return probabilities + + def get_class_logit_batch( + self, input_batch: torch.Tensor, target_class_idx: int + ) -> torch.Tensor: + """Get raw logits for a specific target class across a batch of images.""" + input_batch = input_batch.to(self.device) + + with torch.no_grad(): + outputs = self.model(input_batch) + return outputs[:, target_class_idx] diff --git a/ciao/scoring/__init__.py b/ciao/scoring/__init__.py new file mode 100644 index 0000000..680d121 --- /dev/null +++ b/ciao/scoring/__init__.py @@ -0,0 +1,18 @@ +"""Evaluation and scoring utilities for segments and hyperpixels.""" + +from ciao.scoring.hyperpixel import ( + calculate_hyperpixel_deltas, + select_top_hyperpixels, +) +from ciao.scoring.segments import ( + calculate_segment_scores, + create_surrogate_dataset, +) + + +__all__ = [ + "calculate_hyperpixel_deltas", + "calculate_segment_scores", + "create_surrogate_dataset", + "select_top_hyperpixels", +] diff --git a/ciao/scoring/hyperpixel.py b/ciao/scoring/hyperpixel.py new file mode 100644 index 0000000..39c2580 --- /dev/null +++ b/ciao/scoring/hyperpixel.py @@ -0,0 +1,115 @@ +import numpy as np +import torch + +from ciao.model.predictor import ModelPredictor + + +def calculate_hyperpixel_deltas( + predictor: ModelPredictor, + input_batch: torch.Tensor, + segments: np.ndarray, + hyperpixel_segment_ids_list: list[list[int]], + replacement_image: torch.Tensor, + target_class_idx: int, + batch_size: int = 64, +) -> list[float]: + """Calculate masking deltas for hyperpixel candidates using batched inference. + + Handles internal batching to prevent memory overflow with large path counts. + + Args: + predictor: ModelPredictor instance + input_batch: Input tensor batch [1, C, H, W] + segments: Pixel-to-segment mapping array [H, W] + hyperpixel_segment_ids_list: List of segment ID lists, e.g. [[1,2,3], [4,5,6]] + replacement_image: Replacement tensor [C, H, W] + target_class_idx: Target class index + batch_size: Batch size + + Returns: + List[float]: Delta scores for each candidate + """ + if not hyperpixel_segment_ids_list: + return [] + + # Validate all segment lists are non-empty + for i, segment_ids in enumerate(hyperpixel_segment_ids_list): + if not segment_ids: + raise ValueError(f"Empty segment list at index {i}") + + if input_batch.dim() != 4 or input_batch.shape[0] != 1: + raise ValueError( + f"input_batch must have shape [1, C, H, W], got {tuple(input_batch.shape)}" + ) + + expected_shape = input_batch.shape[1:] + if tuple(replacement_image.shape) != tuple(expected_shape): + raise ValueError( + "replacement_image must have shape [C, H, W] matching input_batch, " + f"got {tuple(replacement_image.shape)} vs expected {tuple(expected_shape)}" + ) + + # Move tensors to the predictor's device once to avoid repeated transfers. + # Align replacement_image dtype with input_batch to prevent torch.where errors. + input_batch = input_batch.to(predictor.device) + replacement_image = replacement_image.to( + device=predictor.device, dtype=input_batch.dtype + ) + + with torch.no_grad(): + original_logit = predictor.get_class_logit_batch(input_batch, target_class_idx)[ + 0 + ].item() + + gpu_segments = torch.from_numpy(segments).to(predictor.device) + + all_deltas = [] + num_masks = len(hyperpixel_segment_ids_list) + + for batch_start in range(0, num_masks, batch_size): + batch_end = min(batch_start + batch_size, num_masks) + current_batch_size = batch_end - batch_start + + # Clone on GPU directly + batch_inputs = input_batch.repeat(current_batch_size, 1, 1, 1) + + # Fully vectorized mask creation + mask_list = [] + for segment_ids in hyperpixel_segment_ids_list[batch_start:batch_end]: + target_ids = torch.tensor( + segment_ids, dtype=gpu_segments.dtype, device=predictor.device + ) + mask_list.append(torch.isin(gpu_segments, target_ids)) + + # mask_tensor shape: [batch_size, H, W] + mask_tensor = torch.stack(mask_list) + + # Apply masks using a single broadcasted operation + batch_inputs = torch.where( + mask_tensor.unsqueeze(1), # [batch_size, 1, H, W] + replacement_image.unsqueeze(0), # [1, C, H, W] + batch_inputs, # [batch_size, C, H, W] + ) + + masked_logits = predictor.get_class_logit_batch( + batch_inputs, target_class_idx + ) + batch_deltas = [ + original_logit - masked_logit.item() for masked_logit in masked_logits + ] + all_deltas.extend(batch_deltas) + + del batch_inputs, masked_logits, mask_tensor + + return all_deltas + + +def select_top_hyperpixels( + hyperpixels: list[dict[str, object]], max_hyperpixels: int = 10 +) -> list[dict[str, object]]: + """Select top hyperpixels by their primary algorithm-specific score.""" + return sorted( + hyperpixels, + key=lambda hp: abs(hp["hyperpixel_score"]), # type: ignore[arg-type] + reverse=True, + )[:max_hyperpixels] diff --git a/ciao/scoring/segments.py b/ciao/scoring/segments.py new file mode 100644 index 0000000..a4aa8ed --- /dev/null +++ b/ciao/scoring/segments.py @@ -0,0 +1,136 @@ +import logging + +import numpy as np +import torch + +from ciao.algorithm.bitmask_graph import get_frontier, iter_bits +from ciao.model.predictor import ModelPredictor +from ciao.scoring.hyperpixel import calculate_hyperpixel_deltas + + +logger = logging.getLogger(__name__) + + +def create_surrogate_dataset( + predictor: ModelPredictor, + input_batch: torch.Tensor, + replacement_image: torch.Tensor, + segments: np.ndarray, + adj_masks: tuple[int, ...], + target_class_idx: int, + neighborhood_distance: int = 1, + batch_size: int = 16, +) -> tuple[np.ndarray, np.ndarray]: + """Create surrogate dataset for interpretability. + + Each row represents one masking operation: + - Features (X): Binary indicator vector [num_segments] - 1 if segment was masked, 0 otherwise + - Target (y): Delta score (original_logit - masked_logit) + + This dataset can be used for: + - Computing segment importance scores + - Training interpretable models (like LIME does) + - Analyzing masking effects + + Args: + predictor: ModelPredictor instance + input_batch: Input tensor batch + replacement_image: Replacement tensor [C, H, W] + segments: Pixel-to-segment mapping array [H, W] + adj_masks: Tuple of adjacency bitmasks where bit i indicates neighbor i + target_class_idx: Target class index + neighborhood_distance: Distance for neighborhood masking + batch_size: Batch size for processing segments + + Returns: + X: Binary indicator matrix [num_samples, num_segments] + y: Delta scores array [num_samples] + """ + # BFS algorithm using low-level bitmask graph operations + local_groups = [] + num_segments = len(adj_masks) + + for segment_id in range(num_segments): + visited_mask = 1 << segment_id + current_layer_mask = visited_mask + + for _ in range(neighborhood_distance): + next_layer_mask = get_frontier( + mask=current_layer_mask, adj_masks=adj_masks, used_mask=visited_mask + ) + + # Early exit if we reached the boundary of the isolated graph component + if not next_layer_mask: + break + + visited_mask |= next_layer_mask + current_layer_mask = next_layer_mask + + local_groups.append(list(iter_bits(visited_mask))) + + # Calculate deltas for all local groups + deltas = calculate_hyperpixel_deltas( + predictor, + input_batch, + segments, + local_groups, + replacement_image, + target_class_idx, + batch_size=batch_size, + ) + + # Create surrogate dataset + num_samples = len(local_groups) + X = np.zeros((num_samples, num_segments), dtype=np.float32) + y = np.array(deltas, dtype=np.float32) + + # Fast vectorized indicator matrix filling + for i, masked_segments in enumerate(local_groups): + X[i, masked_segments] = 1.0 + + logger.info(f"Created surrogate dataset: X shape {X.shape}, y shape {y.shape}") + logger.info(f"Average delta: {y.mean():.4f}, std: {y.std():.4f}") + + return X, y + + +def calculate_segment_scores(X: np.ndarray, y: np.ndarray) -> dict[int, float]: # noqa: N803 + """Calculate neighborhood-smoothed segment importance scores from sampled deltas. + + This function computes the mean delta score for each segment across all + surrogate samples where that segment was masked. It acts as a fast + approximation of the segment's marginal contribution to the prediction. + + Args: + X: Binary indicator matrix of shape [num_samples, num_segments]. + X[i, j] == 1 if segment j is masked in sample i, else 0. + y: Delta scores array of shape [num_samples]. + + Returns: + Dict mapping segment_id -> averaged importance score. + + Raises: + ValueError: If any segment was never masked in the surrogate dataset. + """ + # Vectorized count of how many times each segment was masked + counts = X.sum(axis=0) + + # Fail-fast validation for unmasked segments + unmasked_indices = np.where(counts == 0)[0] + if unmasked_indices.size > 0: + raise ValueError( + f"Segment(s) {unmasked_indices.tolist()} never appear in any local group. " + "This suggests a bug in group generation or segment ID mapping." + ) + + # Vectorized mean calculation using matrix multiplication + segment_means = (y @ X) / counts + + # Convert the numpy array results to the expected dictionary format + scores = {int(i): float(score) for i, score in enumerate(segment_means)} + + if scores: + score_values = list(scores.values()) + logger.info(f"Score range: [{min(score_values):.4f}, {max(score_values):.4f}]") + + return scores