Skip to content

Enh: add IncrementalWatershedCut for interactive seeded watershed#299

Open
lapertor wants to merge 2 commits into
higra:masterfrom
lapertor:enh/seeded-watershed-incremental
Open

Enh: add IncrementalWatershedCut for interactive seeded watershed#299
lapertor wants to merge 2 commits into
higra:masterfrom
lapertor:enh/seeded-watershed-incremental

Conversation

@lapertor
Copy link
Copy Markdown
Contributor

@lapertor lapertor commented Apr 1, 2026

Closes #298

Motivation

labelisation_seeded_watershed recomputes the full watershed from scratch on every seed change, which is expensive in interactive segmentation scenarios. This PR introduces IncrementalWatershedCut, a stateful class that maintains a cached Binary Partition Tree (BPT) and updates only the affected regions when seeds are added or removed.

The implementation follows the algorithm described in:

Q. Lebon, J. Lefevre, J. Cousty, B. Perret.
Interactive Segmentation With Incremental Watershed Cuts, CIARP 2023.
https://hal.science/hal-04069187v1

Algorithm

The key data structures are:

  • A canonical BPT built once at construction time via bpt_canonical
  • A visitCount array on BPT nodes: tracking how many seed paths pass through each internal node
  • A is_cut boolean array on MST edges: an MST edge is a watershed edge when its corresponding BPT node has visitCount >= 2
  • A BFS on the MST forest (skipping cut edges) to assign labels to connected components

When a seed is added, the algorithm walks up the BPT from the seed leaf, incrementing visitCount at each node; when it reaches 2, the edge is marked as a cut. When a seed is removed, the walk decrements visitCount; when it drops to 1, the edge is unmarked. A subsequent BFS re-labels the graph. The BFS correctly handles both splits (add seed) and merges (remove seed), which is why a union-find structure is not used here.

Changes

  • include/higra/algo/watershed.hpp: new incremental_watershed_cut class
  • higra/algo/py_watershed.cpp: pybind11 binding for IncrementalWatershedCut
  • higra/algo/watershed.py: Python wrapper class IncrementalWatershedCut
  • test/cpp/algo/test_watershed.cpp: 5 new C++ tests
  • test/python/test_algo/test_watershed.py: 10 new Python tests in class TestIncrementalWatershed

Usage

import higra as hg
import numpy as np

graph = hg.get_4_adjacency_graph((512, 512))
edge_weights = np.random.rand(graph.num_edges())

iws = hg.IncrementalWatershedCut(graph, edge_weights)  # build BPT once

# Add seeds interactively
iws.add_seeds(np.array([0, 100]), np.array([1, 2]))
labels = iws.get_labeling()  # fast update

# Modify seeds without full recomputation
iws.add_seeds(np.array([2000]), np.array([1]))
labels = iws.get_labeling()

iws.remove_seeds(np.array([100]))
labels = iws.get_labeling()

Performance

Benchmarks on a 500x500 image (20 interactions, 5 seeds each):

Method Mean per interaction Total (20 interactions)
labelisation_seeded_watershed ~78 ms ~1575 ms
IncrementalWatershedCut (update only) ~5 ms ~234 ms
Speedup ~15x ~6.7x (init amortized)

On 1000x1000: ~11x speedup per interaction. The speedup is stable across
interactions regardless of the number of accumulated seeds.

Consistency

test_consistency_with_seeded_watershed verifies that IncrementalWatershedCut produces the same labeling as labelisation_seeded_watershed given identical seeds, ensuring backward compatibility.

Benchmark script
"""
Benchmark: IncrementalWatershedCut vs labelisation_seeded_watershed
in an interactive segmentation scenario.

Simulates an interactive session where seeds are added incrementally.
At each interaction, the incremental version only updates its internal state,
while the baseline recomputes the full watershed from scratch.
"""

import higra as hg
import numpy as np
import time


