-
Notifications
You must be signed in to change notification settings - Fork 0
feat: implement scoring module for segments #7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
dhalmazna
wants to merge
8
commits into
master
Choose a base branch
from
feat/segment-and-hyperpixel-scoring
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
bc6d907
feat: add ModelPredictor class for model predictions
dhalmazna ab2a08b
chore: apply agents' suggestions
dhalmazna 48cef9b
feat: add scoring utilities for hyperpixels
dhalmazna 35251f4
feat: add basic bitmask graph utilities
dhalmazna a82814a
feat: add segment scoring and surrogate dataset utilities
dhalmazna 548986a
docs: change README to use scoring/ instead of evaluation/
dhalmazna 4ad0d92
chore: apply agents' suggestions
dhalmazna f30a530
fix: change hyperpixel score retrieval
dhalmazna File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Some comments aren't visible on the classic Files Changed page.
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,9 @@ | ||
| """CIAO algorithm implementations.""" | ||
|
|
||
| from ciao.algorithm.bitmask_graph import get_frontier, iter_bits | ||
|
|
||
|
|
||
| __all__ = [ | ||
| "get_frontier", | ||
| "iter_bits", | ||
| ] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,55 @@ | ||
| """Bitmask-based graph utilities for efficient segment manipulation. | ||
|
|
||
| This module will provide low-level primitives for working with graph structures | ||
| represented as integer bitmasks, where each bit represents a node/segment. | ||
| """ | ||
|
|
||
| from collections.abc import Iterator | ||
|
|
||
|
|
||
| def iter_bits(mask: int) -> Iterator[int]: | ||
| """Iterate over set bits in a mask using low-bit isolation. | ||
|
|
||
| Yields node IDs in arbitrary order (depends on bit positions). | ||
| Performance: O(k) where k is the number of set bits. | ||
|
|
||
| Example: | ||
| mask = 0b10110 # bits 1, 2, 4 are set | ||
| list(iter_bits(mask)) # [1, 2, 4] | ||
| """ | ||
| temp = mask | ||
| while temp: | ||
| low_bit = temp & -temp | ||
| node_id = low_bit.bit_length() - 1 | ||
| yield node_id | ||
| temp ^= low_bit | ||
|
|
||
|
|
||
| def get_frontier(mask: int, adj_masks: tuple[int, ...], used_mask: int) -> int: | ||
| """Compute the expansion frontier (valid neighbors) for graph traversal. | ||
|
|
||
| The frontier is the set of segments adjacent to the current structure | ||
| that can be added in the next step. | ||
|
|
||
| A segment is in the frontier if: | ||
| - It is adjacent to at least one segment in the current mask | ||
| - It is NOT already in the current mask | ||
| - It is NOT in the used_mask (respects global exclusion constraints) | ||
|
|
||
| Args: | ||
| mask: Bitmask of current structure/connected component | ||
| adj_masks: Tuple of adjacency bitmasks (adj_masks[i] = neighbors of segment i) | ||
| used_mask: Bitmask of globally excluded segments | ||
|
|
||
| Returns: | ||
| Bitmask of valid frontier segments | ||
| """ | ||
| frontier = 0 | ||
|
|
||
| for node_id in iter_bits(mask): | ||
| frontier |= adj_masks[node_id] | ||
|
|
||
| frontier &= ~mask | ||
| frontier &= ~used_mask | ||
|
|
||
| return frontier |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,6 @@ | ||
| """Model prediction utilities for CIAO.""" | ||
|
|
||
| from ciao.model.predictor import ModelPredictor | ||
|
|
||
|
|
||
| __all__ = ["ModelPredictor"] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,37 @@ | ||
| import torch | ||
|
|
||
|
|
||
| class ModelPredictor: | ||
| """Handles model predictions and class information for the CIAO explainer.""" | ||
|
|
||
| def __init__(self, model: torch.nn.Module, class_names: list[str]) -> None: | ||
| self.model = model | ||
| self.class_names = class_names | ||
|
|
||
| # Ensure deterministic inference by disabling Dropout and freezing BatchNorm | ||
| self.model.eval() | ||
|
|
||
| # Robustly determine the device (fall back to CPU if model has no parameters) | ||
| try: | ||
| self.device = next(model.parameters()).device | ||
| except StopIteration: | ||
| self.device = torch.device("cpu") | ||
|
|
||
| def get_predictions(self, input_batch: torch.Tensor) -> torch.Tensor: | ||
| """Get model predictions (returns probabilities).""" | ||
| input_batch = input_batch.to(self.device) | ||
|
|
||
| with torch.no_grad(): | ||
| outputs = self.model(input_batch) | ||
| probabilities = torch.nn.functional.softmax(outputs, dim=1) | ||
| return probabilities | ||
|
|
||
| def get_class_logit_batch( | ||
| self, input_batch: torch.Tensor, target_class_idx: int | ||
| ) -> torch.Tensor: | ||
| """Get raw logits for a specific target class across a batch of images.""" | ||
| input_batch = input_batch.to(self.device) | ||
|
|
||
| with torch.no_grad(): | ||
| outputs = self.model(input_batch) | ||
| return outputs[:, target_class_idx] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,18 @@ | ||
| """Evaluation and scoring utilities for segments and hyperpixels.""" | ||
|
|
||
| from ciao.scoring.hyperpixel import ( | ||
| calculate_hyperpixel_deltas, | ||
| select_top_hyperpixels, | ||
| ) | ||
| from ciao.scoring.segments import ( | ||
| calculate_segment_scores, | ||
| create_surrogate_dataset, | ||
| ) | ||
|
|
||
|
|
||
| __all__ = [ | ||
| "calculate_hyperpixel_deltas", | ||
| "calculate_segment_scores", | ||
| "create_surrogate_dataset", | ||
| "select_top_hyperpixels", | ||
| ] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,115 @@ | ||
| import numpy as np | ||
| import torch | ||
|
|
||
| from ciao.model.predictor import ModelPredictor | ||
|
|
||
|
|
||
| def calculate_hyperpixel_deltas( | ||
| predictor: ModelPredictor, | ||
| input_batch: torch.Tensor, | ||
| segments: np.ndarray, | ||
| hyperpixel_segment_ids_list: list[list[int]], | ||
| replacement_image: torch.Tensor, | ||
| target_class_idx: int, | ||
| batch_size: int = 64, | ||
| ) -> list[float]: | ||
| """Calculate masking deltas for hyperpixel candidates using batched inference. | ||
|
|
||
| Handles internal batching to prevent memory overflow with large path counts. | ||
|
|
||
| Args: | ||
| predictor: ModelPredictor instance | ||
| input_batch: Input tensor batch [1, C, H, W] | ||
| segments: Pixel-to-segment mapping array [H, W] | ||
| hyperpixel_segment_ids_list: List of segment ID lists, e.g. [[1,2,3], [4,5,6]] | ||
| replacement_image: Replacement tensor [C, H, W] | ||
| target_class_idx: Target class index | ||
| batch_size: Batch size | ||
|
|
||
| Returns: | ||
| List[float]: Delta scores for each candidate | ||
| """ | ||
| if not hyperpixel_segment_ids_list: | ||
| return [] | ||
|
|
||
| # Validate all segment lists are non-empty | ||
| for i, segment_ids in enumerate(hyperpixel_segment_ids_list): | ||
| if not segment_ids: | ||
| raise ValueError(f"Empty segment list at index {i}") | ||
|
|
||
| if input_batch.dim() != 4 or input_batch.shape[0] != 1: | ||
| raise ValueError( | ||
| f"input_batch must have shape [1, C, H, W], got {tuple(input_batch.shape)}" | ||
| ) | ||
|
|
||
| expected_shape = input_batch.shape[1:] | ||
| if tuple(replacement_image.shape) != tuple(expected_shape): | ||
| raise ValueError( | ||
| "replacement_image must have shape [C, H, W] matching input_batch, " | ||
| f"got {tuple(replacement_image.shape)} vs expected {tuple(expected_shape)}" | ||
| ) | ||
|
|
||
| # Move tensors to the predictor's device once to avoid repeated transfers. | ||
| # Align replacement_image dtype with input_batch to prevent torch.where errors. | ||
| input_batch = input_batch.to(predictor.device) | ||
| replacement_image = replacement_image.to( | ||
| device=predictor.device, dtype=input_batch.dtype | ||
| ) | ||
|
|
||
| with torch.no_grad(): | ||
| original_logit = predictor.get_class_logit_batch(input_batch, target_class_idx)[ | ||
| 0 | ||
| ].item() | ||
|
|
||
| gpu_segments = torch.from_numpy(segments).to(predictor.device) | ||
|
|
||
| all_deltas = [] | ||
| num_masks = len(hyperpixel_segment_ids_list) | ||
|
|
||
| for batch_start in range(0, num_masks, batch_size): | ||
| batch_end = min(batch_start + batch_size, num_masks) | ||
| current_batch_size = batch_end - batch_start | ||
|
|
||
| # Clone on GPU directly | ||
| batch_inputs = input_batch.repeat(current_batch_size, 1, 1, 1) | ||
|
|
||
| # Fully vectorized mask creation | ||
| mask_list = [] | ||
| for segment_ids in hyperpixel_segment_ids_list[batch_start:batch_end]: | ||
| target_ids = torch.tensor( | ||
| segment_ids, dtype=gpu_segments.dtype, device=predictor.device | ||
| ) | ||
| mask_list.append(torch.isin(gpu_segments, target_ids)) | ||
|
|
||
| # mask_tensor shape: [batch_size, H, W] | ||
| mask_tensor = torch.stack(mask_list) | ||
|
|
||
| # Apply masks using a single broadcasted operation | ||
| batch_inputs = torch.where( | ||
| mask_tensor.unsqueeze(1), # [batch_size, 1, H, W] | ||
| replacement_image.unsqueeze(0), # [1, C, H, W] | ||
| batch_inputs, # [batch_size, C, H, W] | ||
| ) | ||
|
|
||
| masked_logits = predictor.get_class_logit_batch( | ||
| batch_inputs, target_class_idx | ||
| ) | ||
| batch_deltas = [ | ||
| original_logit - masked_logit.item() for masked_logit in masked_logits | ||
| ] | ||
| all_deltas.extend(batch_deltas) | ||
|
|
||
| del batch_inputs, masked_logits, mask_tensor | ||
|
|
||
| return all_deltas | ||
|
|
||
|
|
||
| def select_top_hyperpixels( | ||
| hyperpixels: list[dict[str, object]], max_hyperpixels: int = 10 | ||
| ) -> list[dict[str, object]]: | ||
| """Select top hyperpixels by their primary algorithm-specific score.""" | ||
| return sorted( | ||
| hyperpixels, | ||
| key=lambda hp: abs(hp["hyperpixel_score"]), # type: ignore[arg-type] | ||
| reverse=True, | ||
| )[:max_hyperpixels] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,136 @@ | ||
| import logging | ||
|
|
||
| import numpy as np | ||
| import torch | ||
|
|
||
| from ciao.algorithm.bitmask_graph import get_frontier, iter_bits | ||
| from ciao.model.predictor import ModelPredictor | ||
| from ciao.scoring.hyperpixel import calculate_hyperpixel_deltas | ||
|
|
||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| def create_surrogate_dataset( | ||
| predictor: ModelPredictor, | ||
| input_batch: torch.Tensor, | ||
| replacement_image: torch.Tensor, | ||
| segments: np.ndarray, | ||
| adj_masks: tuple[int, ...], | ||
| target_class_idx: int, | ||
| neighborhood_distance: int = 1, | ||
| batch_size: int = 16, | ||
| ) -> tuple[np.ndarray, np.ndarray]: | ||
| """Create surrogate dataset for interpretability. | ||
|
|
||
| Each row represents one masking operation: | ||
| - Features (X): Binary indicator vector [num_segments] - 1 if segment was masked, 0 otherwise | ||
| - Target (y): Delta score (original_logit - masked_logit) | ||
|
|
||
| This dataset can be used for: | ||
| - Computing segment importance scores | ||
| - Training interpretable models (like LIME does) | ||
| - Analyzing masking effects | ||
|
|
||
| Args: | ||
| predictor: ModelPredictor instance | ||
| input_batch: Input tensor batch | ||
| replacement_image: Replacement tensor [C, H, W] | ||
| segments: Pixel-to-segment mapping array [H, W] | ||
| adj_masks: Tuple of adjacency bitmasks where bit i indicates neighbor i | ||
| target_class_idx: Target class index | ||
| neighborhood_distance: Distance for neighborhood masking | ||
| batch_size: Batch size for processing segments | ||
|
|
||
| Returns: | ||
| X: Binary indicator matrix [num_samples, num_segments] | ||
| y: Delta scores array [num_samples] | ||
| """ | ||
| # BFS algorithm using low-level bitmask graph operations | ||
| local_groups = [] | ||
| num_segments = len(adj_masks) | ||
|
|
||
| for segment_id in range(num_segments): | ||
| visited_mask = 1 << segment_id | ||
| current_layer_mask = visited_mask | ||
|
|
||
| for _ in range(neighborhood_distance): | ||
| next_layer_mask = get_frontier( | ||
| mask=current_layer_mask, adj_masks=adj_masks, used_mask=visited_mask | ||
| ) | ||
|
|
||
| # Early exit if we reached the boundary of the isolated graph component | ||
| if not next_layer_mask: | ||
| break | ||
|
|
||
| visited_mask |= next_layer_mask | ||
| current_layer_mask = next_layer_mask | ||
|
|
||
| local_groups.append(list(iter_bits(visited_mask))) | ||
|
|
||
| # Calculate deltas for all local groups | ||
| deltas = calculate_hyperpixel_deltas( | ||
| predictor, | ||
| input_batch, | ||
| segments, | ||
| local_groups, | ||
| replacement_image, | ||
| target_class_idx, | ||
| batch_size=batch_size, | ||
| ) | ||
|
|
||
| # Create surrogate dataset | ||
| num_samples = len(local_groups) | ||
| X = np.zeros((num_samples, num_segments), dtype=np.float32) | ||
| y = np.array(deltas, dtype=np.float32) | ||
|
|
||
| # Fast vectorized indicator matrix filling | ||
| for i, masked_segments in enumerate(local_groups): | ||
| X[i, masked_segments] = 1.0 | ||
|
|
||
| logger.info(f"Created surrogate dataset: X shape {X.shape}, y shape {y.shape}") | ||
| logger.info(f"Average delta: {y.mean():.4f}, std: {y.std():.4f}") | ||
|
|
||
| return X, y | ||
|
|
||
|
|
||
| def calculate_segment_scores(X: np.ndarray, y: np.ndarray) -> dict[int, float]: # noqa: N803 | ||
| """Calculate neighborhood-smoothed segment importance scores from sampled deltas. | ||
|
|
||
| This function computes the mean delta score for each segment across all | ||
| surrogate samples where that segment was masked. It acts as a fast | ||
| approximation of the segment's marginal contribution to the prediction. | ||
|
|
||
| Args: | ||
| X: Binary indicator matrix of shape [num_samples, num_segments]. | ||
| X[i, j] == 1 if segment j is masked in sample i, else 0. | ||
| y: Delta scores array of shape [num_samples]. | ||
|
|
||
| Returns: | ||
| Dict mapping segment_id -> averaged importance score. | ||
|
|
||
| Raises: | ||
| ValueError: If any segment was never masked in the surrogate dataset. | ||
| """ | ||
| # Vectorized count of how many times each segment was masked | ||
| counts = X.sum(axis=0) | ||
|
|
||
| # Fail-fast validation for unmasked segments | ||
| unmasked_indices = np.where(counts == 0)[0] | ||
| if unmasked_indices.size > 0: | ||
| raise ValueError( | ||
| f"Segment(s) {unmasked_indices.tolist()} never appear in any local group. " | ||
| "This suggests a bug in group generation or segment ID mapping." | ||
| ) | ||
|
|
||
| # Vectorized mean calculation using matrix multiplication | ||
| segment_means = (y @ X) / counts | ||
|
|
||
| # Convert the numpy array results to the expected dictionary format | ||
| scores = {int(i): float(score) for i, score in enumerate(segment_means)} | ||
|
|
||
| if scores: | ||
| score_values = list(scores.values()) | ||
| logger.info(f"Score range: [{min(score_values):.4f}, {max(score_values):.4f}]") | ||
|
|
||
| return scores | ||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.