Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 21 additions & 1 deletion ciao/algorithm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,29 @@
"""CIAO algorithm implementations."""

from ciao.algorithm.bitmask_graph import get_frontier, iter_bits
from ciao.algorithm.bitmask_graph import (
add_node,
get_frontier,
has_node,
iter_bits,
mask_to_ids,
pick_random_set_bit,
remove_node,
sample_connected_superset,
)
from ciao.algorithm.search_helpers import evaluate_masks, is_terminal


__all__ = [
# Bitmask graph utilities
"add_node",
# Shared MCTS/MCGS utilities
"evaluate_masks",
"get_frontier",
"has_node",
"is_terminal",
"iter_bits",
"mask_to_ids",
"pick_random_set_bit",
"remove_node",
"sample_connected_superset",
]
87 changes: 86 additions & 1 deletion ciao/algorithm/bitmask_graph.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
"""Bitmask-based graph utilities for efficient segment manipulation.

This module will provide low-level primitives for working with graph structures
This module provides low-level primitives for working with graph structures
represented as integer bitmasks, where each bit represents a node/segment.
"""

import random
from collections.abc import Iterator


def mask_to_ids(mask: int) -> list[int]:
"""Convert integer bitmask to list of segment indices."""
return list(iter_bits(mask))


def iter_bits(mask: int) -> Iterator[int]:
"""Iterate over set bits in a mask using low-bit isolation.

Expand All @@ -25,6 +31,42 @@ def iter_bits(mask: int) -> Iterator[int]:
temp ^= low_bit


def has_node(mask: int, node: int) -> bool:
"""Test if a node is present in the mask."""
return bool(mask & (1 << node))


def add_node(mask: int, node: int) -> int:
"""Add a node to the mask."""
return mask | (1 << node)


def remove_node(mask: int, node: int) -> int:
"""Remove a node from the mask."""
return mask & ~(1 << node)


def pick_random_set_bit(mask: int) -> int:
"""Select a random set bit from the mask in O(k) where k is the number of set bits.

Without allocating a list. Efficient for sparse masks.

Raises:
ValueError: If `mask` has no set bits (i.e. `mask == 0`).
"""
count = mask.bit_count()
if count == 0:
raise ValueError("Cannot pick a random bit from an empty mask (mask == 0).")

which = random.randrange(count)

temp = mask
for _ in range(which):
temp &= temp - 1 # Clear lowest set bit

return (temp & -temp).bit_length() - 1


def get_frontier(mask: int, adj_masks: tuple[int, ...], used_mask: int) -> int:
"""Compute the expansion frontier (valid neighbors) for graph traversal.

Expand Down Expand Up @@ -53,3 +95,46 @@ def get_frontier(mask: int, adj_masks: tuple[int, ...], used_mask: int) -> int:
frontier &= ~used_mask

return frontier


def sample_connected_superset(
base_mask: int,
target_length: int,
adj_masks: tuple[int, ...],
used_mask: int,
) -> int:
"""Sample a connected superset via random walk expansion.

IMPORTANT: This is NOT a uniform sampler over all connected supersets.
The distribution is biased towards segments discovered early and
depends on graph topology. This bias is acceptable for Monte Carlo
estimation in the parent algorithm.

Args:
base_mask: Starting set (must be non-empty and connected)
target_length: Desired size of the superset
adj_masks: Adjacency bitmasks for neighbor lookups
used_mask: Global exclusion mask (segments that must not be added)

Returns:
Bitmask of connected superset containing base_mask

Raises:
ValueError: If base_mask is empty.
"""
if base_mask == 0:
raise ValueError("base_mask must be non-empty to sample a connected superset.")

mask = base_mask

while mask.bit_count() < target_length:
frontier = get_frontier(mask, adj_masks, used_mask)
if frontier == 0:
break

# Select next segment
seg_id = pick_random_set_bit(frontier)

mask = add_node(mask, seg_id)

return mask
52 changes: 52 additions & 0 deletions ciao/algorithm/search_helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
"""Shared utilities for MCTS and MCGS search algorithms.

This module contains common functions used by both Monte Carlo Tree Search (MCTS)
and Monte Carlo Graph Search (MCGS) implementations.
"""

import numpy as np
import torch

from ciao.algorithm.bitmask_graph import get_frontier, mask_to_ids
from ciao.model.predictor import ModelPredictor
from ciao.scoring.hyperpixel import calculate_hyperpixel_deltas


def is_terminal(
mask: int, adj_masks: tuple[int, ...], used_mask: int, max_depth: int
) -> bool:
"""Check if state is terminal (max depth or no frontier)."""
return (
mask.bit_count() >= max_depth or get_frontier(mask, adj_masks, used_mask) == 0
)


def evaluate_masks(
predictor: ModelPredictor,
input_batch: torch.Tensor,
segments: np.ndarray,
target_class_idx: int,
masks: list[int],
replacement_image: torch.Tensor,
) -> list[float]:
"""Evaluate multiple segment masks by computing class score deltas (batched)."""
# Guard against invalid masks (zero or negative)
if any(mask <= 0 for mask in masks):
raise ValueError(
"Cannot evaluate invalid mask: A mask must be a positive integer. "
"Zero masks contain no segments, and negative masks cause "
"incorrect bit iteration due to two's complement representation."
)

all_segment_ids = [mask_to_ids(mask) for mask in masks]

rewards = calculate_hyperpixel_deltas(
predictor=predictor,
input_batch=input_batch,
segments=segments,
replacement_image=replacement_image,
target_class_idx=target_class_idx,
hyperpixel_segment_ids_list=all_segment_ids,
)

return rewards