diff --git a/software/control/core/multi_point_controller.py b/software/control/core/multi_point_controller.py index f9051eb73..c68565a93 100644 --- a/software/control/core/multi_point_controller.py +++ b/software/control/core/multi_point_controller.py @@ -9,7 +9,7 @@ from datetime import datetime from enum import Enum from threading import Thread -from typing import Optional, Tuple, Any +from typing import List, Optional, Tuple, Any import numpy as np import pandas as pd @@ -18,6 +18,7 @@ import control._def from control.core.auto_focus_controller import AutoFocusController from control.core.multi_point_utils import MultiPointControllerFunctions, ScanPositionInformation, AcquisitionParameters +from control.core.state_machine import TimepointStateMachine, TimepointState, FOVIdentifier from control.core.scan_coordinates import ScanCoordinates from control.core.laser_auto_focus_controller import LaserAutofocusController from control.core.live_controller import LiveController @@ -209,6 +210,7 @@ def __init__( self.objectiveStore: ObjectiveStore = objective_store self.callbacks: MultiPointControllerFunctions = callbacks self.multiPointWorker: Optional[MultiPointWorker] = None + self._state_machine: Optional[TimepointStateMachine] = None self.fluidics: Optional[Any] = microscope.addons.fluidics self.thread: Optional[Thread] = None self._per_acq_log_handler = None @@ -808,6 +810,14 @@ def finish_fn(): self.overlap_percent, ) + # Calculate total FOVs for this timepoint (for state machine) + total_fovs = sum(len(coords) for coords in scan_position_information.scan_region_fov_coords_mm.values()) + + # Create state machine for pause/resume/retake functionality + self._state_machine = TimepointStateMachine() + self._state_machine.reset(total_fovs) + self._state_machine.on_state_changed = self._on_state_changed + self.multiPointWorker = MultiPointWorker( scope=self.microscope, live_controller=self.liveController, @@ -821,6 +831,7 @@ def finish_fn(): extra_job_classes=[], alignment_widget=self._alignment_widget, slack_notifier=self._slack_notifier, + state_machine=self._state_machine, ) # Signal after worker creation so backpressure_controller is available @@ -946,6 +957,75 @@ def validate_acquisition_settings(self) -> bool: return False return True + # --- State Machine Control Methods --- + + def request_pause(self) -> bool: + """Request pause of current acquisition. + + The acquisition will pause after completing the current FOV. + + Returns: + True if pause request was accepted, False if not in a pausable state. + """ + if self._state_machine: + return self._state_machine.request_pause() + return False + + def request_resume(self) -> bool: + """Resume paused acquisition. + + Returns: + True if successfully initiated resume, False if not paused. + """ + if self._state_machine: + return self._state_machine.resume() + return False + + def request_retake(self, fovs: List[FOVIdentifier]) -> bool: + """Request retake of specified FOVs. + + Can only be called when acquisition is paused. + + Args: + fovs: List of FOV identifiers to retake. + + Returns: + True if retake started successfully, False otherwise. + """ + if self._state_machine: + return self._state_machine.retake(fovs) + return False + + def abort_retake(self) -> bool: + """Abort current retake operation only. + + Does not abort the entire acquisition - returns to paused state. + + Returns: + True if abort was handled, False otherwise. + """ + if self._state_machine: + accepted, abort_all = self._state_machine.abort() + if accepted and not abort_all and self.multiPointWorker: + self.multiPointWorker.request_retake_abort() + return accepted + return False + + def get_acquisition_state(self) -> Optional[TimepointState]: + """Get current acquisition state. + + Returns: + Current TimepointState, or None if no state machine exists. + """ + if self._state_machine: + return self._state_machine.state + return None + + def _on_state_changed(self, new_state: TimepointState) -> None: + """Handle state machine state changes (called from state machine thread).""" + self._log.info(f"Acquisition state changed to: {new_state.name}") + self.callbacks.signal_state_changed(new_state) + def get_plate_view(self) -> np.ndarray: """Get the current plate view array from the acquisition. diff --git a/software/control/core/multi_point_utils.py b/software/control/core/multi_point_utils.py index 067ce06a6..8f805bd20 100644 --- a/software/control/core/multi_point_utils.py +++ b/software/control/core/multi_point_utils.py @@ -9,6 +9,7 @@ if TYPE_CHECKING: from control.slack_notifier import TimepointStats, AcquisitionStats + from control.core.state_machine import TimepointState, FOVIdentifier @dataclass @@ -125,3 +126,10 @@ class MultiPointControllerFunctions: # Zarr frame written callback - called when subprocess completes writing a frame # Args: (fov, time_point, z_index, channel_name, region_idx) signal_zarr_frame_written: Callable[[int, int, int, str, int], None] = lambda *a, **kw: None + + # State machine signals for pause/resume/retake functionality + signal_state_changed: Callable[["TimepointState"], None] = lambda *a, **kw: None + signal_paused: Callable[[], None] = lambda *a, **kw: None + signal_resumed: Callable[[], None] = lambda *a, **kw: None + signal_retake_started: Callable[[List["FOVIdentifier"]], None] = lambda *a, **kw: None + signal_retakes_complete: Callable[[], None] = lambda *a, **kw: None diff --git a/software/control/core/multi_point_worker.py b/software/control/core/multi_point_worker.py index 564f66395..787cdf8a4 100644 --- a/software/control/core/multi_point_worker.py +++ b/software/control/core/multi_point_worker.py @@ -25,6 +25,7 @@ PlateViewInit, PlateViewUpdate, ) +from control.core.state_machine import TimepointStateMachine, TimepointState, FOVIdentifier from control.core.objective_store import ObjectiveStore from control.microcontroller import Microcontroller from control.microscope import Microscope @@ -81,12 +82,19 @@ def __init__( abort_on_failed_jobs: bool = True, alignment_widget=None, slack_notifier=None, + state_machine: Optional[TimepointStateMachine] = None, ): self._log = squid.logging.get_logger(__class__.__name__) self._timing = utils.TimingManager("MultiPointWorker Timer Manager") self._alignment_widget = alignment_widget # Optional AlignmentWidget for coordinate offset self._slack_notifier = slack_notifier # Optional SlackNotifier for notifications + # State machine for pause/resume/retake functionality + self._state_machine = state_machine + self._fov_coords_map: Dict[Tuple[str, int], Tuple[float, float, float]] = {} + self._retake_abort_requested = False + self._current_path: Optional[str] = None # Store current timepoint path for retakes + # Slack notification tracking counters self._timepoint_image_count = 0 self._timepoint_fov_count = 0 @@ -592,6 +600,14 @@ def run_single_time_point(self): with self._timing.get_timer("run_coordinate_acquisition"): self.run_coordinate_acquisition(current_path) + # Mark timepoint as captured in state machine + if self._state_machine: + self._state_machine.mark_all_captured() + + # Check for pause after timepoint complete (allows review before next timepoint) + if self._state_machine and self._state_machine.is_pause_requested(): + self._handle_pause() + # Save plate view for this timepoint if self._generate_downsampled_views and self._downsampled_view_manager is not None: # Wait for pending downsampled view jobs to complete @@ -1127,6 +1143,9 @@ def save_plate_view(self, path: str) -> None: self._downsampled_view_manager.save_plate_view(path) def run_coordinate_acquisition(self, current_path): + # Store current path for potential retakes + self._current_path = current_path + # Reset backpressure counters at acquisition start # IMPORTANT: Must be before any camera triggers self._backpressure.reset() @@ -1146,6 +1165,17 @@ def run_coordinate_acquisition(self, current_path): self.total_scans = self.num_fovs * self.NZ * len(self.selected_configurations) for fov, coordinate_mm in enumerate(coordinates): + # Store coordinates for potential retake + self._fov_coords_map[(region_id, fov)] = coordinate_mm + + # STATE MACHINE: Check for pause request before each FOV + if self._state_machine and self._state_machine.is_pause_requested(): + self._handle_pause() + # After resume, check state + state = self._state_machine.state + if state == TimepointState.CAPTURED: + return # Done with timepoint (user chose to skip remaining) + # Just so the job result queues don't get too big, check and print a summary of intermediate results here with self._timing.get_timer("job result summaries"): result = self._summarize_runner_outputs() @@ -1159,6 +1189,10 @@ def run_coordinate_acquisition(self, current_path): with self._timing.get_timer("acquire_at_position"): self.acquire_at_position(region_id, current_path, fov) + # Mark FOV as captured in state machine + if self._state_machine: + self._state_machine.mark_fov_captured() + if self.abort_requested_fn(): self.handle_acquisition_abort(current_path) return @@ -1613,6 +1647,75 @@ def handle_acquisition_abort(self, current_path): self._wait_for_outstanding_callback_images() + def _handle_pause(self) -> None: + """Handle pause request - wait for resume or retake.""" + if not self._state_machine: + return + + # Ensure all in-flight images are processed before pausing + self._wait_for_outstanding_callback_images() + + # Transition to PAUSED state + self._state_machine.complete_pause() + self.callbacks.signal_paused() + + # Wait for resume (loop handles retake cycles) + while True: + self._state_machine.wait_for_resume() + + state = self._state_machine.state + if state == TimepointState.ACQUIRING: + self.callbacks.signal_resumed() + break + elif state == TimepointState.CAPTURED: + self.callbacks.signal_resumed() + break + elif state == TimepointState.RETAKING: + self._run_retakes() + # After retakes complete, we're back in PAUSED - loop again + + def _run_retakes(self) -> None: + """Execute retakes for FOVs in retake list.""" + if not self._state_machine: + return + + retake_list = self._state_machine.get_retake_list() + self.callbacks.signal_retake_started(retake_list) + self._retake_abort_requested = False + + self._log.info(f"Starting retake of {len(retake_list)} FOVs") + + for fov_id in retake_list: + # Check for retake abort + if self._retake_abort_requested: + self._log.info("Retake aborted by user request") + break + + # Check for full acquisition abort + if self.abort_requested_fn(): + self._log.info("Retake aborted due to acquisition abort") + break + + # Get stored coordinates + coord = self._fov_coords_map.get((fov_id.region_id, fov_id.fov_index)) + if coord is None: + self._log.warning(f"No coordinates for FOV {fov_id}, skipping") + continue + + self._log.info(f"Retaking FOV: region={fov_id.region_id}, fov={fov_id.fov_index}") + + # Move and acquire + self.move_to_coordinate(coord, fov_id.region_id, fov_id.fov_index) + if self._current_path: + self.acquire_at_position(fov_id.region_id, self._current_path, fov_id.fov_index) + + self._state_machine.complete_retakes() + self.callbacks.signal_retakes_complete() + + def request_retake_abort(self) -> None: + """Request abort of current retake operation (called from controller).""" + self._retake_abort_requested = True + def move_z_for_stack(self): if self.use_piezo: self.z_piezo_um += self.deltaZ * 1000 diff --git a/software/control/core/state_machine.py b/software/control/core/state_machine.py new file mode 100644 index 000000000..1ca23fc4b --- /dev/null +++ b/software/control/core/state_machine.py @@ -0,0 +1,269 @@ +"""Acquisition state machine for timepoint-level pause/resume/retake functionality. + +This module provides a thread-safe state machine for controlling acquisition flow, +enabling users to pause mid-acquisition, review images, and retake specific FOVs. +""" + +from enum import Enum, auto +from dataclasses import dataclass +from typing import List, Tuple, Optional, Callable +import threading + + +class TimepointState(Enum): + """States for timepoint acquisition.""" + + ACQUIRING = auto() # Actively acquiring FOVs + PAUSED = auto() # Acquisition paused, waiting for user action + RETAKING = auto() # Re-acquiring specific FOVs + CAPTURED = auto() # All FOVs captured, waiting for next timepoint or review + + +@dataclass(frozen=True) +class FOVIdentifier: + """Unique identifier for a field of view within an acquisition.""" + + region_id: str + fov_index: int + + +class TimepointStateMachine: + """Thread-safe state machine for timepoint acquisition. + + This state machine manages the acquisition flow within a single timepoint, + enabling pause/resume and retake functionality. It is designed to be used + from both the UI thread (for control operations) and the worker thread + (for state queries and transitions). + + State transitions: + ACQUIRING -> PAUSED (via request_pause + complete_pause) + ACQUIRING -> CAPTURED (via mark_all_captured) + PAUSED -> ACQUIRING (via resume, if FOVs remaining) + PAUSED -> CAPTURED (via resume, if no FOVs remaining) + PAUSED -> RETAKING (via retake) + RETAKING -> PAUSED (via complete_retakes or abort) + CAPTURED -> PAUSED (via request_pause + complete_pause) + """ + + def __init__(self): + self._state = TimepointState.ACQUIRING + self._lock = threading.Lock() + self._pause_requested = threading.Event() + self._resume_event = threading.Event() + self._retake_list: List[FOVIdentifier] = [] + self._fovs_remaining = 0 + + # Callbacks (set by controller) + self.on_state_changed: Optional[Callable[[TimepointState], None]] = None + + @property + def state(self) -> TimepointState: + """Get current state (thread-safe).""" + with self._lock: + return self._state + + def reset(self, total_fovs: int) -> None: + """Reset state machine for a new timepoint. + + Args: + total_fovs: Total number of FOVs to acquire in this timepoint. + """ + with self._lock: + self._state = TimepointState.ACQUIRING + self._fovs_remaining = total_fovs + self._retake_list.clear() + self._pause_requested.clear() + self._resume_event.clear() + + # --- Pause/Resume --- + + def request_pause(self) -> bool: + """Request pause of acquisition (called from UI thread). + + The actual transition to PAUSED happens when the worker calls + complete_pause() after finishing the current FOV. + + Returns: + True if pause request was accepted, False if not in a pausable state. + """ + with self._lock: + if self._state in (TimepointState.ACQUIRING, TimepointState.CAPTURED): + self._pause_requested.set() + return True + return False + + def is_pause_requested(self) -> bool: + """Check if pause has been requested (called from worker thread).""" + return self._pause_requested.is_set() + + def complete_pause(self) -> bool: + """Complete transition to PAUSED state (called from worker thread). + + This should be called after ensuring all in-flight images are processed. + + Returns: + True if successfully transitioned to PAUSED, False otherwise. + """ + with self._lock: + if self._pause_requested.is_set(): + self._pause_requested.clear() + old_state = self._state + self._state = TimepointState.PAUSED + self._notify_state_changed(old_state) + return True + return False + + def resume(self) -> bool: + """Resume acquisition from PAUSED state (called from UI thread). + + Transitions to ACQUIRING if FOVs remain, otherwise to CAPTURED. + + Returns: + True if successfully initiated resume, False if not paused. + """ + with self._lock: + if self._state != TimepointState.PAUSED: + return False + + old_state = self._state + if self._fovs_remaining > 0: + self._state = TimepointState.ACQUIRING + else: + self._state = TimepointState.CAPTURED + + self._notify_state_changed(old_state) + self._resume_event.set() + return True + + def wait_for_resume(self, timeout: Optional[float] = None) -> bool: + """Block until resumed or timeout (called from worker thread). + + Args: + timeout: Maximum seconds to wait, or None for indefinite wait. + + Returns: + True if resumed, False if timed out. + """ + result = self._resume_event.wait(timeout) + self._resume_event.clear() + return result + + # --- Retake --- + + def retake(self, fovs: List[FOVIdentifier]) -> bool: + """Start retaking specified FOVs (called from UI thread). + + Can only be called from PAUSED state. + + Args: + fovs: List of FOV identifiers to retake. + + Returns: + True if retake started successfully, False otherwise. + """ + with self._lock: + if self._state != TimepointState.PAUSED: + return False + if not fovs: + return False + + self._retake_list = list(fovs) + old_state = self._state + self._state = TimepointState.RETAKING + self._notify_state_changed(old_state) + self._resume_event.set() + return True + + def get_retake_list(self) -> List[FOVIdentifier]: + """Get list of FOVs to retake (called from worker thread). + + Returns: + Copy of the retake list. + """ + with self._lock: + return list(self._retake_list) + + def complete_retakes(self) -> bool: + """Finish retaking, return to PAUSED (called from worker thread). + + Returns: + True if successfully transitioned back to PAUSED, False otherwise. + """ + with self._lock: + if self._state != TimepointState.RETAKING: + return False + + self._retake_list.clear() + old_state = self._state + self._state = TimepointState.PAUSED + self._notify_state_changed(old_state) + return True + + # --- FOV Tracking --- + + def mark_fov_captured(self) -> None: + """Mark one FOV as captured (called from worker thread).""" + with self._lock: + if self._fovs_remaining > 0: + self._fovs_remaining -= 1 + + def mark_all_captured(self) -> bool: + """Transition to CAPTURED state after all FOVs done (called from worker thread). + + Returns: + True if successfully transitioned to CAPTURED, False otherwise. + """ + with self._lock: + if self._state != TimepointState.ACQUIRING: + return False + + old_state = self._state + self._state = TimepointState.CAPTURED + self._notify_state_changed(old_state) + return True + + @property + def fovs_remaining(self) -> int: + """Get count of remaining FOVs to acquire (thread-safe).""" + with self._lock: + return self._fovs_remaining + + # --- Abort --- + + def abort(self) -> Tuple[bool, bool]: + """Abort current operation (called from UI thread). + + When called during RETAKING, only aborts the retake operation and + returns to PAUSED. In other states, signals to abort the entire + acquisition. + + Returns: + Tuple of (accepted, abort_entire_acquisition): + - accepted: True if abort was handled + - abort_entire_acquisition: True if full acquisition should abort, + False if only retake was aborted + """ + with self._lock: + if self._state == TimepointState.RETAKING: + # Abort retake only - return to PAUSED + self._retake_list.clear() + old_state = self._state + self._state = TimepointState.PAUSED + self._notify_state_changed(old_state) + self._resume_event.set() # Unblock worker + return (True, False) + else: + # Abort entire acquisition + return (True, True) + + def _notify_state_changed(self, old_state: TimepointState) -> None: + """Call state change callback if state changed (must hold lock).""" + if self.on_state_changed and old_state != self._state: + new_state = self._state + # Schedule callback outside lock to avoid deadlock + # Using a daemon thread so it doesn't block shutdown + threading.Thread( + target=self.on_state_changed, + args=(new_state,), + daemon=True, + ).start() diff --git a/software/docs/plans/2026-02-04-acquisition-state-machine-design.md b/software/docs/plans/2026-02-04-acquisition-state-machine-design.md new file mode 100644 index 000000000..3e64b3a41 --- /dev/null +++ b/software/docs/plans/2026-02-04-acquisition-state-machine-design.md @@ -0,0 +1,396 @@ +# Acquisition State Machine Design + +## Overview + +This document describes a state machine design for the Squid microscope acquisition system that enables: + +1. **Pause/Resume**: Gracefully pause acquisition mid-timepoint or after completion +2. **Selective Retake**: Review recent FOVs and retake specific ones +3. **QC Integration**: Hook points for quality control checks that can trigger pauses or suggest retakes + +## Design Goals + +- Minimal changes to existing `MultiPointWorker` acquisition loop +- Clear separation between state management and acquisition logic +- Support manual pause/retake as primary use case +- Extensible for future automated QC workflows +- Per-timepoint scope (previous timepoints are immutable) + +## Architecture + +### Acquisition Level (No State Machine) + +Simple tracking variables - no formal state machine needed: + +```python +@dataclass +class AcquisitionContext: + current_timepoint: int = 0 + total_timepoints: int = 1 + aborted: bool = False + proceed_policy: ProceedPolicy = ProceedPolicy.AUTO + +class ProceedPolicy(Enum): + AUTO = auto() # Proceed immediately when CAPTURED + MANUAL = auto() # Wait for explicit proceed() + QC_GATED = auto() # Wait for QC approval (future) +``` + +**Logic:** +- Acquisition is "running" while `current_timepoint < total_timepoints` and not `aborted` +- Proceed to next timepoint when current timepoint reaches `CAPTURED` state +- Proceed behavior controlled by `proceed_policy` + +### Timepoint Level (State Machine) + +Manages pause/resume/retake within a single timepoint. + +``` + ┌───────────┐ pause() ┌────────┐ + │ ACQUIRING │─────────────▶│ PAUSED │ + └─────┬─────┘ └───┬────┘ + │ │ │ + │ │ ├─── resume() ──▶ ACQUIRING + abort() │ │ │ (FOVs remaining) + │ │ │ + │ │ all FOVs done ├─── resume() ──▶ CAPTURED + │ │ │ (no FOVs remaining) + │ ▼ │ + │ ┌──────────┐ pause() │ retake([fovs]) + │ │ CAPTURED │──────────────▶│ + │ └────┬─────┘ ▼ + │ │ ┌──────────┐ + │ │ abort() │ RETAKING │ + │ │ └────┬─────┘ + │ │ │ │ + │ │ ┌───────────┘ │ abort() + │ │ │ done │ (retaking only) + │ │ ▼ ▼ + │ │ ┌────────┐ ┌────────┐ + ▼ ▼ │ PAUSED │ │ PAUSED │ + Abort entire └────────┘ └────────┘ + acquisition +``` + +## States + +| State | Description | +|-------|-------------| +| `ACQUIRING` | Initial capture of FOVs in sequence | +| `PAUSED` | Decision point - operator reviews, selects retakes, or proceeds | +| `RETAKING` | Re-capturing specific FOVs from the retake list | +| `CAPTURED` | All FOVs captured, ready for next timepoint | + +## Transitions + +| From | To | Trigger | Notes | +|------|----|---------|-------| +| `ACQUIRING` | `PAUSED` | `pause()` | Graceful - completes current FOV first | +| `ACQUIRING` | `CAPTURED` | All FOVs done | Automatic transition | +| `CAPTURED` | `PAUSED` | `pause()` | For review before next timepoint | +| `PAUSED` | `ACQUIRING` | `resume()` | Only if FOVs remaining | +| `PAUSED` | `CAPTURED` | `resume()` | Only if no FOVs remaining | +| `PAUSED` | `RETAKING` | `retake(fov_list)` | Receives list of (region_id, fov_index) | +| `RETAKING` | `PAUSED` | Retake list complete | Returns to PAUSED for review | + +## Abort Behavior + +| Abort From | Effect | +|------------|--------| +| `ACQUIRING` | Abort entire acquisition | +| `PAUSED` | Abort entire acquisition | +| `CAPTURED` | Abort entire acquisition | +| `RETAKING` | Abort retaking only, return to `PAUSED` | + +## Pause Behavior + +Pause is **graceful**: +1. Current FOV capture completes +2. All jobs for current FOV are dispatched (save, QC, etc.) +3. Then state transitions to `PAUSED` + +No half-captured images or orphaned jobs. + +## Retake Mechanism + +**Input:** +- `retake(fov_list)` receives a list of `(region_id, fov_index)` tuples +- List provided by user (manual selection) or QC system (automated) +- State machine does not track per-FOV status internally + +**Behavior:** +- Retaking overwrites original files (no versioning) +- After retakes complete, returns to `PAUSED` for review +- Operator can trigger additional retakes or resume + +**Identification:** +- FOVs identified by index pair: `(region_id, fov_index)` +- Matches existing loop structure in `MultiPointWorker` + +## Interface + +```python +from enum import Enum, auto +from typing import List, Tuple, Callable, Optional +from dataclasses import dataclass +import threading + +class TimepointState(Enum): + ACQUIRING = auto() + PAUSED = auto() + RETAKING = auto() + CAPTURED = auto() + +@dataclass(frozen=True) +class FOVIdentifier: + region_id: str + fov_index: int + +class TimepointStateMachine: + """Manages state for a single timepoint.""" + + def __init__(self, total_fovs: int): + self._state = TimepointState.ACQUIRING + self._lock = threading.Lock() + self._pause_requested = threading.Event() + self._resume_event = threading.Event() + self._retake_list: List[FOVIdentifier] = [] + self._fovs_remaining = total_fovs + self._total_fovs = total_fovs + + @property + def state(self) -> TimepointState: + """Current state.""" + with self._lock: + return self._state + + @property + def fovs_remaining(self) -> int: + """Number of FOVs not yet captured.""" + with self._lock: + return self._fovs_remaining + + def request_pause(self) -> bool: + """ + Request pause. Returns True if request accepted. + Pause is graceful - completes current FOV first. + Valid from: ACQUIRING, CAPTURED + """ + with self._lock: + if self._state in (TimepointState.ACQUIRING, TimepointState.CAPTURED): + self._pause_requested.set() + return True + return False + + def wait_for_pause(self, timeout: Optional[float] = None) -> bool: + """Block until pause is requested. Used by worker thread.""" + return self._pause_requested.wait(timeout) + + def complete_pause(self) -> bool: + """ + Called by worker after completing current FOV. + Actually transitions to PAUSED state. + """ + with self._lock: + if self._pause_requested.is_set(): + self._pause_requested.clear() + self._state = TimepointState.PAUSED + return True + return False + + def resume(self) -> bool: + """ + Resume acquisition. + Valid from: PAUSED + Transitions to: ACQUIRING (if FOVs remaining) or CAPTURED (if done) + """ + with self._lock: + if self._state != TimepointState.PAUSED: + return False + + if self._fovs_remaining > 0: + self._state = TimepointState.ACQUIRING + else: + self._state = TimepointState.CAPTURED + + self._resume_event.set() + return True + + def wait_for_resume(self, timeout: Optional[float] = None) -> bool: + """Block until resumed. Used by worker thread.""" + result = self._resume_event.wait(timeout) + self._resume_event.clear() + return result + + def retake(self, fovs: List[FOVIdentifier]) -> bool: + """ + Start retaking specified FOVs. + Valid from: PAUSED + Transitions to: RETAKING + """ + with self._lock: + if self._state != TimepointState.PAUSED: + return False + if not fovs: + return False + + self._retake_list = list(fovs) + self._state = TimepointState.RETAKING + self._resume_event.set() + return True + + def get_retake_list(self) -> List[FOVIdentifier]: + """Get current retake list.""" + with self._lock: + return list(self._retake_list) + + def complete_retakes(self) -> bool: + """ + Called by worker when retake list is complete. + Transitions to: PAUSED + """ + with self._lock: + if self._state != TimepointState.RETAKING: + return False + + self._retake_list.clear() + self._state = TimepointState.PAUSED + return True + + def mark_fov_captured(self) -> None: + """Called by worker when an FOV is captured.""" + with self._lock: + if self._fovs_remaining > 0: + self._fovs_remaining -= 1 + + def mark_all_captured(self) -> bool: + """ + Called by worker when all FOVs are done. + Transitions to: CAPTURED + """ + with self._lock: + if self._state != TimepointState.ACQUIRING: + return False + + self._state = TimepointState.CAPTURED + return True + + def abort(self) -> Tuple[bool, bool]: + """ + Abort current operation. + Returns: (abort_accepted, abort_entire_acquisition) + + From RETAKING: aborts retake only, returns to PAUSED + From other states: aborts entire acquisition + """ + with self._lock: + if self._state == TimepointState.RETAKING: + self._retake_list.clear() + self._state = TimepointState.PAUSED + return (True, False) # Abort retake only + else: + return (True, True) # Abort entire acquisition +``` + +## Integration with MultiPointWorker + +```python +# Pseudocode for modified acquisition loop + +class MultiPointWorker: + def __init__(self, ..., state_machine: TimepointStateMachine): + self._state_machine = state_machine + ... + + def run_coordinate_acquisition(self): + for region_id, coords in self.regions: + for fov_index, coord in enumerate(coords): + + # Check for pause request before each FOV + if self._state_machine._pause_requested.is_set(): + self._finish_current_jobs() + self._state_machine.complete_pause() + + # Wait for resume or retake + self._state_machine.wait_for_resume() + + # Check what state we're in after resume + state = self._state_machine.state + if state == TimepointState.CAPTURED: + return # Done with this timepoint + elif state == TimepointState.RETAKING: + self._run_retakes() + continue # Check state again + + # Check for abort + if self._abort_requested(): + return + + # Acquire FOV + self.acquire_at_position(region_id, fov_index, coord) + self._state_machine.mark_fov_captured() + + # All FOVs done + self._state_machine.mark_all_captured() + + def _run_retakes(self): + """Execute retakes for FOVs in retake list.""" + retake_list = self._state_machine.get_retake_list() + + for fov_id in retake_list: + # Check for retake abort + if self._retake_abort_requested(): + break + + coord = self._get_fov_coordinates(fov_id) + self.acquire_at_position(fov_id.region_id, fov_id.fov_index, coord) + + self._state_machine.complete_retakes() +``` + +## Signals (PyQt Integration) + +```python +@dataclass +class StateMachineSignals: + # State changes + signal_state_changed: Callable[[TimepointState], None] + + # Pause/resume + signal_pause_requested: Callable[[], None] + signal_paused: Callable[[], None] + signal_resumed: Callable[[], None] + + # Retake + signal_retake_started: Callable[[List[FOVIdentifier]], None] + signal_retake_fov_complete: Callable[[FOVIdentifier], None] + signal_retakes_complete: Callable[[], None] + + # Progress + signal_fov_captured: Callable[[FOVIdentifier], None] + signal_timepoint_captured: Callable[[], None] +``` + +## Thread Safety + +State transitions are thread-safe: +- `threading.Lock` protects all state mutations +- `threading.Event` for pause/resume signaling +- Worker thread waits on events; UI thread signals them + +## QC Integration Points + +The state machine provides hooks for QC system (detailed in separate document): + +1. **After FOV captured**: QC job dispatched, result can call `request_pause()` +2. **In PAUSED state**: QC can provide suggested retake list via UI +3. **Before proceed**: QC gate can block advancement to next timepoint + +## Summary + +| Component | Responsibility | +|-----------|----------------| +| `AcquisitionContext` | Tracks timepoint index, abort flag, proceed policy | +| `TimepointStateMachine` | Pause/resume/retake within a timepoint | +| `MultiPointWorker` | Actual image acquisition, responds to state machine | +| `ProceedPolicy` | Configures auto vs manual timepoint progression | diff --git a/software/docs/plans/2026-02-04-qc-system-design.md b/software/docs/plans/2026-02-04-qc-system-design.md new file mode 100644 index 000000000..5662da0bc --- /dev/null +++ b/software/docs/plans/2026-02-04-qc-system-design.md @@ -0,0 +1,574 @@ +# QC System Design + +## Overview + +This document describes the Quality Control (QC) system for the Squid microscope acquisition. The QC system: + +1. **Collects metrics** per-FOV during acquisition (focus score, z-position, etc.) +2. **Stores metrics** for analysis and review +3. **Applies policies** to decide when to pause and which FOVs to flag for retake + +## Design Goals + +- QC runs in parallel with acquisition (as Jobs in subprocess) +- Clean separation: metrics collection vs. policy decisions +- Swappable QC methods for different applications +- Extensible for future automation +- Simple initial implementation focused on manual review + +## Architecture + +``` +┌─────────────────────────────────────────────────────────────┐ +│ MultiPointWorker │ +│ │ +│ acquire_at_position() │ +│ │ │ +│ ▼ │ +│ ┌─────────────┐ dispatch ┌─────────────────────┐ │ +│ │ CameraFrame │ ─────────────────▶│ SaveImageJob │ │ +│ └─────────────┘ │ └─────────────────────┘ │ +│ │ │ +│ │ ┌─────────────────────┐ │ +│ └─────────▶│ QCJob │ │ +│ └──────────┬──────────┘ │ +└───────────────────────────────────────────────│─────────────┘ + │ + ┌───────────▼───────────┐ + │ JobRunner │ + │ (QC subprocess) │ + │ │ + │ ┌─────────────────┐ │ + │ │ QCMetricsCalc │ │ + │ │ - focus_score() │ │ + │ │ - z_position() │ │ + │ │ - ... │ │ + │ └─────────────────┘ │ + └───────────┬───────────┘ + │ + ▼ + ┌───────────────────────┐ + │ QCResult │ + │ (via output_queue) │ + └───────────┬───────────┘ + │ + ┌─────────────────────────┼─────────────────────────┐ + │ │ │ + ▼ ▼ ▼ + ┌─────────────────┐ ┌─────────────────┐ ┌─────────────────┐ + │ MetricsStore │ │ QCPolicy │ │ UI Display │ + │ (per-timepoint) │◀────▶│ (check at end) │ │ (live metrics) │ + └─────────────────┘ └────────┬────────┘ └─────────────────┘ + │ + ▼ + ┌─────────────────┐ + │ PolicyDecision │ + │ - flagged_fovs │ + │ - should_pause │ + └─────────────────┘ +``` + +## Components + +### 1. QCJob + +A Job subclass that runs in a subprocess, calculates metrics for a single FOV. + +```python +@dataclass +class QCJob(Job[QCResult]): + """Quality control job for a single FOV.""" + + capture_info: CaptureInfo + capture_image: JobImage + qc_config: QCConfig + previous_timepoint_z: Optional[float] = None # For z-drift calculation + + def run(self) -> QCResult: + image = self.capture_image.get_image() + metrics = FOVMetrics( + fov_id=FOVIdentifier( + region_id=self.capture_info.region_id, + fov_index=self.capture_info.fov, + ), + timestamp=self.capture_info.capture_time, + z_position_um=self.capture_info.position.z_mm * 1000, + ) + + # Calculate enabled metrics + if self.qc_config.calculate_focus_score: + metrics.focus_score = calculate_focus_score(image) + + if self.qc_config.record_laser_af_displacement: + metrics.laser_af_displacement_um = self.capture_info.z_piezo_um + + if self.previous_timepoint_z is not None: + metrics.z_diff_from_last_timepoint_um = ( + metrics.z_position_um - self.previous_timepoint_z + ) + + return QCResult(metrics=metrics) +``` + +### 2. FOVMetrics + +Data class holding all QC metrics for a single FOV. + +```python +@dataclass +class FOVMetrics: + """QC metrics for a single FOV.""" + + fov_id: FOVIdentifier + timestamp: float + z_position_um: float + + # Optional metrics (calculated based on config) + focus_score: Optional[float] = None + laser_af_displacement_um: Optional[float] = None + z_diff_from_last_timepoint_um: Optional[float] = None + + # Extensible: add more metrics as needed + # e.g., saturation_percent, background_intensity, cell_count, etc. +``` + +### 3. QCResult + +Result returned by QCJob. + +```python +@dataclass +class QCResult: + """Result from QC job.""" + + metrics: FOVMetrics + error: Optional[str] = None # If QC calculation failed +``` + +### 4. TimepointMetricsStore + +Holds all QC metrics for the current timepoint. + +```python +class TimepointMetricsStore: + """Stores QC metrics for a single timepoint.""" + + def __init__(self, timepoint_index: int): + self._timepoint = timepoint_index + self._metrics: Dict[FOVIdentifier, FOVMetrics] = {} + self._lock = threading.Lock() + + def add(self, metrics: FOVMetrics) -> None: + """Add metrics for an FOV.""" + with self._lock: + self._metrics[metrics.fov_id] = metrics + + def get(self, fov_id: FOVIdentifier) -> Optional[FOVMetrics]: + """Get metrics for a specific FOV.""" + with self._lock: + return self._metrics.get(fov_id) + + def get_all(self) -> List[FOVMetrics]: + """Get all metrics for this timepoint.""" + with self._lock: + return list(self._metrics.values()) + + def get_metric_values(self, metric_name: str) -> Dict[FOVIdentifier, float]: + """Get a specific metric across all FOVs.""" + with self._lock: + result = {} + for fov_id, m in self._metrics.items(): + value = getattr(m, metric_name, None) + if value is not None: + result[fov_id] = value + return result + + def to_dataframe(self) -> pd.DataFrame: + """Export metrics as DataFrame for analysis.""" + ... + + def save(self, path: str) -> None: + """Persist metrics to disk (CSV or JSON).""" + ... +``` + +### 5. QCConfig + +Configuration for which metrics to calculate. + +```python +@dataclass +class QCConfig: + """Configuration for QC metrics collection.""" + + enabled: bool = False + + # Which metrics to calculate + calculate_focus_score: bool = True + record_laser_af_displacement: bool = False + calculate_z_diff_from_last_timepoint: bool = False + + # Focus score method + focus_score_method: str = "laplacian_variance" # or "normalized_variance", etc. +``` + +### 6. QCPolicy + +Configurable rules for deciding when to pause and which FOVs to flag. + +```python +@dataclass +class QCPolicyConfig: + """Configuration for QC policy decisions.""" + + enabled: bool = False + + # When to run policy checks + check_after_timepoint: bool = True + # Future: check_after_fov: bool = False + + # Threshold-based rules + focus_score_min: Optional[float] = None # Flag FOVs below this + z_drift_max_um: Optional[float] = None # Flag FOVs exceeding this + + # Outlier detection + detect_outliers: bool = False + outlier_metric: str = "focus_score" # Which metric to check + outlier_std_threshold: float = 2.0 # Flag if > N std from mean + + # Action when FOVs are flagged + pause_if_any_flagged: bool = True + + +class QCPolicy: + """Evaluates QC metrics and decides on actions.""" + + def __init__(self, config: QCPolicyConfig): + self._config = config + + def check_timepoint(self, metrics_store: TimepointMetricsStore) -> PolicyDecision: + """ + Evaluate all FOVs in a timepoint. + Called when timepoint reaches CAPTURED state. + + Returns: + PolicyDecision with flagged FOVs and whether to pause + """ + flagged: List[FOVIdentifier] = [] + reasons: Dict[FOVIdentifier, List[str]] = {} + + all_metrics = metrics_store.get_all() + + # Threshold checks + if self._config.focus_score_min is not None: + for m in all_metrics: + if m.focus_score is not None and m.focus_score < self._config.focus_score_min: + flagged.append(m.fov_id) + reasons.setdefault(m.fov_id, []).append( + f"focus_score={m.focus_score:.2f} < {self._config.focus_score_min}" + ) + + if self._config.z_drift_max_um is not None: + for m in all_metrics: + if m.z_diff_from_last_timepoint_um is not None: + if abs(m.z_diff_from_last_timepoint_um) > self._config.z_drift_max_um: + if m.fov_id not in flagged: + flagged.append(m.fov_id) + reasons.setdefault(m.fov_id, []).append( + f"z_drift={m.z_diff_from_last_timepoint_um:.2f}um > {self._config.z_drift_max_um}" + ) + + # Outlier detection + if self._config.detect_outliers: + outliers = self._detect_outliers( + metrics_store, + self._config.outlier_metric, + self._config.outlier_std_threshold, + ) + for fov_id in outliers: + if fov_id not in flagged: + flagged.append(fov_id) + reasons.setdefault(fov_id, []).append( + f"outlier in {self._config.outlier_metric}" + ) + + should_pause = self._config.pause_if_any_flagged and len(flagged) > 0 + + return PolicyDecision( + flagged_fovs=flagged, + flag_reasons=reasons, + should_pause=should_pause, + ) + + def _detect_outliers( + self, + metrics_store: TimepointMetricsStore, + metric_name: str, + std_threshold: float, + ) -> List[FOVIdentifier]: + """Detect outliers using standard deviation method.""" + values = metrics_store.get_metric_values(metric_name) + if len(values) < 3: + return [] + + arr = np.array(list(values.values())) + mean, std = arr.mean(), arr.std() + + outliers = [] + for fov_id, value in values.items(): + if abs(value - mean) > std_threshold * std: + outliers.append(fov_id) + + return outliers + + # Future: add check_fov() for immediate per-FOV decisions + # def check_fov(self, metrics: FOVMetrics) -> QCAction: + # """Called immediately after FOV QC completes.""" + # if self._config.focus_score_min and metrics.focus_score < self._config.focus_score_min: + # return QCAction.PAUSE + # return QCAction.CONTINUE + + +@dataclass +class PolicyDecision: + """Result of QC policy evaluation.""" + + flagged_fovs: List[FOVIdentifier] + flag_reasons: Dict[FOVIdentifier, List[str]] + should_pause: bool +``` + +## Integration with Acquisition + +### Job Dispatch + +QCJob is dispatched alongside other jobs in `MultiPointWorker._image_callback()`: + +```python +def _image_callback(self, camera_frame: CameraFrame): + # ... existing job dispatch ... + + # Add QC job if enabled + if self._qc_config.enabled: + qc_job = QCJob( + capture_info=info, + capture_image=JobImage(camera_frame.frame.copy()), + qc_config=self._qc_config, + previous_timepoint_z=self._get_previous_timepoint_z(info.fov_id), + ) + self._qc_job_runner.dispatch(qc_job) +``` + +### Result Processing + +QC results are handled in `_summarize_runner_outputs()`: + +```python +def _summarize_job_result(self, result: JobResult): + if isinstance(result.result, QCResult): + qc_result = result.result + + # Store metrics + self._metrics_store.add(qc_result.metrics) + + # Emit signal for UI update + self.callbacks.signal_qc_metrics_updated(qc_result.metrics) + + # Future: immediate per-FOV policy check + # if self._qc_policy.check_fov(qc_result.metrics) == QCAction.PAUSE: + # self._state_machine.request_pause() +``` + +### Timepoint End Policy Check + +When timepoint reaches CAPTURED state: + +```python +def _on_timepoint_captured(self): + # Run QC policy check + if self._qc_policy_config.enabled and self._qc_policy_config.check_after_timepoint: + decision = self._qc_policy.check_timepoint(self._metrics_store) + + # Emit signal with flagged FOVs for UI + self.callbacks.signal_qc_policy_decision(decision) + + if decision.should_pause: + self._state_machine.request_pause() + # UI will show flagged FOVs, user can select retakes +``` + +## Focus Score Methods + +Swappable focus score algorithms: + +```python +def calculate_focus_score(image: np.ndarray, method: str = "laplacian_variance") -> float: + """Calculate focus score using specified method.""" + + if method == "laplacian_variance": + # Variance of Laplacian - higher = more in focus + laplacian = cv2.Laplacian(image, cv2.CV_64F) + return laplacian.var() + + elif method == "normalized_variance": + # Normalized variance - robust to intensity variations + mean = image.mean() + if mean == 0: + return 0.0 + return image.var() / mean + + elif method == "gradient_magnitude": + # Sum of gradient magnitudes + gx = cv2.Sobel(image, cv2.CV_64F, 1, 0) + gy = cv2.Sobel(image, cv2.CV_64F, 0, 1) + return np.sqrt(gx**2 + gy**2).mean() + + elif method == "fft_high_freq": + # High-frequency content in FFT + fft = np.fft.fft2(image) + fft_shift = np.fft.fftshift(fft) + # Mask out low frequencies + h, w = image.shape[:2] + cy, cx = h // 2, w // 2 + mask_size = min(h, w) // 8 + fft_shift[cy-mask_size:cy+mask_size, cx-mask_size:cx+mask_size] = 0 + return np.abs(fft_shift).mean() + + else: + raise ValueError(f"Unknown focus method: {method}") +``` + +## Signals for UI + +```python +@dataclass +class QCSignals: + # Per-FOV metrics (for live display) + signal_qc_metrics_updated: Callable[[FOVMetrics], None] + + # Policy decision (at timepoint end) + signal_qc_policy_decision: Callable[[PolicyDecision], None] +``` + +## Data Persistence + +Metrics are saved alongside acquisition data: + +``` +{experiment_path}/ +├── 000/ # Timepoint 0 +│ ├── images/ +│ └── qc_metrics.csv # QC metrics for this timepoint +├── 001/ # Timepoint 1 +│ ├── images/ +│ └── qc_metrics.csv +└── qc_summary.json # Overall QC summary +``` + +## Configuration Example + +```yaml +# In acquisition config or separate qc_config.yaml + +qc: + enabled: true + + metrics: + calculate_focus_score: true + focus_score_method: "laplacian_variance" + record_laser_af_displacement: true + calculate_z_diff_from_last_timepoint: true + + policy: + enabled: true + check_after_timepoint: true + + # Threshold rules + focus_score_min: 100.0 + z_drift_max_um: 5.0 + + # Outlier detection + detect_outliers: true + outlier_metric: "focus_score" + outlier_std_threshold: 2.0 + + # Action + pause_if_any_flagged: true +``` + +## Future Extensions + +### 1. Immediate Per-FOV Pause + +Add `check_fov()` to QCPolicy: + +```python +def check_fov(self, metrics: FOVMetrics) -> QCAction: + if self._config.focus_score_min and metrics.focus_score < self._config.focus_score_min: + return QCAction.PAUSE + return QCAction.CONTINUE +``` + +Call from `_summarize_job_result()` after storing metrics. + +### 2. Custom QC Methods + +Plugin system for custom QC calculations: + +```python +class QCMethod(ABC): + @property + @abstractmethod + def name(self) -> str: ... + + @abstractmethod + def calculate(self, image: np.ndarray, capture_info: CaptureInfo) -> Dict[str, float]: ... + +# Register custom methods +qc_registry.register(MyCustomQCMethod()) +``` + +### 3. Cross-Timepoint Analysis + +Compare metrics across timepoints: + +```python +class AcquisitionMetricsStore: + """Holds metrics for all timepoints.""" + + def get_fov_history(self, fov_id: FOVIdentifier) -> List[FOVMetrics]: + """Get metrics for an FOV across all timepoints.""" + ... + + def detect_drift_trend(self, fov_id: FOVIdentifier) -> DriftTrend: + """Analyze z-drift trend over time.""" + ... +``` + +### 4. Automated Retake Strategies + +Different retry behaviors per QC failure type: + +```python +@dataclass +class RetakeStrategy: + re_autofocus: bool = True + adjust_exposure: bool = False + max_attempts: int = 2 +``` + +## Summary + +| Component | Responsibility | +|-----------|----------------| +| `QCJob` | Calculates metrics per-FOV in subprocess | +| `FOVMetrics` | Data class for QC measurements | +| `TimepointMetricsStore` | Holds all metrics for current timepoint | +| `QCPolicy` | Evaluates metrics, decides when to pause | +| `PolicyDecision` | Result with flagged FOVs and actions | + +The design prioritizes: +- **Parallel execution**: QC runs as Job, doesn't slow acquisition +- **Separation of concerns**: Metrics vs. policy decisions +- **Flexibility**: Swappable focus methods, configurable policies +- **Extensibility**: Easy to add immediate per-FOV checks later diff --git a/software/tests/control/test_state_machine.py b/software/tests/control/test_state_machine.py new file mode 100644 index 000000000..8bc813ca7 --- /dev/null +++ b/software/tests/control/test_state_machine.py @@ -0,0 +1,353 @@ +"""Unit tests for the acquisition state machine. + +Tests cover: +- State transitions +- Invalid transitions +- Thread safety +- Abort behavior +""" + +import threading +import time +import pytest + +from control.core.state_machine import TimepointStateMachine, TimepointState, FOVIdentifier + + +class TestTimepointStateMachine: + """Tests for TimepointStateMachine.""" + + def test_initial_state(self): + """State machine starts in ACQUIRING state.""" + sm = TimepointStateMachine() + assert sm.state == TimepointState.ACQUIRING + + def test_reset(self): + """Reset initializes state and FOV count.""" + sm = TimepointStateMachine() + sm.reset(total_fovs=10) + assert sm.state == TimepointState.ACQUIRING + assert sm.fovs_remaining == 10 + + # --- Pause/Resume Tests --- + + def test_request_pause_from_acquiring(self): + """Pause can be requested from ACQUIRING state.""" + sm = TimepointStateMachine() + sm.reset(total_fovs=5) + assert sm.request_pause() is True + assert sm.is_pause_requested() is True + + def test_complete_pause_transitions_to_paused(self): + """complete_pause transitions to PAUSED state.""" + sm = TimepointStateMachine() + sm.reset(total_fovs=5) + sm.request_pause() + assert sm.complete_pause() is True + assert sm.state == TimepointState.PAUSED + + def test_resume_from_paused_with_fovs_remaining(self): + """Resume from PAUSED returns to ACQUIRING if FOVs remain.""" + sm = TimepointStateMachine() + sm.reset(total_fovs=5) + sm.request_pause() + sm.complete_pause() + + assert sm.resume() is True + assert sm.state == TimepointState.ACQUIRING + + def test_resume_from_paused_no_fovs_remaining(self): + """Resume from PAUSED transitions to CAPTURED if no FOVs remain.""" + sm = TimepointStateMachine() + sm.reset(total_fovs=2) + sm.mark_fov_captured() + sm.mark_fov_captured() + sm.request_pause() + sm.complete_pause() + + assert sm.fovs_remaining == 0 + assert sm.resume() is True + assert sm.state == TimepointState.CAPTURED + + def test_resume_not_from_paused_fails(self): + """Resume fails if not in PAUSED state.""" + sm = TimepointStateMachine() + sm.reset(total_fovs=5) + assert sm.resume() is False # Still in ACQUIRING + + def test_wait_for_resume_unblocks_on_resume(self): + """wait_for_resume unblocks when resume is called.""" + sm = TimepointStateMachine() + sm.reset(total_fovs=5) + sm.request_pause() + sm.complete_pause() + + # Start a thread that will wait for resume + wait_result = [None] + + def wait_thread(): + wait_result[0] = sm.wait_for_resume(timeout=2.0) + + t = threading.Thread(target=wait_thread) + t.start() + + # Give the thread time to start waiting + time.sleep(0.1) + + # Resume should unblock the wait + sm.resume() + t.join(timeout=1.0) + + assert not t.is_alive() + assert wait_result[0] is True + + def test_wait_for_resume_times_out(self): + """wait_for_resume returns False on timeout.""" + sm = TimepointStateMachine() + sm.reset(total_fovs=5) + sm.request_pause() + sm.complete_pause() + + result = sm.wait_for_resume(timeout=0.1) + assert result is False + + # --- Retake Tests --- + + def test_retake_from_paused(self): + """Retake can be started from PAUSED state.""" + sm = TimepointStateMachine() + sm.reset(total_fovs=5) + sm.request_pause() + sm.complete_pause() + + fovs = [FOVIdentifier("A1", 0), FOVIdentifier("A1", 1)] + assert sm.retake(fovs) is True + assert sm.state == TimepointState.RETAKING + + def test_retake_not_from_paused_fails(self): + """Retake fails if not in PAUSED state.""" + sm = TimepointStateMachine() + sm.reset(total_fovs=5) + + fovs = [FOVIdentifier("A1", 0)] + assert sm.retake(fovs) is False + + def test_retake_empty_list_fails(self): + """Retake with empty list fails.""" + sm = TimepointStateMachine() + sm.reset(total_fovs=5) + sm.request_pause() + sm.complete_pause() + + assert sm.retake([]) is False + + def test_get_retake_list(self): + """get_retake_list returns copy of retake list.""" + sm = TimepointStateMachine() + sm.reset(total_fovs=5) + sm.request_pause() + sm.complete_pause() + + fovs = [FOVIdentifier("A1", 0), FOVIdentifier("A1", 1)] + sm.retake(fovs) + + retake_list = sm.get_retake_list() + assert len(retake_list) == 2 + assert retake_list[0] == FOVIdentifier("A1", 0) + assert retake_list[1] == FOVIdentifier("A1", 1) + + def test_complete_retakes_returns_to_paused(self): + """complete_retakes transitions back to PAUSED.""" + sm = TimepointStateMachine() + sm.reset(total_fovs=5) + sm.request_pause() + sm.complete_pause() + sm.retake([FOVIdentifier("A1", 0)]) + + assert sm.complete_retakes() is True + assert sm.state == TimepointState.PAUSED + assert sm.get_retake_list() == [] + + def test_complete_retakes_not_from_retaking_fails(self): + """complete_retakes fails if not in RETAKING state.""" + sm = TimepointStateMachine() + sm.reset(total_fovs=5) + sm.request_pause() + sm.complete_pause() + + assert sm.complete_retakes() is False + + # --- FOV Tracking Tests --- + + def test_mark_fov_captured(self): + """mark_fov_captured decrements remaining count.""" + sm = TimepointStateMachine() + sm.reset(total_fovs=5) + + sm.mark_fov_captured() + assert sm.fovs_remaining == 4 + + sm.mark_fov_captured() + assert sm.fovs_remaining == 3 + + def test_mark_fov_captured_at_zero_stays_zero(self): + """mark_fov_captured doesn't go negative.""" + sm = TimepointStateMachine() + sm.reset(total_fovs=1) + + sm.mark_fov_captured() + assert sm.fovs_remaining == 0 + + sm.mark_fov_captured() # Should not go negative + assert sm.fovs_remaining == 0 + + def test_mark_all_captured(self): + """mark_all_captured transitions to CAPTURED.""" + sm = TimepointStateMachine() + sm.reset(total_fovs=5) + + assert sm.mark_all_captured() is True + assert sm.state == TimepointState.CAPTURED + + def test_mark_all_captured_not_from_acquiring_fails(self): + """mark_all_captured fails if not in ACQUIRING state.""" + sm = TimepointStateMachine() + sm.reset(total_fovs=5) + sm.request_pause() + sm.complete_pause() + + assert sm.mark_all_captured() is False + + # --- Abort Tests --- + + def test_abort_from_retaking_returns_to_paused(self): + """Abort from RETAKING returns to PAUSED.""" + sm = TimepointStateMachine() + sm.reset(total_fovs=5) + sm.request_pause() + sm.complete_pause() + sm.retake([FOVIdentifier("A1", 0)]) + + accepted, abort_all = sm.abort() + assert accepted is True + assert abort_all is False + assert sm.state == TimepointState.PAUSED + + def test_abort_from_acquiring_aborts_all(self): + """Abort from ACQUIRING signals full abort.""" + sm = TimepointStateMachine() + sm.reset(total_fovs=5) + + accepted, abort_all = sm.abort() + assert accepted is True + assert abort_all is True + + def test_abort_from_paused_aborts_all(self): + """Abort from PAUSED signals full abort.""" + sm = TimepointStateMachine() + sm.reset(total_fovs=5) + sm.request_pause() + sm.complete_pause() + + accepted, abort_all = sm.abort() + assert accepted is True + assert abort_all is True + + # --- State Change Callback Tests --- + + def test_state_change_callback_called(self): + """State change callback is called on state transitions.""" + sm = TimepointStateMachine() + sm.reset(total_fovs=5) + + states_received = [] + + def callback(state): + states_received.append(state) + + sm.on_state_changed = callback + + sm.request_pause() + sm.complete_pause() + + # Give callback thread time to run + time.sleep(0.1) + + assert TimepointState.PAUSED in states_received + + # --- Thread Safety Tests --- + + def test_concurrent_pause_requests(self): + """Multiple concurrent pause requests are handled safely.""" + sm = TimepointStateMachine() + sm.reset(total_fovs=100) + + results = [] + + def request_pause_thread(): + result = sm.request_pause() + results.append(result) + + threads = [threading.Thread(target=request_pause_thread) for _ in range(10)] + for t in threads: + t.start() + for t in threads: + t.join() + + # At least one should succeed + assert any(results) + # All should get consistent result + assert sm.is_pause_requested() + + def test_concurrent_fov_marking(self): + """Multiple concurrent mark_fov_captured calls are thread-safe.""" + sm = TimepointStateMachine() + sm.reset(total_fovs=100) + + def mark_fov_thread(): + for _ in range(10): + sm.mark_fov_captured() + + threads = [threading.Thread(target=mark_fov_thread) for _ in range(10)] + for t in threads: + t.start() + for t in threads: + t.join() + + # All 100 FOVs should be marked + assert sm.fovs_remaining == 0 + + +class TestFOVIdentifier: + """Tests for FOVIdentifier dataclass.""" + + def test_fov_identifier_equality(self): + """FOVIdentifier equality is based on region_id and fov_index.""" + fov1 = FOVIdentifier("A1", 0) + fov2 = FOVIdentifier("A1", 0) + fov3 = FOVIdentifier("A1", 1) + fov4 = FOVIdentifier("A2", 0) + + assert fov1 == fov2 + assert fov1 != fov3 + assert fov1 != fov4 + + def test_fov_identifier_hashable(self): + """FOVIdentifier can be used in sets and as dict keys.""" + fov1 = FOVIdentifier("A1", 0) + fov2 = FOVIdentifier("A1", 0) + fov3 = FOVIdentifier("A1", 1) + + # Can use in set + fov_set = {fov1, fov2, fov3} + assert len(fov_set) == 2 + + # Can use as dict key + fov_dict = {fov1: "first", fov3: "second"} + assert fov_dict[fov2] == "first" + + def test_fov_identifier_immutable(self): + """FOVIdentifier is frozen (immutable).""" + fov = FOVIdentifier("A1", 0) + with pytest.raises(AttributeError): + fov.region_id = "A2"