def benchmark_interactive(shape=(500, 500), n_interactions=20, seeds_per_interaction=5,
                          n_runs=3):
    """
    Simulate an interactive session:
    - Start with 2 initial seeds (foreground + background)
    - At each interaction, add a batch of new seeds
    - Compare the time per interaction for both methods
    """
    h, w = shape
    n_vertices = h * w
    graph = hg.get_4_adjacency_graph(shape)

    rng = np.random.RandomState(42)
    edge_weights = rng.rand(graph.num_edges()).astype(np.float64)

    # Pre-generate all seed positions and labels
    all_seed_vertices = rng.choice(n_vertices, size=2 + n_interactions * seeds_per_interaction,
                                   replace=False)
    all_seed_labels = np.ones(len(all_seed_vertices), dtype=np.int64)
    # First seed is foreground (1), second is background (2)
    all_seed_labels[0] = 1
    all_seed_labels[1] = 2
    # Remaining seeds alternate between foreground and background
    for i in range(2, len(all_seed_labels)):
        all_seed_labels[i] = 1 + (i % 2)

    print(f"Shape: {shape}")
    print(f"Num vertices: {n_vertices}")
    print(f"Num edges: {graph.num_edges()}")
    print(f"Interactions: {n_interactions}")
    print(f"Seeds per interaction: {seeds_per_interaction}")
    print(f"Runs per measurement: {n_runs}")
    print()

    # ---- Baseline: labelisation_seeded_watershed from scratch each time ----
    times_baseline = []
    for interaction in range(n_interactions):
        n_seeds = 2 + (interaction + 1) * seeds_per_interaction
        seed_v = all_seed_vertices[:n_seeds]
        seed_l = all_seed_labels[:n_seeds]

        # Build the full seeds array
        vertex_seeds = np.zeros(n_vertices, dtype=np.int64)
        vertex_seeds[seed_v] = seed_l

        elapsed = 0.0
        for _ in range(n_runs):
            t0 = time.perf_counter()
            labels_baseline = hg.labelisation_seeded_watershed(graph, edge_weights, vertex_seeds)
            t1 = time.perf_counter()
            elapsed += (t1 - t0)
        times_baseline.append(elapsed / n_runs)

    # ---- Incremental: IncrementalWatershedCut ----
    times_incremental_total = []  # init + add_seeds + get_labeling
    times_incremental_update = []  # add_seeds + get_labeling only

    t0 = time.perf_counter()
    iws = hg.IncrementalWatershedCut(graph, edge_weights)
    t_init = time.perf_counter() - t0

    # Add initial 2 seeds
    iws.add_seeds(all_seed_vertices[:2], all_seed_labels[:2])
    _ = iws.get_labeling()

    for interaction in range(n_interactions):
        start = 2 + interaction * seeds_per_interaction
        end = start + seeds_per_interaction
        new_v = all_seed_vertices[start:end]
        new_l = all_seed_labels[start:end]

        elapsed = 0.0
        for run in range(n_runs):
            if run == 0:
                # First run: actually modify state
                t0 = time.perf_counter()
                iws.add_seeds(new_v, new_l)
                labels_incr = iws.get_labeling()
                t1 = time.perf_counter()
                elapsed += (t1 - t0)
            else:
                # Subsequent runs: remove then re-add to measure consistently
                iws.remove_seeds(new_v)
                t0 = time.perf_counter()
                iws.add_seeds(new_v, new_l)
                labels_incr = iws.get_labeling()
                t1 = time.perf_counter()
                elapsed += (t1 - t0)

        avg_time = elapsed / n_runs
        times_incremental_update.append(avg_time)
        times_incremental_total.append(avg_time)

    # ---- Report ----
    print("=" * 70)
    print(f"{'Interaction':>12} | {'Baseline (ms)':>14} | {'Incremental (ms)':>17} | {'Speedup':>8}")
    print("-" * 70)
    for i in range(n_interactions):
        t_base = times_baseline[i] * 1000
        t_incr = times_incremental_update[i] * 1000
        speedup = t_base / t_incr if t_incr > 0 else float('inf')
        print(f"{i+1:>12} | {t_base:>14.2f} | {t_incr:>17.2f} | {speedup:>7.1f}x")

    print("-" * 70)
    mean_base = np.mean(times_baseline) * 1000
    mean_incr = np.mean(times_incremental_update) * 1000
    print(f"{'Mean':>12} | {mean_base:>14.2f} | {mean_incr:>17.2f} | {mean_base/mean_incr:>7.1f}x")
    print()
    print(f"Incremental init time: {t_init*1000:.2f} ms (one-time cost)")
    print(f"Baseline total time:   {sum(times_baseline)*1000:.2f} ms")
    print(f"Incremental total time (init + all updates): "
          f"{t_init*1000 + sum(times_incremental_total)*1000:.2f} ms")


