Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 81 additions & 1 deletion software/control/core/multi_point_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from datetime import datetime
from enum import Enum
from threading import Thread
from typing import Optional, Tuple, Any
from typing import List, Optional, Tuple, Any

import numpy as np
import pandas as pd
Expand All @@ -18,6 +18,7 @@
import control._def
from control.core.auto_focus_controller import AutoFocusController
from control.core.multi_point_utils import MultiPointControllerFunctions, ScanPositionInformation, AcquisitionParameters
from control.core.state_machine import TimepointStateMachine, TimepointState, FOVIdentifier
from control.core.scan_coordinates import ScanCoordinates
from control.core.laser_auto_focus_controller import LaserAutofocusController
from control.core.live_controller import LiveController
Expand Down Expand Up @@ -209,6 +210,7 @@ def __init__(
self.objectiveStore: ObjectiveStore = objective_store
self.callbacks: MultiPointControllerFunctions = callbacks
self.multiPointWorker: Optional[MultiPointWorker] = None
self._state_machine: Optional[TimepointStateMachine] = None
self.fluidics: Optional[Any] = microscope.addons.fluidics
self.thread: Optional[Thread] = None
self._per_acq_log_handler = None
Expand Down Expand Up @@ -808,6 +810,14 @@ def finish_fn():
self.overlap_percent,
)

# Calculate total FOVs for this timepoint (for state machine)
total_fovs = sum(len(coords) for coords in scan_position_information.scan_region_fov_coords_mm.values())

# Create state machine for pause/resume/retake functionality
self._state_machine = TimepointStateMachine()
self._state_machine.reset(total_fovs)
self._state_machine.on_state_changed = self._on_state_changed

self.multiPointWorker = MultiPointWorker(
scope=self.microscope,
live_controller=self.liveController,
Expand All @@ -821,6 +831,7 @@ def finish_fn():
extra_job_classes=[],
alignment_widget=self._alignment_widget,
slack_notifier=self._slack_notifier,
state_machine=self._state_machine,
)

# Signal after worker creation so backpressure_controller is available
Expand Down Expand Up @@ -946,6 +957,75 @@ def validate_acquisition_settings(self) -> bool:
return False
return True

# --- State Machine Control Methods ---

def request_pause(self) -> bool:
"""Request pause of current acquisition.

The acquisition will pause after completing the current FOV.

Returns:
True if pause request was accepted, False if not in a pausable state.
"""
if self._state_machine:
return self._state_machine.request_pause()
return False

def request_resume(self) -> bool:
"""Resume paused acquisition.

Returns:
True if successfully initiated resume, False if not paused.
"""
if self._state_machine:
return self._state_machine.resume()
return False

def request_retake(self, fovs: List[FOVIdentifier]) -> bool:
"""Request retake of specified FOVs.

Can only be called when acquisition is paused.

Args:
fovs: List of FOV identifiers to retake.

Returns:
True if retake started successfully, False otherwise.
"""
if self._state_machine:
return self._state_machine.retake(fovs)
return False

def abort_retake(self) -> bool:
"""Abort current retake operation only.

Does not abort the entire acquisition - returns to paused state.

Returns:
True if abort was handled, False otherwise.
"""
if self._state_machine:
accepted, abort_all = self._state_machine.abort()
if accepted and not abort_all and self.multiPointWorker:
self.multiPointWorker.request_retake_abort()
return accepted
return False

def get_acquisition_state(self) -> Optional[TimepointState]:
"""Get current acquisition state.

Returns:
Current TimepointState, or None if no state machine exists.
"""
if self._state_machine:
return self._state_machine.state
return None

def _on_state_changed(self, new_state: TimepointState) -> None:
"""Handle state machine state changes (called from state machine thread)."""
self._log.info(f"Acquisition state changed to: {new_state.name}")
self.callbacks.signal_state_changed(new_state)

