From bc6d907c1c7078041c102b60439635ded15bef44 Mon Sep 17 00:00:00 2001 From: dhalmazna Date: Thu, 5 Mar 2026 11:20:22 +0100 Subject: [PATCH 1/8] feat: add ModelPredictor class for model predictions --- ciao/model/__init__.py | 6 ++++++ ciao/model/predictor.py | 44 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 50 insertions(+) create mode 100644 ciao/model/__init__.py create mode 100644 ciao/model/predictor.py diff --git a/ciao/model/__init__.py b/ciao/model/__init__.py new file mode 100644 index 0000000..dd425b4 --- /dev/null +++ b/ciao/model/__init__.py @@ -0,0 +1,6 @@ +"""Model prediction utilities for CIAO.""" + +from ciao.model.predictor import ModelPredictor + + +__all__ = ["ModelPredictor"] diff --git a/ciao/model/predictor.py b/ciao/model/predictor.py new file mode 100644 index 0000000..bb75c3c --- /dev/null +++ b/ciao/model/predictor.py @@ -0,0 +1,44 @@ +import torch + + +class ModelPredictor: + """Handles model predictions and class information.""" + + def __init__(self, model: torch.nn.Module, class_names: list[str]) -> None: + self.model = model + self.class_names = class_names + self.device = next(model.parameters()).device + + def get_predictions(self, input_batch: torch.Tensor) -> torch.Tensor: + """Get model predictions (returns probabilities).""" + with torch.no_grad(): + outputs = self.model(input_batch) + probabilities = torch.nn.functional.softmax(outputs, dim=1) + return probabilities + + def predict_image( + self, input_batch: torch.Tensor, top_k: int = 5 + ) -> list[tuple[int, str, float]]: + """Get top-k predictions for an image.""" + probabilities = self.get_predictions(input_batch) + top_probs, top_indices = torch.topk(probabilities[0], top_k) + + results = [] + for i in range(top_k): + class_idx = int(top_indices[i].item()) + prob = float(top_probs[i].item()) + class_name = ( + self.class_names[class_idx] + if class_idx < len(self.class_names) + else f"class_{class_idx}" + ) + results.append((class_idx, class_name, prob)) + return results + + def get_class_logit_batch( + self, input_batch: torch.Tensor, target_class_idx: int + ) -> torch.Tensor: + """Get logits for a batch of images - optimized for batched inference (directly from model outputs).""" + with torch.no_grad(): + outputs = self.model(input_batch) + return outputs[:, target_class_idx] From ab2a08bbf99e3b17e3ffff77d7ac058daa205752 Mon Sep 17 00:00:00 2001 From: dhalmazna Date: Tue, 10 Mar 2026 14:07:17 +0100 Subject: [PATCH 2/8] chore: apply agents' suggestions --- ciao/model/predictor.py | 37 +++++++++++++++---------------------- 1 file changed, 15 insertions(+), 22 deletions(-) diff --git a/ciao/model/predictor.py b/ciao/model/predictor.py index bb75c3c..5efd700 100644 --- a/ciao/model/predictor.py +++ b/ciao/model/predictor.py @@ -2,43 +2,36 @@ class ModelPredictor: - """Handles model predictions and class information.""" + """Handles model predictions and class information for the CIAO explainer.""" def __init__(self, model: torch.nn.Module, class_names: list[str]) -> None: self.model = model self.class_names = class_names - self.device = next(model.parameters()).device + + # Ensure deterministic inference by disabling Dropout and freezing BatchNorm + self.model.eval() + + # Robustly determine the device (fall back to CPU if model has no parameters) + try: + self.device = next(model.parameters()).device + except StopIteration: + self.device = torch.device("cpu") def get_predictions(self, input_batch: torch.Tensor) -> torch.Tensor: """Get model predictions (returns probabilities).""" + input_batch = input_batch.to(self.device) + with torch.no_grad(): outputs = self.model(input_batch) probabilities = torch.nn.functional.softmax(outputs, dim=1) return probabilities - def predict_image( - self, input_batch: torch.Tensor, top_k: int = 5 - ) -> list[tuple[int, str, float]]: - """Get top-k predictions for an image.""" - probabilities = self.get_predictions(input_batch) - top_probs, top_indices = torch.topk(probabilities[0], top_k) - - results = [] - for i in range(top_k): - class_idx = int(top_indices[i].item()) - prob = float(top_probs[i].item()) - class_name = ( - self.class_names[class_idx] - if class_idx < len(self.class_names) - else f"class_{class_idx}" - ) - results.append((class_idx, class_name, prob)) - return results - def get_class_logit_batch( self, input_batch: torch.Tensor, target_class_idx: int ) -> torch.Tensor: - """Get logits for a batch of images - optimized for batched inference (directly from model outputs).""" + """Get raw logits for a specific target class across a batch of images.""" + input_batch = input_batch.to(self.device) + with torch.no_grad(): outputs = self.model(input_batch) return outputs[:, target_class_idx] From 48cef9b94f94c212e4b896824a3671f8fdcba715 Mon Sep 17 00:00:00 2001 From: dhalmazna Date: Fri, 13 Mar 2026 15:11:46 +0100 Subject: [PATCH 3/8] feat: add scoring utilities for hyperpixels --- ciao/scoring/__init__.py | 12 +++++ ciao/scoring/hyperpixel.py | 107 +++++++++++++++++++++++++++++++++++++ 2 files changed, 119 insertions(+) create mode 100644 ciao/scoring/__init__.py create mode 100644 ciao/scoring/hyperpixel.py diff --git a/ciao/scoring/__init__.py b/ciao/scoring/__init__.py new file mode 100644 index 0000000..92d01d7 --- /dev/null +++ b/ciao/scoring/__init__.py @@ -0,0 +1,12 @@ +"""Evaluation and scoring utilities for segments and hyperpixels.""" + +from ciao.scoring.hyperpixel import ( + calculate_hyperpixel_deltas, + select_top_hyperpixels, +) + + +__all__ = [ + "calculate_hyperpixel_deltas", + "select_top_hyperpixels", +] diff --git a/ciao/scoring/hyperpixel.py b/ciao/scoring/hyperpixel.py new file mode 100644 index 0000000..7b8bb1a --- /dev/null +++ b/ciao/scoring/hyperpixel.py @@ -0,0 +1,107 @@ +import numpy as np +import torch + +from ciao.model.predictor import ModelPredictor + + +def calculate_hyperpixel_deltas( + predictor: ModelPredictor, + input_batch: torch.Tensor, + segments: np.ndarray, + hyperpixel_segment_ids_list: list[list[int]], + replacement_image: torch.Tensor, + target_class_idx: int, + batch_size: int = 64, +) -> list[float]: + """Calculate masking deltas for hyperpixel candidates using batched inference. + + Handles internal batching to prevent memory overflow with large path counts. + + Args: + predictor: ModelPredictor instance + input_batch: Input tensor batch [1, C, H, W] + segments: Pixel-to-segment mapping array [H, W] + hyperpixel_segment_ids_list: List of segment ID lists, e.g. [[1,2,3], [4,5,6]] + replacement_image: Replacement tensor [C, H, W] + target_class_idx: Target class index + batch_size: Batch size + + Returns: + List[float]: Delta scores for each candidate + """ + if not hyperpixel_segment_ids_list: + return [] + + # Validate all segment lists are non-empty + for i, segment_ids in enumerate(hyperpixel_segment_ids_list): + if not segment_ids: + raise ValueError(f"Empty segment list at index {i}") + + if input_batch.dim() != 4 or input_batch.shape[0] != 1: + raise ValueError( + f"input_batch must have shape [1, C, H, W], got {tuple(input_batch.shape)}" + ) + + expected_shape = input_batch.shape[1:] + if tuple(replacement_image.shape) != tuple(expected_shape): + raise ValueError( + "replacement_image must have shape [C, H, W] matching input_batch, " + f"got {tuple(replacement_image.shape)} vs expected {tuple(expected_shape)}" + ) + + replacement_image = replacement_image.to(predictor.device) + + with torch.no_grad(): + original_logit = predictor.get_class_logit_batch(input_batch, target_class_idx)[ + 0 + ].item() + + gpu_segments = torch.from_numpy(segments).to(predictor.device) + + all_deltas = [] + num_masks = len(hyperpixel_segment_ids_list) + + for batch_start in range(0, num_masks, batch_size): + batch_end = min(batch_start + batch_size, num_masks) + current_batch_size = batch_end - batch_start + + batch_inputs = input_batch.repeat(current_batch_size, 1, 1, 1) + + for i, segment_ids in enumerate( + hyperpixel_segment_ids_list[batch_start:batch_end] + ): + target_ids = torch.tensor( + segment_ids, dtype=gpu_segments.dtype, device=predictor.device + ) + combined_mask = torch.isin(gpu_segments, target_ids) + + # Apply mask with proper broadcasting + batch_inputs[i] = torch.where( + combined_mask.unsqueeze(0), + replacement_image, + batch_inputs[i], + ) + + masked_logits = predictor.get_class_logit_batch( + batch_inputs, target_class_idx + ) + batch_deltas = [ + original_logit - masked_logit.item() for masked_logit in masked_logits + ] + all_deltas.extend(batch_deltas) + + del batch_inputs, masked_logits + + return all_deltas + + +def select_top_hyperpixels( + hyperpixels: list[dict[str, object]], max_hyperpixels: int = 10 +) -> list[dict[str, object]]: + """Select top hyperpixels by their primary algorithm-specific score.""" + # Use hyperpixel_score + return sorted( + hyperpixels, + key=lambda hp: abs(hp.get("hyperpixel_score", 0)), # type: ignore[arg-type] + reverse=True, + )[:max_hyperpixels] From 35251f4f031dc35f9b9954735bfab8086a20fef2 Mon Sep 17 00:00:00 2001 From: dhalmazna Date: Fri, 13 Mar 2026 15:53:56 +0100 Subject: [PATCH 4/8] feat: add basic bitmask graph utilities --- ciao/algorithm/__init__.py | 9 ++++++ ciao/algorithm/bitmask_graph.py | 55 +++++++++++++++++++++++++++++++++ 2 files changed, 64 insertions(+) create mode 100644 ciao/algorithm/__init__.py create mode 100644 ciao/algorithm/bitmask_graph.py diff --git a/ciao/algorithm/__init__.py b/ciao/algorithm/__init__.py new file mode 100644 index 0000000..83cc50d --- /dev/null +++ b/ciao/algorithm/__init__.py @@ -0,0 +1,9 @@ +"""CIAO algorithm implementations.""" + +from ciao.algorithm.bitmask_graph import get_frontier, iter_bits + + +__all__ = [ + "get_frontier", + "iter_bits", +] diff --git a/ciao/algorithm/bitmask_graph.py b/ciao/algorithm/bitmask_graph.py new file mode 100644 index 0000000..f270f8b --- /dev/null +++ b/ciao/algorithm/bitmask_graph.py @@ -0,0 +1,55 @@ +"""Bitmask-based graph utilities for efficient segment manipulation. + +This module will provide low-level primitives for working with graph structures +represented as integer bitmasks, where each bit represents a node/segment. +""" + +from collections.abc import Iterator + + +def iter_bits(mask: int) -> Iterator[int]: + """Iterate over set bits in a mask using low-bit isolation. + + Yields node IDs in arbitrary order (depends on bit positions). + Performance: O(k) where k is the number of set bits. + + Example: + mask = 0b10110 # bits 1, 2, 4 are set + list(iter_bits(mask)) # [1, 2, 4] + """ + temp = mask + while temp: + low_bit = temp & -temp + node_id = low_bit.bit_length() - 1 + yield node_id + temp ^= low_bit + + +def get_frontier(mask: int, adj_masks: tuple[int, ...], used_mask: int) -> int: + """Compute the expansion frontier (valid neighbors) for graph traversal. + + The frontier is the set of segments adjacent to the current structure + that can be added in the next step. + + A segment is in the frontier if: + - It is adjacent to at least one segment in the current mask + - It is NOT already in the current mask + - It is NOT in the used_mask (respects global exclusion constraints) + + Args: + mask: Bitmask of current structure/connected component + adj_masks: Tuple of adjacency bitmasks (adj_masks[i] = neighbors of segment i) + used_mask: Bitmask of globally excluded segments + + Returns: + Bitmask of valid frontier segments + """ + frontier = 0 + + for node_id in iter_bits(mask): + frontier |= adj_masks[node_id] + + frontier &= ~mask + frontier &= ~used_mask + + return frontier From a82814ac04a5bf9691fffbb647cf9505f4e1b534 Mon Sep 17 00:00:00 2001 From: dhalmazna Date: Fri, 13 Mar 2026 15:54:38 +0100 Subject: [PATCH 5/8] feat: add segment scoring and surrogate dataset utilities --- ciao/scoring/__init__.py | 6 ++ ciao/scoring/segments.py | 146 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 152 insertions(+) create mode 100644 ciao/scoring/segments.py diff --git a/ciao/scoring/__init__.py b/ciao/scoring/__init__.py index 92d01d7..680d121 100644 --- a/ciao/scoring/__init__.py +++ b/ciao/scoring/__init__.py @@ -4,9 +4,15 @@ calculate_hyperpixel_deltas, select_top_hyperpixels, ) +from ciao.scoring.segments import ( + calculate_segment_scores, + create_surrogate_dataset, +) __all__ = [ "calculate_hyperpixel_deltas", + "calculate_segment_scores", + "create_surrogate_dataset", "select_top_hyperpixels", ] diff --git a/ciao/scoring/segments.py b/ciao/scoring/segments.py new file mode 100644 index 0000000..e93b21f --- /dev/null +++ b/ciao/scoring/segments.py @@ -0,0 +1,146 @@ +import logging + +import numpy as np +import torch + +from ciao.algorithm.bitmask_graph import get_frontier, iter_bits +from ciao.model.predictor import ModelPredictor +from ciao.scoring.hyperpixel import calculate_hyperpixel_deltas + + +logger = logging.getLogger(__name__) + + +def create_surrogate_dataset( + predictor: ModelPredictor, + input_batch: torch.Tensor, + replacement_image: torch.Tensor, + segments: np.ndarray, + adj_masks: tuple[int, ...], + target_class_idx: int, + neighborhood_distance: int = 1, + batch_size: int = 16, +) -> tuple[np.ndarray, np.ndarray]: + """Create surrogate dataset for interpretability. + + Each row represents one masking operation: + - Features (X): Binary indicator vector [num_segments] - 1 if segment was masked, 0 otherwise + - Target (y): Delta score (original_logit - masked_logit) + + This dataset can be used for: + - Computing segment importance scores + - Training interpretable models (like LIME does) + - Analyzing masking effects + + Args: + predictor: ModelPredictor instance + input_batch: Input tensor batch + replacement_image: Replacement tensor [C, H, W] + segments: Pixel-to-segment mapping array [H, W] + adj_masks: Tuple of adjacency bitmasks where bit i indicates neighbor i + target_class_idx: Target class index + neighborhood_distance: Distance for neighborhood masking + batch_size: Batch size for processing segments + + Returns: + X: Binary indicator matrix [num_samples, num_segments] + y: Delta scores array [num_samples] + """ + # Get original logit + original_logit = predictor.get_class_logit_batch(input_batch, target_class_idx)[ + 0 + ].item() + logger.debug(f"Original logit: {original_logit}") + logger.debug( + f"Probability of class {target_class_idx}: " + f"{predictor.get_predictions(input_batch)[0, target_class_idx].item()}" + ) + + # BFS algorithm + local_groups = [] + num_segments = len(adj_masks) + + for segment_id in range(num_segments): + visited_mask = 1 << segment_id + current_layer_mask = visited_mask + + for _ in range(neighborhood_distance): + next_layer_mask = get_frontier( + mask=current_layer_mask, adj_masks=adj_masks, used_mask=visited_mask + ) + + # Early exit if we reached the boundary of the isolated graph component + if not next_layer_mask: + break + + visited_mask |= next_layer_mask + current_layer_mask = next_layer_mask + + local_groups.append(list(iter_bits(visited_mask))) + + # Calculate deltas for all local groups + deltas = calculate_hyperpixel_deltas( + predictor, + input_batch, + segments, + local_groups, + replacement_image, + target_class_idx, + batch_size=batch_size, + ) + + # Create surrogate dataset + num_samples = len(local_groups) + X = np.zeros((num_samples, num_segments), dtype=np.float32) + y = np.array(deltas, dtype=np.float32) + + # Fill indicator matrix + for i, masked_segments in enumerate(local_groups): + X[i, masked_segments] = 1.0 + + logger.info(f"Created surrogate dataset: X shape {X.shape}, y shape {y.shape}") + logger.info(f"Average delta: {y.mean():.4f}, std: {y.std():.4f}") + + return X, y + + +def calculate_segment_scores(X: np.ndarray, y: np.ndarray) -> dict[int, float]: # noqa: N803 + """Calculate neighborhood-smoothed segment importance scores from sampled deltas. + + This function computes the mean delta score for each segment across all + surrogate samples where that segment was masked. It acts as a fast + approximation of the segment's marginal contribution to the prediction. + + Args: + X: Binary indicator matrix of shape [num_samples, num_segments]. + X[i, j] == 1 if segment j is masked in sample i, else 0. + y: Delta scores array of shape [num_samples]. + + Returns: + Dict mapping segment_id -> averaged importance score. + + Raises: + ValueError: If any segment was never masked in the surrogate dataset. + """ + # Vectorized count of how many times each segment was masked + counts = X.sum(axis=0) + + # Fail-fast validation for unmasked segments + unmasked_indices = np.where(counts == 0)[0] + if unmasked_indices.size > 0: + raise ValueError( + f"Segment(s) {unmasked_indices.tolist()} never appear in any local group. " + "This suggests a bug in group generation or segment ID mapping." + ) + + # Vectorized mean calculation using matrix multiplication + segment_means = (y @ X) / counts + + # Convert the numpy array results to the expected dictionary format + scores = {int(i): float(score) for i, score in enumerate(segment_means)} + + if scores: + score_values = list(scores.values()) + logger.info(f"Score range: [{min(score_values):.4f}, {max(score_values):.4f}]") + + return scores From 548986a0db9cccbe547d64eff1ab4f3d0887bc5c Mon Sep 17 00:00:00 2001 From: dhalmazna Date: Fri, 13 Mar 2026 15:58:24 +0100 Subject: [PATCH 6/8] docs: change README to use scoring/ instead of evaluation/ --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 17bee0b..e75e36f 100644 --- a/README.md +++ b/README.md @@ -101,8 +101,8 @@ ciao/ │ │ ├── loader.py # Path loaders │ │ ├── preprocessing.py # Image preprocessing utilities │ │ └── segmentation.py # Segmentation utilities (hex/square grids) -│ ├── evaluation/ # Scoring and evaluation -│ │ ├── surrogate.py # Surrogate dataset creation and segment scoring +│ ├── scoring/ # Scoring +│ │ ├── segments.py # Surrogate dataset creation and segment scoring │ │ └── hyperpixel.py # Hyperpixel evaluation and selection │ ├── explainer/ # Core explainer implementation │ │ └── ciao_explainer.py # Main CIAO explainer class From 4ad0d92f96cf028703660747f842143257a178e8 Mon Sep 17 00:00:00 2001 From: dhalmazna Date: Fri, 13 Mar 2026 16:39:07 +0100 Subject: [PATCH 7/8] chore: apply agents' suggestions --- ciao/scoring/hyperpixel.py | 34 +++++++++++++++++++++------------- ciao/scoring/segments.py | 14 ++------------ 2 files changed, 23 insertions(+), 25 deletions(-) diff --git a/ciao/scoring/hyperpixel.py b/ciao/scoring/hyperpixel.py index 7b8bb1a..3bd0342 100644 --- a/ciao/scoring/hyperpixel.py +++ b/ciao/scoring/hyperpixel.py @@ -49,7 +49,12 @@ def calculate_hyperpixel_deltas( f"got {tuple(replacement_image.shape)} vs expected {tuple(expected_shape)}" ) - replacement_image = replacement_image.to(predictor.device) + # Move tensors to the predictor's device once to avoid repeated transfers. + # Align replacement_image dtype with input_batch to prevent torch.where errors. + input_batch = input_batch.to(predictor.device) + replacement_image = replacement_image.to( + device=predictor.device, dtype=input_batch.dtype + ) with torch.no_grad(): original_logit = predictor.get_class_logit_batch(input_batch, target_class_idx)[ @@ -65,22 +70,26 @@ def calculate_hyperpixel_deltas( batch_end = min(batch_start + batch_size, num_masks) current_batch_size = batch_end - batch_start + # Clone on GPU directly batch_inputs = input_batch.repeat(current_batch_size, 1, 1, 1) - for i, segment_ids in enumerate( - hyperpixel_segment_ids_list[batch_start:batch_end] - ): + # Fully vectorized mask creation + mask_list = [] + for segment_ids in hyperpixel_segment_ids_list[batch_start:batch_end]: target_ids = torch.tensor( segment_ids, dtype=gpu_segments.dtype, device=predictor.device ) - combined_mask = torch.isin(gpu_segments, target_ids) + mask_list.append(torch.isin(gpu_segments, target_ids)) - # Apply mask with proper broadcasting - batch_inputs[i] = torch.where( - combined_mask.unsqueeze(0), - replacement_image, - batch_inputs[i], - ) + # mask_tensor shape: [batch_size, H, W] + mask_tensor = torch.stack(mask_list) + + # Apply masks using a single broadcasted operation + batch_inputs = torch.where( + mask_tensor.unsqueeze(1), # [batch_size, 1, H, W] + replacement_image.unsqueeze(0), # [1, C, H, W] + batch_inputs, # [batch_size, C, H, W] + ) masked_logits = predictor.get_class_logit_batch( batch_inputs, target_class_idx @@ -90,7 +99,7 @@ def calculate_hyperpixel_deltas( ] all_deltas.extend(batch_deltas) - del batch_inputs, masked_logits + del batch_inputs, masked_logits, mask_tensor return all_deltas @@ -99,7 +108,6 @@ def select_top_hyperpixels( hyperpixels: list[dict[str, object]], max_hyperpixels: int = 10 ) -> list[dict[str, object]]: """Select top hyperpixels by their primary algorithm-specific score.""" - # Use hyperpixel_score return sorted( hyperpixels, key=lambda hp: abs(hp.get("hyperpixel_score", 0)), # type: ignore[arg-type] diff --git a/ciao/scoring/segments.py b/ciao/scoring/segments.py index e93b21f..a4aa8ed 100644 --- a/ciao/scoring/segments.py +++ b/ciao/scoring/segments.py @@ -46,17 +46,7 @@ def create_surrogate_dataset( X: Binary indicator matrix [num_samples, num_segments] y: Delta scores array [num_samples] """ - # Get original logit - original_logit = predictor.get_class_logit_batch(input_batch, target_class_idx)[ - 0 - ].item() - logger.debug(f"Original logit: {original_logit}") - logger.debug( - f"Probability of class {target_class_idx}: " - f"{predictor.get_predictions(input_batch)[0, target_class_idx].item()}" - ) - - # BFS algorithm + # BFS algorithm using low-level bitmask graph operations local_groups = [] num_segments = len(adj_masks) @@ -94,7 +84,7 @@ def create_surrogate_dataset( X = np.zeros((num_samples, num_segments), dtype=np.float32) y = np.array(deltas, dtype=np.float32) - # Fill indicator matrix + # Fast vectorized indicator matrix filling for i, masked_segments in enumerate(local_groups): X[i, masked_segments] = 1.0 From f30a5306f0c123da0ef07fb104978a44646a88ed Mon Sep 17 00:00:00 2001 From: dhalmazna Date: Sun, 15 Mar 2026 18:18:30 +0100 Subject: [PATCH 8/8] fix: change hyperpixel score retrieval --- ciao/scoring/hyperpixel.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ciao/scoring/hyperpixel.py b/ciao/scoring/hyperpixel.py index 3bd0342..39c2580 100644 --- a/ciao/scoring/hyperpixel.py +++ b/ciao/scoring/hyperpixel.py @@ -110,6 +110,6 @@ def select_top_hyperpixels( """Select top hyperpixels by their primary algorithm-specific score.""" return sorted( hyperpixels, - key=lambda hp: abs(hp.get("hyperpixel_score", 0)), # type: ignore[arg-type] + key=lambda hp: abs(hp["hyperpixel_score"]), # type: ignore[arg-type] reverse=True, )[:max_hyperpixels]