def benchmark_scaling():
    """
    Show how the incremental method scales with image size.
    """
    print("\n" + "=" * 70)
    print("SCALING BENCHMARK: per-interaction time vs image size")
    print("=" * 70)

    sizes = [(100, 100), (500, 500), (1000, 1000), (2000, 2000)]
    n_runs = 3

    for shape in sizes:
        h, w = shape
        n_vertices = h * w
        graph = hg.get_4_adjacency_graph(shape)

        rng = np.random.RandomState(42)
        edge_weights = rng.rand(graph.num_edges()).astype(np.float64)

        seed_vertices = rng.choice(n_vertices, size=12, replace=False)
        seed_labels = np.array([1, 2] * 6, dtype=np.int64)

        # Baseline
        vertex_seeds = np.zeros(n_vertices, dtype=np.int64)
        vertex_seeds[seed_vertices] = seed_labels
        elapsed_base = 0.0
        for _ in range(n_runs):
            t0 = time.perf_counter()
            hg.labelisation_seeded_watershed(graph, edge_weights, vertex_seeds)
            elapsed_base += time.perf_counter() - t0
        t_base = (elapsed_base / n_runs) * 1000

        # Incremental: init + add all seeds + get_labeling
        t0 = time.perf_counter()
        iws = hg.IncrementalWatershedCut(graph, edge_weights)
        t_init = (time.perf_counter() - t0) * 1000

        # Add 10 seeds, then measure adding the last 2
        iws.add_seeds(seed_vertices[:10], seed_labels[:10])
        _ = iws.get_labeling()

        elapsed_incr = 0.0
        for run in range(n_runs):
            if run == 0:
                t0 = time.perf_counter()
                iws.add_seeds(seed_vertices[10:], seed_labels[10:])
                _ = iws.get_labeling()
                elapsed_incr += time.perf_counter() - t0
            else:
                iws.remove_seeds(seed_vertices[10:])
                t0 = time.perf_counter()
                iws.add_seeds(seed_vertices[10:], seed_labels[10:])
                _ = iws.get_labeling()
                elapsed_incr += time.perf_counter() - t0
        t_incr = (elapsed_incr / n_runs) * 1000

        speedup = t_base / t_incr if t_incr > 0 else float('inf')
        print(f"{shape[0]:>5}x{shape[1]:<5} | "
              f"Init: {t_init:>8.2f} ms | "
              f"Baseline: {t_base:>8.2f} ms | "
              f"Incremental update: {t_incr:>8.2f} ms | "
              f"Speedup: {speedup:.1f}x")


if __name__ == "__main__":
    benchmark_interactive(shape=(500, 500), n_interactions=20, seeds_per_interaction=5)
    benchmark_interactive(shape=(1000, 1000), n_interactions=20, seeds_per_interaction=5)
    benchmark_scaling()
Benchmark output (Windows, Intel Core i7-14700KF, Python 3.13.12)
Shape: (500, 500)
Num vertices: 250000
Num edges: 499000
Interactions: 20
Seeds per interaction: 5
Runs per measurement: 3

======================================================================
 Interaction |  Baseline (ms) |  Incremental (ms) |  Speedup
----------------------------------------------------------------------
           1 |          83.57 |              5.53 |    15.1x
           2 |          80.64 |              5.20 |    15.5x
           3 |          77.09 |              5.16 |    14.9x
           4 |          86.14 |              5.24 |    16.4x
           5 |          77.70 |              5.14 |    15.1x
           6 |          78.92 |              4.99 |    15.8x
           7 |          77.85 |              5.11 |    15.2x
           8 |          79.57 |              4.87 |    16.3x
           9 |          77.41 |              4.79 |    16.2x
          10 |          78.74 |              5.36 |    14.7x
          11 |          78.33 |              5.07 |    15.4x
          12 |          77.94 |              4.87 |    16.0x
          13 |          80.04 |              4.80 |    16.7x
          14 |          77.10 |              5.04 |    15.3x
          15 |          78.47 |              4.79 |    16.4x
          16 |          76.55 |              4.87 |    15.7x
          17 |          77.44 |              5.20 |    14.9x
          18 |          76.18 |              5.25 |    14.5x
          19 |          77.91 |              4.97 |    15.7x
          20 |          78.03 |              5.05 |    15.4x
----------------------------------------------------------------------
        Mean |          78.78 |              5.07 |    15.6x

Incremental init time: 132.65 ms (one-time cost)
Baseline total time:   1575.63 ms
Incremental total time (init + all updates): 233.96 ms
Shape: (1000, 1000)
Num vertices: 1000000
Num edges: 1998000
Interactions: 20
Seeds per interaction: 5
Runs per measurement: 3

======================================================================
 Interaction |  Baseline (ms) |  Incremental (ms) |  Speedup
