From b78ec18746739c550df19dfa8d656bd3f505b470 Mon Sep 17 00:00:00 2001 From: dhalmazna Date: Sat, 14 Mar 2026 13:08:43 +0100 Subject: [PATCH 1/2] feat: enhance bitmask graph utilities and add search helper functions --- ciao/algorithm/__init__.py | 22 ++++++++- ciao/algorithm/bitmask_graph.py | 78 +++++++++++++++++++++++++++++++- ciao/algorithm/search_helpers.py | 52 +++++++++++++++++++++ 3 files changed, 150 insertions(+), 2 deletions(-) create mode 100644 ciao/algorithm/search_helpers.py diff --git a/ciao/algorithm/__init__.py b/ciao/algorithm/__init__.py index 83cc50d..4552852 100644 --- a/ciao/algorithm/__init__.py +++ b/ciao/algorithm/__init__.py @@ -1,9 +1,29 @@ """CIAO algorithm implementations.""" -from ciao.algorithm.bitmask_graph import get_frontier, iter_bits +from ciao.algorithm.bitmask_graph import ( + add_node, + get_frontier, + has_node, + iter_bits, + mask_to_ids, + pick_random_set_bit, + remove_node, + sample_connected_superset, +) +from ciao.algorithm.search_helpers import evaluate_masks, is_terminal __all__ = [ + # Bitmask graph utilities + "add_node", + # Shared MCTS/MCGS utilities + "evaluate_masks", "get_frontier", + "has_node", + "is_terminal", "iter_bits", + "mask_to_ids", + "pick_random_set_bit", + "remove_node", + "sample_connected_superset", ] diff --git a/ciao/algorithm/bitmask_graph.py b/ciao/algorithm/bitmask_graph.py index f270f8b..a768f9d 100644 --- a/ciao/algorithm/bitmask_graph.py +++ b/ciao/algorithm/bitmask_graph.py @@ -1,12 +1,18 @@ """Bitmask-based graph utilities for efficient segment manipulation. -This module will provide low-level primitives for working with graph structures +This module provides low-level primitives for working with graph structures represented as integer bitmasks, where each bit represents a node/segment. """ +import random from collections.abc import Iterator +def mask_to_ids(mask: int) -> list[int]: + """Convert integer bitmask to list of segment indices.""" + return [i for i in range(mask.bit_length()) if (mask >> i) & 1] + + def iter_bits(mask: int) -> Iterator[int]: """Iterate over set bits in a mask using low-bit isolation. @@ -25,6 +31,39 @@ def iter_bits(mask: int) -> Iterator[int]: temp ^= low_bit +def has_node(mask: int, node: int) -> bool: + """Test if a node is present in the mask.""" + return bool(mask & (1 << node)) + + +def add_node(mask: int, node: int) -> int: + """Add a node to the mask.""" + return mask | (1 << node) + + +def remove_node(mask: int, node: int) -> int: + """Remove a node from the mask.""" + return mask & ~(1 << node) + + +def pick_random_set_bit(mask: int) -> int: + """Select a random set bit from the mask in O(N) where N is the index of the bit. + + Without allocating a list. Efficient for sparse masks. + """ + count = mask.bit_count() + if count == 0: + return -1 + + which = random.randrange(count) + + temp = mask + for _ in range(which): + temp &= temp - 1 # Clear lowest set bit + + return (temp & -temp).bit_length() - 1 + + def get_frontier(mask: int, adj_masks: tuple[int, ...], used_mask: int) -> int: """Compute the expansion frontier (valid neighbors) for graph traversal. @@ -53,3 +92,40 @@ def get_frontier(mask: int, adj_masks: tuple[int, ...], used_mask: int) -> int: frontier &= ~used_mask return frontier + + +def sample_connected_superset( + base_mask: int, + target_length: int, + adj_masks: tuple[int, ...], + used_mask: int, +) -> int: + """Sample a connected superset via random walk expansion. + + IMPORTANT: This is NOT a uniform sampler over all connected supersets. + The distribution is biased towards segments discovered early and + depends on graph topology. This bias is acceptable for Monte Carlo + estimation in the parent algorithm. + + Args: + base_mask: Starting set (must be non-empty and connected) + target_length: Desired size of the superset + adj_masks: Adjacency bitmasks for neighbor lookups + used_mask: Global exclusion mask (segments that must not be added) + + Returns: + Bitmask of connected superset containing base_mask + """ + mask = base_mask + + while mask.bit_count() < target_length: + frontier = get_frontier(mask, adj_masks, used_mask) + if frontier == 0: + break + + # Select next segment + seg_id = pick_random_set_bit(frontier) + + mask = add_node(mask, seg_id) + + return mask diff --git a/ciao/algorithm/search_helpers.py b/ciao/algorithm/search_helpers.py new file mode 100644 index 0000000..8906dae --- /dev/null +++ b/ciao/algorithm/search_helpers.py @@ -0,0 +1,52 @@ +"""Shared utilities for MCTS and MCGS search algorithms. + +This module contains common functions used by both Monte Carlo Tree Search (MCTS) +and Monte Carlo Graph Search (MCGS) implementations. +""" + +import numpy as np +import torch + +from ciao.algorithm.bitmask_graph import get_frontier, iter_bits +from ciao.model.predictor import ModelPredictor +from ciao.scoring.hyperpixel import calculate_hyperpixel_deltas + + +def is_terminal( + mask: int, adj_masks: tuple[int, ...], used_mask: int, max_depth: int +) -> bool: + """Check if state is terminal (max depth or no frontier).""" + return ( + mask.bit_count() >= max_depth or get_frontier(mask, adj_masks, used_mask) == 0 + ) + + +def evaluate_masks( + predictor: ModelPredictor, + input_batch: torch.Tensor, + segments: np.ndarray, + target_class_idx: int, + masks: list[int], + replacement_image: torch.Tensor, +) -> list[float]: + """Evaluate multiple segment masks by computing class score deltas (batched).""" + # Guard against invalid masks (zero or negative) + if any(mask <= 0 for mask in masks): + raise ValueError( + "Cannot evaluate invalid mask: A mask must be a positive integer. " + "Zero masks contain no segments, and negative masks cause " + "incorrect bit iteration due to two's complement representation." + ) + + all_segment_ids = [list(iter_bits(mask)) for mask in masks] + + rewards = calculate_hyperpixel_deltas( + predictor=predictor, + input_batch=input_batch, + segments=segments, + replacement_image=replacement_image, + target_class_idx=target_class_idx, + hyperpixel_segment_ids_list=all_segment_ids, + ) + + return rewards From dd0ad921ccf8cfde87c9b1b739109f1b56ec3017 Mon Sep 17 00:00:00 2001 From: dhalmazna Date: Sat, 14 Mar 2026 16:23:11 +0100 Subject: [PATCH 2/2] feat: refactor bitmask utilities for improved clarity and error handling --- ciao/algorithm/bitmask_graph.py | 15 ++++++++++++--- ciao/algorithm/search_helpers.py | 4 ++-- 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/ciao/algorithm/bitmask_graph.py b/ciao/algorithm/bitmask_graph.py index a768f9d..6f5d89d 100644 --- a/ciao/algorithm/bitmask_graph.py +++ b/ciao/algorithm/bitmask_graph.py @@ -10,7 +10,7 @@ def mask_to_ids(mask: int) -> list[int]: """Convert integer bitmask to list of segment indices.""" - return [i for i in range(mask.bit_length()) if (mask >> i) & 1] + return list(iter_bits(mask)) def iter_bits(mask: int) -> Iterator[int]: @@ -47,13 +47,16 @@ def remove_node(mask: int, node: int) -> int: def pick_random_set_bit(mask: int) -> int: - """Select a random set bit from the mask in O(N) where N is the index of the bit. + """Select a random set bit from the mask in O(k) where k is the number of set bits. Without allocating a list. Efficient for sparse masks. + + Raises: + ValueError: If `mask` has no set bits (i.e. `mask == 0`). """ count = mask.bit_count() if count == 0: - return -1 + raise ValueError("Cannot pick a random bit from an empty mask (mask == 0).") which = random.randrange(count) @@ -115,7 +118,13 @@ def sample_connected_superset( Returns: Bitmask of connected superset containing base_mask + + Raises: + ValueError: If base_mask is empty. """ + if base_mask == 0: + raise ValueError("base_mask must be non-empty to sample a connected superset.") + mask = base_mask while mask.bit_count() < target_length: diff --git a/ciao/algorithm/search_helpers.py b/ciao/algorithm/search_helpers.py index 8906dae..d767866 100644 --- a/ciao/algorithm/search_helpers.py +++ b/ciao/algorithm/search_helpers.py @@ -7,7 +7,7 @@ import numpy as np import torch -from ciao.algorithm.bitmask_graph import get_frontier, iter_bits +from ciao.algorithm.bitmask_graph import get_frontier, mask_to_ids from ciao.model.predictor import ModelPredictor from ciao.scoring.hyperpixel import calculate_hyperpixel_deltas @@ -38,7 +38,7 @@ def evaluate_masks( "incorrect bit iteration due to two's complement representation." ) - all_segment_ids = [list(iter_bits(mask)) for mask in masks] + all_segment_ids = [mask_to_ids(mask) for mask in masks] rewards = calculate_hyperpixel_deltas( predictor=predictor,