def get_plate_view(self) -> np.ndarray:
"""Get the current plate view array from the acquisition.

Expand Down
8 changes: 8 additions & 0 deletions software/control/core/multi_point_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

if TYPE_CHECKING:
from control.slack_notifier import TimepointStats, AcquisitionStats
from control.core.state_machine import TimepointState, FOVIdentifier


@dataclass
Expand Down Expand Up @@ -125,3 +126,10 @@ class MultiPointControllerFunctions:
# Zarr frame written callback - called when subprocess completes writing a frame
# Args: (fov, time_point, z_index, channel_name, region_idx)
signal_zarr_frame_written: Callable[[int, int, int, str, int], None] = lambda *a, **kw: None

# State machine signals for pause/resume/retake functionality
signal_state_changed: Callable[["TimepointState"], None] = lambda *a, **kw: None
signal_paused: Callable[[], None] = lambda *a, **kw: None
signal_resumed: Callable[[], None] = lambda *a, **kw: None
signal_retake_started: Callable[[List["FOVIdentifier"]], None] = lambda *a, **kw: None
signal_retakes_complete: Callable[[], None] = lambda *a, **kw: None
103 changes: 103 additions & 0 deletions software/control/core/multi_point_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
PlateViewInit,
PlateViewUpdate,
)
from control.core.state_machine import TimepointStateMachine, TimepointState, FOVIdentifier
from control.core.objective_store import ObjectiveStore
from control.microcontroller import Microcontroller
from control.microscope import Microscope
Expand Down Expand Up @@ -81,12 +82,19 @@ def __init__(
abort_on_failed_jobs: bool = True,
alignment_widget=None,
slack_notifier=None,
state_machine: Optional[TimepointStateMachine] = None,
):
self._log = squid.logging.get_logger(__class__.__name__)
self._timing = utils.TimingManager("MultiPointWorker Timer Manager")
self._alignment_widget = alignment_widget # Optional AlignmentWidget for coordinate offset
self._slack_notifier = slack_notifier # Optional SlackNotifier for notifications

# State machine for pause/resume/retake functionality
self._state_machine = state_machine
self._fov_coords_map: Dict[Tuple[str, int], Tuple[float, float, float]] = {}
self._retake_abort_requested = False
self._current_path: Optional[str] = None # Store current timepoint path for retakes

# Slack notification tracking counters
self._timepoint_image_count = 0
self._timepoint_fov_count = 0
Expand Down Expand Up @@ -592,6 +600,14 @@ def run_single_time_point(self):
with self._timing.get_timer("run_coordinate_acquisition"):
self.run_coordinate_acquisition(current_path)

# Mark timepoint as captured in state machine
if self._state_machine:
self._state_machine.mark_all_captured()

# Check for pause after timepoint complete (allows review before next timepoint)
if self._state_machine and self._state_machine.is_pause_requested():
self._handle_pause()

# Save plate view for this timepoint
if self._generate_downsampled_views and self._downsampled_view_manager is not None:
# Wait for pending downsampled view jobs to complete
Expand Down Expand Up @@ -1127,6 +1143,9 @@ def save_plate_view(self, path: str) -> None:
self._downsampled_view_manager.save_plate_view(path)

def run_coordinate_acquisition(self, current_path):
# Store current path for potential retakes
self._current_path = current_path

# Reset backpressure counters at acquisition start
# IMPORTANT: Must be before any camera triggers
self._backpressure.reset()
Expand All @@ -1146,6 +1165,17 @@ def run_coordinate_acquisition(self, current_path):
self.total_scans = self.num_fovs * self.NZ * len(self.selected_configurations)

for fov, coordinate_mm in enumerate(coordinates):
# Store coordinates for potential retake
self._fov_coords_map[(region_id, fov)] = coordinate_mm

# STATE MACHINE: Check for pause request before each FOV
if self._state_machine and self._state_machine.is_pause_requested():
self._handle_pause()
# After resume, check state
state = self._state_machine.state
if state == TimepointState.CAPTURED:
return # Done with timepoint (user chose to skip remaining)

# Just so the job result queues don't get too big, check and print a summary of intermediate results here
with self._timing.get_timer("job result summaries"):
result = self._summarize_runner_outputs()
Expand All @@ -1159,6 +1189,10 @@ def run_coordinate_acquisition(self, current_path):
with self._timing.get_timer("acquire_at_position"):
self.acquire_at_position(region_id, current_path, fov)

# Mark FOV as captured in state machine
if self._state_machine:
self._state_machine.mark_fov_captured()

if self.abort_requested_fn():
self.handle_acquisition_abort(current_path)
return
Expand Down Expand Up @@ -1613,6 +1647,75 @@ def handle_acquisition_abort(self, current_path):

self._wait_for_outstanding_callback_images()

def _handle_pause(self) -> None:
"""Handle pause request - wait for resume or retake."""
if not self._state_machine:
return

# Ensure all in-flight images are processed before pausing
self._wait_for_outstanding_callback_images()

# Transition to PAUSED state
self._state_machine.complete_pause()
self.callbacks.signal_paused()

# Wait for resume (loop handles retake cycles)
while True:
self._state_machine.wait_for_resume()

state = self._state_machine.state
if state == TimepointState.ACQUIRING:
self.callbacks.signal_resumed()
break
elif state == TimepointState.CAPTURED:
self.callbacks.signal_resumed()
break
elif state == TimepointState.RETAKING:
self._run_retakes()
# After retakes complete, we're back in PAUSED - loop again

def _run_retakes(self) -> None:
"""Execute retakes for FOVs in retake list."""
if not self._state_machine:
return

retake_list = self._state_machine.get_retake_list()
self.callbacks.signal_retake_started(retake_list)
self._retake_abort_requested = False

self._log.info(f"Starting retake of {len(retake_list)} FOVs")

for fov_id in retake_list:
# Check for retake abort
if self._retake_abort_requested:
self._log.info("Retake aborted by user request")
break

# Check for full acquisition abort
if self.abort_requested_fn():
self._log.info("Retake aborted due to acquisition abort")
break

# Get stored coordinates
coord = self._fov_coords_map.get((fov_id.region_id, fov_id.fov_index))
if coord is None:
self._log.warning(f"No coordinates for FOV {fov_id}, skipping")
continue

self._log.info(f"Retaking FOV: region={fov_id.region_id}, fov={fov_id.fov_index}")

# Move and acquire
self.move_to_coordinate(coord, fov_id.region_id, fov_id.fov_index)
if self._current_path:
self.acquire_at_position(fov_id.region_id, self._current_path, fov_id.fov_index)

self._state_machine.complete_retakes()
self.callbacks.signal_retakes_complete()

def request_retake_abort(self) -> None:
"""Request abort of current retake operation (called from controller)."""
self._retake_abort_requested = True

def move_z_for_stack(self):
if self.use_piezo:
self.z_piezo_um += self.deltaZ * 1000
Expand Down
Loading