----------------------------------------------------------------------
           1 |         387.37 |             35.89 |    10.8x
           2 |         399.85 |             36.06 |    11.1x
           3 |         398.65 |             35.02 |    11.4x
           4 |         396.37 |             36.10 |    11.0x
           5 |         389.45 |             37.91 |    10.3x
           6 |         388.10 |             36.56 |    10.6x
           7 |         383.60 |             35.09 |    10.9x
           8 |         385.06 |             35.98 |    10.7x
           9 |         385.41 |             35.53 |    10.8x
          10 |         409.90 |             37.44 |    10.9x
          11 |         388.49 |             34.13 |    11.4x
          12 |         394.39 |             35.33 |    11.2x
          13 |         403.05 |             34.53 |    11.7x
          14 |         389.26 |             34.39 |    11.3x
          15 |         390.24 |             34.67 |    11.3x
          16 |         390.97 |             33.66 |    11.6x
          17 |         390.84 |             36.26 |    10.8x
          18 |         387.86 |             34.14 |    11.4x
          19 |         386.30 |             35.78 |    10.8x
          20 |         390.45 |             35.16 |    11.1x
----------------------------------------------------------------------
        Mean |         391.78 |             35.48 |    11.0x

Incremental init time: 734.49 ms (one-time cost)
Baseline total time:   7835.60 ms
Incremental total time (init + all updates): 1444.11 ms

======================================================================
SCALING BENCHMARK: per-interaction time vs image size
======================================================================
  100x100   | Init:     3.43 ms | Baseline:     2.51 ms | Incremental update:     0.18 ms | Speedup: 14.3x
  500x500   | Init:   141.11 ms | Baseline:    77.33 ms | Incremental update:     5.49 ms | Speedup: 14.1x
 1000x1000  | Init:   725.87 ms | Baseline:   399.25 ms | Incremental update:    35.18 ms | Speedup: 11.3x
 2000x2000  | Init:  3898.48 ms | Baseline:  2001.76 ms | Incremental update:   179.41 ms | Speedup: 11.2x

Comment thread higra/algo/py_watershed.cpp
Comment thread include/higra/algo/watershed.hpp Outdated
Comment thread include/higra/algo/watershed.hpp Outdated
Comment thread include/higra/algo/watershed.hpp Outdated
@lapertor lapertor force-pushed the enh/seeded-watershed-incremental branch from 956d3c3 to 03054ba Compare April 1, 2026 17:21
}

// Local relabeling: BFS from v in its component
relabel_component_from_seed(v, l);
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Calling relabel immediately after having found the new cut edge is probably suboptimal as it can lead to relabelling the same vertices several time in the same batch.

Comment thread include/higra/algo/watershed.hpp Outdated
}

// Local relabeling: find merged component and relabel from remaining seeds
relabel_merged_component(v);
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here

Comment thread include/higra/algo/watershed.hpp Outdated
* (BFS from v respecting current cuts), reset labels to 0, then
* relabel from all remaining seeds in the component.
*/
void relabel_merged_component(index_t v) {
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is sub-optimal. When a seed is removed we also know the edge of the mst that is reactivated which can be used to determine the label of the "other side" (except if the other side is also removed in the same batch, in which case relabelling will start somewhere else).

* The algorithm maintains a canonical BPT and a visitCount array to identify
* watershed edges without recomputing from scratch at each interaction.
* The labeling is cached and updated locally when seeds change.
*
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please Add a short usage example.

Also add a reference to the class in the global doc

lapertor added 2 commits May 12, 2026 10:51
Add `incremental_watershed_cut` C++ class and `IncrementalWatershedCut`
Python class that maintain a cached BPT to update seeded watershed cuts
incrementally, avoiding a full recomputation on each seed change.

Based on the algorithm described in:
  Q. Lebon, J. Lefevre, J. Cousty, B. Perret.
  'Interactive Segmentation With Incremental Watershed Cuts', CIARP 2023.
  https://hal.science/hal-04069187v1

Signed-off-by: lapertor <raphael.lapertot@gmail.com>
…ions, optimize BFS

- Fix visitCount batch removal bug: walk-up loop now always decrements
  visitCount and breaks on 2->1 transition (matching Lebon's removeMarker),
  instead of stopping on visitCount==1 without decrementing.
- Add hg_assert in unreachable else-branch of Pass 2a (both sides of
  de-cut with different seed labels should not happen).
- Remove unused #include <unordered_set>.
- Optimize component_seed_label: replace per-call std::vector allocation
  with m_visited generation-counter pattern (zero-cost reset).
- Add trailing newline to watershed.rst.
- Clean test_watershed.py diff (73 insertions, no whitespace changes).
- Add 3 regression tests: batch remove equals sequential, both sides of
  edge, and interactive churn.
- Add C++ test for batch remove sibling subtrees visitCount.

Signed-off-by: lapertor <raphael.lapertot@gmail.com>
@lapertor lapertor force-pushed the enh/seeded-watershed-incremental branch from 03054ba to b1430a7 Compare May 13, 2026 11:23
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Feature proposal: incremental seeded watershed cut based on BPT

2 participants