diff --git a/bec_ipython_client/bec_ipython_client/callbacks/live_table.py b/bec_ipython_client/bec_ipython_client/callbacks/live_table.py index 88b98d864..014b6b907 100644 --- a/bec_ipython_client/bec_ipython_client/callbacks/live_table.py +++ b/bec_ipython_client/bec_ipython_client/callbacks/live_table.py @@ -226,10 +226,10 @@ def _run_update(self, target_num_points: int): while True: self.check_alarms() self.point_data = self.scan_item.live_data.get(self.point_id) - if self.scan_item.num_points: - progressbar.max_points = self.scan_item.num_points + if self.scan_item.num_monitored_readouts: + progressbar.max_points = self.scan_item.num_monitored_readouts if target_num_points == 0: - target_num_points = self.scan_item.num_points + target_num_points = self.scan_item.num_monitored_readouts progressbar.update(self.point_id) if self.point_data: @@ -256,12 +256,12 @@ def _run_update(self, target_num_points: int): f"Scan {self.scan_item.scan_number} was aborted by user." ) - if not self.scan_item.num_points: + if not self.scan_item.num_monitored_readouts and self.scan_item.status != "closed": continue if self.point_id == target_num_points: break - if self.point_id > self.scan_item.num_points: + if self.point_id > self.scan_item.num_monitored_readouts: raise RuntimeError("Received more points than expected.") if len(self.scan_item.live_data) == 0 and self.scan_item.status == "closed": diff --git a/bec_ipython_client/tests/client_tests/test_live_table.py b/bec_ipython_client/tests/client_tests/test_live_table.py index 891803831..f87c3457c 100644 --- a/bec_ipython_client/tests/client_tests/test_live_table.py +++ b/bec_ipython_client/tests/client_tests/test_live_table.py @@ -146,11 +146,13 @@ def test_run_update(self, bec_client_mock, scan_item): ) live_update.scan_item = scan_item scan_item.num_points = 2 + scan_item.num_monitored_readouts = 2 scan_item.live_data = {0: data} with mock.patch.object(live_update, "print_table_data") as mock_print_table_data: live_update._run_update(1) assert mock_print_table_data.called scan_item.num_points = 2 + scan_item.num_monitored_readouts = 2 scan_item.live_data = {0: data, 1: data} scan_item.status = "closed" with mock.patch.object(live_update, "print_table_data") as mock_print_table_data: @@ -173,11 +175,13 @@ def test_run_update_without_monitored_devices(self, bec_client_mock, scan_item): ) live_update.scan_item = scan_item scan_item.num_points = 2 + scan_item.num_monitored_readouts = 2 scan_item.live_data = {0: data} with mock.patch.object(live_update, "print_table_data") as mock_print_table_data: live_update._run_update(1) assert mock_print_table_data.called scan_item.num_points = 2 + scan_item.num_monitored_readouts = 2 scan_item.live_data = {} scan_item.status_message = messages.ScanStatusMessage( readout_priority={"monitored": [], "baseline": ["samx"]}, diff --git a/bec_ipython_client/tests/end-2-end/test_scans_v4_lib_e2e.py b/bec_ipython_client/tests/end-2-end/test_scans_v4_lib_e2e.py new file mode 100644 index 000000000..acd6337d7 --- /dev/null +++ b/bec_ipython_client/tests/end-2-end/test_scans_v4_lib_e2e.py @@ -0,0 +1,268 @@ +from __future__ import annotations + +import time + +import numpy as np +import pytest + +from bec_server.scan_server.scans import position_generators + + +def _get_v4_scan_runner(bec, scan_name: str): + return getattr(bec.scans, f"_v4_{scan_name}") + + +def _run_v4_scan( + bec, scan_name: str, *args, timeout: float = 60, wait_for_num_points: bool = True, **kwargs +): + bec.metadata.update({"unit_test": f"test_v4_{scan_name}_lib"}) + status = _get_v4_scan_runner(bec, scan_name)(*args, **kwargs) + status.wait(timeout=timeout, num_points=wait_for_num_points, file_written=False) + return status + + +def _assert_device_position(device, target: float): + current = device.read(cached=True)[device.full_name]["value"] + tolerance = device._config["deviceConfig"].get("tolerance", 0.05) + assert np.isclose(current, target, atol=tolerance) + + +def _resolve_scan_args(scan_args: tuple, dev): + resolved_args = [] + for arg in scan_args: + if isinstance(arg, str) and arg.startswith("dev."): + resolved_args.append(getattr(dev, arg.removeprefix("dev."))) + continue + resolved_args.append(arg) + return tuple(resolved_args) + + +def _wait_for_live_data_count(bec, status, expected_count: int, timeout: float = 5): + deadline = time.time() + timeout + while time.time() < deadline: + bec.callbacks.poll() + if len(status.scan.live_data) >= expected_count: + return + time.sleep(0.1) + + +def _wait_for_scan_status(status, expected_status: str, timeout: float = 10): + deadline = time.time() + timeout + while time.time() < deadline: + if status.status == expected_status: + return + time.sleep(0.1) + raise TimeoutError(f"Timed out waiting for scan status {expected_status!r}.") + + +def _wait_for_queue_status(bec, queue_name: str, expected_status: str, timeout: float = 10): + deadline = time.time() + timeout + while time.time() < deadline: + current_status = bec.queue.queue_storage.current_scan_queue[queue_name].status + if current_status == expected_status: + return + time.sleep(0.1) + raise TimeoutError(f"Timed out waiting for queue status {expected_status!r}.") + + +@pytest.mark.timeout(120) +@pytest.mark.parametrize( + ("scan_name", "scan_args", "scan_kwargs", "expected_num_points", "expected_num_readouts"), + [ + ("acquire", (), {"exp_time": 0.01, "burst_at_each_point": 3}, 1, 3), + ("line_scan", ("dev.samx", -1, 1), {"steps": 4, "exp_time": 0.01, "relative": False}, 4, 4), + ( + "grid_scan", + ("dev.samx", -1, 1, 3, "dev.samy", -1, 1, 2), + {"exp_time": 0.01, "relative": False}, + 6, + 6, + ), + ( + "list_scan", + ("dev.samx", [0, 0.5, 1.0], "dev.samy", [0, -0.5, -1.0]), + {"exp_time": 0.01, "relative": False}, + 3, + 3, + ), + ("log_scan", ("dev.samx", 1, 10), {"steps": 4, "exp_time": 0.01, "relative": False}, 4, 4), + ( + "fermat_scan", + ("dev.samx", -1, 1, "dev.samy", -1, 1), + {"step": 1.0, "exp_time": 0.01, "relative": False}, + len(position_generators.fermat_spiral_pos(-1, 1, -1, 1, step=1.0)), + len(position_generators.fermat_spiral_pos(-1, 1, -1, 1, step=1.0)), + ), + ( + "hexagonal_scan", + ("dev.samx", -1, 1, 1, "dev.samy", -1, 1, 1), + {"exp_time": 0.01, "relative": False}, + len(position_generators.hex_grid_2d([(-1, 1, 1), (-1, 1, 1)], snaked=True)), + len(position_generators.hex_grid_2d([(-1, 1, 1), (-1, 1, 1)], snaked=True)), + ), + ( + "multi_region_line_scan", + ("dev.samx",), + {"regions": [(-1, 0, 2), (1, 2, 2)], "exp_time": 0.01, "relative": False}, + len(position_generators.multi_region_line_positions([(-1, 0, 2), (1, 2, 2)])), + len(position_generators.multi_region_line_positions([(-1, 0, 2), (1, 2, 2)])), + ), + ( + "multi_region_grid_scan", + ("dev.samx", "dev.samy"), + { + "regions": [((-1, 0, 2), (-1, 0, 2)), ((1, 2, 2), (1, 2, 2))], + "exp_time": 0.01, + "relative": False, + }, + len( + position_generators.multi_region_grid_positions( + [((-1, 0, 2), (-1, 0, 2)), ((1, 2, 2), (1, 2, 2))], snaked=True + ) + ), + len( + position_generators.multi_region_grid_positions( + [((-1, 0, 2), (-1, 0, 2)), ((1, 2, 2), (1, 2, 2))], snaked=True + ) + ), + ), + ( + "round_scan", + ("dev.samx", "dev.samy", 0.0, 2.0, 2, 3), + {"exp_time": 0.01, "relative": False}, + len( + position_generators.round_scan_positions( + inner_radius=0.0, outer_radius=2.0, number_of_rings=2, points_in_first_ring=3 + ) + ), + len( + position_generators.round_scan_positions( + inner_radius=0.0, outer_radius=2.0, number_of_rings=2, points_in_first_ring=3 + ) + ), + ), + ( + "round_roi_scan", + ("dev.samx", -1.0, 1.0, "dev.samy", -1.0, 1.0), + {"shell_spacing": 1.0, "pos_in_first_ring": 3, "exp_time": 0.01, "relative": False}, + len( + position_generators.get_round_roi_scan_positions( + motor_1_start=-1.0, + motor_1_stop=1.0, + motor_2_start=-1.0, + motor_2_stop=1.0, + radial_step=1.0, + points_in_first_shell=3, + ) + ), + len( + position_generators.get_round_roi_scan_positions( + motor_1_start=-1.0, + motor_1_stop=1.0, + motor_2_start=-1.0, + motor_2_stop=1.0, + radial_step=1.0, + points_in_first_shell=3, + ) + ), + ), + ("time_scan", (), {"points": 3, "interval": 0.05, "exp_time": 0.01}, 3, 3), + ], +) +def test_v4_fixed_point_scans_lib( + bec_client_lib, scan_name, scan_args, scan_kwargs, expected_num_points, expected_num_readouts +): + bec = bec_client_lib + dev = bec.device_manager.devices + resolved_args = _resolve_scan_args(scan_args, dev) + + status = _run_v4_scan(bec, scan_name, *resolved_args, **scan_kwargs) + + assert status.scan is not None + assert status.scan.num_points == expected_num_points + assert status.scan.num_monitored_readouts == expected_num_readouts + assert len(status.scan.live_data) == expected_num_readouts + + +@pytest.mark.timeout(120) +def test_v4_mv_scan_lib(bec_client_lib): + bec = bec_client_lib + dev = bec.device_manager.devices + + status = _run_v4_scan(bec, "mv", dev.samx, 1.5, dev.samy, -1.5, relative=False) + status.wait(timeout=30) + + _assert_device_position(dev.samx, 1.5) + _assert_device_position(dev.samy, -1.5) + + +@pytest.mark.timeout(120) +def test_v4_umv_scan_lib(bec_client_lib): + bec = bec_client_lib + dev = bec.device_manager.devices + + status = _run_v4_scan(bec, "umv", dev.samx, -1.0, dev.samy, 1.0, relative=False) + status.wait(timeout=30) + + _assert_device_position(dev.samx, -1.0) + _assert_device_position(dev.samy, 1.0) + + +@pytest.mark.timeout(120) +def test_v4_cont_line_scan_lib(bec_client_lib): + bec = bec_client_lib + dev = bec.device_manager.devices + original_velocity = dev.samx.velocity.get() + try: + dev.samx.velocity.set(1).wait() + status = _run_v4_scan( + bec, "cont_line_scan", dev.samx, 0.0, 0.2, steps=3, exp_time=0.01, relative=False + ) + finally: + dev.samx.velocity.set(original_velocity).wait() + + assert status.scan is not None + assert status.scan.num_points == 3 + assert len(status.scan.live_data) == 3 + + +@pytest.mark.timeout(120) +def test_v4_line_sweep_scan_lib(bec_client_lib): + bec = bec_client_lib + dev = bec.device_manager.devices + original_velocity = dev.samx.velocity.get() + try: + dev.samx.velocity.set(1).wait() + dev.samx.limits = [-50, 50] + status = _run_v4_scan( + bec, + "line_sweep_scan", + dev.samx, + -5.0, + 5.0, + min_update=0.01, + relative=False, + wait_for_num_points=False, + ) + finally: + dev.samx.velocity.set(original_velocity).wait() + + assert status.scan is not None + _wait_for_live_data_count(bec, status, expected_count=1) + assert len(status.scan.live_data) > 0 + + +@pytest.mark.timeout(120) +def test_v4_scan_lib_stop_resolves_cleanly(bec_client_lib): + bec = bec_client_lib + status = _get_v4_scan_runner(bec, "time_scan")(points=100, interval=0.2, exp_time=0.01) + + time.sleep(0.5) + status.cancel() + + _wait_for_scan_status(status, "STOPPED", timeout=15) + assert status.status == "STOPPED" + _wait_for_queue_status(bec, "primary", "PAUSED", timeout=15) + + bec.queue.request_scan_continuation() + _wait_for_queue_status(bec, "primary", "RUNNING", timeout=15) diff --git a/bec_lib/bec_lib/bl_state_machine.py b/bec_lib/bec_lib/bl_state_machine.py new file mode 100644 index 000000000..0b5601ee8 --- /dev/null +++ b/bec_lib/bec_lib/bl_state_machine.py @@ -0,0 +1,94 @@ +""" +Module for managing aggregated beamline states based on configuration files. + +Example of the YAML configuration file: +``` yaml +alignment: + devices: + samx: + readback: + value: 0 + abs_tol: 0.1 + measurement: + devices: + samx: + readback: + value: 19 + abs_tol: 0.1 + velocity: + value: 5 + abs_tol: 0.1 + samy: + readback: + value: 0 + abs_tol: 0.1 + test: + devices: + samy: + readback: + value: 0 + abs_tol: 0.1 +``` + +""" + +from __future__ import annotations + +import yaml + +from bec_lib.bl_state_manager import BeamlineStateManager +from bec_lib.bl_states import AggregatedStateConfig + + +class BeamlineStateMachine: + + def __init__(self, manager: BeamlineStateManager) -> None: + self._manager = manager + self._configs: dict[str, AggregatedStateConfig] = {} + + def load_from_config( + self, name: str, config_path: str | None = None, config_dict: dict | None = None + ) -> None: + """ + Load a state configuration from a YAML file or a dictionary. If None or both are provided, + an error will be raised. Config must be states for an AggregatedStateConfig or a dictionary/YAML file that + can be parsed into one. Please check AggregatedStateConfig state field for the expected format of the configuration. + + Args: + name (str): The name of the aggregated state to load. + config_path (str | None): The path to the YAML configuration file. + config_dict (dict | None): A dictionary containing the configuration. If provided, this will be used instead of loading from a file. + """ + self._check_inputs(config_path=config_path, config_dict=config_dict) + if config_path: + with open(config_path, "r", encoding="utf-8") as f: + config_dict = yaml.safe_load(f) + config = AggregatedStateConfig(name=name, states=config_dict) + self._manager.add(config) + + def update_config( + self, name: str, config_path: str | None = None, config_dict: dict | None = None + ) -> None: + """ + Update a state configuration from a YAML file or a dictionary. If None or both are provided, + an error will be raised. Config must be states for an AggregatedStateConfig or a dictionary/YAML file that + can be parsed into one. Please check AggregatedStateConfig state field for the expected format of the configuration. + + Args: + name (str): The name of the aggregated state to update. + config_path (str | None): The path to the YAML configuration file. + config_dict (dict | None): A dictionary containing the configuration. If provided, this will be used instead of loading from a file. + """ + self._check_inputs(config_path=config_path, config_dict=config_dict) + if config_path: + with open(config_path, "r", encoding="utf-8") as f: + config_dict = yaml.safe_load(f) + # Load the new state + config = AggregatedStateConfig(name=name, states=config_dict) + self._manager._update_state(config) + + def _check_inputs(self, config_path: str | None, config_dict: dict | None) -> None: + if (config_path is None and config_dict is None) or ( + config_path is not None and config_dict is not None + ): + raise ValueError("Either config_path or config_dict must be provided, but not both.") diff --git a/bec_lib/bec_lib/bl_states.py b/bec_lib/bec_lib/bl_states.py index 8d32959c2..d8689fdb5 100644 --- a/bec_lib/bec_lib/bl_states.py +++ b/bec_lib/bec_lib/bl_states.py @@ -1,11 +1,15 @@ +"""Module defining beamline states and their evaluation logic.""" + from __future__ import annotations import functools import keyword import traceback from abc import ABC, abstractmethod -from typing import Callable, ClassVar, Generic, Type, TypeVar, cast +from dataclasses import dataclass +from typing import Any, Callable, ClassVar, Generic, Literal, Type, TypeVar, cast +import yaml from pydantic import BaseModel, field_validator, model_validator from bec_lib import messages @@ -121,6 +125,61 @@ class DeviceWithinLimitsStateConfig(DeviceStateConfig): tolerance: float = 0.1 +class SignalConfig(BaseModel): + """Target value for a signal inside a named machine state.""" + + value: float | int | str | bool + abs_tol: float = 0.0 + + +class DeviceConfig(BaseModel): + """Configuration for a device inside a named machine state.""" + + abs_tol: float = 0.0 + value: float | int | str | bool | None = None + low_limit: SignalConfig | None = None + high_limit: SignalConfig | None = None + signals: dict[str, SignalConfig] | None = None + + @model_validator(mode="after") + def validate_config(self) -> DeviceConfig: + """ + Validate that either value, low_limit, high_limit, or signals are provided. + """ + if ( + self.value is None + and self.low_limit is None + and self.high_limit is None + and self.signals is None + ): + raise ValueError( + "At least one of value, low_limit, high_limit, or signals must be provided." + ) + return self + + +class SubDeviceStateConfig(BaseModel): + """ + Configuration for a sub-state with a specific label. + This is a device/signal mappping to either a DeviceConfig or SignalConfig. + """ + + devices: dict[str, DeviceConfig | SignalConfig] + transition_metadata: dict[str, Any] | None = None + + +class AggregatedStateConfig(BeamlineStateConfig): + """ + Configuration for a state machine driven by multiple device signals. + + Keys of the states dictionary are the labels of the different states. + """ + + state_type: ClassVar[str] = "AggregatedState" + + states: dict[str, SubDeviceStateConfig] + + C = TypeVar("C", bound=BeamlineStateConfig) D = TypeVar("D", bound=DeviceStateConfig) @@ -322,6 +381,366 @@ def _update_device_state(self, msg_obj: MessageObject) -> messages.BeamlineState return self.evaluate(msg) +SignalSource = TypeVar("SignalSource", bound=Literal["readback", "configuration", "limits"]) + + +@dataclass(frozen=True) +class ResolvedStateSignal: + label: str + device_name: str + signal_name: str + expected_value: float | int | str | bool + abs_tolerance: float | int + source: SignalSource + + +class AggregatedState(BeamlineState[AggregatedStateConfig]): + """Beamline state that infers the current named state from multiple device signals.""" + + CONFIG_CLASS = AggregatedStateConfig + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + # Mapping from signal updates to affected state labels, used for efficient evaluation when a signal update is received + self._signal_info_to_labels: dict[tuple[str, SignalSource, str], set[str]] = {} + # Mapping from state labels to the list of signal requirements that define that state + self._requirements_for_label: dict[str, list[ResolvedStateSignal]] = {} + # Set of subscriptions to signal updates + self._subscriptions: set[tuple[str, SignalSource]] = set() + # Cache of the latest signal values + self._signal_value_cache: dict[tuple[str, SignalSource, str], Any] = {} + # List of currently active state labels + self._current_labels: list[str] = [] + + @staticmethod + def _endpoint(device: str, source: SignalSource): + """Static method to get the appropriate message endpoint based on the signal source.""" + if source == "readback": + return MessageEndpoints.device_readback(device) + if source == "configuration": + return MessageEndpoints.device_read_configuration(device) + if source == "limits": + return MessageEndpoints.device_limits(device) + raise ValueError( + f"Invalid signal source '{source}', please use 'readback', 'configuration', or 'limits'." + ) + + def _get_device_manager(self): + if self.device_manager is None: + # pylint: disable=import-outside-toplevel + from bec_lib.client import BECClient + + bec = BECClient() + return bec.device_manager + return self.device_manager + + @staticmethod + def _get_signal_source(signal_info: dict[str, Any], error_prefix: str) -> SignalSource: + kind_str = str(signal_info.get("kind_str", "")).lower() + if "hinted" in kind_str or "normal" in kind_str: + return "readback" + if "config" in kind_str: + return "configuration" + raise ValueError( + f"{error_prefix} Unsupported kind: '{kind_str}' for signal : \n {yaml.dump(signal_info, indent=4)}" + ) + + @staticmethod + def _resolve_signal( + device_name: str, signal_name: str, device_manager: DeviceManagerBase, error_prefix: str + ) -> tuple[str, SignalSource]: + devices = device_manager.devices + try: + if not isinstance(device_name, str): + raise ValueError( + f"{error_prefix} Device name must be a string, got {type(device_name)}" + ) + device_obj: DeviceBase = devices[device_name] + except KeyError: + raise ValueError(f"{error_prefix} Device '{device_name}' not found.") from None + + # Special handling for limits, as they are not regular signals. + if signal_name in ["low_limit", "low_limit_travel"]: + return "low", "limits" + if signal_name in ["high_limit", "high_limit_travel"]: + return "high", "limits" + + signal_info = None + # This case is relevant if we are looking at a Signal directly + if device_name == signal_name and len(device_obj.root._info["signals"]) == 0: + signal_info = {"obj_name": signal_name, "kind_str": "hinted"} + # Case where we have a signal specified as a dotted name, e.g. + elif "." in signal_name: + try: + signal_obj = devices[signal_name] + except AttributeError: + raise ValueError( + f"{error_prefix} Signal '{signal_name}' not found for device '{device_name}'." + ) from None + if signal_obj.parent != device_obj: + raise ValueError( + f"{error_prefix} Signal '{signal_name}' does not belong to device '{device_name}'." + ) + signal_component = ".".join(signal_name.split(".")[1:]) + signal_info = device_obj.root._info["signals"].get(signal_component) + # Case where the signal is specified as the signal + else: + signal_info = device_obj.root._info["signals"].get(signal_name) + if signal_info is None: + for candidate in device_obj.root._info["signals"].values(): + if candidate.get("obj_name") == signal_name: + signal_info = candidate + break + + if signal_info is None: + raise ValueError( + f"{error_prefix} Signal '{signal_name}' not found for device '{device_name}'." + ) + + obj_name = signal_info.get("obj_name") + signal_source = AggregatedState._get_signal_source(signal_info, error_prefix) + return obj_name, signal_source + + @staticmethod + def get_state_requirements( + label: str, + state_config: SubDeviceStateConfig, + device_manager: DeviceManagerBase, + error_prefix: str, + ) -> list[ResolvedStateSignal]: + state_requirements: list[ResolvedStateSignal] = [] + for device_name, config in state_config.devices.items(): + if isinstance(config, SignalConfig): + state_requirements.append( + AggregatedState._build_requirement_for_signal( + device_name, + device_name, + config.value, + config.abs_tol, + label, + device_manager, + error_prefix, + ) + ) + elif isinstance(config, DeviceConfig): + # If a value is specified for the device, add it as a requirement + if config.value is not None: + state_requirements.append( + AggregatedState._build_requirement_for_signal( + device_name, + device_name, + config.value, + config.abs_tol, + label, + device_manager, + error_prefix, + ) + ) + if config.low_limit is not None: + state_requirements.append( + AggregatedState._build_requirement_for_signal( + device_name, + "low_limit", + config.low_limit.value, + config.low_limit.abs_tol, + label, + device_manager, + error_prefix, + ) + ) + if config.high_limit is not None: + state_requirements.append( + AggregatedState._build_requirement_for_signal( + device_name, + "high_limit", + config.high_limit.value, + config.high_limit.abs_tol, + label, + device_manager, + error_prefix, + ) + ) + for signal_name, signal_config in (config.signals or {}).items(): + state_requirements.append( + AggregatedState._build_requirement_for_signal( + device_name, + signal_name, + signal_config.value, + signal_config.abs_tol, + label, + device_manager, + error_prefix, + ) + ) + return state_requirements + + def _build_rules(self) -> None: + self._signal_info_to_labels.clear() + self._requirements_for_label.clear() + self._subscriptions.clear() + for label, device_configs in self.config.states.items(): + state_requirements: list[ResolvedStateSignal] = AggregatedState.get_state_requirements( + label, device_configs, self._get_device_manager(), self._error_prefix + ) + for requirement in state_requirements: + device_name = requirement.device_name + signal_name = requirement.signal_name + source = requirement.source + self._subscriptions.add((device_name, source)) + self._signal_info_to_labels.setdefault( + (device_name, source, signal_name), set() + ).add(label) + self._requirements_for_label[label] = state_requirements + + @staticmethod + def _build_requirement_for_signal( + device_name: str, + signal_name: str, + value: Any, + abs_tol: float, + label: str, + device_manager: DeviceManagerBase, + error_prefix: str, + ) -> ResolvedStateSignal: + resolved_signal_name, source = AggregatedState._resolve_signal( + device_name, signal_name, device_manager, error_prefix + ) + + return ResolvedStateSignal( + label=label, + device_name=device_name, + signal_name=resolved_signal_name, + expected_value=value, + abs_tolerance=abs_tol, + source=source, + ) + + def start(self) -> None: + if self.started: + return + + if self.connector is None: + raise RuntimeError("Redis connector is not set.") + + try: + msg = None + self._build_rules() + affected_labels = self._fill_cache() + msg = self.evaluate(affected_labels=affected_labels) + except Exception as exc: + self._handle_state_exception(exc) + + if msg is not None: + self._emit_state(msg) + for device, source in self._subscriptions: + self.connector.register( + self._endpoint(device, source), + cb=self._update_aggregated_state, + device=device, + source=source, + ) + super().start() + + def _fill_cache(self) -> set[str]: + affected_labels: set[str] = set() + for device, source in self._subscriptions: + endpoint = self._endpoint(device, source) + msg = self.connector.get(endpoint) + if msg is not None: + affected_labels.update(self._cache_message(device, source, msg)) + return affected_labels + + def _cache_message( + self, device: str, source: SignalSource, msg: messages.DeviceMessage + ) -> set[str]: + affected_labels: set[str] = set() + for signal_name, signal_data in msg.signals.items(): + key = (device, source, signal_name) + labels = self._signal_info_to_labels.get(key) + if labels is None: # signal not relevant for any state + continue + self._signal_value_cache[key] = signal_data.get("value") + affected_labels.update(labels) + return affected_labels + + def stop(self) -> None: + if not self.started: + return + if self.connector is not None: + for device, source in self._subscriptions: + self.connector.unregister( + self._endpoint(device, source), cb=self._update_aggregated_state + ) + super().stop() + + def _update_aggregated_state( + self, msg_obj: MessageObject, device: str, source: SignalSource, **_kwargs + ) -> None: + try: + msg: messages.DeviceMessage = msg_obj.value # type: ignore ; we know it's a DeviceMessage + affected_labels = self._cache_message(device, source, msg) + if affected_labels: + msg = self.evaluate(affected_labels=affected_labels) + if msg is not None: + self._emit_state(msg) + except Exception as exc: + self._handle_state_exception(exc) + + def evaluate( + self, affected_labels: set[str] | None = None + ) -> messages.BeamlineStateMessage | None: + if affected_labels is None: + return None + # We need to always extend the affected labels with the current labels, + # as the signal that updated might be not relevant for the currently active state, + # but the state should still be checked for validity. + affected_labels.update(self._current_labels) + matching_labels = [label for label in affected_labels if self._label_matches(label)] + if matching_labels: + self._current_labels = matching_labels + state_msg = messages.BeamlineStateMessage( + name=self.config.name, status="valid", label="|".join(matching_labels) + ) + return state_msg + + self._current_labels = [] + state_msg = messages.BeamlineStateMessage( + name=self.config.name, status="invalid", label="No matching state" + ) + return state_msg + + def _label_matches(self, label: str) -> bool: + requirements = self._requirements_for_label.get(label, []) + return bool(requirements) and all( + self._requirement_matches(requirement) for requirement in requirements + ) + + def _requirement_matches(self, requirement: ResolvedStateSignal) -> bool: + key = (requirement.device_name, requirement.source, requirement.signal_name) + cached_value = self._signal_value_cache.get(key, None) + if cached_value is None: + return False + + try: + # Cast to float to make sure comparison with abs works as expected. + value = float(cached_value) + expected_value = float(requirement.expected_value) + return abs(value - expected_value) <= requirement.abs_tolerance + # Catch TypeError and ValueError in case the value is not a number or cannot be cast to float, + # in that case we fall back to exact equality. + except (TypeError, ValueError): + try: + result = cached_value == requirement.expected_value + except (TypeError, ValueError): + return False + # In case this comparison runs on comparing two arrays. + # We do not consider this comparsion as valid currently. + try: + return bool(result) + except (TypeError, ValueError): + return False + + class ShutterState(DeviceBeamlineState[DeviceStateConfig]): """ A state that checks if the shutter is open. diff --git a/bec_lib/bec_lib/client.py b/bec_lib/bec_lib/client.py index 883b2ad0a..0ef137d2a 100644 --- a/bec_lib/bec_lib/client.py +++ b/bec_lib/bec_lib/client.py @@ -20,6 +20,7 @@ from bec_lib.alarm_handler import AlarmHandler, Alarms from bec_lib.bec_service import BECService +from bec_lib.bl_state_machine import BeamlineStateMachine from bec_lib.bl_state_manager import BeamlineStateManager from bec_lib.callback_handler import CallbackHandler, EventType from bec_lib.config_helper import ConfigHelperUser @@ -162,6 +163,7 @@ def __init__( self._username = "" self._system_user = "" self.beamline_states = None + self.state_machine = None self.messaging: MessagingContainer = None # type: ignore def __new__(cls, *args, forced=False, **kwargs): @@ -241,6 +243,7 @@ def _start_services(self): self.device_monitor = DeviceMonitorPlugin(self.connector) self._update_username() self.beamline_states = BeamlineStateManager(client=self) + self.state_machine = BeamlineStateMachine(manager=self.beamline_states) def alarms(self, severity=Alarms.WARNING): """get the next alarm with at least the specified severity""" diff --git a/bec_lib/bec_lib/messages.py b/bec_lib/bec_lib/messages.py index 37fd0599c..59803cfb3 100644 --- a/bec_lib/bec_lib/messages.py +++ b/bec_lib/bec_lib/messages.py @@ -222,6 +222,9 @@ class ScanStatusMessage(BECMessage): default=None, description="Number of points in the scan. Only relevant if the number of points is determined by BEC.", ) + num_monitored_readouts: int | None = Field( + default=0, description="Number of monitored readouts in the scan." + ) scan_name: str | None = Field(default=None, description="Name of the scan, e.g. 'line_scan'") scan_type: Literal["step", "fly"] | None = Field(default=None, description="Type of scan") dataset_number: int | None = None @@ -904,6 +907,7 @@ class ScanHistoryMessage(BECMessage): end_time (float): End time of the scan. scan_name (str): Name of the scan. num_points (int): Number of points in the scan. + num_monitored_readouts (int): Number of monitored readouts in the scan. request_inputs (dict, optional): Inputs for the scan request, if available. stored_data_info (dict[str, dict[str, _StoredDataInfo]], optional): Information about the stored data for each device in the scan. metadata (dict, optional): Additional metadata. @@ -920,7 +924,8 @@ class ScanHistoryMessage(BECMessage): start_time: float end_time: float scan_name: str - num_points: int + num_points: int | None = None + num_monitored_readouts: int | None = None request_inputs: dict | None = None stored_data_info: dict[str, dict[str, _StoredDataInfo]] | None = None diff --git a/bec_lib/bec_lib/scan_data_container.py b/bec_lib/bec_lib/scan_data_container.py index 6011e04b3..444305e15 100644 --- a/bec_lib/bec_lib/scan_data_container.py +++ b/bec_lib/bec_lib/scan_data_container.py @@ -709,7 +709,13 @@ def __repr__(self) -> str: scan_number = f"\tScan number: {self._msg.scan_number}\n" scan_name = f"\tScan name: {self._msg.scan_name}\n" exit_status = f"\tStatus: {self._msg.exit_status}\n" - num_points = f"\tNumber of points (monitored): {self._msg.num_points}\n" + num_points = f"\tNumber of points: {self._msg.num_points}\n" + num_monitored_readouts = ( + f"\tNumber of monitored readouts: {self._msg.num_monitored_readouts}\n" + if self._msg.num_monitored_readouts is not None + else "" + ) + public_file = f"\tFile: {self._msg.file_path}\n" details = ( start_time @@ -720,6 +726,7 @@ def __repr__(self) -> str: + scan_name + exit_status + num_points + + num_monitored_readouts + public_file ) return f"ScanDataContainer:\n {details}" diff --git a/bec_lib/bec_lib/scan_items.py b/bec_lib/bec_lib/scan_items.py index 6efb493cc..51cc6f3b2 100644 --- a/bec_lib/bec_lib/scan_items.py +++ b/bec_lib/bec_lib/scan_items.py @@ -48,6 +48,7 @@ class ScanItem: open_scan_defs: Set of open scan definition IDs. open_queue_group: Queue group this scan belongs to. num_points: Total number of data points in the scan. + num_monitored_readouts: Total number of monitored readouts in the scan. start_time: Unix timestamp when the scan started. end_time: Unix timestamp when the scan ended. scan_report_instructions: Instructions for generating scan reports. @@ -85,6 +86,7 @@ def __init__( self.open_scan_defs = set() self.open_queue_group = None self.num_points: int | None = None + self.num_monitored_readouts: int | None = None self.start_time: float | None = None self.end_time: float | None = None self.scan_report_instructions: list[dict] = [] @@ -215,13 +217,25 @@ def describe(self) -> str: scan_id = f"\tScan ID: {self.scan_id}\n" if self.scan_id else "" scan_number = f"\tScan number: {self.scan_number}\n" if self.scan_number else "" num_points = f"\tNumber of points: {self.num_points}\n" if self.num_points else "" + num_monitored_readouts = ( + f"\tNumber of monitored readouts: {self.num_monitored_readouts}\n" + if self.num_monitored_readouts + else "" + ) public_file = "" for file_path in self.public_files: file_name = file_path.split("/")[-1] if "_master" in file_name: public_file = "\tFile: " + file_path + "\n" details = ( - start_time + end_time + elapsed_time + scan_id + scan_number + num_points + public_file + start_time + + end_time + + elapsed_time + + scan_id + + scan_number + + num_points + + num_monitored_readouts + + public_file ) return details @@ -344,11 +358,13 @@ def update_with_scan_status(self, scan_status: messages.ScanStatusMessage) -> No ) return + terminal_states = {"aborted", "halted", "closed", "user_completed"} + # update timestamps if scan_status.status == "open": scan_item.start_time = scan_status.timestamp - elif scan_status.timestamp: - # update for all other statuses if timestamp is provided + elif scan_status.status in terminal_states and scan_status.timestamp: + # Only terminal states should stamp the end time; paused scans remain open. scan_item.end_time = scan_status.timestamp # update status message @@ -359,6 +375,9 @@ def update_with_scan_status(self, scan_status: messages.ScanStatusMessage) -> No if scan_status.num_points: scan_item.num_points = scan_status.num_points + if scan_status.num_monitored_readouts: + scan_item.num_monitored_readouts = scan_status.num_monitored_readouts + # update scan number if scan_number is not None: scan_item.scan_number = scan_number @@ -369,8 +388,8 @@ def update_with_scan_status(self, scan_status: messages.ScanStatusMessage) -> No # add scan def id scan_def_id = scan_status.info.get("scan_def_id") if scan_def_id: - if scan_status.status != "open": - scan_item.open_scan_defs.remove(scan_def_id) + if scan_status.status in terminal_states: + scan_item.open_scan_defs.discard(scan_def_id) else: scan_item.open_scan_defs.add(scan_def_id) diff --git a/bec_lib/bec_lib/scan_report.py b/bec_lib/bec_lib/scan_report.py index 3c65ad646..bc00e6ff7 100644 --- a/bec_lib/bec_lib/scan_report.py +++ b/bec_lib/bec_lib/scan_report.py @@ -274,7 +274,7 @@ def _num_points_reached(self) -> bool: """ if not self.scan: return False - return self.scan.num_points == len(self.scan.live_data) + return self.scan.num_monitored_readouts == len(self.scan.live_data) def __str__(self) -> str: separator = "--" * 10 diff --git a/bec_lib/bec_lib/scans.py b/bec_lib/bec_lib/scans.py index 041d8d72f..4e1025db4 100644 --- a/bec_lib/bec_lib/scans.py +++ b/bec_lib/bec_lib/scans.py @@ -262,29 +262,53 @@ def _get_runtime_arg_type(dtype: object) -> type | tuple[type, ...]: return tuple(runtime_types) if dtype is None: return types.NoneType + if dtype is float: + return (float, int) return dtype @staticmethod - def get_arg_type(in_type: str | dict | list): - """translate type string into python type""" - # pylint: disable=too-many-return-statements - if in_type == "float": - return (float, int) - if in_type == "int": - return int - if in_type == "list": - return list - if in_type in ("boolean", "bool"): - return bool - if in_type == "str": - return str - if in_type == "dict": - return dict - if in_type in ("device", "DeviceBase"): - return DeviceBase - if dtype := deserialize_dtype(in_type): - return Scans._get_runtime_arg_type(dtype) - raise TypeError(f"Unknown type {in_type}") + def _arg_matches_type(arg, dtype: object) -> bool: + """Validate an argument against a possibly nested type annotation.""" + if get_origin(dtype) is Annotated: + return Scans._arg_matches_type(arg, get_args(dtype)[0]) + if get_origin(dtype) is Literal: + return any(arg == literal for literal in get_args(dtype)) + if dtype.__class__.__name__ == "_UnionGenericAlias" or dtype.__class__ == types.UnionType: + return any(Scans._arg_matches_type(arg, union_arg) for union_arg in get_args(dtype)) + if dtype is None: + return arg is None + origin = get_origin(dtype) + if origin is list: + if not isinstance(arg, list): + return False + args = get_args(dtype) + if not args: + return True + return all(Scans._arg_matches_type(item, args[0]) for item in arg) + if origin is dict: + if not isinstance(arg, dict): + return False + key_type, value_type = get_args(dtype) or (object, object) + return all( + Scans._arg_matches_type(key, key_type) + and Scans._arg_matches_type(value, value_type) + for key, value in arg.items() + ) + if origin is tuple: + if not isinstance(arg, tuple): + return False + args = get_args(dtype) + if not args: + return True + if len(args) == 2 and args[1] is Ellipsis: + return all(Scans._arg_matches_type(item, args[0]) for item in arg) + if len(arg) != len(args): + return False + return all( + Scans._arg_matches_type(item, item_type) for item, item_type in zip(arg, args) + ) + runtime_type = Scans._get_runtime_arg_type(dtype) + return isinstance(arg, runtime_type) @staticmethod def prepare_scan_request( @@ -334,10 +358,12 @@ def prepare_scan_request( # check that all arguments are of the correct type for ii, arg in enumerate(args): - if not isinstance(arg, Scans.get_arg_type(arg_input[ii % len(arg_input)])): + serialized_dtype = arg_input[ii % len(arg_input)] + dtype = deserialize_dtype(serialized_dtype) + if not Scans._arg_matches_type(arg, dtype): raise TypeError( f"{scan_info.get('doc')}\n Argument {ii} must be of type" - f" {arg_input[ii%len(arg_input)]}, not {type(arg).__name__}." + f" {serialized_dtype}, not {type(arg).__name__}." ) metadata = {} diff --git a/bec_lib/bec_lib/signature_serializer.py b/bec_lib/bec_lib/signature_serializer.py index 577978e51..33fadad6a 100644 --- a/bec_lib/bec_lib/signature_serializer.py +++ b/bec_lib/bec_lib/signature_serializer.py @@ -40,6 +40,14 @@ def _serialize_dtype(dtype: object) -> Generator[str | dict, None, None]: yield from itertools.chain.from_iterable(_serialize_union_arg(x) for x in dtype.__args__) # type: ignore if dtype.__class__.__name__ == "_LiteralGenericAlias": yield {"Literal": dtype.__args__} # type: ignore + origin = get_origin(dtype) + if origin is not None and origin not in {Annotated, Literal, Union, types.UnionType}: + yield { + "Generic": { + "origin": serialize_dtype(origin), + "args": [serialize_dtype(arg) for arg in get_args(dtype)], + } + } def _serialize_union_arg(dtype: object) -> list[str | dict]: @@ -209,6 +217,13 @@ def deserialize_dtype(dtype: list | dict | str) -> object: if "Annotated" in dtype: annotated_dtype = dtype.get("Annotated") return _deserialize_annotated_dtype(annotated_dtype) + if "Generic" in dtype: + generic_dtype = dtype["Generic"] + origin = deserialize_dtype(generic_dtype["origin"]) + args = tuple(deserialize_dtype(arg) for arg in generic_dtype.get("args", [])) + if origin is inspect._empty: + return inspect._empty + return origin[args] if args else origin return Literal[*dtype["Literal"]] if dtype == "_empty": # pylint: disable=protected-access diff --git a/bec_lib/bec_lib/tests/utils.py b/bec_lib/bec_lib/tests/utils.py index b273fcf7e..9d5a1fd9f 100644 --- a/bec_lib/bec_lib/tests/utils.py +++ b/bec_lib/bec_lib/tests/utils.py @@ -661,20 +661,20 @@ def get_device_info_mock(device_name, device_class) -> messages.DeviceInfoMessag device_base_class = "positioner" if device_class == "SimPositioner" else "signal" if device_base_class == "positioner": signals = positioner_info["device_info"]["signals"] - elif device_base_class == "signal": - signals = { - device_name: { - "metadata": { - "connected": True, - "read_access": True, - "write_access": False, - "timestamp": 0, - "status": None, - "severity": None, - "precision": None, - } - } - } + # elif device_base_class == "signal": + # signals = { + # device_name: { + # "metadata": { + # "connected": True, + # "read_access": True, + # "write_access": False, + # "timestamp": 0, + # "status": None, + # "severity": None, + # "precision": None, + # } + # } + # } else: signals = {} dev_info = { diff --git a/bec_lib/tests/test_beamline_states.py b/bec_lib/tests/test_beamline_states.py index e6a936e8e..29ef4ac77 100644 --- a/bec_lib/tests/test_beamline_states.py +++ b/bec_lib/tests/test_beamline_states.py @@ -3,10 +3,13 @@ import inspect from unittest import mock +import numpy as np import pytest +import yaml from pydantic import BaseModel from bec_lib import bl_states, messages +from bec_lib.bl_state_machine import BeamlineStateMachine from bec_lib.bl_state_manager import ( BeamlineStateClientBase, BeamlineStateManager, @@ -196,6 +199,315 @@ def test_device_within_limits_state(self, connected_connector, dm_with_devices): assert state.evaluate(invalid).status == "invalid" assert state.evaluate(missing).status == "invalid" + @pytest.fixture(scope="function") + def aggregated_state_config(self): + """Fixture for an test aggregated state configuration.""" + return bl_states.AggregatedStateConfig( + name="alignment", + states={ + "alignment": { + "devices": { + "samx": { + "value": 0, + "abs_tol": 0.1, + "low_limit": {"value": -20, "abs_tol": 0.1}, + "high_limit": {"value": 20, "abs_tol": 0.1}, + }, + "bpm4i": {"value": 0, "abs_tol": 0.1}, + } + }, + "measurement": { + "devices": { + "samx": { + "value": 19, + "abs_tol": 0.1, + "low_limit": {"value": -20, "abs_tol": 0.1}, + "high_limit": {"value": 20, "abs_tol": 0.1}, + "signals": {"velocity": {"value": 5, "abs_tol": 0.1}}, + }, + "bpm4i": {"value": 2, "abs_tol": 0.1}, + } + }, + "test": {"devices": {"bpm4i": {"value": 0, "abs_tol": 0.1}}}, + "string_state": {"devices": {"bpm3i": {"value": "ok"}}}, + }, + ) + + def test_aggregated_state_init_and_start( + self, connected_connector, dm_with_devices, aggregated_state_config + ): + """ + Test the initialization of the AggregatedState. + + Based on the provided configuration, we expect certain callbacks to be registered with the + Redis connector. This test checks this which essentially checks the proper functionality + of the 'start' method. + """ + + state = bl_states.AggregatedState( + name=aggregated_state_config.name, + config=aggregated_state_config, + redis_connector=connected_connector, + device_manager=dm_with_devices, + ) + state.start() + # We should now have subscriptions on samx limits, readback and read_configuration, and bpm4i & bpm4i + info = [ + MessageEndpoints.device_readback("samx"), + MessageEndpoints.device_read_configuration("samx"), + MessageEndpoints.device_limits("samx"), + MessageEndpoints.device_readback("bpm4i"), + MessageEndpoints.device_readback("bpm3i"), + ] + for endpoint in info: + assert endpoint.endpoint in state.connector._topics_cb + + def test_aggregated_state_evaluation( + self, connected_connector, dm_with_devices, aggregated_state_config + ): + """ + Test the evaluation of the AggregatedState when receiving message updates. This should trigger a state evaluation for + the affected labels and the current state, and if the state changes, a new state should be published. + """ + state = bl_states.AggregatedState( + name=aggregated_state_config.name, + config=aggregated_state_config, + redis_connector=connected_connector, + device_manager=dm_with_devices, + ) + state.start() + + with ( + mock.patch.object(state, "evaluate", return_value=None) as evaluate, + mock.patch.object(state, "_emit_state") as emit_state, + ): + # Test triggering evaluation for multiple labels + # samx affects alignment and measurement, so both should be evaluated. + msg_with_2_states = messages.DeviceMessage( + signals={"samx": {"value": 5.0, "timestamp": 1.0}} + ) + msg_obj = MessageObject( + value=msg_with_2_states, topic=MessageEndpoints.device_readback("samx").endpoint + ) + state._update_aggregated_state(msg_obj, device="samx", source="readback") + evaluate.assert_called_once_with(affected_labels=set(["alignment", "measurement"])) + emit_state.assert_not_called() # As evaluate is mocked to return None, _emit_state should not be called + + def test_aggregated_state_evaluate( + self, connected_connector, dm_with_devices, aggregated_state_config + ): + """ + Test the evaluate method. + We manually cache the relevant messages and then call evaluate with the affected label. + We then check if the output message has the expected status and label, and if the current labels are updated correctly. + """ + state = bl_states.AggregatedState( + name=aggregated_state_config.name, + config=aggregated_state_config, + redis_connector=connected_connector, + device_manager=dm_with_devices, + ) + state._build_rules() + # Assume that we are currently in test + state._current_labels = ["test"] + state._cache_message( + "samx", + "readback", + messages.DeviceMessage( + signals={"samx": {"value": 0, "timestamp": 1.0}}, metadata={"stream": "primary"} + ), + ) + state._cache_message( + "samx", + "configuration", + messages.DeviceMessage( + signals={"samx_velocity": {"value": 5, "timestamp": 1.0}}, + metadata={"stream": "baseline"}, + ), + ) + state._cache_message( + "samx", + "limits", + messages.DeviceMessage( + signals={ + "low": {"value": -20, "timestamp": 1.0}, + "high": {"value": 20, "timestamp": 1.0}, + }, + metadata={"stream": "baseline"}, + ), + ) + state._cache_message( + "bpm4i", + "readback", + messages.DeviceMessage( + signals={"bpm4i": {"value": 0, "timestamp": 1.0}}, metadata={"stream": "primary"} + ), + ) + + msg = state.evaluate(affected_labels={"alignment"}) + + assert msg.status == "valid" + # The order of the labels is not guaranteed + assert msg.label in ["alignment|test", "test|alignment"] + assert set(state._current_labels) == set(["alignment", "test"]) + + state._cache_message( + "samx", + "readback", + messages.DeviceMessage( + signals={"samx": {"value": 3, "timestamp": 2.0}}, metadata={"stream": "primary"} + ), + ) + + msg = state.evaluate(affected_labels={"alignment"}) + + assert msg.status == "valid" + assert msg.label == "test" + assert state._current_labels == ["test"] + + state._cache_message( + "bpm4i", + "readback", + messages.DeviceMessage( + signals={"bpm4i": {"value": 2, "timestamp": 2.0}}, metadata={"stream": "primary"} + ), + ) + + msg = state.evaluate(affected_labels={"alignment", "test", "measurement"}) + + assert msg.status == "invalid" + assert msg.label == "No matching state" + assert state._current_labels == [] + + def test_aggregated_state_exception_handling( + self, connected_connector, dm_with_devices, aggregated_state_config + ): + """ + Test that if an exception is raised during the evaluation of the state, this is properly handled and an alarm is raised. + We check that the evaluate method is called and that if it raises an exception, the raise_alarm method of the connector + is called, and a state with status "unknown" and label "broken state" is published. + """ + state = bl_states.AggregatedState( + name=aggregated_state_config.name, + config=aggregated_state_config, + redis_connector=connected_connector, + device_manager=dm_with_devices, + ) + state.start() + msg = messages.DeviceMessage( + signals={"samx": {"value": 0, "timestamp": 1.0}}, metadata={"stream": "primary"} + ) + msg_obj = MessageObject(value=msg, topic=MessageEndpoints.device_readback("samx").endpoint) + + with ( + mock.patch.object( + state, "evaluate", side_effect=RuntimeError("broken state") + ) as evaluate, + mock.patch.object(connected_connector, "raise_alarm") as raise_alarm, + ): + state._update_aggregated_state(msg_obj, device="samx", source="readback") + + evaluate.assert_called_once_with(affected_labels={"alignment", "measurement"}) + raise_alarm.assert_called_once() + out = connected_connector.xread( + MessageEndpoints.beamline_state("alignment"), from_start=True + ) + assert out[-1]["data"].status == "unknown" + assert out[-1]["data"].label == "broken state" + assert state.raised_warning is True + + def test_aggregated_state_transitions_between_labels( + self, connected_connector, dm_with_devices, aggregated_state_config + ): + """ + Test the transitions between different labels of the aggregated state. We simulate the messages that would trigger + the transitions and check that the output message has the expected status and label, and that the current labels are updated correctly. + """ + state = bl_states.AggregatedState( + name=aggregated_state_config.name, + config=aggregated_state_config, + redis_connector=connected_connector, + device_manager=dm_with_devices, + ) + state.start() + + def update(device, source, signals): + msg = messages.DeviceMessage(signals=signals, metadata={"stream": "primary"}) + msg_obj = MessageObject(value=msg, topic=state._endpoint(device, source).endpoint) + state._update_aggregated_state(msg_obj, device=device, source=source) + out = connected_connector.xread( + MessageEndpoints.beamline_state("alignment"), from_start=True + ) + return out[-1]["data"] + + msg = update("samx", "configuration", {"samx_velocity": {"value": 5, "timestamp": 1.0}}) + assert msg.status == "invalid" + + update( + "samx", + "limits", + {"low": {"value": -20, "timestamp": 1.0}, "high": {"value": 20, "timestamp": 1.0}}, + ) + update("samx", "readback", {"samx": {"value": 0, "timestamp": 1.0}}) + msg = update("bpm4i", "readback", {"bpm4i": {"value": 0, "timestamp": 1.0}}) + assert msg.status == "valid" + assert set(msg.label.split("|")) == {"alignment", "test"} + + msg = update("samx", "readback", {"samx": {"value": 19, "timestamp": 2.0}}) + assert msg.status == "valid" + assert msg.label == "test" + + msg = update("bpm4i", "readback", {"bpm4i": {"value": 2, "timestamp": 2.0}}) + assert msg.status == "valid" + assert msg.label == "measurement" + + @pytest.mark.parametrize( + ("cached_value", "expected_value", "abs_tolerance", "matches"), + [ + (1.05, 1.0, 0.1, True), + (1.2, 1.0, 0.1, False), + (5, 5, 0.0, True), + (np.int64(5), 5, 0.0, True), + (np.float64(1.05), 1.0, 0.1, True), + ("ok", "ok", 0.0, True), + ("not-ok", "ok", 0.0, False), + ([1, 2], 1, 0.0, False), + (np.array([1.0, 2.0]), 1.0, 0.1, False), + (np.array([1.0, 2.0]), np.array([1.0, 2.0]), 0.0, False), + ], + ) + def test_aggregated_state_requirement_matches( + self, + connected_connector, + dm_with_devices, + aggregated_state_config, + cached_value, + expected_value, + abs_tolerance, + matches, + ): + """ + Test the evaluation of requirements in the aggregated state. We manually set the signal value + cache and then call the _requirement_matches method with a requirement, and check if the output is as expected. + """ + state = bl_states.AggregatedState( + name=aggregated_state_config.name, + config=aggregated_state_config, + redis_connector=connected_connector, + device_manager=dm_with_devices, + ) + requirement = bl_states.ResolvedStateSignal( + label="alignment", + device_name="bpm4i", + signal_name="bpm4i", + expected_value=expected_value, + abs_tolerance=abs_tolerance, + source="readback", + ) + state._signal_value_cache[("bpm4i", "readback", "bpm4i")] = cached_value + + assert state._requirement_matches(requirement) is matches + class TestBeamlineStateManager: def test_manager_registers_for_state_updates(self, connected_connector): @@ -321,3 +633,83 @@ def test_show_all_prints_table(self, state_manager, capsys): captured = capsys.readouterr() assert "shutter_open" in (captured.out + captured.err) + + +class TestStateMachine: + + @pytest.fixture() + def state_machine(self, state_manager): + state_machine = BeamlineStateMachine(manager=state_manager) + return state_machine + + @pytest.fixture() + def config_dict(self): + return { + "alignment": { + "devices": { + "samx": { + "value": 0, + "abs_tol": 0.1, + "signals": {"velocity": {"value": 5, "abs_tol": 0.1}}, + } + } + } + } + + def test_load_from_config_with_dict( + self, state_machine: BeamlineStateMachine, tmp_path, config_dict + ): + """Test loading configuration from a dictionary or file.""" + + # Load valid configuration from dictionary + with mock.patch.object(state_machine._manager, "add") as manager_add: + state_machine.load_from_config( + name="alignment", config_path=None, config_dict=config_dict + ) + manager_add.assert_called_once_with( + bl_states.AggregatedStateConfig(name="alignment", states=config_dict) + ) + # Loading with both config_path and config_dict should raise an error + with pytest.raises(ValueError): + state_machine.load_from_config( + name="alignment", config_path="path/to/config.yaml", config_dict=config_dict + ) + # Loading with neither config_path nor config_dict should raise an error + with pytest.raises(ValueError): + state_machine.load_from_config(name="alignment", config_path=None, config_dict=None) + + # Loading from file should work. + config_path = tmp_path / "config.yaml" + with open(config_path, "w", encoding="utf-8") as f: + yaml.dump(config_dict, f) + state_machine.load_from_config(name="alignment", config_path=str(config_path)) + manager_add.assert_called_with( + bl_states.AggregatedStateConfig(name="alignment", states=config_dict) + ) + + def test_update_config(self, state_machine: BeamlineStateMachine, config_dict, tmp_path): + """Test update method of state machine.""" + with mock.patch.object(state_machine._manager, "_update_state") as manager_update: + config = bl_states.AggregatedStateConfig(name="alignment", states=config_dict) + state_machine.update_config(name="alignment", config_dict=config_dict) + manager_update.assert_called_once_with(config) + + manager_update.reset_mock() + + # Invalid updates should raise an error + with pytest.raises(ValueError): + state_machine.update_config(name="alignment", config_dict=None) + manager_update.assert_not_called() + + with pytest.raises(ValueError): + state_machine.update_config( + name="alignment", config_path="path/to/config.yaml", config_dict=config_dict + ) + manager_update.assert_not_called() + manager_update.reset_mock() + # Updating from file should work. + config_path = tmp_path / "config.yaml" + with open(config_path, "w", encoding="utf-8") as f: + yaml.dump(config_dict, f) + state_machine.update_config(name="alignment", config_path=str(config_path)) + manager_update.assert_called_once_with(config) diff --git a/bec_lib/tests/test_scan_context.py b/bec_lib/tests/test_scan_context.py index b8fe2b037..7523de79f 100644 --- a/bec_lib/tests/test_scan_context.py +++ b/bec_lib/tests/test_scan_context.py @@ -144,30 +144,41 @@ def test_parameter_bundler(bec_client_mock): @pytest.mark.parametrize( - "in_type,out", + "dtype,out", [ - ("float", (float, int)), - ("int", int), - ("list", list), - ("boolean", bool), - ("bool", bool), - ("str", str), - ("dict", dict), - ("device", DeviceBase), - ("DeviceBase", DeviceBase), - (serialize_dtype(Annotated[float, ScanArgument(description="Step size")]), float), - ( - serialize_dtype(Annotated[float, ScanArgument(description="Step size")] | None), - (float, type(None)), - ), + (float, (float, int)), + (int, int), + (list, list), + (bool, bool), + (str, str), + (dict, dict), + (DeviceBase, DeviceBase), + (list[float], list[float]), + (Annotated[float, ScanArgument(description="Step size")], (float, int)), + (Annotated[float, ScanArgument(description="Step size")] | None, (float, int, type(None))), ], ) -def test_get_arg_type(bec_client_mock, in_type, out): +def test_get_runtime_arg_type(bec_client_mock, dtype, out): client = bec_client_mock - res = client.scans.get_arg_type(in_type) + res = client.scans._get_runtime_arg_type(dtype) assert res == out +@pytest.mark.parametrize( + ("arg", "dtype", "matches"), + [ + ([1, 2.5], list[float], True), + ([1, "nope"], list[float], False), + ({1: 2.0}, dict[int, float], True), + ({1: "nope"}, dict[int, float], False), + ], +) +def test_arg_matches_type_for_generic_containers(bec_client_mock, arg, dtype, matches): + client = bec_client_mock + + assert client.scans._arg_matches_type(arg, dtype) is matches + + def test_strip_scan_signature_annotations_for_ipython_signature(): signature = [ { @@ -202,12 +213,6 @@ def test_strip_scan_signature_annotations_for_ipython_signature(): ] -def test_get_arg_type_raises(bec_client_mock): - client = bec_client_mock - with pytest.raises(TypeError): - client.scans.get_arg_type("not_existing") - - def test_interactive_scan_cm(bec_client_mock): client = bec_client_mock client.scans._open_interactive_scan = mock.MagicMock() diff --git a/bec_lib/tests/test_scan_items.py b/bec_lib/tests/test_scan_items.py index e56375114..49f6dcd5c 100644 --- a/bec_lib/tests/test_scan_items.py +++ b/bec_lib/tests/test_scan_items.py @@ -253,6 +253,20 @@ def test_update_with_scan_status_updates_end_time(): assert scan_item.end_time == 10 +def test_update_with_scan_status_does_not_update_end_time_for_paused(): + scan_manager = ScanManager(ConnectorMock("")) + with mock.patch.object(scan_manager.scan_storage, "find_scan_by_ID") as mock_find_scan: + scan_item = mock.MagicMock() + scan_item.end_time = 0 + mock_find_scan.return_value = scan_item + scan_manager.scan_storage.update_with_scan_status( + messages.ScanStatusMessage( + scan_id="scan_id", status="paused", scan_number=1, info={}, timestamp=10 + ) + ) + assert scan_item.end_time == 0 + + def test_update_with_scan_status_does_not_update_end_time(): scan_manager = ScanManager(ConnectorMock("")) with mock.patch.object(scan_manager.scan_storage, "find_scan_by_ID") as mock_find_scan: @@ -371,6 +385,26 @@ def test_update_with_scan_status_removes_scan_def_id(): assert "scan_def_id" not in scan_item.open_scan_defs +def test_update_with_scan_status_keeps_scan_def_id_for_paused(): + scan_manager = ScanManager(ConnectorMock("")) + scan_manager.scan_storage.last_scan_number = 0 + with mock.patch.object(scan_manager.scan_storage, "find_scan_by_ID") as mock_find_scan: + scan_item = mock.MagicMock() + scan_item.open_scan_defs = {"scan_def_id"} + mock_find_scan.return_value = scan_item + scan_manager.scan_storage.update_with_scan_status( + messages.ScanStatusMessage( + scan_id="scan_id", + status="paused", + scan_number=1, + num_points=10, + info={"scan_def_id": "scan_def_id"}, + timestamp=10, + ) + ) + assert "scan_def_id" in scan_item.open_scan_defs + + def test_add_scan_segment_emits_data(): scan_manager = ScanManager(ConnectorMock("")) scan_item = mock.MagicMock() diff --git a/bec_lib/tests/test_scan_object.py b/bec_lib/tests/test_scan_object.py index e51d63da8..323d52690 100644 --- a/bec_lib/tests/test_scan_object.py +++ b/bec_lib/tests/test_scan_object.py @@ -9,7 +9,7 @@ def scan_obj(bec_client_mock): scan_info = { "class": "FermatSpiralScan", - "arg_input": {"device": "device", "start": "float", "stop": "float"}, + "arg_input": {"device": "DeviceBase", "start": "float", "stop": "float"}, "required_kwargs": ["step", "relative"], "arg_bundle_size": {"bundle": 3, "min": 2, "max": 2}, "doc": ( diff --git a/bec_lib/tests/test_scan_report.py b/bec_lib/tests/test_scan_report.py index 6d7e8e42b..5006061a5 100644 --- a/bec_lib/tests/test_scan_report.py +++ b/bec_lib/tests/test_scan_report.py @@ -171,20 +171,20 @@ def test_scan_report_file_written_no_master(scan_report): def test_scan_report_num_points_reached(scan_report): with mock.patch.object(scan_report.request, "scan") as mock_scan: - mock_scan.num_points = 10 + mock_scan.num_monitored_readouts = 10 mock_scan.live_data = {"0": "msg", "1": "msg", "2": "msg"} assert scan_report._num_points_reached() is False def test_scan_report_num_points_reached_no_points(scan_report): with mock.patch.object(scan_report.request, "scan") as mock_scan: - mock_scan.num_points = 0 + mock_scan.num_monitored_readouts = 0 mock_scan.live_data = {} assert scan_report._num_points_reached() is True def test_scan_report_num_points_reached_match(scan_report): with mock.patch.object(scan_report.request, "scan") as mock_scan: - mock_scan.num_points = 3 + mock_scan.num_monitored_readouts = 3 mock_scan.live_data = {"0": "msg", "1": "msg", "2": "msg"} assert scan_report._num_points_reached() is True diff --git a/bec_lib/tests/test_signature_serializer.py b/bec_lib/tests/test_signature_serializer.py index d585cb15a..c47ac95e6 100644 --- a/bec_lib/tests/test_signature_serializer.py +++ b/bec_lib/tests/test_signature_serializer.py @@ -322,6 +322,7 @@ def test_func(step: Annotated[float, scan_argument] | None = None): }, ), (Annotated[float, "unknown metadata"], "float"), + (list[float], {"Generic": {"origin": "list", "args": ["float"]}}), ], ) def test_serialize_dtype(dtype_in, dtype_out): @@ -372,6 +373,7 @@ def test_serialize_dtype(dtype_in, dtype_out): ({"Annotated": {"type": "float", "metadata": {"Other": {}}}}, float), ({"Annotated": {"type": "float", "metadata": {}}}, float), ({"Annotated": {"type": "float"}}, float), + ({"Generic": {"origin": "list", "args": ["float"]}}, list[float]), ( { "Annotated": { diff --git a/bec_server/bec_server/device_server/device_server.py b/bec_server/bec_server/device_server/device_server.py index cc4f26753..01995184c 100644 --- a/bec_server/bec_server/device_server/device_server.py +++ b/bec_server/bec_server/device_server/device_server.py @@ -761,7 +761,8 @@ def status_callback(self, status): content = status.instruction.content is_config_set = content["action"] == "set" - is_rpc_set = content["action"] == "rpc" and (".set" in content["parameter"]["func"]) + rpc_func = content["parameter"].get("func", "") + is_rpc_set = content["action"] == "rpc" and (rpc_func == "set" or ".set" in rpc_func) if is_config_set or is_rpc_set: if obj.kind == Kind.config: @@ -807,8 +808,11 @@ def _read_device(self, instr: messages.DeviceInstructionMessage, new_status=True return self.requests_handler.add_request(instr, num_status_objects=0) - self._read_and_update_devices(devices, instr.metadata) - self.requests_handler.set_finished(instr.metadata["device_instr_id"], success=True) + result = self._read_and_update_devices(devices, instr.metadata) + response_result = result if instr.parameter.get("return_result", False) else None + self.requests_handler.set_finished( + instr.metadata["device_instr_id"], success=True, result=response_result + ) def _read_and_update_devices(self, devices: list[str], metadata: dict) -> list: start = time.time() diff --git a/bec_server/bec_server/file_writer/file_writer_manager.py b/bec_server/bec_server/file_writer/file_writer_manager.py index 3da18e040..0fc61a44c 100644 --- a/bec_server/bec_server/file_writer/file_writer_manager.py +++ b/bec_server/bec_server/file_writer/file_writer_manager.py @@ -37,6 +37,7 @@ def __init__(self, scan_number: int, scan_id: str) -> None: self.scan_segments = {} self.scan_finished = False self.num_points: int | None = None + self.num_monitored_readouts: int | None = None self.baseline = {} self.async_writer: AsyncWriter | None = None self.beamline_states: dict[str, list[messages.BeamlineStateMessage]] = defaultdict(list) @@ -66,14 +67,16 @@ def ready_to_write(self) -> bool: if self.enforce_sync: # wait for all points to be received. Since this method will be called for every # update of the scan segments, we can also accept to write after the scan is finished - _ready_to_write = self.scan_finished and (self.num_points == len(self.scan_segments)) + _ready_to_write = self.scan_finished and ( + self.num_monitored_readouts == len(self.scan_segments) + ) if not _ready_to_write: if self.status_msg is None or self.status_msg.readout_priority is None: return False monitored_devices = self.status_msg.readout_priority.get("monitored") if not monitored_devices: logger.info( - f"Received number of segments: {len(self.scan_segments)}, Number of points (expected): {self.num_points}, Ready to write: {_ready_to_write}" + f"Received number of segments: {len(self.scan_segments)}, Number of monitored readouts (expected): {self.num_monitored_readouts}, Ready to write: {_ready_to_write}" ) return self.scan_finished return _ready_to_write @@ -267,6 +270,7 @@ def update_scan_storage_with_status(self, msg: messages.ScanStatusMessage) -> No scan_storage.scan_finished = True scan_storage.num_points = msg.num_points + scan_storage.num_monitored_readouts = msg.num_monitored_readouts info = msg.content.get("info") if info: if msg.scan_type == "step": @@ -455,6 +459,7 @@ def write_file(self, scan_id: str) -> None: start_time=storage.start_time, end_time=storage.end_time, num_points=storage.num_points, + num_monitored_readouts=storage.num_monitored_readouts, scan_name=storage.metadata.get("scan_name"), request_inputs=storage.metadata.get("request_inputs", {}), stored_data_info=self.file_writer.stored_data_info or {}, diff --git a/bec_server/bec_server/scan_bundler/bec_emitter.py b/bec_server/bec_server/scan_bundler/bec_emitter.py index f252c1be8..69d7ddb23 100644 --- a/bec_server/bec_server/scan_bundler/bec_emitter.py +++ b/bec_server/bec_server/scan_bundler/bec_emitter.py @@ -105,9 +105,12 @@ def _update_scan_progress(self, scan_id: str, point_id: int, done=False) -> None ) return info = self.scan_bundler.sync_storage[scan_id]["info"] + + num_monitored_readouts = info.get("num_monitored_readouts", info.get("num_points", 0)) + msg = messages.ProgressMessage( value=point_id + 1, - max_value=info.get("num_points", point_id + 1), + max_value=num_monitored_readouts or point_id + 1, done=done, metadata={ "scan_id": scan_id, @@ -143,8 +146,9 @@ def on_scan_status_update(self, status_msg: messages.ScanStatusMessage): return num_points = max(status_msg.info.get("num_points", 0) - 1, 0) + num_monitored_readouts = status_msg.info.get("num_monitored_readouts", num_points) if status_msg.status == "closed": - self._update_scan_progress(status_msg.scan_id, num_points, done=True) + self._update_scan_progress(status_msg.scan_id, num_monitored_readouts, done=True) return sb = self.scan_bundler diff --git a/bec_server/bec_server/scan_bundler/scan_bundler.py b/bec_server/bec_server/scan_bundler/scan_bundler.py index 5ddd16363..78fd7a0be 100644 --- a/bec_server/bec_server/scan_bundler/scan_bundler.py +++ b/bec_server/bec_server/scan_bundler/scan_bundler.py @@ -310,11 +310,19 @@ def _add_device_to_storage(self, msgs, device, timeout_time=10): device_is_monitor_sync = self.sync_storage[scan_id]["info"]["monitor_sync"] == device dev_obj = self.device_manager.devices.get(device) if dev_obj in self.monitored_devices[scan_id]["devices"] or device_is_monitor_sync: - if self.sync_storage[scan_id]["info"]["scan_type"] == "step": + if self.sync_storage[scan_id]["info"]["scan_type"] in [ + "step", # DEPRECATED: will be removed in the future, only software_triggered and hardware_triggered will be supported + "software_triggered", + "hardware_triggered", + ]: self._step_scan_update(scan_id, device, signal, metadata) elif self.sync_storage[scan_id]["info"]["scan_type"] == "fly": + # DEPRECATED: will be removed in the future self._fly_scan_update(scan_id, device, signal, metadata) else: + logger.error( + f"Unknown scan type {self.sync_storage[scan_id]['info']['scan_type']}" + ) raise RuntimeError( f"Unknown scan type {self.sync_storage[scan_id]['info']['scan_type']}" ) diff --git a/bec_server/bec_server/scan_server/direct_scan_worker.py b/bec_server/bec_server/scan_server/direct_scan_worker.py new file mode 100644 index 000000000..d3460a02b --- /dev/null +++ b/bec_server/bec_server/scan_server/direct_scan_worker.py @@ -0,0 +1,203 @@ +from __future__ import annotations + +import time +import traceback +from typing import TYPE_CHECKING + +from bec_lib import messages +from bec_lib.alarm_handler import Alarms +from bec_lib.logger import bec_logger +from bec_server.scan_server.errors import DeviceInstructionError, ScanAbortion, UserScanInterruption +from bec_server.scan_server.scan_queue import InstructionQueueStatus +from bec_server.scan_server.scans.scans_v4 import ScanBase + +logger = bec_logger.logger + +if TYPE_CHECKING: + from bec_server.scan_server.scan_queue import DirectInstructionQueueItem + from bec_server.scan_server.scan_worker import ScanWorker + +SCAN_SEQUENCE = [ + "prepare_scan", + "open_scan", + "stage", + "pre_scan", + "scan_core", + "post_scan", + "unstage", + "close_scan", +] + + +class DirectScanWorker: + """ + DirectScanWorker runs scan lifecycle methods directly. + Unlike GeneratorScanWorker, it does not interpret instructions. + Instructions are sent directly to Redis by the scan itself. + """ + + def __init__(self, *, worker: ScanWorker): + self.worker = worker + self.scan = None + + def reset(self): + self.scan = None + + def process_instructions(self, queue: DirectInstructionQueueItem) -> None: + self.worker.current_instruction_queue_item = queue + + scan = queue.move_to_next_scan() + if scan is None: + logger.error("No scan found in the queue item to process.") + return + self.run(scan) + + queue.status = InstructionQueueStatus.COMPLETED + self.worker.current_instruction_queue_item = None + self.reset() + + def run(self, scan: ScanBase): + """ + Run the scan. + + Args: + scan (ScanBase): Scan to run + """ + self.scan = scan + + # pylint: disable=protected-access + scan.actions._interruption_callback = self.check_for_interruption + scan.actions._update_queue_info_callback = self.update_queue_info + queue = self.worker.current_instruction_queue_item + try: + with self.worker.device_manager._rpc_method(scan.actions.rpc_call): + for step in SCAN_SEQUENCE: + method = getattr(scan, step, None) + if not method: + raise ScanAbortion(f"Scan is missing required method: {step}") + self.check_for_interruption() + method() + except Exception as exc: + if self.worker.signal_event.is_set(): + # If the signal event is set, it means that the scan worker is shutting down, so we don't need to handle the abortion + return + if queue is None: + return + if queue.stopped or not queue.active_request_block: + raise exc + queue.stopped = True + try: + # We reset the worker to RUNNING to allow for cleanup tasks + # during the on_exception hook. + self.worker.status = InstructionQueueStatus.RUNNING + self.scan.actions._metadata_suffix = "__on-exception" + self._run_on_exception_hook(exc) + except Exception as exc_cleanup: + self.worker.connector.send_client_info("") + self._handle_exception(exc_cleanup) + self._handle_exception(exc) + if queue is None: + return + queue.status = InstructionQueueStatus.COMPLETED + self.worker.current_instruction_queue_item = None + self.reset() + + def _handle_exception(self, exc: Exception): + content = traceback.format_exc() + logger.error(content) + + def _raise_alarm(error_info: messages.ErrorInfo): + self.worker.connector.raise_alarm( + severity=Alarms.MAJOR, info=error_info, metadata=self.get_metadata_for_alarm() + ) + + if isinstance(exc, DeviceInstructionError): + _raise_alarm(error_info=exc.error_info) + raise ScanAbortion from exc + error_info = messages.ErrorInfo( + error_message=content, + compact_error_message=traceback.format_exc(limit=0), + exception_type=exc.__class__.__name__, + device=None, + ) + _raise_alarm(error_info=error_info) + raise ScanAbortion from exc + + def check_for_interruption(self): + if self.worker.status == InstructionQueueStatus.PAUSED: + if self.scan is not None: + self.scan.actions._send_scan_status("paused") + while self.worker.status == InstructionQueueStatus.PAUSED: + time.sleep(0.1) + if self.worker.status == InstructionQueueStatus.STOPPED: + item = self.worker.current_instruction_queue_item + if item is None or item.exit_info is None: + raise ScanAbortion() + raise UserScanInterruption(exit_info=item.exit_info) + + def update_queue_info(self): + self.worker.current_instruction_queue_item.parent.queue_manager.send_queue_status() + + def _propagate_error(self, content: str, exc: Exception): + logger.error(content) + error_info = messages.ErrorInfo( + error_message=content, + compact_error_message=traceback.format_exc(limit=0), + exception_type=exc.__class__.__name__, + device=None, + ) + self.worker.connector.raise_alarm( + severity=Alarms.MAJOR, info=error_info, metadata=self.get_metadata_for_alarm() + ) + + def get_metadata_for_alarm(self) -> dict: + if self.scan is None: + return {} + metadata = {} + if self.scan.scan_info.scan_id is not None: + metadata["scan_id"] = self.scan.scan_info.scan_id + if self.scan.scan_info.scan_number is not None: + metadata["scan_number"] = self.scan.scan_info.scan_number + return metadata + + def _run_on_exception_hook(self, exc: Exception): + scan = self.scan + if scan is None: + return + if not self.worker.current_instruction_queue_item.run_on_exception_hook: + return + hook_exc = exc.__cause__ if exc.__cause__ is not None else exc + if not hasattr(scan, "on_exception") or not callable(getattr(scan, "on_exception")): + return + try: + scan._shutdown_event.clear() + with self.worker.device_manager._rpc_method(scan.actions.rpc_call): + scan.on_exception(hook_exc) # type: ignore + except Exception: + scan.actions.send_client_info("") + logger.exception("Failed to run direct scan on_exception hook") + + def _handle_scan_abortion(self, queue: DirectInstructionQueueItem, exc: ScanAbortion): + # TODO: We currently access the method from the scan worker for being backwards compatible with + # the generator-based worker. Once we have fully switched to the direct worker, we should move + # the method to the run method of the direct worker and remove it from the scan worker. + content = traceback.format_exc() + logger.error(content) + if self.scan is None: + return + + exit_info = exc.exit_info if isinstance(exc, UserScanInterruption) else queue.exit_info + if exit_info: + self.scan.actions._send_scan_status(exit_info[0], reason=exit_info[1]) + else: + reason = "alarm" + if queue.run_on_exception_hook: + self.scan.actions._send_scan_status("aborted", reason=reason) + else: + self.scan.actions._send_scan_status("halted", reason=reason) + + queue.status = InstructionQueueStatus.STOPPED + queue.append_to_queue_history() + self.worker.parent.queue_manager.queues[self.worker.queue_name].abort() + self.reset() + self.worker.status = InstructionQueueStatus.RUNNING diff --git a/bec_server/bec_server/scan_server/errors.py b/bec_server/bec_server/scan_server/errors.py index 80dca70f6..7bc022632 100644 --- a/bec_server/bec_server/scan_server/errors.py +++ b/bec_server/bec_server/scan_server/errors.py @@ -11,6 +11,10 @@ class ScanAbortion(Exception): pass +class ScanInputValidationError(ValueError): + pass + + class UserScanInterruption(ScanAbortion): def __init__(self, exit_info: ExitInfoType): super().__init__() diff --git a/bec_server/bec_server/scan_server/generator_scan_worker.py b/bec_server/bec_server/scan_server/generator_scan_worker.py new file mode 100644 index 000000000..117a3f269 --- /dev/null +++ b/bec_server/bec_server/scan_server/generator_scan_worker.py @@ -0,0 +1,566 @@ +from __future__ import annotations + +import os +import time +import traceback +from string import Template +from typing import TYPE_CHECKING, Literal + +from bec_lib import messages +from bec_lib.alarm_handler import Alarms +from bec_lib.endpoints import MessageEndpoints +from bec_lib.file_utils import compile_file_components +from bec_lib.logger import bec_logger + +from .errors import DeviceInstructionError, ScanAbortion, UserScanInterruption +from .scan_queue import InstructionQueueItem, InstructionQueueStatus, RequestBlock +from .scan_stubs import ScanStubStatus + +logger = bec_logger.logger + +if TYPE_CHECKING: + from bec_server.scan_server.scan_worker import ScanWorker + + +class GeneratorScanWorker: + """ + Scan worker class that processes scan instructions and sends device instructions to the device server. + """ + + def __init__(self, *, worker: ScanWorker): + self.worker = worker + self.scan_id = None + self.readout_priority = {} + self.scan_type = None + self.current_scan_id: str = "" + self.current_scan_info = None + self.max_point_id = 0 + self._exposure_time = None + self.interception_msg = None + self.reset() + + def open_scan(self, instr: messages.DeviceInstructionMessage) -> None: + """ + Open a new scan and emit a scan status message. + + Args: + instr (DeviceInstructionMessage): Device instruction received from the scan assembler + + """ + if not self.scan_id: + self.scan_id = instr.metadata.get("scan_id") + self.readout_priority = instr.content["parameter"].get("readout_priority", {}) + self.scan_type = instr.content["parameter"].get("scan_type") + + if not instr.metadata.get("scan_def_id"): + self.max_point_id = 0 + instr_num_points = instr.content["parameter"].get("num_points", 0) + if instr_num_points is None: + instr_num_points = 0 + num_points = self.max_point_id + instr_num_points + if self.max_point_id: + num_points += 1 + + active_rb = self.worker.current_instruction_queue_item.active_request_block + + self._initialize_scan_info(active_rb, instr, num_points) + + # only append the scan_progress if the scan is not using device_progress + if active_rb.scan.use_scan_progress_report: + if not self.scan_report_instructions or not self.scan_report_instructions[-1].get( + "device_progress" + ): + self.scan_report_instructions.append( + { + "scan_progress": { + "points": num_points, + "show_table": active_rb.scan.show_live_table, + } + } + ) + self.worker.current_instruction_queue_item.parent.queue_manager.send_queue_status() + + self._send_scan_status("open") + + def close_scan(self, instr: messages.DeviceInstructionMessage, max_point_id: int) -> None: + """ + Close a scan and emit a scan status message. + + Args: + instr (DeviceInstructionMessage): Device instruction received from the scan assembler + max_point_id (int): Maximum point ID of the scan + """ + scan_id = instr.metadata.get("scan_id") + + if self.scan_id != scan_id: + return + + # reset the scan ID now that the scan will be closed + self.scan_id = None + + scan_info = self.current_scan_info + if scan_info.get("scan_type") == "fly": + # flyers do not increase the point_id but instead set the num_points directly + num_points = ( + self.worker.current_instruction_queue_item.active_request_block.scan.num_pos + ) + self.current_scan_info["num_points"] = num_points + + else: + # point_id starts at 0 + scan_info["num_points"] = max_point_id + 1 + + self._send_scan_status("closed") + + def publish_data_as_read(self, instr: messages.DeviceInstructionMessage): + """ + Publish data as read by sending a DeviceMessage to the device_read endpoint. + This instruction replicates the behaviour of the device server when it receives a read instruction. + + Args: + instr (DeviceInstructionMessage): Device instruction received from the scan assembler + """ + connector = self.worker.device_manager.connector + data = instr.content["parameter"]["data"] + devices = instr.content["device"] + if not isinstance(devices, list): + devices = [devices] + if not isinstance(data, list): + data = [data] + for device, dev_data in zip(devices, data): + msg = messages.DeviceMessage(signals=dev_data, metadata=instr.metadata) + connector.set_and_publish(MessageEndpoints.device_read(device), msg) + + def process_scan_report_instruction(self, instr): + """ + Process a scan report instruction by appending it to the scan_report_instructions list. + + Args: + instr (DeviceInstructionMessage): Device instruction received from the scan assembler + + """ + self.scan_report_instructions.append(instr.content["parameter"]) + self.worker.current_instruction_queue_item.parent.queue_manager.send_queue_status() + + def forward_instruction(self, instr: messages.DeviceInstructionMessage) -> None: + """ + Forward an instruction to the device server. + + Args: + instr (DeviceInstructionMessage): Device instruction received from the scan assembler + + """ + self.worker.connector.send(MessageEndpoints.device_instructions(), instr) + + @property + def scan_report_instructions(self): + """ + List of scan report instructions + """ + req_block = self.worker.current_instruction_queue_item.active_request_block + return req_block.scan_report_instructions + + def _wait_for_device_server(self) -> None: + self.worker.parent.wait_for_service("DeviceServer") + + def _check_for_interruption(self) -> None: + if self.worker.status == InstructionQueueStatus.PAUSED: + self._send_scan_status("paused") + while self.worker.status == InstructionQueueStatus.PAUSED: + time.sleep(0.1) + if self.worker.status == InstructionQueueStatus.STOPPED: + item = self.worker.current_instruction_queue_item + if item is None or item.exit_info is None: + raise ScanAbortion() + raise UserScanInterruption(exit_info=item.exit_info) + + def _initialize_scan_info( + self, active_rb: RequestBlock, instr: messages.DeviceInstructionMessage, num_points: int + ): + + metadata = active_rb.metadata + self.current_scan_info = {**instr.metadata, **instr.content["parameter"]} + self.current_scan_info.update(metadata) + self.current_scan_info.update( + { + "scan_number": self.worker.parent.scan_number, + "dataset_number": self.worker.parent.dataset_number, + "exp_time": self._exposure_time, + "frames_per_trigger": active_rb.scan.frames_per_trigger, + "settling_time": active_rb.scan.settling_time, + "readout_time": active_rb.scan.readout_time, + "scan_report_devices": active_rb.scan.scan_report_devices, + "monitor_sync": active_rb.scan.monitor_sync, + "num_points": num_points, + "scan_parameters": active_rb.scan.scan_parameters, + "request_inputs": active_rb.scan.request_inputs, + "file_components": compile_file_components( + base_path=self._get_file_base_path(), + scan_nr=self.worker.parent.scan_number, + file_directory=active_rb.scan.scan_parameters["system_config"][ + "file_directory" + ], + user_suffix=active_rb.scan.scan_parameters["system_config"]["file_suffix"], + ), + } + ) + self.current_scan_info["scan_msgs"] = [ + str(scan_msg) for scan_msg in self.worker.current_instruction_queue_item.scan_msgs + ] + self.current_scan_info["args"] = active_rb.scan.parameter["args"] + self.current_scan_info["kwargs"] = active_rb.scan.parameter["kwargs"] + self.current_scan_info["readout_priority"] = { + "monitored": [ + dev.full_name + for dev in self.worker.device_manager.devices.monitored_devices( + readout_priority=self.readout_priority + ) + ], + "baseline": [ + dev.full_name + for dev in self.worker.device_manager.devices.baseline_devices( + readout_priority=self.readout_priority + ) + ], + "async": [ + dev.full_name + for dev in self.worker.device_manager.devices.async_devices( + readout_priority=self.readout_priority + ) + ], + "continuous": [ + dev.full_name + for dev in self.worker.device_manager.devices.continuous_devices( + readout_priority=self.readout_priority + ) + ], + "on_request": [ + dev.full_name + for dev in self.worker.device_manager.devices.on_request_devices( + readout_priority=self.readout_priority + ) + ], + } + + def _get_file_base_path(self) -> str: + """ + Get the file base path for the scan data. The base path can be a string or a template. + If it is a template, the account name will be substituted into the template. + The account name is retrieved from the current account message. + If the account name is not found, an empty string will be used. + """ + current_account_msg = self.worker.connector.get_last(MessageEndpoints.account(), "data") + if current_account_msg: + current_account = current_account_msg.value + if not isinstance(current_account, str): + logger.warning( + f"Account name is not a string: {current_account}. " "Ignoring specified value." + ) + current_account = None + else: + if "/" in current_account: + raise ValueError( + f"Account name cannot contain a slash (/): {current_account}. " + ) + # _ and - are allowed + check_value = current_account.replace("_", "").replace("-", "") + if not check_value.isalnum() or not check_value.isascii(): + raise ValueError( + f"Account name can only contain alphanumeric characters: {current_account}. " + ) + + else: + current_account = None + + # pylint: disable=protected-access + file_base_path = self.worker.parent._service_config.config["file_writer"]["base_path"] + if "$" not in file_base_path: + # we deal with a normal string + if current_account: + return os.path.abspath(os.path.join(file_base_path, current_account)) + # if there is no account, we return the base path with the data folder + return os.path.abspath(file_base_path) + + # we deal with a string template + file_base_path = Template(file_base_path) + + try: + # check if the template is valid + return os.path.abspath(file_base_path.substitute(account=current_account or "")) + except KeyError as exc: + raise ValueError( + f"Invalid template variable: {exc} in the file base path. " + "Please check your service config." + ) from exc + + def _send_scan_status( + self, + status: Literal["open", "paused", "closed", "aborted", "halted", "user_completed"], + reason: Literal["user", "alarm"] | None = None, + ) -> None: + if not self.current_scan_info: + return + current_scan_info_print = self.current_scan_info.copy() + if current_scan_info_print.get("positions", []): + current_scan_info_print["positions"] = "..." + logger.info( + f"New scan status: {self.current_scan_id} / {status} / {current_scan_info_print}" + ) + si = self.current_scan_info + update_fields = [ + "scan_name", + "scan_number", + "session_id", + "dataset_number", + "num_points", + "scan_type", + "scan_report_devices", + "user_metadata", + "readout_priority", + "scan_parameters", + "request_inputs", + ] + update = {k: si.get(k) for k in update_fields if si.get(k) is not None} + msg = messages.ScanStatusMessage( + scan_id=self.current_scan_id, + status=status, + reason=reason, + num_monitored_readouts=si.get("num_points"), + info=self.current_scan_info, + **update, + ) + if msg.readout_priority != (cur_rp := self.current_scan_info.get("readout_priority")): + raise RuntimeError( + f"Readout priority mismatch: expected {cur_rp}, got {msg.readout_priority}" + ) + expire = None if status in ["open", "paused"] else 1800 + pipe = self.worker.device_manager.connector.pipeline() + self.worker.device_manager.connector.set( + MessageEndpoints.public_scan_info(self.current_scan_id), msg, pipe=pipe, expire=expire + ) + self.worker.device_manager.connector.set_and_publish( + MessageEndpoints.scan_status(), msg, pipe=pipe + ) + pipe.execute() + + def update_instr_with_scan_report(self, instr: messages.DeviceInstructionMessage): + if not self.scan_report_instructions: + return + for scan_report in self.scan_report_instructions: + if "readback" not in scan_report: + continue + readback = scan_report["readback"] + instr_device = ( + instr.content["device"] + if isinstance(instr.content["device"], list) + else [instr.content["device"]] + ) + + if set(readback.get("devices", [])) & set(instr_device): + instr.metadata["response"] = True + + def get_metadata_for_alarm(self) -> dict: + """ + Get metadata for the alarm to be raised in case of an error. + This includes the scan ID and scan number if available. + + Returns: + dict: Metadata dictionary with scan ID and scan number. + """ + metadata = {} + if not self.current_scan_info: + return metadata + + if self.current_scan_info.get("scan_id"): + metadata["scan_id"] = self.current_scan_info["scan_id"] + if self.current_scan_info.get("scan_number"): + metadata["scan_number"] = self.current_scan_info["scan_number"] + return metadata + + ############################# + # PROCESS INSTRUCTIONS LOOP # + ############################# + + def _init_instruction_loop(self, queue: InstructionQueueItem) -> float | None: + """Get ready to run the process instructions loop, and return the start time if successful.""" + if not queue: + return None + self.worker.current_instruction_queue_item = queue + start = time.time() + self.max_point_id = 0 + # make sure the device server is ready to receive data + self._wait_for_device_server() + queue.is_active = True + return start + + def _propagate_pi_error(self, content: str, error_info: messages.ErrorInfo): + logger.error(content) + self.worker.connector.raise_alarm( + severity=Alarms.MAJOR, info=error_info, metadata=self.get_metadata_for_alarm() + ) + + def process_instructions(self, queue: InstructionQueueItem) -> None: + """ + Process scan instructions and send DeviceInstructions to OPAAS. + For now this is an in-memory communication. In the future however, + we might want to pass it through a dedicated Kafka topic. + Args: + queue: instruction queue + + Returns: + + """ + if (start := self._init_instruction_loop(queue)) is None: + return + try: + rpc_method = queue.queue.request_blocks[0].scan.stubs._rpc_call + with self.worker.device_manager._rpc_method(rpc_method): + for instr in queue: + self._check_for_interruption() + if instr is None: + continue + self._exposure_time = getattr(queue.active_request_block.scan, "exp_time", None) + self._instruction_step(instr) + except ScanAbortion as exc: + if self.worker.signal_event.is_set(): + return + if queue.stopped or not (queue.return_to_start and queue.active_request_block): + raise exc + queue.stopped = True + try: + cleanup = queue.active_request_block.scan.move_to_start() + rpc_method = queue.active_request_block.scan.stubs._rpc_call + self.worker.status = InstructionQueueStatus.RUNNING + with self.worker.device_manager._rpc_method(rpc_method): + for instr in cleanup: + self._check_for_interruption() + instr.metadata["scan_id"] = queue.queue.active_rb.scan_id + instr.metadata["queue_id"] = queue.queue_id + self._instruction_step(instr) + except DeviceInstructionError as exc_di: + self._propagate_pi_error(traceback.format_exc(), exc_di.error_info) + raise ScanAbortion from exc_di + except Exception as exc_return_to_start: + # if the return_to_start fails, raise the original exception + content = traceback.format_exc() + error_info = messages.ErrorInfo( + error_message=content, + compact_error_message=traceback.format_exc(limit=0), + exception_type=exc_return_to_start.__class__.__name__, + device=None, + ) + self._propagate_pi_error(content, error_info) + raise exc + raise exc + except DeviceInstructionError as exc_di: + self._propagate_pi_error(traceback.format_exc(), exc_di.error_info) + raise ScanAbortion from exc_di + except Exception as exc: + content = traceback.format_exc() + error_info = messages.ErrorInfo( + error_message=content, + compact_error_message=traceback.format_exc(limit=0), + exception_type=exc.__class__.__name__, + device=None, + ) + self._propagate_pi_error(content, error_info) + raise ScanAbortion from exc + queue.is_active = False + queue.status = InstructionQueueStatus.COMPLETED + self.worker.current_instruction_queue_item = None + + logger.info(f"QUEUE ITEM finished after {time.time()-start:.2f} seconds") + self.reset() + + def _instruction_step(self, instr: messages.DeviceInstructionMessage): + logger.debug(instr) + action = instr.content.get("action") + scan_def_id = instr.metadata.get("scan_def_id") + self.current_scan_id = instr.metadata.get("scan_id", "") + + if "point_id" in instr.metadata: + self.max_point_id = instr.metadata["point_id"] + + logger.debug(f"Device instruction: {instr}") + self._check_for_interruption() + + if action == "open_scan": + self.open_scan(instr) + elif action == "close_scan" and scan_def_id is None: + self.close_scan(instr, self.max_point_id) + elif action == "close_scan" and scan_def_id is not None: + pass + elif action == "open_scan_def": + pass + elif action == "close_scan_def": + self.close_scan(instr, self.max_point_id) + elif action == "publish_data_as_read": + self.publish_data_as_read(instr) + elif action == "scan_report_instruction": + self.process_scan_report_instruction(instr) + elif action == "set": + self.update_instr_with_scan_report(instr) + self.forward_instruction(instr) + elif action in [ + "trigger", + "kickoff", + "complete", + "baseline_reading", + "pre_scan", + "rpc", + "read", + "stage", + "unstage", + ]: + self.forward_instruction(instr) + + else: + raise ValueError(f"Unknown device instruction: {instr}") + + def reset(self): + """reset the scan worker and its member variables""" + self.current_scan_id = "" + self.current_scan_info = {} + self.scan_id = None + self.interception_msg = None + self.worker.current_instruction_queue_item = None + + def cleanup(self): + """perform cleanup instructions""" + status = ScanStubStatus(self.worker.parent.queue_manager.instruction_handler) + staged_devices = [ + dev.root.name for dev in self.worker.device_manager.devices.enabled_devices + ] + msg = messages.DeviceInstructionMessage( + device=staged_devices, + action="unstage", + parameter={}, + metadata={"device_instr_id": status._device_instr_id}, + ) + self.forward_instruction(msg) + # status.wait() + + def _handle_scan_abortion(self, queue: InstructionQueueItem, exc: ScanAbortion): + content = traceback.format_exc() + logger.error(content) + + exit_info = None + if isinstance(exc, UserScanInterruption): + exit_info = exc.exit_info + else: + exit_info = queue.exit_info + if exit_info: + self._send_scan_status(exit_info[0], reason=exit_info[1]) + else: + reason = "alarm" + if queue.return_to_start: + self._send_scan_status("aborted", reason=reason) + else: + self._send_scan_status("halted", reason=reason) + logger.info(f"Scan aborted: {queue.queue_id}") + queue.status = InstructionQueueStatus.STOPPED + queue.append_to_queue_history() + self.cleanup() + self.worker.parent.queue_manager.queues[self.worker.queue_name].abort() + self.reset() + self.worker.status = InstructionQueueStatus.RUNNING diff --git a/bec_server/bec_server/scan_server/scan_assembler.py b/bec_server/bec_server/scan_server/scan_assembler.py index 57b0eb1e2..f3112034a 100644 --- a/bec_server/bec_server/scan_server/scan_assembler.py +++ b/bec_server/bec_server/scan_server/scan_assembler.py @@ -4,9 +4,13 @@ from typing import TYPE_CHECKING from bec_lib import messages +from bec_lib.device import DeviceBase from bec_lib.logger import bec_logger -from .scans import RequestBase, ScanBase, unpack_scan_args +from .scan_gui_models import GUIInput +from .scan_input_validator import ScanInputValidator +from .scans.legacy_scans import RequestBase, ScanArgType, ScanBase, unpack_scan_args +from .scans.scans_v4 import ScanBase as ScanBaseV4 logger = bec_logger.logger @@ -24,6 +28,7 @@ def __init__(self, *, parent: ScanServer): self.device_manager = self.parent.device_manager self.connector = self.parent.connector self.scan_manager = self.parent.scan_manager + self.input_validator = ScanInputValidator() def is_scan_message(self, msg: messages.ScanQueueMessage) -> bool: """Check if the scan queue message would construct a new scan. @@ -35,10 +40,21 @@ def is_scan_message(self, msg: messages.ScanQueueMessage) -> bool: bool: True if the message is a scan message, False otherwise """ scan = msg.content.get("scan_type") - cls_name = self.scan_manager.available_scans[scan]["class"] - scan_cls = self.scan_manager.scan_dict[cls_name] + scan_cls = self.scan_manager.scan_dict[scan] return issubclass(scan_cls, ScanBase) + def is_direct_scan_message(self, msg: messages.ScanQueueMessage) -> bool: + """Check if the scan queue message would construct a new direct scan. + + Args: + msg (messages.ScanQueueMessage): message to be checked + Returns: + bool: True if the message is a direct scan message, False otherwise + """ + scan = msg.content.get("scan_type") + scan_cls = self.scan_manager.scan_dict[scan] + return issubclass(scan_cls, ScanBaseV4) + def assemble_device_instructions( self, msg: messages.ScanQueueMessage, scan_id: str ) -> RequestBase: @@ -56,19 +72,70 @@ def assemble_device_instructions( RequestBase: Scan instance of the initialized scan class """ scan = msg.content.get("scan_type") - cls_name = self.scan_manager.available_scans[scan]["class"] - scan_cls = self.scan_manager.scan_dict[cls_name] + scan_cls = self.scan_manager.scan_dict[scan] logger.info(f"Preparing instructions of request of type {scan} / {scan_cls.__name__}") args = unpack_scan_args(msg.content.get("parameter", {}).get("args", [])) kwargs = msg.content.get("parameter", {}).get("kwargs", {}) + request_inputs = self._assemble_request_inputs(scan_cls, args, kwargs) + + scan_instance = scan_cls( + *args, + device_manager=self.device_manager, + parameter=msg.content.get("parameter"), + metadata=msg.metadata, + instruction_handler=self.parent.queue_manager.instruction_handler, + scan_id=scan_id, + request_inputs=request_inputs, + **kwargs, + ) + return scan_instance + + def assemble_direct_scan(self, msg: messages.ScanQueueMessage, scan_id: str) -> ScanBaseV4: + """Assemble the device instructions for a given ScanQueueMessage. + This will be achieved by calling the specified class (must be a derived class of ScanBaseV4) + + Args: + msg (messages.ScanQueueMessage): scan queue message for which the instruction should be assembled + scan_id (str): scan id of the scan + + Raises: + ScanAbortion: Raised if the scan initialization fails. + + Returns: + ScanBaseV4: Scan instance of the initialized scan class + """ + scan = msg.content.get("scan_type") + scan_cls = self.scan_manager.scan_dict[scan] + + logger.info(f"Preparing instructions of direct scan of type {scan} / {scan_cls.__name__}") + args = unpack_scan_args(msg.content.get("parameter", {}).get("args", [])) + kwargs = msg.content.get("parameter", {}).get("kwargs", {}) + + request_inputs = self._assemble_request_inputs(scan_cls, args, kwargs) + resolved_args, resolved_kwargs = self._resolve_direct_scan_inputs(scan_cls, args, kwargs) + self.input_validator.validate(scan_cls, resolved_args, resolved_kwargs) + + scan_instance = scan_cls( + *resolved_args, + device_manager=self.device_manager, + redis_connector=self.connector, + metadata=msg.metadata, + instruction_handler=self.parent.queue_manager.instruction_handler, + scan_id=scan_id, + request_inputs=request_inputs, + **resolved_kwargs, + ) + return scan_instance + + def _assemble_request_inputs(self, scan_cls, args, kwargs) -> dict: + cls_input_args = [ name for name, val in inspect.signature(scan_cls).parameters.items() if val.default == inspect.Parameter.empty and name != "kwargs" ] - request_inputs = {} if scan_cls.arg_bundle_size["bundle"] > 0: request_inputs["arg_bundle"] = args @@ -100,15 +167,53 @@ def assemble_device_instructions( for key, val in kwargs.items(): if key not in cls_input_args: request_inputs["kwargs"][key] = val - - scan_instance = scan_cls( - *args, - device_manager=self.device_manager, - parameter=msg.content.get("parameter"), - metadata=msg.metadata, - instruction_handler=self.parent.queue_manager.instruction_handler, - scan_id=scan_id, - request_inputs=request_inputs, - **kwargs, - ) - return scan_instance + return request_inputs + + def _resolve_direct_scan_inputs(self, scan_cls, args, kwargs) -> tuple[list, dict]: + """Resolve v4 scan device arguments from names to device objects.""" + arg_input = getattr(scan_cls, "arg_input", {}) or {} + signature_annotations = self.input_validator.scan_signature_annotations(scan_cls) + kwarg_annotations = {**signature_annotations, **arg_input} + if not arg_input and not signature_annotations: + return args, kwargs + + resolved_args = list(args) + resolved_kwargs = kwargs.copy() + + if arg_input and scan_cls.arg_bundle_size["bundle"] > 0: + # Convert arg bundles if present + bundle_size = scan_cls.arg_bundle_size["bundle"] + arg_names = list(arg_input.keys()) + for bundle_start in range(0, len(resolved_args), bundle_size): + for offset, arg_name in enumerate(arg_names): + arg_index = bundle_start + offset + if arg_index >= len(resolved_args): + break + if self._is_device_arg(arg_input.get(arg_name)): + resolved_args[arg_index] = self._resolve_device(resolved_args[arg_index]) + else: + # Convert normal arg inputs + arg_names = list(signature_annotations.keys()) + for arg_index, arg_name in enumerate(arg_names): + if arg_index >= len(resolved_args): + break + if self._is_device_arg(signature_annotations.get(arg_name)): + resolved_args[arg_index] = self._resolve_device(resolved_args[arg_index]) + + # Convert kwarg inputs + for key, value in resolved_kwargs.items(): + if self._is_device_arg(kwarg_annotations.get(key)): + resolved_kwargs[key] = self._resolve_device(value) + + return resolved_args, resolved_kwargs + + def _is_device_arg(self, arg_type) -> bool: + converted = GUIInput.convert_to_legacy_scan_arg_type(arg_type) + if converted == ScanArgType.DEVICE: + return True + return inspect.isclass(converted) and issubclass(converted, DeviceBase) + + def _resolve_device(self, value): + if isinstance(value, DeviceBase): + return value + return self.device_manager.devices[value] diff --git a/bec_server/bec_server/scan_server/scan_guard.py b/bec_server/bec_server/scan_server/scan_guard.py index 9ffd0cca3..e837df906 100644 --- a/bec_server/bec_server/scan_server/scan_guard.py +++ b/bec_server/bec_server/scan_server/scan_guard.py @@ -304,7 +304,7 @@ def _handle_scan_order_change(self, msg: messages.ScanQueueOrderMessage): queue = self.parent.queue_manager.queues[target_queue] for scan in queue.queue: - if msg.scan_id in scan.queue.scan_id: + if msg.scan_id in scan.scan_id: break else: logger.error(f"Scan {msg.scan_id} not found in queue {target_queue}") diff --git a/bec_server/bec_server/scan_server/scan_gui_models.py b/bec_server/bec_server/scan_server/scan_gui_models.py index 17a398b30..c032d290f 100644 --- a/bec_server/bec_server/scan_server/scan_gui_models.py +++ b/bec_server/bec_server/scan_server/scan_gui_models.py @@ -10,7 +10,7 @@ from bec_lib.device import DeviceBase from bec_lib.signature_serializer import signature_to_dict -from bec_server.scan_server.scans import ScanArgType, ScanBase +from bec_server.scan_server.scans.legacy_scans import ScanArgType, ScanBase context_signature = ContextVar("context_signature") context_docstring = ContextVar("context_docstring") @@ -43,6 +43,10 @@ def convert_to_legacy_scan_arg_type(cls, value): if get_origin(value) is Annotated: value = get_args(value)[0] + origin = get_origin(value) + if origin is not None: + value = origin + if not inspect.isclass(value): return value diff --git a/bec_server/bec_server/scan_server/scan_input_validator.py b/bec_server/bec_server/scan_server/scan_input_validator.py new file mode 100644 index 000000000..11063b2bc --- /dev/null +++ b/bec_server/bec_server/scan_server/scan_input_validator.py @@ -0,0 +1,357 @@ +from __future__ import annotations + +import inspect +from collections.abc import Mapping, Sequence +from types import UnionType +from typing import Annotated, Any, TypeAlias, Union, get_args, get_origin, get_type_hints + +from pydantic import ConfigDict, TypeAdapter, ValidationError + +from bec_lib.device import DeviceBase +from bec_lib.scan_args import ScanArgument + +from .errors import ScanInputValidationError +from .scans.legacy_scans import ScanArgType + +ScanClass: TypeAlias = type[Any] +AnnotationMap: TypeAlias = dict[str, Any] + + +class ScanInputValidator: + """Validate scan inputs against supported scan input annotations. + + The validator checks input types and numeric bounds declared on ``ScanArgument`` + metadata inside ``typing.Annotated`` declarations. It supports both v4 scan + input styles: bundled positional arguments declared through ``scan_cls.arg_input`` + and fixed constructor inputs declared on ``scan_cls.__init__``. + """ + + def validate(self, scan_cls: ScanClass, args: Sequence[Any], kwargs: Mapping[str, Any]) -> None: + """Validate resolved scan inputs for a scan class. + + Args: + scan_cls (ScanClass): Scan class whose input annotations should be used. + args (Sequence[Any]): Positional scan arguments after device-name resolution. + kwargs (Mapping[str, Any]): Keyword scan arguments after device-name resolution. + + Raises: + ScanInputValidationError: If an input has the wrong type or violates a supported + ``ScanArgument`` bound. + """ + self._validate_arg_input_bundle(scan_cls, args) + self._validate_signature_inputs(scan_cls, args, kwargs) + + def _validate_arg_input_bundle(self, scan_cls: ScanClass, args: Sequence[Any]) -> None: + """Validate bundled positional arguments declared by ``scan_cls.arg_input``. + + Args: + scan_cls (ScanClass): Scan class that may define ``arg_input`` bundles. + args (Sequence[Any]): Positional scan arguments after device-name resolution. + + Raises: + ScanInputValidationError: If a bundled input has the wrong type or violates a + supported ``ScanArgument`` bound. + """ + arg_input = getattr(scan_cls, "arg_input", {}) or {} + bundle_size = getattr(scan_cls, "arg_bundle_size", {}).get("bundle", 0) + if not arg_input or bundle_size <= 0: + return + + arg_names = list(arg_input.keys()) + for bundle_start in range(0, len(args), bundle_size): + for offset, arg_name in enumerate(arg_names): + arg_index = bundle_start + offset + if arg_index >= len(args): + break + self._validate_value(arg_name, args[arg_index], arg_input[arg_name]) + + def _validate_signature_inputs( + self, scan_cls: ScanClass, args: Sequence[Any], kwargs: Mapping[str, Any] + ) -> None: + """Validate constructor-annotated positional and keyword inputs. + + Args: + scan_cls (ScanClass): Scan class whose constructor annotations should be used. + args (Sequence[Any]): Positional scan arguments after device-name resolution. + kwargs (Mapping[str, Any]): Keyword scan arguments after device-name resolution. + + Raises: + ScanInputValidationError: If a constructor input has the wrong type or violates a + supported ``ScanArgument`` bound. + """ + signature_annotations = self.scan_signature_annotations(scan_cls) + if not self._uses_arg_input_bundle(scan_cls): + for arg_index, arg_name in enumerate(signature_annotations): + if arg_index >= len(args): + break + self._validate_value(arg_name, args[arg_index], signature_annotations[arg_name]) + + for arg_name, value in kwargs.items(): + if arg_name in signature_annotations: + self._validate_value(arg_name, value, signature_annotations[arg_name]) + + def _uses_arg_input_bundle(self, scan_cls: ScanClass) -> bool: + """Return whether positional arguments are described by ``arg_input`` bundles. + + Args: + scan_cls (ScanClass): Scan class to inspect. + + Returns: + bool: True if the scan class uses bundled positional inputs. + """ + return bool(getattr(scan_cls, "arg_input", {}) or {}) and ( + getattr(scan_cls, "arg_bundle_size", {}).get("bundle", 0) > 0 + ) + + def scan_signature_annotations(self, scan_cls: ScanClass) -> AnnotationMap: + """Return constructor input annotations keyed by argument name. + + ``*args`` and ``**kwargs`` are intentionally excluded because they do not + describe individual user-facing scan inputs. + + Args: + scan_cls (ScanClass): Scan class whose constructor annotations should be inspected. + + Returns: + AnnotationMap: Constructor annotations keyed by argument name. + """ + type_hints = get_type_hints(scan_cls.__init__, include_extras=True) + return { + name: type_hints.get(name, parameter.annotation) + for name, parameter in inspect.signature(scan_cls).parameters.items() + if name not in {"args", "kwargs"} + and parameter.annotation is not inspect.Parameter.empty + } + + def _validate_value(self, arg_name: str, value: Any, annotation: Any) -> None: + """Validate a single input value against the annotated type and bounds. + + Args: + arg_name (str): Name of the scan input being validated. + value (Any): Input value after device-name resolution. + annotation (Any): Type annotation that may contain ``ScanArgument`` metadata. + + Raises: + ScanInputValidationError: If the input has the wrong type or violates a supported + ``ScanArgument`` bound. + """ + self._validate_type(arg_name, value, annotation) + scan_argument = self._scan_argument_from_annotation(annotation) + if scan_argument is None: + return + + for operator_name, limit in [ + ("gt", scan_argument.gt), + ("ge", scan_argument.ge), + ("lt", scan_argument.lt), + ("le", scan_argument.le), + ]: + if limit is None: + continue + if not self._satisfies_bound(value, operator_name, limit): + raise ScanInputValidationError( + f"Invalid value for scan argument '{arg_name}': {value!r}. Input must be " + f"{self._bound_description(operator_name)} {limit!r}." + ) + + def _scan_argument_from_annotation(self, annotation: Any) -> ScanArgument | None: + """Extract ``ScanArgument`` metadata from a ``typing.Annotated`` annotation. + + Args: + annotation (Any): Type annotation to inspect. + + Returns: + ScanArgument | None: Extracted scan argument metadata, if present. + """ + if get_origin(annotation) is not Annotated: + return None + for metadata in get_args(annotation)[1:]: + if isinstance(metadata, ScanArgument): + return metadata + return None + + def _validate_type(self, arg_name: str, value: Any, annotation: Any) -> None: + """Validate a single input value against the type part of an annotation. + + Args: + arg_name (str): Name of the scan input being validated. + value (Any): Input value after device-name resolution. + annotation (Any): Type annotation or legacy ``ScanArgType`` to check. + + Raises: + ScanInputValidationError: If the value does not match the annotated type. + """ + type_annotation = self._type_annotation(annotation) + if type_annotation is Any or type_annotation is inspect.Parameter.empty: + return + + value = self._normalize_tuple_payloads(value, type_annotation) + try: + TypeAdapter( + type_annotation, config=ConfigDict(arbitrary_types_allowed=True, strict=True) + ).validate_python(value) + except ValidationError: + raise ScanInputValidationError( + f"Invalid type for scan argument '{arg_name}': expected " + f"{self._type_description(type_annotation)}, got {type(value).__name__}." + ) from None + + def _type_annotation(self, annotation: Any) -> Any: + """Return the runtime-checkable type annotation for a scan input annotation. + + ``ScanArgument`` metadata is stripped from ``Annotated`` declarations + because bounds are validated explicitly by this component. + + Args: + annotation (Any): Type annotation or legacy ``ScanArgType`` to convert. + + Returns: + Any: Annotation suitable for Pydantic ``TypeAdapter`` validation. + """ + if isinstance(annotation, ScanArgType): + return self._legacy_scan_arg_type_annotation(annotation) + + origin = get_origin(annotation) + if origin is Annotated: + return self._type_annotation(get_args(annotation)[0]) + + return annotation + + def _legacy_scan_arg_type_annotation(self, scan_arg_type: ScanArgType) -> Any: + """Return the Python type represented by a legacy ``ScanArgType``. + + Args: + scan_arg_type (ScanArgType): Legacy scan argument type to convert. + + Returns: + Any: Python type represented by the legacy scan argument type. + """ + return { + ScanArgType.DEVICE: DeviceBase, + ScanArgType.FLOAT: float, + ScanArgType.INT: int, + ScanArgType.BOOL: bool, + ScanArgType.STR: str, + ScanArgType.LIST: list, + ScanArgType.DICT: dict, + }[scan_arg_type] + + def _normalize_tuple_payloads(self, value: Any, annotation: Any) -> Any: + """Convert list payloads to tuples where the annotation expects tuples. + + Scan request payloads are serialized through JSON-like message data where + tuple-shaped inputs arrive as lists. This normalization keeps type validation + strict while preserving the accepted request shape. + + Args: + value (Any): Input value after device-name resolution. + annotation (Any): Runtime-checkable type annotation. + + Returns: + Any: Value with tuple-shaped nested lists converted to tuples. + """ + origin = get_origin(annotation) + args = get_args(annotation) + + if origin is Annotated: + return self._normalize_tuple_payloads(value, self._type_annotation(annotation)) + + if origin in {Union, UnionType}: + return value + + if not args: + return value + + if origin is list and isinstance(value, list): + item_type = args[0] + return [self._normalize_tuple_payloads(item, item_type) for item in value] + + if origin is dict and isinstance(value, dict): + key_type, value_type = args if len(args) == 2 else (Any, Any) + return { + self._normalize_tuple_payloads(key, key_type): self._normalize_tuple_payloads( + item, value_type + ) + for key, item in value.items() + } + + if origin is tuple and isinstance(value, (list, tuple)): + if len(args) == 2 and args[1] is Ellipsis: + return tuple(self._normalize_tuple_payloads(item, args[0]) for item in value) + if len(args) == len(value): + return tuple( + self._normalize_tuple_payloads(item, item_type) + for item, item_type in zip(value, args, strict=True) + ) + + return value + + def _type_description(self, annotation: Any) -> str: + """Return a user-facing type description for an annotation. + + Args: + annotation (Any): Type annotation or legacy ``ScanArgType`` to describe. + + Returns: + str: Human-readable type description. + """ + origin = get_origin(annotation) + if origin is Annotated: + annotation = self._type_annotation(annotation) + origin = get_origin(annotation) + + if origin in {Union, UnionType}: + return " or ".join(self._type_description(arg) for arg in get_args(annotation)) + + if origin is not None: + return str(annotation).replace("typing.", "") + + if annotation is None or annotation is type(None): + return "None" + + return getattr(annotation, "__name__", str(annotation)) + + def _satisfies_bound(self, value: Any, operator_name: str, limit: float) -> bool: + """Return whether ``value`` satisfies the named numeric bound. + + Args: + value (Any): Input value to compare. + operator_name (str): Bound operator name, one of ``gt``, ``ge``, ``lt``, or ``le``. + limit (float): Numeric limit from ``ScanArgument`` metadata. + + Returns: + bool: True if the value satisfies the bound. + + Raises: + ScanInputValidationError: If the value cannot be compared to the limit. + """ + try: + if operator_name == "gt": + return value > limit + if operator_name == "ge": + return value >= limit + if operator_name == "lt": + return value < limit + if operator_name == "le": + return value <= limit + except TypeError as exc: + raise ScanInputValidationError( + f"Invalid value for scan argument: {value!r} cannot be compared to {limit!r}." + ) from exc + return True + + def _bound_description(self, operator_name: str) -> str: + """Return a user-facing description for a bound operator. + + Args: + operator_name (str): Bound operator name, one of ``gt``, ``ge``, ``lt``, or ``le``. + + Returns: + str: Human-readable bound description. + """ + return { + "gt": "greater than", + "ge": "greater than or equal to", + "lt": "less than", + "le": "less than or equal to", + }[operator_name] diff --git a/bec_server/bec_server/scan_server/scan_manager.py b/bec_server/bec_server/scan_server/scan_manager.py index a383bad52..f0531ab69 100644 --- a/bec_server/bec_server/scan_server/scan_manager.py +++ b/bec_server/bec_server/scan_server/scan_manager.py @@ -2,7 +2,9 @@ Scan Manager loads the available scans and publishes them to redis. """ +import importlib import inspect +import pkgutil from bec_lib import plugin_helper from bec_lib.device import DeviceBase @@ -12,7 +14,9 @@ from bec_lib.signature_serializer import serialize_dtype, signature_to_dict from bec_server.scan_server.scan_gui_models import GUIConfig -from . import scans as scans_module +from . import scans as scans_v4_module +from .scans import legacy_scans as scans_module +from .scans.scans_v4 import ScanBase as ScanBaseV4 logger = bec_logger.logger @@ -38,7 +42,7 @@ def __init__(self, *, parent): """ self.parent = parent self.available_scans = {} - self.scan_dict: dict[str, type[scans_module.RequestBase]] = {} + self.scan_dict: dict[str, type[scans_module.RequestBase] | type[ScanBaseV4]] = {} self._plugins = {} self.load_plugins() self.update_available_scans() @@ -50,9 +54,9 @@ def load_plugins(self): if not plugins: return for name, cls in plugins.items(): - if not issubclass(cls, scans_module.RequestBase): + if not issubclass(cls, (scans_module.RequestBase, ScanBaseV4)): logger.error( - f"Plugin {name} is not a valid scan plugin as it does not inherit from RequestBase. Skipping." + f"Plugin {name} is not a valid scan plugin as it does not inherit from RequestBase or ScanBase. Skipping." ) continue self._plugins[name] = cls @@ -63,13 +67,15 @@ def update_available_scans(self): members: list[tuple[str, type]] = inspect.getmembers( scans_module, predicate=inspect.isclass ) + members.extend(self._get_v4_scan_members()) members.extend((name, cls) for name, cls in self._plugins.items() if inspect.isclass(cls)) for name, scan_cls in members: - is_scan = issubclass(scan_cls, scans_module.RequestBase) + is_scan = issubclass(scan_cls, (scans_module.RequestBase, ScanBaseV4)) if not is_scan or not scan_cls.scan_name: logger.debug(f"Ignoring {name}") continue + if scan_cls.scan_name in self.available_scans: logger.error(f"{scan_cls.scan_name} already exists. Skipping.") continue @@ -85,7 +91,8 @@ def update_available_scans(self): for report_cls in report_classes: if issubclass(scan_cls, report_cls): base_cls = report_cls.__name__ - self.scan_dict[scan_cls.__name__] = scan_cls + + self.scan_dict[scan_cls.scan_name] = scan_cls gui_config = self.validate_gui_config(scan_cls) gui_visibility = {} if hasattr(scan_cls, "gui_visibility"): @@ -150,6 +157,22 @@ def convert_arg_input(self, arg_input) -> dict: converted_arg_input[key] = serialize_dtype(dtype) return converted_arg_input + def _get_v4_scan_members(self) -> list[tuple[str, type]]: + """Collect classes from all modules in the scans package.""" + members: list[tuple[str, type]] = [] + for module_info in pkgutil.iter_modules( + scans_v4_module.__path__, prefix=f"{scans_v4_module.__name__}." + ): + if module_info.name == f"{scans_v4_module.__name__}.legacy_scans": + continue + module = importlib.import_module(module_info.name) + members.extend( + (name, cls) + for name, cls in inspect.getmembers(module, predicate=inspect.isclass) + if cls.__module__ == module.__name__ + ) + return members + def publish_available_scans(self): """send all available scans to the broker""" self.parent.connector.set( diff --git a/bec_server/bec_server/scan_server/scan_plugins/otf_scan.py b/bec_server/bec_server/scan_server/scan_plugins/otf_scan.py index 93f3b4e75..b5ef2bc70 100644 --- a/bec_server/bec_server/scan_server/scan_plugins/otf_scan.py +++ b/bec_server/bec_server/scan_server/scan_plugins/otf_scan.py @@ -1,7 +1,7 @@ import time from bec_lib.logger import bec_logger -from bec_server.scan_server.scans import ScanArgType, ScanBase, SyncFlyScanBase +from bec_server.scan_server.scans.legacy_scans import ScanArgType, ScanBase, SyncFlyScanBase logger = bec_logger.logger diff --git a/bec_server/bec_server/scan_server/scan_queue.py b/bec_server/bec_server/scan_server/scan_queue.py index 732846be4..d7022a11e 100644 --- a/bec_server/bec_server/scan_server/scan_queue.py +++ b/bec_server/bec_server/scan_server/scan_queue.py @@ -26,6 +26,8 @@ if TYPE_CHECKING: from bec_server.scan_server.scan_server import ScanServer + from bec_server.scan_server.scan_worker import ScanWorker + from bec_server.scan_server.scans.scans_v4 import ScanBase as ScanBase_v4 def requires_queue(fcn): @@ -250,7 +252,7 @@ def _handle_scan_order_change(self, msg: messages.ScanQueueOrderMessage) -> None def _get_queue_item_by_scan_id( self, msg: messages.ScanQueueOrderMessage - ) -> InstructionQueueItem | None: + ) -> InstructionQueueItem | DirectInstructionQueueItem | None: """ Get the queue item by scan_id. @@ -259,7 +261,7 @@ def _get_queue_item_by_scan_id( """ queue = self.queues[msg.queue] for instruction_queue in queue.queue: - if msg.scan_id in instruction_queue.queue.scan_id: + if msg.scan_id in instruction_queue.scan_id: return instruction_queue return None @@ -359,8 +361,6 @@ def set_abort( with AutoResetCM(que): if que.queue: que.status = ScanQueueStatus.PAUSED - if que.worker_status == InstructionQueueStatus.STOPPED: - return instruction_queue = que.active_instruction_queue if not instruction_queue: return @@ -374,7 +374,7 @@ def set_abort( ) return que.worker_status = InstructionQueueStatus.STOPPED - if instruction_queue.scan_id[-1] is None: + if instruction_queue.scan_id and instruction_queue.scan_id[-1] is None: stop_id = instruction_queue.queue_id else: stop_id = instruction_queue.scan_id @@ -388,7 +388,10 @@ def set_halt( exit_info = ("halted", "user" if user_call else "alarm") instruction_queue = self.queues[queue].active_instruction_queue if instruction_queue: - instruction_queue.return_to_start = False + if isinstance(instruction_queue, DirectInstructionQueueItem): + instruction_queue.run_on_exception_hook = False + else: + instruction_queue.return_to_start = False self.set_abort(scan_id=scan_id, queue=queue, exit_info=exit_info) @requires_queue @@ -501,9 +504,14 @@ def set_release_lock( def _get_active_scan_id(self, queue): if len(self.queues[queue].queue) == 0: return None - if self.queues[queue].queue[0].active_request_block is None: + instr_queue = self.queues[queue].queue[0] + if instr_queue.active_request_block is None: return None - return self.queues[queue].queue[0].active_request_block.scan_id + if isinstance(instr_queue, DirectInstructionQueueItem): + if instr_queue.active_scan is None: + return None + return instr_queue.active_scan.scan_info.scan_id + return instr_queue.active_request_block.scan_id def _wait_for_queue_to_appear_in_history( self, scan_id, queue, timeout=60 @@ -617,20 +625,18 @@ def __init__( self, queue_manager: QueueManager, queue_name="primary", - instruction_queue_item_cls: type[InstructionQueueItem] | None = None, + instruction_queue_item_cls: ( + type[InstructionQueueItem] | type[DirectInstructionQueueItem] | None + ) = None, ) -> None: - self.queue: Deque[InstructionQueueItem] = collections.deque() + self.queue: Deque[InstructionQueueItem | DirectInstructionQueueItem] = collections.deque() self.queue_name = queue_name - self.history_queue: collections.deque[InstructionQueueItem] = collections.deque( - maxlen=self.MAX_HISTORY + self.history_queue: collections.deque[InstructionQueueItem | DirectInstructionQueueItem] = ( + collections.deque(maxlen=self.MAX_HISTORY) ) self.active_instruction_queue = None self.queue_manager = queue_manager - self._instruction_queue_item_cls = ( - instruction_queue_item_cls - if instruction_queue_item_cls is not None - else InstructionQueueItem - ) + self._instruction_queue_item_cls_override = instruction_queue_item_cls # self.open_instruction_queue = None self._status = self.DEFAULT_QUEUE_STATUS self.signal_event = threading.Event() @@ -841,11 +847,21 @@ def insert(self, msg: messages.ScanQueueMessage, position=-1, **_kwargs): queue_exists = True if not queue_exists: # create new queue element (InstructionQueueItem) - instruction_queue = self._instruction_queue_item_cls( + assembler = self.queue_manager.parent.scan_assembler + if assembler.is_direct_scan_message(msg): + iq_class = DirectInstructionQueueItem + else: + iq_class = InstructionQueueItem + iq_class = self._instruction_queue_item_cls_override or iq_class + + instruction_queue = iq_class( parent=self, assembler=self.queue_manager.parent.scan_assembler, worker=self.scan_worker, ) + if instruction_queue is None: + logger.error("Failed to create instruction queue item.") + return instruction_queue.append_scan_request(msg) if not queue_exists: instruction_queue.queue_group = target_group @@ -1337,3 +1353,219 @@ def stop(self): for blck in blcks: # pylint: disable=protected-access blck.scan._shutdown_event.set() + + +class DirectInstructionQueueItem: + """ + An instruction queue item for v4 scans. + """ + + def __init__(self, parent: ScanQueue, assembler: ScanAssembler, worker: ScanWorker) -> None: + self.parent = parent + self.assembler = assembler + self.worker = worker + self.exit_info: ExitInfoType | None = None + self.queue_id = str(uuid.uuid4()) + self.stopped = False + self._scan_id = str(uuid.uuid4()) + self.queue_group = None + self.queue_group_is_closed = False + + self._status = InstructionQueueStatus.PENDING + self._run_on_exception_hook = None + + self.active_scan: ScanBase_v4 | None = None + self.scans: list[ScanBase_v4] = [] + self.scan_msgs: list[messages.ScanQueueMessage] = [] + + @property + def status(self) -> InstructionQueueStatus: + """get the status of the instruction queue item""" + return self._status + + @status.setter + def status(self, val: InstructionQueueStatus) -> None: + """set the status of the instruction queue item and update the worker and queue status accordingly""" + logger.debug( + f"Setting status of direct instruction queue {self.parent.queue_name} to {val.name} from thread {threading.current_thread().name}" + ) + self._status = val + self.worker.status = val + if val == InstructionQueueStatus.STOPPED: + self.stop() + self.parent.queue_manager.send_queue_status() + + @property + def active_request_block(self) -> None | ScanBase_v4: + """there are no request blocks for direct instruction queue items""" + return self.active_scan + + @property + def scan_id(self) -> list[str | None]: + return [scan.scan_info.scan_id for scan in self.scans] + + @property + def is_scan(self) -> list[bool]: + return [scan.scan_info.scan_type is not None for scan in self.scans] + + @property + def scan_number(self) -> list[int | None]: + return [self._get_scan_number(scan) for scan in self.scans] + + def append_scan_request(self, msg: messages.ScanQueueMessage) -> None: + """ + Append a new scan from a scan queue message. The scan will be assembled but not executed until it becomes active. + + Args: + msg (ScanQueueMessage): the scan queue message containing the scan information + """ + scan = self.assembler.assemble_direct_scan(msg, scan_id=self._scan_id) + self.scans.append(scan) + self.scan_msgs.append(msg) + + def set_active(self): + """change the instruction queue status to RUNNING""" + if self.status == InstructionQueueStatus.PENDING: + self.status = InstructionQueueStatus.RUNNING + + @property + def run_on_exception_hook(self) -> bool: + """whether or not to run the direct scan on_exception hook after scan abortion""" + if self._run_on_exception_hook is not None: + return self._run_on_exception_hook + if self.active_scan is not None: + return bool(self.active_scan.scan_info.run_on_exception_hook) + return False + + @run_on_exception_hook.setter + def run_on_exception_hook(self, val: bool): + self._run_on_exception_hook = val + + def describe(self): + """description of the instruction queue""" + request_blocks = self.describe_scans() + content = messages.QueueInfoEntry( + queue_id=self.queue_id, + scan_id=self.scan_id, + is_scan=self.is_scan, + request_blocks=request_blocks, + scan_number=self.scan_number, + status=self.status.name, + active_request_block=self.describe_active_scan(), + ) + return content + + def describe_active_scan(self): + """description of the active scan""" + if self.active_scan is None: + return None + if self.active_scan not in self.scans: + return None + msg = self.scan_msgs[self.scans.index(self.active_scan)] + scan_info = self._get_request_block_message(self.active_scan, msg) + return scan_info + + def describe_scans(self): + """description of the scans in the instruction queue item""" + info = [] + for scan, msg in zip(self.scans, self.scan_msgs): + scan_info = self._get_request_block_message(scan, msg) + info.append(scan_info) + return info + + def _get_request_block_message( + self, scan: ScanBase_v4, msg: messages.ScanQueueMessage + ) -> messages.RequestBlock: + """ + Get the request block message for a given scan and scan queue message + + Args: + scan (ScanBase_v4): the scan for which to get the request block message + msg (ScanQueueMessage): the scan queue message containing the scan information + + Returns: + RequestBlock: the request block message containing the scan information + """ + return messages.RequestBlock( + msg=msg, + RID=msg.metadata["RID"], + scan_motors=scan.scan_info.readout_priority_modification.get("monitored", []), + readout_priority=scan.scan_info.readout_priority_modification, + is_scan=scan.scan_info.scan_type is not None, + scan_number=self._get_scan_number(scan), + scan_id=scan.scan_info.scan_id, + report_instructions=scan.scan_info.scan_report_instructions, + ) + + @property + def _scan_server_scan_number(self) -> int: + return self.parent.queue_manager.parent.scan_number + + def _get_scan_number(self, scan: ScanBase_v4) -> int | None: + if not scan.is_scan: + return None + if scan.scan_info.scan_number is not None: + # We've already assigned a scan number to this scan, return it + return scan.scan_info.scan_number + return self._scan_server_scan_number + self.scan_ids_head(scan) + + def scan_ids_head(self, target_scan: ScanBase_v4) -> int: + """Calculate the scan-number offset for a scan within the current queue.""" + offset = 1 + for queue in self.parent.queue: + if queue.status in [InstructionQueueStatus.COMPLETED, InstructionQueueStatus.RUNNING]: + continue + if queue.queue_id != self.queue_id: + offset += len([scan_id for scan_id in queue.scan_id if scan_id]) + continue + for scan in queue.scans: + if scan is target_scan: + return offset + if scan.scan_info.scan_id: + offset += 1 + return offset + return offset + + def move_to_next_scan(self): + """move to the next scan in the instruction queue item""" + if self.active_scan is None: + if len(self.scans) > 0: + self._set_scan_as_active(self.scans[0]) + return self.active_scan + raise StopIteration("No active scan and no scans in the queue.") + current_index = self.scans.index(self.active_scan) + if current_index + 1 < len(self.scans): + self._set_scan_as_active(self.scans[current_index + 1]) + return self.active_scan + raise StopIteration("No more scans in the queue.") + + def _set_scan_as_active(self, scan: ScanBase_v4): + """set a given scan as the active scan""" + self.active_scan = scan + if scan.scan_info.scan_number is None: + with self.parent.queue_manager._lock: + self.parent.queue_manager.parent.scan_number += 1 + if not self.scan_msgs[self.scans.index(scan)].metadata.get("dataset_id_on_hold"): + self.parent.queue_manager.parent.dataset_number += 1 + scan.scan_info.scan_number = self.parent.queue_manager.parent.scan_number + scan.scan_info.dataset_number = self.parent.queue_manager.parent.dataset_number + self.set_active() + + def append_to_queue_history(self): + """append a new queue item to the redis history buffer""" + msg = messages.ScanQueueHistoryMessage( + status=self.status.name, queue_id=self.queue_id, info=self.describe() + ) + self.parent.queue_manager.connector.lpush( + MessageEndpoints.scan_queue_history(), msg, max_size=100 + ) + + def stop(self): + """stop the instruction queue item and all active scans""" + for scan in self.scans: + scan._shutdown_event.set() + + def abort(self): + self.active_scan = None + self.scans = [] + self.scan_msgs = [] diff --git a/bec_server/bec_server/scan_server/scan_worker.py b/bec_server/bec_server/scan_server/scan_worker.py index ec661da4d..b16f5d103 100644 --- a/bec_server/bec_server/scan_server/scan_worker.py +++ b/bec_server/bec_server/scan_server/scan_worker.py @@ -1,21 +1,17 @@ from __future__ import annotations -import os import threading -import time import traceback -from string import Template -from typing import TYPE_CHECKING, Literal +from typing import TYPE_CHECKING from bec_lib import messages from bec_lib.alarm_handler import Alarms -from bec_lib.endpoints import MessageEndpoints -from bec_lib.file_utils import compile_file_components from bec_lib.logger import bec_logger -from .errors import DeviceInstructionError, ScanAbortion, UserScanInterruption -from .scan_queue import InstructionQueueItem, InstructionQueueStatus, RequestBlock -from .scan_stubs import ScanStubStatus +from .direct_scan_worker import DirectScanWorker +from .errors import ScanAbortion +from .generator_scan_worker import GeneratorScanWorker +from .scan_queue import DirectInstructionQueueItem, InstructionQueueItem, InstructionQueueStatus logger = bec_logger.logger @@ -37,546 +33,41 @@ def __init__(self, *, parent: ScanServer, queue_name: str = "primary"): self.connector = self.parent.connector self.status = InstructionQueueStatus.IDLE self.signal_event = threading.Event() - self.scan_id = None - self.readout_priority = {} - self.scan_type = None - self.current_scan_id: str = "" - self.current_scan_info = None - self.max_point_id = 0 - self._exposure_time = None - self.current_instruction_queue_item: InstructionQueueItem | None = None - self.interception_msg = None - self.reset() + self.current_instruction_queue_item: ( + InstructionQueueItem | DirectInstructionQueueItem | None + ) = None - def open_scan(self, instr: messages.DeviceInstructionMessage) -> None: + def get_worker_for_queue( + self, queue: InstructionQueueItem + ) -> GeneratorScanWorker | DirectScanWorker: """ - Open a new scan and emit a scan status message. + Get the appropriate worker for the given queue. For now, we only have one worker type, + but this is where we will extend the functionality to support also direct ScanWorkers + that do not use the generator pattern and instead send instructions to the device server directly. - Args: - instr (DeviceInstructionMessage): Device instruction received from the scan assembler - - """ - if not self.scan_id: - self.scan_id = instr.metadata.get("scan_id") - self.readout_priority = instr.content["parameter"].get("readout_priority", {}) - self.scan_type = instr.content["parameter"].get("scan_type") - - if not instr.metadata.get("scan_def_id"): - self.max_point_id = 0 - instr_num_points = instr.content["parameter"].get("num_points", 0) - if instr_num_points is None: - instr_num_points = 0 - num_points = self.max_point_id + instr_num_points - if self.max_point_id: - num_points += 1 - - active_rb = self.current_instruction_queue_item.active_request_block - - self._initialize_scan_info(active_rb, instr, num_points) - - # only append the scan_progress if the scan is not using device_progress - if active_rb.scan.use_scan_progress_report: - if not self.scan_report_instructions or not self.scan_report_instructions[-1].get( - "device_progress" - ): - self.scan_report_instructions.append( - { - "scan_progress": { - "points": num_points, - "show_table": active_rb.scan.show_live_table, - } - } - ) - self.current_instruction_queue_item.parent.queue_manager.send_queue_status() - - self._send_scan_status("open") - - def close_scan(self, instr: messages.DeviceInstructionMessage, max_point_id: int) -> None: - """ - Close a scan and emit a scan status message. - - Args: - instr (DeviceInstructionMessage): Device instruction received from the scan assembler - max_point_id (int): Maximum point ID of the scan - """ - scan_id = instr.metadata.get("scan_id") - - if self.scan_id != scan_id: - return - - # reset the scan ID now that the scan will be closed - self.scan_id = None - - scan_info = self.current_scan_info - if scan_info.get("scan_type") == "fly": - # flyers do not increase the point_id but instead set the num_points directly - num_points = self.current_instruction_queue_item.active_request_block.scan.num_pos - self.current_scan_info["num_points"] = num_points - - else: - # point_id starts at 0 - scan_info["num_points"] = max_point_id + 1 - - self._send_scan_status("closed") - - def publish_data_as_read(self, instr: messages.DeviceInstructionMessage): - """ - Publish data as read by sending a DeviceMessage to the device_read endpoint. - This instruction replicates the behaviour of the device server when it receives a read instruction. - - Args: - instr (DeviceInstructionMessage): Device instruction received from the scan assembler - """ - connector = self.device_manager.connector - data = instr.content["parameter"]["data"] - devices = instr.content["device"] - if not isinstance(devices, list): - devices = [devices] - if not isinstance(data, list): - data = [data] - for device, dev_data in zip(devices, data): - msg = messages.DeviceMessage(signals=dev_data, metadata=instr.metadata) - connector.set_and_publish(MessageEndpoints.device_read(device), msg) - - def process_scan_report_instruction(self, instr): - """ - Process a scan report instruction by appending it to the scan_report_instructions list. + For now, it simply serves as a factory. Args: - instr (DeviceInstructionMessage): Device instruction received from the scan assembler - - """ - self.scan_report_instructions.append(instr.content["parameter"]) - self.current_instruction_queue_item.parent.queue_manager.send_queue_status() - - def forward_instruction(self, instr: messages.DeviceInstructionMessage) -> None: - """ - Forward an instruction to the device server. - - Args: - instr (DeviceInstructionMessage): Device instruction received from the scan assembler - - """ - self.connector.send(MessageEndpoints.device_instructions(), instr) - - @property - def scan_report_instructions(self): - """ - List of scan report instructions - """ - req_block = self.current_instruction_queue_item.active_request_block - return req_block.scan_report_instructions - - def _wait_for_device_server(self) -> None: - self.parent.wait_for_service("DeviceServer") - - def _check_for_interruption(self) -> None: - if self.status == InstructionQueueStatus.PAUSED: - self._send_scan_status("paused") - while self.status == InstructionQueueStatus.PAUSED: - time.sleep(0.1) - if self.status == InstructionQueueStatus.STOPPED: - item = self.current_instruction_queue_item - if item is None or item.exit_info is None: - raise ScanAbortion() - raise UserScanInterruption(exit_info=item.exit_info) - - def _initialize_scan_info( - self, active_rb: RequestBlock, instr: messages.DeviceInstructionMessage, num_points: int - ): - - metadata = active_rb.metadata - self.current_scan_info = {**instr.metadata, **instr.content["parameter"]} - self.current_scan_info.update(metadata) - self.current_scan_info.update( - { - "scan_number": self.parent.scan_number, - "dataset_number": self.parent.dataset_number, - "exp_time": self._exposure_time, - "frames_per_trigger": active_rb.scan.frames_per_trigger, - "settling_time": active_rb.scan.settling_time, - "readout_time": active_rb.scan.readout_time, - "scan_report_devices": active_rb.scan.scan_report_devices, - "monitor_sync": active_rb.scan.monitor_sync, - "num_points": num_points, - "scan_parameters": active_rb.scan.scan_parameters, - "request_inputs": active_rb.scan.request_inputs, - "file_components": compile_file_components( - base_path=self._get_file_base_path(), - scan_nr=self.parent.scan_number, - file_directory=active_rb.scan.scan_parameters["system_config"][ - "file_directory" - ], - user_suffix=active_rb.scan.scan_parameters["system_config"]["file_suffix"], - ), - } - ) - self.current_scan_info["scan_msgs"] = [ - str(scan_msg) for scan_msg in self.current_instruction_queue_item.scan_msgs - ] - self.current_scan_info["args"] = active_rb.scan.parameter["args"] - self.current_scan_info["kwargs"] = active_rb.scan.parameter["kwargs"] - self.current_scan_info["readout_priority"] = { - "monitored": [ - dev.full_name - for dev in self.device_manager.devices.monitored_devices( - readout_priority=self.readout_priority - ) - ], - "baseline": [ - dev.full_name - for dev in self.device_manager.devices.baseline_devices( - readout_priority=self.readout_priority - ) - ], - "async": [ - dev.full_name - for dev in self.device_manager.devices.async_devices( - readout_priority=self.readout_priority - ) - ], - "continuous": [ - dev.full_name - for dev in self.device_manager.devices.continuous_devices( - readout_priority=self.readout_priority - ) - ], - "on_request": [ - dev.full_name - for dev in self.device_manager.devices.on_request_devices( - readout_priority=self.readout_priority - ) - ], - } - - def _get_file_base_path(self) -> str: - """ - Get the file base path for the scan data. The base path can be a string or a template. - If it is a template, the account name will be substituted into the template. - The account name is retrieved from the current account message. - If the account name is not found, an empty string will be used. - """ - current_account_msg = self.connector.get_last(MessageEndpoints.account(), "data") - if current_account_msg: - current_account = current_account_msg.value - if not isinstance(current_account, str): - logger.warning( - f"Account name is not a string: {current_account}. " "Ignoring specified value." - ) - current_account = None - else: - if "/" in current_account: - raise ValueError( - f"Account name cannot contain a slash (/): {current_account}. " - ) - # _ and - are allowed - check_value = current_account.replace("_", "").replace("-", "") - if not check_value.isalnum() or not check_value.isascii(): - raise ValueError( - f"Account name can only contain alphanumeric characters: {current_account}. " - ) - - else: - current_account = None - - # pylint: disable=protected-access - file_base_path = self.parent._service_config.config["file_writer"]["base_path"] - if "$" not in file_base_path: - # we deal with a normal string - if current_account: - return os.path.abspath(os.path.join(file_base_path, current_account)) - # if there is no account, we return the base path with the data folder - return os.path.abspath(file_base_path) - - # we deal with a string template - file_base_path = Template(file_base_path) - - try: - # check if the template is valid - return os.path.abspath(file_base_path.substitute(account=current_account or "")) - except KeyError as exc: - raise ValueError( - f"Invalid template variable: {exc} in the file base path. " - "Please check your service config." - ) from exc - - def _send_scan_status( - self, - status: Literal["open", "paused", "closed", "aborted", "halted", "user_completed"], - reason: Literal["user", "alarm"] | None = None, - ) -> None: - if not self.current_scan_info: - return - current_scan_info_print = self.current_scan_info.copy() - if current_scan_info_print.get("positions", []): - current_scan_info_print["positions"] = "..." - logger.info( - f"New scan status: {self.current_scan_id} / {status} / {current_scan_info_print}" - ) - si = self.current_scan_info - update_fields = [ - "scan_name", - "scan_number", - "session_id", - "dataset_number", - "num_points", - "scan_type", - "scan_report_devices", - "user_metadata", - "readout_priority", - "scan_parameters", - "request_inputs", - ] - update = {k: si.get(k) for k in update_fields if si.get(k) is not None} - msg = messages.ScanStatusMessage( - scan_id=self.current_scan_id, - status=status, - reason=reason, - info=self.current_scan_info, - **update, - ) - if msg.readout_priority != (cur_rp := self.current_scan_info.get("readout_priority")): - raise RuntimeError( - f"Readout priority mismatch: expected {cur_rp}, got {msg.readout_priority}" - ) - expire = None if status in ["open", "paused"] else 1800 - pipe = self.device_manager.connector.pipeline() - self.device_manager.connector.set( - MessageEndpoints.public_scan_info(self.current_scan_id), msg, pipe=pipe, expire=expire - ) - self.device_manager.connector.set_and_publish( - MessageEndpoints.scan_status(), msg, pipe=pipe - ) - pipe.execute() - - def update_instr_with_scan_report(self, instr: messages.DeviceInstructionMessage): - if not self.scan_report_instructions: - return - for scan_report in self.scan_report_instructions: - if "readback" not in scan_report: - continue - readback = scan_report["readback"] - instr_device = ( - instr.content["device"] - if isinstance(instr.content["device"], list) - else [instr.content["device"]] - ) - - if set(readback.get("devices", [])) & set(instr_device): - instr.metadata["response"] = True - - def _get_metadata_for_alarm(self) -> dict: - """ - Get metadata for the alarm to be raised in case of an error. - This includes the scan ID and scan number if available. - + queue (InstructionQueueItem): The instruction queue item for which to get the worker. Returns: - dict: Metadata dictionary with scan ID and scan number. + GeneratorScanWorker: The worker that should be used to process the instructions in the given queue """ - metadata = {} - if not self.current_scan_info: - return metadata - - if self.current_scan_info.get("scan_id"): - metadata["scan_id"] = self.current_scan_info["scan_id"] - if self.current_scan_info.get("scan_number"): - metadata["scan_number"] = self.current_scan_info["scan_number"] - return metadata - - ############################# - # PROCESS INSTRUCTIONS LOOP # - ############################# - - def _init_instruction_loop(self, queue: InstructionQueueItem) -> float | None: - """Get ready to run the process instructions loop, and return the start time if successful.""" - if not queue: - return None - self.current_instruction_queue_item = queue - start = time.time() - self.max_point_id = 0 - # make sure the device server is ready to receive data - self._wait_for_device_server() - queue.is_active = True - return start - - def _propagate_pi_error(self, content: str, error_info: messages.ErrorInfo): - logger.error(content) - self.connector.raise_alarm( - severity=Alarms.MAJOR, info=error_info, metadata=self._get_metadata_for_alarm() - ) - - def _process_instructions(self, queue: InstructionQueueItem) -> None: - """ - Process scan instructions and send DeviceInstructions to OPAAS. - For now this is an in-memory communication. In the future however, - we might want to pass it through a dedicated Kafka topic. - Args: - queue: instruction queue - - Returns: - - """ - if (start := self._init_instruction_loop(queue)) is None: - return - try: - rpc_method = queue.queue.request_blocks[0].scan.stubs._rpc_call - with self.device_manager._rpc_method(rpc_method): - for instr in queue: - self._check_for_interruption() - if instr is None: - continue - self._exposure_time = getattr(queue.active_request_block.scan, "exp_time", None) - self._instruction_step(instr) - except ScanAbortion as exc: - if self.signal_event.is_set(): - return - if queue.stopped or not (queue.return_to_start and queue.active_request_block): - raise exc - queue.stopped = True - try: - cleanup = queue.active_request_block.scan.move_to_start() - rpc_method = queue.active_request_block.scan.stubs._rpc_call - self.status = InstructionQueueStatus.RUNNING - with self.device_manager._rpc_method(rpc_method): - for instr in cleanup: - self._check_for_interruption() - instr.metadata["scan_id"] = queue.queue.active_rb.scan_id - instr.metadata["queue_id"] = queue.queue_id - self._instruction_step(instr) - except DeviceInstructionError as exc_di: - self._propagate_pi_error(traceback.format_exc(), exc_di.error_info) - raise ScanAbortion from exc_di - except Exception as exc_return_to_start: - # if the return_to_start fails, raise the original exception - content = traceback.format_exc() - error_info = messages.ErrorInfo( - error_message=content, - compact_error_message=traceback.format_exc(limit=0), - exception_type=exc_return_to_start.__class__.__name__, - device=None, - ) - self._propagate_pi_error(content, error_info) - raise exc - raise exc - except DeviceInstructionError as exc_di: - self._propagate_pi_error(traceback.format_exc(), exc_di.error_info) - raise ScanAbortion from exc_di - except Exception as exc: - content = traceback.format_exc() - error_info = messages.ErrorInfo( - error_message=content, - compact_error_message=traceback.format_exc(limit=0), - exception_type=exc.__class__.__name__, - device=None, - ) - self._propagate_pi_error(content, error_info) - raise ScanAbortion from exc - queue.is_active = False - queue.status = InstructionQueueStatus.COMPLETED - self.current_instruction_queue_item = None - - logger.info(f"QUEUE ITEM finished after {time.time()-start:.2f} seconds") - self.reset() - - def _instruction_step(self, instr: messages.DeviceInstructionMessage): - logger.debug(instr) - action = instr.content.get("action") - scan_def_id = instr.metadata.get("scan_def_id") - self.current_scan_id = instr.metadata.get("scan_id", "") - - if "point_id" in instr.metadata: - self.max_point_id = instr.metadata["point_id"] - - logger.debug(f"Device instruction: {instr}") - self._check_for_interruption() - - if action == "open_scan": - self.open_scan(instr) - elif action == "close_scan" and scan_def_id is None: - self.close_scan(instr, self.max_point_id) - elif action == "close_scan" and scan_def_id is not None: - pass - elif action == "open_scan_def": - pass - elif action == "close_scan_def": - self.close_scan(instr, self.max_point_id) - elif action == "publish_data_as_read": - self.publish_data_as_read(instr) - elif action == "scan_report_instruction": - self.process_scan_report_instruction(instr) - elif action == "set": - self.update_instr_with_scan_report(instr) - self.forward_instruction(instr) - elif action in [ - "trigger", - "kickoff", - "complete", - "baseline_reading", - "pre_scan", - "rpc", - "read", - "stage", - "unstage", - ]: - self.forward_instruction(instr) - - else: - raise ValueError(f"Unknown device instruction: {instr}") - - def reset(self): - """reset the scan worker and its member variables""" - self.current_scan_id = "" - self.current_scan_info = {} - self.scan_id = None - self.interception_msg = None - self.current_instruction_queue_item = None - - def cleanup(self): - """perform cleanup instructions""" - status = ScanStubStatus(self.parent.queue_manager.instruction_handler) - staged_devices = [dev.root.name for dev in self.device_manager.devices.enabled_devices] - msg = messages.DeviceInstructionMessage( - device=staged_devices, - action="unstage", - parameter={}, - metadata={"device_instr_id": status._device_instr_id}, - ) - self.forward_instruction(msg) - # status.wait() - - def _handle_scan_abortion(self, queue: InstructionQueueItem, exc: ScanAbortion): - content = traceback.format_exc() - logger.error(content) - - exit_info = None - if isinstance(exc, UserScanInterruption): - exit_info = exc.exit_info - else: - exit_info = queue.exit_info - if exit_info: - self._send_scan_status(exit_info[0], reason=exit_info[1]) - else: - reason = "alarm" - if queue.return_to_start: - self._send_scan_status("aborted", reason=reason) - else: - self._send_scan_status("halted", reason=reason) - logger.info(f"Scan aborted: {queue.queue_id}") - queue.status = InstructionQueueStatus.STOPPED - queue.append_to_queue_history() - self.cleanup() - self.parent.queue_manager.queues[self.queue_name].abort() - self.reset() - self.status = InstructionQueueStatus.RUNNING + if isinstance(queue, DirectInstructionQueueItem): + return DirectScanWorker(worker=self) + return GeneratorScanWorker(worker=self) def run(self): try: while not self.signal_event.is_set(): try: for queue in self.parent.queue_manager.queues[self.queue_name]: - self._process_instructions(queue) if self.signal_event.is_set(): break + if not queue: + continue + self.current_instruction_queue_item = queue + worker = self.get_worker_for_queue(queue) + worker.process_instructions(queue) if not queue.stopped: queue.append_to_queue_history() @@ -584,7 +75,7 @@ def run(self): if not queue: # only for type checker; we should never get here continue - self._handle_scan_abortion(queue, exc) + worker._handle_scan_abortion(queue, exc) # pylint: disable=broad-except except Exception as exc: @@ -597,15 +88,15 @@ def run(self): device=None, ) self.connector.raise_alarm( - severity=Alarms.MAJOR, info=error_info, metadata=self._get_metadata_for_alarm() + severity=Alarms.MAJOR, info=error_info, metadata=worker.get_metadata_for_alarm() ) if self.queue_name in self.parent.queue_manager.queues: self.parent.queue_manager.queues[self.queue_name].abort() - self.reset() + worker.reset() logger.critical(f"Scan worker stopped: {exc}. Unrecoverable error.") def shutdown(self): """shutdown the scan worker""" self.signal_event.set() - if self._started.is_set(): + if self._started.is_set(): # type: ignore ; _started is defined in threading.Thread self.join() diff --git a/bec_server/bec_server/scan_server/scans/__init__.py b/bec_server/bec_server/scan_server/scans/__init__.py new file mode 100644 index 000000000..70cdf11d0 --- /dev/null +++ b/bec_server/bec_server/scan_server/scans/__init__.py @@ -0,0 +1,2 @@ +from .legacy_scans import * +from .scan_modifier import scan_hook diff --git a/bec_server/bec_server/scan_server/scans/acquire.py b/bec_server/bec_server/scan_server/scans/acquire.py new file mode 100644 index 000000000..2332f7e30 --- /dev/null +++ b/bec_server/bec_server/scan_server/scans/acquire.py @@ -0,0 +1,203 @@ +""" +Acquire scan implementation for taking one or more exposures without moving motors. + +Scan procedure: + - prepare_scan + - open_scan + - stage + - pre_scan + - scan_core + - at_each_point (optionally called by scan_core) + - post_scan + - unstage + - close_scan + - on_exception (called if any exception is raised during the scan) +""" + +from __future__ import annotations + +from typing import Annotated + +import numpy as np + +from bec_lib.scan_args import ScanArgument, Units +from bec_server.scan_server.scans.scan_modifier import scan_hook +from bec_server.scan_server.scans.scans_v4 import ScanBase, ScanType + + +class Acquire(ScanBase): + # Scan Type: Hardware triggered or software triggered? + # If the main trigger and readout logic is done within the at_each_point method in scan_core, choose SOFTWARE_TRIGGERED. + # If the main trigger and readout logic is implemented on a device that is simply kicked off in this scan, choose HARDWARE_TRIGGERED. + # This primarily serves as information for devices: The device may need to react differently if a software trigger is expected + # for every point. + scan_type = ScanType.SOFTWARE_TRIGGERED + + # Scan name: This is the name of the scan, e.g. "line_scan". This is used for display purposes and to identify the scan type in user interfaces. + # Choose a descriptive name that does not conflict with existing scan names. + scan_name = "_v4_acquire" + + gui_config = { + "Scan Parameters": [ + "exp_time", + "frames_per_trigger", + "settling_time", + "settling_time_after_trigger", + "readout_time", + "burst_at_each_point", + ] + } + + def __init__( + self, + exp_time: Annotated[ + float, ScanArgument(display_name="Exposure Time", units=Units.s, ge=0) + ] = 0, + frames_per_trigger: Annotated[ + int, ScanArgument(display_name="Frames per Trigger", ge=1) + ] = 1, + settling_time: Annotated[ + float, ScanArgument(display_name="Settling Time", units=Units.s, ge=0) + ] = 0, + settling_time_after_trigger: Annotated[ + float, ScanArgument(display_name="Settling Time After Trigger", units=Units.s, ge=0) + ] = 0, + readout_time: Annotated[ + float, ScanArgument(display_name="Readout Time", units=Units.s, ge=0) + ] = 0, + burst_at_each_point: Annotated[ + int, ScanArgument(display_name="Burst at Each Point", ge=1) + ] = 1, + **kwargs, + ): + """ + A simple acquisition at the current position. + + Args: + exp_time (float): exposure time in seconds. Default is 0. + frames_per_trigger (int): number of frames acquired per trigger. Default is 1. + settling_time (float): settling time before the trigger in seconds. Default is 0. + settling_time_after_trigger (float): settling time after the trigger in seconds. Default is 0. + readout_time (float): readout time after the trigger in seconds. Default is 0. + burst_at_each_point (int): number of acquisitions. Default is 1. + + Returns: + ScanReport + + Examples: + >>> scans.acquire(exp_time=0.1) + + """ + super().__init__(**kwargs) + self.motors = [] + self.exp_time = exp_time + self.frames_per_trigger = frames_per_trigger + self.settling_time = settling_time + self.settling_time_after_trigger = settling_time_after_trigger + self.readout_time = readout_time + self.burst_at_each_point = burst_at_each_point + + # Update the default scan info with provided parameters. + self.update_scan_info( + exp_time=exp_time, + frames_per_trigger=frames_per_trigger, + settling_time=settling_time, + settling_time_after_trigger=settling_time_after_trigger, + readout_time=readout_time, + burst_at_each_point=burst_at_each_point, + ) + + @scan_hook + def prepare_scan(self): + """ + Prepare the scan. This can include any steps that need to be executed + before the scan is opened, such as preparing the positions (if not done already) + or setting up the devices. + """ + self.update_scan_info( + positions=np.array([]), num_points=1, num_monitored_readouts=self.burst_at_each_point + ) + + self.actions.add_scan_report_instruction_scan_progress( + points=self.scan_info.num_monitored_readouts, show_table=False + ) + + self._baseline_readout_status = self.actions.read_baseline_devices(wait=False) + + @scan_hook + def open_scan(self): + """ + Open the scan. + This step must call self.actions.open_scan() to ensure that a new scan is + opened. Make sure to prepare the scan metadata before, either in + prepare_scan() or in open_scan() itself and call self.update_scan_info(...) + to update the scan metadata if needed. + """ + self.actions.open_scan() + + @scan_hook + def stage(self): + """ + Stage the devices for the upcoming scan. The stage logic is typically + implemented on the device itself (i.e. by the device's stage method). + However, if there are any additional steps that need to be executed before + staging the devices, they can be implemented here. + """ + self.actions.stage_all_devices() + + @scan_hook + def pre_scan(self): + """ + Pre-scan steps to be executed before the main scan logic. + This is typically the last chance to prepare the devices before the core scan + logic is executed. For example, this is a good place to initialize time-criticial + devices, e.g. devices that have a short timeout. + The pre-scan logic is typically implemented on the device itself. + """ + self.actions.pre_scan_all_devices() + + @scan_hook + def scan_core(self): + """ + Core scan logic to be executed during the scan. + This is where the main scan logic should be implemented. + """ + for _ in range(self.burst_at_each_point): + self.at_each_point() + + @scan_hook + def at_each_point(self): + """ + Logic to be executed at each acquisition point during the scan. + This hook allows concrete acquire-like scans to extend or override the + per-point behavior without reimplementing the full scan_core method. + """ + self.components.trigger_and_read() + + @scan_hook + def post_scan(self): + """ + Post-scan steps to be executed after the main scan logic. + """ + self.actions.complete_all_devices() + + @scan_hook + def unstage(self): + """Unstage the scan by executing post-scan steps.""" + self.actions.unstage_all_devices() + + @scan_hook + def close_scan(self): + """Close the scan.""" + if self._baseline_readout_status is not None: + self._baseline_readout_status.wait() + self.actions.close_scan() + self.actions.check_for_unchecked_statuses() + + @scan_hook + def on_exception(self, exception: Exception): + """ + Handle exceptions that occur during the scan. + This is a good place to implement any cleanup logic that needs to be executed in case of an exception, + such as returning the devices to a safe state or moving the motors back to their starting position. + """ diff --git a/bec_server/bec_server/scan_server/scans/cont_line_scan.py b/bec_server/bec_server/scan_server/scans/cont_line_scan.py new file mode 100644 index 000000000..cb3d58731 --- /dev/null +++ b/bec_server/bec_server/scan_server/scans/cont_line_scan.py @@ -0,0 +1,305 @@ +""" +Continuous line scan implementation for one motor with software-managed readout. + +Scan procedure: + - prepare_scan + - open_scan + - stage + - pre_scan + - scan_core + - at_each_point (optionally called by scan_core) + - post_scan + - unstage + - close_scan + - on_exception (called if any exception is raised during the scan) +""" + +from __future__ import annotations + +from typing import Annotated + +import numpy as np + +from bec_lib.device import DeviceBase +from bec_lib.scan_args import ScanArgument, Units +from bec_server.scan_server.errors import LimitError, ScanAbortion +from bec_server.scan_server.scans.scan_modifier import scan_hook +from bec_server.scan_server.scans.scans_v4 import ScanBase, ScanType + + +class ContLineScan(ScanBase): + # Scan Type: Hardware triggered or software triggered? + # If the main trigger and readout logic is done within the at_each_point method in scan_core, choose SOFTWARE_TRIGGERED. + # If the main trigger and readout logic is implemented on a device that is simply kicked off in this scan, choose HARDWARE_TRIGGERED. + # This primarily serves as information for devices: The device may need to react differently if a software trigger is expected + # for every point. + scan_type = ScanType.SOFTWARE_TRIGGERED + + # Scan name: This is the name of the scan, e.g. "line_scan". This is used for display purposes and to identify the scan type in user interfaces. + # Choose a descriptive name that does not conflict with existing scan names. + scan_name = "_v4_cont_line_scan" + required_kwargs = ["steps", "relative"] + gui_config = { + "Device": ["device", "start", "stop"], + "Movement Parameters": ["steps", "relative", "offset", "atol"], + "Acquisition Parameters": ["exp_time", "readout_time", "frames_per_trigger"], + } + + def __init__( + self, + device: DeviceBase, + start: Annotated[ + float, ScanArgument(display_name="Start Position", reference_units="device") + ], + stop: Annotated[ + float, ScanArgument(display_name="Stop Position", reference_units="device") + ], + steps: Annotated[int, ScanArgument(display_name="Number of Steps", ge=1)], + offset: Annotated[ + float | None, ScanArgument(display_name="Offset", reference_units="device") + ] = None, + atol: Annotated[ + float | None, ScanArgument(display_name="Tolerance", reference_units="device") + ] = None, + exp_time: Annotated[ + float, ScanArgument(display_name="Exposure Time", units=Units.s, ge=0) + ] = 0, + readout_time: Annotated[ + float, ScanArgument(display_name="Readout Time", units=Units.s, ge=0) + ] = 0, + frames_per_trigger: Annotated[ + int, ScanArgument(display_name="Frames per Trigger", ge=1) + ] = 1, + relative: bool = False, + **kwargs, + ): + """ + A continuous line scan. Use this scan if you want to move a motor continuously + from start to stop position while acquiring data at predefined positions. + + Args: + device (DeviceBase): motor to move continuously + start (float): start position + stop (float): stop position + offset (float | None): optional trigger offset from the nominal positions. + atol (float | None): optional tolerance used for position matching. + exp_time (Annotated[float, Units.s]): exposure time in seconds. Default is 0. + steps (int): number of acquisition points. Default is 10. + relative (bool): if True, interpret start and stop relative to the current motor position. + + Returns: + ScanReport + + Examples: + >>> scans.cont_line_scan(dev.motor1, -5, 5, steps=20, exp_time=0.05, relative=True) + """ + super().__init__(**kwargs) + self.device = device + self.motors = [device] + self.start = start + self.stop = stop + self.offset = offset + self.atol = atol + self.exp_time = exp_time + self.steps = steps + self.relative = relative + self.readout_time = readout_time + self.frames_per_trigger = frames_per_trigger + self.motor_acceleration = None + self.motor_velocity = None + self.dist_step = None + self.time_per_step = None + self._point_index = 0 + + self.update_scan_info( + exp_time=exp_time, + relative=relative, + readout_time=readout_time, + frames_per_trigger=frames_per_trigger, + scan_report_devices=self.motors, + ) + self.actions.set_device_readout_priority(self.motors, priority="monitored") + + @scan_hook + def prepare_scan(self): + """ + Prepare the scan. This can include any steps that need to be executed + before the scan is opened, such as preparing the positions (if not done already) + or setting up the devices. + """ + self._get_motor_attributes() + self.positions = np.linspace(self.start, self.stop, self.steps, dtype=float)[:, np.newaxis] + if self.relative: + self.start_positions = self.components.get_start_positions(self.motors) + self.positions += self.start_positions + self.dist_step = self.positions[1][0] - self.positions[0][0] + self._calculate_offset() + self._calculate_atol() + self.time_per_step = self.dist_step / self.motor_velocity + if self.time_per_step < self.exp_time: + raise ScanAbortion( + f"Motor {self.device} is moving too fast. Time per step: {self.time_per_step:.03f} < Exp_time: {self.exp_time:.03f}. Consider reducing speed {self.motor_velocity} or reducing exp_time {self.exp_time}" + ) + self._check_continuous_limits() + + self.update_scan_info( + positions=self.positions, + num_points=len(self.positions), + num_monitored_readouts=len(self.positions), + ) + + self.actions.add_scan_report_instruction_scan_progress( + points=self.scan_info.num_monitored_readouts, show_table=False + ) + + self._baseline_readout_status = self.actions.read_baseline_devices(wait=False) + + # Pre-move the motor to the start position + self._premove_motor_status = self.actions.set( + self.device, self.positions[0][0] - self.offset, wait=False + ) + + @scan_hook + def open_scan(self): + """ + Open the scan. + This step must call self.actions.open_scan() to ensure that a new scan is + opened. Make sure to prepare the scan metadata before, either in + prepare_scan() or in open_scan() itself and call self.update_scan_info(...) + to update the scan metadata if needed. + """ + self.actions.open_scan() + + @scan_hook + def stage(self): + """ + Stage the devices for the upcoming scan. The stage logic is typically + implemented on the device itself (i.e. by the device's stage method). + However, if there are any additional steps that need to be executed before + staging the devices, they can be implemented here. + """ + self.actions.stage_all_devices() + + @scan_hook + def pre_scan(self): + """ + Pre-scan steps to be executed before the main scan logic. + This is typically the last chance to prepare the devices before the core scan + logic is executed. For example, this is a good place to initialize time-criticial + devices, e.g. devices that have a short timeout. + The pre-scan logic is typically implemented on the device itself. + """ + self._premove_motor_status.wait() + self.actions.pre_scan_all_devices() + + @scan_hook + def scan_core(self): + """ + Core scan logic to be executed during the scan. + This is where the main scan logic should be implemented. + """ + self.actions.set(self.device, self.positions[0][0] - self.offset, wait=True) + status = self.actions.set(self.device, self.positions[-1][0], wait=False) + + while self._point_index < len(self.positions): + cont_motor_positions = self.device.read(cached=True) + if not cont_motor_positions: + continue + cont_motor_position = cont_motor_positions[self.device.full_name].get("value") + target_position = self.positions[self._point_index][0] + if np.isclose(cont_motor_position, target_position, atol=self.atol): + self.at_each_point() + self._point_index += 1 + continue + if cont_motor_position > target_position: + raise ScanAbortion( + f"Skipped point {self._point_index + 1}: Consider reducing speed {self.motor_velocity}, increasing the atol {self.atol}, or increasing the offset {self.offset}" + ) + status.wait() + + @scan_hook + def at_each_point(self): + """ + Logic to be executed at each acquisition point during the scan. + This hook allows concrete continuous-line variants to extend or override the + per-point behavior without reimplementing the full scan_core method. + """ + self.components.trigger_and_read() + + @scan_hook + def post_scan(self): + """ + Post-scan steps to be executed after the main scan logic. + """ + status = self.actions.complete_all_devices(wait=False) + if self.relative: + # Move the motors back to their starting position + self.components.move_and_wait(self.motors, self.start_positions) + status.wait() + + @scan_hook + def unstage(self): + """Unstage the scan by executing post-scan steps.""" + self.actions.unstage_all_devices() + + @scan_hook + def close_scan(self): + """Close the scan.""" + if self._baseline_readout_status is not None: + self._baseline_readout_status.wait() + self.actions.close_scan() + self.actions.check_for_unchecked_statuses() + + @scan_hook + def on_exception(self, exception: Exception): + """ + Handle exceptions that occur during the scan. + This is a good place to implement any cleanup logic that needs to be executed in case of an exception, + such as returning the devices to a safe state or moving the motors back to their starting position. + """ + if self.relative: + # Move the motors back to their starting position + self.components.move_and_wait(self.motors, self.start_positions) + + ####################################################### + ######### Helper methods for the scan logic ########### + ####################################################### + + def _get_motor_attributes(self): + if not hasattr(self.device, "velocity"): + raise ScanAbortion(f"Motor {self.device} does not have a velocity attribute.") + if not hasattr(self.device, "acceleration"): + raise ScanAbortion(f"Motor {self.device} does not have an acceleration attribute.") + self.motor_velocity = self.device.velocity.get() + self.motor_acceleration = self.device.acceleration.get() + + def _calculate_offset(self): + if self.offset is not None: + return + self.offset = 0.5 * self.motor_acceleration * self.motor_velocity + + def _calculate_atol(self): + update_freq = 10 + tolerance = 0.1 + precision = 10 ** (-self.device.precision) + if self.atol is not None: + return + self.atol = tolerance * self.motor_velocity * self.exp_time + self.atol = max(self.atol, 2 * precision) + if self.atol / update_freq > self.motor_velocity: + raise ScanAbortion( + f"Motor {self.device} is moving too fast with the calculated tolerance. Consider reducing speed {self.motor_velocity} or increasing the atol {self.atol}" + ) + self.atol = max(self.atol, 2 * 1 / update_freq * self.motor_velocity) + + def _check_continuous_limits(self): + low_limit, high_limit = self.device.limits + if low_limit >= high_limit: + return + for ii, pos in enumerate(self.positions): + pos_axis = pos[0] - self.offset if ii == 0 else pos[0] + if not low_limit <= pos_axis <= high_limit: + raise LimitError( + f"Target position including offset {pos_axis} (offset: {self.offset}) for motor {self.device} is outside of range: [{low_limit}, {high_limit}]", + device=self.device.name, + ) diff --git a/bec_server/bec_server/scan_server/scans/fermat_scan.py b/bec_server/bec_server/scan_server/scans/fermat_scan.py new file mode 100644 index 000000000..bcee9b8d7 --- /dev/null +++ b/bec_server/bec_server/scan_server/scans/fermat_scan.py @@ -0,0 +1,304 @@ +""" +Fermat spiral scan implementation for two-motor area scans. + +Scan procedure: + - prepare_scan + - open_scan + - stage + - pre_scan + - scan_core + - at_each_point (optionally called by scan_core) + - post_scan + - unstage + - close_scan + - on_exception (called if any exception is raised during the scan) +""" + +from __future__ import annotations + +from typing import Annotated, Literal + +import numpy as np + +from bec_lib.device import DeviceBase +from bec_lib.logger import bec_logger +from bec_lib.scan_args import ScanArgument, Units +from bec_server.scan_server.scans import position_generators +from bec_server.scan_server.scans.scan_modifier import scan_hook +from bec_server.scan_server.scans.scans_v4 import ScanBase, ScanType + +logger = bec_logger.logger + + +class FermatSpiralScan(ScanBase): + + # Scan Type: Hardware triggered or software triggered? + # If the main trigger and readout logic is done within the at_each_point method in scan_core, choose SOFTWARE_TRIGGERED. + # If the main trigger and readout logic is implemented on a device that is simply kicked off in this scan, choose HARDWARE_TRIGGERED. + # This primarily serves as information for devices: The device may need to react differently if a software trigger is expected + # for every point. + scan_type = ScanType.SOFTWARE_TRIGGERED + + # Scan name: This is the name of the scan, e.g. "line_scan". This is used for display purposes and to identify the scan type in user interfaces. + # Choose a descriptive name that does not conflict with existing scan names. + scan_name = "_v4_fermat_scan" + + required_kwargs = ["relative"] + + gui_config = { + "Device 1": ["motor1", "start_motor1", "stop_motor1"], + "Device 2": ["motor2", "start_motor2", "stop_motor2"], + "Movement Parameters": ["step", "spiral_type", "relative", "optim_trajectory"], + "Acquisition Parameters": [ + "exp_time", + "frames_per_trigger", + "settling_time", + "settling_time_after_trigger", + "readout_time", + "burst_at_each_point", + ], + } + + def __init__( + self, + motor1: DeviceBase, + start_motor1: Annotated[ + float, ScanArgument(display_name="Start Position", reference_units="motor1") + ], + stop_motor1: Annotated[ + float, ScanArgument(display_name="Stop Position", reference_units="motor1") + ], + motor2: DeviceBase, + start_motor2: Annotated[ + float, ScanArgument(display_name="Start Position", reference_units="motor2") + ], + stop_motor2: Annotated[ + float, ScanArgument(display_name="Stop Position", reference_units="motor2") + ], + step: Annotated[ + float, ScanArgument(display_name="Step Size", reference_units="motor1") + ] = 0.1, + exp_time: Annotated[ + float, ScanArgument(display_name="Exposure Time", units=Units.s, ge=0) + ] = 0, + frames_per_trigger: Annotated[ + int, ScanArgument(display_name="Frames per Trigger", ge=1) + ] = 1, + settling_time: Annotated[ + float, ScanArgument(display_name="Settling Time", units=Units.s, ge=0) + ] = 0, + settling_time_after_trigger: Annotated[ + float, ScanArgument(display_name="Settling Time After Trigger", units=Units.s, ge=0) + ] = 0, + readout_time: Annotated[ + float, ScanArgument(display_name="Readout Time", units=Units.s, ge=0) + ] = 0, + relative: bool = False, + spiral_type: Annotated[ + float, + ScanArgument( + display_name="Spiral Type", + description="Angular offset in radians that determines the shape of the spiral", + ge=0, + le=2, + units=Units.rad, + ), + ] = 0, + optim_trajectory: Annotated[ + Literal["corridor", "shell", "nearest", None], + ScanArgument( + display_name="Trajectory Optimization Method", + description="Method for optimizing the scan trajectory", + ), + ] = None, + burst_at_each_point: Annotated[ + int, ScanArgument(display_name="Burst at Each Point", ge=1) + ] = 1, + **kwargs, + ): + """ + A scan following Fermat's spiral. + + Args: + motor1 (DeviceBase): first motor + start_motor1 (float): start position motor 1 + stop_motor1 (float): end position motor 1 + motor2 (DeviceBase): second motor + start_motor2 (float): start position motor 2 + stop_motor2 (float): end position motor 2 + step (float): step size in motor units. Default is 0.1. + exp_time (float): exposure time in seconds. Default is 0. + frames_per_trigger (int): number of frames acquired per trigger. Default is 1. + settling_time (float): settling time in seconds. Default is 0. + settling_time_after_trigger (float): settling time after trigger in seconds. Default is 0. + readout_time (float): readout time in seconds. Default is 0. + relative (bool): if True, the motors will be moved relative to their current position. + burst_at_each_point (int): number of exposures at each point. Default is 1. + spiral_type (float): Angular offset (e.g. 0, 0.25,... ) in radians that determines the shape of the spiral. Default is 0. + optim_trajectory (str): trajectory optimization method. Default is None. Options are "corridor", "shell", "nearest". + + Returns: + ScanReport + + Examples: + >>> scans.fermat_scan(dev.motor1, -5, 5, dev.motor2, -5, 5, step=0.5, exp_time=0.1, relative=True, optim_trajectory="corridor") + + """ + super().__init__(**kwargs) + self.motors = [motor1, motor2] + self.relative = relative + self.motor1_start_stop = (start_motor1, stop_motor1) + self.motor2_start_stop = (start_motor2, stop_motor2) + self.step = step + self.spiral_type = spiral_type + self.optim_trajectory = optim_trajectory + self.burst_at_each_point = burst_at_each_point + + # Update the default scan info with provided parameters. + self.update_scan_info( + exp_time=exp_time, + frames_per_trigger=frames_per_trigger, + settling_time=settling_time, + settling_time_after_trigger=settling_time_after_trigger, + readout_time=readout_time, + relative=relative, + burst_at_each_point=burst_at_each_point, + spiral_type=spiral_type, + optim_trajectory=optim_trajectory, + scan_report_devices=self.motors, + ) + + # We elevate the readout priority of the scan motors to "monitored" to ensure + # that their positions are included in every readout of the step scan. + self.actions.set_device_readout_priority(self.motors, priority="monitored") + + @scan_hook + def prepare_scan(self): + """ + Prepare the scan. This can include any steps that need to be executed + before the scan is opened, such as preparing the positions (if not done already) + or setting up the devices. + """ + self.positions = position_generators.fermat_spiral_pos( + m1_start=self.motor1_start_stop[0], + m1_stop=self.motor1_start_stop[1], + m2_start=self.motor2_start_stop[0], + m2_stop=self.motor2_start_stop[1], + step=self.step, + spiral_type=self.spiral_type, + ) + + if self.relative: + self.start_positions = self.components.get_start_positions(self.motors) + self.positions += self.start_positions + + self.components.check_limits(self.motors, self.positions) + + self.update_scan_info( + positions=self.positions, + num_points=len(self.positions), + num_monitored_readouts=len(self.positions) * self.burst_at_each_point, + ) + + self.actions.add_scan_report_instruction_scan_progress( + points=self.scan_info.num_monitored_readouts, show_table=False + ) + + self._baseline_readout_status = self.actions.read_baseline_devices(wait=False) + self._premove_motor_status = self.actions.set(self.motors, self.positions[0], wait=False) + + @scan_hook + def open_scan(self): + """ + Open the scan. + This step must call self.actions.open_scan() to ensure that a new scan is + opened. Make sure to prepare the scan metadata before, either in + prepare_scan() or in open_scan() itself and call self.update_scan_info(...) + to update the scan metadata if needed. + """ + self.actions.open_scan() + + @scan_hook + def stage(self): + """ + Stage the devices for the upcoming scan. The stage logic is typically + implemented on the device itself (i.e. by the device's stage method). + However, if there are any additional steps that need to be executed before + staging the devices, they can be implemented here. + """ + self.actions.stage_all_devices() + + @scan_hook + def pre_scan(self): + """ + Pre-scan steps to be executed before the main scan logic. + This is typically the last chance to prepare the devices before the core scan + logic is executed. For example, this is a good place to initialize time-criticial + devices, e.g. devices that have a short timeout. + The pre-scan logic is typically implemented on the device itself. + """ + self._premove_motor_status.wait() + self.actions.pre_scan_all_devices() + + @scan_hook + def scan_core(self): + """ + Core scan logic to be executed during the scan. + This is where the main scan logic should be implemented. + """ + self.components.step_scan( + self.motors, + self.positions, + at_each_point=self.at_each_point, + last_positions=self.positions[0], + ) + + @scan_hook + def at_each_point( + self, motors: list[str], positions: np.ndarray, last_positions: np.ndarray | None + ): + """ + Logic to be executed at each point during the scan. This is called by the step_scan method at each point. + + Args: + motors (list[str | DeviceBase]): List of motor names or device instances being moved. + positions (np.ndarray): Current positions of the motors, shape (len(motors),). + last_positions (np.ndarray | None): Previous positions of the motors, shape (len(motors),) or None if this is the first point. + """ + self.components.step_scan_at_each_point(motors, positions, last_positions=last_positions) + + @scan_hook + def post_scan(self): + """ + Post-scan steps to be executed after the main scan logic. + """ + status = self.actions.complete_all_devices(wait=False) + + if self.relative: + # Move the motors back to their starting position + self.components.move_and_wait(self.motors, self.start_positions) + status.wait() + + @scan_hook + def unstage(self): + """Unstage the scan by executing post-scan steps.""" + self.actions.unstage_all_devices() + + @scan_hook + def close_scan(self): + """Close the scan.""" + if self._baseline_readout_status is not None: + self._baseline_readout_status.wait() + self.actions.close_scan() + self.actions.check_for_unchecked_statuses() + + @scan_hook + def on_exception(self, exception: Exception): + """ + Handle exceptions that occur during the scan. + This is a good place to implement any cleanup logic that needs to be executed in case of an exception, + such as returning the devices to a safe state or moving the motors back to their starting position. + """ + if self.relative: + # Move the motors back to their starting position + self.components.move_and_wait(self.motors, self.start_positions) diff --git a/bec_server/bec_server/scan_server/scans/grid_scan.py b/bec_server/bec_server/scan_server/scans/grid_scan.py new file mode 100644 index 000000000..705c5b4d9 --- /dev/null +++ b/bec_server/bec_server/scan_server/scans/grid_scan.py @@ -0,0 +1,273 @@ +""" +Grid scan implementation for rectilinear two-dimensional step scans. + +Scan procedure: + - prepare_scan + - open_scan + - stage + - pre_scan + - scan_core + - at_each_point (optionally called by scan_core) + - post_scan + - unstage + - close_scan + - on_exception (called if any exception is raised during the scan) +""" + +from __future__ import annotations + +from typing import Annotated + +import numpy as np + +from bec_lib.device import DeviceBase +from bec_lib.logger import bec_logger +from bec_lib.scan_args import ScanArgument, Units +from bec_server.scan_server.scans import position_generators +from bec_server.scan_server.scans.scan_modifier import scan_hook +from bec_server.scan_server.scans.scans_v4 import ScanBase, ScanType, bundle_args + +logger = bec_logger.logger + + +class GridScan(ScanBase): + + # Scan Type: Hardware triggered or software triggered? + # If the main trigger and readout logic is done within the at_each_point method in scan_core, choose SOFTWARE_TRIGGERED. + # If the main trigger and readout logic is implemented on a device that is simply kicked off in this scan, choose HARDWARE_TRIGGERED. + # This primarily serves as information for devices: The device may need to react differently if a software trigger is expected + # for every point. + scan_type = ScanType.SOFTWARE_TRIGGERED + + # Scan name: This is the name of the scan, e.g. "line_scan". This is used for display purposes and to identify the scan type in user interfaces. + # Choose a descriptive name that does not conflict with existing scan names. + scan_name = "_v4_grid_scan" + + # arg_input and arg_bundle_size are only relevant for scans that accept an arbitrary number of motor / position arguments (e.g. line scans, grid scans). + # For scans with a fixed set of parameters (e.g. Fermat spiral), these can be simply removed. + arg_input = { + "device": DeviceBase, + "start": Annotated[ + float, ScanArgument(display_name="Start Position", reference_units="device") + ], + "stop": Annotated[ + float, ScanArgument(display_name="Stop Position", reference_units="device") + ], + "steps": Annotated[int, ScanArgument(display_name="Number of Steps", ge=1)], + } + arg_bundle_size = {"bundle": len(arg_input), "min": 2, "max": None} + required_kwargs = ["relative"] + + gui_config = { + "Movement Parameters": ["relative", "snaked"], + "Acquisition Parameters": [ + "exp_time", + "frames_per_trigger", + "settling_time", + "settling_time_after_trigger", + "readout_time", + "burst_at_each_point", + ], + } + + def __init__( + self, + *args, + exp_time: Annotated[ + float, ScanArgument(display_name="Exposure Time", units=Units.s, ge=0) + ] = 0, + frames_per_trigger: Annotated[ + int, ScanArgument(display_name="Frames per Trigger", ge=1) + ] = 1, + settling_time: Annotated[ + float, ScanArgument(display_name="Settling Time", units=Units.s, ge=0) + ] = 0, + settling_time_after_trigger: Annotated[ + float, ScanArgument(display_name="Settling Time After Trigger", units=Units.s, ge=0) + ] = 0, + readout_time: Annotated[ + float, ScanArgument(display_name="Readout Time", units=Units.s, ge=0) + ] = 0, + relative: bool = False, + snaked: bool = True, + burst_at_each_point: Annotated[ + int, ScanArgument(display_name="Burst at Each Point", ge=1) + ] = 1, + **kwargs, + ): + """ + Scan two or more motors in a grid. + + Args: + *args (Device, float, float, int): pairs of device / start / stop / steps arguments + exp_time (Annotated[float, Units.s]): exposure time in seconds. Default is 0. + frames_per_trigger (int): number of frames acquired per trigger. Default is 1. + settling_time (Annotated[float, Units.s]): settling time in seconds. Default is 0. + settling_time_after_trigger (Annotated[float, Units.s]): settling time after trigger in seconds. Default is 0. + readout_time (Annotated[float, Units.s]): readout time in seconds. Default is 0. + relative (bool): if True, the motors will be moved relative to their current position. Default is False. + burst_at_each_point (int): number of exposures at each point. Default is 1. + snaked (bool): if True, the scan will be snaked. Default is True. + + Returns: + ScanReport + + Examples: + >>> scans.grid_scan(dev.motor1, -5, 5, 10, dev.motor2, -5, 5, 10, exp_time=0.1, relative=True) + + """ + super().__init__(**kwargs) + self.motor_args = args + self.motor_input_bundles = bundle_args(args, bundle_size=self.arg_bundle_size["bundle"]) + self.motors = list(self.motor_input_bundles.keys()) + self.exp_time = exp_time + self.settling_time = settling_time + self.relative = relative + self.snaked = snaked + self.burst_at_each_point = burst_at_each_point + + # Update the default scan info with provided parameters. + self.update_scan_info( + exp_time=exp_time, + frames_per_trigger=frames_per_trigger, + settling_time=settling_time, + settling_time_after_trigger=settling_time_after_trigger, + readout_time=readout_time, + relative=relative, + snaked=snaked, + burst_at_each_point=burst_at_each_point, + scan_report_devices=self.motors, + ) + + # We elevate the readout priority of the scan motors to "monitored" to ensure + # that their positions are included in every readout of the step scan. + self.actions.set_device_readout_priority(self.motors, priority="monitored") + + @scan_hook + def prepare_scan(self): + """ + Prepare the scan. This can include any steps that need to be executed + before the scan is opened, such as preparing the positions (if not done already) + or setting up the devices. + """ + self.positions = position_generators.nd_grid_positions( + self.motor_input_bundles.values(), snaked=self.snaked + ) + + if self.relative: + self.start_positions = self.components.get_start_positions(self.motors) + self.positions += self.start_positions + + self.components.check_limits(self.motors, self.positions) + + self.update_scan_info( + positions=self.positions, + num_points=len(self.positions), + num_monitored_readouts=len(self.positions) * self.burst_at_each_point, + ) + + self.actions.add_scan_report_instruction_scan_progress( + points=self.scan_info.num_monitored_readouts, show_table=False + ) + + self._baseline_readout_status = self.actions.read_baseline_devices(wait=False) + + self._premove_motor_status = self.actions.set(self.motors, self.positions[0], wait=False) + + @scan_hook + def open_scan(self): + """ + Open the scan. + This step must call self.actions.open_scan() to ensure that a new scan is + opened. Make sure to prepare the scan metadata before, either in + prepare_scan() or in open_scan() itself and call self.update_scan_info(...) + to update the scan metadata if needed. + """ + self.actions.open_scan() + + @scan_hook + def stage(self): + """ + Stage the devices for the upcoming scan. The stage logic is typically + implemented on the device itself (i.e. by the device's stage method). + However, if there are any additional steps that need to be executed before + staging the devices, they can be implemented here. + """ + self.actions.stage_all_devices() + + @scan_hook + def pre_scan(self): + """ + Pre-scan steps to be executed before the main scan logic. + This is typically the last chance to prepare the devices before the core scan + logic is executed. For example, this is a good place to initialize time-criticial + devices, e.g. devices that have a short timeout. + The pre-scan logic is typically implemented on the device itself. + """ + self._premove_motor_status.wait() + self.actions.pre_scan_all_devices() + + @scan_hook + def scan_core(self): + """ + Core scan logic to be executed during the scan. + This is where the main scan logic should be implemented. + """ + self.components.step_scan( + self.motors, + self.positions, + at_each_point=self.at_each_point, + last_positions=self.positions[0], + ) + + @scan_hook + def at_each_point( + self, + motors: list[str | DeviceBase], + positions: np.ndarray, + last_positions: np.ndarray | None, + ): + """ + Logic to be executed at each point during the scan. This is called by the step_scan method at each point. + + Args: + motors (list[str | DeviceBase]): List of motor names or device instances being moved. + positions (np.ndarray): Current positions of the motors, shape (len(motors),). + last_positions (np.ndarray | None): Previous positions of the motors, shape (len(motors),) or None if this is the first point. + """ + self.components.step_scan_at_each_point(motors, positions, last_positions=last_positions) + + @scan_hook + def post_scan(self): + """ + Post-scan steps to be executed after the main scan logic. + """ + status = self.actions.complete_all_devices(wait=False) + + if self.relative: + # Move the motors back to their starting position + self.components.move_and_wait(self.motors, self.start_positions) + status.wait() + + @scan_hook + def unstage(self): + """Unstage the scan by executing post-scan steps.""" + self.actions.unstage_all_devices() + + @scan_hook + def close_scan(self): + """Close the scan.""" + if self._baseline_readout_status is not None: + self._baseline_readout_status.wait() + self.actions.close_scan() + self.actions.check_for_unchecked_statuses() + + @scan_hook + def on_exception(self, exception: Exception): + """ + Handle exceptions that occur during the scan. + This is a good place to implement any cleanup logic that needs to be executed in case of an exception, + such as returning the devices to a safe state or moving the motors back to their starting position. + """ + if self.relative: + self.components.move_and_wait(self.motors, self.start_positions) diff --git a/bec_server/bec_server/scan_server/scans/hexagonal_scan.py b/bec_server/bec_server/scan_server/scans/hexagonal_scan.py new file mode 100644 index 000000000..4cb59a5ef --- /dev/null +++ b/bec_server/bec_server/scan_server/scans/hexagonal_scan.py @@ -0,0 +1,296 @@ +""" +Hexagonal scan implementation for two-motor area scans on a hexagonal lattice. + +Scan procedure: + - prepare_scan + - open_scan + - stage + - pre_scan + - scan_core + - at_each_point (optionally called by scan_core) + - post_scan + - unstage + - close_scan + - on_exception (called if any exception is raised during the scan) +""" + +from __future__ import annotations + +from typing import Annotated + +import numpy as np + +from bec_lib.device import DeviceBase +from bec_lib.scan_args import ScanArgument, Units +from bec_server.scan_server.scans import position_generators +from bec_server.scan_server.scans.scan_modifier import scan_hook +from bec_server.scan_server.scans.scans_v4 import ScanBase, ScanType + + +class HexagonalScan(ScanBase): + # Scan Type: Hardware triggered or software triggered? + # If the main trigger and readout logic is done within the at_each_point method in scan_core, choose SOFTWARE_TRIGGERED. + # If the main trigger and readout logic is implemented on a device that is simply kicked off in this scan, choose HARDWARE_TRIGGERED. + # This primarily serves as information for devices: The device may need to react differently if a software trigger is expected + # for every point. + scan_type = ScanType.SOFTWARE_TRIGGERED + + # Scan name: This is the name of the scan, e.g. "line_scan". This is used for display purposes and to identify the scan type in user interfaces. + # Choose a descriptive name that does not conflict with existing scan names. + scan_name = "_v4_hexagonal_scan" + required_kwargs = ["relative"] + + gui_config = { + "Device 1": ["motor1", "start_motor1", "stop_motor1", "step_motor1"], + "Device 2": ["motor2", "start_motor2", "stop_motor2", "step_motor2"], + "Movement Parameters": ["relative", "snaked"], + "Acquisition Parameters": [ + "exp_time", + "frames_per_trigger", + "settling_time", + "settling_time_after_trigger", + "readout_time", + "burst_at_each_point", + ], + } + + def __init__( + self, + motor1: DeviceBase, + start_motor1: Annotated[ + float, ScanArgument(display_name="Start Position", reference_units="motor1") + ], + stop_motor1: Annotated[ + float, ScanArgument(display_name="Stop Position", reference_units="motor1") + ], + step_motor1: Annotated[ + float, ScanArgument(display_name="Step Size", reference_units="motor1", gt=0) + ], + motor2: DeviceBase, + start_motor2: Annotated[ + float, ScanArgument(display_name="Start Position", reference_units="motor2") + ], + stop_motor2: Annotated[ + float, ScanArgument(display_name="Stop Position", reference_units="motor2") + ], + step_motor2: Annotated[ + float, ScanArgument(display_name="Step Size", reference_units="motor2", gt=0) + ], + exp_time: Annotated[ + float, ScanArgument(display_name="Exposure Time", units=Units.s, ge=0) + ] = 0, + frames_per_trigger: Annotated[ + int, ScanArgument(display_name="Frames per Trigger", ge=1) + ] = 1, + settling_time: Annotated[ + float, ScanArgument(display_name="Settling Time", units=Units.s, ge=0) + ] = 0, + settling_time_after_trigger: Annotated[ + float, ScanArgument(display_name="Settling Time After Trigger", units=Units.s, ge=0) + ] = 0, + readout_time: Annotated[ + float, ScanArgument(display_name="Readout Time", units=Units.s, ge=0) + ] = 0, + relative: bool = False, + burst_at_each_point: Annotated[ + int, ScanArgument(display_name="Burst at Each Point", ge=1) + ] = 1, + snaked: bool = True, + **kwargs, + ): + """ + Scan two motors in a hexagonal grid pattern. + + Points are arranged in a honeycomb pattern where alternate rows + are offset by half the horizontal step size, providing more uniform + spatial coverage than rectangular grids. + + Args: + motor1 (DeviceBase): first motor + start_motor1 (float): start position of the first motor + stop_motor1 (float): stop position of the first motor + step_motor1 (float): step size of the first motor + motor2 (DeviceBase): second motor + start_motor2 (float): start position of the second motor + stop_motor2 (float): stop position of the second motor + step_motor2 (float): step size of the second motor + exp_time (Annotated[float, Units.s]): exposure time in seconds. Default is 0. + frames_per_trigger (int): number of frames acquired per trigger. Default is 1. + settling_time (Annotated[float, Units.s]): settling time in seconds. Default is 0. + settling_time_after_trigger (Annotated[float, Units.s]): settling time after trigger in seconds. Default is 0. + readout_time (Annotated[float, Units.s]): readout time in seconds. Default is 0. + relative (bool): if True, interpret the scan positions relative to the current motor positions. + burst_at_each_point (int): number of exposures at each point. Default is 1. + snaked (bool): if True, alternate the traversal direction between neighboring rows. + + Returns: + ScanReport + + Examples: + >>> scans.hexagonal_scan(dev.motor1, -5, 5, 1, dev.motor2, -4, 4, 1, exp_time=0.1, relative=True) + """ + super().__init__(**kwargs) + self.motors = [motor1, motor2] + self.start_motor1 = start_motor1 + self.stop_motor1 = stop_motor1 + self.step_motor1 = step_motor1 + self.start_motor2 = start_motor2 + self.stop_motor2 = stop_motor2 + self.step_motor2 = step_motor2 + self.exp_time = exp_time + self.settling_time = settling_time + self.relative = relative + self.burst_at_each_point = burst_at_each_point + self.snaked = snaked + + # Update the default scan info with provided parameters. + self.update_scan_info( + exp_time=exp_time, + frames_per_trigger=frames_per_trigger, + settling_time=settling_time, + settling_time_after_trigger=settling_time_after_trigger, + readout_time=readout_time, + relative=relative, + burst_at_each_point=burst_at_each_point, + snaked=snaked, + scan_report_devices=self.motors, + ) + + # We elevate the readout priority of the scan motors to "monitored" to ensure + # that their positions are included in every readout of the step scan. + self.actions.set_device_readout_priority(self.motors, priority="monitored") + + @scan_hook + def prepare_scan(self): + """ + Prepare the scan. This can include any steps that need to be executed + before the scan is opened, such as preparing the positions (if not done already) + or setting up the devices. + """ + self.positions = position_generators.hex_grid_2d( + [ + (self.start_motor1, self.stop_motor1, self.step_motor1), + (self.start_motor2, self.stop_motor2, self.step_motor2), + ], + snaked=self.snaked, + ) + + if self.relative: + self.start_positions = self.components.get_start_positions(self.motors) + self.positions += self.start_positions + + self.components.check_limits(self.motors, self.positions) + + self.update_scan_info( + positions=self.positions, + num_points=len(self.positions), + num_monitored_readouts=len(self.positions) * self.burst_at_each_point, + ) + + self.actions.add_scan_report_instruction_scan_progress( + points=self.scan_info.num_monitored_readouts, show_table=False + ) + + self._baseline_readout_status = self.actions.read_baseline_devices(wait=False) + + self._premove_motor_status = self.actions.set(self.motors, self.positions[0], wait=False) + + @scan_hook + def open_scan(self): + """ + Open the scan. + This step must call self.actions.open_scan() to ensure that a new scan is + opened. Make sure to prepare the scan metadata before, either in + prepare_scan() or in open_scan() itself and call self.update_scan_info(...) + to update the scan metadata if needed. + """ + self.actions.open_scan() + + @scan_hook + def stage(self): + """ + Stage the devices for the upcoming scan. The stage logic is typically + implemented on the device itself (i.e. by the device's stage method). + However, if there are any additional steps that need to be executed before + staging the devices, they can be implemented here. + """ + self.actions.stage_all_devices() + + @scan_hook + def pre_scan(self): + """ + Pre-scan steps to be executed before the main scan logic. + This is typically the last chance to prepare the devices before the core scan + logic is executed. For example, this is a good place to initialize time-criticial + devices, e.g. devices that have a short timeout. + The pre-scan logic is typically implemented on the device itself. + """ + self._premove_motor_status.wait() + self.actions.pre_scan_all_devices() + + @scan_hook + def scan_core(self): + """ + Core scan logic to be executed during the scan. + This is where the main scan logic should be implemented. + """ + self.components.step_scan( + self.motors, + self.positions, + at_each_point=self.at_each_point, + last_positions=self.positions[0], + ) + + @scan_hook + def at_each_point( + self, + motors: list[str | DeviceBase], + positions: np.ndarray, + last_positions: np.ndarray | None, + ): + """ + Logic to be executed at each point during the scan. This is called by the step_scan method at each point. + + Args: + motors (list[str | DeviceBase]): List of motor names or device instances being moved. + positions (np.ndarray): Current positions of the motors, shape (len(motors),). + last_positions (np.ndarray | None): Previous positions of the motors, shape + (len(motors),) or None if this is the first point. + """ + self.components.step_scan_at_each_point(motors, positions, last_positions=last_positions) + + @scan_hook + def post_scan(self): + """ + Post-scan steps to be executed after the main scan logic. + """ + status = self.actions.complete_all_devices(wait=False) + + if self.relative: + # Move the motors back to their starting position + self.components.move_and_wait(self.motors, self.start_positions) + status.wait() + + @scan_hook + def unstage(self): + """Unstage the scan by executing post-scan steps.""" + self.actions.unstage_all_devices() + + @scan_hook + def close_scan(self): + """Close the scan.""" + if self._baseline_readout_status is not None: + self._baseline_readout_status.wait() + self.actions.close_scan() + self.actions.check_for_unchecked_statuses() + + @scan_hook + def on_exception(self, exception: Exception): + """ + Handle exceptions that occur during the scan. + This is a good place to implement any cleanup logic that needs to be executed in case of an exception, + such as returning the devices to a safe state or moving the motors back to their starting position. + """ + if self.relative: + # Move the motors back to their starting position + self.components.move_and_wait(self.motors, self.start_positions) diff --git a/bec_server/bec_server/scan_server/scans.py b/bec_server/bec_server/scan_server/scans/legacy_scans.py similarity index 99% rename from bec_server/bec_server/scan_server/scans.py rename to bec_server/bec_server/scan_server/scans/legacy_scans.py index 670d2ff69..00abe5c44 100644 --- a/bec_server/bec_server/scan_server/scans.py +++ b/bec_server/bec_server/scan_server/scans/legacy_scans.py @@ -18,9 +18,9 @@ from bec_lib.logger import bec_logger from bec_server.scan_server.instruction_handler import InstructionHandler -from .errors import LimitError, ScanAbortion -from .path_optimization import PathOptimizerMixin -from .scan_stubs import ScanStubs +from ..errors import LimitError, ScanAbortion +from ..path_optimization import PathOptimizerMixin +from ..scan_stubs import ScanStubs logger = bec_logger.logger diff --git a/bec_server/bec_server/scan_server/scans/line_scan.py b/bec_server/bec_server/scan_server/scans/line_scan.py new file mode 100644 index 000000000..b8b12af8d --- /dev/null +++ b/bec_server/bec_server/scan_server/scans/line_scan.py @@ -0,0 +1,266 @@ +""" +Line scan implementation for one or more motors with evenly spaced step positions. + +Scan procedure: + - prepare_scan + - open_scan + - stage + - pre_scan + - scan_core + - at_each_point (optionally called by scan_core) + - post_scan + - unstage + - close_scan + - on_exception (called if any exception is raised during the scan) +""" + +from __future__ import annotations + +from typing import Annotated + +import numpy as np + +from bec_lib.device import DeviceBase +from bec_lib.scan_args import ScanArgument, Units +from bec_server.scan_server.scans import position_generators +from bec_server.scan_server.scans.scan_modifier import scan_hook +from bec_server.scan_server.scans.scans_v4 import ScanBase, ScanType, bundle_args + + +class LineScan(ScanBase): + # Scan Type: Hardware triggered or software triggered? + # If the main trigger and readout logic is done within the at_each_point method in scan_core, choose SOFTWARE_TRIGGERED. + # If the main trigger and readout logic is implemented on a device that is simply kicked off in this scan, choose HARDWARE_TRIGGERED. + # This primarily serves as information for devices: The device may need to react differently if a software trigger is expected + # for every point. + scan_type = ScanType.SOFTWARE_TRIGGERED + + # Scan name: This is the name of the scan, e.g. "line_scan". This is used for display purposes and to identify the scan type in user interfaces. + # Choose a descriptive name that does not conflict with existing scan names. + scan_name = "_v4_line_scan" + + # arg_input and arg_bundle_size are only relevant for scans that accept an arbitrary number of motor / position arguments (e.g. line scans, grid scans). + # For scans with a fixed set of parameters (e.g. Fermat spiral), these can be simply removed. + arg_input = { + "device": DeviceBase, + "start": Annotated[ + float, ScanArgument(display_name="Start Position", reference_units="device") + ], + "stop": Annotated[ + float, ScanArgument(display_name="Stop Position", reference_units="device") + ], + } + arg_bundle_size = {"bundle": len(arg_input), "min": 1, "max": None} + required_kwargs = ["steps", "relative"] + + gui_config = { + "Movement Parameters": ["steps", "relative"], + "Acquisition Parameters": [ + "exp_time", + "frames_per_trigger", + "settling_time", + "settling_time_after_trigger", + "readout_time", + "burst_at_each_point", + ], + } + + def __init__( + self, + *args, + steps: Annotated[int, ScanArgument(display_name="Number of Steps", gt=0)], + exp_time: Annotated[ + float, ScanArgument(display_name="Exposure Time", units=Units.s, ge=0) + ] = 0, + frames_per_trigger: Annotated[ + int, ScanArgument(display_name="Frames per Trigger", ge=1) + ] = 1, + settling_time: Annotated[ + float, ScanArgument(display_name="Settling Time", units=Units.s, ge=0) + ] = 0, + settling_time_after_trigger: Annotated[ + float, ScanArgument(display_name="Settling Time After Trigger", units=Units.s, ge=0) + ] = 0, + readout_time: Annotated[ + float, ScanArgument(display_name="Readout Time", units=Units.s, ge=0) + ] = 0, + relative: bool = False, + burst_at_each_point: Annotated[ + int, ScanArgument(display_name="Burst at Each Point", ge=1) + ] = 1, + **kwargs, + ): + """ + A line scan for one or more motors. + + Args: + *args (Device, float, float): pairs of device / start / stop arguments + steps (int): number of points along the line + exp_time (Annotated[float, Units.s]): exposure time in seconds. Default is 0. + frames_per_trigger (int): number of frames acquired per trigger. Default is 1. + settling_time (Annotated[float, Units.s]): settling time in seconds. Default is 0. + settling_time_after_trigger (Annotated[float, Units.s]): settling time after trigger in seconds. Default is 0. + readout_time (Annotated[float, Units.s]): readout time in seconds. Default is 0. + relative (bool): if True, the motors will be moved relative to their current position. Default is False. + burst_at_each_point (int): number of exposures at each point. Default is 1. + + Returns: + ScanReport + + Examples: + >>> scans.line_scan(dev.motor1, -5, 5, dev.motor2, -5, 5, steps=10, exp_time=0.1, relative=True) + + """ + super().__init__(**kwargs) + self.motor_args = args + self.motor_input_bundles = bundle_args(args, bundle_size=self.arg_bundle_size["bundle"]) + self.motors = list(self.motor_input_bundles.keys()) + self.steps = steps + self.relative = relative + self.exp_time = exp_time + self.settling_time = settling_time + self.burst_at_each_point = burst_at_each_point + + # Update the default scan info with provided parameters. + self.update_scan_info( + exp_time=exp_time, + frames_per_trigger=frames_per_trigger, + settling_time=settling_time, + settling_time_after_trigger=settling_time_after_trigger, + readout_time=readout_time, + relative=relative, + burst_at_each_point=burst_at_each_point, + scan_report_devices=self.motors, + ) + + # We elevate the readout priority of the scan motors to "monitored" to ensure + # that their positions are included in every readout of the step scan. + self.actions.set_device_readout_priority(self.motors, priority="monitored") + + @scan_hook + def prepare_scan(self): + """ + Prepare the scan. This can include any steps that need to be executed + before the scan is opened, such as preparing the positions (if not done already) + or setting up the devices. + """ + self.positions = position_generators.line_scan_positions( + list(self.motor_input_bundles.values()), steps=self.steps + ) + + if self.relative: + self.start_positions = self.components.get_start_positions(self.motors) + self.positions += self.start_positions + + self.components.check_limits(self.motors, self.positions) + + self.update_scan_info( + positions=self.positions, + num_points=len(self.positions), + num_monitored_readouts=len(self.positions) * self.burst_at_each_point, + ) + + self.actions.add_scan_report_instruction_scan_progress( + points=self.scan_info.num_monitored_readouts, show_table=False + ) + self._premove_motor_status = self.actions.set(self.motors, self.positions[0], wait=False) + self._baseline_readout_status = self.actions.read_baseline_devices(wait=False) + + @scan_hook + def open_scan(self): + """ + Open the scan. + This step must call self.actions.open_scan() to ensure that a new scan is + opened. Make sure to prepare the scan metadata before, either in + prepare_scan() or in open_scan() itself and call self.update_scan_info(...) + to update the scan metadata if needed. + """ + self.actions.open_scan() + + @scan_hook + def stage(self): + """ + Stage the devices for the upcoming scan. The stage logic is typically + implemented on the device itself (i.e. by the device's stage method). + However, if there are any additional steps that need to be executed before + staging the devices, they can be implemented here. + """ + self.actions.stage_all_devices() + + @scan_hook + def pre_scan(self): + """ + Pre-scan steps to be executed before the main scan logic. + This is typically the last chance to prepare the devices before the core scan + logic is executed. For example, this is a good place to initialize time-criticial + devices, e.g. devices that have a short timeout. + The pre-scan logic is typically implemented on the device itself. + """ + self._premove_motor_status.wait() + self.actions.pre_scan_all_devices() + + @scan_hook + def scan_core(self): + """ + Core scan logic to be executed during the scan. + This is where the main scan logic should be implemented. + """ + self.components.step_scan( + self.motors, + self.positions, + at_each_point=self.at_each_point, + last_positions=self.positions[0], + ) + + @scan_hook + def at_each_point( + self, + motors: list[str | DeviceBase], + positions: np.ndarray, + last_positions: np.ndarray | None, + ): + """ + Logic to be executed at each point during the scan. This is called by the step_scan method at each point. + + Args: + motors (list[str | DeviceBase]): List of motor names or device instances being moved. + positions (np.ndarray): Current positions of the motors, shape (len(motors),). + last_positions (np.ndarray | None): Previous positions of the motors, shape (len(motors),) or None if this is the first point. + """ + self.components.step_scan_at_each_point(motors, positions, last_positions=last_positions) + + @scan_hook + def post_scan(self): + """ + Post-scan steps to be executed after the main scan logic. + """ + status = self.actions.complete_all_devices(wait=False) + + if self.relative: + # Move the motors back to their starting position + self.components.move_and_wait(self.motors, self.start_positions) + status.wait() + + @scan_hook + def unstage(self): + """Unstage the scan by executing post-scan steps.""" + self.actions.unstage_all_devices() + + @scan_hook + def close_scan(self): + """Close the scan.""" + if self._baseline_readout_status is not None: + self._baseline_readout_status.wait() + self.actions.close_scan() + self.actions.check_for_unchecked_statuses() + + @scan_hook + def on_exception(self, exception: Exception): + """ + Handle exceptions that occur during the scan. + This is a good place to implement any cleanup logic that needs to be executed in case of an exception, + such as returning the devices to a safe state or moving the motors back to their starting position. + """ + if self.relative: + # Move the motors back to their starting position + self.components.move_and_wait(self.motors, self.start_positions) diff --git a/bec_server/bec_server/scan_server/scans/line_sweep_scan.py b/bec_server/bec_server/scan_server/scans/line_sweep_scan.py new file mode 100644 index 000000000..83ff2e1f1 --- /dev/null +++ b/bec_server/bec_server/scan_server/scans/line_sweep_scan.py @@ -0,0 +1,265 @@ +""" +Line sweep scan implementation for acquiring data while observing a moving device. + +Scan procedure: + - prepare_scan + - open_scan + - stage + - pre_scan + - scan_core + - at_each_point (optionally called by scan_core) + - post_scan + - unstage + - close_scan + - on_exception (called if any exception is raised during the scan) +""" + +from __future__ import annotations + +import threading +import time +from typing import Annotated + +import numpy as np + +from bec_lib.connector import MessageObject +from bec_lib.device import DeviceBase +from bec_lib.endpoints import MessageEndpoints +from bec_lib.scan_args import ScanArgument, Units +from bec_server.scan_server.scans.scan_modifier import scan_hook +from bec_server.scan_server.scans.scans_v4 import ScanBase, ScanType + + +class LineSweepScan(ScanBase): + # Scan Type: Hardware triggered or software triggered? + # If the main trigger and readout logic is done within the at_each_point method in scan_core, choose SOFTWARE_TRIGGERED. + # If the main trigger and readout logic is implemented on a device that is simply kicked off in this scan, choose HARDWARE_TRIGGERED. + # This primarily serves as information for devices: The device may need to react differently if a software trigger is expected + # for every point. + scan_type = ScanType.SOFTWARE_TRIGGERED + + # Scan name: This is the name of the scan, e.g. "line_scan". This is used for display purposes and to identify the scan type in user interfaces. + # Choose a descriptive name that does not conflict with existing scan names. + scan_name = "_v4_line_sweep_scan" + required_kwargs = ["relative"] + gui_config = { + "Device": ["device", "start", "stop"], + "Scan Parameters": ["min_update", "relative"], + } + + def __init__( + self, + device: DeviceBase, + start: Annotated[ + float, ScanArgument(display_name="Start Position", reference_units="device") + ], + stop: Annotated[ + float, ScanArgument(display_name="Stop Position", reference_units="device") + ], + exp_time: Annotated[ + float, ScanArgument(display_name="Exposure Time", units=Units.s, ge=0) + ] = 0, + frames_per_trigger: Annotated[ + int, ScanArgument(display_name="Frames per Trigger", ge=1) + ] = 1, + min_update: Annotated[ + float, ScanArgument(display_name="Minimum Update", units=Units.s, ge=0) + ] = 0, + max_update: Annotated[ + float, ScanArgument(display_name="Maximum Update", units=Units.s, ge=0) + ] = 0, + relative: bool = False, + **kwargs, + ): + """ + Read out monitored devices while a single device moves continuously from + start to stop. The readout is triggered by updates on the readback signal of the moving device + and may be throttled by setting a minimum update time. If a maximum update time is set, the readout + will be triggered if the time since the last update exceeds the maximum update time, even if no new update has been received. + + Args: + device (DeviceBase): monitored device + start (float): start position + stop (float): stop position + exp_time (float): exposure time. Default is 0. + frames_per_trigger (int): number of frames per trigger. Default is 1. + min_update (float): minimum delay between readout updates. Default is 0. + max_update (float): maximum delay between readout updates. Default is 0. + relative (bool): if True, the start and stop positions are relative to the current position. Default is False. + + Returns: + ScanReport + + Examples: + >>> scans.line_sweep_scan(dev.motor1, -5, 5, min_update=0.1, relative=True) + + """ + super().__init__(**kwargs) + device = self.dev[device] if isinstance(device, str) else device + self.device = device + self.motors = [device] + self.start = start + self.stop = stop + self.min_update = min_update + self.max_update = max_update + self.relative = relative + self._readback_update_event = threading.Event() + + self.update_scan_info( + exp_time=exp_time, + frames_per_trigger=frames_per_trigger, + relative=relative, + scan_report_devices=self.motors, + ) + self.actions.set_device_readout_priority(self.motors, priority="monitored") + + @scan_hook + def prepare_scan(self): + """ + Prepare the scan. This can include any steps that need to be executed + before the scan is opened, such as preparing the positions (if not done already) + or setting up the devices. + """ + self.positions = np.array([[self.start], [self.stop]], dtype=float) + if self.relative: + self.start_positions = self.components.get_start_positions(self.motors) + self.positions += self.start_positions + self.components.check_limits(self.motors, self.positions) + self.actions.add_scan_report_instruction_scan_progress(points=0, show_table=False) + + self.update_scan_info(positions=self.positions, num_points=0, num_monitored_readouts=0) + self._baseline_readout_status = self.actions.read_baseline_devices(wait=False) + self._premove_motor_status = self.actions.set(self.motors, self.positions[0], wait=False) + + @scan_hook + def open_scan(self): + """ + Open the scan. + This step must call self.actions.open_scan() to ensure that a new scan is + opened. Make sure to prepare the scan metadata before, either in + prepare_scan() or in open_scan() itself and call self.update_scan_info(...) + to update the scan metadata if needed. + """ + self.actions.open_scan() + + @scan_hook + def stage(self): + """ + Stage the devices for the upcoming scan. The stage logic is typically + implemented on the device itself (i.e. by the device's stage method). + However, if there are any additional steps that need to be executed before + staging the devices, they can be implemented here. + """ + self.actions.stage_all_devices() + + @scan_hook + def pre_scan(self): + """ + Pre-scan steps to be executed before the main scan logic. + This is typically the last chance to prepare the devices before the core scan + logic is executed. For example, this is a good place to initialize time-criticial + devices, e.g. devices that have a short timeout. + The pre-scan logic is typically implemented on the device itself. + """ + self._premove_motor_status.wait() + self.actions.pre_scan_all_devices() + + @scan_hook + def scan_core(self): + """ + Core scan logic to be executed during the scan. + This is where the main scan logic should be implemented. + """ + self._register_readback_updates() + try: + status = self.device.set(self.positions[1][0]) + last_update_time = time.time() + + while not status.done: + update_received = self._readback_update_event.wait(timeout=0.05) + + # If no update has been received, check if the + # maximum update time has been exceeded. If not, + # continue waiting for updates. + if not update_received: + if self.max_update <= 0 or (time.time() - last_update_time) <= self.max_update: + continue + # We've exceeded the maximum update time, so pretend we've received an update + self._readback_update_event.set() + + if not self._readback_update_event.is_set(): + continue + self._readback_update_event.clear() + + # Trigger the main readout logic. + self.at_each_point() + + last_update_time = time.time() + if self.min_update: + time.sleep(self.min_update) + finally: + self._unregister_readback_updates() + + @scan_hook + def at_each_point(self): + """ + Logic to be executed at each acquisition point during the scan. + This hook allows concrete line-sweep variants to extend or override the + per-point behavior without reimplementing the full scan_core method. + """ + self.components.trigger_and_read() + + @scan_hook + def post_scan(self): + """ + Post-scan steps to be executed after the main scan logic. + """ + status = self.actions.complete_all_devices(wait=False) + if self.relative: + # Move the motors back to their starting position + self.components.move_and_wait(self.motors, self.start_positions) + status.wait() + + @scan_hook + def unstage(self): + """Unstage the scan by executing post-scan steps.""" + self.actions.unstage_all_devices() + + @scan_hook + def close_scan(self): + """Close the scan.""" + if self._baseline_readout_status is not None: + self._baseline_readout_status.wait() + self.actions.close_scan() + self.actions.check_for_unchecked_statuses() + + @scan_hook + def on_exception(self, exception: Exception): + """ + Handle exceptions that occur during the scan. + This is a good place to implement any cleanup logic that needs to be executed in case of an exception, + such as returning the devices to a safe state or moving the motors back to their starting position. + """ + if self.relative: + # Move the motors back to their starting position + self.components.move_and_wait(self.motors, self.start_positions) + + ####################################################### + ######### Helper methods for the scan logic ########### + ####################################################### + + def _register_readback_updates(self): + self._readback_update_event.clear() + self.redis_connector.register( + MessageEndpoints.device_readback(self.device.root.name), + cb=self._device_readback_callback, + ) + + def _unregister_readback_updates(self): + self.redis_connector.unregister( + MessageEndpoints.device_readback(self.device.root.name), + cb=self._device_readback_callback, + ) + + def _device_readback_callback(self, _msg: MessageObject): + self._readback_update_event.set() diff --git a/bec_server/bec_server/scan_server/scans/list_scan.py b/bec_server/bec_server/scan_server/scans/list_scan.py new file mode 100644 index 000000000..6d0823df7 --- /dev/null +++ b/bec_server/bec_server/scan_server/scans/list_scan.py @@ -0,0 +1,260 @@ +""" +List scan implementation for explicit point-by-point motor trajectories. + +Scan procedure: + - prepare_scan + - open_scan + - stage + - pre_scan + - scan_core + - at_each_point (optionally called by scan_core) + - post_scan + - unstage + - close_scan + - on_exception (called if any exception is raised during the scan) +""" + +from __future__ import annotations + +from typing import Annotated + +import numpy as np + +from bec_lib.device import DeviceBase +from bec_lib.scan_args import ScanArgument, Units +from bec_server.scan_server.scans.scan_modifier import scan_hook +from bec_server.scan_server.scans.scans_v4 import ScanBase, ScanType, bundle_args + + +class ListScan(ScanBase): + # Scan Type: Hardware triggered or software triggered? + # If the main trigger and readout logic is done within the at_each_point method in scan_core, choose SOFTWARE_TRIGGERED. + # If the main trigger and readout logic is implemented on a device that is simply kicked off in this scan, choose HARDWARE_TRIGGERED. + # This primarily serves as information for devices: The device may need to react differently if a software trigger is expected + # for every point. + scan_type = ScanType.SOFTWARE_TRIGGERED + + # Scan name: This is the name of the scan, e.g. "line_scan". This is used for display purposes and to identify the scan type in user interfaces. + # Choose a descriptive name that does not conflict with existing scan names. + scan_name = "_v4_list_scan" + + # arg_input and arg_bundle_size are only relevant for scans that accept an arbitrary number of motor / position arguments (e.g. line scans, grid scans). + # For scans with a fixed set of parameters (e.g. Fermat spiral), these can be simply removed. + arg_input = {"device": DeviceBase, "positions": list[float]} + arg_bundle_size = {"bundle": len(arg_input), "min": 1, "max": None} + required_kwargs = ["relative"] + + gui_config = { + "Movement Parameters": ["relative"], + "Acquisition Parameters": [ + "exp_time", + "frames_per_trigger", + "settling_time", + "settling_time_after_trigger", + "readout_time", + "burst_at_each_point", + ], + } + + def __init__( + self, + *args, + exp_time: Annotated[ + float, ScanArgument(display_name="Exposure Time", units=Units.s, ge=0) + ] = 0, + frames_per_trigger: Annotated[ + int, ScanArgument(display_name="Frames per Trigger", ge=1) + ] = 1, + settling_time: Annotated[ + float, ScanArgument(display_name="Settling Time", units=Units.s, ge=0) + ] = 0, + settling_time_after_trigger: Annotated[ + float, ScanArgument(display_name="Settling Time After Trigger", units=Units.s, ge=0) + ] = 0, + readout_time: Annotated[ + float, ScanArgument(display_name="Readout Time", units=Units.s, ge=0) + ] = 0, + relative: bool = False, + burst_at_each_point: Annotated[ + int, ScanArgument(display_name="Burst at Each Point", ge=1) + ] = 1, + **kwargs, + ): + """ + A scan following the positions specified in a list. + Please note that all lists must be of equal length. + + Args: + *args (Device, list[float]): pairs of device / positions arguments + exp_time (float): exposure time in seconds. Default is 0. + frames_per_trigger (int): number of frames acquired per trigger. Default is 1. + settling_time (float): settling time in seconds. Default is 0. + settling_time_after_trigger (float): settling time after trigger in seconds. Default is 0. + readout_time (float): readout time in seconds. Default is 0. + relative (bool): if True, the positions will be moved relative to their current position. Default is False. + burst_at_each_point (int): number of exposures at each point. Default is 1. + + Returns: + ScanReport + + Examples: + >>> scans.list_scan(dev.motor1, [0, 1, 2], dev.motor2, [4, 3, 2], exp_time=0.1, relative=True) + + """ + super().__init__(**kwargs) + self.motor_args = args + self.motor_input_bundles = bundle_args(args, bundle_size=self.arg_bundle_size["bundle"]) + self.motors = list(self.motor_input_bundles.keys()) + self.relative = relative + self.exp_time = exp_time + self.settling_time = settling_time + self.burst_at_each_point = burst_at_each_point + + lengths = {len(positions[0]) for positions in self.motor_input_bundles.values()} + if len(lengths) > 1: + raise ValueError("All position lists must be of equal length.") + + # Update the default scan info with provided parameters. + self.update_scan_info( + exp_time=exp_time, + frames_per_trigger=frames_per_trigger, + settling_time=settling_time, + settling_time_after_trigger=settling_time_after_trigger, + readout_time=readout_time, + relative=relative, + burst_at_each_point=burst_at_each_point, + scan_report_devices=self.motors, + ) + + # We elevate the readout priority of the scan motors to "monitored" to ensure + # that their positions are included in every readout of the step scan. + self.actions.set_device_readout_priority(self.motors, priority="monitored") + + @scan_hook + def prepare_scan(self): + """ + Prepare the scan. This can include any steps that need to be executed + before the scan is opened, such as preparing the positions (if not done already) + or setting up the devices. + """ + self.positions = np.vstack( + tuple(positions[0] for positions in self.motor_input_bundles.values()) + ).T.astype(float) + + if self.relative: + self.start_positions = self.components.get_start_positions(self.motors) + self.positions += self.start_positions + + self.components.check_limits(self.motors, self.positions) + + self.update_scan_info( + positions=self.positions, + num_points=len(self.positions), + num_monitored_readouts=len(self.positions) * self.burst_at_each_point, + ) + + self.actions.add_scan_report_instruction_scan_progress( + points=self.scan_info.num_monitored_readouts, show_table=False + ) + + self._premove_motor_status = self.actions.set(self.motors, self.positions[0], wait=False) + self._baseline_readout_status = self.actions.read_baseline_devices(wait=False) + + @scan_hook + def open_scan(self): + """ + Open the scan. + This step must call self.actions.open_scan() to ensure that a new scan is + opened. Make sure to prepare the scan metadata before, either in + prepare_scan() or in open_scan() itself and call self.update_scan_info(...) + to update the scan metadata if needed. + """ + self.actions.open_scan() + + @scan_hook + def stage(self): + """ + Stage the devices for the upcoming scan. The stage logic is typically + implemented on the device itself (i.e. by the device's stage method). + However, if there are any additional steps that need to be executed before + staging the devices, they can be implemented here. + """ + self.actions.stage_all_devices() + + @scan_hook + def pre_scan(self): + """ + Pre-scan steps to be executed before the main scan logic. + This is typically the last chance to prepare the devices before the core scan + logic is executed. For example, this is a good place to initialize time-criticial + devices, e.g. devices that have a short timeout. + The pre-scan logic is typically implemented on the device itself. + """ + self._premove_motor_status.wait() + self.actions.pre_scan_all_devices() + + @scan_hook + def scan_core(self): + """ + Core scan logic to be executed during the scan. + This is where the main scan logic should be implemented. + """ + self.components.step_scan( + self.motors, + self.positions, + at_each_point=self.at_each_point, + last_positions=self.positions[0], + ) + + @scan_hook + def at_each_point( + self, + motors: list[str | DeviceBase], + positions: np.ndarray, + last_positions: np.ndarray | None, + ): + """ + Logic to be executed at each point during the scan. This is called by the step_scan method at each point. + + Args: + motors (list[str | DeviceBase]): List of motor names or device instances being moved. + positions (np.ndarray): Current positions of the motors, shape (len(motors),). + last_positions (np.ndarray | None): Previous positions of the motors, shape (len(motors),) or None if this is the first point. + """ + self.components.step_scan_at_each_point(motors, positions, last_positions=last_positions) + + @scan_hook + def post_scan(self): + """ + Post-scan steps to be executed after the main scan logic. + """ + status = self.actions.complete_all_devices(wait=False) + + if self.relative: + # Move the motors back to their starting position + self.components.move_and_wait(self.motors, self.start_positions) + status.wait() + + @scan_hook + def unstage(self): + """Unstage the scan by executing post-scan steps.""" + self.actions.unstage_all_devices() + + @scan_hook + def close_scan(self): + """Close the scan.""" + if self._baseline_readout_status is not None: + self._baseline_readout_status.wait() + self.actions.close_scan() + self.actions.check_for_unchecked_statuses() + + @scan_hook + def on_exception(self, exception: Exception): + """ + Handle exceptions that occur during the scan. + This is a good place to implement any cleanup logic that needs to be executed in case of an exception, + such as returning the devices to a safe state or moving the motors back to their starting position. + """ + if self.relative: + # Move the motors back to their starting position + self.components.move_and_wait(self.motors, self.start_positions) diff --git a/bec_server/bec_server/scan_server/scans/log_scan.py b/bec_server/bec_server/scan_server/scans/log_scan.py new file mode 100644 index 000000000..f9b78bf21 --- /dev/null +++ b/bec_server/bec_server/scan_server/scans/log_scan.py @@ -0,0 +1,253 @@ +""" +Logarithmic line scan implementation for one or more motors. + +Scan procedure: + - prepare_scan + - open_scan + - stage + - pre_scan + - scan_core + - at_each_point (optionally called by scan_core) + - post_scan + - unstage + - close_scan + - on_exception (called if any exception is raised during the scan) +""" + +from __future__ import annotations + +from typing import Annotated + +import numpy as np + +from bec_lib.device import DeviceBase +from bec_lib.scan_args import ScanArgument, Units +from bec_server.scan_server.scans import position_generators +from bec_server.scan_server.scans.scan_modifier import scan_hook +from bec_server.scan_server.scans.scans_v4 import ScanBase, ScanType, bundle_args + + +class LogScan(ScanBase): + # Scan Type: Hardware triggered or software triggered? + # If the main trigger and readout logic is done within the at_each_point method in scan_core, choose SOFTWARE_TRIGGERED. + # If the main trigger and readout logic is implemented on a device that is simply kicked off in this scan, choose HARDWARE_TRIGGERED. + # This primarily serves as information for devices: The device may need to react differently if a software trigger is expected + # for every point. + scan_type = ScanType.SOFTWARE_TRIGGERED + + # Scan name: This is the name of the scan, e.g. "line_scan". This is used for display purposes and to identify the scan type in user interfaces. + # Choose a descriptive name that does not conflict with existing scan names. + scan_name = "_v4_log_scan" + + # arg_input and arg_bundle_size are only relevant for scans that accept an arbitrary number of motor / position arguments (e.g. line scans, grid scans). + # For scans with a fixed set of parameters (e.g. Fermat spiral), these can be simply removed. + arg_input = { + "device": DeviceBase, + "start": Annotated[ + float, ScanArgument(display_name="Start Position", reference_units="device") + ], + "stop": Annotated[ + float, ScanArgument(display_name="Stop Position", reference_units="device") + ], + } + arg_bundle_size = {"bundle": len(arg_input), "min": 1, "max": None} + required_kwargs = ["steps", "relative"] + + gui_config = { + "Movement Parameters": ["steps", "relative"], + "Acquisition Parameters": [ + "exp_time", + "frames_per_trigger", + "settling_time", + "settling_time_after_trigger", + "readout_time", + "burst_at_each_point", + ], + } + + def __init__( + self, + *args, + steps: Annotated[int, ScanArgument(display_name="Steps", ge=1)] = 1, + exp_time: Annotated[ + float, ScanArgument(display_name="Exposure Time", units=Units.s, ge=0) + ] = 0, + frames_per_trigger: Annotated[ + int, ScanArgument(display_name="Frames per Trigger", ge=1) + ] = 1, + settling_time: Annotated[ + float, ScanArgument(display_name="Settling Time", units=Units.s, ge=0) + ] = 0, + settling_time_after_trigger: Annotated[ + float, ScanArgument(display_name="Settling Time After Trigger", units=Units.s, ge=0) + ] = 0, + readout_time: Annotated[ + float, ScanArgument(display_name="Readout Time", units=Units.s, ge=0) + ] = 0, + relative: bool = False, + burst_at_each_point: Annotated[ + int, ScanArgument(display_name="Burst at Each Point", ge=1) + ] = 1, + **kwargs, + ): + """ + A scan for one or more motors with logarithmically spaced positions. + + Args: + *args (Device, float, float): pairs of device / start / stop arguments + steps (int): number of points along the trajectory + exp_time (float): exposure time in seconds. Default is 0. + frames_per_trigger (int): number of frames acquired per trigger. Default is 1. + settling_time (float): settling time in seconds. Default is 0. + settling_time_after_trigger (float): settling time after trigger in seconds. Default is 0. + readout_time (float): readout time in seconds. Default is 0. + relative (bool): if True, the positions are interpreted relative to the current position. Default is False. + burst_at_each_point (int): number of exposures at each point. Default is 1. + + Returns: + ScanReport + + Examples: + >>> scans.log_scan(dev.motor1, 1, 100, steps=10, exp_time=0.1, relative=False) + """ + super().__init__(**kwargs) + self.motor_args = args + self.motor_input_bundles = bundle_args(args, bundle_size=self.arg_bundle_size["bundle"]) + self.motors = list(self.motor_input_bundles.keys()) + self.steps = steps + self.relative = relative + self.exp_time = exp_time + self.settling_time = settling_time + self.burst_at_each_point = burst_at_each_point + + self.update_scan_info( + exp_time=exp_time, + frames_per_trigger=frames_per_trigger, + settling_time=settling_time, + settling_time_after_trigger=settling_time_after_trigger, + readout_time=readout_time, + relative=relative, + burst_at_each_point=burst_at_each_point, + scan_report_devices=self.motors, + ) + self.actions.set_device_readout_priority(self.motors, priority="monitored") + + @scan_hook + def prepare_scan(self): + """ + Prepare the logarithmically spaced scan trajectory before the scan starts. + This generates the point list, resolves relative coordinates if requested, + checks device limits, initializes progress reporting, and starts baseline readout. + """ + self.positions = position_generators.log_scan_positions( + list(self.motor_input_bundles.values()), steps=self.steps + ) + + if self.relative: + self.start_positions = self.components.get_start_positions(self.motors) + self.positions += self.start_positions + + self.components.check_limits(self.motors, self.positions) + + self.update_scan_info( + positions=self.positions, + num_points=len(self.positions), + num_monitored_readouts=len(self.positions) * self.burst_at_each_point, + ) + + self.actions.add_scan_report_instruction_scan_progress( + points=self.scan_info.num_monitored_readouts, show_table=False + ) + + self._premove_motor_status = self.actions.set(self.motors, self.positions[0], wait=False) + + self._baseline_readout_status = self.actions.read_baseline_devices(wait=False) + + @scan_hook + def open_scan(self): + """Open the scan.""" + self.actions.open_scan() + + @scan_hook + def stage(self): + """ + Stage all devices participating in the scan. + """ + self.actions.stage_all_devices() + + @scan_hook + def pre_scan(self): + """ + Execute pre-scan device logic before the point-by-point trajectory begins. + """ + self._premove_motor_status.wait() + self.actions.pre_scan_all_devices() + + @scan_hook + def scan_core(self): + """ + Execute the logarithmic step scan over the prepared trajectory. + """ + self.components.step_scan( + self.motors, + self.positions, + at_each_point=self.at_each_point, + last_positions=self.positions[0], + ) + + @scan_hook + def at_each_point( + self, + motors: list[str | DeviceBase], + positions: np.ndarray, + last_positions: np.ndarray | None, + ): + """ + Execute the acquisition logic for a single point on the logarithmic trajectory. + + Args: + motors (list[str | DeviceBase]): List of motor names or device instances being moved. + positions (np.ndarray): Current positions of the motors, shape (len(motors),). + last_positions (np.ndarray | None): Previous positions of the motors, shape + (len(motors),) or None if this is the first point. + """ + self.components.step_scan_at_each_point(motors, positions, last_positions=last_positions) + + @scan_hook + def post_scan(self): + """ + Complete device activity after the point-by-point trajectory finishes. + If the scan was configured as relative, the motors are returned to their starting positions. + """ + status = self.actions.complete_all_devices(wait=False) + + if self.relative: + # Move the motors back to their starting position + self.components.move_and_wait(self.motors, self.start_positions) + status.wait() + + @scan_hook + def unstage(self): + """Unstage all devices after the scan completes.""" + self.actions.unstage_all_devices() + + @scan_hook + def close_scan(self): + """ + Close the scan after any pending baseline readout has completed. + """ + if self._baseline_readout_status is not None: + self._baseline_readout_status.wait() + self.actions.close_scan() + self.actions.check_for_unchecked_statuses() + + @scan_hook + def on_exception(self, exception: Exception): + """ + Handle exceptions that occur during the scan. + This is a good place to implement any cleanup logic that needs to be executed in case of an exception, + such as returning the devices to a safe state or moving the motors back to their starting position. + """ + if self.relative: + # Move the motors back to their starting position + self.components.move_and_wait(self.motors, self.start_positions) diff --git a/bec_server/bec_server/scan_server/scans/move.py b/bec_server/bec_server/scan_server/scans/move.py new file mode 100644 index 000000000..6ca70b965 --- /dev/null +++ b/bec_server/bec_server/scan_server/scans/move.py @@ -0,0 +1,161 @@ +""" +Move scan implementation for repositioning one or more motors without acquisition. + +Scan procedure: + - prepare_scan + - open_scan + - stage + - pre_scan + - scan_core + - at_each_point (optionally called by scan_core) + - post_scan + - unstage + - close_scan + - on_exception (called if any exception is raised during the scan) +""" + +from __future__ import annotations + +from bec_lib.device import DeviceBase +from bec_lib.logger import bec_logger +from bec_server.scan_server.scans.scan_modifier import scan_hook +from bec_server.scan_server.scans.scans_v4 import ScanBase, bundle_args + +logger = bec_logger.logger + + +class MoveScan(ScanBase): + + # Scan Type: Hardware triggered or software triggered? + # If the main trigger and readout logic is done within the at_each_point method in scan_core, choose SOFTWARE_TRIGGERED. + # If the main trigger and readout logic is implemented on a device that is simply kicked off in this scan, choose HARDWARE_TRIGGERED. + # This primarily serves as information for devices: The device may need to react differently if a software trigger is expected + # for every point. + scan_type = None + + # Scan name: This is the name of the scan, e.g. "line_scan". This is used for display purposes and to identify the scan type in user interfaces. + # Choose a descriptive name that does not conflict with existing scan names. + scan_name = "_v4_mv" + + # arg_input and arg_bundle_size are only relevant for scans that accept an arbitrary number of motor / position arguments (e.g. line scans, grid scans). + # For scans with a fixed set of parameters (e.g. Fermat spiral), these can be simply removed. + arg_input = {"device": DeviceBase, "target": float} + arg_bundle_size = {"bundle": len(arg_input), "min": 1, "max": None} + required_kwargs = ["relative"] + + # We set is_scan to False to separate this class from the other scans in the user interface + is_scan = False + + def __init__(self, *args, relative: bool = False, **kwargs): + """ + Simple move command that moves one or more motors to the specified positions. + The mv command gives back control to the user immediately after sending the command. For a blocking call + with live updates, use the umv command instead. + + + Args: + *args (Device, float): pairs of device / target position arguments + relative (bool): if True, the motors will be moved relative to their current position. + + Returns: + ScanReport + + Examples: + >>> scans.mv(dev.motor1, -5, dev.motor2, 5, relative=True) + + """ + super().__init__(**kwargs) + self.motor_args = args + self.motor_args_bundles = bundle_args(args, self.arg_bundle_size["bundle"]) + self.motors = list(self.motor_args_bundles.keys()) + self.relative = relative + + # Update the default scan info with provided parameters. + self.update_scan_info(relative=relative, scan_report_devices=self.motors) + + @scan_hook + def prepare_scan(self): + """ + Prepare the scan. This can include any steps that need to be executed + before the scan is opened, such as preparing the positions (if not done already) + or setting up the devices. + """ + self.actions.add_device_with_required_response(self.motors) + + @scan_hook + def open_scan(self): + """ + Open the scan. + This step must call self.actions.open_scan() to ensure that a new scan is + opened. Make sure to prepare the scan metadata before, either in + prepare_scan() or in open_scan() itself and call self.update_scan_info(...) + to update the scan metadata if needed. + """ + + @scan_hook + def stage(self): + """ + Stage the devices for the upcoming scan. The stage logic is typically + implemented on the device itself (i.e. by the device's stage method). + However, if there are any additional steps that need to be executed before + staging the devices, they can be implemented here. + """ + + @scan_hook + def pre_scan(self): + """ + Pre-scan steps to be executed before the main scan logic. + This is typically the last chance to prepare the devices before the core scan + logic is executed. For example, this is a good place to initialize time-criticial + devices, e.g. devices that have a short timeout. + The pre-scan logic is typically implemented on the device itself. + """ + + @scan_hook + def scan_core(self): + """ + Core scan logic to be executed during the scan. + This is where the main scan logic should be implemented. + """ + target_positions = [pos[0] for pos in self.motor_args_bundles.values()] + if self.relative: + current_positions = self.components.get_start_positions(self.motors) + target_positions = [ + target + current + for target, current in zip(target_positions, current_positions, strict=False) + ] + + self.actions.set(self.motors, target_positions, wait=False) + + @scan_hook + def at_each_point(self): + """ + Logic to be executed at each point during the scan. This is called by the step_scan method at each point. + + Args: + motors (list[str | DeviceBase]): List of motor names or device instances being moved. + positions (np.ndarray): Current positions of the motors, shape (len(motors),). + last_positions (np.ndarray | None): Previous positions of the motors, shape (len(motors),) or None if this is the first point. + """ + + @scan_hook + def post_scan(self): + """ + Post-scan steps to be executed after the main scan logic. + """ + + @scan_hook + def unstage(self): + """Unstage the scan by executing post-scan steps.""" + + @scan_hook + def close_scan(self): + """Close the scan.""" + + @scan_hook + def on_exception(self, exception: Exception): + """ + Handle exceptions that occur during the scan. + This is a good place to implement any cleanup logic that needs to be executed in case of an exception, + such as returning the devices to a safe state or moving the motors back to their starting position. + """ diff --git a/bec_server/bec_server/scan_server/scans/multi_region_grid_scan.py b/bec_server/bec_server/scan_server/scans/multi_region_grid_scan.py new file mode 100644 index 000000000..f75b1e906 --- /dev/null +++ b/bec_server/bec_server/scan_server/scans/multi_region_grid_scan.py @@ -0,0 +1,257 @@ +""" +Multi-region grid scan implementation for two-motor scans with disjoint regions. + +Scan procedure: + - prepare_scan + - open_scan + - stage + - pre_scan + - scan_core + - at_each_point (optionally called by scan_core) + - post_scan + - unstage + - close_scan + - on_exception (called if any exception is raised during the scan) +""" + +from __future__ import annotations + +from typing import Annotated + +import numpy as np + +from bec_lib.device import DeviceBase +from bec_lib.scan_args import ScanArgument, Units +from bec_server.scan_server.scans import position_generators +from bec_server.scan_server.scans.scan_modifier import scan_hook +from bec_server.scan_server.scans.scans_v4 import ScanBase, ScanType + + +class MultiRegionGridScan(ScanBase): + # Scan Type: Hardware triggered or software triggered? + # If the main trigger and readout logic is done within the at_each_point method in scan_core, choose SOFTWARE_TRIGGERED. + # If the main trigger and readout logic is implemented on a device that is simply kicked off in this scan, choose HARDWARE_TRIGGERED. + # This primarily serves as information for devices: The device may need to react differently if a software trigger is expected + # for every point. + scan_type = ScanType.SOFTWARE_TRIGGERED + + # Scan name: This is the name of the scan, e.g. "line_scan". This is used for display purposes and to identify the scan type in user interfaces. + # Choose a descriptive name that does not conflict with existing scan names. + scan_name = "_v4_multi_region_grid_scan" + + required_kwargs = ["regions", "relative"] + + gui_config = { + "Motors": ["motor1", "motor2"], + "Movement Parameters": ["regions", "relative", "snaked"], + "Acquisition Parameters": ["exp_time", "settling_time", "burst_at_each_point"], + } + + def __init__( + self, + motor1: DeviceBase, + motor2: DeviceBase, + *, + regions: list[tuple[tuple[float, float, int], tuple[float, float, int]]], + exp_time: Annotated[ + float, ScanArgument(display_name="Exposure Time", units=Units.s, ge=0) + ] = 0, + frames_per_trigger: Annotated[ + int, ScanArgument(display_name="Frames per Trigger", ge=1) + ] = 1, + settling_time: Annotated[ + float, ScanArgument(display_name="Settling Time", units=Units.s, ge=0) + ] = 0, + settling_time_after_trigger: Annotated[ + float, ScanArgument(display_name="Settling Time After Trigger", units=Units.s, ge=0) + ] = 0, + readout_time: Annotated[ + float, ScanArgument(display_name="Readout Time", units=Units.s, ge=0) + ] = 0, + relative: bool = False, + snaked: bool = True, + burst_at_each_point: Annotated[ + int, ScanArgument(display_name="Burst at Each Point", ge=1) + ] = 1, + **kwargs, + ): + """ + Scan two motors on multiple independent rectangular sub-grids. + + For a single region, + ``scans.multi_region_grid_scan(motor1, motor2, regions=[((start1, stop1, steps1), (start2, stop2, steps2))], ...)`` + is equivalent to the standard scan + ``scans.grid_scan(motor1, start1, stop1, steps1, motor2, start2, stop2, steps2, ...)``. + + Args: + motor1 (DeviceBase): first motor + motor2 (DeviceBase): second motor + regions (list[tuple[tuple[float, float, int], tuple[float, float, int]]]): + sequence of paired region definitions. Each entry contains one + ``(start, stop, steps)`` tuple for ``motor1`` and one for ``motor2``. + exp_time (float): exposure time in seconds. Default is 0. + frames_per_trigger (int): number of frames acquired per trigger. Default is 1. + settling_time (float): settling time in seconds. Default is 0. + settling_time_after_trigger (float): settling time after trigger in seconds. Default is 0. + readout_time (float): readout time in seconds. Default is 0. + relative (bool): if True, the generated positions are interpreted relative to the + current motor positions. Default is False. + snaked (bool): if True, the second axis is traversed in alternating directions + within each sub-grid. Default is True. + burst_at_each_point (int): number of exposures at each point. Default is 1. + + Returns: + ScanReport + + Examples: + >>> scans.multi_region_grid_scan(dev.motor1, dev.motor2, regions=[((-5, -1, 5), (-4, 0, 5)), ((1, 5, 3), (-4, 0, 5))], exp_time=0.1, relative=True) + """ + super().__init__(**kwargs) + self.motors = [motor1, motor2] + self.regions = regions + self.exp_time = exp_time + self.settling_time = settling_time + self.relative = relative + self.snaked = snaked + self.burst_at_each_point = burst_at_each_point + + self.update_scan_info( + exp_time=exp_time, + frames_per_trigger=frames_per_trigger, + settling_time=settling_time, + settling_time_after_trigger=settling_time_after_trigger, + readout_time=readout_time, + relative=relative, + snaked=snaked, + burst_at_each_point=burst_at_each_point, + regions=regions, + scan_report_devices=self.motors, + ) + self.actions.set_device_readout_priority(self.motors, priority="monitored") + + @scan_hook + def prepare_scan(self): + """ + Prepare the scan. This can include any steps that need to be executed + before the scan is opened, such as preparing the positions (if not done already) + or setting up the devices. + """ + self.positions = position_generators.multi_region_grid_positions( + self.regions, snaked=self.snaked + ) + + if self.relative: + self.start_positions = self.components.get_start_positions(self.motors) + self.positions += self.start_positions + + self.components.check_limits(self.motors, self.positions) + + self.update_scan_info( + positions=self.positions, + num_points=len(self.positions), + num_monitored_readouts=len(self.positions) * self.burst_at_each_point, + ) + + self.actions.add_scan_report_instruction_scan_progress( + points=self.scan_info.num_monitored_readouts, show_table=False + ) + self._premove_motor_status = self.actions.set(self.motors, self.positions[0], wait=False) + self._baseline_readout_status = self.actions.read_baseline_devices(wait=False) + + @scan_hook + def open_scan(self): + """ + Open the scan. + This step must call self.actions.open_scan() to ensure that a new scan is + opened. Make sure to prepare the scan metadata before, either in + prepare_scan() or in open_scan() itself and call self.update_scan_info(...) + to update the scan metadata if needed. + """ + self.actions.open_scan() + + @scan_hook + def stage(self): + """ + Stage the devices for the upcoming scan. The stage logic is typically + implemented on the device itself (i.e. by the device's stage method). + However, if there are any additional steps that need to be executed before + staging the devices, they can be implemented here. + """ + self.actions.stage_all_devices() + + @scan_hook + def pre_scan(self): + """ + Pre-scan steps to be executed before the main scan logic. + This is typically the last chance to prepare the devices before the core scan + logic is executed. For example, this is a good place to initialize time-criticial + devices, e.g. devices that have a short timeout. + The pre-scan logic is typically implemented on the device itself. + """ + self._premove_motor_status.wait() + self.actions.pre_scan_all_devices() + + @scan_hook + def scan_core(self): + """ + Core scan logic to be executed during the scan. + This is where the main scan logic should be implemented. + """ + self.components.step_scan( + self.motors, + self.positions, + at_each_point=self.at_each_point, + last_positions=self.positions[0], + ) + + @scan_hook + def at_each_point( + self, + motors: list[str | DeviceBase], + positions: np.ndarray, + last_positions: np.ndarray | None, + ): + """ + Logic to be executed at each point during the scan. This is called by the step_scan method at each point. + + Args: + motors (list[str | DeviceBase]): List of motor names or device instances being moved. + positions (np.ndarray): Current positions of the motors, shape (len(motors),). + last_positions (np.ndarray | None): Previous positions of the motors, shape (len(motors),) or None if this is the first point. + """ + self.components.step_scan_at_each_point(motors, positions, last_positions=last_positions) + + @scan_hook + def post_scan(self): + """ + Post-scan steps to be executed after the main scan logic. + """ + status = self.actions.complete_all_devices(wait=False) + if self.relative: + # Move the motors back to their starting position + self.components.move_and_wait(self.motors, self.start_positions) + status.wait() + + @scan_hook + def unstage(self): + """Unstage the scan by executing post-scan steps.""" + self.actions.unstage_all_devices() + + @scan_hook + def close_scan(self): + """Close the scan.""" + if self._baseline_readout_status is not None: + self._baseline_readout_status.wait() + self.actions.close_scan() + self.actions.check_for_unchecked_statuses() + + @scan_hook + def on_exception(self, exception: Exception): + """ + Handle exceptions that occur during the scan. + This is a good place to implement any cleanup logic that needs to be executed in case of an exception, + such as returning the devices to a safe state or moving the motors back to their starting position. + """ + if self.relative: + # Move the motors back to their starting position + self.components.move_and_wait(self.motors, self.start_positions) diff --git a/bec_server/bec_server/scan_server/scans/multi_region_line_scan.py b/bec_server/bec_server/scan_server/scans/multi_region_line_scan.py new file mode 100644 index 000000000..934a7a2f8 --- /dev/null +++ b/bec_server/bec_server/scan_server/scans/multi_region_line_scan.py @@ -0,0 +1,255 @@ +""" +Multi-region line scan implementation for one motor with disjoint scan regions. + +Scan procedure: + - prepare_scan + - open_scan + - stage + - pre_scan + - scan_core + - at_each_point (optionally called by scan_core) + - post_scan + - unstage + - close_scan + - on_exception (called if any exception is raised during the scan) +""" + +from __future__ import annotations + +from typing import Annotated + +import numpy as np + +from bec_lib.device import DeviceBase +from bec_lib.scan_args import ScanArgument, Units +from bec_server.scan_server.scans import position_generators +from bec_server.scan_server.scans.scan_modifier import scan_hook +from bec_server.scan_server.scans.scans_v4 import ScanBase, ScanType + + +class MultiRegionLineScan(ScanBase): + # Scan Type: Hardware triggered or software triggered? + # If the main trigger and readout logic is done within the at_each_point method in scan_core, choose SOFTWARE_TRIGGERED. + # If the main trigger and readout logic is implemented on a device that is simply kicked off in this scan, choose HARDWARE_TRIGGERED. + # This primarily serves as information for devices: The device may need to react differently if a software trigger is expected + # for every point. + scan_type = ScanType.SOFTWARE_TRIGGERED + + # Scan name: This is the name of the scan, e.g. "line_scan". This is used for display purposes and to identify the scan type in user interfaces. + # Choose a descriptive name that does not conflict with existing scan names. + scan_name = "_v4_multi_region_line_scan" + + required_kwargs = ["regions", "relative"] + + gui_config = { + "Movement Parameters": ["regions", "relative"], + "Acquisition Parameters": [ + "exp_time", + "frames_per_trigger", + "settling_time", + "settling_time_after_trigger", + "readout_time", + "burst_at_each_point", + ], + } + + def __init__( + self, + motor: DeviceBase, + *, + regions: list[tuple[float, float, int]], + exp_time: Annotated[ + float, ScanArgument(display_name="Exposure Time", units=Units.s, ge=0) + ] = 0, + frames_per_trigger: Annotated[ + int, ScanArgument(display_name="Frames per Trigger", ge=1) + ] = 1, + settling_time: Annotated[ + float, ScanArgument(display_name="Settling Time", units=Units.s, ge=0) + ] = 0, + settling_time_after_trigger: Annotated[ + float, ScanArgument(display_name="Settling Time After Trigger", units=Units.s, ge=0) + ] = 0, + readout_time: Annotated[ + float, ScanArgument(display_name="Readout Time", units=Units.s, ge=0) + ] = 0, + relative: bool = False, + burst_at_each_point: Annotated[ + int, ScanArgument(display_name="Burst at Each Point", ge=1) + ] = 1, + **kwargs, + ): + """ + Scan one motor across multiple disjoint line regions. + + For a single region, + ``scans.multi_region_line_scan(motor, regions=[(start, stop, steps)], ...)`` + is equivalent to the standard scan + ``scans.line_scan(motor, start, stop, steps=steps, ...)``. + + Args: + motor (DeviceBase): motor to move + regions (list[tuple[float, float, int]]): sequence of ``(start, stop, steps)`` + region definitions + exp_time (float): exposure time in seconds. Default is 0. + frames_per_trigger (int): number of frames acquired per trigger. Default is 1. + settling_time (float): settling time in seconds. Default is 0. + settling_time_after_trigger (float): settling time after trigger in seconds. Default is 0. + readout_time (float): readout time in seconds. Default is 0. + relative (bool): if True, the generated positions are interpreted relative to the + current motor position. Default is False. + burst_at_each_point (int): number of exposures at each point. Default is 1. + + Returns: + ScanReport + + Examples: + >>> scans.multi_region_line_scan(dev.motor1, regions=[(-5, -2, 4), (1, 5, 3)], exp_time=0.1, relative=True) + """ + super().__init__(**kwargs) + self.motor = motor + self.motors = [motor] + self.regions = regions + self.exp_time = exp_time + self.settling_time = settling_time + self.relative = relative + self.burst_at_each_point = burst_at_each_point + + self.update_scan_info( + exp_time=exp_time, + frames_per_trigger=frames_per_trigger, + settling_time=settling_time, + settling_time_after_trigger=settling_time_after_trigger, + readout_time=readout_time, + relative=relative, + burst_at_each_point=burst_at_each_point, + regions=regions, + scan_report_devices=self.motors, + ) + self.actions.set_device_readout_priority(self.motors, priority="monitored") + + @scan_hook + def prepare_scan(self): + """ + Prepare the scan. This can include any steps that need to be executed + before the scan is opened, such as preparing the positions (if not done already) + or setting up the devices. + """ + self.positions = position_generators.multi_region_line_positions(self.regions) + + if self.relative: + self.start_positions = self.components.get_start_positions(self.motors) + self.positions += self.start_positions + + self.components.check_limits(self.motors, self.positions) + + self.update_scan_info( + positions=self.positions, + num_points=len(self.positions), + num_monitored_readouts=len(self.positions) * self.burst_at_each_point, + ) + + self.actions.add_scan_report_instruction_scan_progress( + points=self.scan_info.num_monitored_readouts, show_table=False + ) + + self._premove_motor_status = self.actions.set(self.motors, self.positions[0], wait=False) + self._baseline_readout_status = self.actions.read_baseline_devices(wait=False) + + @scan_hook + def open_scan(self): + """ + Open the scan. + This step must call self.actions.open_scan() to ensure that a new scan is + opened. Make sure to prepare the scan metadata before, either in + prepare_scan() or in open_scan() itself and call self.update_scan_info(...) + to update the scan metadata if needed. + """ + self.actions.open_scan() + + @scan_hook + def stage(self): + """ + Stage the devices for the upcoming scan. The stage logic is typically + implemented on the device itself (i.e. by the device's stage method). + However, if there are any additional steps that need to be executed before + staging the devices, they can be implemented here. + """ + self.actions.stage_all_devices() + + @scan_hook + def pre_scan(self): + """ + Pre-scan steps to be executed before the main scan logic. + This is typically the last chance to prepare the devices before the core scan + logic is executed. For example, this is a good place to initialize time-criticial + devices, e.g. devices that have a short timeout. + The pre-scan logic is typically implemented on the device itself. + """ + self._premove_motor_status.wait() + self.actions.pre_scan_all_devices() + + @scan_hook + def scan_core(self): + """ + Core scan logic to be executed during the scan. + This is where the main scan logic should be implemented. + """ + self.components.step_scan( + self.motors, + self.positions, + at_each_point=self.at_each_point, + last_positions=self.positions[0], + ) + + @scan_hook + def at_each_point( + self, + motors: list[str | DeviceBase], + positions: np.ndarray, + last_positions: np.ndarray | None, + ): + """ + Logic to be executed at each point during the scan. This is called by the step_scan method at each point. + + Args: + motors (list[str | DeviceBase]): List of motor names or device instances being moved. + positions (np.ndarray): Current positions of the motors, shape (len(motors),). + last_positions (np.ndarray | None): Previous positions of the motors, shape (len(motors),) or None if this is the first point. + """ + self.components.step_scan_at_each_point(motors, positions, last_positions=last_positions) + + @scan_hook + def post_scan(self): + """ + Post-scan steps to be executed after the main scan logic. + """ + status = self.actions.complete_all_devices(wait=False) + if self.relative: + # Move the motors back to their starting position + self.components.move_and_wait(self.motors, self.start_positions) + status.wait() + + @scan_hook + def unstage(self): + """Unstage the scan by executing post-scan steps.""" + self.actions.unstage_all_devices() + + @scan_hook + def close_scan(self): + """Close the scan.""" + if self._baseline_readout_status is not None: + self._baseline_readout_status.wait() + self.actions.close_scan() + self.actions.check_for_unchecked_statuses() + + @scan_hook + def on_exception(self, exception: Exception): + """ + Handle exceptions that occur during the scan. + This is a good place to implement any cleanup logic that needs to be executed in case of an exception, + such as returning the devices to a safe state or moving the motors back to their starting position. + """ + if self.relative: + # Move the motors back to their starting position + self.components.move_and_wait(self.motors, self.start_positions) diff --git a/bec_server/bec_server/scan_server/scans/position_generators.py b/bec_server/bec_server/scan_server/scans/position_generators.py new file mode 100644 index 000000000..585529e98 --- /dev/null +++ b/bec_server/bec_server/scan_server/scans/position_generators.py @@ -0,0 +1,541 @@ +from __future__ import annotations + +from collections.abc import Iterator, Sequence + +import numpy as np + + +def rotate_points( + points: np.ndarray, angle: float, center: tuple[float, float] | None = None +) -> np.ndarray: + """ + Rotate 2D points around a center. + + Args: + points (np.ndarray): Array of shape ``(N, 2)`` containing x/y positions. + angle (float): Rotation angle in radians. + center (tuple[float, float] | None): Optional center of rotation. If omitted, + the points are rotated around the origin. + + Returns: + np.ndarray: Rotated points with the same shape as the input. + """ + if points.size == 0 or angle == 0: + return points + + center_array = np.zeros(2, dtype=float) if center is None else np.asarray(center, dtype=float) + rotation = np.array( + [[np.cos(angle), -np.sin(angle)], [np.sin(angle), np.cos(angle)]], dtype=float + ) + return (points - center_array) @ rotation.T + center_array + + +def _filter_points_in_box( + points: np.ndarray, x_center: float, y_center: float, x_range: float, y_range: float +) -> np.ndarray: + """Keep points inside the centered rectangular scan bounds.""" + if points.size == 0: + return points.reshape(0, 2) + + half_x = x_range / 2 + half_y = y_range / 2 + mask = ( + (points[:, 0] >= x_center - half_x) + & (points[:, 0] <= x_center + half_x) + & (points[:, 1] >= y_center - half_y) + & (points[:, 1] <= y_center + half_y) + ) + return points[mask] + + +def spiral_positions( + x_center: float, + y_center: float, + x_range: float, + y_range: float, + dr: float, + nth: float, + dr_y: float | None = None, + tilt: float = 0.0, +) -> np.ndarray: + """ + Generate an Archimedean spiral scan trajectory. + + The spiral is centered at ``(x_center, y_center)`` and clipped to the rectangular + region defined by ``x_range`` and ``y_range``. + """ + if dr <= 0: + raise ValueError("dr must be positive") + if nth <= 0: + raise ValueError("nth must be positive") + + dr_y = dr if dr_y is None else dr_y + if dr_y <= 0: + raise ValueError("dr_y must be positive") + + half_x = x_range / 2 + half_y = y_range / 2 + max_radius = max(half_x / dr, half_y / dr_y) + max_theta = 2 * np.pi * max_radius + dtheta = 2 * np.pi / nth + + theta = np.arange(0.0, max_theta + dtheta, dtheta, dtype=float) + x = (dr / (2 * np.pi)) * theta * np.cos(theta) + y = (dr_y / (2 * np.pi)) * theta * np.sin(theta) + points = np.column_stack((x, y)) + points = rotate_points(points, tilt) + points[:, 0] += x_center + points[:, 1] += y_center + + return _filter_points_in_box(points, x_center, y_center, x_range, y_range) + + +def line_scan_positions( + axes: list[tuple[float, float]], steps: int, endpoint: bool = True +) -> np.ndarray: + """ + Generate linearly spaced positions for one or more axes. + + Args: + axes (list[tuple[float, float]]): Sequence of ``(start, stop)`` pairs, one per axis. + steps (int): Number of points to generate along the trajectory. + endpoint (bool): If True, include the stop value in the generated positions. + + Returns: + np.ndarray: Array of shape ``(steps, len(axes))`` containing the scan positions. + """ + if steps <= 0: + raise ValueError("steps must be positive") + + axis_positions = [ + np.linspace(start, stop, steps, dtype=float, endpoint=endpoint) for start, stop in axes + ] + return np.column_stack(axis_positions) + + +def log_scan_positions(axes: list[tuple[float, float]], steps: int) -> np.ndarray: + """ + Generate positions with logarithmically increasing step sizes. + + The logarithmic spacing is applied to the normalized distance between each + ``start`` and ``stop`` pair, not to the absolute position values. This means + ranges may include zero or cross zero. + + Args: + axes (list[tuple[float, float]]): Sequence of ``(start, stop)`` pairs, one per axis. + steps (int): Number of points to generate along the trajectory. + + Returns: + np.ndarray: Array of shape ``(steps, len(axes))`` containing the scan positions. + + Raises: + ValueError: If ``steps`` is not positive. + """ + if steps <= 0: + raise ValueError("steps must be positive") + + # Log spacing from 0 to 1 + log_progress = (np.logspace(0, 1, steps, dtype=float) - 1) / 9 + + axis_positions = [] + for start, stop in axes: + axis_positions.append(start + log_progress * (stop - start)) + return np.column_stack(axis_positions) + + +def oscillating_positions( + values: Sequence[float], repeat_turning_points: bool = False +) -> Iterator[float]: + """ + Yield values indefinitely in a back-and-forth pattern. + + For a single value, the same value is yielded repeatedly. For multiple values, + the sequence is traversed to the end and then back toward the beginning. By default + the turning points are not repeated twice in a row, but this can be enabled with + ``repeat_turning_points``. + + Args: + values (Sequence[float]): Ordered values to oscillate through. + repeat_turning_points (bool): If ``True``, repeat the end points before + reversing direction. Default is ``False``. + + Yields: + float: The next value in the oscillating sequence. + + Raises: + ValueError: If ``values`` is empty. + + Examples: + Call ``next(...)`` on the returned generator to retrieve the next value + in the oscillating sequence: + + >>> pos_generator = oscillating_positions([600.0, 620.0, 640.0]) + >>> for _ in range(6): + ... value = next(pos_generator) + ... print(value) + 600.0 + 620.0 + 640.0 + 620.0 + 600.0 + 620.0 + """ + if not values: + raise ValueError("values must contain at least one position") + + if len(values) == 1: + while True: + yield float(values[0]) + + index = 0 + direction = 1 + repeated_turning_point = False + while True: + yield float(values[index]) + if index == len(values) - 1: + if repeat_turning_points and not repeated_turning_point: + direction = -1 + repeated_turning_point = True + else: + direction = -1 + index += direction + repeated_turning_point = False + elif index == 0: + if repeat_turning_points and not repeated_turning_point: + direction = 1 + repeated_turning_point = True + else: + direction = 1 + index += direction + repeated_turning_point = False + else: + index += direction + repeated_turning_point = False + + +def _region_points(start: float, stop: float, steps: int) -> np.ndarray: + """ + Generate positions for one inclusive scan region. + + Args: + start (float): Region start position. + stop (float): Region stop position. + steps (int): Number of points in the region, including start and stop. + + Returns: + np.ndarray: 1D array of positions covering the region in scan order. + """ + if steps <= 0: + raise ValueError("steps must be positive") + return np.linspace(start, stop, steps, dtype=float) + + +def multi_region_line_positions(regions: list[tuple[float, float, int]]) -> np.ndarray: + """ + Generate a 1D trajectory across multiple disjoint scan regions. + + Args: + regions (list[tuple[float, float, int]]): Sequence of ``(start, stop, steps)`` + region definitions. + + Returns: + np.ndarray: Array of shape ``(N, 1)`` containing the concatenated positions. + """ + if not regions: + raise ValueError("regions must contain at least one region") + + concatenated = [] + for start, stop, steps in regions: + region_values = _region_points(start, stop, steps) + if concatenated and np.isclose(concatenated[-1], region_values[0]): + region_values = region_values[1:] + concatenated.extend(region_values.tolist()) + return np.asarray(concatenated, dtype=float)[:, np.newaxis] + + +def multi_region_grid_positions( + regions: list[tuple[tuple[float, float, int], tuple[float, float, int]]], snaked: bool = True +) -> np.ndarray: + """ + Generate multiple rectangular sub-grids from paired scan regions. + + Args: + regions (list[tuple[tuple[float, float, int], tuple[float, float, int]]]): Sequence + of paired region definitions. Each entry contains one + ``(start, stop, steps)`` tuple for the first motor and one for the second motor. + snaked (bool): If ``True``, reverse traversal of the second axis on alternating positions + within each sub-grid. + + Returns: + np.ndarray: Array of shape ``(N, 2)`` containing the concatenated scan positions. + """ + if not regions: + raise ValueError("regions must contain at least one paired region") + + positions: list[list[float]] = [] + for region1, region2 in regions: + axis1_positions = _region_points(*region1) + axis2_positions = _region_points(*region2) + + for index, value1 in enumerate(axis1_positions): + current_axis2 = ( + axis2_positions[::-1] if snaked and (index % 2 == 1) else axis2_positions + ) + for value2 in current_axis2: + positions.append([value1, value2]) + + return np.asarray(positions, dtype=float) + + +def nd_grid_positions(axes: list[tuple[float, float, int]], snaked: bool = True) -> np.ndarray: + """ + Generate N-dimensional grid positions. + It creates a grid of positions for N dimensions, with optional snaking behavior. + + snaked==True: + ->->->->- + -<-<-<-<- + ->->->->- + snaked==False: + ->->->->- + ->->->->- + ->->->->- + + Args: + axes (list of tuples): list of tuples (start, stop, step) for each axis + snaked (bool, optional): If True, the grid is generated in a "snaked" + pattern across all dimensions. + + Returns: + np.ndarray: shape (num_points, N) + """ + _axes_arrays = [] + for start, stop, step in axes: + if step <= 0: + raise ValueError("Step size must be positive") + _axes_arrays.append(np.linspace(start, stop, step, dtype=float)) + + def _get_positions_recursively(current_axes): + if len(current_axes) == 1: + return [[v] for v in current_axes[0]] + + positions = [] + for i, val in enumerate(current_axes[0]): + sub_positions = _get_positions_recursively(current_axes[1:]) + if snaked and (i % 2 == 1): + sub_positions.reverse() + positions.extend([[val] + sp for sp in sub_positions]) + return positions + + return np.array(_get_positions_recursively(_axes_arrays)) + + +def fermat_spiral_pos( + m1_start: float, + m1_stop: float, + m2_start: float, + m2_stop: float, + step: float = 1, + spiral_type: float = 0, + center: bool = False, +) -> np.ndarray: + """ + fermat_spiral_pos calculates and returns the positions for a Fermat spiral scan. + + Args: + m1_start (float): start position motor 1 + m1_stop (float): end position motor 1 + m2_start (float): start position motor 2 + m2_stop (float): end position motor 2 + step (float, optional): Step size. Defaults to 1. + spiral_type (float, optional): Angular offset in radians that determines the shape of the spiral. + A spiral with spiral_type=2 is the same as spiral_type=0. Defaults to 0. + center (bool, optional): Add a center point. Defaults to False. + + Returns: + np.ndarray: calculated positions in the form [[m1, m2], ...] + """ + if step <= 0: + raise ValueError("step must be positive") + + phi = np.pi * (3 - np.sqrt(5)) + spiral_type * np.pi + start_index = 0 if center else 1 + + x_center = (m1_start + m1_stop) / 2 + y_center = (m2_start + m2_stop) / 2 + x_range = abs(m1_stop - m1_start) + y_range = abs(m2_stop - m2_start) + half_x = x_range / 2 + half_y = y_range / 2 + + radial_scale = step / np.sqrt(np.pi) + max_index = max(1, int(np.ceil((max(half_x, half_y) / radial_scale) ** 2)) * 2) + + points = [] + for ii in range(start_index, max_index + 1): + radius = radial_scale * np.sqrt(ii) + x = x_center + radius * np.cos(ii * phi) + y = y_center + radius * np.sin(ii * phi) + if not (m1_start <= x <= m1_stop or m1_stop <= x <= m1_start): + continue + if not (m2_start <= y <= m2_stop or m2_stop <= y <= m2_start): + continue + points.append((x, y)) + + return np.asarray(points, dtype=float) + + +def round_scan_positions( + inner_radius: float, + outer_radius: float, + number_of_rings: int, + points_in_first_ring: int, + center_1: float = 0, + center_2: float = 0, +) -> np.ndarray: + """ + Calculate positions for a circular shell scan. + + Args: + inner_radius (float): inner radius + outer_radius (float): outer radius + number_of_rings (int): number of radii + points_in_first_ring (int): number of angles in the inner ring + center_1 (float, optional): center position for axis 1. Defaults to 0. + center_2 (float, optional): center position for axis 2. Defaults to 0. + + Returns: + np.ndarray: calculated positions in the form [[x, y], ...] + """ + positions = [] + radius_step = (inner_radius - outer_radius) / number_of_rings + for ring_index in range(1, number_of_rings + 2): + radius = inner_radius + ring_index * radius_step + points_on_ring = points_in_first_ring * ring_index + angular_step = 2 * np.pi / points_on_ring + positions.extend( + [ + ( + radius * np.sin(point_index * angular_step) + center_1, + radius * np.cos(point_index * angular_step) + center_2, + ) + for point_index in range(points_on_ring) + ] + ) + positions_array = np.array(positions, dtype=float) + return positions_array + + +def get_round_roi_scan_positions( + motor_1_start: float, + motor_1_stop: float, + motor_2_start: float, + motor_2_stop: float, + radial_step: float, + points_in_first_shell: int, + center_1: float = 0, + center_2: float = 0, +): + """ + Calculate round scan positions clipped to a rectangular region of interest. + + The circular shells are centered around ``center_1`` / ``center_2``. The center does + not need to be inside the rectangular ROI defined by the motor start/stop + bounds. + + Args: + motor_1_start (float): start position of the ROI for motor 1 + motor_1_stop (float): stop position of the ROI for motor 1 + motor_2_start (float): start position of the ROI for motor 2 + motor_2_stop (float): stop position of the ROI for motor 2 + radial_step (float): radial shell spacing + points_in_first_shell (int): number of angles in the first shell + center_1 (float, optional): center position for motor 1. Defaults to 0. + center_2 (float, optional): center position for motor 2. Defaults to 0. + + Returns: + np.ndarray: calculated positions in the form [[x, y], ...] + """ + motor_1_min, motor_1_max = sorted((motor_1_start, motor_1_stop)) + motor_2_min, motor_2_max = sorted((motor_2_start, motor_2_stop)) + corners = [ + (motor_1_min, motor_2_min), + (motor_1_min, motor_2_max), + (motor_1_max, motor_2_min), + (motor_1_max, motor_2_max), + ] + max_radius = max( + np.hypot(motor_1_position - center_1, motor_2_position - center_2) + for motor_1_position, motor_2_position in corners + ) + + positions = [] + number_of_shells = 1 + int(np.ceil(max_radius / radial_step)) + for shell_index in range(1, number_of_shells + 2): + radius = shell_index * radial_step + points_on_shell = points_in_first_shell * shell_index + angular_step = 2 * np.pi / points_on_shell + for point_index in range(points_on_shell): + angle = point_index * angular_step + local_position = np.array( + [[radius * np.cos(angle), radius * np.sin(angle)]], dtype=float + ) + motor_1_offset, motor_2_offset = local_position[0] + motor_1_position = motor_1_offset + center_1 + motor_2_position = motor_2_offset + center_2 + if not ( + motor_1_min <= motor_1_position <= motor_1_max + and motor_2_min <= motor_2_position <= motor_2_max + ): + continue + positions.append((motor_1_position, motor_2_position)) + return np.array(positions, dtype=float) + + +def hex_grid_2d(axes: list[tuple[float, float, float]], snaked: bool = True) -> np.ndarray: + """ + Generate a 2D hexagonal grid clipped to (start, stop) bounds. + + Args: + axes: [(x_start, x_stop, x_step), + (y_start, y_stop, y_step)] + x_step = horizontal spacing between columns + y_step = vertical spacing between rows + snaked: if True, reverse direction on alternate rows to minimize travel distance + + Returns: + np.ndarray of shape (N, 2) + """ + if len(axes) != 2: + raise ValueError("2D hex grid requires exactly 2 dimensions") + + (x0, x1, sx), (y0, y1, sy) = axes + + points = [] + + # Number of rows needed + n_rows = int(np.ceil((y1 - y0) / sy)) + 2 + + for row in range(n_rows): + y = y0 + row * sy + + # Alternate row offset - shift by half the x step + x_offset = (sx / 2) if (row % 2) else 0.0 + + # Number of columns needed + n_cols = int(np.ceil((x1 - x0) / sx)) + 2 + + row_points = [] + for col in range(n_cols): + x = x0 + x_offset + col * sx + + if x0 <= x <= x1 and y0 <= y <= y1: + row_points.append((x, y)) + + # Reverse every other row if snaking is enabled + if snaked and (row % 2 == 1): + row_points.reverse() + + points.extend(row_points) + + return np.asarray(points, dtype=float) diff --git a/bec_server/bec_server/scan_server/scans/round_roi_scan.py b/bec_server/bec_server/scan_server/scans/round_roi_scan.py new file mode 100644 index 000000000..1a7c68cbb --- /dev/null +++ b/bec_server/bec_server/scan_server/scans/round_roi_scan.py @@ -0,0 +1,303 @@ +""" +Round ROI scan implementation for circular region-of-interest area scans. + +Scan procedure: + - prepare_scan + - open_scan + - stage + - pre_scan + - scan_core + - at_each_point (optionally called by scan_core) + - post_scan + - unstage + - close_scan + - on_exception (called if any exception is raised during the scan) +""" + +from __future__ import annotations + +from typing import Annotated + +import numpy as np + +from bec_lib.device import DeviceBase +from bec_lib.scan_args import ScanArgument, Units +from bec_server.scan_server.scans import position_generators +from bec_server.scan_server.scans.scan_modifier import scan_hook +from bec_server.scan_server.scans.scans_v4 import ScanBase, ScanType + + +class RoundROIScan(ScanBase): + # Scan Type: Hardware triggered or software triggered? + # If the main trigger and readout logic is done within the at_each_point method in scan_core, choose SOFTWARE_TRIGGERED. + # If the main trigger and readout logic is implemented on a device that is simply kicked off in this scan, choose HARDWARE_TRIGGERED. + # This primarily serves as information for devices: The device may need to react differently if a software trigger is expected + # for every point. + scan_type = ScanType.SOFTWARE_TRIGGERED + + # Scan name: This is the name of the scan, e.g. "line_scan". This is used for display purposes and to identify the scan type in user interfaces. + # Choose a descriptive name that does not conflict with existing scan names. + scan_name = "_v4_round_roi_scan" + required_kwargs = ["shell_spacing", "pos_in_first_ring", "relative"] + + gui_config = { + "Motor 1": ["motor_1", "start_motor_1", "stop_motor_1", "center_1"], + "Motor 2": ["motor_2", "start_motor_2", "stop_motor_2", "center_2"], + "Shell Parameters": ["shell_spacing", "pos_in_first_ring"], + "Acquisition Parameters": [ + "exp_time", + "frames_per_trigger", + "settling_time", + "settling_time_after_trigger", + "readout_time", + "burst_at_each_point", + "relative", + ], + } + + def __init__( + self, + motor_1: DeviceBase, + start_motor_1: Annotated[ + float, ScanArgument(display_name="Start Position", reference_units="motor_1") + ], + stop_motor_1: Annotated[ + float, ScanArgument(display_name="Stop Position", reference_units="motor_1") + ], + motor_2: DeviceBase, + start_motor_2: Annotated[ + float, ScanArgument(display_name="Start Position", reference_units="motor_2") + ], + stop_motor_2: Annotated[ + float, ScanArgument(display_name="Stop Position", reference_units="motor_2") + ], + shell_spacing: Annotated[ + float, ScanArgument(display_name="Shell Spacing", reference_units="motor_1", gt=0) + ] = 1, + pos_in_first_ring: Annotated[ + int, ScanArgument(display_name="Number of Points in First Shell", ge=1) + ] = 5, + exp_time: Annotated[ + float, ScanArgument(display_name="Exposure Time", units=Units.s, ge=0) + ] = 0, + frames_per_trigger: Annotated[ + int, ScanArgument(display_name="Frames per Trigger", ge=1) + ] = 1, + settling_time: Annotated[ + float, ScanArgument(display_name="Settling Time", units=Units.s, ge=0) + ] = 0, + settling_time_after_trigger: Annotated[ + float, ScanArgument(display_name="Settling Time After Trigger", units=Units.s, ge=0) + ] = 0, + readout_time: Annotated[ + float, ScanArgument(display_name="Readout Time", units=Units.s, ge=0) + ] = 0, + relative: bool = False, + center_1: Annotated[ + float, ScanArgument(display_name="Center Motor 1", reference_units="motor_1") + ] = 0, + center_2: Annotated[ + float, ScanArgument(display_name="Center Motor 2", reference_units="motor_2") + ] = 0, + burst_at_each_point: Annotated[ + int, ScanArgument(display_name="Burst at Each Point", ge=1) + ] = 1, + **kwargs, + ): + """ + A scan following a round-roi-like pattern. + + Args: + motor_1 (DeviceBase): first motor + start_motor_1 (float): start position of the ROI for motor_1 + stop_motor_1 (float): stop position of the ROI for motor_1 + motor_2 (DeviceBase): second motor + start_motor_2 (float): start position of the ROI for motor_2 + stop_motor_2 (float): stop position of the ROI for motor_2 + shell_spacing (float): shell width. Default is 1. + pos_in_first_ring (int): number of points in the first shell. Default is 5. + exp_time (float): exposure time in seconds. Default is 0. + frames_per_trigger (int): number of frames acquired per trigger. Default is 1. + settling_time (float): settling time in seconds. Default is 0. + settling_time_after_trigger (float): settling time after trigger in seconds. Default is 0. + readout_time (float): readout time in seconds. Default is 0. + relative (bool): Start from an absolute or relative position. Default is False. + center_1 (float): center position for motor_1. The center may be outside the ROI. + Default is 0. + center_2 (float): center position for motor_2. The center may be outside the ROI. + Default is 0. + burst_at_each_point (int): number of acquisition per point. Default is 1. + + Returns: + ScanReport + + Examples: + >>> scans.round_roi_scan(dev.motor1, -10, 10, dev.motor2, -10, 10, shell_spacing=2, pos_in_first_ring=3, exp_time=0.1) + """ + super().__init__(**kwargs) + self.motors = [motor_1, motor_2] + self.start_motor_1 = start_motor_1 + self.stop_motor_1 = stop_motor_1 + self.start_motor_2 = start_motor_2 + self.stop_motor_2 = stop_motor_2 + self.center_1 = center_1 + self.center_2 = center_2 + self.shell_spacing = shell_spacing + self.pos_in_first_ring = pos_in_first_ring + self.relative = relative + self.exp_time = exp_time + self.settling_time = settling_time + self.burst_at_each_point = burst_at_each_point + + # Update the default scan info with provided parameters. + self.update_scan_info( + exp_time=exp_time, + frames_per_trigger=frames_per_trigger, + settling_time=settling_time, + settling_time_after_trigger=settling_time_after_trigger, + readout_time=readout_time, + relative=relative, + burst_at_each_point=burst_at_each_point, + scan_report_devices=self.motors, + ) + + # We elevate the readout priority of the scan motors to "monitored" to ensure + # that their positions are included in every readout of the step scan. + self.actions.set_device_readout_priority(self.motors, priority="monitored") + + @scan_hook + def prepare_scan(self): + """ + Prepare the scan. This can include any steps that need to be executed + before the scan is opened, such as preparing the positions (if not done already) + or setting up the devices. + """ + self.positions = position_generators.get_round_roi_scan_positions( + motor_1_start=self.start_motor_1, + motor_1_stop=self.stop_motor_1, + motor_2_start=self.start_motor_2, + motor_2_stop=self.stop_motor_2, + radial_step=self.shell_spacing, + points_in_first_shell=self.pos_in_first_ring, + center_1=self.center_1, + center_2=self.center_2, + ) + + if self.relative: + self.start_positions = self.components.get_start_positions(self.motors) + self.positions += self.start_positions + + self.components.check_limits(self.motors, self.positions) + + self.update_scan_info( + positions=self.positions, + num_points=len(self.positions), + num_monitored_readouts=len(self.positions) * self.burst_at_each_point, + ) + + self.actions.add_scan_report_instruction_scan_progress( + points=self.scan_info.num_monitored_readouts, show_table=False + ) + + self._premove_motor_status = self.actions.set(self.motors, self.positions[0], wait=False) + self._baseline_readout_status = self.actions.read_baseline_devices(wait=False) + + @scan_hook + def open_scan(self): + """ + Open the scan. + This step must call self.actions.open_scan() to ensure that a new scan is + opened. Make sure to prepare the scan metadata before, either in + prepare_scan() or in open_scan() itself and call self.update_scan_info(...) + to update the scan metadata if needed. + """ + self.actions.open_scan() + + @scan_hook + def stage(self): + """ + Stage the devices for the upcoming scan. The stage logic is typically + implemented on the device itself (i.e. by the device's stage method). + However, if there are any additional steps that need to be executed before + staging the devices, they can be implemented here. + """ + self.actions.stage_all_devices() + + @scan_hook + def pre_scan(self): + """ + Pre-scan steps to be executed before the main scan logic. + This is typically the last chance to prepare the devices before the core scan + logic is executed. For example, this is a good place to initialize time-criticial + devices, e.g. devices that have a short timeout. + The pre-scan logic is typically implemented on the device itself. + """ + self._premove_motor_status.wait() + self.actions.pre_scan_all_devices() + + @scan_hook + def scan_core(self): + """ + Core scan logic to be executed during the scan. + This is where the main scan logic should be implemented. + """ + self.components.step_scan( + self.motors, + self.positions, + at_each_point=self.at_each_point, + last_positions=self.positions[0], + ) + + @scan_hook + def at_each_point( + self, + motors: list[str | DeviceBase], + positions: np.ndarray, + last_positions: np.ndarray | None, + ): + """ + Logic to be executed at each point during the scan. This is called by the step_scan method at each point. + + Args: + motors (list[str | DeviceBase]): List of motor names or device instances being moved. + positions (np.ndarray): Current positions of the motors, shape (len(motors),). + last_positions (np.ndarray | None): Previous positions of the motors, shape + (len(motors),) or None if this is the first point. + """ + self.components.step_scan_at_each_point(motors, positions, last_positions=last_positions) + + @scan_hook + def post_scan(self): + """ + Post-scan steps to be executed after the main scan logic. + """ + status = self.actions.complete_all_devices(wait=False) + + if self.relative: + # Move the motors back to their starting position + self.components.move_and_wait(self.motors, self.start_positions) + status.wait() + + @scan_hook + def unstage(self): + """Unstage the scan by executing post-scan steps.""" + self.actions.unstage_all_devices() + + @scan_hook + def close_scan(self): + """Close the scan.""" + if self._baseline_readout_status is not None: + self._baseline_readout_status.wait() + self.actions.close_scan() + self.actions.check_for_unchecked_statuses() + + @scan_hook + def on_exception(self, exception: Exception): + """ + Handle exceptions that occur during the scan. + This is a good place to implement any cleanup logic that needs to be executed in case of an exception, + such as returning the devices to a safe state or moving the motors back to their starting position. + """ + if self.relative: + # Move the motors back to their starting position + self.components.move_and_wait(self.motors, self.start_positions) diff --git a/bec_server/bec_server/scan_server/scans/round_scan.py b/bec_server/bec_server/scan_server/scans/round_scan.py new file mode 100644 index 000000000..5320373e0 --- /dev/null +++ b/bec_server/bec_server/scan_server/scans/round_scan.py @@ -0,0 +1,294 @@ +""" +Round scan implementation for circular two-motor step trajectories. + +Scan procedure: + - prepare_scan + - open_scan + - stage + - pre_scan + - scan_core + - at_each_point (optionally called by scan_core) + - post_scan + - unstage + - close_scan + - on_exception (called if any exception is raised during the scan) +""" + +from __future__ import annotations + +from typing import Annotated + +import numpy as np + +from bec_lib.device import DeviceBase +from bec_lib.scan_args import ScanArgument, Units +from bec_server.scan_server.scans import position_generators +from bec_server.scan_server.scans.scan_modifier import scan_hook +from bec_server.scan_server.scans.scans_v4 import ScanBase, ScanType + + +class RoundScan(ScanBase): + # Scan Type: Hardware triggered or software triggered? + # If the main trigger and readout logic is done within the at_each_point method in scan_core, choose SOFTWARE_TRIGGERED. + # If the main trigger and readout logic is implemented on a device that is simply kicked off in this scan, choose HARDWARE_TRIGGERED. + # This primarily serves as information for devices: The device may need to react differently if a software trigger is expected + # for every point. + scan_type = ScanType.SOFTWARE_TRIGGERED + + # Scan name: This is the name of the scan, e.g. "line_scan". This is used for display purposes and to identify the scan type in user interfaces. + # Choose a descriptive name that does not conflict with existing scan names. + scan_name = "_v4_round_scan" + required_kwargs = ["relative"] + + gui_config = { + "Motors": ["motor_1", "motor_2"], + "Ring Parameters": [ + "inner_radius", + "outer_radius", + "center_1", + "center_2", + "number_of_rings", + "pos_in_first_ring", + ], + "Scan Parameters": ["relative", "burst_at_each_point"], + "Acquisition Parameters": [ + "exp_time", + "frames_per_trigger", + "settling_time", + "settling_time_after_trigger", + "readout_time", + "burst_at_each_point", + ], + } + + def __init__( + self, + motor_1: DeviceBase, + motor_2: DeviceBase, + inner_radius: Annotated[ + float, ScanArgument(display_name="Inner Radius", reference_units="motor_1", ge=0) + ], + outer_radius: Annotated[ + float, ScanArgument(display_name="Outer Radius", reference_units="motor_1", ge=0) + ], + number_of_rings: Annotated[int, ScanArgument(display_name="Number of Rings", ge=1)], + pos_in_first_ring: Annotated[ + int, ScanArgument(display_name="Positions in First Ring", ge=1) + ], + exp_time: Annotated[ + float, ScanArgument(display_name="Exposure Time", units=Units.s, ge=0) + ] = 0, + frames_per_trigger: Annotated[ + int, ScanArgument(display_name="Frames per Trigger", ge=1) + ] = 1, + settling_time: Annotated[ + float, ScanArgument(display_name="Settling Time", units=Units.s, ge=0) + ] = 0, + settling_time_after_trigger: Annotated[ + float, ScanArgument(display_name="Settling Time After Trigger", units=Units.s, ge=0) + ] = 0, + readout_time: Annotated[ + float, ScanArgument(display_name="Readout Time", units=Units.s, ge=0) + ] = 0, + relative: bool = False, + center_1: Annotated[ + float, ScanArgument(display_name="Center Motor 1", reference_units="motor_1") + ] = 0, + center_2: Annotated[ + float, ScanArgument(display_name="Center Motor 2", reference_units="motor_2") + ] = 0, + burst_at_each_point: Annotated[ + int, ScanArgument(display_name="Burst at Each Point", ge=1) + ] = 1, + **kwargs, + ): + """ + A scan following a round shell-like pattern with increasing number of points in each ring. The scan starts at the inner ring and moves outwards. + The user defines the inner and outer radius, the number of rings and the number of positions in the first ring. + + Args: + motor_1 (DeviceBase): first motor + motor_2 (DeviceBase): second motor + inner_radius (float): inner radius + outer_radius (float): outer radius + number_of_rings (int): number of rings + pos_in_first_ring (int): number of positions in the first ring + exp_time (Annotated[float, Units.s]): exposure time in seconds. Default is 0. + frames_per_trigger (Annotated[int]): number of frames acquired per trigger. Default is 1. + settling_time (Annotated[float, Units.s]): settling time in seconds. Default is 0. + settling_time_after_trigger (Annotated[float, Units.s]): settling time after trigger in seconds. Default is 0. + readout_time (Annotated[float, Units.s]): readout time in seconds. Default is 0. + relative (bool): if True, the motors will be moved relative to their current position. Default is False. + center_1 (float): center position for motor_1. Default is 0. + center_2 (float): center position for motor_2. Default is 0. + burst_at_each_point (int): number of exposures at each point. Default is 1. + + Returns: + ScanReport + + Examples: + >>> scans.round_scan(dev.motor1, dev.motor2, 0, 25, 5, 3, exp_time=0.1, relative=True) + """ + super().__init__(**kwargs) + self.motors = [motor_1, motor_2] + self.inner_radius = inner_radius + self.outer_radius = outer_radius + self.number_of_rings = number_of_rings + self.pos_in_first_ring = pos_in_first_ring + self.center_1 = center_1 + self.center_2 = center_2 + self.relative = relative + self.exp_time = exp_time + self.settling_time = settling_time + self.burst_at_each_point = burst_at_each_point + + # Update the default scan info with provided parameters. + self.update_scan_info( + exp_time=exp_time, + frames_per_trigger=frames_per_trigger, + settling_time=settling_time, + settling_time_after_trigger=settling_time_after_trigger, + readout_time=readout_time, + relative=relative, + burst_at_each_point=burst_at_each_point, + scan_report_devices=self.motors, + ) + + # We elevate the readout priority of the scan motors to "monitored" to ensure + # that their positions are included in every readout of the step scan. + self.actions.set_device_readout_priority(self.motors, priority="monitored") + + @scan_hook + def prepare_scan(self): + """ + Prepare the scan. This can include any steps that need to be executed + before the scan is opened, such as preparing the positions (if not done already) + or setting up the devices. + """ + self.positions = position_generators.round_scan_positions( + inner_radius=self.inner_radius, + outer_radius=self.outer_radius, + number_of_rings=self.number_of_rings, + points_in_first_ring=self.pos_in_first_ring, + center_1=self.center_1, + center_2=self.center_2, + ) + + if self.relative: + self.start_positions = self.components.get_start_positions(self.motors) + self.positions += self.start_positions + + self.components.check_limits(self.motors, self.positions) + + self.update_scan_info( + positions=self.positions, + num_points=len(self.positions), + num_monitored_readouts=len(self.positions) * self.burst_at_each_point, + ) + + self.actions.add_scan_report_instruction_scan_progress( + points=self.scan_info.num_monitored_readouts, show_table=False + ) + + self._premove_motor_status = self.actions.set(self.motors, self.positions[0], wait=False) + self._baseline_readout_status = self.actions.read_baseline_devices(wait=False) + + @scan_hook + def open_scan(self): + """ + Open the scan. + This step must call self.actions.open_scan() to ensure that a new scan is + opened. Make sure to prepare the scan metadata before, either in + prepare_scan() or in open_scan() itself and call self.update_scan_info(...) + to update the scan metadata if needed. + """ + self.actions.open_scan() + + @scan_hook + def stage(self): + """ + Stage the devices for the upcoming scan. The stage logic is typically + implemented on the device itself (i.e. by the device's stage method). + However, if there are any additional steps that need to be executed before + staging the devices, they can be implemented here. + """ + self.actions.stage_all_devices() + + @scan_hook + def pre_scan(self): + """ + Pre-scan steps to be executed before the main scan logic. + This is typically the last chance to prepare the devices before the core scan + logic is executed. For example, this is a good place to initialize time-criticial + devices, e.g. devices that have a short timeout. + The pre-scan logic is typically implemented on the device itself. + """ + self._premove_motor_status.wait() + self.actions.pre_scan_all_devices() + + @scan_hook + def scan_core(self): + """ + Core scan logic to be executed during the scan. + This is where the main scan logic should be implemented. + """ + self.components.step_scan( + self.motors, + self.positions, + at_each_point=self.at_each_point, + last_positions=self.positions[0], + ) + + @scan_hook + def at_each_point( + self, + motors: list[str | DeviceBase], + positions: np.ndarray, + last_positions: np.ndarray | None, + ): + """ + Logic to be executed at each point during the scan. This is called by the step_scan method at each point. + + Args: + motors (list[str | DeviceBase]): List of motor names or device instances being moved. + positions (np.ndarray): Current positions of the motors, shape (len(motors),). + last_positions (np.ndarray | None): Previous positions of the motors, shape + (len(motors),) or None if this is the first point. + """ + self.components.step_scan_at_each_point(motors, positions, last_positions=last_positions) + + @scan_hook + def post_scan(self): + """ + Post-scan steps to be executed after the main scan logic. + """ + status = self.actions.complete_all_devices(wait=False) + + if self.relative: + # Move the motors back to their starting position + self.components.move_and_wait(self.motors, self.start_positions) + status.wait() + + @scan_hook + def unstage(self): + """Unstage the scan by executing post-scan steps.""" + self.actions.unstage_all_devices() + + @scan_hook + def close_scan(self): + """Close the scan.""" + if self._baseline_readout_status is not None: + self._baseline_readout_status.wait() + self.actions.close_scan() + self.actions.check_for_unchecked_statuses() + + @scan_hook + def on_exception(self, exception: Exception): + """ + Handle exceptions that occur during the scan. + This is a good place to implement any cleanup logic that needs to be executed in case of an exception, + such as returning the devices to a safe state or moving the motors back to their starting position. + """ + if self.relative: + # Move the motors back to their starting position + self.components.move_and_wait(self.motors, self.start_positions) diff --git a/bec_server/bec_server/scan_server/scans/scan_actions.py b/bec_server/bec_server/scan_server/scans/scan_actions.py new file mode 100644 index 000000000..077d0c709 --- /dev/null +++ b/bec_server/bec_server/scan_server/scans/scan_actions.py @@ -0,0 +1,1215 @@ +from __future__ import annotations + +import os +import time +import uuid +from string import Template +from typing import TYPE_CHECKING, Any, Callable, Literal, TypeAlias + +import numpy as np + +from bec_lib import messages +from bec_lib.alarm_handler import Alarms +from bec_lib.device import DeviceBase +from bec_lib.endpoints import MessageEndpoints +from bec_lib.file_utils import compile_file_components +from bec_lib.logger import bec_logger +from bec_server.scan_server.scan_stubs import ScanStubStatus + +if TYPE_CHECKING: + from bec_server.scan_server.scans.scans_v4 import ScanBase, ScanInfo + +logger = bec_logger.logger + +ReadoutPriorityMap: TypeAlias = dict[ + Literal["monitored", "baseline", "async", "continuous", "on_request"], list[str] +] + + +class ScanActions: + """Class to handle the core actions for the scan logic.""" + + def __init__(self, scan: ScanBase): + self._scan = scan + self._connector = scan.redis_connector + self._device_manager = scan.device_manager + self._instruction_handler = scan._instruction_handler + self._status_registry = {} + self._shutdown_event = scan._shutdown_event + self._num_monitored_readouts = 0 + self._interruption_callback: Callable[[], None] | None = None + self._update_queue_info_callback: Callable[[], None] | None = None + self._devices_with_required_response = set() + self._readout_groups_read = False + self._metadata_suffix = "" + + @property + def readout_priority(self) -> dict: + return self._scan.scan_info.readout_priority_modification + + def open_scan(self): + """ + Open the scan. + We fetch all relevant metadata from the scan object and emit a new scan status. + """ + self._send_scan_status("open") + + def stage_all_devices( + self, wait=True, exclude: str | DeviceBase | list[str | DeviceBase] | None = None + ) -> ScanStubStatus: + """ + Stage all devices for the scan. This will call the "stage" method + on all devices. + + If you want to stage only specific devices, use the "stage" method. + + .. note :: + We exclude devices that are on_request or continuous as they are not expected to be staged for a scan. + + Args: + wait (bool, optional): if True, wait for the staging to complete. Defaults to True. + exclude (str | DeviceBase | list[str | DeviceBase] | None, optional): + device(s) to exclude from staging. Defaults to None. + + Returns: + ScanStubStatus: status object to track the staging process + """ + status = self._create_status(is_container=True, name="stage_all_devices") + + # We separate the staging of async devices and regular devices to optimize the staging process. + # Async devices are typically slower to stage and should be staged in parallel. + async_devices = self._device_manager.devices.async_devices( + readout_priority=self.readout_priority + ) + excluded_devices = [device.name for device in async_devices] + excluded_devices.extend( + device.name + for device in self._device_manager.devices.on_request_devices( + readout_priority=self.readout_priority + ) + ) + excluded_devices.extend( + device.name + for device in self._device_manager.devices.continuous_devices( + readout_priority=self.readout_priority + ) + ) + excluded_device_names = set(excluded_devices) + user_excluded_device_names = set() + if exclude is not None: + user_excluded_device_names = set(self._normalize_device_names(exclude)) + excluded_device_names.update(user_excluded_device_names) + + if async_devices: + async_devices = sorted(async_devices, key=lambda x: x.name) + async_devices = [ + device for device in async_devices if device.name not in user_excluded_device_names + ] + + for det in async_devices: + sub_status = self.stage(det, status_name=f"stage_{det.name}", wait=False) + status.add_status(sub_status) + + # Now we stage the remaining devices. This will be done sequentially, assuming that + # they are typically no-op or fast operations. + stage_device_names_without_async = [ + dev.root.name + for dev in self._device_manager.devices.enabled_devices + if dev.name not in excluded_device_names + ] + + if stage_device_names_without_async: + sub_status = self.stage( + stage_device_names_without_async, status_name="stage_sync_devices", wait=False + ) + status.add_status(sub_status) + if wait: + status.wait() + return status + + def stage( + self, + device: str | DeviceBase | list[str | DeviceBase], + status_name: str | None = None, + wait=True, + ) -> ScanStubStatus: + """ + Stage a device for the scan. This will call the "stage" method + on the specified device(s). + + If you want to stage all devices, use the `stage_all_devices` method. + + Args: + device (str or DeviceBase or list[str or DeviceBase]): device(s) to stage + status_name (str, optional): name for the status object. Defaults to None. + wait (bool, optional): if True, wait for the staging to complete. Defaults to True. + + Returns: + ScanStubStatus: status object to track the staging process + """ + + # We support str and DeviceBase inputs as well as lists of those. + # We convert them to a list of device names for easier processing. + if isinstance(device, list): + device_names = [] + for dev in device: + if isinstance(dev, DeviceBase): + device_names.append(dev.name) + else: + device_names.append(dev) + else: + device_names = [device.name if isinstance(device, DeviceBase) else device] + if len(device_names) == 1: + device_names = device_names[0] + status = self._create_status(name=status_name or f"stage_{device_names}") + + # If there are no devices to stage, we can immediately set the status to done and return. + if len(device_names) == 0: + status.set_done() + return status + + instr = messages.DeviceInstructionMessage( + device=device_names, + action="stage", + parameter={}, + metadata={"device_instr_id": status._device_instr_id}, + ) + self._send(instr) + if wait: + status.wait() + return status + + def pre_scan( + self, + device: str | DeviceBase | list[str | DeviceBase], + status_name: str | None = None, + wait=True, + ) -> ScanStubStatus: + """ + Run the pre-scan step for one or multiple devices. + + If you want to run pre-scan on all enabled devices, use the + `pre_scan_all_devices` method. + + Args: + device (str | DeviceBase | list[str | DeviceBase]): device(s) to run pre-scan for. + status_name (str, optional): name for the status object. Defaults to None. + wait (bool, optional): if True, wait for completion. Defaults to True. + + Returns: + ScanStubStatus: status object to track the pre-scan process. + """ + device_names = self._normalize_device_names(device) + if len(device_names) == 1: + device_names = device_names[0] + status = self._create_status(name=status_name or f"pre_scan_{device_names}") + + if len(device_names) == 0: + status.set_done() + return status + + instr = messages.DeviceInstructionMessage( + device=device_names, + action="pre_scan", + parameter={}, + metadata={"device_instr_id": status._device_instr_id}, + ) + self._send(instr) + if wait: + status.wait() + return status + + def pre_scan_all_devices( + self, wait=True, exclude: str | DeviceBase | list[str | DeviceBase] | None = None + ) -> ScanStubStatus: + """ + Pre-scan steps to be executed before the main scan logic. This will call + the "pre_scan" method all devices that implement it. + + This is typically the last chance to prepare the devices before the core scan + logic is executed. For example, this is a good place to initialize time-critical + devices, e.g. devices that have a short timeout. + + Args: + wait (bool, optional): if True, wait for the pre-scan steps to complete. Defaults to True. + exclude (str | DeviceBase | list[str | DeviceBase] | None, optional): + device(s) to exclude from pre-scan. Defaults to None. + + Returns: + ScanStubStatus: status object to track the pre-scan process + """ + status = self._create_status(name="pre_scan_all_devices") + + devices = [dev.root.name for dev in self._device_manager.devices.enabled_devices] + if exclude is not None: + excluded_device_names = set(self._normalize_device_names(exclude)) + devices = [ + device_name for device_name in devices if device_name not in excluded_device_names + ] + if devices: + devices = sorted(devices) + + instr = messages.DeviceInstructionMessage( + device=devices, + action="pre_scan", + parameter={}, + metadata={"device_instr_id": status._device_instr_id}, + ) + self._send(instr) + if wait: + status.wait() + return status + + def set( + self, + device: str | DeviceBase | list[str | DeviceBase] | list[str] | list[DeviceBase], + value: float | list[float], + wait=True, + ) -> ScanStubStatus: + """ + Set one or multiple devices to specific values. This will call the "set" method + on the specified device(s) with the given value(s). + + Args: + device (str or DeviceBase or list[str or DeviceBase] or list[str] or list[DeviceBase]): device(s) to set + value (float or list[float]): target value(s) for the device(s) + wait (bool, optional): if True, wait for the set operation to complete. Defaults to True. + + Returns: + ScanStubStatus: status object to track the set process + """ + devices = device if isinstance(device, list) else [device] + values = value.tolist() if isinstance(value, np.ndarray) else value + values = values if isinstance(values, list) else [values] + + if len(devices) != len(values): + raise ValueError("The number of devices and values must match.") + + status = self._create_status(is_container=True, name="set") + for dev, val in zip(devices, values, strict=False): + device_name = dev.name if isinstance(dev, DeviceBase) else dev + sub_status = self._create_status(name=f"set_{device_name}") + instr = messages.DeviceInstructionMessage( + device=device_name, + action="set", + parameter={"value": val}, + metadata={"device_instr_id": sub_status._device_instr_id}, + ) + self._send(instr) + status.add_status(sub_status) + + if wait: + status.wait() + return status + + def kickoff( + self, device: str | DeviceBase, parameters: dict | None = None, wait=True + ) -> ScanStubStatus: + """ + Kickoff a device with the given parameters. This will call the + "kickoff" method on the specified device with the given parameters. + + Args: + device (str or DeviceBase): device to kickoff + parameters (dict, optional): parameters for the kickoff. Defaults to None. + wait (bool, optional): if True, wait for the kickoff to complete. Defaults to True. + + Returns: + ScanStubStatus: status object to track the kickoff process + """ + device_name = device.name if isinstance(device, DeviceBase) else device + status = self._create_status(name=f"kickoff_{device_name}") + + instr = messages.DeviceInstructionMessage( + device=device_name, + action="kickoff", + parameter={"configure": parameters or {}}, + metadata={"device_instr_id": status._device_instr_id}, + ) + self._send(instr) + if wait: + status.wait() + return status + + def complete(self, device: str | DeviceBase, wait=True) -> ScanStubStatus: + """ + Complete a device. This will call the "complete" method on the device. + + To complete all devices, use the `complete_all_devices` method. + + Args: + device (str or DeviceBase): device to complete + wait (bool, optional): if True, wait for the completion to complete. Defaults to True. + + Returns: + ScanStubStatus: status object to track the completion process + """ + device_name = device.name if isinstance(device, DeviceBase) else device + status = self._create_status(name=f"complete_{device_name}") + + instr = messages.DeviceInstructionMessage( + device=device_name, + action="complete", + parameter={}, + metadata={"device_instr_id": status._device_instr_id}, + ) + self._send(instr) + if wait: + status.wait() + return status + + def complete_all_devices( + self, wait=True, exclude: str | DeviceBase | list[str | DeviceBase] | None = None + ) -> ScanStubStatus: + """ + Complete all devices for the scan. This will call the + "complete" method on all devices that are enabled for the scan. + + If you want to complete only specific devices, use the `complete` method. + + Args: + wait (bool, optional): if True, wait for the completion to complete. Defaults to True. + exclude (str | DeviceBase | list[str | DeviceBase] | None, optional): + device(s) to exclude from completion. Defaults to None. + + Returns: + ScanStubStatus: status object to track the completion process + """ + status = self._create_status(name="complete_all_devices") + device_names = [dev.root.name for dev in self._device_manager.devices.enabled_devices] + if exclude is not None: + excluded_device_names = set(self._normalize_device_names(exclude)) + device_names = [ + device_name + for device_name in device_names + if device_name not in excluded_device_names + ] + instr = messages.DeviceInstructionMessage( + device=device_names, + action="complete", + parameter={}, + metadata={"device_instr_id": status._device_instr_id}, + ) + self._send(instr) + if wait: + status.wait() + return status + + def read_monitored_devices(self, wait=True) -> ScanStubStatus: + """ + Read from the monitored devices. This will call the "read" method on + all devices that are currently configured with readout priority "monitored". + + Args: + wait (bool, optional): if True, wait for the read to complete. Defaults to True. + + Returns: + ScanStubStatus: status object to track the read process + """ + # We set a flag to indicate that we triggered the monitored devices. + # This is used to raise a warning if the scan definition tries to modify the + # readout groups after the monitored devices were read, which could lead to unexpected behavior. + self._readout_groups_read = True + + status = self._create_status(name="read_monitored_devices") + monitored_devices = [ + _dev.root.name + for _dev in self._device_manager.devices.monitored_devices( + readout_priority=self.readout_priority + ) + ] + if not monitored_devices: + status.set_done() + status.set_done_checked() + return status + monitored_devices = sorted(monitored_devices) + instr = messages.DeviceInstructionMessage( + device=monitored_devices, + action="read", + parameter={}, + metadata={ + "device_instr_id": status._device_instr_id, + "point_id": self._num_monitored_readouts, + }, + ) + self._send(instr) + self._num_monitored_readouts += 1 + if wait: + status.wait() + return status + + def read_manually( + self, devices: str | DeviceBase | list[str | DeviceBase], wait=True + ) -> Any | ScanStubStatus: + """ + Read the given devices and return the read data. This will call the + "read" method on the specified device(s). + + This action performs a regular device-server read and asks the device server + to include the read result in the instruction response. If ``wait`` is + False, the status object is returned instead of the read data. + + .. note :: + Reading manually is rarely the right choice; in almost all cases, + :meth:`read_monitored_devices` is the preferred and optimized action because it lets + the device server read and publish the monitored devices directly. Use ``read_manually`` + only when you need to intercept the read data for some reason before it is published and + cannot implement the interception on the device. + + Args: + devices (str | DeviceBase | list[str | DeviceBase]): device(s) to read. + wait (bool, optional): if True, wait for the read and return the read data. Defaults to True. + + Returns: + Any | ScanStubStatus: read data when ``wait`` is True, otherwise the status object. + """ + device_names = self._normalize_device_names(devices) + status = self._create_status(name=f"read_manually_{device_names}") + if not device_names: + status.set_done([]) + status.set_done_checked() + return status.result if wait else status + + instr = messages.DeviceInstructionMessage( + device=sorted(device_names), + action="read", + parameter={"return_result": True}, + metadata={"device_instr_id": status._device_instr_id}, + ) + self._send(instr) + if not wait: + return status + status.wait() + return status.result + + def publish_manual_read( + self, readings: dict[str, dict] | list[dict], wait=True + ) -> ScanStubStatus: + """ + Publish externally provided data as the next monitored-device readout. + + The provided readings must comply with the scan's currently configured + monitored devices. In almost all cases, :meth:`read_monitored_devices` is + the preferred and optimized action because it lets the device server read + and publish the monitored devices directly. ``publish_manual_read`` is + rarely the right choice; use it only when the scan has already acquired + equivalent monitored-device data manually and must attach that data to the + next scan point. + + Args: + readings (dict[str, dict] | list[dict]): readings for the currently + monitored devices. Dict keys must match the monitored device names. + A list may be provided either in monitored-device order or as + single-key dictionaries keyed by device name. + wait (bool, optional): retained for API consistency. Publishing is synchronous. + Defaults to True. + + Returns: + ScanStubStatus: status object to track the publish process. + """ + self._readout_groups_read = True + monitored_devices = self._get_monitored_device_names() + normalized_readings = self._normalize_manual_readings(readings, monitored_devices) + self._validate_manual_reading_signals(normalized_readings, monitored_devices) + + status = self._create_status(name="publish_manual_read") + if not monitored_devices: + status.set_done() + status.set_done_checked() + return status + + metadata = self._get_message_metadata() + metadata["point_id"] = self._num_monitored_readouts + if self._interruption_callback is not None: + self._interruption_callback() + pipe = self._connector.pipeline() + for device, signals in zip(monitored_devices, normalized_readings, strict=False): + msg = messages.DeviceMessage(signals=signals, metadata=metadata) + self._connector.set_and_publish(MessageEndpoints.device_read(device), msg, pipe=pipe) + pipe.execute() + self._num_monitored_readouts += 1 + status.set_done() + status.set_done_checked() + return status + + def read_baseline_devices(self, wait=True) -> ScanStubStatus: + """ + Read from the baseline devices. This will call the "read" method on all devices + that are configured with readout priority "baseline". + + Args: + wait (bool, optional): if True, wait for the read to complete. Defaults to True. + + Returns: + ScanStubStatus: status object to track the read process + """ + # We set a flag to indicate that we triggered the baseline devices + # This is used to raise a warning if the scan definition tries to modify the + # readout groups after the baseline devices were read, which could lead to unexpected behavior. + self._readout_groups_read = True + + status = self._create_status(name="read_baseline_devices") + baseline_devices = [ + _dev.root.name + for _dev in self._device_manager.devices.baseline_devices( + readout_priority=self.readout_priority + ) + ] + if not baseline_devices: + status.set_done() + status.set_done_checked() + return status + baseline_devices = sorted(baseline_devices) + instr = messages.DeviceInstructionMessage( + device=baseline_devices, + action="read", + parameter={}, + metadata={"device_instr_id": status._device_instr_id, "readout_priority": "baseline"}, + ) + self._send(instr) + if wait: + status.wait() + return status + + def trigger_all_devices(self, min_wait: float | None = None, wait=True) -> ScanStubStatus: + """ + Trigger all devices for the scan. The list of devices to trigger is determined automatically + based on their softwareTrigger configuration. + This will call the "trigger" method on all devices that are configured to be triggered for the scan. + + Args: + min_wait (float, optional): minimum time to wait before the trigger is executed. This can be used to ensure that the system has settled before the trigger is executed. Defaults to None. + wait (bool, optional): if True, wait for the trigger to complete. Defaults to True. + """ + status = self._create_status(name="trigger_all_devices") + devices = [ + dev.root.name for dev in self._device_manager.devices.get_software_triggered_devices() + ] + if not devices: + status.set_done() + status.set_done_checked() + return status + + devices = sorted(devices) + instr = messages.DeviceInstructionMessage( + device=devices, + action="trigger", + parameter={}, + metadata={"device_instr_id": status._device_instr_id}, + ) + self._send(instr) + if min_wait is not None: + time.sleep(min_wait) + if wait: + status.wait() + return status + + def unstage(self, device: str | DeviceBase, wait=True) -> ScanStubStatus: + """ + Unstage a device for the scan. This will call the "unstage" method on the specified device(s). + + If you want to unstage all devices, use the `unstage_all_devices` method. + + Args: + device (str or DeviceBase): device to unstage + wait (bool, optional): if True, wait for the unstaging to complete. Defaults to True. + + Returns: + ScanStubStatus: status object to track the unstaging process + """ + device_name = device.name if isinstance(device, DeviceBase) else device + status = self._create_status(name=f"unstage_{device_name}") + + instr = messages.DeviceInstructionMessage( + device=device_name, + action="unstage", + parameter={}, + metadata={"device_instr_id": status._device_instr_id}, + ) + self._send(instr) + if wait: + status.wait() + return status + + def unstage_all_devices( + self, wait=True, exclude: str | DeviceBase | list[str | DeviceBase] | None = None + ) -> ScanStubStatus: + """ + Unstage all devices for the scan. This will call the "unstage" method on all devices. + + If you want to unstage only specific devices, use the "unstage" method. + + Args: + wait (bool, optional): if True, wait for the unstaging to complete. Defaults to True. + exclude (str | DeviceBase | list[str | DeviceBase] | None, optional): + device(s) to exclude from unstaging. Defaults to None. + """ + + status = self._create_status(name="unstage_all_devices") + staged_devices = [dev.root.name for dev in self._device_manager.devices.enabled_devices] + if exclude is not None: + excluded_device_names = set(self._normalize_device_names(exclude)) + staged_devices = [ + device_name + for device_name in staged_devices + if device_name not in excluded_device_names + ] + instr = messages.DeviceInstructionMessage( + device=staged_devices, + action="unstage", + parameter={}, + metadata={"device_instr_id": status._device_instr_id}, + ) + self._send(instr) + if wait: + status.wait() + return status + + def add_scan_report_instruction_readback( + self, + devices: list[str | DeviceBase], + start: list[float], + stop: list[float], + request_id: str | None = None, + ): + """ + Add a readback report instruction to the instruction handler. + Readback instructions allow clients to subscribe to the readback of the given devices + and show a live update of their position during the scan as a progress bar. + + Args: + devices (list[str | DeviceBase]): list of device names or DeviceBase instances to report + start (list[float]): list of start positions for the devices + stop (list[float]): list of stop positions for the devices + request_id (str, optional): request ID to associate the readback instruction with. If None, the scan's RID will be used. Defaults to None. + """ + request_id = request_id or self._scan.scan_info.metadata["RID"] + device_names = [dev.name if isinstance(dev, DeviceBase) else dev for dev in devices] + scan_report_instruction = { + "readback": {"RID": request_id, "devices": device_names, "start": start, "end": stop} + } + self.add_device_with_required_response(device_names) + self._scan.scan_info.scan_report_instructions.append(scan_report_instruction) + if self._update_queue_info_callback is not None: + self._update_queue_info_callback() + + def add_scan_report_instruction_device_progress(self, device: str | DeviceBase): + """ + Add a device progress report instruction to the instruction handler. + Device progress instructions allow clients to subscribe to the progress signal of the given device + and show a live update of the progress during the scan as a progress bar. + + Args: + device (str | DeviceBase): name of the device or DeviceBase instance to report + """ + if isinstance(device, DeviceBase): + device_name = device.name + else: + device_name = device + scan_report_instruction = {"device_progress": [device_name]} + self._scan.scan_info.scan_report_instructions.append(scan_report_instruction) + if self._update_queue_info_callback is not None: + self._update_queue_info_callback() + + def add_scan_report_instruction_scan_progress(self, points: int = 0, show_table: bool = True): + """ + Add a scan progress report instruction to the instruction handler. + Scan progress instructions inform clients to print a table-like report of the scan progress. + If you don't know the number of points in advance, you can set points to 0. The progressbar will + not be able to estimate the remaining time in this case, but it will still show the elapsed time and the number of points completed. + + Args: + points (int, optional): total number of points in the scan, used to calculate the progress percentage. Defaults to 0. + show_table (bool, optional): if True, show a progress table with estimated time remaining. Defaults to True. + """ + scan_report_instruction = {"scan_progress": {"points": points, "show_table": show_table}} + self._scan.scan_info.scan_report_instructions.append(scan_report_instruction) + if self._update_queue_info_callback is not None: + self._update_queue_info_callback() + + def set_device_readout_priority( + self, + devices: list[DeviceBase] | list[str], + priority: Literal["baseline", "monitored", "on_request", "async"], + ): + """ + Set the readout priority for the given devices. This will determine when the devices are read out during the scan. + The provided list of devices is a modification to the existing readout priority. + + Adding device A that is by default a baseline device to priority "monitored" will move it from the baseline + readout to the monitored readout. All other devices will keep their default readout priority. + This method is particularly useful for adding scan motors to the monitored readouts so that their positions + are included in the scan report for each point. + + Args: + devices (list[str | DeviceBase]): List of device names or DeviceBase instances to set the readout priority for. + priority (str): Readout priority to set for the devices. Should be one of "baseline", "monitored", "on_request", or "async". + """ + if self._readout_groups_read: + msg = f"Warning: Modifying readout groups after they have been read can lead to unexpected behavior. Devices: {devices}, Priority: {priority}" + error_info = messages.ErrorInfo( + error_message=msg, + compact_error_message=msg, + exception_type="ReadoutGroupModificationWarning", + device=None, + ) + self._connector.raise_alarm(severity=Alarms.WARNING, info=error_info) + + if not isinstance(devices, list): + devices = [devices] + + for device in devices: + if isinstance(device, DeviceBase): + device_name = device.name + else: + device_name = device + self._scan.scan_info.readout_priority_modification[priority].append(device_name) + + def close_scan(self): + """Close the scan.""" + # We set the number of monitored readouts to the actual number of monitored + # readouts that were triggered during the scan. It will be broadcasted with + # the next scan status. + self._scan.scan_info.num_monitored_readouts = self._num_monitored_readouts + + self.check_for_unchecked_statuses() + + self._send_scan_status("closed") + + def check_for_unchecked_statuses(self): + """ + Check if there are any unchecked status objects left. + Their done status was not checked nor were they waited for. + While this is not an error, it is a warning that the scan + might not have completed as expected. + """ + + unchecked_status_objects = self._get_remaining_status_objects( + exclude_done=False, exclude_checked=True + ) + if unchecked_status_objects: + msg = f"Scan completed with unchecked status objects: {unchecked_status_objects}. Use .wait() or .done within the scan to check their status." + error_info = messages.ErrorInfo( + error_message=msg, + compact_error_message=msg, + exception_type="UncheckedStatusObjectsWarning", + device=None, + ) + self._connector.raise_alarm(severity=Alarms.WARNING, info=error_info) + + # Check if there are any remaining status objects that are not done. + # This is not an error but we send a warning and wait for them to complete. + remaining_status_objects = self._get_remaining_status_objects( + exclude_done=True, exclude_checked=False + ) + if remaining_status_objects: + msg = f"Scan completed with remaining status objects: {remaining_status_objects}" + error_info = messages.ErrorInfo( + error_message=msg, + compact_error_message=msg, + exception_type="ScanCleanupWarning", + device=None, + ) + self._connector.raise_alarm(severity=Alarms.WARNING, info=error_info) + for obj in remaining_status_objects: + obj.wait() + + def add_device_with_required_response( + self, device: str | DeviceBase | list[DeviceBase] | list[str] + ): + """ + Add a device to the set of devices with required response. + If a device is in this set, an additional "response" flag will be added to the metadata of the device instruction messages for this device. + The device server will then include a "response" message in the instruction response for this device, + which enabled clients to listen to the completion of the instruction more easily. + + If you are unsure whether a device needs to be added to this set, you probably don't need it. + It is mostly relevant for the simple mv and umv scans. + + Args: + device (str or DeviceBase or list[DeviceBase] or list[str]): device(s) to add to the set of devices with required response + """ + if isinstance(device, list): + for dev in device: + device_name = dev.name if isinstance(dev, DeviceBase) else dev + self._devices_with_required_response.add(device_name) + else: + device_name = device.name if isinstance(device, DeviceBase) else device + self._devices_with_required_response.add(device_name) + + def rpc_call(self, device: str, func_name: str, *args, **kwargs) -> Any | ScanStubStatus: + """ + Make an RPC call to a device. This will call the given function on the device with the given arguments. + The device server will execute the function and return the result in the instruction response. + This method is a low-level interface to call arbitrary functions on the device server and should be used with caution. + + Args: + device (str): name of the device to call the function on + func_name (str): name of the function to call on the device + *args: positional arguments to pass to the function + **kwargs: keyword arguments to pass to the function + + Example: + >>> # Call the "acquire_image" method on the "detector1" device with an exposure time of 1 second. + >>> # Similar to calling detector1.acquire_image(exposure_time=1.0) on the device server. + >>> result = self.actions.rpc_call("detector1", "acquire_image", exposure_time=1.0) + + >>> # Call the "start_interferometer" method on the "controller" sub-device of the "rt" device with some parameters. + >>> result = self.actions.rpc_call("rt.controller", "start_interferometer", param1=42, param2="foo") + + Returns: + Any | ScanStubStatus: The result of the RPC call or a ScanStubStatus object if the result is a status object. + + """ + status = self._create_status(name=f"rpc_{device}_{func_name}") + rpc_id = str(uuid.uuid4()) + parameter = { + "device": device, + "func": func_name, + "rpc_id": rpc_id, + "args": args, + "kwargs": kwargs, + } + msg = messages.DeviceInstructionMessage( + device=device, + action="rpc", + parameter=parameter, + metadata={"device_instr_id": status._device_instr_id}, + ) + self._send(msg) + status.wait(resolve_on_known_type=True) + if status._result_is_status: + return status + return status.result + + def send_client_info(self, message: str): + """ + Emit a new client info message. + Client info messages are meant to inform the user about the progress. They are shown in the GUI + statusbar. + + Args: + message (str): message to show in the statusbar + """ + self._connector.send_client_info( + message, rid=self._scan.scan_info.metadata.get("RID"), source="scan_server" + ) + + ######################################################################### + ############## Helper methods ########################################### + ######################################################################### + + def _create_status(self, is_container=False, name: str | None = None) -> ScanStubStatus: + """ + Helper method to create a status object and register it in the status registry. + + Args: + is_container (bool, optional): if True, the status object is merely a container for other status objects and should not be waited on directly. Defaults to False. + name (str, optional): name for the status object. Defaults to None. + """ + status = ScanStubStatus( + self._instruction_handler, + shutdown_event=self._shutdown_event, + registry=self._status_registry, + is_container=is_container, + name=name, + ) + self._status_registry[status._device_instr_id] = status + return status + + def _get_remaining_status_objects(self, exclude_done=True, exclude_checked=True): + """ + Get the remaining status objects. + + Args: + exclude_checked (bool, optional): Exclude checked status objects. Defaults to False. + exclude_done (bool, optional): Exclude done status objects. Defaults to True. + + Returns: + list: List of remaining status objects. + """ + objs = list(self._status_registry.values()) + if exclude_checked: + objs = [st for st in objs if not st._done_checked] + if exclude_done: + objs = [st for st in objs if not st.done] + return objs + + def _send(self, msg: messages.DeviceInstructionMessage): + """Send a message to the device server.""" + if self._interruption_callback is not None: + self._interruption_callback() + metadata = self._get_message_metadata() + msg.metadata = {**metadata, **msg.metadata} + instr_devices = msg.device if isinstance(msg.device, list) else [msg.device] + if set(instr_devices) & self._devices_with_required_response: + msg.metadata["response"] = True + self._connector.send(MessageEndpoints.device_instructions(), msg) + + def _get_message_metadata(self) -> dict: + metadata = {} + if self._scan.scan_info.scan_id is not None: + metadata["scan_id"] = self._scan.scan_info.scan_id + self._metadata_suffix + for key in ["RID", "queue_id"]: + value = self._scan.scan_info.metadata.get(key) + if value is not None: + metadata[key] = value + self._metadata_suffix + return metadata + + def _normalize_device_names( + self, devices: str | DeviceBase | list[str | DeviceBase] + ) -> list[str]: + if not isinstance(devices, list): + devices = [devices] + return [dev.name if isinstance(dev, DeviceBase) else dev for dev in devices] + + def _get_monitored_device_names(self) -> list[str]: + monitored_devices = [ + _dev.root.name + for _dev in self._device_manager.devices.monitored_devices( + readout_priority=self.readout_priority + ) + ] + return sorted(monitored_devices) + + @staticmethod + def _normalize_manual_readings( + readings: dict[str, dict] | list[dict], monitored_devices: list[str] + ) -> list[dict]: + if isinstance(readings, dict): + reading_devices = sorted(readings) + if reading_devices != monitored_devices: + missing_devices = sorted(set(monitored_devices) - set(reading_devices)) + unexpected_devices = sorted(set(reading_devices) - set(monitored_devices)) + raise ValueError( + "Manual read devices must match the currently monitored devices. " + f"Missing devices: {missing_devices}. " + f"Unexpected devices: {unexpected_devices}." + ) + return [readings[device] for device in monitored_devices] + + if not isinstance(readings, list): + raise TypeError("Manual readings must be provided as a dict or list of dictionaries.") + + if len(readings) != len(monitored_devices): + raise ValueError( + "Manual read count must match the currently monitored devices. " + f"Expected {len(monitored_devices)}, got {len(readings)}." + ) + + if all(isinstance(reading, dict) and len(reading) == 1 for reading in readings): + keyed_readings = {} + for reading in readings: + device, data = next(iter(reading.items())) + keyed_readings[device] = data + return ScanActions._normalize_manual_readings(keyed_readings, monitored_devices) + + if not all(isinstance(reading, dict) for reading in readings): + raise TypeError("Each manual reading must be a dictionary.") + return readings + + def _validate_manual_reading_signals( + self, readings: list[dict], monitored_devices: list[str] + ) -> None: + missing_signals = {} + for device, reading in zip(monitored_devices, readings, strict=False): + expected_signal_names = self._get_expected_read_signal_names(device) + missing = sorted(set(expected_signal_names) - set(reading)) + if missing: + missing_signals[device] = missing + + if missing_signals: + raise ValueError( + "Manual read data must include all signals from the currently monitored devices. " + f"Missing signals: {missing_signals}." + ) + + def _get_expected_read_signal_names(self, device: str) -> list[str]: + device_info = self._device_manager.devices[device]._info + signals = device_info.get("signals", {}) + signal_names = [ + signal_info.get("obj_name", signal_name) + for signal_name, signal_info in signals.items() + if self._signal_is_read_signal(signal_info) + ] + if not signal_names: + raise ValueError( + f"Cannot validate manual read data for monitored device {device!r}: " + "no read signals are configured in the device metadata." + ) + return signal_names + + @staticmethod + def _signal_is_read_signal(signal_info: dict) -> bool: + kind = signal_info.get("kind_str", "").lower() + if "config" in kind or "omitted" in kind: + return False + return True + + def _send_scan_status( + self, + status: Literal["open", "paused", "closed", "aborted", "halted", "user_completed"], + reason: Literal["user", "alarm"] | None = None, + ) -> None: + """Publish the current scan status for the active direct scan.""" + scan = self._scan + logger.info(f"New scan status: {scan.scan_info.scan_id} / {status} / {scan.scan_info}") + msg = self._build_scan_status_message(status=status, reason=reason) + + expire = None if status in ["open", "paused"] else 1800 + pipe = self._connector.pipeline() + self._connector.set( + MessageEndpoints.public_scan_info(scan.scan_info.scan_id), msg, pipe=pipe, expire=expire + ) + self._connector.set_and_publish(MessageEndpoints.scan_status(), msg, pipe=pipe) + pipe.execute() + + def _build_scan_status_message( + self, + status: Literal["open", "paused", "closed", "aborted", "halted", "user_completed"], + reason: Literal["user", "alarm"] | None = None, + ) -> messages.ScanStatusMessage: + """Build the scan status message for the active direct scan.""" + legacy_scan_parameters = self._get_legacy_scan_parameters(self._scan.scan_info) + resolved_readout_priority = self._get_resolved_readout_priority() + file_components = self._get_file_components(self._scan.scan_info) + info = self._build_scan_status_info( + legacy_scan_parameters=legacy_scan_parameters, + resolved_readout_priority=resolved_readout_priority, + file_components=file_components, + ) + scan_info = self._scan.scan_info + scan_type = scan_info.scan_type + return messages.ScanStatusMessage( + scan_id=scan_info.scan_id, + status=status, + reason=reason, + scan_name=scan_info.scan_name, + scan_number=scan_info.scan_number, + session_id=scan_info.metadata.get("session_id"), + dataset_number=scan_info.dataset_number, + num_points=scan_info.num_points, + scan_type=scan_type if scan_type in {"step", "fly"} else None, + scan_report_devices=scan_info.scan_report_devices, + user_metadata=scan_info.user_metadata, + readout_priority=resolved_readout_priority, + scan_parameters=legacy_scan_parameters, + request_inputs=scan_info.request_inputs, + num_monitored_readouts=scan_info.num_monitored_readouts, + info=info, + ) + + def _build_scan_status_info( + self, + legacy_scan_parameters: dict, + resolved_readout_priority: ReadoutPriorityMap, + file_components: tuple[str, str] | None, + ) -> dict: + """Build the compatibility-augmented info payload for scan status messages.""" + base_info = self._scan.scan_info.model_dump(mode="python") + if base_info.get("positions") is not None: + base_info["positions"] = base_info["positions"].tolist() + compatibility_fields = { + "scan_parameters": legacy_scan_parameters, + "readout_priority": resolved_readout_priority, + "file_components": file_components, + } + return {**base_info, **compatibility_fields} + + def _get_legacy_scan_parameters(self, scan_info: ScanInfo) -> dict: + scan_parameters = { + "exp_time": scan_info.exp_time, + "frames_per_trigger": scan_info.frames_per_trigger, + "settling_time": scan_info.settling_time, + "readout_time": scan_info.readout_time, + "relative": scan_info.relative, + } + scan_parameters.update(scan_info.additional_scan_parameters or {}) + if scan_info.system_config is not None: + scan_parameters["system_config"] = scan_info.system_config + return {key: value for key, value in scan_parameters.items() if value is not None} + + def _get_resolved_readout_priority(self) -> ReadoutPriorityMap: + readout_priority = self._scan.scan_info.readout_priority_modification + return { + "monitored": [ + dev.full_name + for dev in self._device_manager.devices.monitored_devices( + readout_priority=readout_priority + ) + ], + "baseline": [ + dev.full_name + for dev in self._device_manager.devices.baseline_devices( + readout_priority=readout_priority + ) + ], + "async": [ + dev.full_name + for dev in self._device_manager.devices.async_devices( + readout_priority=readout_priority + ) + ], + "continuous": [ + dev.full_name + for dev in self._device_manager.devices.continuous_devices( + readout_priority=readout_priority + ) + ], + "on_request": [ + dev.full_name + for dev in self._device_manager.devices.on_request_devices( + readout_priority=readout_priority + ) + ], + } + + def _get_file_components(self, scan_info: ScanInfo) -> tuple[str, str] | None: + scan_number = scan_info.scan_number + system_config = scan_info.system_config or {} + if scan_number is None or "file_directory" not in system_config: + return None + return compile_file_components( + base_path=self._get_file_base_path(), + scan_nr=scan_number, + file_directory=system_config["file_directory"], + user_suffix=system_config.get("file_suffix"), + ) + + def _get_file_base_path(self) -> str: + current_account_msg = self._connector.get_last(MessageEndpoints.account(), "data") + if current_account_msg: + current_account = current_account_msg.value + if not isinstance(current_account, str): + logger.warning( + f"Account name is not a string: {current_account}. Ignoring specified value." + ) + current_account = None + else: + if "/" in current_account: + raise ValueError( + f"Account name cannot contain a slash (/): {current_account}. " + ) + check_value = current_account.replace("_", "").replace("-", "") + if not check_value.isalnum() or not check_value.isascii(): + raise ValueError( + f"Account name can only contain alphanumeric characters: {current_account}. " + ) + else: + current_account = None + + file_base_path = self._device_manager.parent._service_config.config["file_writer"][ + "base_path" + ] + if "$" not in file_base_path: + if current_account: + return os.path.abspath(os.path.join(file_base_path, current_account)) + return os.path.abspath(file_base_path) + + file_base_path = Template(file_base_path) + try: + return os.path.abspath(file_base_path.substitute(account=current_account or "")) + except KeyError as exc: + raise ValueError( + f"Invalid template variable: {exc} in the file base path. Please check your service config." + ) from exc diff --git a/bec_server/bec_server/scan_server/scans/scan_components.py b/bec_server/bec_server/scan_server/scans/scan_components.py new file mode 100644 index 000000000..a64f1f749 --- /dev/null +++ b/bec_server/bec_server/scan_server/scans/scan_components.py @@ -0,0 +1,243 @@ +from __future__ import annotations + +import time +from typing import TYPE_CHECKING, Callable, Literal + +import numpy as np + +from bec_lib.device import DeviceBase +from bec_server.scan_server.errors import LimitError +from bec_server.scan_server.path_optimization import PathOptimizerMixin + +if TYPE_CHECKING: + from bec_server.scan_server.scans.scans_v4 import ScanBase + + +class ScanComponents: + """ + Class to handle the components for the scan logic. + The components are reusable building blocks for the scan logic, + such as step scans or grid scans. They use the ScanStubs to + execute the scan logic. + """ + + def __init__(self, scan: ScanBase): + self._scan = scan + self._actions = scan.actions + self._redis_connector = scan.redis_connector + self._device_manager = scan.device_manager + self._dev = self._device_manager.devices if self._device_manager else None + self._path_optimizer = PathOptimizerMixin() + + def move_and_wait( + self, + motors: list[str | DeviceBase] | list[str] | list[DeviceBase], + positions: np.ndarray | list[float], + last_positions: np.ndarray | None = None, + ): + """ + Move the given motors to the given positions and wait for the movement to complete. + If last_positions is provided, only the motors with changed positions will be moved. + + Args: + motors (list[str | DeviceBase] | list[str] | list[DeviceBase]): List of motor names or device instances to move. + positions (np.ndarray | list[float]): Array or list of positions to move to, shape (len(motors),). + last_positions (np.ndarray, optional): Array of last positions, shape (len(motors),). + If provided, only motors with changed positions will be moved. Defaults to None. + """ + motors_to_move = [] + positions_to_move = [] + for motor_index, motor in enumerate(motors): + if last_positions is not None: + if np.isclose(positions[motor_index], last_positions[motor_index]): + continue + motors_to_move.append(motor) + positions_to_move.append(positions[motor_index]) + + if motors_to_move: + self._actions.set(motors_to_move, positions_to_move, wait=True) + + def trigger_and_read(self): + """ + Trigger the devices and start the readout. This is typically used for step scans after the motors have been moved to the next position. + + The logic is as follows: + 1. Let the system settle before triggering + 2. Trigger the devices + 3. Let the system settle after the trigger + 4. Start the readout + + """ + # Let the system settle before triggering + time.sleep(self._scan.scan_info.settling_time) + trigger_time = self._scan.scan_info.exp_time * self._scan.scan_info.frames_per_trigger + + # Trigger the devices + self._actions.trigger_all_devices(min_wait=trigger_time) + + # Let the system settle after the trigger + time.sleep(self._scan.scan_info.settling_time_after_trigger) + + # Start the readout + self._actions.read_monitored_devices() + + def step_scan( + self, + motors: list[str | DeviceBase] | list[str] | list[DeviceBase], + positions: np.ndarray, + at_each_point: ( + Callable[[list[str | DeviceBase], np.ndarray, np.ndarray | None], None] | None + ) = None, + last_positions: np.ndarray | None = None, + ): + """ + Execute a step scan with the given positions. It is the core scan logic + for most step scans. + + Args: + motors (list[str | DeviceBase] | list[str] | list[DeviceBase]): List of motor names or device instances to move. + positions (np.ndarray): Array of positions to move to, shape (num_points, len(motors)). + at_each_point (Callable[[list[str | DeviceBase], np.ndarray, np.ndarray | None], None], optional): Function to call at each point. Defaults to None. + last_positions (np.ndarray, optional): Array of last positions, shape (num_points, len(motors)). If provided, only motors with changed positions will be moved. Defaults to None. + """ + at_each_point = at_each_point or self.step_scan_at_each_point + for pos in positions: + for _ in range(self._scan.scan_info.burst_at_each_point): + at_each_point(motors, pos, last_positions=last_positions) + last_positions = pos.copy() + + def step_scan_at_each_point( + self, + motors: list[str | DeviceBase] | list[str] | list[DeviceBase], + pos: np.ndarray, + last_positions: np.ndarray | None = None, + ): + """ + Execute a step scan at each point. This is the core logic that is executed at each point of the step scan. + It is separated from the step_scan method to allow scan hooks to override the logic. + + The logic is as follows: + 1. Move the motors to the next position without waiting for each motor to complete + 2. Wait for each motor to complete + 3. Let the system settle before triggering + 4. Trigger the devices + 5. Let the system settle after the trigger + 6. Start the readout + + Args: + motors (list[str | DeviceBase] | list[str] | list[DeviceBase]): List of motor names or device instances to move. + pos (np.ndarray): Array of positions to move to, shape (len(motors),). + last_positions (np.ndarray, optional): Array of last positions, shape (len(motors),). + If provided, only motors with changed positions will be moved. Defaults to None. + """ + self.move_and_wait(motors, pos, last_positions=last_positions) + self.trigger_and_read() + + def get_start_positions( + self, motors: list[str | DeviceBase] | list[str] | list[DeviceBase] + ) -> list[float]: + """ + Get the current position of the given motors. This can be used to make the positions relative to the current position of the motors. + + Args: + motors (list[str | DeviceBase] | list[str] | list[DeviceBase]): List of motor names or device instances. + + Returns: + list[float]: List of current positions of the motors. + """ + start_positions = [] + for motor in motors: + if isinstance(motor, str): + obj = self._dev[motor] + else: + obj = motor + val = obj.read() + start_positions.append(val[obj.full_name].get("value")) + return start_positions + + def optimize_trajectory( + self, + positions: np.ndarray, + optimization_type: Literal["corridor", "shell", "nearest"] = "corridor", + primary_axis: int = 1, + preferred_directions: list[int] | None = None, + corridor_size: int | None = None, + num_iterations: int = 5, + ) -> np.ndarray: + """ + Optimize the trajectory of the scan by reordering the positions. This can help to minimize the movement time of the motors. + The optimization can be done in different ways, depending on the optimization_type parameter: + - "corridor": optimize the trajectory in a corridor-like way, where the scan moves back and forth along the primary axis. This is typically a good choice for grid scans. If preferred_directions are provided, the optimizer will try to optimize the trajectory in a way that minimizes the movement in the non-preferred direction. + - "shell": optimize the trajectory in a shell-like way, where the scan moves in a spiral from the outside to the inside. This is typically a good choice for round scans. + - "nearest": optimize the trajectory by always moving to the nearest next point. This is typically a good choice for random scans. + + Args: + positions (np.ndarray): Array of positions to optimize, shape (num_points, num_motors). + optimization_type (str, optional): Type of optimization to perform. Defaults to "corridor". + primary_axis (int, optional): Primary axis for corridor optimization. Defaults to 1. + preferred_directions (list[int] | None, optional): List of preferred directions for the non-primary axes. Each entry should be -1, 0, or 1, indicating the preferred direction of movement along that axis. The length of the list should be equal to the number of non-primary axes. Defaults to None, which means no preferred directions. + corridor_size (int | None, optional): Size of the corridor for corridor optimization. Defaults to None, which means the default corridor size will be used. + Returns: + np.ndarray: Optimized array of positions, shape (num_points, num_motors). + """ + + if optimization_type == "corridor": + if preferred_directions is None or len(preferred_directions) == 0: + positions = self._path_optimizer.optimize_corridor( + positions, + num_iterations=num_iterations, + corridor_size=corridor_size, + sort_axis=primary_axis, + ) + else: + preferred_direction = ( + preferred_directions[primary_axis] + if len(preferred_directions) > primary_axis + else None + ) + positions = self._path_optimizer.optimize_corridor( + positions, + num_iterations=num_iterations, + sort_axis=primary_axis, + preferred_direction=preferred_direction, + corridor_size=corridor_size, + ) + + elif optimization_type == "shell": + positions = self._path_optimizer.optimize_shell( + positions, num_iterations=num_iterations + ) + elif optimization_type == "nearest": + positions = self._path_optimizer.optimize_nearest_neighbor(positions) + else: + raise ValueError(f"Invalid optimization type: {optimization_type}") + return positions + + def check_limits( + self, motors: list[str | DeviceBase] | list[str] | list[DeviceBase], positions: np.ndarray + ): + """ + Check if the given positions for the given motors are within the limits of the motors. + If not, raise a LimitError. + + Args: + motors (list[str | DeviceBase] | list[str] | list[DeviceBase]): List of motor names or device instances. + positions (np.ndarray): Array of positions to check, shape (num_points, len(motors)). + + Raises: + LimitError: If any of the positions are out of limits for the corresponding motor. + """ + for motor_index, motor in enumerate(motors): + if isinstance(motor, str): + low_limit, high_limit = self._dev[motor].limits + else: + low_limit, high_limit = motor.limits + if low_limit >= high_limit: + # if both limits are the same or low > high, no limits are set + continue + for pos in positions[:, motor_index]: + if not low_limit <= pos <= high_limit: + raise LimitError( + f"Target position {pos} for motor {motor} is out of limits ({low_limit}, {high_limit})", + device=motor if isinstance(motor, str) else motor.full_name, + ) diff --git a/bec_server/bec_server/scan_server/scans/scan_modifier.py b/bec_server/bec_server/scan_server/scans/scan_modifier.py new file mode 100644 index 000000000..2d9a46261 --- /dev/null +++ b/bec_server/bec_server/scan_server/scans/scan_modifier.py @@ -0,0 +1,89 @@ +from __future__ import annotations + +from functools import wraps +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from bec_lib.device import DeviceBase + from bec_server.scan_server.scans.scans_v4 import ScanBase + + +def scan_hook(func): + """ + Decorator for scan hooks. It registers the decorated method as a scan hook and thus allows + scan modifiers to override or augment the scan logic. + """ + + @wraps(func) + def wrapper(self, *args, **kwargs): + return func(self, *args, **kwargs) + + # pylint: disable=protected-access + wrapper._scan_hook_info = {"method_name": func.__name__} # type: ignore + + return wrapper + + +def scan_hook_impl(hook_type: str): + """ + Decorator for scan hook implementations. It registers the decorated method as an implementation of the specified scan hook type. + The hook_type should be one of the following: "before", "after" or "replace". + This allows the scan modifier to specify whether the decorated method should be executed before, after or instead of the original scan hook method. + """ + + def decorator(func): + @wraps(func) + def wrapper(self, *args, **kwargs): + return func(self, *args, **kwargs) + + # pylint: disable=protected-access + wrapper._scan_hook_impl_info = {"hook_type": hook_type} # type: ignore + + return wrapper + + return decorator + + +# pragma: no cover +# def get_scan_hooks(cls) -> dict[str, str]: +# """ +# Get the scan hooks defined in the given class. It returns a dictionary mapping the hook method names to their corresponding scan hook types. +# """ +# hooks = {} +# for attr_name in dir(cls): +# attr = getattr(cls, attr_name) +# if callable(attr) and hasattr(attr, "_scan_hook_info"): +# hook_info = attr._scan_hook_info # type: ignore +# hooks[hook_info["method_name"]] = attr_name +# return hooks + + +# def prepare_eiger(scan: ScanBase): +# if "eiger" not in scan.dev: +# return +# print("Preparing Eiger for the scan...") +# eiger = scan.dev["eiger"] +# num_frames = scan.scan_info.frames_per_trigger * scan.scan_info.num_points +# eiger.num_frames.set(num_frames).wait() + + +# def prepare_falcon(scan: ScanBase): +# if "falcon" not in scan.dev: +# return +# print("Preparing Falcon for the scan...") +# falcon = scan.dev["falcon"] +# falcon.num_frames.set(100).wait() + + +# class ScanModifier: + +# @scan_hook_impl("after") +# def stage(self, scan: ScanBase): +# """ +# Stage the devices for the upcoming scan. The stage logic is typically +# implemented on the device itself (i.e. by the device's stage method). +# However, if there are any additional steps that need to be executed before +# staging the devices, they can be implemented here. +# """ +# prepare_eiger(scan) +# prepare_falcon(scan) diff --git a/bec_server/bec_server/scan_server/scans/scans_v4.py b/bec_server/bec_server/scan_server/scans/scans_v4.py new file mode 100644 index 000000000..4d872347e --- /dev/null +++ b/bec_server/bec_server/scan_server/scans/scans_v4.py @@ -0,0 +1,260 @@ +""" +Module for handling scans for v4 of BEC. In contrast to previous implementations, the scan logic does not rely on generators +executed on the worker but instead uses the RedisConnector to send commands directly to the devices. +""" + +from __future__ import annotations + +import enum +import threading +from collections.abc import Sequence +from typing import Annotated + +import numpy as np +import pint +from pydantic import BaseModel, ConfigDict, Field, field_validator +from toolz import partition + +from bec_lib.device import DeviceBase +from bec_lib.devicemanager import DeviceManagerBase as DeviceManager +from bec_lib.redis_connector import RedisConnector +from bec_server.scan_server.instruction_handler import InstructionHandler +from bec_server.scan_server.scans.scan_actions import ScanActions +from bec_server.scan_server.scans.scan_components import ScanComponents + +Units = pint.UnitRegistry() + + +class ScanType(str, enum.Enum): + HARDWARE_TRIGGERED = "hardware_triggered" + SOFTWARE_TRIGGERED = "software_triggered" + + +def bundle_args(args: tuple, bundle_size: int) -> dict: + """ + Bundle the given arguments into bundles of the given size. + + Args: + args (tuple): arguments to bundle + bundle_size (int): size of the bundles + + Returns: + dict: bundled arguments + + """ + params = {} + for cmds in partition(bundle_size, args): + params[cmds[0]] = list(cmds[1:]) + return params + + +class ScanInfo(BaseModel): + + # General scan information + scan_name: Annotated[str, Field(description="Name of the scan type, e.g. 'grid_scan'")] + scan_id: Annotated[str, Field(description="Unique identifier for the scan")] + scan_type: Annotated[ + ScanType | None, + Field( + None, description="Type of the scan, e.g. 'software_triggered' or 'hardware_triggered'" + ), + ] + scan_number: Annotated[int | None, Field(description="Scan number, if applicable")] = None + dataset_number: Annotated[int | None, Field(description="Dataset number, if applicable")] = None + + # Scan parameters + num_points: Annotated[int, Field(description="Number of points in the scan.")] = 0 + positions: Annotated[ + np.ndarray | None, + Field(description="Positions for the scan, shape (num_points, num_motors)"), + ] = None + exp_time: Annotated[float, Field(description="Exposure time for the scan", ge=0.0)] = 0.0 + frames_per_trigger: Annotated[int, Field(description="Number of frames per trigger", ge=1)] = 1 + settling_time: Annotated[ + float, Field(description="Settling time before the software trigger", ge=0.0) + ] = 0.0 + settling_time_after_trigger: Annotated[ + float, Field(description="Settling time after the software trigger", ge=0.0) + ] = 0.0 + readout_time: Annotated[float, Field(description="Readout time after the trigger", ge=0.0)] = ( + 0.0 + ) + burst_at_each_point: Annotated[ + int, Field(description="Number of bursts at each point", ge=1) + ] = 1 + relative: Annotated[ + bool, Field(description="Whether the positions are relative or absolute") + ] = False + run_on_exception_hook: Annotated[ + bool, Field(description="Whether to run the on_exception hook if the scan is interrupted") + ] = True + + request_inputs: Annotated[dict, Field(description="Request inputs")] = {} + readout_priority_modification: Annotated[ + dict, Field(description="Readout priority modification") + ] = {"baseline": [], "monitored": [], "on_request": [], "async": []} + scan_report_instructions: Annotated[ + list[dict], Field(description="List of scan report instructions") + ] = [] + scan_report_devices: Annotated[ + list[str], Field(description="List of devices to report during the scan") + ] = [] + monitor_sync: Annotated[ + str | None, + Field(description="Monitor synchronization mode for fly scans"), # Will be removed! + ] = None + additional_scan_parameters: Annotated[dict, Field(description="Additional scan parameters")] = ( + {} + ) + user_metadata: Annotated[dict, Field(description="User-provided metadata for the scan")] = {} + system_config: Annotated[dict, Field(description="System configuration for the scan")] = {} + scan_queue: Annotated[str, Field(description="Name of the queue the scan belongs to")] = ( + "primary" + ) + metadata: Annotated[dict, Field(description="Additional metadata for the scan")] = {} + + # progress tracking + num_monitored_readouts: Annotated[ + int, + Field( + description="Number of performed readouts of monitored devices. For a step scan, this is equal to num_points * burst_at_each_point." + ), + ] = 0 + + def __str__(self) -> str: + data = self.model_dump(mode="python") + positions = self.positions + if isinstance(positions, np.ndarray): + data["positions"] = np.array2string(positions, threshold=8, edgeitems=2, precision=4) + return f"{self.__class__.__name__}({data})" + + __repr__ = __str__ + + @field_validator("scan_report_devices", mode="before") + @classmethod + def _serialize_scan_report_devices(cls, value: object) -> object: + """ + Convert scan report devices to device names. + + Args: + value (object): List of device names or ``DeviceBase`` instances. + + Returns: + object: List with ``DeviceBase`` instances replaced by their names. + """ + if value is None: + return value + return [device.name if isinstance(device, DeviceBase) else device for device in value] + + model_config = ConfigDict(validate_assignment=True, arbitrary_types_allowed=True) + + +class ScanBase: + scan_type = ScanType.SOFTWARE_TRIGGERED + scan_name = "_v4_base_scan" + required_kwargs = [] + arg_input = {} + arg_bundle_size = {"bundle": len(arg_input), "min": None, "max": None} + is_scan = True + + def __init__( + self, + scan_id: str, + redis_connector: RedisConnector, + device_manager: DeviceManager, + instruction_handler: InstructionHandler, + request_inputs: dict, + system_config: dict, + user_metadata: dict | None = None, + metadata: dict | None = None, + scan_queue: str | None = None, + run_on_exception_hook: bool | None = None, + additional_scan_parameters: dict | None = None, + ): + """Base class for all scans.""" + self.redis_connector = redis_connector + self.device_manager = device_manager + self._instruction_handler = instruction_handler + self.dev = self.device_manager.devices + self._shutdown_event = threading.Event() + self.actions = ScanActions(scan=self) + self.components = ScanComponents(scan=self) + + optional_kwargs = {} + for kwarg in [ + "metadata", + "user_metadata", + "scan_queue", + "run_on_exception_hook", + "additional_scan_parameters", + ]: + data = locals()[kwarg] + if data is not None: + optional_kwargs[kwarg] = data + self.scan_info = ScanInfo( + scan_name=self.scan_name, scan_id=scan_id, scan_type=self.scan_type, **optional_kwargs + ) + self.scan_info.request_inputs = request_inputs + self.scan_info.system_config = system_config + self._baseline_readout_status = None + self._premove_motor_status = None + self.positions = np.array([]) + self.start_positions = [] + + def update_scan_info( + self, + num_points: int | None = None, + num_monitored_readouts: int | None = None, + positions: np.ndarray | None = None, + exp_time: float | None = None, + frames_per_trigger: int | None = None, + settling_time: float | None = None, + settling_time_after_trigger: float | None = None, + burst_at_each_point: int | None = None, + relative: bool | None = None, + run_on_exception_hook: bool | None = None, + scan_report_devices: Sequence[str | DeviceBase] | None = None, + **kwargs, + ): + """ + Update the scan info with the given keyword arguments. + If the scan info model has an attribute with the same name as the keyword argument, + it will be updated. Otherwise, the keyword argument will be added to the additional_scan_parameters dictionary. + This allows for flexible scan info management, where standard parameters can be defined as attributes of the + ScanInfo model, and any additional parameters can be stored in the additional_scan_parameters dictionary. + + Args: + num_points (int, optional): Number of points in the scan. Defaults to None. + num_monitored_readouts (int, optional): Number of monitored readouts that will be collected during the scan. Defaults to None. + positions (np.ndarray, optional): Positions for the scan, shape (num_points, num_motors). Defaults to None. + exp_time (float, optional): Exposure time for the scan. Defaults to None. + frames_per_trigger (int, optional): Number of frames per trigger. Defaults to None. + settling_time (float, optional): Settling time before the software trigger. Defaults to None. + settling_time_after_trigger (float, optional): Settling time after the software trigger. Defaults to None. + burst_at_each_point (int, optional): Number of bursts at each point. Defaults to None. + relative (bool, optional): Whether the positions are relative or absolute. Defaults to None. + run_on_exception_hook (bool, optional): Whether to run the on_exception hook if the scan is interrupted. Defaults to None. + scan_report_devices (Sequence[str | DeviceBase], optional): Devices to report + during the scan. Device objects are stored by name. Defaults to None. + **kwargs: Keyword arguments to update the scan info with. + """ + for attr_name, value in [ + ("num_points", num_points), + ("num_monitored_readouts", num_monitored_readouts), + ("positions", positions), + ("exp_time", exp_time), + ("frames_per_trigger", frames_per_trigger), + ("settling_time", settling_time), + ("settling_time_after_trigger", settling_time_after_trigger), + ("burst_at_each_point", burst_at_each_point), + ("relative", relative), + ("run_on_exception_hook", run_on_exception_hook), + ("scan_report_devices", scan_report_devices), + ]: + if value is not None: + setattr(self.scan_info, attr_name, value) + for key, value in kwargs.items(): + if hasattr(self.scan_info, key): + setattr(self.scan_info, key, value) + else: + self.scan_info.additional_scan_parameters[key] = value diff --git a/bec_server/bec_server/scan_server/scans/state_transition_scan.py b/bec_server/bec_server/scan_server/scans/state_transition_scan.py new file mode 100644 index 000000000..a9f8ff973 --- /dev/null +++ b/bec_server/bec_server/scan_server/scans/state_transition_scan.py @@ -0,0 +1,265 @@ +""" +Updated move scan implementation for coordinated motor repositioning commands. + +Scan procedure: + - prepare_scan + - open_scan + - stage + - pre_scan + - scan_core + - at_each_point (optionally called by scan_core) + - post_scan + - unstage + - close_scan + - on_exception (called if any exception is raised during the scan) +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Tuple + +from bec_lib.alarm_handler import AlarmBase, Alarms +from bec_lib.bl_states import AggregatedState, SubDeviceStateConfig +from bec_lib.device import DeviceBase, Positioner, Signal +from bec_lib.endpoints import MessageEndpoints +from bec_lib.logger import bec_logger +from bec_lib.messages import AlarmMessage, ErrorInfo +from bec_server.scan_server.scans.scan_modifier import scan_hook +from bec_server.scan_server.scans.scans_v4 import ScanBase + +if TYPE_CHECKING: + from bec_lib.bl_states import AggregatedStateConfig, ResolvedStateSignal + from bec_lib.messages import AvailableBeamlineStatesMessage + +logger = bec_logger.logger + + +class StateTransitionScanError(AlarmBase): + """Exception raised when an RPC call fails.""" + + def __init__(self, exc_type: str, message: str, compact_message: str) -> None: + alarm = AlarmMessage( + severity=Alarms.MAJOR, + info=ErrorInfo( + exception_type=exc_type, + error_message=message, + compact_error_message=compact_message, + ), + ) + super().__init__(alarm, Alarms.MAJOR, handled=False) + + +class StateTransitionScan(ScanBase): + + # Scan Type: Hardware triggered or software triggered? + # If the main trigger and readout logic is done within the at_each_point method in scan_core, choose SOFTWARE_TRIGGERED. + # If the main trigger and readout logic is implemented on a device that is simply kicked off in this scan, choose HARDWARE_TRIGGERED. + # This primarily serves as information for devices: The device may need to react differently if a software trigger is expected + # for every point. + scan_type = None + + # Scan name: This is the name of the scan, e.g. "line_scan". This is used for display purposes and to identify the scan type in user interfaces. + # Choose a descriptive name that does not conflict with existing scan names. + scan_name = "_v4_state_transition" + + # We set is_scan to False to separate this class from the other scans in the user interface + is_scan = False + + def __init__(self, *args, state_name: str, target_label: str, **kwargs): + """ + State transition scan that moves a motor in between two states. + The main purpose of this scan is to be used in conjunction with state + management in BEC, and transitioning the beamline in-between different aggregated states. + """ + super().__init__(**kwargs) + self.state_name = state_name + self.target_label = target_label + # Check if the state and the target label exists, if yes, fetch the configuration for the target state + self.config_for_label = self._fetch_config_for_label(state_name, target_label) + + # We need to sort the devices and signals in the config, and identify which of them are motor setpoint/readback pairs + # and which of them are just readouts and thereby can not be set within the transition. + self._signals_to_set: list[Tuple[Signal, Any]] = [] + self._limits_to_set: dict[str, Tuple[Positioner, float, float]] = {} + self._devices_to_set: list[Tuple[Positioner, float]] = [] + + # pylint: disable=protected-access + @scan_hook + def prepare_scan(self): + """ + Prepare the scan. This can include any steps that need to be executed + before the scan is opened, such as preparing the positions (if not done already) + or setting up the devices. + """ + requirements: list[ResolvedStateSignal] = AggregatedState.get_state_requirements( + self.target_label, self.config_for_label, self.device_manager, "StateTransitionScan" + ) + for req in requirements: + dev_obj: DeviceBase = self.device_manager.devices.get(req.device_name) + # Device not found + if dev_obj is None: + raise StateTransitionScanError( + exc_type="DeviceNotFound", + message=f"Device {req.device_name} not found in device manager.", + compact_message=f"Device {req.device_name} not found.", + ) + # First we handle Signals logic + if isinstance(dev_obj, Signal): + self._signals_to_set.append((dev_obj, req.expected_value)) + continue + # Positioner and Device logic. Devices must implement .set for this to work, otherwise we can not set them and we raise an error + if isinstance(dev_obj, DeviceBase): + # Handle motor-specific logic here + # First we handle logic for motions of the motor. Device_name and signal_name will be equivalent here + if req.signal_name == req.device_name: + self._devices_to_set.append((dev_obj, req.expected_value)) + continue + if req.signal_name in ["low", "high"]: + if req.device_name not in self._limits_to_set: + self._limits_to_set[req.device_name] = ( + dev_obj, + dev_obj.low_limit, + dev_obj.high_limit, + ) + if req.signal_name == "low_limit": + self._limits_to_set[req.device_name] = ( + dev_obj, + req.expected_value, + self._limits_to_set[req.device_name][2], + ) + else: + self._limits_to_set[req.device_name] = ( + dev_obj, + self._limits_to_set[req.device_name][1], + req.expected_value, + ) + continue + signal_obj = self._get_signal_object(dev_obj, req.signal_name) + if signal_obj is None: + raise StateTransitionScanError( + exc_type="SignalNotFound", + message=f"Signal {req.signal_name} for device {req.device_name} not found in device manager.", + compact_message=f"Signal {req.signal_name} for device {req.device_name} not found.", + ) + self._signals_to_set.append((signal_obj, req.expected_value)) + continue + + self.update_scan_info(scan_report_devices=[dev for dev, _ in self._devices_to_set]) + + def _get_signal_object(self, device_obj: DeviceBase, signal_name: str) -> Signal: + for component_name, info in device_obj._info["signals"].items(): + if info["obj_name"] == signal_name: + return getattr(device_obj, component_name) + + @scan_hook + def open_scan(self): + """ + Open the scan. + This step must call self.actions.open_scan() to ensure that a new scan is + opened. Make sure to prepare the scan metadata before, either in + prepare_scan() or in open_scan() itself and call self.update_scan_info(...) + to update the scan metadata if needed. + """ + + @scan_hook + def stage(self): + """ + Stage the devices for the upcoming scan. The stage logic is typically + implemented on the device itself (i.e. by the device's stage method). + However, if there are any additional steps that need to be executed before + staging the devices, they can be implemented here. + """ + + @scan_hook + def pre_scan(self): + """ + Pre-scan steps to be executed before the main scan logic. + This is typically the last chance to prepare the devices before the core scan + logic is executed. For example, this is a good place to initialize time-criticial + devices, e.g. devices that have a short timeout. + The pre-scan logic is typically implemented on the device itself. + """ + + @scan_hook + def scan_core(self): + """ + Core scan logic to be executed during the scan. + This is where the main scan logic should be implemented. + """ + motors = [element[0] for element in self._devices_to_set] + target_positions = [element[1] for element in self._devices_to_set] + current_positions = self.components.get_start_positions(motors) + + self.actions.add_scan_report_instruction_readback( + devices=motors, + start=current_positions, + stop=target_positions, + request_id=self.scan_info.metadata["RID"], + ) + + self.components.move_and_wait(motors, target_positions) + # After the move is completed, we set the limits and signals. + for dev_name, (dev_obj, low_limit, high_limit) in self._limits_to_set.items(): + dev_obj.limits = [low_limit, high_limit] + for signal_obj, target_value in self._signals_to_set: + signal_obj.set(target_value).wait() + + @scan_hook + def at_each_point(self): + """ + Logic to be executed at each point during the scan. This is called by the step_scan method at each point. + + Args: + motors (list[str | DeviceBase]): List of motor names or device instances being moved. + positions (np.ndarray): Current positions of the motors, shape (len(motors),). + last_positions (np.ndarray | None): Previous positions of the motors, shape (len(motors),) or None if this is the first point. + """ + + @scan_hook + def post_scan(self): + """ + Post-scan steps to be executed after the main scan logic. + """ + + @scan_hook + def unstage(self): + """Unstage the scan by executing post-scan steps.""" + + @scan_hook + def close_scan(self): + """Close the scan.""" + + @scan_hook + def on_exception(self, exception: Exception): + """ + Handle exceptions that occur during the scan. + This is a good place to implement any cleanup logic that needs to be executed in case of an exception, + such as returning the devices to a safe state or moving the motors back to their starting position. + """ + + ################# + ## Custom Methods + ################# + + def _fetch_config_for_label(self, state_name: str, target_label: str) -> SubDeviceStateConfig: + available_states_msg: AvailableBeamlineStatesMessage = self.redis_connector.get_last( + MessageEndpoints.available_beamline_states() + ) + configs = [ + state for state in available_states_msg["data"].states if state.name == state_name + ] + if len(configs) == 0: + raise ValueError(f"State {state_name} not found in available states.") + elif len(configs) > 1: # Should not be possible, but just in case + raise ValueError(f"Multiple states with name {state_name} found in available states.") + config: AggregatedStateConfig = configs[0] + if config.state_type != "AggregatedState": + raise ValueError( + f"State {state_name} is not an aggregated state. Transitions are only supported for aggregated states." + ) + available_labels = list(config.parameters["states"].keys()) + if target_label not in available_labels: + raise ValueError( + f"Target label {target_label} not found in state {state_name}. Available labels: {available_labels}" + ) + return SubDeviceStateConfig.model_validate(config.parameters["states"][target_label]) diff --git a/bec_server/bec_server/scan_server/scans/time_scan.py b/bec_server/bec_server/scan_server/scans/time_scan.py new file mode 100644 index 000000000..52b895967 --- /dev/null +++ b/bec_server/bec_server/scan_server/scans/time_scan.py @@ -0,0 +1,178 @@ +""" +Time scan implementation for repeated acquisition over a fixed duration or count. + +Scan procedure: + - prepare_scan + - open_scan + - stage + - pre_scan + - scan_core + - post_scan + - unstage + - close_scan + - on_exception (called if any exception is raised during the scan) +""" + +from __future__ import annotations + +import time +from typing import Annotated + +import numpy as np + +from bec_lib.scan_args import ScanArgument, Units +from bec_server.scan_server.scans.scan_modifier import scan_hook +from bec_server.scan_server.scans.scans_v4 import ScanBase, ScanType + + +class TimeScan(ScanBase): + # Scan Type: Hardware triggered or software triggered? + # If the main trigger and readout logic is done within the at_each_point method in scan_core, choose SOFTWARE_TRIGGERED. + # If the main trigger and readout logic is implemented on a device that is simply kicked off in this scan, choose HARDWARE_TRIGGERED. + # This primarily serves as information for devices: The device may need to react differently if a software trigger is expected + # for every point. + scan_type = ScanType.SOFTWARE_TRIGGERED + + # Scan name: This is the name of the scan, e.g. "line_scan". This is used for display purposes and to identify the scan type in user interfaces. + # Choose a descriptive name that does not conflict with existing scan names. + scan_name = "_v4_time_scan" + + gui_config = {"Scan Parameters": ["points", "interval", "exp_time", "settling_time"]} + + def __init__( + self, + points: int, + interval: Annotated[float, ScanArgument(display_name="Interval", units=Units.s, ge=0)], + exp_time: Annotated[ + float, ScanArgument(display_name="Exposure Time", units=Units.s, ge=0) + ] = 0, + settling_time: Annotated[ + float, ScanArgument(display_name="Settling Time", units=Units.s, ge=0) + ] = 0, + **kwargs, + ): + """ + Trigger and readout devices at a fixed interval. + Note that the interval time cannot be less than the exposure time. + The effective sleep time between points is: + sleep_time = max(interval - exp_time, 0) + + Args: + points (int): number of points + interval (float): time interval between points + exp_time (float): exposure time in seconds. Default is 0. + settling_time (float): settling time in seconds. Default is 0. + + Returns: + ScanReport + + Examples: + >>> scans.time_scan(10, 1.5, exp_time=0.1) + + """ + super().__init__(**kwargs) + self.motors = [] + self.points = points + self.interval = interval + self.exp_time = exp_time + self.settling_time = settling_time + self.sleep_time = max(interval - exp_time, 0) + + # Update the default scan info with provided parameters. + self.update_scan_info(exp_time=exp_time, settling_time=settling_time) + + @scan_hook + def prepare_scan(self): + """ + Prepare the scan. This can include any steps that need to be executed + before the scan is opened, such as preparing the positions (if not done already) + or setting up the devices. + """ + + self.update_scan_info( + positions=np.array([]), num_points=self.points, num_monitored_readouts=self.points + ) + + self.actions.add_scan_report_instruction_scan_progress(points=self.points, show_table=False) + + self._baseline_readout_status = self.actions.read_baseline_devices(wait=False) + + @scan_hook + def open_scan(self): + """ + Open the scan. + This step must call self.actions.open_scan() to ensure that a new scan is + opened. Make sure to prepare the scan metadata before, either in + prepare_scan() or in open_scan() itself and call self.update_scan_info(...) + to update the scan metadata if needed. + """ + self.actions.open_scan() + + @scan_hook + def stage(self): + """ + Stage the devices for the upcoming scan. The stage logic is typically + implemented on the device itself (i.e. by the device's stage method). + However, if there are any additional steps that need to be executed before + staging the devices, they can be implemented here. + """ + self.actions.stage_all_devices() + + @scan_hook + def pre_scan(self): + """ + Pre-scan steps to be executed before the main scan logic. + This is typically the last chance to prepare the devices before the core scan + logic is executed. For example, this is a good place to initialize time-criticial + devices, e.g. devices that have a short timeout. + The pre-scan logic is typically implemented on the device itself. + """ + self.actions.pre_scan_all_devices() + + @scan_hook + def scan_core(self): + """ + Core scan logic to be executed during the scan. + This is where the main scan logic should be implemented. + """ + for point_index in range(self.points): + self.at_each_point() + if point_index < self.points - 1 and self.sleep_time > 0: + time.sleep(self.sleep_time) + + @scan_hook + def at_each_point(self): + """ + Logic to be executed at each acquisition point during the scan. + This hook allows concrete time-scan variants to extend or override the + per-point behavior without reimplementing the full scan_core method. + """ + self.components.trigger_and_read() + + @scan_hook + def post_scan(self): + """ + Post-scan steps to be executed after the main scan logic. + """ + self.actions.complete_all_devices() + + @scan_hook + def unstage(self): + """Unstage the scan by executing post-scan steps.""" + self.actions.unstage_all_devices() + + @scan_hook + def close_scan(self): + """Close the scan.""" + if self._baseline_readout_status is not None: + self._baseline_readout_status.wait() + self.actions.close_scan() + self.actions.check_for_unchecked_statuses() + + @scan_hook + def on_exception(self, exception: Exception): + """ + Handle exceptions that occur during the scan. + This is a good place to implement any cleanup logic that needs to be executed in case of an exception, + such as returning the devices to a safe state or moving the motors back to their starting position. + """ diff --git a/bec_server/bec_server/scan_server/scans/updated_move.py b/bec_server/bec_server/scan_server/scans/updated_move.py new file mode 100644 index 000000000..f5a39988b --- /dev/null +++ b/bec_server/bec_server/scan_server/scans/updated_move.py @@ -0,0 +1,168 @@ +""" +Updated move scan implementation for coordinated motor repositioning commands. + +Scan procedure: + - prepare_scan + - open_scan + - stage + - pre_scan + - scan_core + - at_each_point (optionally called by scan_core) + - post_scan + - unstage + - close_scan + - on_exception (called if any exception is raised during the scan) +""" + +from __future__ import annotations + +from bec_lib.device import DeviceBase +from bec_lib.logger import bec_logger +from bec_server.scan_server.scans.scan_modifier import scan_hook +from bec_server.scan_server.scans.scans_v4 import ScanBase, bundle_args + +logger = bec_logger.logger + + +class UpdatedMoveScan(ScanBase): + + # Scan Type: Hardware triggered or software triggered? + # If the main trigger and readout logic is done within the at_each_point method in scan_core, choose SOFTWARE_TRIGGERED. + # If the main trigger and readout logic is implemented on a device that is simply kicked off in this scan, choose HARDWARE_TRIGGERED. + # This primarily serves as information for devices: The device may need to react differently if a software trigger is expected + # for every point. + scan_type = None + + # Scan name: This is the name of the scan, e.g. "line_scan". This is used for display purposes and to identify the scan type in user interfaces. + # Choose a descriptive name that does not conflict with existing scan names. + scan_name = "_v4_umv" + + # arg_input and arg_bundle_size are only relevant for scans that accept an arbitrary number of motor / position arguments (e.g. line scans, grid scans). + # For scans with a fixed set of parameters (e.g. Fermat spiral), these can be simply removed. + arg_input = {"device": DeviceBase, "target": float} + arg_bundle_size = {"bundle": len(arg_input), "min": 1, "max": None} + required_kwargs = ["relative"] + + # We set is_scan to False to separate this class from the other scans in the user interface + is_scan = False + + def __init__(self, *args, relative: bool = False, **kwargs): + """ + Simple move command that moves one or more motors to the specified positions. + The umv command is the blocking version of the mv command. + It waits for the motors to reach their target positions before returning control to the user. + + + Args: + *args (Device, float): pairs of device / target position arguments + relative (bool): if True, the motors will be moved relative to their current position. Default is False. + + Returns: + ScanReport + + Examples: + >>> scans.umv(dev.motor1, -5, dev.motor2, 5, relative=True) + + """ + super().__init__(**kwargs) + self.motor_args = args + self.motor_args_bundles = bundle_args(args, self.arg_bundle_size["bundle"]) + self.motors = list(self.motor_args_bundles.keys()) + self.relative = relative + + # Update the default scan info with provided parameters. + self.update_scan_info(relative=relative, scan_report_devices=self.motors) + + @scan_hook + def prepare_scan(self): + """ + Prepare the scan. This can include any steps that need to be executed + before the scan is opened, such as preparing the positions (if not done already) + or setting up the devices. + """ + + @scan_hook + def open_scan(self): + """ + Open the scan. + This step must call self.actions.open_scan() to ensure that a new scan is + opened. Make sure to prepare the scan metadata before, either in + prepare_scan() or in open_scan() itself and call self.update_scan_info(...) + to update the scan metadata if needed. + """ + + @scan_hook + def stage(self): + """ + Stage the devices for the upcoming scan. The stage logic is typically + implemented on the device itself (i.e. by the device's stage method). + However, if there are any additional steps that need to be executed before + staging the devices, they can be implemented here. + """ + + @scan_hook + def pre_scan(self): + """ + Pre-scan steps to be executed before the main scan logic. + This is typically the last chance to prepare the devices before the core scan + logic is executed. For example, this is a good place to initialize time-criticial + devices, e.g. devices that have a short timeout. + The pre-scan logic is typically implemented on the device itself. + """ + + @scan_hook + def scan_core(self): + """ + Core scan logic to be executed during the scan. + This is where the main scan logic should be implemented. + """ + current_positions = self.components.get_start_positions(self.motors) + target_positions = list(self.motor_args_bundles.values()) + target_positions = [pos[0] for pos in target_positions] + if self.relative: + target_positions = [ + target + current + for target, current in zip(target_positions, current_positions, strict=False) + ] + + self.actions.add_scan_report_instruction_readback( + devices=self.motors, + start=current_positions, + stop=target_positions, + request_id=self.scan_info.metadata["RID"], + ) + + self.components.move_and_wait(self.motors, target_positions) + + @scan_hook + def at_each_point(self): + """ + Logic to be executed at each point during the scan. This is called by the step_scan method at each point. + + Args: + motors (list[str | DeviceBase]): List of motor names or device instances being moved. + positions (np.ndarray): Current positions of the motors, shape (len(motors),). + last_positions (np.ndarray | None): Previous positions of the motors, shape (len(motors),) or None if this is the first point. + """ + + @scan_hook + def post_scan(self): + """ + Post-scan steps to be executed after the main scan logic. + """ + + @scan_hook + def unstage(self): + """Unstage the scan by executing post-scan steps.""" + + @scan_hook + def close_scan(self): + """Close the scan.""" + + @scan_hook + def on_exception(self, exception: Exception): + """ + Handle exceptions that occur during the scan. + This is a good place to implement any cleanup logic that needs to be executed in case of an exception, + such as returning the devices to a safe state or moving the motors back to their starting position. + """ diff --git a/bec_server/bec_server/scan_server/tests/scan_fixtures.py b/bec_server/bec_server/scan_server/tests/scan_fixtures.py new file mode 100644 index 000000000..19f51b0e0 --- /dev/null +++ b/bec_server/bec_server/scan_server/tests/scan_fixtures.py @@ -0,0 +1,414 @@ +import copy +import importlib +import inspect +import pkgutil +from collections.abc import Iterator +from types import SimpleNamespace +from typing import Any, get_type_hints +from unittest import mock + +import pytest + +from bec_lib import messages +from bec_lib.device import DeviceBase +from bec_lib.tests.utils import ConnectorMock +from bec_server.scan_server.instruction_handler import InstructionHandler +from bec_server.scan_server.scan_assembler import ScanAssembler +from bec_server.scan_server.scan_gui_models import GUIInput +from bec_server.scan_server.scans import ScanArgType +from bec_server.scan_server.scans.scans_v4 import ScanBase + + +class _DoneAfterNthCheckStatusMock: + def __init__(self, resolve_after: int = 1, result=None) -> None: + self.resolve_after = max(resolve_after, 1) + self.result = result + self.wait_calls = 0 + self._done_checks = 0 + + @property + def done(self): + self._done_checks += 1 + return self._done_checks >= self.resolve_after + + def wait(self, *args, **kwargs): + self.wait_calls += 1 + return self + + +@pytest.fixture +def nth_done_status_mock(): + def _build(resolve_after: int = 1, result=None): + return _DoneAfterNthCheckStatusMock(resolve_after=resolve_after, result=result) + + return _build + + +@pytest.fixture +def readout_priority(): + return SimpleNamespace( + monitored=[], baseline=["samx", "samy", "samz"], on_request=[], async_=[] + ) + + +@pytest.fixture +def device_manager(device_manager_class, session_from_test_config): + service_mock = mock.MagicMock() + service_mock.connector = ConnectorMock("", store_data=False) + dev_manager = device_manager_class(service_mock) + dev_manager._allow_override = True + dev_manager.config_update_handler = mock.MagicMock() + dev_manager._session = copy.deepcopy(session_from_test_config) + dev_manager._load_session() + dev_manager._v4_custom_devices = {} + + def _add_device(device: DeviceBase, replace=False): + if not isinstance(device, DeviceBase): + raise TypeError("device must be an instance of DeviceBase.") + if device.name in dev_manager.devices and not replace: + raise ValueError( + f"Device {device.name!r} already exists. Use replace=True to overwrite it." + ) + dev_manager.devices[device.name] = device + dev_manager._v4_custom_devices[device.name] = device + + dev_manager.add_device = _add_device + yield dev_manager + dev_manager.shutdown() + + +class _MockDevice(DeviceBase): + def __init__(self, name: str, limits=(-10.0, 10.0), value: float = 0.0): + info = { + "device_info": { + "signals": { + name: {"obj_name": name, "kind_str": "hinted", "describe": {"precision": 3}} + } + } + } + super().__init__(name=name, info=info) + self._limits = limits + self._value = value + self._enabled = True + self._precision = 3 + + def read(self, *args, **kwargs): + return {self.full_name: {"value": self._value}} + + @property + def root(self): + return self + + @property + def full_name(self): + return self.name + + @property + def limits(self): + return self._limits + + @property + def enabled(self): + return self._enabled + + @property + def precision(self): + return self._precision + + +class MockCustomDevice(DeviceBase): + def __init__( + self, + name: str, + device_info: dict, + signal_read_values: dict[str, float | Iterator] | None = None, + limits=(-10.0, 10.0), + precision: int = 3, + enabled: bool = True, + ): + """ + A mock device that implements the DeviceBase interface and allows for custom signal definitions and read values. + + Args: + name (str): The name of the device. + device_info (dict): A dictionary containing the device information, including signal definitions. Typically taken from + the "_info" field, e.g. dev.samx._info["signals"]. + signal_read_values (dict[str, float | Iterator], optional): A dictionary mapping signal names or their corresponding + obj_names to their read values. If a value is an iterator, the next value will be returned on each read. If not provided, + signals will return None. Note that the signal name should be the signal name, not the readout name (obj_name). + limits (tuple, optional): The limits of the device. Defaults to (-10.0, 10.0). + precision (int, optional): The precision of the device. Defaults to 3. + enabled (bool, optional): Whether the device is enabled. Defaults to True. + """ + info = {"device_info": device_info} + super().__init__(name=name, info=info) + self._limits = limits + self._precision = precision + self._enabled = enabled + self._signal_read_values = signal_read_values or {} + + for signal_name, signal_info in device_info.get("signals", {}).items(): + signal = getattr(self, signal_name, None) + if signal is None: + continue + obj_name = signal_info.get("obj_name", signal_name) + signal.get = mock.MagicMock( + side_effect=lambda signal_name=signal_name, obj_name=obj_name: ( + self._read_signal_value(signal_name, obj_name) + ) + ) + + def _read_signal_value(self, signal_name: str, obj_name: str): + value = self._signal_read_values.get(signal_name, self._signal_read_values.get(obj_name)) + if isinstance(value, Iterator): + return next(value) + return value + + def set_signal_value(self, signal_name: str, value: Any): + """ + Set the simulated read value for a signal. + + Args: + signal_name (str): The name of the signal to set the value for. This should be the signal name, not the obj_name. + value (Any): The value to set for the signal. This can be a single value or an iterator for multiple reads. + """ + if signal_name not in self._info.get("signals", {}): + raise ValueError(f"Signal {signal_name!r} is not defined in the device info.") + self._signal_read_values[signal_name] = value + + def read(self, *args, **kwargs): + data = {} + for signal_name, signal_info in self._info.get("signals", {}).items(): + kind = signal_info.get("kind_str", "").lower() + if kind in {"config", "omitted"}: + continue + obj_name = signal_info.get("obj_name", signal_name) + value = self._read_signal_value(signal_name, obj_name) + data[obj_name] = {"value": value} + return data + + def read_configuration(self, *args, **kwargs): + data = {} + for signal_name, signal_info in self._info.get("signals", {}).items(): + kind = signal_info.get("kind_str", "").lower() + if kind != "config": + continue + obj_name = signal_info.get("obj_name", signal_name) + value = self._read_signal_value(signal_name, obj_name) + data[obj_name] = {"value": value} + return data + + @property + def root(self): + return self + + @property + def full_name(self): + return self.name + + @property + def limits(self): + return self._limits + + @property + def enabled(self): + return self._enabled + + @property + def precision(self): + return self._precision + + +class _MockDevices(dict): + def __init__(self, devices: dict[str, DeviceBase], readout_priority: dict | None = None): + super().__init__(devices) + readout_priority = readout_priority or {} + self._base_readout_priority = { + "baseline": list(readout_priority.get("baseline", [])), + "monitored": list(readout_priority.get("monitored", [])), + "on_request": list(readout_priority.get("on_request", [])), + "async": list(readout_priority.get("async", [])), + } + + @property + def enabled_devices(self): + return list(self.values()) + + def _applied_readout_priority(self, readout_priority=None) -> dict[str, list[str]]: + groups = { + group_name: [device_name for device_name in device_names if device_name in self] + for group_name, device_names in self._base_readout_priority.items() + } + + for group_name in ["baseline", "monitored", "on_request", "async"]: + for device_name in (readout_priority or {}).get(group_name, []): + if device_name not in self: + continue + for existing_group in groups.values(): + if device_name in existing_group: + existing_group.remove(device_name) + groups[group_name].append(device_name) + + for group_name, device_names in groups.items(): + groups[group_name] = sorted(set(device_names)) + return groups + + def monitored_devices(self, readout_priority=None): + monitored = self._applied_readout_priority(readout_priority)["monitored"] + return [self[name] for name in monitored if name in self] + + def baseline_devices(self, readout_priority=None): + baseline = self._applied_readout_priority(readout_priority)["baseline"] + return [self[name] for name in baseline if name in self] + + def async_devices(self, readout_priority=None): + async_devices = self._applied_readout_priority(readout_priority)["async"] + return [self[name] for name in async_devices if name in self] + + def on_request_devices(self, readout_priority=None): + on_request = self._applied_readout_priority(readout_priority)["on_request"] + return [self[name] for name in on_request if name in self] + + def get_software_triggered_devices(self): + return [] + + +def _infer_v4_device_names(scan_cls, scan_args: tuple, scan_kwargs: dict) -> list[str]: + arg_input = getattr(scan_cls, "arg_input", {}) or {} + if not arg_input: + type_hints = get_type_hints(scan_cls.__init__) + signature = inspect.signature(scan_cls) + arg_input = { + name: type_hints.get(name, parameter.annotation) + for name, parameter in signature.parameters.items() + if name not in {"args", "kwargs"} + and parameter.annotation is not inspect.Parameter.empty + } + if not arg_input: + return [] + + device_names = [] + bundle_size = scan_cls.arg_bundle_size["bundle"] + + def _is_device_arg(arg_type) -> bool: + converted = GUIInput.convert_to_legacy_scan_arg_type(arg_type) + if converted == ScanArgType.DEVICE: + return True + return inspect.isclass(converted) and issubclass(converted, DeviceBase) + + if bundle_size > 0: + arg_names = list(arg_input.keys()) + for bundle_start in range(0, len(scan_args), bundle_size): + for offset, arg_name in enumerate(arg_names): + arg_index = bundle_start + offset + if arg_index >= len(scan_args): + break + if _is_device_arg(arg_input.get(arg_name)): + device_names.append(scan_args[arg_index]) + else: + bound = inspect.signature(scan_cls).bind_partial(*scan_args, **scan_kwargs) + for arg_name, value in bound.arguments.items(): + if arg_name == "args": + continue + if _is_device_arg(arg_input.get(arg_name)): + device_names.append(value) + + for arg_name, arg_type in arg_input.items(): + if _is_device_arg(arg_type) and arg_name in scan_kwargs: + device_names.append(scan_kwargs[arg_name]) + + return [name for name in device_names if isinstance(name, str)] + + +def _base_readout_priority(readout_priority) -> dict[str, list[str]]: + return { + "monitored": list(readout_priority.monitored), + "baseline": list(readout_priority.baseline), + "on_request": list(readout_priority.on_request), + "async": list(readout_priority.async_), + } + + +def _get_v4_scan_classes() -> dict[str, type[ScanBase]]: + import bec_server.scan_server.scans as scans_v4_module + + scan_classes = {} + for module_info in pkgutil.iter_modules( + scans_v4_module.__path__, prefix=f"{scans_v4_module.__name__}." + ): + module = importlib.import_module(module_info.name) + for _, scan_cls in inspect.getmembers(module, predicate=inspect.isclass): + if scan_cls.__module__ != module.__name__: + continue + if not issubclass(scan_cls, ScanBase): + continue + scan_name = getattr(scan_cls, "scan_name", None) + if not scan_name or scan_name == "_v4_base_scan": + continue + scan_classes[scan_name] = scan_cls + if scan_name.startswith("_v4_"): + scan_classes[scan_name.removeprefix("_v4_")] = scan_cls + return scan_classes + + +@pytest.fixture +def v4_scan_assembler(readout_priority, device_manager): + scan_classes = _get_v4_scan_classes() + + def _assemble_scan(scan_type, *scan_args, **scan_kwargs): + scan_id = scan_kwargs.pop("scan_id", "scan-id-test") + + try: + scan_cls = scan_classes[scan_type] + except KeyError as exc: + available = ", ".join(sorted(scan_classes)) + raise KeyError(f"Unknown v4 scan type '{scan_type}'. Available: {available}") from exc + + connector = ConnectorMock("") + instruction_handler = InstructionHandler(connector) + base_readout_priority = _base_readout_priority(readout_priority) + device_names = sorted( + set(_infer_v4_device_names(scan_cls, scan_args, scan_kwargs)) + | set(base_readout_priority["monitored"]) + | set(base_readout_priority["baseline"]) + | set(base_readout_priority["on_request"]) + | set(base_readout_priority["async"]) + ) + custom_devices = getattr(device_manager, "_v4_custom_devices", {}) + preloaded_devices = { + name: custom_devices[name] for name in device_names if name in custom_devices + } + devices = _MockDevices(preloaded_devices, readout_priority=base_readout_priority) + for name in device_names: + if name in devices: + continue + devices[name] = _MockDevice(name) + scan_device_manager = SimpleNamespace(devices=devices, connector=connector) + resolved_scan_kwargs = { + "system_config": {"file_directory": "/tmp/data/S00000"}, + **scan_kwargs, + } + + parent = mock.MagicMock() + parent.device_manager = scan_device_manager + parent.connector = connector + parent.queue_manager.instruction_handler = instruction_handler + parent.scan_manager = SimpleNamespace(scan_dict={scan_type: scan_cls}) + + assembler = ScanAssembler(parent=parent) + msg = messages.ScanQueueMessage( + metadata={"RID": "rid-test"}, + scan_type=scan_type, + parameter={"args": list(scan_args), "kwargs": resolved_scan_kwargs}, + queue="primary", + ) + scan = assembler.assemble_direct_scan(msg, scan_id) + scan._test = SimpleNamespace( + connector=connector, + instruction_handler=instruction_handler, + device_manager=scan_device_manager, + assembler=assembler, + ) + return scan + + return _assemble_scan diff --git a/bec_server/bec_server/scan_server/tests/scan_hook_tests.py b/bec_server/bec_server/scan_server/tests/scan_hook_tests.py new file mode 100644 index 000000000..b187efcde --- /dev/null +++ b/bec_server/bec_server/scan_server/tests/scan_hook_tests.py @@ -0,0 +1,138 @@ +from unittest import mock + + +def assert_prepare_scan_reads_baseline_devices(scan): + baseline_status = mock.MagicMock() + scan.actions.read_baseline_devices = mock.MagicMock(return_value=baseline_status) + + scan.prepare_scan() + + scan.actions.read_baseline_devices.assert_called_once_with(wait=False) + assert scan._baseline_readout_status is baseline_status + + +def assert_prepare_scan_starts_premove_move(scan): + premove_status = mock.MagicMock() + scan.actions.set = mock.MagicMock(return_value=premove_status) + + scan.prepare_scan() + + assert scan.actions.set.call_count >= 1 + assert any(call.kwargs.get("wait") is False for call in scan.actions.set.call_args_list) + assert scan._premove_motor_status is premove_status + + +def assert_scan_open_called(scan): + scan.actions.open_scan = mock.MagicMock() + + scan.open_scan() + + scan.actions.open_scan.assert_called_once_with() + + +def assert_stage_all_devices_called(scan): + scan.actions.stage_all_devices = mock.MagicMock() + + scan.stage() + + scan.actions.stage_all_devices.assert_called_once_with() + + +def assert_pre_scan_called(scan): + scan._premove_motor_status = mock.MagicMock() + scan.actions.pre_scan_all_devices = mock.MagicMock() + + scan.pre_scan() + + scan.actions.pre_scan_all_devices.assert_called_once_with() + + +def assert_pre_scan_waits_for_premove(scan): + premove_status = mock.MagicMock() + scan._premove_motor_status = premove_status + scan.actions.pre_scan_all_devices = mock.MagicMock() + + scan.pre_scan() + + premove_status.wait.assert_called_once_with() + scan.actions.pre_scan_all_devices.assert_called_once_with() + + +def assert_unstage_all_devices_called(scan): + scan.actions.unstage_all_devices = mock.MagicMock() + + scan.unstage() + + scan.actions.unstage_all_devices.assert_called_once_with() + + +def assert_close_scan_waits_for_baseline_and_closes(scan, nth_done_status_mock): + baseline_status = nth_done_status_mock(resolve_after=2) + scan._baseline_readout_status = baseline_status + scan.actions.close_scan = mock.MagicMock() + scan.actions.check_for_unchecked_statuses = mock.MagicMock() + + scan.close_scan() + + assert baseline_status.wait_calls == 1 + scan.actions.close_scan.assert_called_once_with() + scan.actions.check_for_unchecked_statuses.assert_called_once_with() + + +def assert_scan_core_delegates_to_step_scan(scan): + scan.prepare_scan() + scan.components.step_scan = mock.MagicMock() + + scan.scan_core() + + scan.components.step_scan.assert_called_once() + args, kwargs = scan.components.step_scan.call_args + assert args == (scan.motors, scan.scan_info.positions) + assert kwargs["at_each_point"] == scan.at_each_point + if "last_positions" in kwargs: + assert (kwargs["last_positions"] == scan.positions[0]).all() + + +def assert_post_scan_waits_for_completion_and_moves_back_when_relative(scan, nth_done_status_mock): + completion_status = nth_done_status_mock(resolve_after=3) + scan.relative = True + scan.start_positions = [1.2, -0.7] + scan.actions.complete_all_devices = mock.MagicMock(return_value=completion_status) + scan.components.move_and_wait = mock.MagicMock() + + scan.post_scan() + + scan.actions.complete_all_devices.assert_called_once_with(wait=False) + scan.components.move_and_wait.assert_called_once_with(scan.motors, scan.start_positions) + assert completion_status.wait_calls == 1 + + +DEFAULT_HOOK_TESTS = [ + ("prepare_scan", [assert_prepare_scan_reads_baseline_devices]), + ("open_scan", [assert_scan_open_called]), + ("stage", [assert_stage_all_devices_called]), + ("pre_scan", [assert_pre_scan_called]), + ("unstage", [assert_unstage_all_devices_called]), + ("close_scan", [assert_close_scan_waits_for_baseline_and_closes]), +] + + +PREMOVE_HOOK_TESTS = [ + ("prepare_scan", [assert_prepare_scan_starts_premove_move]), + ("pre_scan", [assert_pre_scan_waits_for_premove]), +] + + +STANDARD_STEP_SCAN_TESTS = [ + ("scan_core", [assert_scan_core_delegates_to_step_scan]), + ("post_scan", [assert_post_scan_waits_for_completion_and_moves_back_when_relative]), +] + + +def run_scan_tests(scan, tests, nth_done_status_mock=None): + for test_name, assertions in tests: + for assertion in assertions: + if test_name in {"close_scan", "post_scan"}: + assertion(scan, nth_done_status_mock) + else: + assertion(scan) diff --git a/bec_server/tests/tests_device_server/conftest.py b/bec_server/tests/tests_device_server/conftest.py index c760a9467..f2e39cc6d 100644 --- a/bec_server/tests/tests_device_server/conftest.py +++ b/bec_server/tests/tests_device_server/conftest.py @@ -1,3 +1,6 @@ +import os + +os.environ.setdefault("OPHYD_CONTROL_LAYER", "dummy") import fakeredis import pytest diff --git a/bec_server/tests/tests_device_server/test_device_manager_ds.py b/bec_server/tests/tests_device_server/test_device_manager_ds.py index f42401b62..01ec78df0 100644 --- a/bec_server/tests/tests_device_server/test_device_manager_ds.py +++ b/bec_server/tests/tests_device_server/test_device_manager_ds.py @@ -5,6 +5,7 @@ import numpy as np import pytest from ophyd_devices.devices.psi_motor import EpicsMotor +from ophyd_devices.tests.utils import patched_device from bec_lib import messages from bec_lib.bec_errors import DeviceConfigError @@ -451,9 +452,8 @@ def epics_motor_config(): @pytest.fixture def epics_motor(): - - motor = EpicsMotor(prefix="TEST:MOTOR", name="test_motor") - return motor + with patched_device(EpicsMotor, prefix="TEST:MOTOR", name="test_motor") as motor: + yield motor @pytest.mark.parametrize( diff --git a/bec_server/tests/tests_device_server/test_device_server.py b/bec_server/tests/tests_device_server/test_device_server.py index d4edb5bf5..aa9ca786c 100644 --- a/bec_server/tests/tests_device_server/test_device_server.py +++ b/bec_server/tests/tests_device_server/test_device_server.py @@ -732,6 +732,30 @@ def test_read_device(device_server_mock, instr): assert res[-1]["msg"].metadata["stream"] == "primary" +@pytest.mark.parametrize("device_manager_class", [DeviceManagerDS]) +def test_read_device_can_return_result(device_server_mock): + device_server = device_server_mock + instr = messages.DeviceInstructionMessage( + device=["samx", "samy"], + action="read", + parameter={"return_result": True}, + metadata={"stream": "primary", "device_instr_id": "diid", "RID": "test"}, + ) + + device_server._read_device(instr) + + responses = [ + msg["msg"] + for msg in device_server.connector.message_sent + if msg["queue"] == MessageEndpoints.device_instructions_response() + ] + response = responses[-1] + assert response.result is not None + assert len(response.result) == 2 + assert response.result[0].keys() == device_server.device_manager.devices.samx.obj.read().keys() + assert response.result[1].keys() == device_server.device_manager.devices.samy.obj.read().keys() + + @pytest.mark.parametrize("devices", [["samx", "samy"], ["samx"]]) @pytest.mark.parametrize("device_manager_class", [DeviceManagerDS]) def test_read_config_and_update_devices(device_server_mock, devices): diff --git a/bec_server/tests/tests_file_writer/test_file_writer_manager.py b/bec_server/tests/tests_file_writer/test_file_writer_manager.py index 98f736648..f3dd6565d 100644 --- a/bec_server/tests/tests_file_writer/test_file_writer_manager.py +++ b/bec_server/tests/tests_file_writer/test_file_writer_manager.py @@ -229,7 +229,7 @@ def test_scan_storage_append(scan_storage_mock): def test_scan_storage_ready_to_write(scan_storage_mock): storage = scan_storage_mock - storage.num_points = 1 + storage.num_monitored_readouts = 1 storage.scan_finished = True storage.append(1, {"data": "data"}) assert storage.ready_to_write() is True @@ -265,12 +265,12 @@ def test_ready_to_write(file_writer_manager_mock, scan_storage_mock): ) file_manager.scan_storage["scan_id"] = scan_storage_mock file_manager.scan_storage["scan_id"].scan_finished = True - file_manager.scan_storage["scan_id"].num_points = 1 + file_manager.scan_storage["scan_id"].num_monitored_readouts = 1 file_manager.scan_storage["scan_id"].scan_segments = {"0": {"data": np.zeros((10, 10))}} assert file_manager.scan_storage["scan_id"].ready_to_write() is True file_manager.scan_storage["scan_id1"] = scan_storage_mock file_manager.scan_storage["scan_id1"].scan_finished = True - file_manager.scan_storage["scan_id1"].num_points = 2 + file_manager.scan_storage["scan_id1"].num_monitored_readouts = 2 file_manager.scan_storage["scan_id1"].scan_segments = {"0": {"data": np.zeros((10, 10))}} assert file_manager.scan_storage["scan_id1"].ready_to_write() is False scan_storage_mock.status_msg = messages.ScanStatusMessage( diff --git a/bec_server/tests/tests_scan_server/conftest.py b/bec_server/tests/tests_scan_server/conftest.py index 75bb68c11..b20871b72 100644 --- a/bec_server/tests/tests_scan_server/conftest.py +++ b/bec_server/tests/tests_scan_server/conftest.py @@ -1,3 +1,7 @@ +import os + +os.environ.setdefault("OPHYD_CONTROL_LAYER", "dummy") + import fakeredis import pytest @@ -7,6 +11,8 @@ # overwrite threads_check fixture from bec_lib, # to have it in autouse +pytest_plugins = ["bec_server.scan_server.tests.scan_fixtures"] + @pytest.fixture(autouse=True) def threads_check(threads_check): diff --git a/bec_server/tests/tests_scan_server/scans_v4/__init__.py b/bec_server/tests/tests_scan_server/scans_v4/__init__.py new file mode 100644 index 000000000..8b1378917 --- /dev/null +++ b/bec_server/tests/tests_scan_server/scans_v4/__init__.py @@ -0,0 +1 @@ + diff --git a/bec_server/tests/tests_scan_server/scans_v4/test_acquire.py b/bec_server/tests/tests_scan_server/scans_v4/test_acquire.py new file mode 100644 index 000000000..5e807e6ab --- /dev/null +++ b/bec_server/tests/tests_scan_server/scans_v4/test_acquire.py @@ -0,0 +1,69 @@ +from unittest import mock + +import pytest + +from bec_server.scan_server.tests.scan_hook_tests import ( + assert_close_scan_waits_for_baseline_and_closes, + assert_pre_scan_called, + assert_prepare_scan_reads_baseline_devices, + assert_scan_open_called, + assert_stage_all_devices_called, + assert_unstage_all_devices_called, + run_scan_tests, +) + +ACQUIRE_DEFAULT_HOOK_TESTS = [ + ("prepare_scan", [assert_prepare_scan_reads_baseline_devices]), + ("open_scan", [assert_scan_open_called]), + ("stage", [assert_stage_all_devices_called]), + ("pre_scan", [assert_pre_scan_called]), + ("unstage", [assert_unstage_all_devices_called]), + ("close_scan", [assert_close_scan_waits_for_baseline_and_closes]), +] + + +@pytest.mark.parametrize(("hook_name", "hook_tests"), ACQUIRE_DEFAULT_HOOK_TESTS) +def test_acquire_default_hooks(v4_scan_assembler, nth_done_status_mock, hook_name, hook_tests): + scan = v4_scan_assembler("acquire", exp_time=0.2, burst_at_each_point=3) + + run_scan_tests(scan, [(hook_name, hook_tests)], nth_done_status_mock=nth_done_status_mock) + + +def test_acquire_prepare_scan_updates_scan_info_and_queue(v4_scan_assembler): + scan = v4_scan_assembler("acquire", exp_time=0.2, burst_at_each_point=3) + + scan.prepare_scan() + + assert scan.scan_info.num_points == 1 + assert scan.scan_info.num_monitored_readouts == 3 + assert scan.scan_info.positions.size == 0 + assert scan.scan_info.scan_report_instructions == [ + {"scan_progress": {"points": 3, "show_table": False}} + ] + + +def test_acquire_scan_core_triggers_and_reads_for_each_burst(v4_scan_assembler): + scan = v4_scan_assembler("acquire", exp_time=0.2, burst_at_each_point=3) + scan.at_each_point = mock.MagicMock() + + scan.scan_core() + + assert scan.at_each_point.call_count == 3 + + +def test_acquire_at_each_point_triggers_and_reads(v4_scan_assembler): + scan = v4_scan_assembler("acquire", exp_time=0.2, burst_at_each_point=3) + scan.components.trigger_and_read = mock.MagicMock() + + scan.at_each_point() + + scan.components.trigger_and_read.assert_called_once_with() + + +def test_acquire_post_scan_completes_all_devices(v4_scan_assembler): + scan = v4_scan_assembler("acquire", exp_time=0.2, burst_at_each_point=3) + scan.actions.complete_all_devices = mock.MagicMock() + + scan.post_scan() + + scan.actions.complete_all_devices.assert_called_once_with() diff --git a/bec_server/tests/tests_scan_server/scans_v4/test_cont_line_scan.py b/bec_server/tests/tests_scan_server/scans_v4/test_cont_line_scan.py new file mode 100644 index 000000000..cfe6a996e --- /dev/null +++ b/bec_server/tests/tests_scan_server/scans_v4/test_cont_line_scan.py @@ -0,0 +1,260 @@ +from types import SimpleNamespace +from unittest import mock + +import numpy as np +import pytest + +from bec_server.scan_server.errors import ScanAbortion +from bec_server.scan_server.tests.scan_fixtures import MockCustomDevice +from bec_server.scan_server.tests.scan_hook_tests import ( + PREMOVE_HOOK_TESTS, + assert_close_scan_waits_for_baseline_and_closes, + assert_pre_scan_called, + assert_prepare_scan_reads_baseline_devices, + assert_scan_open_called, + assert_stage_all_devices_called, + assert_unstage_all_devices_called, + run_scan_tests, +) + +CONT_LINE_DEFAULT_HOOK_TESTS = [ + ("prepare_scan", [assert_prepare_scan_reads_baseline_devices]), + ("open_scan", [assert_scan_open_called]), + ("stage", [assert_stage_all_devices_called]), + ("pre_scan", [assert_pre_scan_called]), + ("unstage", [assert_unstage_all_devices_called]), + ("close_scan", [assert_close_scan_waits_for_baseline_and_closes]), + *PREMOVE_HOOK_TESTS, +] + + +def _assemble_cont_line_scan( + v4_scan_assembler, + device_manager, + *, + start=-1.0, + stop=1.0, + steps=3, + exp_time=0.1, + relative=False, + velocity=1.0, + acceleration=2.0, + precision=3, +): + device_info = { + "signals": { + "readback": {"obj_name": "samx", "kind_str": "hinted", "describe": {"precision": 3}}, + "velocity": { + "obj_name": "samx_velocity", + "kind_str": "config", + "describe": {"precision": 3}, + }, + "acceleration": { + "obj_name": "samx_acceleration", + "kind_str": "config", + "describe": {"precision": 3}, + }, + } + } + custom_samx = MockCustomDevice( + "samx", + device_info=device_info, + signal_read_values={ + "samx": 0.0, + "samx_velocity": velocity, + "samx_acceleration": acceleration, + }, + precision=precision, + ) + device_manager.add_device(custom_samx, replace=True) + return v4_scan_assembler( + "cont_line_scan", "samx", start, stop, steps=steps, exp_time=exp_time, relative=relative + ) + + +@pytest.mark.parametrize(("hook_name", "hook_tests"), CONT_LINE_DEFAULT_HOOK_TESTS) +def test_cont_line_scan_default_hooks( + v4_scan_assembler, device_manager, nth_done_status_mock, hook_name, hook_tests +): + scan = _assemble_cont_line_scan( + v4_scan_assembler, + device_manager, + start=-1.0, + stop=1.0, + steps=3, + exp_time=0.1, + relative=False, + ) + + run_scan_tests(scan, [(hook_name, hook_tests)], nth_done_status_mock=nth_done_status_mock) + + +def test_cont_line_scan_prepare_scan_updates_scan_info(v4_scan_assembler, device_manager): + scan = _assemble_cont_line_scan( + v4_scan_assembler, + device_manager, + start=-1.0, + stop=1.0, + steps=3, + exp_time=0.1, + relative=False, + ) + + scan.prepare_scan() + + assert np.array_equal(scan.positions, np.array([[-1.0], [0.0], [1.0]])) + assert scan.scan_info.num_points == 3 + assert scan.offset == 1.0 + + +def test_cont_line_scan_example_custom_device_manager_integration( + v4_scan_assembler, device_manager +): + custom_samx = MockCustomDevice( + "samx", + device_info={ + "signals": { + "readback": { + "obj_name": "samx", + "kind_str": "hinted", + "describe": {"precision": 3}, + }, + "velocity": { + "obj_name": "samx_velocity", + "kind_str": "config", + "describe": {"precision": 3}, + }, + "acceleration": { + "obj_name": "samx_acceleration", + "kind_str": "config", + "describe": {"precision": 3}, + }, + } + }, + signal_read_values={"samx": 2.5, "samx_velocity": 1.0, "samx_acceleration": 2.0}, + ) + device_manager.add_device(custom_samx, replace=True) + + scan = v4_scan_assembler( + "cont_line_scan", "samx", -1.0, 1.0, steps=3, exp_time=0.1, relative=False + ) + + assert scan.device is custom_samx + + +def test_mock_custom_device_supports_generated_signal_values(): + custom_samx = MockCustomDevice( + "samx", + device_info={ + "signals": { + "readback": { + "obj_name": "samx", + "kind_str": "hinted", + "describe": {"precision": 3}, + }, + "velocity": { + "obj_name": "samx_velocity", + "kind_str": "config", + "describe": {"precision": 3}, + }, + } + }, + signal_read_values={"samx": iter([0.0, 0.5]), "samx_velocity": iter([1.0, 2.0])}, + ) + + assert custom_samx.read()["samx"]["value"] == 0.0 + assert custom_samx.read()["samx"]["value"] == 0.5 + + assert custom_samx.velocity.get() == 1.0 + assert custom_samx.read_configuration()["samx_velocity"]["value"] == 2.0 + + custom_samx.set_signal_value("velocity", 5.0) + assert custom_samx.velocity.get() == 5.0 + + +def test_cont_line_scan_at_each_point_triggers_and_reads(v4_scan_assembler): + scan = v4_scan_assembler( + "cont_line_scan", "samx", -1.0, 1.0, steps=3, exp_time=0.1, relative=False + ) + scan.components.trigger_and_read = mock.MagicMock() + + scan.at_each_point() + + scan.components.trigger_and_read.assert_called_once_with() + + +def test_cont_line_scan_scan_core_moves_and_reads_at_matching_positions( + v4_scan_assembler, device_manager +): + scan = _assemble_cont_line_scan( + v4_scan_assembler, + device_manager, + start=-1.0, + stop=1.0, + steps=3, + exp_time=0.1, + relative=False, + ) + scan.prepare_scan() + start_status = SimpleNamespace(wait=mock.MagicMock()) + end_status = SimpleNamespace(wait=mock.MagicMock()) + scan.actions.set = mock.MagicMock(side_effect=[start_status, end_status]) + read_values = iter( + [{"samx": {"value": -1.0}}, {"samx": {"value": 0.0}}, {"samx": {"value": 1.0}}] + ) + scan.device.read = mock.MagicMock(side_effect=lambda **kwargs: next(read_values)) + scan.at_each_point = mock.MagicMock() + + scan.scan_core() + + assert scan.actions.set.call_args_list == [ + mock.call(scan.device, -2.0, wait=True), + mock.call(scan.device, 1.0, wait=False), + ] + end_status.wait.assert_called_once_with() + assert scan.at_each_point.call_count == 3 + + +def test_cont_line_scan_prepare_scan_raises_when_motor_too_fast(v4_scan_assembler, device_manager): + scan = _assemble_cont_line_scan( + v4_scan_assembler, + device_manager, + start=-1.0, + stop=1.0, + steps=3, + exp_time=10.0, + relative=False, + velocity=1.0, + acceleration=2.0, + ) + + with pytest.raises(ScanAbortion, match="moving too fast"): + scan.prepare_scan() + + +def test_cont_line_scan_post_scan_moves_back_when_relative(v4_scan_assembler, nth_done_status_mock): + scan = v4_scan_assembler( + "cont_line_scan", "samx", -1.0, 1.0, steps=3, exp_time=0.1, relative=True + ) + completion_status = nth_done_status_mock(resolve_after=2) + scan.start_positions = [1.5] + scan.actions.complete_all_devices = mock.MagicMock(return_value=completion_status) + scan.components.move_and_wait = mock.MagicMock() + + scan.post_scan() + + scan.actions.complete_all_devices.assert_called_once_with(wait=False) + scan.components.move_and_wait.assert_called_once_with(scan.motors, scan.start_positions) + assert completion_status.wait_calls == 1 + + +def test_cont_line_scan_on_exception_moves_back_when_relative(v4_scan_assembler): + scan = v4_scan_assembler( + "cont_line_scan", "samx", -1.0, 1.0, steps=3, exp_time=0.1, relative=True + ) + scan.start_positions = [1.5] + scan.components.move_and_wait = mock.MagicMock() + + scan.on_exception(RuntimeError("boom")) + + scan.components.move_and_wait.assert_called_once_with(scan.motors, scan.start_positions) diff --git a/bec_server/tests/tests_scan_server/scans_v4/test_fermat_scan.py b/bec_server/tests/tests_scan_server/scans_v4/test_fermat_scan.py new file mode 100644 index 000000000..b5bd4d49b --- /dev/null +++ b/bec_server/tests/tests_scan_server/scans_v4/test_fermat_scan.py @@ -0,0 +1,44 @@ +from unittest import mock + +import numpy as np +import pytest + +from bec_server.scan_server.tests.scan_hook_tests import ( + DEFAULT_HOOK_TESTS, + PREMOVE_HOOK_TESTS, + STANDARD_STEP_SCAN_TESTS, + run_scan_tests, +) + + +@pytest.mark.parametrize( + ("hook_name", "hook_tests"), + [*DEFAULT_HOOK_TESTS, *PREMOVE_HOOK_TESTS, *STANDARD_STEP_SCAN_TESTS], +) +def test_fermat_scan_default_hooks(v4_scan_assembler, nth_done_status_mock, hook_name, hook_tests): + scan = v4_scan_assembler("fermat_scan", "samx", -1.0, 1.0, "samy", -2.0, 2.0, step=0.5) + + run_scan_tests(scan, [(hook_name, hook_tests)], nth_done_status_mock=nth_done_status_mock) + + +def test_fermat_scan_prepare_scan_updates_scan_info_and_queue(v4_scan_assembler): + scan = v4_scan_assembler("fermat_scan", "samx", -1.0, 1.0, "samy", -2.0, 2.0, step=0.5) + + scan.prepare_scan() + + assert isinstance(scan.positions, np.ndarray) + assert scan.positions.shape[1] == 2 + assert scan.scan_info.num_points == len(scan.positions) + assert np.array_equal(scan.scan_info.positions, scan.positions) + assert scan.scan_info.scan_report_instructions == [ + {"scan_progress": {"points": len(scan.positions), "show_table": False}} + ] + + read_messages = [ + entry["msg"] + for entry in scan._test.connector.message_sent + if getattr(entry["msg"], "action", None) == "read" + ] + assert len(read_messages) == 1 + assert read_messages[0].device == ["samz"] + assert read_messages[0].metadata["readout_priority"] == "baseline" diff --git a/bec_server/tests/tests_scan_server/scans_v4/test_grid_scan.py b/bec_server/tests/tests_scan_server/scans_v4/test_grid_scan.py new file mode 100644 index 000000000..9a56242ac --- /dev/null +++ b/bec_server/tests/tests_scan_server/scans_v4/test_grid_scan.py @@ -0,0 +1,35 @@ +from unittest import mock + +import numpy as np +import pytest + +from bec_server.scan_server.tests.scan_hook_tests import ( + DEFAULT_HOOK_TESTS, + PREMOVE_HOOK_TESTS, + STANDARD_STEP_SCAN_TESTS, + run_scan_tests, +) + + +@pytest.mark.parametrize( + ("hook_name", "hook_tests"), + [*DEFAULT_HOOK_TESTS, *PREMOVE_HOOK_TESTS, *STANDARD_STEP_SCAN_TESTS], +) +def test_grid_scan_default_hooks(v4_scan_assembler, nth_done_status_mock, hook_name, hook_tests): + scan = v4_scan_assembler("grid_scan", "samx", -1.0, 1.0, 3, "samy", -2.0, 2.0, 5, snaked=True) + + run_scan_tests(scan, [(hook_name, hook_tests)], nth_done_status_mock=nth_done_status_mock) + + +def test_grid_scan_prepare_scan_updates_scan_info_and_queue(v4_scan_assembler): + scan = v4_scan_assembler("grid_scan", "samx", -1.0, 1.0, 3, "samy", -2.0, 2.0, 5, snaked=True) + + scan.prepare_scan() + + assert isinstance(scan.positions, np.ndarray) + assert scan.positions.shape == (15, 2) + assert scan.scan_info.num_points == 15 + assert np.array_equal(scan.scan_info.positions, scan.positions) + assert scan.scan_info.scan_report_instructions == [ + {"scan_progress": {"points": 15, "show_table": False}} + ] diff --git a/bec_server/tests/tests_scan_server/scans_v4/test_hexagonal_scan.py b/bec_server/tests/tests_scan_server/scans_v4/test_hexagonal_scan.py new file mode 100644 index 000000000..55daa9f71 --- /dev/null +++ b/bec_server/tests/tests_scan_server/scans_v4/test_hexagonal_scan.py @@ -0,0 +1,54 @@ +import numpy as np +import pytest + +from bec_server.scan_server.scans import position_generators +from bec_server.scan_server.tests.scan_hook_tests import ( + DEFAULT_HOOK_TESTS, + PREMOVE_HOOK_TESTS, + STANDARD_STEP_SCAN_TESTS, + run_scan_tests, +) + + +@pytest.mark.parametrize( + ("hook_name", "hook_tests"), + [*DEFAULT_HOOK_TESTS, *PREMOVE_HOOK_TESTS, *STANDARD_STEP_SCAN_TESTS], +) +def test_hexagonal_scan_default_hooks( + v4_scan_assembler, nth_done_status_mock, hook_name, hook_tests +): + scan = v4_scan_assembler( + "hexagonal_scan", "samx", -1.0, 1.0, 1.0, "samy", -1.0, 1.0, 1.0, relative=False + ) + + run_scan_tests(scan, [(hook_name, hook_tests)], nth_done_status_mock=nth_done_status_mock) + + +def test_hexagonal_scan_prepare_scan_updates_scan_info_and_queue(v4_scan_assembler): + scan = v4_scan_assembler( + "hexagonal_scan", "samx", -1.0, 1.0, 1.0, "samy", -1.0, 1.0, 1.0, relative=False + ) + + scan.prepare_scan() + + expected_positions = position_generators.hex_grid_2d( + [(-1.0, 1.0, 1.0), (-1.0, 1.0, 1.0)], snaked=True + ) + assert np.array_equal(scan.positions, expected_positions) + assert scan.scan_info.num_points == len(expected_positions) + assert np.array_equal(scan.scan_info.positions, expected_positions) + + +def test_hexagonal_scan_prepare_scan_offsets_positions_when_relative(v4_scan_assembler): + scan = v4_scan_assembler( + "hexagonal_scan", "samx", -1.0, 1.0, 1.0, "samy", -1.0, 1.0, 1.0, relative=True + ) + scan.components.get_start_positions = lambda motors: [5.0, -2.0] + + scan.prepare_scan() + + expected_positions = position_generators.hex_grid_2d( + [(-1.0, 1.0, 1.0), (-1.0, 1.0, 1.0)], snaked=True + ) + [5.0, -2.0] + assert scan.start_positions == [5.0, -2.0] + assert np.array_equal(scan.positions, expected_positions) diff --git a/bec_server/tests/tests_scan_server/scans_v4/test_line_scan.py b/bec_server/tests/tests_scan_server/scans_v4/test_line_scan.py new file mode 100644 index 000000000..ba5ab7ed9 --- /dev/null +++ b/bec_server/tests/tests_scan_server/scans_v4/test_line_scan.py @@ -0,0 +1,46 @@ +import numpy as np +import pytest + +from bec_server.scan_server.tests.scan_hook_tests import ( + DEFAULT_HOOK_TESTS, + PREMOVE_HOOK_TESTS, + STANDARD_STEP_SCAN_TESTS, + run_scan_tests, +) + + +@pytest.mark.parametrize( + ("hook_name", "hook_tests"), + [*DEFAULT_HOOK_TESTS, *PREMOVE_HOOK_TESTS, *STANDARD_STEP_SCAN_TESTS], +) +def test_line_scan_default_hooks(v4_scan_assembler, nth_done_status_mock, hook_name, hook_tests): + scan = v4_scan_assembler("line_scan", "samx", -1.0, 1.0, "samy", -2.0, 2.0, steps=5) + + run_scan_tests(scan, [(hook_name, hook_tests)], nth_done_status_mock=nth_done_status_mock) + + +def test_line_scan_prepare_scan_updates_scan_info_and_queue(v4_scan_assembler): + scan = v4_scan_assembler("line_scan", "samx", -1.0, 1.0, "samy", -2.0, 2.0, steps=5) + + scan.prepare_scan() + + expected_positions = np.array([[-1.0, -2.0], [-0.5, -1.0], [0.0, 0.0], [0.5, 1.0], [1.0, 2.0]]) + assert np.array_equal(scan.positions, expected_positions) + assert scan.scan_info.num_points == 5 + assert np.array_equal(scan.scan_info.positions, expected_positions) + assert scan.scan_info.scan_report_instructions == [ + {"scan_progress": {"points": 5, "show_table": False}} + ] + + +def test_line_scan_prepare_scan_offsets_positions_when_relative(v4_scan_assembler): + scan = v4_scan_assembler( + "line_scan", "samx", -1.0, 1.0, "samy", -2.0, 2.0, steps=5, relative=True + ) + scan.components.get_start_positions = lambda motors: [2.0, 3.0] + + scan.prepare_scan() + + expected_positions = np.array([[1.0, 1.0], [1.5, 2.0], [2.0, 3.0], [2.5, 4.0], [3.0, 5.0]]) + assert scan.start_positions == [2.0, 3.0] + assert np.array_equal(scan.positions, expected_positions) diff --git a/bec_server/tests/tests_scan_server/scans_v4/test_line_sweep_scan.py b/bec_server/tests/tests_scan_server/scans_v4/test_line_sweep_scan.py new file mode 100644 index 000000000..dac527d78 --- /dev/null +++ b/bec_server/tests/tests_scan_server/scans_v4/test_line_sweep_scan.py @@ -0,0 +1,235 @@ +from unittest import mock + +import numpy as np +import pytest + +from bec_lib import messages +from bec_lib.connector import MessageObject +from bec_lib.endpoints import MessageEndpoints +from bec_server.scan_server.tests.scan_hook_tests import ( + PREMOVE_HOOK_TESTS, + assert_close_scan_waits_for_baseline_and_closes, + assert_pre_scan_called, + assert_prepare_scan_reads_baseline_devices, + assert_scan_open_called, + assert_stage_all_devices_called, + assert_unstage_all_devices_called, + run_scan_tests, +) + +LINE_SWEEP_DEFAULT_HOOK_TESTS = [ + ("prepare_scan", [assert_prepare_scan_reads_baseline_devices]), + ("open_scan", [assert_scan_open_called]), + ("stage", [assert_stage_all_devices_called]), + ("pre_scan", [assert_pre_scan_called]), + ("unstage", [assert_unstage_all_devices_called]), + ("close_scan", [assert_close_scan_waits_for_baseline_and_closes]), + *PREMOVE_HOOK_TESTS, +] + + +def _device_readback_message(device_name: str, value: float) -> MessageObject: + endpoint = MessageEndpoints.device_readback(device_name) + return MessageObject( + topic=endpoint.endpoint, + value=messages.DeviceMessage(signals={device_name: {"value": value}}), + ) + + +@pytest.mark.parametrize(("hook_name", "hook_tests"), LINE_SWEEP_DEFAULT_HOOK_TESTS) +def test_line_sweep_scan_default_hooks( + v4_scan_assembler, nth_done_status_mock, hook_name, hook_tests +): + scan = v4_scan_assembler("line_sweep_scan", "samx", -5.0, 5.0, relative=True) + + run_scan_tests(scan, [(hook_name, hook_tests)], nth_done_status_mock=nth_done_status_mock) + + +def test_line_sweep_scan_prepare_scan_updates_scan_info(v4_scan_assembler): + scan = v4_scan_assembler( + "line_sweep_scan", + "samx", + -5.0, + 5.0, + exp_time=0.2, + frames_per_trigger=3, + max_update=0.4, + relative=False, + ) + + scan.prepare_scan() + + assert np.array_equal(scan.positions, np.array([[-5.0], [5.0]])) + assert scan.scan_info.num_points == 0 + assert scan.scan_info.exp_time == 0.2 + assert scan.scan_info.frames_per_trigger == 3 + assert scan.max_update == 0.4 + assert scan.scan_info.scan_report_instructions == [ + {"scan_progress": {"points": 0, "show_table": False}} + ] + + +def test_line_sweep_scan_at_each_point_triggers_and_reads(v4_scan_assembler): + scan = v4_scan_assembler("line_sweep_scan", "samx", -5.0, 5.0, relative=False) + scan.components.trigger_and_read = mock.MagicMock() + + scan.at_each_point() + + scan.components.trigger_and_read.assert_called_once_with() + + +def test_line_sweep_scan_scan_core_moves_and_reads_until_done( + v4_scan_assembler, nth_done_status_mock +): + scan = v4_scan_assembler("line_sweep_scan", "samx", -5.0, 5.0, min_update=0.1, relative=False) + scan.prepare_scan() + done_status = nth_done_status_mock(resolve_after=4) + scan.device.set = mock.MagicMock(return_value=done_status) + scan.at_each_point = mock.MagicMock() + scan.redis_connector.unregister = mock.MagicMock() + + def register_readback(endpoint, cb): + assert endpoint == MessageEndpoints.device_readback("samx") + cb(_device_readback_message("samx", 1.0)) + + scan.redis_connector.register = mock.MagicMock(side_effect=register_readback) + with mock.patch("bec_server.scan_server.scans.line_sweep_scan.time.sleep") as sleep_mock: + scan.scan_core() + + scan.device.set.assert_called_once_with(5.0) + scan.redis_connector.register.assert_called_once_with( + MessageEndpoints.device_readback("samx"), cb=scan._device_readback_callback + ) + scan.redis_connector.unregister.assert_called_once_with( + MessageEndpoints.device_readback("samx"), cb=scan._device_readback_callback + ) + scan.at_each_point.assert_called_once_with() + sleep_mock.assert_called_once_with(0.1) + + +def test_line_sweep_scan_scan_core_coalesces_multiple_readback_updates( + v4_scan_assembler, nth_done_status_mock +): + scan = v4_scan_assembler("line_sweep_scan", "samx", -5.0, 5.0, relative=False) + scan.prepare_scan() + done_status = nth_done_status_mock(resolve_after=2) + scan.device.set = mock.MagicMock(return_value=done_status) + scan.at_each_point = mock.MagicMock() + scan.redis_connector.unregister = mock.MagicMock() + + def register_readback(endpoint, cb): + assert endpoint == MessageEndpoints.device_readback("samx") + cb(_device_readback_message("samx", 1.0)) + cb(_device_readback_message("samx", 2.0)) + + scan.redis_connector.register = mock.MagicMock(side_effect=register_readback) + + scan.scan_core() + + scan.at_each_point.assert_called_once_with() + + +def test_line_sweep_scan_scan_core_reads_final_pending_update( + v4_scan_assembler, nth_done_status_mock +): + scan = v4_scan_assembler("line_sweep_scan", "samx", -5.0, 5.0, relative=False) + scan.prepare_scan() + done_status = nth_done_status_mock(resolve_after=2) + scan.device.set = mock.MagicMock(return_value=done_status) + scan.at_each_point = mock.MagicMock() + scan.redis_connector.unregister = mock.MagicMock() + + def register_readback(endpoint, cb): + assert endpoint == MessageEndpoints.device_readback("samx") + cb(_device_readback_message("samx", 1.0)) + + scan.redis_connector.register = mock.MagicMock(side_effect=register_readback) + + scan.scan_core() + + scan.at_each_point.assert_called_once_with() + + +def test_line_sweep_scan_scan_core_waits_for_event_when_no_update( + v4_scan_assembler, nth_done_status_mock +): + scan = v4_scan_assembler("line_sweep_scan", "samx", -5.0, 5.0, relative=False) + scan.prepare_scan() + done_status = nth_done_status_mock(resolve_after=2) + scan.device.set = mock.MagicMock(return_value=done_status) + scan.at_each_point = mock.MagicMock() + scan.redis_connector.unregister = mock.MagicMock() + wait_calls = [] + + def wait(timeout): + wait_calls.append(timeout) + return False + + scan._readback_update_event.wait = mock.MagicMock(side_effect=wait) + scan.redis_connector.register = mock.MagicMock() + + scan.scan_core() + + scan._readback_update_event.wait.assert_called_once_with(timeout=0.05) + assert wait_calls == [0.05] + scan.at_each_point.assert_not_called() + + +def test_line_sweep_scan_scan_core_triggers_read_when_max_update_expires( + v4_scan_assembler, nth_done_status_mock +): + scan = v4_scan_assembler("line_sweep_scan", "samx", -5.0, 5.0, max_update=0.1, relative=False) + scan.prepare_scan() + done_status = nth_done_status_mock(resolve_after=3) + scan.device.set = mock.MagicMock(return_value=done_status) + scan.at_each_point = mock.MagicMock() + scan.redis_connector.unregister = mock.MagicMock() + scan.redis_connector.register = mock.MagicMock() + scan._readback_update_event.wait = mock.MagicMock(side_effect=[False, False]) + + with mock.patch( + "bec_server.scan_server.scans.line_sweep_scan.time.time", + side_effect=[100.0, 100.05, 100.2, 100.2], + ): + scan.scan_core() + + assert scan._readback_update_event.wait.call_args_list == [ + mock.call(timeout=0.05), + mock.call(timeout=0.05), + ] + scan.at_each_point.assert_called_once_with() + + +def test_line_sweep_scan_device_readback_callback_sets_update_event(v4_scan_assembler): + scan = v4_scan_assembler("line_sweep_scan", "samx", -5.0, 5.0, relative=False) + assert scan._readback_update_event.is_set() is False + + scan._device_readback_callback(_device_readback_message("samx", 2.0)) + + assert scan._readback_update_event.is_set() is True + + +def test_line_sweep_scan_post_scan_moves_back_when_relative( + v4_scan_assembler, nth_done_status_mock +): + scan = v4_scan_assembler("line_sweep_scan", "samx", -5.0, 5.0, relative=True) + completion_status = nth_done_status_mock(resolve_after=2) + scan.start_positions = [1.0] + scan.actions.complete_all_devices = mock.MagicMock(return_value=completion_status) + scan.components.move_and_wait = mock.MagicMock() + + scan.post_scan() + + scan.actions.complete_all_devices.assert_called_once_with(wait=False) + scan.components.move_and_wait.assert_called_once_with(scan.motors, scan.start_positions) + assert completion_status.wait_calls == 1 + + +def test_line_sweep_scan_on_exception_moves_back_when_relative(v4_scan_assembler): + scan = v4_scan_assembler("line_sweep_scan", "samx", -5.0, 5.0, relative=True) + scan.start_positions = [1.0] + scan.components.move_and_wait = mock.MagicMock() + + scan.on_exception(RuntimeError("boom")) + + scan.components.move_and_wait.assert_called_once_with(scan.motors, scan.start_positions) diff --git a/bec_server/tests/tests_scan_server/scans_v4/test_list_scan.py b/bec_server/tests/tests_scan_server/scans_v4/test_list_scan.py new file mode 100644 index 000000000..d22996dcc --- /dev/null +++ b/bec_server/tests/tests_scan_server/scans_v4/test_list_scan.py @@ -0,0 +1,46 @@ +import numpy as np +import pytest + +from bec_server.scan_server.tests.scan_hook_tests import ( + DEFAULT_HOOK_TESTS, + PREMOVE_HOOK_TESTS, + STANDARD_STEP_SCAN_TESTS, + run_scan_tests, +) + + +@pytest.mark.parametrize( + ("hook_name", "hook_tests"), + [*DEFAULT_HOOK_TESTS, *PREMOVE_HOOK_TESTS, *STANDARD_STEP_SCAN_TESTS], +) +def test_list_scan_default_hooks(v4_scan_assembler, nth_done_status_mock, hook_name, hook_tests): + scan = v4_scan_assembler("list_scan", "samx", [0, 1, 2], "samy", [3, 4, 5], relative=False) + + run_scan_tests(scan, [(hook_name, hook_tests)], nth_done_status_mock=nth_done_status_mock) + + +def test_list_scan_prepare_scan_updates_scan_info_and_queue(v4_scan_assembler): + scan = v4_scan_assembler("list_scan", "samx", [0, 1, 2], "samy", [3, 4, 5], relative=False) + + scan.prepare_scan() + + expected_positions = np.array([[0.0, 3.0], [1.0, 4.0], [2.0, 5.0]]) + assert np.array_equal(scan.positions, expected_positions) + assert scan.scan_info.num_points == 3 + assert np.array_equal(scan.scan_info.positions, expected_positions) + + +def test_list_scan_prepare_scan_offsets_positions_when_relative(v4_scan_assembler): + scan = v4_scan_assembler("list_scan", "samx", [0, 1, 2], "samy", [3, 4, 5], relative=True) + scan.components.get_start_positions = lambda motors: [1.0, -1.0] + + scan.prepare_scan() + + expected_positions = np.array([[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]]) + assert scan.start_positions == [1.0, -1.0] + assert np.array_equal(scan.positions, expected_positions) + + +def test_list_scan_raises_for_different_lengths(v4_scan_assembler): + with pytest.raises(ValueError, match="equal length"): + v4_scan_assembler("list_scan", "samx", [0, 1], "samy", [0, 1, 2], relative=False) diff --git a/bec_server/tests/tests_scan_server/scans_v4/test_log_scan.py b/bec_server/tests/tests_scan_server/scans_v4/test_log_scan.py new file mode 100644 index 000000000..6c2ed4e22 --- /dev/null +++ b/bec_server/tests/tests_scan_server/scans_v4/test_log_scan.py @@ -0,0 +1,52 @@ +import numpy as np +import pytest + +from bec_server.scan_server.tests.scan_hook_tests import ( + DEFAULT_HOOK_TESTS, + PREMOVE_HOOK_TESTS, + STANDARD_STEP_SCAN_TESTS, + run_scan_tests, +) + + +@pytest.mark.parametrize( + ("hook_name", "hook_tests"), + [*DEFAULT_HOOK_TESTS, *PREMOVE_HOOK_TESTS, *STANDARD_STEP_SCAN_TESTS], +) +def test_log_scan_default_hooks(v4_scan_assembler, nth_done_status_mock, hook_name, hook_tests): + scan = v4_scan_assembler("log_scan", "samx", 0.1, 10.0, "samy", 0.01, 1.0, steps=3) + + run_scan_tests(scan, [(hook_name, hook_tests)], nth_done_status_mock=nth_done_status_mock) + + +def test_log_scan_prepare_scan_updates_scan_info_and_queue(v4_scan_assembler): + scan = v4_scan_assembler("log_scan", "samx", 0.1, 10.0, "samy", 0.01, 1.0, steps=3) + + scan.prepare_scan() + + middle_progress = (np.sqrt(10) - 1) / 9 + expected_positions = np.array( + [[0.1, 0.01], [0.1 + middle_progress * 9.9, 0.01 + middle_progress * 0.99], [10.0, 1.0]] + ) + assert np.allclose(scan.positions, expected_positions) + assert scan.scan_info.num_points == 3 + assert np.allclose(scan.scan_info.positions, expected_positions) + assert scan.scan_info.scan_report_instructions == [ + {"scan_progress": {"points": 3, "show_table": False}} + ] + + +def test_log_scan_prepare_scan_offsets_positions_when_relative(v4_scan_assembler): + scan = v4_scan_assembler( + "log_scan", "samx", -1.0, 1.0, "samy", 0.0, 1.0, steps=3, relative=True + ) + scan.components.get_start_positions = lambda motors: [2.0, 3.0] + + scan.prepare_scan() + + middle_progress = (np.sqrt(10) - 1) / 9 + expected_positions = np.array( + [[1.0, 3.0], [1.0 + middle_progress * 2.0, 3.0 + middle_progress], [3.0, 4.0]] + ) + assert scan.start_positions == [2.0, 3.0] + assert np.allclose(scan.positions, expected_positions) diff --git a/bec_server/tests/tests_scan_server/scans_v4/test_move_scan.py b/bec_server/tests/tests_scan_server/scans_v4/test_move_scan.py new file mode 100644 index 000000000..766906998 --- /dev/null +++ b/bec_server/tests/tests_scan_server/scans_v4/test_move_scan.py @@ -0,0 +1,42 @@ +from unittest import mock + +import pytest + + +@pytest.mark.parametrize( + ("hook_name",), + [("open_scan",), ("stage",), ("pre_scan",), ("post_scan",), ("unstage",), ("close_scan",)], +) +def test_move_scan_default_noop_hooks_do_not_raise(v4_scan_assembler, hook_name): + scan = v4_scan_assembler("mv", "samx", 1.5, "samy", -2.0) + + getattr(scan, hook_name)() + + +def test_move_scan_prepare_scan_registers_required_response_devices(v4_scan_assembler): + scan = v4_scan_assembler("mv", "samx", 1.5, "samy", -2.0) + scan.actions.add_device_with_required_response = mock.MagicMock() + + scan.prepare_scan() + + scan.actions.add_device_with_required_response.assert_called_once_with(scan.motors) + + +def test_move_scan_scan_core_sets_absolute_targets_without_wait(v4_scan_assembler): + scan = v4_scan_assembler("mv", "samx", 1.5, "samy", -2.0) + scan.actions.set = mock.MagicMock() + + scan.scan_core() + + scan.actions.set.assert_called_once_with(scan.motors, [1.5, -2.0], wait=False) + + +def test_move_scan_scan_core_sets_relative_targets_without_wait(v4_scan_assembler): + scan = v4_scan_assembler("mv", "samx", 1.5, "samy", -2.0, relative=True) + scan.components.get_start_positions = mock.MagicMock(return_value=[0.5, 3.0]) + scan.actions.set = mock.MagicMock() + + scan.scan_core() + + scan.components.get_start_positions.assert_called_once_with(scan.motors) + scan.actions.set.assert_called_once_with(scan.motors, [2.0, 1.0], wait=False) diff --git a/bec_server/tests/tests_scan_server/scans_v4/test_multi_region_grid_scan.py b/bec_server/tests/tests_scan_server/scans_v4/test_multi_region_grid_scan.py new file mode 100644 index 000000000..769f75d1a --- /dev/null +++ b/bec_server/tests/tests_scan_server/scans_v4/test_multi_region_grid_scan.py @@ -0,0 +1,102 @@ +import numpy as np +import pytest + +from bec_server.scan_server.tests.scan_hook_tests import ( + DEFAULT_HOOK_TESTS, + PREMOVE_HOOK_TESTS, + STANDARD_STEP_SCAN_TESTS, + run_scan_tests, +) + + +@pytest.mark.parametrize( + ("hook_name", "hook_tests"), + [*DEFAULT_HOOK_TESTS, *PREMOVE_HOOK_TESTS, *STANDARD_STEP_SCAN_TESTS], +) +def test_multi_region_grid_scan_default_hooks( + v4_scan_assembler, nth_done_status_mock, hook_name, hook_tests +): + scan = v4_scan_assembler( + "multi_region_grid_scan", + "samx", + "samy", + regions=[[(-3.0, -1.0, 2), (-2.0, 2.0, 3)], [(1.0, 3.0, 2), (-2.0, 2.0, 3)]], + snaked=True, + ) + + run_scan_tests(scan, [(hook_name, hook_tests)], nth_done_status_mock=nth_done_status_mock) + + +def test_multi_region_grid_scan_prepare_scan_updates_scan_info_and_queue(v4_scan_assembler): + scan = v4_scan_assembler( + "multi_region_grid_scan", + "samx", + "samy", + regions=[[(-3.0, -1.0, 2), (-2.0, 2.0, 3)], [(1.0, 3.0, 2), (-2.0, 2.0, 3)]], + snaked=True, + ) + + scan.prepare_scan() + + expected_positions = np.array( + [ + [-3.0, -2.0], + [-3.0, 0.0], + [-3.0, 2.0], + [-1.0, 2.0], + [-1.0, 0.0], + [-1.0, -2.0], + [1.0, -2.0], + [1.0, 0.0], + [1.0, 2.0], + [3.0, 2.0], + [3.0, 0.0], + [3.0, -2.0], + ] + ) + assert np.allclose(scan.positions, expected_positions) + assert scan.scan_info.num_points == len(expected_positions) + assert np.allclose(scan.scan_info.positions, expected_positions) + assert scan.scan_info.scan_report_instructions == [ + {"scan_progress": {"points": len(expected_positions), "show_table": False}} + ] + + +def test_multi_region_grid_scan_prepare_scan_offsets_positions_when_relative(v4_scan_assembler): + scan = v4_scan_assembler( + "multi_region_grid_scan", + "samx", + "samy", + regions=[[(-3.0, -1.0, 2), (-2.0, 2.0, 3)], [(1.0, 3.0, 2), (-2.0, 2.0, 3)]], + snaked=True, + relative=True, + ) + scan.components.get_start_positions = lambda motors: [1.0, -1.0] + + scan.prepare_scan() + + expected_positions = np.array( + [ + [-2.0, -3.0], + [-2.0, -1.0], + [-2.0, 1.0], + [0.0, 1.0], + [0.0, -1.0], + [0.0, -3.0], + [2.0, -3.0], + [2.0, -1.0], + [2.0, 1.0], + [4.0, 1.0], + [4.0, -1.0], + [4.0, -3.0], + ] + ) + assert scan.start_positions == [1.0, -1.0] + assert np.allclose(scan.positions, expected_positions) + + +def test_multi_region_grid_scan_prepare_scan_rejects_empty_region_list(v4_scan_assembler): + scan = v4_scan_assembler("multi_region_grid_scan", "samx", "samy", regions=[], snaked=True) + + with pytest.raises(ValueError, match="at least one paired region"): + scan.prepare_scan() diff --git a/bec_server/tests/tests_scan_server/scans_v4/test_multi_region_line_scan.py b/bec_server/tests/tests_scan_server/scans_v4/test_multi_region_line_scan.py new file mode 100644 index 000000000..8e22c9c9d --- /dev/null +++ b/bec_server/tests/tests_scan_server/scans_v4/test_multi_region_line_scan.py @@ -0,0 +1,56 @@ +import numpy as np +import pytest + +from bec_server.scan_server.tests.scan_hook_tests import ( + DEFAULT_HOOK_TESTS, + PREMOVE_HOOK_TESTS, + STANDARD_STEP_SCAN_TESTS, + run_scan_tests, +) + + +@pytest.mark.parametrize( + ("hook_name", "hook_tests"), + [*DEFAULT_HOOK_TESTS, *PREMOVE_HOOK_TESTS, *STANDARD_STEP_SCAN_TESTS], +) +def test_multi_region_line_scan_default_hooks( + v4_scan_assembler, nth_done_status_mock, hook_name, hook_tests +): + scan = v4_scan_assembler( + "multi_region_line_scan", "samx", regions=[(-5.0, -2.0, 4), (-1.0, 6.0, 4)] + ) + + run_scan_tests(scan, [(hook_name, hook_tests)], nth_done_status_mock=nth_done_status_mock) + + +def test_multi_region_line_scan_prepare_scan_updates_scan_info_and_queue(v4_scan_assembler): + scan = v4_scan_assembler( + "multi_region_line_scan", "samx", regions=[(-5.0, -2.0, 4), (-1.0, 6.0, 4)] + ) + + scan.prepare_scan() + + expected_positions = np.array( + [[-5.0], [-4.0], [-3.0], [-2.0], [-1.0], [1.33333333], [3.66666667], [6.0]] + ) + assert np.allclose(scan.positions, expected_positions) + assert scan.scan_info.num_points == len(expected_positions) + assert np.allclose(scan.scan_info.positions, expected_positions) + assert scan.scan_info.scan_report_instructions == [ + {"scan_progress": {"points": len(expected_positions), "show_table": False}} + ] + + +def test_multi_region_line_scan_prepare_scan_offsets_positions_when_relative(v4_scan_assembler): + scan = v4_scan_assembler( + "multi_region_line_scan", "samx", regions=[(-5.0, -2.0, 4), (-1.0, 6.0, 4)], relative=True + ) + scan.components.get_start_positions = lambda motors: [2.0] + + scan.prepare_scan() + + expected_positions = np.array( + [[-3.0], [-2.0], [-1.0], [0.0], [1.0], [3.33333333], [5.66666667], [8.0]] + ) + assert scan.start_positions == [2.0] + assert np.allclose(scan.positions, expected_positions) diff --git a/bec_server/tests/tests_scan_server/scans_v4/test_position_generators.py b/bec_server/tests/tests_scan_server/scans_v4/test_position_generators.py new file mode 100644 index 000000000..f13210873 --- /dev/null +++ b/bec_server/tests/tests_scan_server/scans_v4/test_position_generators.py @@ -0,0 +1,156 @@ +import numpy as np + +from bec_server.scan_server.scans import position_generators + + +def test_rotate_points_rotates_2d_positions(): + points = np.array([[1.0, 0.0], [0.0, 2.0]]) + + rotated = position_generators.rotate_points(points, np.pi / 2) + + assert np.allclose(rotated, [[0.0, 1.0], [-2.0, 0.0]]) + + +def test_rotate_points_rotates_around_custom_center(): + points = np.array([[2.0, 1.0]]) + + rotated = position_generators.rotate_points(points, np.pi / 2, center=(1.0, 1.0)) + + assert np.allclose(rotated, [[1.0, 2.0]]) + + +def test_line_scan_positions_generates_linear_trajectory(): + positions = position_generators.line_scan_positions([(-1.0, 1.0), (-2.0, 2.0)], steps=5) + + assert np.allclose(positions, [[-1.0, -2.0], [-0.5, -1.0], [0.0, 0.0], [0.5, 1.0], [1.0, 2.0]]) + + +def test_log_scan_positions_generates_log_trajectory(): + positions = position_generators.log_scan_positions([(1.0, 100.0), (10.0, 1000.0)], steps=3) + middle_progress = (np.sqrt(10) - 1) / 9 + + assert np.allclose( + positions, + [ + [1.0, 10.0], + [1.0 + middle_progress * 99.0, 10.0 + middle_progress * 990.0], + [100.0, 1000.0], + ], + ) + + +def test_log_scan_positions_spans_distance_for_zero_crossing_ranges(): + positions = position_generators.log_scan_positions([(0.0, 10.0), (-5.0, 5.0)], steps=3) + middle_progress = (np.sqrt(10) - 1) / 9 + + assert np.allclose( + positions, + [[0.0, -5.0], [middle_progress * 10.0, -5.0 + middle_progress * 10.0], [10.0, 5.0]], + ) + + +def test_log_scan_positions_supports_reverse_ranges(): + positions = position_generators.log_scan_positions([(10.0, 0.0)], steps=3) + middle_progress = (np.sqrt(10) - 1) / 9 + + assert np.allclose(positions, [[10.0], [10.0 - middle_progress * 10.0], [0.0]]) + + +def test_oscillating_positions_cycles_back_and_forth(): + generator = position_generators.oscillating_positions([1.0, 2.0, 3.0]) + + values = [next(generator) for _ in range(7)] + + assert values == [1.0, 2.0, 3.0, 2.0, 1.0, 2.0, 3.0] + + +def test_oscillating_positions_repeats_single_value(): + generator = position_generators.oscillating_positions([5.0]) + + values = [next(generator) for _ in range(4)] + + assert values == [5.0, 5.0, 5.0, 5.0] + + +def test_oscillating_positions_can_repeat_turning_points(): + generator = position_generators.oscillating_positions( + [1.0, 2.0, 3.0], repeat_turning_points=True + ) + + values = [next(generator) for _ in range(9)] + + assert values == [1.0, 1.0, 2.0, 3.0, 3.0, 2.0, 1.0, 1.0, 2.0] + + +def test_multi_region_line_positions_concatenates_regions(): + positions = position_generators.multi_region_line_positions([(-5.0, -2.0, 4), (-1.0, 6.0, 4)]) + + assert np.allclose( + positions, [[-5.0], [-4.0], [-3.0], [-2.0], [-1.0], [1.33333333], [3.66666667], [6.0]] + ) + + +def test_multi_region_grid_positions_builds_snaked_grid(): + positions = position_generators.multi_region_grid_positions( + [((-3.0, -1.0, 2), (-2.0, 2.0, 3)), ((1.0, 3.0, 2), (-2.0, 2.0, 3))], snaked=True + ) + + assert np.allclose( + positions, + [ + [-3.0, -2.0], + [-3.0, 0.0], + [-3.0, 2.0], + [-1.0, 2.0], + [-1.0, 0.0], + [-1.0, -2.0], + [1.0, -2.0], + [1.0, 0.0], + [1.0, 2.0], + [3.0, 2.0], + [3.0, 0.0], + [3.0, -2.0], + ], + ) + + +def test_multi_region_grid_positions_rejects_empty_region_list(): + with np.testing.assert_raises_regex(ValueError, "at least one paired region"): + position_generators.multi_region_grid_positions([]) + + +def test_spiral_positions_starts_at_center_and_stays_in_bounds(): + positions = position_generators.spiral_positions( + x_center=2.0, y_center=-1.0, x_range=4.0, y_range=6.0, dr=0.5, nth=8 + ) + + assert len(positions) > 1 + assert np.allclose(positions[0], [2.0, -1.0]) + assert np.all(positions[:, 0] >= 0.0) + assert np.all(positions[:, 0] <= 4.0) + assert np.all(positions[:, 1] >= -4.0) + assert np.all(positions[:, 1] <= 2.0) + + +def test_spiral_positions_supports_tilt(): + untilted = position_generators.spiral_positions( + x_center=0.0, y_center=0.0, x_range=6.0, y_range=6.0, dr=0.5, nth=8, tilt=0.0 + ) + tilted = position_generators.spiral_positions( + x_center=0.0, y_center=0.0, x_range=6.0, y_range=6.0, dr=0.5, nth=8, tilt=np.pi / 4 + ) + + assert len(untilted) == len(tilted) + assert np.allclose(untilted[0], tilted[0]) + assert not np.allclose(untilted[1], tilted[1]) + + +def test_fermat_spiral_positions_are_centered_in_requested_box(): + positions = position_generators.fermat_spiral_pos(10.0, 14.0, -3.0, 1.0, step=0.5, center=True) + + assert len(positions) > 0 + assert np.allclose(positions[0], [12.0, -1.0]) + assert np.all(positions[:, 0] >= 10.0) + assert np.all(positions[:, 0] <= 14.0) + assert np.all(positions[:, 1] >= -3.0) + assert np.all(positions[:, 1] <= 1.0) diff --git a/bec_server/tests/tests_scan_server/scans_v4/test_relative_scan_return_to_start.py b/bec_server/tests/tests_scan_server/scans_v4/test_relative_scan_return_to_start.py new file mode 100644 index 000000000..cc40f79d3 --- /dev/null +++ b/bec_server/tests/tests_scan_server/scans_v4/test_relative_scan_return_to_start.py @@ -0,0 +1,47 @@ +from unittest import mock + +import pytest + +RELATIVE_SCAN_CASES = [ + ("cont_line_scan", ("samx", -1.0, 1.0), {"steps": 3, "exp_time": 0.1, "relative": True}), + ("fermat_scan", ("samx", -1.0, 1.0, "samy", -2.0, 2.0), {"step": 0.5, "relative": True}), + ("grid_scan", ("samx", -1.0, 1.0, 3, "samy", -2.0, 2.0, 5), {"snaked": True, "relative": True}), + ("hexagonal_scan", ("samx", -1.0, 1.0, 1.0, "samy", -1.0, 1.0, 1.0), {"relative": True}), + ("line_scan", ("samx", -1.0, 1.0, "samy", -2.0, 2.0), {"steps": 5, "relative": True}), + ("list_scan", ("samx", [0, 1, 2], "samy", [3, 4, 5]), {"relative": True}), + ("log_scan", ("samx", 0.1, 1.0, "samy", 0.01, 1.0), {"steps": 3, "relative": True}), + ("line_sweep_scan", ("samx", -5.0, 5.0), {"relative": True}), + ( + "multi_region_grid_scan", + ("samx", "samy"), + { + "regions": [((-5.0, -1.0, 5), (-4.0, 0.0, 5)), ((1.0, 5.0, 3), (-4.0, 0.0, 5))], + "snaked": True, + "relative": True, + }, + ), + ( + "multi_region_line_scan", + ("samx",), + {"regions": [(-5.0, -2.0, 4), (-1.0, 6.0, 4)], "relative": True}, + ), + ( + "round_roi_scan", + ("samx", -3.0, 3.0, "samy", -2.0, 2.0), + {"shell_spacing": 1.0, "pos_in_first_ring": 4, "relative": True}, + ), + ("round_scan", ("samx", "samy", 0.0, 2.0, 2, 3), {"relative": True}), +] + + +@pytest.mark.parametrize(("scan_type", "scan_args", "scan_kwargs"), RELATIVE_SCAN_CASES) +def test_relative_v4_scan_on_exception_moves_back_to_start( + v4_scan_assembler, scan_type, scan_args, scan_kwargs +): + scan = v4_scan_assembler(scan_type, *scan_args, **scan_kwargs) + scan.start_positions = [float(index + 1) for index, _motor in enumerate(scan.motors)] + scan.components.move_and_wait = mock.MagicMock() + + scan.on_exception(RuntimeError("boom")) + + scan.components.move_and_wait.assert_called_once_with(scan.motors, scan.start_positions) diff --git a/bec_server/tests/tests_scan_server/scans_v4/test_round_roi_scan.py b/bec_server/tests/tests_scan_server/scans_v4/test_round_roi_scan.py new file mode 100644 index 000000000..b46dacda0 --- /dev/null +++ b/bec_server/tests/tests_scan_server/scans_v4/test_round_roi_scan.py @@ -0,0 +1,81 @@ +import numpy as np +import pytest + +from bec_server.scan_server.scans import position_generators +from bec_server.scan_server.tests.scan_hook_tests import ( + DEFAULT_HOOK_TESTS, + PREMOVE_HOOK_TESTS, + STANDARD_STEP_SCAN_TESTS, + run_scan_tests, +) + + +@pytest.mark.parametrize( + ("hook_name", "hook_tests"), + [*DEFAULT_HOOK_TESTS, *PREMOVE_HOOK_TESTS, *STANDARD_STEP_SCAN_TESTS], +) +def test_round_roi_scan_default_hooks( + v4_scan_assembler, nth_done_status_mock, hook_name, hook_tests +): + scan = v4_scan_assembler( + "round_roi_scan", + "samx", + -3.0, + 3.0, + "samy", + -2.0, + 2.0, + shell_spacing=1.0, + pos_in_first_ring=4, + relative=False, + ) + + run_scan_tests(scan, [(hook_name, hook_tests)], nth_done_status_mock=nth_done_status_mock) + + +def test_round_roi_scan_prepare_scan_updates_scan_info_and_queue(v4_scan_assembler): + scan = v4_scan_assembler( + "round_roi_scan", + "samx", + -3.0, + 3.0, + "samy", + -2.0, + 2.0, + shell_spacing=1.0, + pos_in_first_ring=4, + relative=False, + ) + + scan.prepare_scan() + + expected_positions = position_generators.get_round_roi_scan_positions( + -3.0, 3.0, -2.0, 2.0, 1.0, 4 + ) + assert np.array_equal(scan.positions, expected_positions) + assert scan.scan_info.num_points == len(expected_positions) + assert np.array_equal(scan.scan_info.positions, expected_positions) + + +def test_round_roi_scan_prepare_scan_offsets_positions_when_relative(v4_scan_assembler): + scan = v4_scan_assembler( + "round_roi_scan", + "samx", + -3.0, + 3.0, + "samy", + -2.0, + 2.0, + shell_spacing=1.0, + pos_in_first_ring=4, + relative=True, + ) + scan.components.get_start_positions = lambda motors: [2.0, 3.0] + + scan.prepare_scan() + + expected_positions = position_generators.get_round_roi_scan_positions( + -3.0, 3.0, -2.0, 2.0, 1.0, 4 + ) + [2.0, 3.0] + assert scan.start_positions == [2.0, 3.0] + assert np.array_equal(scan.positions, expected_positions) diff --git a/bec_server/tests/tests_scan_server/scans_v4/test_round_scan.py b/bec_server/tests/tests_scan_server/scans_v4/test_round_scan.py new file mode 100644 index 000000000..51f3bf1ae --- /dev/null +++ b/bec_server/tests/tests_scan_server/scans_v4/test_round_scan.py @@ -0,0 +1,42 @@ +import numpy as np +import pytest + +from bec_server.scan_server.scans import position_generators +from bec_server.scan_server.tests.scan_hook_tests import ( + DEFAULT_HOOK_TESTS, + PREMOVE_HOOK_TESTS, + STANDARD_STEP_SCAN_TESTS, + run_scan_tests, +) + + +@pytest.mark.parametrize( + ("hook_name", "hook_tests"), + [*DEFAULT_HOOK_TESTS, *PREMOVE_HOOK_TESTS, *STANDARD_STEP_SCAN_TESTS], +) +def test_round_scan_default_hooks(v4_scan_assembler, nth_done_status_mock, hook_name, hook_tests): + scan = v4_scan_assembler("round_scan", "samx", "samy", 0.0, 2.0, 2, 3, relative=False) + + run_scan_tests(scan, [(hook_name, hook_tests)], nth_done_status_mock=nth_done_status_mock) + + +def test_round_scan_prepare_scan_updates_scan_info_and_queue(v4_scan_assembler): + scan = v4_scan_assembler("round_scan", "samx", "samy", 0.0, 2.0, 2, 3, relative=False) + + scan.prepare_scan() + + expected_positions = position_generators.round_scan_positions(0.0, 2.0, 2, 3) + assert np.array_equal(scan.positions, expected_positions) + assert scan.scan_info.num_points == len(expected_positions) + assert np.array_equal(scan.scan_info.positions, expected_positions) + + +def test_round_scan_prepare_scan_offsets_positions_when_relative(v4_scan_assembler): + scan = v4_scan_assembler("round_scan", "samx", "samy", 0.0, 2.0, 2, 3, relative=True) + scan.components.get_start_positions = lambda motors: [1.0, -1.0] + + scan.prepare_scan() + + expected_positions = position_generators.round_scan_positions(0.0, 2.0, 2, 3) + [1.0, -1.0] + assert scan.start_positions == [1.0, -1.0] + assert np.array_equal(scan.positions, expected_positions) diff --git a/bec_server/tests/tests_scan_server/scans_v4/test_scan_actions.py b/bec_server/tests/tests_scan_server/scans_v4/test_scan_actions.py new file mode 100644 index 000000000..06e518e85 --- /dev/null +++ b/bec_server/tests/tests_scan_server/scans_v4/test_scan_actions.py @@ -0,0 +1,696 @@ +import os +import threading +from dataclasses import dataclass +from unittest import mock + +import pytest + +from bec_lib import messages +from bec_lib.device import ReadoutPriority +from bec_lib.endpoints import MessageEndpoints +from bec_lib.tests.fixtures import dm_with_devices # noqa: F401 +from bec_lib.tests.utils import ConnectorMock +from bec_server.scan_server.instruction_handler import InstructionHandler +from bec_server.scan_server.scan_stubs import ScanStubStatus +from bec_server.scan_server.scans.scans_v4 import ScanBase, ScanInfo + + +class _TestScan(ScanBase): + scan_name = "_v4_test_scan" + scan_type = None + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.scan_info.scan_number = 1 + self.scan_info.dataset_number = 2 + self.scan_info.scan_report_devices = ["samx"] + self.scan_info.readout_priority_modification = { + "baseline": [], + "monitored": [], + "on_request": [], + "async": [], + "continuous": [], + } + self.scan_info.scan_report_instructions = [] + self.update_scan_info( + num_points=3, + num_monitored_readouts=0, + exp_time=0.1, + frames_per_trigger=1, + settling_time=0.2, + relative=False, + run_on_exception_hook=False, + ) + self.scan_info.readout_time = 0.3 + + +@dataclass +class _ActionContext: + actions: object + connector: object + device_manager: object + scan: ScanBase + + +class _TestServiceConfig: + def __init__(self, base_path: str): + self.config = {"file_writer": {"base_path": base_path}} + + +class _TestParent: + def __init__(self, base_path: str): + self._service_config = _TestServiceConfig(base_path) + + +@pytest.fixture +def action_context(dm_with_devices): + def _build(connector=None): + connector = connector or ConnectorMock("") + instruction_handler = InstructionHandler(connector) + dm_with_devices.connector = connector + scan = _TestScan( + scan_id="scan-id-test", + request_inputs={}, + system_config={}, + redis_connector=connector, + device_manager=dm_with_devices, + instruction_handler=instruction_handler, + ) + return _ActionContext( + actions=scan.actions, connector=connector, device_manager=dm_with_devices, scan=scan + ) + + return _build + + +def _sent_device_instructions(ctx, action): + return [ + entry["msg"] + for entry in ctx.connector.message_sent + if getattr(entry["msg"], "action", None) == action + ] + + +def test_scan_info_stores_scan_report_device_objects_as_names(dm_with_devices): + scan_info = ScanInfo( + scan_name="_v4_test_scan", + scan_id="scan-id-test", + scan_type=None, + scan_report_devices=[dm_with_devices.devices["samx"], "samy"], + ) + + assert scan_info.scan_report_devices == ["samx", "samy"] + + scan_info.scan_report_devices = [dm_with_devices.devices["samz"]] + + assert scan_info.scan_report_devices == ["samz"] + + +def _last_device_instruction(ctx, action): + return _sent_device_instructions(ctx, action)[-1] + + +def _enabled_device_names(ctx): + return [dev.root.name for dev in ctx.device_manager.devices.enabled_devices] + + +def _reading(device_name, value): + return { + device_name: {"value": value}, + f"{device_name}_setpoint": {"value": value}, + f"{device_name}_motor_is_moving": {"value": 0}, + } + + +def _set_readout_priority(ctx, **readout_groups): + for device in ctx.device_manager.devices.values(): + device.root._config["readoutPriority"] = ReadoutPriority.ON_REQUEST + for priority, device_names in readout_groups.items(): + readout_priority = ReadoutPriority[priority.upper()] + for device_name in device_names: + ctx.device_manager.devices[device_name].root._config[ + "readoutPriority" + ] = readout_priority + + +def _set_software_triggered(ctx, *device_names): + software_triggered = set(device_names) + for device in ctx.device_manager.devices.values(): + device.root._config["softwareTrigger"] = device.root.name in software_triggered + + +def test_open_close_scan_send_scan_status(action_context): + ctx = action_context() + ctx.actions.check_for_unchecked_statuses = mock.MagicMock() + ctx.actions._send_scan_status = mock.MagicMock() + + ctx.actions.open_scan() + ctx.actions.close_scan() + + assert ctx.actions._send_scan_status.mock_calls == [mock.call("open"), mock.call("closed")] + ctx.actions.check_for_unchecked_statuses.assert_called_once_with() + + +def test_build_scan_status_message(action_context): + ctx = action_context() + _set_readout_priority( + ctx, monitored=["samx"], baseline=["samy"], on_request=["bpm4i"], **{"async": ["samz"]} + ) + + msg = ctx.actions._build_scan_status_message("open") + + assert msg.scan_id == "scan-id-test" + assert msg.scan_name == "_v4_test_scan" + assert msg.scan_number == 1 + assert msg.dataset_number == 2 + assert msg.num_points == 3 + assert msg.num_monitored_readouts == 0 + assert msg.scan_type is None + assert msg.scan_parameters == { + "exp_time": 0.1, + "frames_per_trigger": 1, + "settling_time": 0.2, + "readout_time": 0.3, + "relative": False, + "system_config": {}, + } + assert msg.readout_priority["monitored"] == ["samx"] + assert msg.readout_priority["baseline"] == ["samy"] + assert msg.readout_priority["async"] == ["samz"] + assert "bpm4i" in msg.readout_priority["on_request"] + + +def test_device_instruction_actions_emit_expected_messages(action_context): + ctx = action_context() + ctx.scan.scan_info.metadata["RID"] = "rid-123" + ctx.scan.scan_info.metadata["queue_id"] = "queue-123" + + stage_status = ctx.actions.stage("samx", wait=False) + pre_scan_status = ctx.actions.pre_scan_all_devices(wait=False) + kickoff_status = ctx.actions.kickoff("samx", parameters={"frames": 3}, wait=False) + complete_status = ctx.actions.complete("samy", wait=False) + unstage_status = ctx.actions.unstage("samz", wait=False) + + stage_msg = _last_device_instruction(ctx, "stage") + pre_scan_msg = _last_device_instruction(ctx, "pre_scan") + kickoff_msg = _last_device_instruction(ctx, "kickoff") + complete_msg = _last_device_instruction(ctx, "complete") + unstage_msg = _last_device_instruction(ctx, "unstage") + + assert stage_msg.device == "samx" + assert stage_msg.metadata["device_instr_id"] == stage_status._device_instr_id + assert stage_msg.metadata["scan_id"] == "scan-id-test" + assert stage_msg.metadata["RID"] == "rid-123" + assert stage_msg.metadata["queue_id"] == "queue-123" + assert pre_scan_msg.device == sorted(_enabled_device_names(ctx)) + assert pre_scan_msg.metadata["device_instr_id"] == pre_scan_status._device_instr_id + assert kickoff_msg.device == "samx" + assert kickoff_msg.parameter == {"configure": {"frames": 3}} + assert kickoff_msg.metadata["device_instr_id"] == kickoff_status._device_instr_id + assert complete_msg.device == "samy" + assert complete_msg.metadata["device_instr_id"] == complete_status._device_instr_id + assert unstage_msg.device == "samz" + assert unstage_msg.metadata["device_instr_id"] == unstage_status._device_instr_id + + +def test_set_emits_one_instruction_per_device(action_context): + ctx = action_context() + + status = ctx.actions.set(["samx", "samy"], [1.5, 2.5], wait=False) + + set_messages = _sent_device_instructions(ctx, "set")[-2:] + assert ( + status._sub_status_objects[0]._device_instr_id + == set_messages[0].metadata["device_instr_id"] + ) + assert ( + status._sub_status_objects[1]._device_instr_id + == set_messages[1].metadata["device_instr_id"] + ) + assert [(msg.device, msg.parameter) for msg in set_messages] == [ + ("samx", {"value": 1.5}), + ("samy", {"value": 2.5}), + ] + + +def test_set_rejects_mismatched_device_and_value_counts(action_context): + ctx = action_context() + + with pytest.raises(ValueError, match="number of devices and values"): + ctx.actions.set(["samx", "samy"], [1.5], wait=False) + + +def test_pre_scan_emits_expected_messages(action_context): + ctx = action_context() + + status_single = ctx.actions.pre_scan("samx", wait=False) + status_multi = ctx.actions.pre_scan(["samx", "samy"], wait=False) + + pre_scan_messages = _sent_device_instructions(ctx, "pre_scan") + assert pre_scan_messages[-2].device == "samx" + assert pre_scan_messages[-2].metadata["device_instr_id"] == status_single._device_instr_id + assert pre_scan_messages[-1].device == ["samx", "samy"] + assert pre_scan_messages[-1].metadata["device_instr_id"] == status_multi._device_instr_id + + +def test_pre_scan_all_devices_respects_exclude(action_context): + ctx = action_context() + + status = ctx.actions.pre_scan_all_devices(wait=False, exclude=["samx"]) + + pre_scan_msg = _last_device_instruction(ctx, "pre_scan") + assert pre_scan_msg.device == sorted( + [device_name for device_name in _enabled_device_names(ctx) if device_name != "samx"] + ) + assert pre_scan_msg.metadata["device_instr_id"] == status._device_instr_id + + +def test_read_actions_emit_expected_messages_and_point_ids(action_context): + ctx = action_context() + _set_readout_priority(ctx, baseline=["samz"], monitored=["samx", "samy"]) + + baseline_status = ctx.actions.read_baseline_devices(wait=False) + monitored_status_1 = ctx.actions.read_monitored_devices(wait=False) + monitored_status_2 = ctx.actions.read_monitored_devices(wait=False) + + read_messages = _sent_device_instructions(ctx, "read") + baseline_msg = read_messages[-3] + monitored_msg_1 = read_messages[-2] + monitored_msg_2 = read_messages[-1] + assert baseline_msg.device == ["samz"] + assert baseline_msg.metadata["readout_priority"] == "baseline" + assert baseline_msg.metadata["device_instr_id"] == baseline_status._device_instr_id + assert monitored_msg_1.device == ["samx", "samy"] + assert monitored_msg_1.metadata["point_id"] == 0 + assert monitored_msg_1.metadata["device_instr_id"] == monitored_status_1._device_instr_id + assert monitored_msg_2.metadata["point_id"] == 1 + assert monitored_msg_2.metadata["device_instr_id"] == monitored_status_2._device_instr_id + + +def test_empty_read_and_trigger_actions_return_done_status(action_context): + ctx = action_context() + _set_readout_priority(ctx) + _set_software_triggered(ctx) + + baseline_status = ctx.actions.read_baseline_devices(wait=False) + monitored_status = ctx.actions.read_monitored_devices(wait=False) + trigger_status = ctx.actions.trigger_all_devices(wait=False) + + assert baseline_status.done + assert monitored_status.done + assert trigger_status.done + assert not _sent_device_instructions(ctx, "read") + assert not _sent_device_instructions(ctx, "trigger") + + +def test_trigger_all_devices_emits_software_triggered_devices(action_context): + ctx = action_context() + _set_software_triggered(ctx, "samy", "samx") + + status = ctx.actions.trigger_all_devices(wait=False) + + trigger_msg = _last_device_instruction(ctx, "trigger") + assert trigger_msg.device == ["samx", "samy"] + assert trigger_msg.metadata["device_instr_id"] == status._device_instr_id + + +def test_complete_and_unstage_all_devices_emit_enabled_devices(action_context): + ctx = action_context() + + complete_status = ctx.actions.complete_all_devices(wait=False) + unstage_status = ctx.actions.unstage_all_devices(wait=False) + + complete_msg = _last_device_instruction(ctx, "complete") + unstage_msg = _last_device_instruction(ctx, "unstage") + assert complete_msg.device == _enabled_device_names(ctx) + assert complete_msg.metadata["device_instr_id"] == complete_status._device_instr_id + assert unstage_msg.device == _enabled_device_names(ctx) + assert unstage_msg.metadata["device_instr_id"] == unstage_status._device_instr_id + + +def test_complete_and_unstage_all_devices_respect_exclude(action_context): + ctx = action_context() + + complete_status = ctx.actions.complete_all_devices(wait=False, exclude=["samx"]) + unstage_status = ctx.actions.unstage_all_devices(wait=False, exclude="samy") + + complete_msg = _last_device_instruction(ctx, "complete") + unstage_msg = _last_device_instruction(ctx, "unstage") + assert complete_msg.device == [ + device_name for device_name in _enabled_device_names(ctx) if device_name != "samx" + ] + assert complete_msg.metadata["device_instr_id"] == complete_status._device_instr_id + assert unstage_msg.device == [ + device_name for device_name in _enabled_device_names(ctx) if device_name != "samy" + ] + assert unstage_msg.metadata["device_instr_id"] == unstage_status._device_instr_id + + +def test_stage_all_devices_stages_async_and_sync_devices(action_context): + ctx = action_context() + async_dev = ctx.device_manager.devices["samx"] + on_request_dev = ctx.device_manager.devices["bpm4i"] + continuous_dev = ctx.device_manager.devices["samz"] + enabled_devices = [ + ctx.device_manager.devices["samx"], + ctx.device_manager.devices["samy"], + ctx.device_manager.devices["samz"], + ctx.device_manager.devices["bpm4i"], + ] + container_status = ScanStubStatus( + ctx.scan._instruction_handler, + shutdown_event=threading.Event(), + registry={}, + is_container=True, + name="stage_all_devices", + ) + container_status.add_status = mock.MagicMock(wraps=container_status.add_status) + container_status.wait = mock.MagicMock() + async_status = ScanStubStatus( + ctx.scan._instruction_handler, + shutdown_event=threading.Event(), + registry={}, + name="stage_samx", + ) + sync_status = ScanStubStatus( + ctx.scan._instruction_handler, + shutdown_event=threading.Event(), + registry={}, + name="stage_sync_devices", + ) + + ctx.actions._create_status = mock.MagicMock(return_value=container_status) + ctx.actions.stage = mock.MagicMock(side_effect=[async_status, sync_status]) + + with ( + mock.patch.object( + type(ctx.device_manager.devices), "async_devices", return_value=[async_dev] + ), + mock.patch.object( + type(ctx.device_manager.devices), "on_request_devices", return_value=[on_request_dev] + ), + mock.patch.object( + type(ctx.device_manager.devices), "continuous_devices", return_value=[continuous_dev] + ), + mock.patch.object( + type(ctx.device_manager.devices), + "enabled_devices", + new_callable=mock.PropertyMock, + return_value=enabled_devices, + ), + ): + status = ctx.actions.stage_all_devices(wait=True) + + assert status is container_status + assert ctx.actions.stage.mock_calls == [ + mock.call(async_dev, status_name="stage_samx", wait=False), + mock.call(["samy"], status_name="stage_sync_devices", wait=False), + ] + assert container_status.add_status.mock_calls == [ + mock.call(async_status), + mock.call(sync_status), + ] + container_status.wait.assert_called_once_with() + + +def test_stage_all_devices_respects_exclude(action_context): + ctx = action_context() + async_dev = ctx.device_manager.devices["samx"] + on_request_dev = ctx.device_manager.devices["bpm4i"] + continuous_dev = ctx.device_manager.devices["samz"] + enabled_devices = [ + ctx.device_manager.devices["samx"], + ctx.device_manager.devices["samy"], + ctx.device_manager.devices["samz"], + ctx.device_manager.devices["bpm4i"], + ] + container_status = ScanStubStatus( + ctx.scan._instruction_handler, + shutdown_event=threading.Event(), + registry={}, + is_container=True, + name="stage_all_devices", + ) + sync_status = ScanStubStatus( + ctx.scan._instruction_handler, + shutdown_event=threading.Event(), + registry={}, + name="stage_sync_devices", + ) + + ctx.actions._create_status = mock.MagicMock(return_value=container_status) + ctx.actions.stage = mock.MagicMock(return_value=sync_status) + + with ( + mock.patch.object( + type(ctx.device_manager.devices), "async_devices", return_value=[async_dev] + ), + mock.patch.object( + type(ctx.device_manager.devices), "on_request_devices", return_value=[on_request_dev] + ), + mock.patch.object( + type(ctx.device_manager.devices), "continuous_devices", return_value=[continuous_dev] + ), + mock.patch.object( + type(ctx.device_manager.devices), + "enabled_devices", + new_callable=mock.PropertyMock, + return_value=enabled_devices, + ), + ): + status = ctx.actions.stage_all_devices(wait=False, exclude="samx") + + assert status is container_status + assert ctx.actions.stage.mock_calls == [ + mock.call(["samy"], status_name="stage_sync_devices", wait=False) + ] + + +def test_report_instructions_update_scan_info_and_queue(action_context): + ctx = action_context() + ctx.actions._update_queue_info_callback = mock.MagicMock() + + ctx.actions.add_scan_report_instruction_readback(["samx"], [0], [1], "rid") + ctx.actions.add_scan_report_instruction_device_progress("samy") + ctx.actions.add_scan_report_instruction_scan_progress(points=5, show_table=False) + + assert ctx.scan.scan_info.scan_report_instructions == [ + {"readback": {"RID": "rid", "devices": ["samx"], "start": [0], "end": [1]}}, + {"device_progress": ["samy"]}, + {"scan_progress": {"points": 5, "show_table": False}}, + ] + assert "samx" in ctx.actions._devices_with_required_response + assert ctx.actions._update_queue_info_callback.call_count == 3 + + +def test_rpc_call_returns_result_or_status(action_context): + ctx = action_context() + status = ScanStubStatus( + ctx.scan._instruction_handler, + device_instr_id="device-instr-id", + shutdown_event=threading.Event(), + registry={}, + name="rpc_samx_kickoff", + ) + status.set_done({"ok": True}) + status.wait = mock.MagicMock() + status._result_is_status = False + ctx.actions._create_status = mock.MagicMock(return_value=status) + ctx.actions._send = mock.MagicMock() + + result = ctx.actions.rpc_call("samx", "kickoff", 1, test=True) + + assert result == {"ok": True} + sent_msg = ctx.actions._send.call_args.args[0] + assert sent_msg.device == "samx" + assert sent_msg.action == "rpc" + assert sent_msg.parameter["device"] == "samx" + assert sent_msg.parameter["func"] == "kickoff" + assert sent_msg.parameter["args"] == (1,) + assert sent_msg.parameter["kwargs"] == {"test": True} + assert sent_msg.metadata["device_instr_id"] == "device-instr-id" + status.wait.assert_called_once_with(resolve_on_known_type=True) + + status._result_is_status = True + status.wait.reset_mock() + ctx.actions._send.reset_mock() + + result = ctx.actions.rpc_call("samx", "kickoff") + + assert result is status + status.wait.assert_called_once_with(resolve_on_known_type=True) + + +def test_send_scan_status_publishes_message(action_context): + ctx = action_context() + pipe = mock.MagicMock() + ctx.connector.pipeline = mock.MagicMock(return_value=pipe) + ctx.connector.set = mock.MagicMock() + ctx.connector.set_and_publish = mock.MagicMock() + status_msg = messages.ScanStatusMessage(scan_id="scan-id-test", status="closed", info={}) + ctx.actions._build_scan_status_message = mock.MagicMock(return_value=status_msg) + + ctx.actions._send_scan_status("closed", reason="alarm") + + ctx.actions._build_scan_status_message.assert_called_once_with(status="closed", reason="alarm") + ctx.connector.set.assert_called_once_with( + MessageEndpoints.public_scan_info("scan-id-test"), status_msg, pipe=pipe, expire=1800 + ) + ctx.connector.set_and_publish.assert_called_once_with( + MessageEndpoints.scan_status(), status_msg, pipe=pipe + ) + pipe.execute.assert_called_once_with() + + +def test_get_file_base_path_uses_account_and_templates(action_context): + ctx = action_context() + ctx.device_manager.parent = _TestParent("/tmp/data") + ctx.connector.get_last = mock.MagicMock( + return_value=messages.VariableMessage(value="test_account") + ) + + assert ctx.actions._get_file_base_path() == os.path.abspath("/tmp/data/test_account") + + ctx.device_manager.parent._service_config.config["file_writer"][ + "base_path" + ] = "/tmp/$account/raw" + assert ctx.actions._get_file_base_path() == os.path.abspath("/tmp/test_account/raw") + + +def test_get_file_base_path_rejects_invalid_account_and_template(action_context): + ctx = action_context() + ctx.device_manager.parent = _TestParent("/tmp/$missing/raw") + ctx.connector.get_last = mock.MagicMock( + return_value=messages.VariableMessage(value="bad/account") + ) + + with pytest.raises(ValueError, match="cannot contain a slash"): + ctx.actions._get_file_base_path() + + ctx.connector.get_last = mock.MagicMock(return_value=None) + with pytest.raises(ValueError, match="Invalid template variable"): + ctx.actions._get_file_base_path() + + +def test_required_response_flag_is_added_for_registered_device(action_context): + ctx = action_context() + ctx.actions.add_device_with_required_response("samx") + + ctx.actions.stage(["samx", "samy"], wait=False) + + stage_msg = _last_device_instruction(ctx, "stage") + assert stage_msg.metadata["response"] is True + + +def test_set_device_readout_priority_warns_after_reads(action_context): + ctx = action_context() + _set_readout_priority(ctx, monitored=["samx"]) + ctx.connector.raise_alarm = mock.MagicMock() + + ctx.actions.read_monitored_devices(wait=False) + ctx.actions.set_device_readout_priority(["samy"], priority="monitored") + + assert ctx.scan.scan_info.readout_priority_modification["monitored"] == ["samy"] + ctx.connector.raise_alarm.assert_called_once() + + +def test_check_for_unchecked_statuses_raises_cleanup_warnings(action_context): + ctx = action_context() + ctx.connector.raise_alarm = mock.MagicMock() + unchecked_status = ctx.actions.stage("samx", wait=False) + remaining_status = ctx.actions.complete("samy", wait=False) + remaining_status.wait = mock.MagicMock() + unchecked_status.set_done() + + ctx.actions.check_for_unchecked_statuses() + + assert ctx.connector.raise_alarm.call_count == 2 + alarm_types = [ + call.kwargs["info"].exception_type for call in ctx.connector.raise_alarm.mock_calls + ] + assert alarm_types == ["UncheckedStatusObjectsWarning", "ScanCleanupWarning"] + remaining_status.wait.assert_called_once_with() + + +def test_read_manually_sends_read_with_return_result(action_context): + ctx = action_context() + + status = ctx.actions.read_manually(["samy", "samx"], wait=False) + + read_messages = _sent_device_instructions(ctx, "read") + msg = read_messages[-1] + assert msg.device == ["samx", "samy"] + assert msg.parameter == {"return_result": True} + assert msg.metadata["device_instr_id"] == status._device_instr_id + assert "point_id" not in msg.metadata + + +def test_publish_manual_read_validates_and_increments_point_id(action_context): + ctx = action_context() + _set_readout_priority(ctx, monitored=["samx", "samy"]) + readings = {"samy": _reading("samy", 2), "samx": _reading("samx", 1)} + ctx.connector.pipeline = mock.MagicMock(wraps=ctx.connector.pipeline) + + ctx.actions.publish_manual_read(readings, wait=False) + ctx.actions.publish_manual_read( + [{"samy": _reading("samy", 4)}, {"samx": _reading("samx", 3)}], wait=False + ) + + assert not _sent_device_instructions(ctx, "publish_data_as_read") + samx_read_messages = [ + entry["msg"] + for entry in ctx.connector.message_sent + if entry["queue"] == MessageEndpoints.device_read("samx").endpoint + ] + samy_read_messages = [ + entry["msg"] + for entry in ctx.connector.message_sent + if entry["queue"] == MessageEndpoints.device_read("samy").endpoint + ] + samx_readback_messages = [ + entry["msg"] + for entry in ctx.connector.message_sent + if entry["queue"] == MessageEndpoints.device_readback("samx").endpoint + ] + assert ctx.connector.pipeline.call_count == 2 + assert samx_read_messages[-2].signals == _reading("samx", 1) + assert samy_read_messages[-2].signals == _reading("samy", 2) + assert samx_read_messages[-2].metadata["point_id"] == 0 + assert samx_read_messages[-1].signals == _reading("samx", 3) + assert samy_read_messages[-1].signals == _reading("samy", 4) + assert samx_read_messages[-1].metadata["point_id"] == 1 + assert not samx_readback_messages + + +def test_publish_manual_read_uses_pipeline_with_fakeredis(action_context, connected_connector): + ctx = action_context(connector=connected_connector) + _set_readout_priority(ctx, monitored=["samx", "samy"]) + + ctx.actions.publish_manual_read( + {"samx": _reading("samx", 1), "samy": _reading("samy", 2)}, wait=False + ) + + samx_msg = connected_connector.get(MessageEndpoints.device_read("samx")) + samy_msg = connected_connector.get(MessageEndpoints.device_read("samy")) + assert samx_msg.signals == _reading("samx", 1) + assert samx_msg.metadata["point_id"] == 0 + assert samy_msg.signals == _reading("samy", 2) + assert samy_msg.metadata["point_id"] == 0 + assert connected_connector.get(MessageEndpoints.device_readback("samx")) is None + + +def test_publish_manual_read_rejects_wrong_devices(action_context): + ctx = action_context() + _set_readout_priority(ctx, monitored=["samx", "samy"]) + + with pytest.raises(ValueError, match=r"Missing devices: \['samy'\]"): + ctx.actions.publish_manual_read({"samx": _reading("samx", 1)}, wait=False) + + +def test_publish_manual_read_rejects_missing_signals(action_context): + ctx = action_context() + _set_readout_priority(ctx, monitored=["samx", "samy"]) + + readings = {"samx": {"other_signal": {"value": 1}}, "samy": _reading("samy", 2)} + with pytest.raises(ValueError, match=r"Missing signals: .*'samx': .*'samx'"): + ctx.actions.publish_manual_read(readings, wait=False) diff --git a/bec_server/tests/tests_scan_server/scans_v4/test_scan_components.py b/bec_server/tests/tests_scan_server/scans_v4/test_scan_components.py new file mode 100644 index 000000000..607b30cce --- /dev/null +++ b/bec_server/tests/tests_scan_server/scans_v4/test_scan_components.py @@ -0,0 +1,185 @@ +from unittest import mock + +import numpy as np +import pytest + +from bec_server.scan_server.errors import LimitError + + +def test_move_and_wait_only_moves_changed_motors(v4_scan_assembler): + scan = v4_scan_assembler("mv", "samx", 1.5, "samy", -2.0) + scan.actions.set = mock.MagicMock() + + scan.components.move_and_wait( + scan.motors, np.array([1.0, 2.5]), last_positions=np.array([1.0, 2.0]) + ) + + scan.actions.set.assert_called_once_with([scan.dev["samy"]], [2.5], wait=True) + + +def test_move_and_wait_skips_set_when_positions_do_not_change(v4_scan_assembler): + scan = v4_scan_assembler("mv", "samx", 1.5, "samy", -2.0) + scan.actions.set = mock.MagicMock() + + scan.components.move_and_wait( + scan.motors, np.array([1.0, 2.0]), last_positions=np.array([1.0, 2.0]) + ) + + scan.actions.set.assert_not_called() + + +def test_trigger_and_read_waits_triggers_and_reads(v4_scan_assembler): + scan = v4_scan_assembler("mv", "samx", 1.5, "samy", -2.0) + scan.scan_info.exp_time = 0.2 + scan.scan_info.frames_per_trigger = 3 + scan.scan_info.settling_time = 0.1 + scan.scan_info.settling_time_after_trigger = 0.4 + scan.actions.trigger_all_devices = mock.MagicMock() + scan.actions.read_monitored_devices = mock.MagicMock() + + with mock.patch("bec_server.scan_server.scans.scan_components.time.sleep") as sleep_mock: + scan.components.trigger_and_read() + + sleep_mock.assert_has_calls([mock.call(0.1), mock.call(0.4)]) + scan.actions.trigger_all_devices.assert_called_once() + assert scan.actions.trigger_all_devices.call_args.kwargs["min_wait"] == pytest.approx(0.6) + scan.actions.read_monitored_devices.assert_called_once_with() + + +def test_step_scan_reuses_previous_position_across_points_and_bursts(v4_scan_assembler): + scan = v4_scan_assembler("mv", "samx", 1.5, "samy", -2.0) + scan.scan_info.burst_at_each_point = 2 + at_each_point = mock.MagicMock() + positions = np.array([[1.0, 2.0], [3.0, 4.0]]) + + scan.components.step_scan(scan.motors, positions, at_each_point=at_each_point) + + assert at_each_point.call_count == 4 + first_args, first_kwargs = at_each_point.call_args_list[0] + assert first_args[0] == scan.motors + np.testing.assert_allclose(first_args[1], positions[0]) + assert first_kwargs["last_positions"] is None + + np.testing.assert_allclose(at_each_point.call_args_list[1].args[1], positions[0]) + np.testing.assert_allclose( + at_each_point.call_args_list[1].kwargs["last_positions"], positions[0] + ) + np.testing.assert_allclose(at_each_point.call_args_list[2].args[1], positions[1]) + np.testing.assert_allclose( + at_each_point.call_args_list[2].kwargs["last_positions"], positions[0] + ) + np.testing.assert_allclose(at_each_point.call_args_list[3].args[1], positions[1]) + np.testing.assert_allclose( + at_each_point.call_args_list[3].kwargs["last_positions"], positions[1] + ) + + +def test_step_scan_at_each_point_moves_then_triggers(v4_scan_assembler): + scan = v4_scan_assembler("mv", "samx", 1.5, "samy", -2.0) + scan.components.move_and_wait = mock.MagicMock() + scan.components.trigger_and_read = mock.MagicMock() + pos = np.array([1.0, 2.0]) + last_positions = np.array([0.5, 1.5]) + + scan.components.step_scan_at_each_point(scan.motors, pos, last_positions=last_positions) + + scan.components.move_and_wait.assert_called_once_with( + scan.motors, pos, last_positions=last_positions + ) + scan.components.trigger_and_read.assert_called_once_with() + + +def test_get_start_positions_supports_motor_names_and_instances(v4_scan_assembler): + scan = v4_scan_assembler("mv", "samx", 1.5, "samy", -2.0) + scan.dev["samx"]._value = 1.25 + scan.dev["samy"]._value = -3.5 + + start_positions = scan.components.get_start_positions(["samx", scan.dev["samy"]]) + + assert start_positions == [1.25, -3.5] + + +def test_optimize_trajectory_uses_corridor_defaults_without_preferred_direction(v4_scan_assembler): + scan = v4_scan_assembler("mv", "samx", 1.5, "samy", -2.0) + positions = np.array([[0.0, 1.0], [1.0, 0.0]]) + optimized = np.array([[1.0, 0.0], [0.0, 1.0]]) + scan.components._path_optimizer.optimize_corridor = mock.MagicMock(return_value=optimized) + + result = scan.components.optimize_trajectory( + positions, optimization_type="corridor", corridor_size=4, num_iterations=7 + ) + + scan.components._path_optimizer.optimize_corridor.assert_called_once_with( + positions, num_iterations=7, corridor_size=4, sort_axis=1 + ) + np.testing.assert_allclose(result, optimized) + + +def test_optimize_trajectory_passes_primary_axis_preference_for_corridor(v4_scan_assembler): + scan = v4_scan_assembler("mv", "samx", 1.5, "samy", -2.0) + positions = np.array([[0.0, 1.0], [1.0, 0.0]]) + scan.components._path_optimizer.optimize_corridor = mock.MagicMock(return_value=positions) + + scan.components.optimize_trajectory( + positions, + optimization_type="corridor", + primary_axis=1, + preferred_directions=[-1, 1], + corridor_size=2, + num_iterations=3, + ) + + scan.components._path_optimizer.optimize_corridor.assert_called_once_with( + positions, num_iterations=3, sort_axis=1, preferred_direction=1, corridor_size=2 + ) + + +@pytest.mark.parametrize( + ("optimization_type", "optimizer_name"), + [("shell", "optimize_shell"), ("nearest", "optimize_nearest_neighbor")], +) +def test_optimize_trajectory_dispatches_to_other_optimizers( + v4_scan_assembler, optimization_type, optimizer_name +): + scan = v4_scan_assembler("mv", "samx", 1.5, "samy", -2.0) + positions = np.array([[0.0, 1.0], [1.0, 0.0]]) + optimized = np.array([[1.0, 0.0], [0.0, 1.0]]) + optimizer = mock.MagicMock(return_value=optimized) + setattr(scan.components._path_optimizer, optimizer_name, optimizer) + + result = scan.components.optimize_trajectory( + positions, optimization_type=optimization_type, num_iterations=4 + ) + + if optimization_type == "shell": + optimizer.assert_called_once_with(positions, num_iterations=4) + else: + optimizer.assert_called_once_with(positions) + np.testing.assert_allclose(result, optimized) + + +def test_optimize_trajectory_rejects_unknown_optimization_type(v4_scan_assembler): + scan = v4_scan_assembler("mv", "samx", 1.5, "samy", -2.0) + + with pytest.raises(ValueError, match="Invalid optimization type"): + scan.components.optimize_trajectory(np.array([[0.0, 1.0]]), optimization_type="bad") + + +def test_check_limits_accepts_positions_inside_motor_limits(v4_scan_assembler): + scan = v4_scan_assembler("mv", "samx", 1.5, "samy", -2.0) + + scan.components.check_limits(scan.motors, np.array([[0.0, -1.0], [1.0, 2.0]])) + + +def test_check_limits_ignores_motors_without_configured_limits(v4_scan_assembler): + scan = v4_scan_assembler("mv", "samx", 1.5, "samy", -2.0) + scan.dev["samx"]._limits = (5.0, 5.0) + + scan.components.check_limits(scan.motors, np.array([[100.0, -1.0], [200.0, 2.0]])) + + +def test_check_limits_raises_limit_error_for_out_of_bounds_position(v4_scan_assembler): + scan = v4_scan_assembler("mv", "samx", 1.5, "samy", -2.0) + + with pytest.raises(LimitError, match="Target position 12.0"): + scan.components.check_limits(scan.motors, np.array([[12.0, 0.0]])) diff --git a/bec_server/tests/tests_scan_server/scans_v4/test_time_scan.py b/bec_server/tests/tests_scan_server/scans_v4/test_time_scan.py new file mode 100644 index 000000000..3b0c1e094 --- /dev/null +++ b/bec_server/tests/tests_scan_server/scans_v4/test_time_scan.py @@ -0,0 +1,70 @@ +from unittest import mock + +import pytest + +from bec_server.scan_server.tests.scan_hook_tests import ( + assert_close_scan_waits_for_baseline_and_closes, + assert_pre_scan_called, + assert_prepare_scan_reads_baseline_devices, + assert_scan_open_called, + assert_stage_all_devices_called, + assert_unstage_all_devices_called, + run_scan_tests, +) + +TIME_SCAN_DEFAULT_HOOK_TESTS = [ + ("prepare_scan", [assert_prepare_scan_reads_baseline_devices]), + ("open_scan", [assert_scan_open_called]), + ("stage", [assert_stage_all_devices_called]), + ("pre_scan", [assert_pre_scan_called]), + ("unstage", [assert_unstage_all_devices_called]), + ("close_scan", [assert_close_scan_waits_for_baseline_and_closes]), +] + + +@pytest.mark.parametrize(("hook_name", "hook_tests"), TIME_SCAN_DEFAULT_HOOK_TESTS) +def test_time_scan_default_hooks(v4_scan_assembler, nth_done_status_mock, hook_name, hook_tests): + scan = v4_scan_assembler("time_scan", 3, 1.5, exp_time=0.2) + + run_scan_tests(scan, [(hook_name, hook_tests)], nth_done_status_mock=nth_done_status_mock) + + +def test_time_scan_prepare_scan_updates_scan_info_and_queue(v4_scan_assembler): + scan = v4_scan_assembler("time_scan", 3, 1.5, exp_time=0.2) + + scan.prepare_scan() + + assert scan.scan_info.num_points == 3 + assert scan.scan_info.positions.size == 0 + assert scan.scan_info.scan_report_instructions == [ + {"scan_progress": {"points": 3, "show_table": False}} + ] + + +def test_time_scan_scan_core_triggers_reads_and_waits_between_points(v4_scan_assembler): + scan = v4_scan_assembler("time_scan", 3, 1.5, exp_time=0.2) + scan.at_each_point = mock.MagicMock() + + with mock.patch("bec_server.scan_server.scans.time_scan.time.sleep") as sleep_mock: + scan.scan_core() + + assert scan.at_each_point.call_count == 3 + assert sleep_mock.call_args_list == [mock.call(1.3), mock.call(1.3)] + + +def test_time_scan_at_each_point_triggers_and_reads(v4_scan_assembler): + scan = v4_scan_assembler("time_scan", 3, 1.5, exp_time=0.2) + scan.components.trigger_and_read = mock.MagicMock() + + scan.at_each_point() + + scan.components.trigger_and_read.assert_called_once_with() + + +def test_time_scan_post_scan_completes_all_devices(v4_scan_assembler): + scan = v4_scan_assembler("time_scan", 3, 1.5, exp_time=0.2) + scan.actions.complete_all_devices = mock.MagicMock() + + scan.post_scan() + + scan.actions.complete_all_devices.assert_called_once_with() diff --git a/bec_server/tests/tests_scan_server/scans_v4/test_updated_move_scan.py b/bec_server/tests/tests_scan_server/scans_v4/test_updated_move_scan.py new file mode 100644 index 000000000..9a9a5585e --- /dev/null +++ b/bec_server/tests/tests_scan_server/scans_v4/test_updated_move_scan.py @@ -0,0 +1,53 @@ +from unittest import mock + +import pytest + + +@pytest.mark.parametrize( + ("hook_name",), + [ + ("prepare_scan",), + ("open_scan",), + ("stage",), + ("pre_scan",), + ("post_scan",), + ("unstage",), + ("close_scan",), + ], +) +def test_updated_move_scan_default_noop_hooks_do_not_raise(v4_scan_assembler, hook_name): + scan = v4_scan_assembler("umv", "samx", 1.5, "samy", -2.0) + + getattr(scan, hook_name)() + + +def test_updated_move_scan_scan_core_adds_readback_and_moves_to_absolute_targets(v4_scan_assembler): + scan = v4_scan_assembler("umv", "samx", 1.5, "samy", -2.0) + scan.scan_info.metadata["RID"] = "rid-123" + scan.components.get_start_positions = mock.MagicMock(return_value=[0.5, 3.0]) + scan.actions.add_scan_report_instruction_readback = mock.MagicMock() + scan.components.move_and_wait = mock.MagicMock() + + scan.scan_core() + + scan.components.get_start_positions.assert_called_once_with(scan.motors) + scan.actions.add_scan_report_instruction_readback.assert_called_once_with( + devices=scan.motors, start=[0.5, 3.0], stop=[1.5, -2.0], request_id="rid-123" + ) + scan.components.move_and_wait.assert_called_once_with(scan.motors, [1.5, -2.0]) + + +def test_updated_move_scan_scan_core_adds_readback_and_moves_to_relative_targets(v4_scan_assembler): + scan = v4_scan_assembler("umv", "samx", 1.5, "samy", -2.0, relative=True) + scan.scan_info.metadata["RID"] = "rid-123" + scan.components.get_start_positions = mock.MagicMock(return_value=[0.5, 3.0]) + scan.actions.add_scan_report_instruction_readback = mock.MagicMock() + scan.components.move_and_wait = mock.MagicMock() + + scan.scan_core() + + scan.components.get_start_positions.assert_called_once_with(scan.motors) + scan.actions.add_scan_report_instruction_readback.assert_called_once_with( + devices=scan.motors, start=[0.5, 3.0], stop=[2.0, 1.0], request_id="rid-123" + ) + scan.components.move_and_wait.assert_called_once_with(scan.motors, [2.0, 1.0]) diff --git a/bec_server/tests/tests_scan_server/test_direct_scan_worker.py b/bec_server/tests/tests_scan_server/test_direct_scan_worker.py new file mode 100644 index 000000000..fe40d25f7 --- /dev/null +++ b/bec_server/tests/tests_scan_server/test_direct_scan_worker.py @@ -0,0 +1,540 @@ +from types import SimpleNamespace +from unittest import mock + +import pytest + +from bec_lib import messages +from bec_lib.alarm_handler import Alarms +from bec_server.scan_server.direct_scan_worker import DirectScanWorker +from bec_server.scan_server.errors import DeviceInstructionError, ScanAbortion, UserScanInterruption +from bec_server.scan_server.scan_queue import ( + DirectInstructionQueueItem, + InstructionQueueStatus, + ScanQueue, +) +from bec_server.scan_server.scans.scans_v4 import ScanBase +from bec_server.scan_server.tests.utils import ScanServerMock + + +class _TestDirectScan(ScanBase): + scan_name = "_v4_test_direct_scan" + scan_type = None + + def __init__(self, *args, called_steps=None, fail_step=None, **kwargs): + self.called_steps = called_steps if called_steps is not None else [] + self.fail_step = fail_step + super().__init__(*args, **kwargs) + self.scan_info.scan_number = 7 + + def _record_step(self, step_name: str): + self.called_steps.append(step_name) + if self.fail_step == step_name: + raise RuntimeError(f"{step_name} failed") + + def prepare_scan(self): + self._record_step("prepare_scan") + + def open_scan(self): + self._record_step("open_scan") + + def stage(self): + self._record_step("stage") + + def pre_scan(self): + self._record_step("pre_scan") + + def scan_core(self): + self._record_step("scan_core") + + def post_scan(self): + self._record_step("post_scan") + + def unstage(self): + self._record_step("unstage") + + def close_scan(self): + self._record_step("close_scan") + + +@pytest.fixture +def direct_worker_context(dm_with_devices): + scan_server = ScanServerMock(dm_with_devices) + queue_manager = scan_server.queue_manager + queue_manager.shutdown() + queue_manager.send_queue_status = mock.MagicMock() + scan_queue = ScanQueue(queue_manager, queue_name="primary") + queue_manager.queues["primary"] = scan_queue + scan_server.connector.raise_alarm = mock.MagicMock() + scan_server.connector.send_client_info = mock.MagicMock() + scan_queue.abort = mock.MagicMock() + + queue = DirectInstructionQueueItem(scan_queue, mock.MagicMock(), scan_queue.scan_worker) + queue.append_to_queue_history = mock.MagicMock() + scan_queue.queue.append(queue) + scan_queue.active_instruction_queue = queue + + yield SimpleNamespace( + connector=scan_server.connector, + device_manager=scan_server.device_manager, + direct_worker=DirectScanWorker(worker=scan_queue.scan_worker), + instruction_handler=queue_manager.instruction_handler, + queue=queue, + queue_manager=queue_manager, + queue_state=scan_queue, + scan_worker=scan_queue.scan_worker, + scan_server=scan_server, + ) + + scan_server.shutdown() + + +@pytest.fixture +def make_scan(direct_worker_context): + def _build(*, called_steps=None, fail_step=None): + scan = _TestDirectScan( + scan_id="scan-id", + redis_connector=direct_worker_context.connector, + device_manager=direct_worker_context.device_manager, + instruction_handler=direct_worker_context.instruction_handler, + request_inputs={}, + system_config={}, + called_steps=called_steps, + fail_step=fail_step, + ) + scan.actions._send_scan_status = mock.MagicMock() + scan.actions.send_client_info = mock.MagicMock() + scan._shutdown_event = mock.MagicMock() + return scan + + return _build + + +def _append_scan(queue: DirectInstructionQueueItem, scan: _TestDirectScan): + queue.scans.append(scan) + queue.scan_msgs.append( + messages.ScanQueueMessage( + scan_type=scan.scan_info.scan_name, + parameter={"args": {}, "kwargs": {}}, + queue="primary", + metadata={"RID": "rid-1"}, + ) + ) + + +def test_check_for_interruption_sends_paused_status_via_scan_actions( + direct_worker_context, make_scan +): + scan = make_scan() + direct_worker_context.direct_worker.scan = scan + direct_worker_context.scan_worker.status = InstructionQueueStatus.PAUSED + + def _resume(_seconds): + direct_worker_context.scan_worker.status = InstructionQueueStatus.RUNNING + + with mock.patch("bec_server.scan_server.direct_scan_worker.time.sleep", side_effect=_resume): + direct_worker_context.direct_worker.check_for_interruption() + + scan.actions._send_scan_status.assert_called_once_with("paused") + + +def test_check_for_interruption_raises_user_interruption_on_stop(direct_worker_context): + direct_worker_context.scan_worker.status = InstructionQueueStatus.STOPPED + direct_worker_context.scan_worker.current_instruction_queue_item = direct_worker_context.queue + direct_worker_context.queue.exit_info = ("user_completed", "user") + + with pytest.raises(UserScanInterruption) as exc: + direct_worker_context.direct_worker.check_for_interruption() + + assert exc.value.exit_info == ("user_completed", "user") + + +def test_check_for_interruption_raises_scan_abortion_without_exit_info(direct_worker_context): + direct_worker_context.scan_worker.status = InstructionQueueStatus.STOPPED + direct_worker_context.scan_worker.current_instruction_queue_item = direct_worker_context.queue + direct_worker_context.queue.exit_info = None + + with pytest.raises(ScanAbortion): + direct_worker_context.direct_worker.check_for_interruption() + + +def test_check_for_interruption_does_not_send_paused_without_scan(direct_worker_context): + direct_worker_context.scan_worker.status = InstructionQueueStatus.PAUSED + + def _resume(_seconds): + direct_worker_context.scan_worker.status = InstructionQueueStatus.RUNNING + + with mock.patch("bec_server.scan_server.direct_scan_worker.time.sleep", side_effect=_resume): + direct_worker_context.direct_worker.check_for_interruption() + + +def test_process_instructions_runs_scan_and_resets_state(direct_worker_context, make_scan): + scan = make_scan() + _append_scan(direct_worker_context.queue, scan) + + with mock.patch.object(direct_worker_context.direct_worker, "run") as run_mock: + with mock.patch.object(direct_worker_context.direct_worker, "reset") as reset_mock: + direct_worker_context.direct_worker.process_instructions(direct_worker_context.queue) + + run_mock.assert_called_once_with(scan) + assert direct_worker_context.queue.status == InstructionQueueStatus.COMPLETED + assert direct_worker_context.scan_worker.current_instruction_queue_item is None + reset_mock.assert_called_once_with() + + +def test_process_instructions_returns_when_queue_has_no_scan(direct_worker_context): + direct_worker_context.queue.move_to_next_scan = mock.MagicMock(return_value=None) + + with mock.patch("bec_server.scan_server.direct_scan_worker.logger.error") as log_error: + direct_worker_context.direct_worker.process_instructions(direct_worker_context.queue) + + log_error.assert_called_once_with("No scan found in the queue item to process.") + assert ( + direct_worker_context.scan_worker.current_instruction_queue_item + is direct_worker_context.queue + ) + + +def test_run_executes_full_scan_sequence_in_order(direct_worker_context, make_scan): + called_steps = [] + scan = make_scan(called_steps=called_steps) + direct_worker_context.queue.active_scan = scan + direct_worker_context.scan_worker.current_instruction_queue_item = direct_worker_context.queue + rpc_cm = mock.MagicMock() + direct_worker_context.device_manager._rpc_method = mock.MagicMock(return_value=rpc_cm) + + direct_worker_context.direct_worker.run(scan) + + assert ( + scan.actions._interruption_callback + == direct_worker_context.direct_worker.check_for_interruption + ) + assert ( + scan.actions._update_queue_info_callback + == direct_worker_context.direct_worker.update_queue_info + ) + direct_worker_context.device_manager._rpc_method.assert_called_once_with(scan.actions.rpc_call) + assert called_steps == [ + "prepare_scan", + "open_scan", + "stage", + "pre_scan", + "scan_core", + "post_scan", + "unstage", + "close_scan", + ] + assert direct_worker_context.queue.status == InstructionQueueStatus.COMPLETED + assert direct_worker_context.scan_worker.current_instruction_queue_item is None + assert direct_worker_context.direct_worker.scan is None + + +def test_run_returns_early_when_signal_event_is_set(direct_worker_context, make_scan): + scan = make_scan(fail_step="scan_core") + direct_worker_context.queue.active_scan = scan + direct_worker_context.scan_worker.current_instruction_queue_item = direct_worker_context.queue + direct_worker_context.scan_worker.signal_event.set() + direct_worker_context.device_manager._rpc_method = mock.MagicMock(return_value=mock.MagicMock()) + direct_worker_context.direct_worker._handle_exception = mock.MagicMock() + + direct_worker_context.direct_worker.run(scan) + + direct_worker_context.direct_worker._handle_exception.assert_not_called() + assert direct_worker_context.queue.status == InstructionQueueStatus.PENDING + direct_worker_context.scan_worker.signal_event.clear() + + +def test_run_returns_early_when_current_queue_is_none(direct_worker_context, make_scan): + scan = make_scan(fail_step="scan_core") + direct_worker_context.scan_worker.current_instruction_queue_item = None + direct_worker_context.device_manager._rpc_method = mock.MagicMock(return_value=mock.MagicMock()) + direct_worker_context.direct_worker._handle_exception = mock.MagicMock() + + direct_worker_context.direct_worker.run(scan) + + direct_worker_context.direct_worker._handle_exception.assert_not_called() + + +def test_run_reraises_when_queue_is_already_stopped(direct_worker_context, make_scan): + scan = make_scan(fail_step="scan_core") + direct_worker_context.queue.active_scan = scan + direct_worker_context.queue.stopped = True + direct_worker_context.scan_worker.current_instruction_queue_item = direct_worker_context.queue + direct_worker_context.device_manager._rpc_method = mock.MagicMock(return_value=mock.MagicMock()) + + with pytest.raises(RuntimeError, match="scan_core failed"): + direct_worker_context.direct_worker.run(scan) + + direct_worker_context.queue.stopped = False + + +def test_run_reraises_when_queue_has_no_active_request_block(direct_worker_context, make_scan): + scan = make_scan(fail_step="scan_core") + direct_worker_context.queue.active_scan = None + direct_worker_context.scan_worker.current_instruction_queue_item = direct_worker_context.queue + direct_worker_context.device_manager._rpc_method = mock.MagicMock(return_value=mock.MagicMock()) + + with pytest.raises(RuntimeError, match="scan_core failed"): + direct_worker_context.direct_worker.run(scan) + + +def test_run_uses_on_exception_cleanup_before_handling_error(direct_worker_context, make_scan): + scan = make_scan(fail_step="scan_core") + direct_worker_context.queue.active_scan = scan + direct_worker_context.scan_worker.current_instruction_queue_item = direct_worker_context.queue + direct_worker_context.device_manager._rpc_method = mock.MagicMock(return_value=mock.MagicMock()) + direct_worker_context.direct_worker._run_on_exception_hook = mock.MagicMock() + direct_worker_context.direct_worker._handle_exception = mock.MagicMock( + side_effect=ScanAbortion() + ) + + with pytest.raises(ScanAbortion): + direct_worker_context.direct_worker.run(scan) + + assert direct_worker_context.queue.stopped is True + assert direct_worker_context.scan_worker.status == InstructionQueueStatus.RUNNING + assert scan.actions._metadata_suffix == "__on-exception" + direct_worker_context.direct_worker._run_on_exception_hook.assert_called_once() + assert isinstance( + direct_worker_context.direct_worker._run_on_exception_hook.call_args.args[0], RuntimeError + ) + direct_worker_context.direct_worker._handle_exception.assert_called_once() + + +def test_run_handles_cleanup_exception_before_original_error(direct_worker_context, make_scan): + scan = make_scan(fail_step="scan_core") + cleanup_exc = UserScanInterruption(exit_info=("halted", "user")) + direct_worker_context.queue.active_scan = scan + direct_worker_context.scan_worker.current_instruction_queue_item = direct_worker_context.queue + direct_worker_context.device_manager._rpc_method = mock.MagicMock(return_value=mock.MagicMock()) + direct_worker_context.direct_worker._run_on_exception_hook = mock.MagicMock( + side_effect=cleanup_exc + ) + direct_worker_context.direct_worker._handle_exception = mock.MagicMock( + side_effect=ScanAbortion() + ) + + with pytest.raises(ScanAbortion): + direct_worker_context.direct_worker.run(scan) + + direct_worker_context.connector.send_client_info.assert_called_once_with("") + assert direct_worker_context.direct_worker._handle_exception.call_args.args[0] is cleanup_exc + direct_worker_context.queue.stopped = False + + +def test_handle_exception_raises_alarm_for_device_instruction_error( + direct_worker_context, make_scan +): + scan = make_scan() + direct_worker_context.direct_worker.scan = scan + error_info = messages.ErrorInfo( + error_message="device failed", + compact_error_message="DeviceInstructionError", + exception_type="DeviceInstructionError", + device="samx", + ) + exc = DeviceInstructionError(error_info) + + with pytest.raises(ScanAbortion): + direct_worker_context.direct_worker._handle_exception(exc) + + direct_worker_context.connector.raise_alarm.assert_called_once_with( + severity=Alarms.MAJOR, info=error_info, metadata={"scan_id": "scan-id", "scan_number": 7} + ) + + +def test_handle_exception_raises_alarm_for_generic_exception(direct_worker_context, make_scan): + scan = make_scan() + direct_worker_context.direct_worker.scan = scan + + try: + raise RuntimeError("boom") + except RuntimeError as exc: + with pytest.raises(ScanAbortion): + direct_worker_context.direct_worker._handle_exception(exc) + + direct_worker_context.connector.raise_alarm.assert_called_once() + assert ( + direct_worker_context.connector.raise_alarm.call_args.kwargs["info"].exception_type + == "RuntimeError" + ) + + +def test_propagate_error_raises_major_alarm_with_scan_metadata(direct_worker_context, make_scan): + scan = make_scan() + direct_worker_context.direct_worker.scan = scan + + direct_worker_context.direct_worker._propagate_error("traceback", RuntimeError("boom")) + + direct_worker_context.connector.raise_alarm.assert_called_once() + assert direct_worker_context.connector.raise_alarm.call_args.kwargs["severity"] == Alarms.MAJOR + assert direct_worker_context.connector.raise_alarm.call_args.kwargs["metadata"] == { + "scan_id": "scan-id", + "scan_number": 7, + } + assert ( + direct_worker_context.connector.raise_alarm.call_args.kwargs["info"].exception_type + == "RuntimeError" + ) + + +@pytest.mark.parametrize( + ("scan_id", "scan_number", "expected"), + [ + (None, None, {}), + ("scan-id", None, {"scan_id": "scan-id"}), + (None, 7, {"scan_number": 7}), + ("scan-id", 7, {"scan_id": "scan-id", "scan_number": 7}), + ], +) +def test_get_metadata_for_alarm(direct_worker_context, make_scan, scan_id, scan_number, expected): + direct_worker_context.direct_worker.scan = SimpleNamespace( + scan_info=SimpleNamespace(scan_id=scan_id, scan_number=scan_number) + ) + + assert direct_worker_context.direct_worker.get_metadata_for_alarm() == expected + + +def test_run_on_exception_hook_invokes_scan_hook_when_enabled(direct_worker_context, make_scan): + scan = make_scan() + scan.on_exception = mock.MagicMock() + direct_worker_context.direct_worker.scan = scan + direct_worker_context.scan_worker.current_instruction_queue_item = direct_worker_context.queue + direct_worker_context.queue.run_on_exception_hook = True + direct_worker_context.device_manager._rpc_method = mock.MagicMock(return_value=mock.MagicMock()) + exc = ScanAbortion() + + direct_worker_context.direct_worker._run_on_exception_hook(exc) + + scan._shutdown_event.clear.assert_called_once_with() + scan.on_exception.assert_called_once_with(exc) + + +def test_run_on_exception_hook_uses_root_cause(direct_worker_context, make_scan): + scan = make_scan() + scan.on_exception = mock.MagicMock() + direct_worker_context.direct_worker.scan = scan + direct_worker_context.scan_worker.current_instruction_queue_item = direct_worker_context.queue + direct_worker_context.queue.run_on_exception_hook = True + direct_worker_context.device_manager._rpc_method = mock.MagicMock(return_value=mock.MagicMock()) + root_cause = RuntimeError("root cause") + + try: + raise root_cause + except RuntimeError as cause: + exc = ScanAbortion() + exc.__cause__ = cause + direct_worker_context.direct_worker._run_on_exception_hook(exc) + + scan.on_exception.assert_called_once_with(root_cause) + + +def test_run_on_exception_hook_returns_when_scan_is_none(direct_worker_context): + direct_worker_context.direct_worker.scan = None + direct_worker_context.scan_worker.current_instruction_queue_item = direct_worker_context.queue + + direct_worker_context.direct_worker._run_on_exception_hook(ScanAbortion()) + + +def test_run_on_exception_hook_returns_when_on_exception_is_missing( + direct_worker_context, make_scan +): + scan = make_scan() + direct_worker_context.direct_worker.scan = scan + direct_worker_context.scan_worker.current_instruction_queue_item = direct_worker_context.queue + direct_worker_context.queue.run_on_exception_hook = True + + direct_worker_context.direct_worker._run_on_exception_hook(ScanAbortion()) + + +def test_run_on_exception_hook_sends_client_info_when_hook_fails(direct_worker_context, make_scan): + scan = make_scan() + + def _fail(_exc): + raise RuntimeError("cleanup failed") + + scan.on_exception = _fail + direct_worker_context.direct_worker.scan = scan + direct_worker_context.scan_worker.current_instruction_queue_item = direct_worker_context.queue + direct_worker_context.queue.run_on_exception_hook = True + direct_worker_context.device_manager._rpc_method = mock.MagicMock(return_value=mock.MagicMock()) + + direct_worker_context.direct_worker._run_on_exception_hook(ScanAbortion()) + + scan.actions.send_client_info.assert_called_once_with("") + + +def test_run_on_exception_hook_skips_when_disabled(direct_worker_context, make_scan): + scan = make_scan() + scan.on_exception = mock.MagicMock() + direct_worker_context.direct_worker.scan = scan + direct_worker_context.scan_worker.current_instruction_queue_item = direct_worker_context.queue + direct_worker_context.queue.run_on_exception_hook = False + + direct_worker_context.direct_worker._run_on_exception_hook(ScanAbortion()) + + scan.on_exception.assert_not_called() + + +def test_handle_scan_abortion_sends_abort_status_via_scan_actions(direct_worker_context, make_scan): + scan = make_scan() + direct_worker_context.queue.exit_info = None + direct_worker_context.queue.run_on_exception_hook = True + direct_worker_context.direct_worker.scan = scan + direct_worker_context.direct_worker.reset = mock.MagicMock() + + direct_worker_context.direct_worker._handle_scan_abortion( + direct_worker_context.queue, ScanAbortion() + ) + + scan.actions._send_scan_status.assert_called_once_with("aborted", reason="alarm") + assert direct_worker_context.queue.status == InstructionQueueStatus.STOPPED + direct_worker_context.queue.append_to_queue_history.assert_called_once_with() + direct_worker_context.queue_state.abort.assert_called_once_with() + direct_worker_context.direct_worker.reset.assert_called_once_with() + assert direct_worker_context.scan_worker.status == InstructionQueueStatus.RUNNING + + +def test_handle_scan_abortion_returns_when_scan_is_none(direct_worker_context): + direct_worker_context.direct_worker.scan = None + + direct_worker_context.direct_worker._handle_scan_abortion( + direct_worker_context.queue, ScanAbortion() + ) + + direct_worker_context.queue.append_to_queue_history.assert_not_called() + + +def test_handle_scan_abortion_sends_user_status_via_scan_actions(direct_worker_context, make_scan): + scan = make_scan() + direct_worker_context.queue.exit_info = None + direct_worker_context.direct_worker.scan = scan + + direct_worker_context.direct_worker._handle_scan_abortion( + direct_worker_context.queue, UserScanInterruption(exit_info=("user_completed", "user")) + ) + + scan.actions._send_scan_status.assert_called_once_with("user_completed", reason="user") + + +def test_handle_scan_abortion_halts_when_exception_hook_is_disabled( + direct_worker_context, make_scan +): + scan = make_scan() + direct_worker_context.queue.exit_info = None + direct_worker_context.queue.run_on_exception_hook = False + direct_worker_context.direct_worker.scan = scan + direct_worker_context.direct_worker.reset = mock.MagicMock() + + direct_worker_context.direct_worker._handle_scan_abortion( + direct_worker_context.queue, ScanAbortion() + ) + + scan.actions._send_scan_status.assert_called_once_with("halted", reason="alarm") + + +def test_update_queue_info_forwards_to_queue_manager(direct_worker_context): + direct_worker_context.scan_worker.current_instruction_queue_item = direct_worker_context.queue + + direct_worker_context.direct_worker.update_queue_info() + + direct_worker_context.queue_manager.send_queue_status.assert_called_once_with() diff --git a/bec_server/tests/tests_scan_server/test_generator_scan_worker.py b/bec_server/tests/tests_scan_server/test_generator_scan_worker.py new file mode 100644 index 000000000..bd183c96d --- /dev/null +++ b/bec_server/tests/tests_scan_server/test_generator_scan_worker.py @@ -0,0 +1,803 @@ +# pylint: skip-file +import os +import uuid +from types import SimpleNamespace +from unittest import mock + +import pytest + +from bec_lib import messages +from bec_lib.endpoints import MessageEndpoints +from bec_lib.tests.fixtures import dm_with_devices +from bec_lib.tests.utils import ConnectorMock +from bec_server.scan_server.errors import ScanAbortion, UserScanInterruption +from bec_server.scan_server.generator_scan_worker import GeneratorScanWorker +from bec_server.scan_server.scan_queue import ( + InstructionQueueItem, + InstructionQueueStatus, + RequestBlock, +) +from bec_server.scan_server.scan_worker import ScanWorker + + +@pytest.fixture +def scan_worker_mock(dm_with_devices) -> ScanWorker: + dm_with_devices.connector = mock.MagicMock() + dm_with_devices._rpc_method = mock.MagicMock() + queue_manager = SimpleNamespace( + instruction_handler=mock.MagicMock(), queues={"primary": mock.MagicMock()} + ) + parent = SimpleNamespace( + device_manager=dm_with_devices, + connector=mock.MagicMock(), + queue_manager=queue_manager, + wait_for_service=mock.MagicMock(), + _service_config=SimpleNamespace(config={"file_writer": {"base_path": "/tmp/data"}}), + scan_number=2, + dataset_number=3, + ) + scan_worker = ScanWorker(parent=parent) + yield scan_worker + + +@pytest.fixture +def generator_worker_mock(scan_worker_mock) -> GeneratorScanWorker: + return GeneratorScanWorker(worker=scan_worker_mock) + + +class InstructionQueueMock: + def __init__(self): + self.status = InstructionQueueStatus.PENDING + self.idx = 1 + self.is_active = False + self.stopped = False + self.return_to_start = True + self.queue_id = "queue-id" + self.scan_msgs = [] + self.parent = SimpleNamespace( + queue_manager=SimpleNamespace(send_queue_status=mock.MagicMock()) + ) + self.active_request_block = mock.MagicMock() + self.active_request_block.scan = mock.MagicMock() + self.active_request_block.scan.exp_time = 1 + self.active_request_block.scan.stubs._rpc_call = mock.MagicMock() + self.active_request_block.scan.move_to_start.return_value = [] + self.active_request_block.scan_report_instructions = [] + self.active_request_block.metadata = {} + self.queue = SimpleNamespace( + request_blocks=[ + mock.MagicMock( + scan=mock.MagicMock(stubs=SimpleNamespace(_rpc_call=mock.MagicMock())) + ) + ], + active_rb=SimpleNamespace(scan_id="scan-id"), + ) + + def append_to_queue_history(self): + pass + + def stop(self): + pass + + def __next__(self): + if ( + self.status + in [ + InstructionQueueStatus.RUNNING, + InstructionQueueStatus.DEFERRED_PAUSE, + InstructionQueueStatus.PENDING, + ] + and self.idx < 5 + ): + self.idx += 1 + return "instr_status" + + raise StopIteration + + def __iter__(self): + return self + + +def test_wait_for_device_server(generator_worker_mock): + worker = generator_worker_mock + with mock.patch.object(worker.worker.parent, "wait_for_service") as service_mock: + worker._wait_for_device_server() + service_mock.assert_called_once_with("DeviceServer") + + +def test_publish_data_as_read(generator_worker_mock): + worker = generator_worker_mock + instr = messages.DeviceInstructionMessage( + device=["samx"], + action="publish_data_as_read", + parameter={"data": {}}, + metadata={ + "readout_priority": "monitored", + "DIID": 3, + "scan_id": "scan_id", + "RID": "requestID", + }, + ) + with mock.patch.object(worker.worker.device_manager, "connector") as connector_mock: + worker.publish_data_as_read(instr) + msg = messages.DeviceMessage( + signals=instr.content["parameter"]["data"], metadata=instr.metadata + ) + connector_mock.set_and_publish.assert_called_once_with( + MessageEndpoints.device_read("samx"), msg + ) + + +def test_publish_data_as_read_multiple(generator_worker_mock): + worker = generator_worker_mock + data = [{"samx": {}}, {"samy": {}}] + devices = ["samx", "samy"] + instr = messages.DeviceInstructionMessage( + device=devices, + action="publish_data_as_read", + parameter={"data": data}, + metadata={ + "readout_priority": "monitored", + "DIID": 3, + "scan_id": "scan_id", + "RID": "requestID", + }, + ) + with mock.patch.object(worker.worker.device_manager, "connector") as connector_mock: + worker.publish_data_as_read(instr) + mock_calls = [] + for device, dev_data in zip(devices, data): + msg = messages.DeviceMessage(signals=dev_data, metadata=instr.metadata) + mock_calls.append(mock.call(MessageEndpoints.device_read(device), msg)) + assert connector_mock.set_and_publish.mock_calls == mock_calls + + +def test_check_for_interruption(generator_worker_mock): + worker = generator_worker_mock + worker.worker.status = InstructionQueueStatus.STOPPED + with pytest.raises(ScanAbortion): + worker._check_for_interruption() + + +@pytest.mark.parametrize( + "instr, corr_num_points, scan_id", + [ + ( + messages.DeviceInstructionMessage( + device=None, + action="open_scan", + parameter={ + "num_points": 150, + "readout_priority": { + "monitored": ["samx", "samy"], + "baseline": [], + "on_request": [], + "async": [], + "continuous": [], + }, + }, + metadata={ + "readout_priority": "monitored", + "DIID": 18, + "scan_id": "12345", + "scan_def_id": 100, + "point_id": 50, + "RID": 11, + }, + ), + 201, + False, + ), + ( + messages.DeviceInstructionMessage( + device=None, + action="open_scan", + parameter={"num_points": 150}, + metadata={ + "readout_priority": "monitored", + "DIID": 18, + "scan_id": "12345", + "RID": 11, + }, + ), + 150, + True, + ), + ], +) +def test_open_scan(generator_worker_mock, instr, corr_num_points, scan_id): + worker = generator_worker_mock + + if not scan_id: + assert worker.scan_id is None + else: + worker.scan_id = 111 + + if "point_id" in instr.metadata: + worker.max_point_id = instr.metadata["point_id"] + + queue_mock = mock.MagicMock() + queue_mock.active_request_block.scan_report_instructions = [] + queue_mock.active_request_block.scan.show_live_table = True + queue_mock.active_request_block.scan.use_scan_progress_report = True + queue_mock.parent.queue_manager.send_queue_status = mock.MagicMock() + worker.worker.current_instruction_queue_item = queue_mock + + with mock.patch.object(worker, "_initialize_scan_info") as init_mock: + with mock.patch.object(worker, "_send_scan_status") as send_mock: + worker.open_scan(instr) + + if not scan_id: + assert worker.scan_id == instr.metadata.get("scan_id") + else: + assert worker.scan_id == 111 + assert worker.readout_priority == instr.content["parameter"].get("readout_priority", {}) + init_mock.assert_called_once_with( + queue_mock.active_request_block, instr, corr_num_points + ) + assert queue_mock.active_request_block.scan_report_instructions == [ + {"scan_progress": {"points": corr_num_points, "show_table": True}} + ] + queue_mock.parent.queue_manager.send_queue_status.assert_called_once() + send_mock.assert_called_once_with("open") + + +@pytest.mark.parametrize( + "msg", + [ + messages.ScanQueueMessage( + scan_type="grid_scan", + parameter={ + "args": {"samx": (-5, 5, 5), "samy": (-1, 1, 2)}, + "kwargs": { + "exp_time": 1, + "relative": True, + "system_config": {"file_suffix": None, "file_directory": None}, + }, + "num_points": 10, + }, + queue="primary", + metadata={ + "RID": "something", + "system_config": {"file_suffix": None, "file_directory": None}, + }, + ), + messages.ScanQueueMessage( + scan_type="grid_scan", + parameter={ + "args": {"samx": (-5, 5, 5), "samy": (-1, 1, 2)}, + "kwargs": { + "exp_time": 1, + "relative": True, + "system_config": {"file_suffix": "test", "file_directory": "tmp"}, + }, + "num_points": 10, + }, + queue="primary", + metadata={ + "RID": "something", + "system_config": {"file_suffix": "test", "file_directory": "tmp"}, + }, + ), + messages.ScanQueueMessage( + scan_type="grid_scan", + parameter={ + "args": {"samx": (-5, 5, 5), "samy": (-1, 1, 2)}, + "kwargs": { + "exp_time": 1, + "relative": True, + "system_config": {"file_suffix": "test", "file_directory": None}, + }, + "num_points": 10, + }, + queue="primary", + metadata={ + "RID": "something", + "system_config": {"file_suffix": "test", "file_directory": None}, + }, + ), + ], +) +def test_initialize_scan_info(generator_worker_mock, msg): + worker = generator_worker_mock + rb = SimpleNamespace( + metadata=msg.metadata, + scan=SimpleNamespace( + frames_per_trigger=1, + settling_time=0, + readout_time=0, + scan_report_devices=[], + monitor_sync="bec", + scan_parameters=msg.parameter["kwargs"], + request_inputs=msg.parameter, + parameter=msg.parameter, + ), + ) + + worker.worker.current_instruction_queue_item = mock.MagicMock(scan_msgs=[]) + worker.readout_priority = { + "monitored": ["samx"], + "baseline": [], + "async": [], + "continuous": [], + "on_request": [], + } + open_scan_msg = messages.DeviceInstructionMessage( + device=None, + action="open_scan", + parameter={"num_points": msg.content["parameter"].get("num_points")}, + metadata=msg.metadata, + ) + worker._initialize_scan_info(rb, open_scan_msg, msg.content["parameter"].get("num_points")) + + assert worker.current_scan_info["RID"] == "something" + assert worker.current_scan_info["scan_number"] == 2 + assert worker.current_scan_info["dataset_number"] == 3 + assert worker.current_scan_info["scan_report_devices"] == rb.scan.scan_report_devices + assert worker.current_scan_info["num_points"] == 10 + assert worker.current_scan_info["scan_msgs"] == [] + assert worker.current_scan_info["monitor_sync"] == "bec" + assert worker.current_scan_info["frames_per_trigger"] == 1 + assert worker.current_scan_info["args"] == {"samx": (-5, 5, 5), "samy": (-1, 1, 2)} + assert worker.current_scan_info["kwargs"] == msg.parameter["kwargs"] + assert "samx" in worker.current_scan_info["readout_priority"]["monitored"] + assert "samy" in worker.current_scan_info["readout_priority"]["baseline"] + + base_path = worker.worker.parent._service_config.config["file_writer"]["base_path"] + file_dir = msg.parameter["kwargs"]["system_config"]["file_directory"] + suffix = msg.parameter["kwargs"]["system_config"]["file_suffix"] + if file_dir is None: + if suffix is None: + file_dir = "S00000-00999/S00002" + else: + file_dir = f"S00000-00999/S00002_{suffix}" + file_components = os.path.abspath(os.path.join(base_path, file_dir, "S00002")), "h5" + assert worker.current_scan_info["file_components"] == file_components + + +@pytest.mark.parametrize( + "msg,scan_id,max_point_id,exp_num_points", + [ + ( + messages.DeviceInstructionMessage( + device=None, + action="close_scan", + parameter={}, + metadata={"readout_priority": "monitored", "DIID": 18, "scan_id": "12345"}, + ), + "12345", + 19, + 20, + ), + ( + messages.DeviceInstructionMessage( + device=None, + action="close_scan", + parameter={}, + metadata={"readout_priority": "monitored", "DIID": 18, "scan_id": "12345"}, + ), + "0987", + 200, + 19, + ), + ], +) +def test_close_scan(generator_worker_mock, msg, scan_id, max_point_id, exp_num_points): + worker = generator_worker_mock + worker.scan_id = scan_id + worker.current_scan_info["num_points"] = 19 + + reset = bool(worker.scan_id == msg.metadata["scan_id"]) + with mock.patch.object(worker, "_send_scan_status") as send_scan_status_mock: + worker.close_scan(msg, max_point_id=max_point_id) + if reset: + send_scan_status_mock.assert_called_with("closed") + assert worker.scan_id is None + else: + assert worker.scan_id == scan_id + assert worker.current_scan_info["num_points"] == exp_num_points + + +@pytest.mark.parametrize("status,expire", [("open", None), ("closed", 1800), ("aborted", 1800)]) +def test_send_scan_status(generator_worker_mock, status, expire): + worker = generator_worker_mock + worker.worker.device_manager.connector = ConnectorMock() + worker.current_scan_id = str(uuid.uuid4()) + worker.current_scan_info = {"scan_number": 5} + worker._send_scan_status(status) + scan_info_msgs = [ + msg + for msg in worker.worker.device_manager.connector.message_sent + if msg["queue"] + == MessageEndpoints.public_scan_info(scan_id=worker.current_scan_id).endpoint + ] + assert len(scan_info_msgs) == 1 + assert scan_info_msgs[0]["expire"] == expire + + +@pytest.mark.parametrize("abortion", [False, True]) +def test_process_instructions(generator_worker_mock, abortion): + worker = generator_worker_mock + queue = InstructionQueueMock() + worker.worker.device_manager._rpc_method.return_value = mock.MagicMock( + __enter__=mock.MagicMock(return_value=None), __exit__=mock.MagicMock(return_value=None) + ) + + with mock.patch.object(worker, "_wait_for_device_server") as wait_mock: + with mock.patch.object(worker, "reset") as reset_mock: + with mock.patch.object(worker, "_check_for_interruption") as interruption_mock: + queue.queue.request_blocks.append(mock.MagicMock()) + with mock.patch.object(queue.queue, "active_rb") as rb_mock: + with mock.patch.object(worker, "_instruction_step") as step_mock: + if abortion: + interruption_mock.side_effect = ScanAbortion + with pytest.raises(ScanAbortion): + worker.process_instructions(queue) + else: + worker.process_instructions(queue) + + assert worker.max_point_id == 0 + wait_mock.assert_called_once() + + if not abortion: + assert interruption_mock.call_count == 4 + assert worker._exposure_time == 1 + assert step_mock.call_count == 4 + assert queue.is_active is False + assert queue.status == InstructionQueueStatus.COMPLETED + assert worker.worker.current_instruction_queue_item is None + reset_mock.assert_called_once() + + else: + assert queue.stopped is True + assert interruption_mock.call_count == 1 + assert queue.is_active is True + assert queue.status == InstructionQueueStatus.PENDING + assert worker.worker.current_instruction_queue_item == queue + + +@pytest.mark.parametrize( + "msg,method", + [ + ( + messages.DeviceInstructionMessage( + device=None, + action="open_scan", + parameter={"readout_priority": {"monitored": [], "baseline": [], "on_request": []}}, + metadata={"readout_priority": "monitored", "scan_id": "12345"}, + ), + "open_scan", + ), + ( + messages.DeviceInstructionMessage( + device=None, + action="close_scan", + parameter={}, + metadata={"readout_priority": "monitored", "scan_id": "12345"}, + ), + "close_scan", + ), + ( + messages.DeviceInstructionMessage( + device=None, + action="trigger", + parameter={"group": "trigger"}, + metadata={"readout_priority": "monitored", "point_id": 0}, + ), + "forward_instruction", + ), + ( + messages.DeviceInstructionMessage( + device="samx", + action="set", + parameter={"value": 1.3681828686580249}, + metadata={"readout_priority": "monitored"}, + ), + "forward_instruction", + ), + ( + messages.DeviceInstructionMessage( + device=None, + action="read", + parameter={"group": "monitored"}, + metadata={"readout_priority": "monitored", "point_id": 1}, + ), + "forward_instruction", + ), + ( + messages.DeviceInstructionMessage( + device=None, + action="stage", + parameter={}, + metadata={"readout_priority": "monitored"}, + ), + "forward_instruction", + ), + ( + messages.DeviceInstructionMessage( + device=None, + action="unstage", + parameter={}, + metadata={"readout_priority": "monitored"}, + ), + "forward_instruction", + ), + ( + messages.DeviceInstructionMessage( + device="samx", + action="rpc", + parameter={ + "device": "lsamy", + "func": "readback.get", + "rpc_id": "61a7376c-36cf-41af-94b1-76c1ba821d47", + "args": [], + "kwargs": {}, + }, + metadata={"readout_priority": "monitored"}, + ), + "forward_instruction", + ), + ( + messages.DeviceInstructionMessage( + device="samx", action="kickoff", parameter={}, metadata={} + ), + "forward_instruction", + ), + ( + messages.DeviceInstructionMessage( + device=None, + action="baseline_reading", + parameter={}, + metadata={"readout_priority": "baseline"}, + ), + "forward_instruction", + ), + ( + messages.DeviceInstructionMessage(device=None, action="close_scan_def", parameter={}), + "close_scan", + ), + ( + messages.DeviceInstructionMessage( + device=None, action="publish_data_as_read", parameter={} + ), + "publish_data_as_read", + ), + ( + messages.DeviceInstructionMessage( + device=None, action="scan_report_instruction", parameter={} + ), + "process_scan_report_instruction", + ), + ( + messages.DeviceInstructionMessage(device=None, action="pre_scan", parameter={}), + "forward_instruction", + ), + ( + messages.DeviceInstructionMessage(device=None, action="complete", parameter={}), + "forward_instruction", + ), + ], +) +def test_instruction_step(generator_worker_mock, msg, method): + worker = generator_worker_mock + with mock.patch( + f"bec_server.scan_server.generator_scan_worker.GeneratorScanWorker.{method}" + ) as instruction_method: + with mock.patch.object(worker, "update_instr_with_scan_report") as update_mock: + worker._instruction_step(msg) + instruction_method.assert_called_once() + if method == "set": + update_mock.assert_called_once_with(msg) + + +def test_reset(generator_worker_mock): + worker = generator_worker_mock + worker.current_scan_id = 1 + worker.current_scan_info = 1 + worker.scan_id = 1 + worker.interception_msg = 1 + worker.worker.current_instruction_queue_item = 1 + + worker.reset() + + assert worker.current_scan_id == "" + assert worker.current_scan_info == {} + assert worker.scan_id is None + assert worker.interception_msg is None + assert worker.worker.current_instruction_queue_item is None + + +def test_cleanup(generator_worker_mock): + worker = generator_worker_mock + with mock.patch.object(worker, "forward_instruction") as forward_mock: + worker.cleanup() + sent_message = forward_mock.mock_calls[0].args[0] + diid = sent_message.metadata["device_instr_id"] + devices = sent_message.device + msg = messages.DeviceInstructionMessage( + device=devices, action="unstage", parameter={}, metadata={"device_instr_id": diid} + ) + forward_mock.assert_called_once_with(msg) + + +@pytest.mark.parametrize( + "msg", + [ + messages.DeviceInstructionMessage( + device=["samx"], action="set", parameter={"value": 1}, metadata={"scan_id": "scan_id"} + ) + ], +) +def test_worker_update_instr_with_scan_report_no_update(msg, generator_worker_mock): + worker = generator_worker_mock + worker.worker.current_instruction_queue_item = mock.MagicMock(spec=InstructionQueueItem) + arb = worker.worker.current_instruction_queue_item.active_request_block = mock.MagicMock( + spec=RequestBlock + ) + arb.scan_report_instructions = [] + with mock.patch.object(worker, "forward_instruction") as forward_mock: + worker._instruction_step(msg) + worker.update_instr_with_scan_report(msg) + forward_mock.assert_called_once_with(msg) + + +@pytest.mark.parametrize( + "msg", + [ + messages.DeviceInstructionMessage( + device=["samx"], action="set", parameter={"value": 1}, metadata={"scan_id": "scan_id"} + ) + ], +) +def test_worker_update_instr_with_scan_report_no_update_with_report(msg, generator_worker_mock): + worker = generator_worker_mock + worker.worker.current_instruction_queue_item = mock.MagicMock(spec=InstructionQueueItem) + arb = worker.worker.current_instruction_queue_item.active_request_block = mock.MagicMock( + spec=RequestBlock + ) + arb.scan_report_instructions = [{"scan_progress": {"points": 10, "show_table": True}}] + with mock.patch.object(worker, "forward_instruction") as forward_mock: + worker._instruction_step(msg) + worker.update_instr_with_scan_report(msg) + forward_mock.assert_called_once_with(msg) + + +@pytest.mark.parametrize( + "msg", + [ + messages.DeviceInstructionMessage( + device=["samx"], action="set", parameter={"value": 1}, metadata={"scan_id": "scan_id"} + ) + ], +) +def test_worker_update_instr_with_scan_report_update(msg, generator_worker_mock): + worker = generator_worker_mock + worker.worker.current_instruction_queue_item = mock.MagicMock(spec=InstructionQueueItem) + arb = worker.worker.current_instruction_queue_item.active_request_block = mock.MagicMock( + spec=RequestBlock + ) + arb.scan_report_instructions = [ + {"readback": {"RID": "rid", "devices": ["samx"], "start": [0], "end": [1]}} + ] + with mock.patch.object(worker, "forward_instruction") as forward_mock: + worker._instruction_step(msg) + worker.update_instr_with_scan_report(msg) + forward_mock.assert_called_once_with(msg) + assert msg.metadata["response"] is True + + +@pytest.mark.parametrize( + "base_path, current_account_msg, expected_path, raises_error", + [ + ( + "/data/$account/raw", + messages.VariableMessage(value="test_account"), + "/data/test_account/raw", + False, + ), + ("/data/$account/raw", None, "/data/raw", False), + ( + "/data/raw", + messages.VariableMessage(value="test_account"), + "/data/raw/test_account", + False, + ), + ("/data/raw", None, "/data/raw", False), + ( + "/data/$account/$sub_dir/raw", + messages.VariableMessage(value="test_account"), + "/data/test_account/$sub_dir/raw", + True, + ), + ], +) +def test_worker_get_file_base_path( + generator_worker_mock, base_path, current_account_msg, expected_path, raises_error +): + worker = generator_worker_mock + file_writer_base_path_orig = worker.worker.parent._service_config.config["file_writer"][ + "base_path" + ] + try: + worker.worker.parent._service_config.config["file_writer"]["base_path"] = base_path + with mock.patch.object( + worker.worker.connector, "get_last", return_value=current_account_msg + ): + if raises_error: + with pytest.raises(ValueError): + worker._get_file_base_path() + else: + file_path = worker._get_file_base_path() + assert file_path == expected_path + worker.worker.connector.get_last.assert_called_once_with( + MessageEndpoints.account(), "data" + ) + finally: + worker.worker.parent._service_config.config["file_writer"][ + "base_path" + ] = file_writer_base_path_orig + + +@pytest.mark.parametrize( + "scan_info, out", + [ + (None, {}), + ({}, {}), + ({"scan_id": "12345"}, {"scan_id": "12345"}), + ({"scan_number": 1}, {"scan_number": 1}), + ({"scan_id": "12345", "scan_number": 1}, {"scan_id": "12345", "scan_number": 1}), + ], +) +def test_worker_get_metadata_for_alarm(generator_worker_mock, scan_info, out): + worker = generator_worker_mock + worker.current_scan_info = scan_info + metadata = worker.get_metadata_for_alarm() + assert metadata == out + + +def test_handle_scan_abortion(generator_worker_mock): + worker = generator_worker_mock + queue = mock.MagicMock(spec=InstructionQueueItem) + queue.exit_info = None + queue.queue_id = "id-12345" + with mock.patch.object(worker, "_send_scan_status") as send_status_mock: + abortion = ScanAbortion() + worker._handle_scan_abortion(queue, abortion) + send_status_mock.assert_called_once_with("aborted", reason="alarm") + + +def test_handle_scan_halt(generator_worker_mock): + worker = generator_worker_mock + queue = mock.MagicMock(spec=InstructionQueueItem) + queue.exit_info = None + queue.queue_id = "id-12345" + with mock.patch.object(worker, "_send_scan_status") as send_status_mock: + abortion = ScanAbortion() + queue.return_to_start = False + worker._handle_scan_abortion(queue, abortion) + send_status_mock.assert_called_once_with("halted", reason="alarm") + + +def test_handle_user_scan_interruption(generator_worker_mock): + worker = generator_worker_mock + queue = mock.MagicMock(spec=InstructionQueueItem) + queue.exit_info = None + queue.queue_id = "id-12345" + with mock.patch.object(worker, "_send_scan_status") as send_status_mock: + interruption = UserScanInterruption(exit_info=("user_completed", "user")) + worker._handle_scan_abortion(queue, interruption) + send_status_mock.assert_called_once_with("user_completed", reason="user") + + +def test_handle_user_scan_interruption_followed_by_abortion(generator_worker_mock): + worker = generator_worker_mock + queue = mock.MagicMock(spec=InstructionQueueItem) + queue.exit_info = "user_completed" + queue.queue_id = "id-12345" + with mock.patch.object(worker, "_send_scan_status") as send_status_mock: + interruption = UserScanInterruption(exit_info=("user_completed", "user")) + abortion = ScanAbortion() + worker._handle_scan_abortion(queue, interruption) + worker._handle_scan_abortion(queue, abortion) + send_status_mock.mock_calls[0].assert_called_with("user_completed", reason="user") + send_status_mock.mock_calls[1].assert_called_with("user_completed", reason="user") diff --git a/bec_server/tests/tests_scan_server/test_scans.py b/bec_server/tests/tests_scan_server/test_legacy_scans.py similarity index 99% rename from bec_server/tests/tests_scan_server/test_scans.py rename to bec_server/tests/tests_scan_server/test_legacy_scans.py index e563b899e..c13eabdfa 100644 --- a/bec_server/tests/tests_scan_server/test_scans.py +++ b/bec_server/tests/tests_scan_server/test_legacy_scans.py @@ -8,7 +8,7 @@ from bec_lib import messages from bec_server.device_server.tests.utils import DMMock from bec_server.scan_server.scan_plugins.otf_scan import OTFScan -from bec_server.scan_server.scans import ( +from bec_server.scan_server.scans.legacy_scans import ( Acquire, CloseInteractiveScan, ContLineFlyScan, diff --git a/bec_server/tests/tests_scan_server/test_path_optimization.py b/bec_server/tests/tests_scan_server/test_path_optimization.py index 92cbbcd0c..8fc819afb 100644 --- a/bec_server/tests/tests_scan_server/test_path_optimization.py +++ b/bec_server/tests/tests_scan_server/test_path_optimization.py @@ -2,7 +2,10 @@ import pytest from bec_server.scan_server.path_optimization import PathOptimizerMixin -from bec_server.scan_server.scans import get_fermat_spiral_pos, get_round_roi_scan_positions +from bec_server.scan_server.scans.legacy_scans import ( + get_fermat_spiral_pos, + get_round_roi_scan_positions, +) def test_shell_optimization(): diff --git a/bec_server/tests/tests_scan_server/test_scan_assembler.py b/bec_server/tests/tests_scan_server/test_scan_assembler.py index 0156fd57e..6e25228a4 100644 --- a/bec_server/tests/tests_scan_server/test_scan_assembler.py +++ b/bec_server/tests/tests_scan_server/test_scan_assembler.py @@ -1,10 +1,17 @@ +from typing import Annotated from unittest import mock import pytest from bec_lib import messages +from bec_lib.device import DeviceBase +from bec_lib.scan_args import ScanArgument +from bec_lib.tests.fixtures import dm_with_devices +from bec_server.scan_server.errors import ScanInputValidationError from bec_server.scan_server.scan_assembler import ScanAssembler -from bec_server.scan_server.scans import FermatSpiralScan, LineScan, RequestBase +from bec_server.scan_server.scans import ScanArgType +from bec_server.scan_server.scans.legacy_scans import FermatSpiralScan, LineScan, RequestBase +from bec_server.scan_server.scans.scans_v4 import ScanBase as ScanBaseV4 @pytest.fixture @@ -32,6 +39,72 @@ def run(self): pass +class CustomDirectScan(ScanBaseV4): + scan_name = "custom_direct_scan" + arg_input = {"device": ScanArgType.DEVICE, "target": ScanArgType.FLOAT} + arg_bundle_size = {"bundle": len(arg_input), "min": 1, "max": None} + is_scan = False + + def __init__(self, *args, **kwargs): + super().__init__(**kwargs) + self.received_args = args + + +class CustomFixedDirectScan(ScanBaseV4): + scan_name = "custom_fixed_direct_scan" + is_scan = False + + def __init__(self, device: DeviceBase, target: float, **kwargs): + super().__init__(**kwargs) + self.device = device + self.target = target + + +class CustomBoundedDirectScan(ScanBaseV4): + scan_name = "custom_bounded_direct_scan" + is_scan = False + + def __init__( + self, + value: Annotated[float, ScanArgument(display_name="Value", gt=0, ge=1, lt=10, le=9)], + **kwargs, + ): + super().__init__(**kwargs) + self.value = value + + +class CustomBundledBoundedDirectScan(ScanBaseV4): + scan_name = "custom_bundled_bounded_direct_scan" + arg_input = { + "device": ScanArgType.DEVICE, + "target": Annotated[int, ScanArgument(display_name="Target", ge=1)], + } + arg_bundle_size = {"bundle": len(arg_input), "min": 1, "max": None} + is_scan = False + + def __init__( + self, + *args, + scale: Annotated[float, ScanArgument(display_name="Scale", le=10)] = 1, + **kwargs, + ): + super().__init__(**kwargs) + self.received_args = args + self.scale = scale + + +class CustomBundledScanWithDeviceKwarg(ScanBaseV4): + scan_name = "custom_bundled_scan_with_device_kwarg" + arg_input = {"device": ScanArgType.DEVICE, "target": ScanArgType.FLOAT} + arg_bundle_size = {"bundle": len(arg_input), "min": 1, "max": None} + is_scan = False + + def __init__(self, *args, monitor: DeviceBase, **kwargs): + super().__init__(**kwargs) + self.received_args = args + self.monitor = monitor + + @pytest.mark.parametrize( "msg, request_inputs_expected", [ @@ -155,12 +228,246 @@ class MockScanManager: "custom_scan2": {"class": "CustomScan2"}, } scan_dict = { - "FermatSpiralScan": FermatSpiralScan, - "LineScan": LineScan, - "CustomScan": CustomScan, - "CustomScan2": CustomScan2, + "fermat_scan": FermatSpiralScan, + "line_scan": LineScan, + "custom_scan": CustomScan, + "custom_scan2": CustomScan2, } with mock.patch.object(scan_assembler, "scan_manager", MockScanManager()): request = scan_assembler.assemble_device_instructions(msg, "scan_id") assert request.request_inputs == request_inputs_expected + + +def test_scan_assembler_assemble_direct_scan_resolves_device_args(dm_with_devices): + parent = mock.MagicMock() + parent.device_manager = dm_with_devices + parent.connector = mock.MagicMock() + parent.queue_manager.instruction_handler = mock.MagicMock() + assembler = ScanAssembler(parent=parent) + + class MockScanManager: + scan_dict = {"custom_direct_scan": CustomDirectScan} + + msg = messages.ScanQueueMessage( + scan_type="custom_direct_scan", + parameter={ + "args": {"samx": (1,), "samy": (2,)}, + "kwargs": {"system_config": {"file_directory": "/tmp/data"}}, + }, + queue="primary", + ) + + with mock.patch.object(assembler, "scan_manager", MockScanManager()): + request = assembler.assemble_direct_scan(msg, "scan_id") + + assert request.received_args == ( + dm_with_devices.devices["samx"], + 1, + dm_with_devices.devices["samy"], + 2, + ) + assert request.scan_info.request_inputs["arg_bundle"] == ["samx", 1, "samy", 2] + + +def test_scan_assembler_assemble_direct_scan_resolves_annotated_device_args(dm_with_devices): + parent = mock.MagicMock() + parent.device_manager = dm_with_devices + parent.connector = mock.MagicMock() + parent.queue_manager.instruction_handler = mock.MagicMock() + assembler = ScanAssembler(parent=parent) + + class MockScanManager: + scan_dict = {"custom_fixed_direct_scan": CustomFixedDirectScan} + + msg = messages.ScanQueueMessage( + scan_type="custom_fixed_direct_scan", + parameter={ + "args": ["samx", 1], + "kwargs": {"system_config": {"file_directory": "/tmp/data"}}, + }, + queue="primary", + ) + + with mock.patch.object(assembler, "scan_manager", MockScanManager()): + request = assembler.assemble_direct_scan(msg, "scan_id") + + assert request.device is dm_with_devices.devices["samx"] + assert request.target == 1 + assert request.scan_info.request_inputs["inputs"] == {"device": "samx", "target": 1} + + +@pytest.mark.parametrize( + ("value", "message"), + [ + (0, "greater than"), + (0.5, "greater than or equal to"), + (10, "less than"), + (9.5, "less than or equal to"), + ], +) +def test_scan_assembler_validates_fixed_direct_scan_input_bounds(dm_with_devices, value, message): + parent = mock.MagicMock() + parent.device_manager = dm_with_devices + parent.connector = mock.MagicMock() + parent.queue_manager.instruction_handler = mock.MagicMock() + assembler = ScanAssembler(parent=parent) + + class MockScanManager: + scan_dict = {"custom_bounded_direct_scan": CustomBoundedDirectScan} + + msg = messages.ScanQueueMessage( + scan_type="custom_bounded_direct_scan", + parameter={"args": [value], "kwargs": {"system_config": {"file_directory": "/tmp/data"}}}, + queue="primary", + ) + + with mock.patch.object(assembler, "scan_manager", MockScanManager()): + with pytest.raises(ScanInputValidationError, match=message): + assembler.assemble_direct_scan(msg, "scan_id") + + +def test_scan_assembler_validates_fixed_direct_scan_input_type(dm_with_devices): + parent = mock.MagicMock() + parent.device_manager = dm_with_devices + parent.connector = mock.MagicMock() + parent.queue_manager.instruction_handler = mock.MagicMock() + assembler = ScanAssembler(parent=parent) + + class MockScanManager: + scan_dict = {"custom_fixed_direct_scan": CustomFixedDirectScan} + + msg = messages.ScanQueueMessage( + scan_type="custom_fixed_direct_scan", + parameter={ + "args": ["samx", "invalid"], + "kwargs": {"system_config": {"file_directory": "/tmp/data"}}, + }, + queue="primary", + ) + + with mock.patch.object(assembler, "scan_manager", MockScanManager()): + with pytest.raises(ScanInputValidationError, match="target.*expected float"): + assembler.assemble_direct_scan(msg, "scan_id") + + +def test_scan_assembler_validates_bundled_direct_scan_input_bounds(dm_with_devices): + parent = mock.MagicMock() + parent.device_manager = dm_with_devices + parent.connector = mock.MagicMock() + parent.queue_manager.instruction_handler = mock.MagicMock() + assembler = ScanAssembler(parent=parent) + + class MockScanManager: + scan_dict = {"custom_bundled_bounded_direct_scan": CustomBundledBoundedDirectScan} + + msg = messages.ScanQueueMessage( + scan_type="custom_bundled_bounded_direct_scan", + parameter={ + "args": {"samx": (0,)}, + "kwargs": {"system_config": {"file_directory": "/tmp/data"}}, + }, + queue="primary", + ) + + with mock.patch.object(assembler, "scan_manager", MockScanManager()): + with pytest.raises(ScanInputValidationError, match="target.*greater than or equal to"): + assembler.assemble_direct_scan(msg, "scan_id") + + +def test_scan_assembler_validates_bundled_direct_scan_input_type(dm_with_devices): + parent = mock.MagicMock() + parent.device_manager = dm_with_devices + parent.connector = mock.MagicMock() + parent.queue_manager.instruction_handler = mock.MagicMock() + assembler = ScanAssembler(parent=parent) + + class MockScanManager: + scan_dict = {"custom_direct_scan": CustomDirectScan} + + msg = messages.ScanQueueMessage( + scan_type="custom_direct_scan", + parameter={ + "args": {"samx": ("invalid",)}, + "kwargs": {"system_config": {"file_directory": "/tmp/data"}}, + }, + queue="primary", + ) + + with mock.patch.object(assembler, "scan_manager", MockScanManager()): + with pytest.raises(ScanInputValidationError, match="target.*expected float"): + assembler.assemble_direct_scan(msg, "scan_id") + + +def test_scan_assembler_validates_signature_kwargs_for_arg_input_scan(dm_with_devices): + parent = mock.MagicMock() + parent.device_manager = dm_with_devices + parent.connector = mock.MagicMock() + parent.queue_manager.instruction_handler = mock.MagicMock() + assembler = ScanAssembler(parent=parent) + + class MockScanManager: + scan_dict = {"custom_bundled_bounded_direct_scan": CustomBundledBoundedDirectScan} + + msg = messages.ScanQueueMessage( + scan_type="custom_bundled_bounded_direct_scan", + parameter={ + "args": {"samx": (1,)}, + "kwargs": {"scale": 11, "system_config": {"file_directory": "/tmp/data"}}, + }, + queue="primary", + ) + + with mock.patch.object(assembler, "scan_manager", MockScanManager()): + with pytest.raises(ScanInputValidationError, match="scale.*less than or equal to"): + assembler.assemble_direct_scan(msg, "scan_id") + + +def test_scan_assembler_validates_signature_kwargs_type_for_arg_input_scan(dm_with_devices): + parent = mock.MagicMock() + parent.device_manager = dm_with_devices + parent.connector = mock.MagicMock() + parent.queue_manager.instruction_handler = mock.MagicMock() + assembler = ScanAssembler(parent=parent) + + class MockScanManager: + scan_dict = {"custom_bundled_bounded_direct_scan": CustomBundledBoundedDirectScan} + + msg = messages.ScanQueueMessage( + scan_type="custom_bundled_bounded_direct_scan", + parameter={ + "args": {"samx": (1,)}, + "kwargs": {"scale": "invalid", "system_config": {"file_directory": "/tmp/data"}}, + }, + queue="primary", + ) + + with mock.patch.object(assembler, "scan_manager", MockScanManager()): + with pytest.raises(ScanInputValidationError, match="scale.*expected float"): + assembler.assemble_direct_scan(msg, "scan_id") + + +def test_scan_assembler_resolves_signature_device_kwargs_for_arg_input_scan(dm_with_devices): + parent = mock.MagicMock() + parent.device_manager = dm_with_devices + parent.connector = mock.MagicMock() + parent.queue_manager.instruction_handler = mock.MagicMock() + assembler = ScanAssembler(parent=parent) + + class MockScanManager: + scan_dict = {"custom_bundled_scan_with_device_kwarg": CustomBundledScanWithDeviceKwarg} + + msg = messages.ScanQueueMessage( + scan_type="custom_bundled_scan_with_device_kwarg", + parameter={ + "args": {"samx": (1,)}, + "kwargs": {"monitor": "samy", "system_config": {"file_directory": "/tmp/data"}}, + }, + queue="primary", + ) + + with mock.patch.object(assembler, "scan_manager", MockScanManager()): + request = assembler.assemble_direct_scan(msg, "scan_id") + + assert request.received_args == (dm_with_devices.devices["samx"], 1) + assert request.monitor is dm_with_devices.devices["samy"] diff --git a/bec_server/tests/tests_scan_server/test_scan_guard.py b/bec_server/tests/tests_scan_server/test_scan_guard.py index 92268cd4c..07700f906 100644 --- a/bec_server/tests/tests_scan_server/test_scan_guard.py +++ b/bec_server/tests/tests_scan_server/test_scan_guard.py @@ -451,6 +451,7 @@ def stop_worker(self): class MockInstructionItem: def __init__(self): self.queue = MockRequestBlockQueue() + self.scan_id = ["scan_id"] class MockRequestBlockQueue: def __init__(self): diff --git a/bec_server/tests/tests_scan_server/test_scan_gui_models.py b/bec_server/tests/tests_scan_server/test_scan_gui_models.py index 346231e10..6c1f4b9b8 100644 --- a/bec_server/tests/tests_scan_server/test_scan_gui_models.py +++ b/bec_server/tests/tests_scan_server/test_scan_gui_models.py @@ -119,6 +119,18 @@ def __init__(self, *args, **kwargs): super().__init__(**kwargs) +class GenericListArgInputScan(ScanBase): # pragma: no cover + scan_name = "generic_list_arg_input_scan" + required_kwargs = [] + arg_input = {"device": ScanArgType.DEVICE, "positions": list[float]} + arg_bundle_size = {"bundle": len(arg_input), "min": 1, "max": None} + gui_config = {"Scan Parameters": []} + + def __init__(self, *args, **kwargs): + """Scan with generic list arg_input typing for GUI compatibility tests.""" + super().__init__(**kwargs) + + def test_gui_config_good_scan_dump(): gui_config = GUIConfig.from_dict(GoodScan) expected_config = { @@ -329,3 +341,33 @@ def test_gui_config_rich_arg_input_is_converted_to_legacy_scan_arg_types(): "expert": False, }, ] + + +def test_gui_config_generic_list_arg_input_is_converted_to_legacy_scan_arg_types(): + gui_config = GUIConfig.from_dict(GenericListArgInputScan) + + assert gui_config.arg_group.model_dump()["arg_inputs"] == { + "device": ScanArgType.DEVICE, + "positions": ScanArgType.LIST, + } + + assert gui_config.arg_group.model_dump()["inputs"] == [ + { + "arg": True, + "name": "device", + "display_name": "Device", + "type": "device", + "tooltip": None, + "default": None, + "expert": False, + }, + { + "arg": True, + "name": "positions", + "display_name": "Positions", + "type": "list", + "tooltip": None, + "default": None, + "expert": False, + }, + ] diff --git a/bec_server/tests/tests_scan_server/test_scan_server_queue.py b/bec_server/tests/tests_scan_server/test_scan_server_queue.py index f3d12d159..b95906d1f 100644 --- a/bec_server/tests/tests_scan_server/test_scan_server_queue.py +++ b/bec_server/tests/tests_scan_server/test_scan_server_queue.py @@ -13,6 +13,7 @@ from bec_server.scan_server.errors import LimitError, ScanAbortion from bec_server.scan_server.scan_assembler import ScanAssembler from bec_server.scan_server.scan_queue import ( + DirectInstructionQueueItem, InstructionQueueItem, InstructionQueueStatus, QueueManager, @@ -22,6 +23,8 @@ ScanQueueStatus, ) from bec_server.scan_server.scan_worker import ScanWorker +from bec_server.scan_server.scans.scans_v4 import ScanBase as ScanBaseV4 +from bec_server.scan_server.scans.scans_v4 import ScanType from bec_server.scan_server.tests.fixtures import scan_server_mock # pylint: disable=missing-function-docstring @@ -70,6 +73,24 @@ def append_scan_request(self, msg): self.queue.append(msg) +class _DummyV4Scan(ScanBaseV4): + scan_name = "_v4_dummy_scan" + + +def _build_dummy_v4_scan(scan_id: str, scan_number: int | None = None) -> _DummyV4Scan: + scan = _DummyV4Scan( + scan_id=scan_id, + redis_connector=mock.MagicMock(), + device_manager=mock.MagicMock(), + instruction_handler=mock.MagicMock(), + request_inputs={}, + system_config={}, + ) + scan.scan_info.scan_type = ScanType.SOFTWARE_TRIGGERED + scan.scan_info.scan_number = scan_number + return scan + + def test_queuemanager_queue_contains_primary(queuemanager_mock): queue_manager = queuemanager_mock() assert "primary" in queue_manager.queues @@ -216,6 +237,234 @@ def test_set_halt_disables_return_to_start(queuemanager_mock): assert queue.return_to_start is False +def test_set_halt_disables_return_to_start_for_direct_instruction_queue(queuemanager_mock): + queue_manager = queuemanager_mock() + queue_manager.queues["primary"].active_instruction_queue = DirectInstructionQueueItem( + queue_manager.queues["primary"], mock.MagicMock(), mock.MagicMock() + ) + queue_manager.queues["primary"].active_instruction_queue.run_on_exception_hook = True + with mock.patch.object(queue_manager, "set_abort") as set_abort: + queue = queue_manager.queues["primary"].active_instruction_queue + queue_manager.set_halt(scan_id="dummy", parameter={}) + set_abort.assert_called_once_with( + scan_id="dummy", queue="primary", exit_info=("halted", "user") + ) + assert queue.run_on_exception_hook is False + + +def test_direct_instruction_queue_run_on_exception_hook_uses_scan_info(queuemanager_mock): + queue_manager = queuemanager_mock() + queue = DirectInstructionQueueItem( + queue_manager.queues["primary"], mock.MagicMock(), mock.MagicMock() + ) + scan = _build_dummy_v4_scan(scan_id="scan-id-test") + scan.scan_info.run_on_exception_hook = False + queue.active_scan = scan + + assert queue.run_on_exception_hook is False + + scan.scan_info.run_on_exception_hook = True + + assert queue.run_on_exception_hook is True + + +def test_direct_instruction_queue_status_updates_worker_and_sends_queue_status(queuemanager_mock): + queue_manager = queuemanager_mock() + worker = mock.MagicMock() + queue = DirectInstructionQueueItem(queue_manager.queues["primary"], mock.MagicMock(), worker) + queue.stop = mock.MagicMock() + queue_manager.send_queue_status = mock.MagicMock() + + queue.status = InstructionQueueStatus.RUNNING + queue.status = InstructionQueueStatus.STOPPED + + assert worker.status == InstructionQueueStatus.STOPPED + queue.stop.assert_called_once_with() + assert queue_manager.send_queue_status.call_count == 2 + + +def test_direct_instruction_queue_append_scan_request_assembles_and_stores_scan(queuemanager_mock): + queue_manager = queuemanager_mock() + assembler = mock.MagicMock() + queue = DirectInstructionQueueItem( + queue_manager.queues["primary"], assembler, queue_manager.queues["primary"].scan_worker + ) + msg = messages.ScanQueueMessage( + scan_type="mv", + parameter={"args": {"samx": (1,)}, "kwargs": {}}, + queue="primary", + metadata={"RID": "rid-1"}, + ) + scan = _build_dummy_v4_scan("scan-id-test") + assembler.assemble_direct_scan.return_value = scan + + queue.append_scan_request(msg) + + assembler.assemble_direct_scan.assert_called_once_with(msg, scan_id=queue._scan_id) + assert queue.scans == [scan] + assert queue.scan_msgs == [msg] + + +def test_direct_instruction_queue_describe_active_scan_returns_none_when_missing(queuemanager_mock): + queue_manager = queuemanager_mock() + queue = DirectInstructionQueueItem( + queue_manager.queues["primary"], + mock.MagicMock(), + queue_manager.queues["primary"].scan_worker, + ) + scan = _build_dummy_v4_scan("scan-id-test") + + assert queue.describe_active_scan() is None + + queue.active_scan = scan + + assert queue.describe_active_scan() is None + + +def test_direct_instruction_queue_describe_active_scan_returns_request_block(queuemanager_mock): + queue_manager = queuemanager_mock() + queue = DirectInstructionQueueItem( + queue_manager.queues["primary"], + mock.MagicMock(), + queue_manager.queues["primary"].scan_worker, + ) + scan = _build_dummy_v4_scan("scan-id-test") + scan.scan_info.readout_priority_modification = {"monitored": ["samx"]} + scan.scan_info.scan_report_instructions = [{"device": "samx"}] + msg = messages.ScanQueueMessage( + scan_type="mv", + parameter={"args": {"samx": (1,)}, "kwargs": {}}, + queue="primary", + metadata={"RID": "rid-1"}, + ) + queue.scans = [scan] + queue.scan_msgs = [msg] + queue.active_scan = scan + + info = queue.describe_active_scan() + + assert info.msg == msg + assert info.RID == "rid-1" + assert info.report_instructions == [{"device": "samx"}] + assert info.scan_id == "scan-id-test" + + +def test_direct_instruction_queue_move_to_next_scan_activates_and_assigns_numbers( + queuemanager_mock, +): + queue_manager = queuemanager_mock() + scan_queue = queue_manager.queues["primary"] + queue = DirectInstructionQueueItem(scan_queue, mock.MagicMock(), scan_queue.scan_worker) + scan_queue.queue.append(queue) + first_scan = _build_dummy_v4_scan("scan-1", scan_number=None) + second_scan = _build_dummy_v4_scan("scan-2", scan_number=None) + second_scan.scan_info.metadata["dataset_id_on_hold"] = True + msg1 = messages.ScanQueueMessage( + scan_type="mv", + parameter={"args": {"samx": (1,)}, "kwargs": {}}, + queue="primary", + metadata={"RID": "rid-1"}, + ) + msg2 = messages.ScanQueueMessage( + scan_type="mv", + parameter={"args": {"samx": (2,)}, "kwargs": {}}, + queue="primary", + metadata={"RID": "rid-2", "dataset_id_on_hold": True}, + ) + queue.scans = [first_scan, second_scan] + queue.scan_msgs = [msg1, msg2] + + active_scan = queue.move_to_next_scan() + + assert active_scan is first_scan + assert queue.active_request_block is first_scan + assert queue.status == InstructionQueueStatus.RUNNING + assert first_scan.scan_info.scan_number is not None + first_dataset_number = first_scan.scan_info.dataset_number + + active_scan = queue.move_to_next_scan() + + assert active_scan is second_scan + assert second_scan.scan_info.scan_number is not None + assert second_scan.scan_info.dataset_number == first_dataset_number + + +def test_direct_instruction_queue_move_to_next_scan_raises_when_empty_or_exhausted( + queuemanager_mock, +): + queue_manager = queuemanager_mock() + queue = DirectInstructionQueueItem( + queue_manager.queues["primary"], + mock.MagicMock(), + queue_manager.queues["primary"].scan_worker, + ) + + with pytest.raises(StopIteration, match="No active scan and no scans"): + queue.move_to_next_scan() + + scan = _build_dummy_v4_scan("scan-id-test") + msg = messages.ScanQueueMessage( + scan_type="mv", + parameter={"args": {"samx": (1,)}, "kwargs": {}}, + queue="primary", + metadata={"RID": "rid-1"}, + ) + queue.scans = [scan] + queue.scan_msgs = [msg] + queue.active_scan = scan + + with pytest.raises(StopIteration, match="No more scans"): + queue.move_to_next_scan() + + +def test_direct_instruction_queue_append_to_queue_history_pushes_message(queuemanager_mock): + queue_manager = queuemanager_mock() + connector = mock.MagicMock() + queue_manager.connector = connector + queue = DirectInstructionQueueItem( + queue_manager.queues["primary"], + mock.MagicMock(), + queue_manager.queues["primary"].scan_worker, + ) + queue.status = InstructionQueueStatus.COMPLETED + + queue.append_to_queue_history() + + connector.lpush.assert_called_once() + endpoint, msg = connector.lpush.call_args.args[:2] + assert endpoint == MessageEndpoints.scan_queue_history() + assert msg.status == "COMPLETED" + assert msg.queue_id == queue.queue_id + assert connector.lpush.call_args.kwargs["max_size"] == 100 + + +def test_direct_instruction_queue_stop_and_abort_update_internal_state(queuemanager_mock): + queue_manager = queuemanager_mock() + queue = DirectInstructionQueueItem( + queue_manager.queues["primary"], + mock.MagicMock(), + queue_manager.queues["primary"].scan_worker, + ) + first_scan = _build_dummy_v4_scan("scan-1") + second_scan = _build_dummy_v4_scan("scan-2") + first_scan._shutdown_event = mock.MagicMock() + second_scan._shutdown_event = mock.MagicMock() + queue.scans = [first_scan, second_scan] + queue.scan_msgs = [mock.MagicMock(), mock.MagicMock()] + queue.active_scan = first_scan + + queue.stop() + + first_scan._shutdown_event.set.assert_called_once_with() + second_scan._shutdown_event.set.assert_called_once_with() + + queue.abort() + + assert queue.active_scan is None + assert queue.scans == [] + assert queue.scan_msgs == [] + + def wait_to_reach_state(queue_manager, queue, state): while queue_manager.queues[queue].status != state: pass @@ -568,6 +817,42 @@ def test_request_block_scan_number(scan_server_mock, scan_queue_msg): assert request_block.scan_number == 5 +def test_direct_instruction_queue_item_scan_number_projection_within_item(queuemanager_mock): + queue_manager = queuemanager_mock() + scan_queue = queue_manager.queues["primary"] + base_scan_number = queue_manager.parent.scan_number + instruction_queue = DirectInstructionQueueItem( + scan_queue, mock.MagicMock(), scan_queue.scan_worker + ) + scan_queue.queue.append(instruction_queue) + + scan1 = _build_dummy_v4_scan("scan-1") + scan2 = _build_dummy_v4_scan("scan-2") + + instruction_queue.scans = [scan1, scan2] + + assert instruction_queue.scan_number == [base_scan_number + 1, base_scan_number + 2] + + +def test_direct_instruction_queue_item_scan_number_projection_across_queue_items(queuemanager_mock): + queue_manager = queuemanager_mock() + scan_queue = queue_manager.queues["primary"] + base_scan_number = queue_manager.parent.scan_number + + first_queue = DirectInstructionQueueItem(scan_queue, mock.MagicMock(), scan_queue.scan_worker) + second_queue = DirectInstructionQueueItem(scan_queue, mock.MagicMock(), scan_queue.scan_worker) + scan_queue.queue.extend([first_queue, second_queue]) + + first_scan = _build_dummy_v4_scan("scan-1") + second_scan = _build_dummy_v4_scan("scan-2") + + first_queue.scans = [first_scan] + second_queue.scans = [second_scan] + + assert first_queue.scan_number == [base_scan_number + 1] + assert second_queue.scan_number == [base_scan_number + 2] + + def test_remove_queue_item(queuemanager_mock): queue_manager = queuemanager_mock() msg = messages.ScanQueueMessage( diff --git a/bec_server/tests/tests_scan_server/test_scan_server_scan_manager.py b/bec_server/tests/tests_scan_server/test_scan_server_scan_manager.py index 293ac4110..5ff3d6ac9 100644 --- a/bec_server/tests/tests_scan_server/test_scan_server_scan_manager.py +++ b/bec_server/tests/tests_scan_server/test_scan_server_scan_manager.py @@ -1,3 +1,4 @@ +from typing import Annotated from unittest import mock import pytest @@ -32,6 +33,8 @@ def scan_manager(): ({"a": Device}, {"a": "DeviceBase"}), ({"a": Positioner}, {"a": "DeviceBase"}), ({"a": DeviceBase | str}, {"a": ["DeviceBase", "str"]}), + ({"a": Annotated[float, "device"]}, {"a": "float"}), + ({"a": list[float]}, {"a": {"Generic": {"origin": "list", "args": ["float"]}}}), ], ) def test_scan_manager_convert_arg_input(scan_manager, arg_input, arg_output): diff --git a/bec_server/tests/tests_scan_server/test_scan_worker.py b/bec_server/tests/tests_scan_server/test_scan_worker.py index 74b16036b..3efdd0f6a 100644 --- a/bec_server/tests/tests_scan_server/test_scan_worker.py +++ b/bec_server/tests/tests_scan_server/test_scan_worker.py @@ -1,834 +1,90 @@ # pylint: skip-file -import os -import uuid +from types import SimpleNamespace from unittest import mock import pytest -from bec_lib import messages -from bec_lib.endpoints import MessageEndpoints -from bec_lib.tests.utils import ConnectorMock -from bec_server.scan_server.errors import ScanAbortion, UserScanInterruption -from bec_server.scan_server.scan_assembler import ScanAssembler -from bec_server.scan_server.scan_queue import ( - InstructionQueueItem, - InstructionQueueStatus, - QueueManager, - RequestBlock, - RequestBlockQueue, - ScanQueue, -) +from bec_lib.tests.fixtures import dm_with_devices +from bec_server.scan_server.direct_scan_worker import DirectScanWorker +from bec_server.scan_server.errors import ScanAbortion +from bec_server.scan_server.generator_scan_worker import GeneratorScanWorker +from bec_server.scan_server.scan_queue import DirectInstructionQueueItem, InstructionQueueItem from bec_server.scan_server.scan_worker import ScanWorker -from bec_server.scan_server.tests.fixtures import scan_server_mock @pytest.fixture -def scan_worker_mock(scan_server_mock) -> ScanWorker: - scan_server_mock.device_manager.connector = mock.MagicMock() - scan_worker = ScanWorker(parent=scan_server_mock) - yield scan_worker - - -class RequestBlockQueueMock(RequestBlockQueue): - request_blocks = [] - _scan_id = [] - - @property - def scan_id(self): - return self._scan_id - - def append(self, msg): - pass - - -class InstructionQueueMock(InstructionQueueItem): - def __init__(self, parent: ScanQueue, assembler: ScanAssembler, worker: ScanWorker) -> None: - super().__init__(parent, assembler, worker) - self.queue = RequestBlockQueueMock(self, assembler) - # self.queue.active_rb = [] - self.idx = 1 - - def append_scan_request(self, msg): - self.scan_msgs.append(msg) - self.queue.append(msg) - - def __next__(self): - if ( - self.status - in [ - InstructionQueueStatus.RUNNING, - InstructionQueueStatus.DEFERRED_PAUSE, - InstructionQueueStatus.PENDING, - ] - and self.idx < 5 - ): - self.idx += 1 - return "instr_status" - - else: - raise StopIteration - - -def test_wait_for_device_server(scan_worker_mock): - worker = scan_worker_mock - with mock.patch.object(worker.parent, "wait_for_service") as service_mock: - worker._wait_for_device_server() - service_mock.assert_called_once_with("DeviceServer") - - -def test_publish_data_as_read(scan_worker_mock): - worker = scan_worker_mock - instr = messages.DeviceInstructionMessage( - device=["samx"], - action="publish_data_as_read", - parameter={"data": {}}, - metadata={ - "readout_priority": "monitored", - "DIID": 3, - "scan_id": "scan_id", - "RID": "requestID", - }, +def scan_worker_mock(dm_with_devices) -> ScanWorker: + parent = SimpleNamespace( + device_manager=dm_with_devices, + connector=mock.MagicMock(), + queue_manager=SimpleNamespace(queues={}), ) - with mock.patch.object(worker.device_manager, "connector") as connector_mock: - worker.publish_data_as_read(instr) - msg = messages.DeviceMessage( - signals=instr.content["parameter"]["data"], metadata=instr.metadata - ) - connector_mock.set_and_publish.assert_called_once_with( - MessageEndpoints.device_read("samx"), msg - ) - - -def test_publish_data_as_read_multiple(scan_worker_mock): - worker = scan_worker_mock - data = [{"samx": {}}, {"samy": {}}] - devices = ["samx", "samy"] - instr = messages.DeviceInstructionMessage( - device=devices, - action="publish_data_as_read", - parameter={"data": data}, - metadata={ - "readout_priority": "monitored", - "DIID": 3, - "scan_id": "scan_id", - "RID": "requestID", - }, - ) - with mock.patch.object(worker.device_manager, "connector") as connector_mock: - worker.publish_data_as_read(instr) - mock_calls = [] - for device, dev_data in zip(devices, data): - msg = messages.DeviceMessage(signals=dev_data, metadata=instr.metadata) - mock_calls.append(mock.call(MessageEndpoints.device_read(device), msg)) - assert connector_mock.set_and_publish.mock_calls == mock_calls - - -def test_check_for_interruption(scan_worker_mock): - worker = scan_worker_mock - worker.status = InstructionQueueStatus.STOPPED - with pytest.raises(ScanAbortion) as exc_info: - worker._check_for_interruption() - - -@pytest.mark.parametrize( - "instr, corr_num_points, scan_id", - [ - ( - messages.DeviceInstructionMessage( - device=None, - action="open_scan", - parameter={"num_points": 150}, - metadata={ - "readout_priority": "monitored", - "DIID": 18, - "scan_id": "12345", - "scan_def_id": 100, - "point_id": 50, - "RID": 11, - }, - ), - 201, - False, - ), - ( - messages.DeviceInstructionMessage( - device=None, - action="open_scan", - parameter={"num_points": 150}, - metadata={ - "readout_priority": "monitored", - "DIID": 18, - "scan_id": "12345", - "RID": 11, - }, - ), - 150, - True, - ), - ], -) -def test_open_scan(scan_worker_mock, instr, corr_num_points, scan_id): - worker = scan_worker_mock - - if not scan_id: - assert worker.scan_id == None - else: - worker.scan_id = 111 - - if "point_id" in instr.metadata: - worker.max_point_id = instr.metadata["point_id"] - - assert worker.parent.connector.get(MessageEndpoints.scan_number()) == None - - with mock.patch.object(worker, "current_instruction_queue_item") as queue_mock: - with mock.patch.object(worker, "_initialize_scan_info") as init_mock: - with mock.patch.object(worker.scan_report_instructions, "append") as instr_append_mock: - with mock.patch.object(worker, "_send_scan_status") as send_mock: - with mock.patch.object( - worker.current_instruction_queue_item.parent.queue_manager, - "send_queue_status", - ) as queue_status_mock: - active_rb = queue_mock.active_request_block - active_rb.scan_report_instructions = [] - active_rb.scan.show_live_table = True - worker.open_scan(instr) - - if not scan_id: - assert worker.scan_id == instr.metadata.get("scan_id") - else: - assert worker.scan_id == 111 - init_mock.assert_called_once_with(active_rb, instr, corr_num_points) - assert active_rb.scan_report_instructions == [ - {"scan_progress": {"points": corr_num_points, "show_table": True}} - ] - queue_status_mock.assert_called_once() - send_mock.assert_called_once_with("open") - - -@pytest.mark.parametrize( - "msg", - [ - messages.ScanQueueMessage( - scan_type="grid_scan", - parameter={ - "args": {"samx": (-5, 5, 5), "samy": (-1, 1, 2)}, - "kwargs": { - "exp_time": 1, - "relative": True, - "system_config": {"file_suffix": None, "file_directory": None}, - }, - "num_points": 10, - }, - queue="primary", - metadata={ - "RID": "something", - "system_config": {"file_suffix": None, "file_directory": None}, - }, - ), - messages.ScanQueueMessage( - scan_type="grid_scan", - parameter={ - "args": {"samx": (-5, 5, 5), "samy": (-1, 1, 2)}, - "kwargs": { - "exp_time": 1, - "relative": True, - "system_config": {"file_suffix": "test", "file_directory": "tmp"}, - }, - "num_points": 10, - }, - queue="primary", - metadata={ - "RID": "something", - "system_config": {"file_suffix": "test", "file_directory": "tmp"}, - }, - ), - messages.ScanQueueMessage( - scan_type="grid_scan", - parameter={ - "args": {"samx": (-5, 5, 5), "samy": (-1, 1, 2)}, - "kwargs": { - "exp_time": 1, - "relative": True, - "system_config": {"file_suffix": "test", "file_directory": None}, - }, - "num_points": 10, - }, - queue="primary", - metadata={ - "RID": "something", - "system_config": {"file_suffix": "test", "file_directory": None}, - }, - ), - ], -) -def test_initialize_scan_info(scan_worker_mock, msg): - worker = scan_worker_mock - scan_server = scan_worker_mock.parent - rb = RequestBlock(msg, assembler=ScanAssembler(parent=scan_server), parent=mock.MagicMock()) - assert rb.metadata == msg.metadata - - with mock.patch.object(worker, "current_instruction_queue_item"): - worker.readout_priority = { - "monitored": ["samx"], - "baseline": [], - "async": [], - "continuous": [], - "on_request": [], - } - open_scan_msg = list(rb.scan.open_scan())[0] - worker._initialize_scan_info(rb, open_scan_msg, msg.content["parameter"].get("num_points")) - - assert worker.current_scan_info["RID"] == "something" - assert worker.current_scan_info["scan_number"] == 2 - assert worker.current_scan_info["dataset_number"] == 3 - assert worker.current_scan_info["scan_report_devices"] == rb.scan.scan_report_devices - assert worker.current_scan_info["num_points"] == 10 - assert worker.current_scan_info["scan_msgs"] == [] - assert worker.current_scan_info["monitor_sync"] == "bec" - assert worker.current_scan_info["frames_per_trigger"] == 1 - assert worker.current_scan_info["args"] == {"samx": (-5, 5, 5), "samy": (-1, 1, 2)} - assert worker.current_scan_info["kwargs"] == msg.parameter["kwargs"] - assert "samx" in worker.current_scan_info["readout_priority"]["monitored"] - assert "samy" in worker.current_scan_info["readout_priority"]["baseline"] - - base_path = worker.parent._service_config.config["file_writer"]["base_path"] - file_dir = msg.parameter["kwargs"]["system_config"]["file_directory"] - suffix = msg.parameter["kwargs"]["system_config"]["file_suffix"] - if file_dir is None: - if suffix is None: - file_dir = "S00000-00999/S00002" - else: - file_dir = f"S00000-00999/S00002_{suffix}" - file_components = os.path.abspath(os.path.join(base_path, file_dir, "S00002")), "h5" - assert worker.current_scan_info["file_components"] == file_components - + scan_worker = ScanWorker(parent=parent) + yield scan_worker -@pytest.mark.parametrize( - "msg,scan_id,max_point_id,exp_num_points", - [ - ( - messages.DeviceInstructionMessage( - device=None, - action="close_scan", - parameter={}, - metadata={"readout_priority": "monitored", "DIID": 18, "scan_id": "12345"}, - ), - "12345", - 19, - 20, - ), - ( - messages.DeviceInstructionMessage( - device=None, - action="close_scan", - parameter={}, - metadata={"readout_priority": "monitored", "DIID": 18, "scan_id": "12345"}, - ), - "0987", - 200, - 19, - ), - ], -) -def test_close_scan(scan_worker_mock, msg, scan_id, max_point_id, exp_num_points): - worker = scan_worker_mock - worker.scan_id = scan_id - worker.current_scan_info["num_points"] = 19 - reset = bool(worker.scan_id == msg.metadata["scan_id"]) - with mock.patch.object(worker, "_send_scan_status") as send_scan_status_mock: - worker.close_scan(msg, max_point_id=max_point_id) - if reset: - send_scan_status_mock.assert_called_with("closed") - assert worker.scan_id == None - else: - assert worker.scan_id == scan_id - assert worker.current_scan_info["num_points"] == exp_num_points +def test_get_worker_for_instruction_queue_item(scan_worker_mock): + queue = InstructionQueueItem.__new__(InstructionQueueItem) + worker = scan_worker_mock.get_worker_for_queue(queue) -# @pytest.mark.parametrize( -# "msg", -# [ -# messages.DeviceInstructionMessage( -# device=None, -# action="stage", -# parameter={}, -# metadata={"readout_priority": "async", "DIID": 18, "scan_id": "12345"}, -# ) -# ], -# ) -# def test_stage_device(scan_worker_mock, msg): -# worker = scan_worker_mock -# worker.device_manager.devices["eiger"]._config["readoutPriority"] = "async" -# worker.device_manager.devices["flyer_sim"]._config["readoutPriority"] = "on_request" + assert isinstance(worker, GeneratorScanWorker) -# with mock.patch.object(worker, "_wait_for_stage") as wait_mock: -# with mock.patch.object(worker.device_manager.connector, "send") as send_mock: -# worker.stage_devices(msg) -# on_request_device_names = [ -# dev.name for dev in worker.device_manager.devices.on_request_devices() -# ] -# async_devices = worker.device_manager.devices.async_devices() -# async_device_names = [dev.name for dev in async_devices] -# excluded_devices = async_devices -# excluded_devices.extend(worker.device_manager.devices.on_request_devices()) -# excluded_devices.extend(worker.device_manager.devices.continuous_devices()) -# devices = [ -# dev.name -# for dev in worker.device_manager.devices.enabled_devices -# if dev not in excluded_devices -# ] +def test_get_worker_for_direct_instruction_queue_item(scan_worker_mock): + queue = DirectInstructionQueueItem.__new__(DirectInstructionQueueItem) -# for dev in [ -# *worker.device_manager.devices.monitored_devices(), -# *worker.device_manager.devices.baseline_devices(), -# *worker.device_manager.devices.async_devices(), -# ]: -# assert dev.name in worker._staged_devices -# for async_dev in async_devices: -# assert ( -# mock.call( -# MessageEndpoints.device_instructions(), -# messages.DeviceInstructionMessage( -# device=async_dev.name, -# action="stage", -# parameter=msg.content["parameter"], -# metadata=msg.metadata, -# ), -# ) -# in send_mock.mock_calls -# ) -# assert ( -# mock.call( -# MessageEndpoints.device_instructions(), -# messages.DeviceInstructionMessage( -# device=devices, -# action="stage", -# parameter=msg.content["parameter"], -# metadata=msg.metadata, -# ), -# ) -# in send_mock.mock_calls -# ) -# assert ( -# mock.call(staged=True, devices=devices, metadata=msg.metadata) -# in wait_mock.mock_calls -# ) -# assert ( -# mock.call(staged=True, devices=async_device_names, metadata=msg.metadata) -# in wait_mock.mock_calls -# ) -# for dev in on_request_device_names: -# assert dev not in worker._staged_devices + worker = scan_worker_mock.get_worker_for_queue(queue) + assert isinstance(worker, DirectScanWorker) -@pytest.mark.parametrize("status,expire", [("open", None), ("closed", 1800), ("aborted", 1800)]) -def test_send_scan_status(scan_worker_mock, status, expire): - worker = scan_worker_mock - worker.device_manager.connector = ConnectorMock() - worker.current_scan_id = str(uuid.uuid4()) - worker.current_scan_info = {"scan_number": 5} - worker._send_scan_status(status) - scan_info_msgs = [ - msg - for msg in worker.device_manager.connector.message_sent - if msg["queue"] - == MessageEndpoints.public_scan_info(scan_id=worker.current_scan_id).endpoint - ] - assert len(scan_info_msgs) == 1 - assert scan_info_msgs[0]["expire"] == expire +def test_run_delegates_to_selected_worker(scan_worker_mock): + queue = mock.MagicMock() + queue.stopped = False -@pytest.mark.parametrize("abortion", [False, True]) -def test_process_instructions(scan_worker_mock, abortion): - worker = scan_worker_mock - scan_server = scan_worker_mock.parent - scan_queue = ScanQueue(QueueManager(scan_server)) - queue = InstructionQueueMock( - parent=scan_queue, assembler=ScanAssembler(parent=scan_server), worker=worker - ) + delegated_worker = mock.MagicMock() - with mock.patch.object(worker, "_wait_for_device_server") as wait_mock: - with mock.patch.object(worker, "reset") as reset_mock: - with mock.patch.object(worker, "_check_for_interruption") as interruption_mock: - queue.queue.request_blocks.append(mock.MagicMock()) - with mock.patch.object(queue.queue, "active_rb") as rb_mock: - with mock.patch.object(worker, "_instruction_step") as step_mock: - if abortion: - interruption_mock.side_effect = ScanAbortion - with pytest.raises(ScanAbortion) as exc_info: - worker._process_instructions(queue) - else: - worker._process_instructions(queue) + def _process(_queue): + scan_worker_mock.signal_event.set() - assert worker.max_point_id == 0 - wait_mock.assert_called_once() + delegated_worker.process_instructions.side_effect = _process - if not abortion: - assert interruption_mock.call_count == 4 - assert worker._exposure_time == getattr(rb_mock.scan, "exp_time", None) - assert step_mock.call_count == 4 - assert queue.is_active == False - assert queue.status == InstructionQueueStatus.COMPLETED - assert worker.current_instruction_queue_item == None - reset_mock.assert_called_once() + with mock.patch.object( + scan_worker_mock, "get_worker_for_queue", return_value=delegated_worker + ) as get_worker: + scan_worker_mock.parent.queue_manager.queues[scan_worker_mock.queue_name] = [queue] - else: - assert queue.stopped == True - assert interruption_mock.call_count == 1 - assert queue.is_active == True - assert queue.status == InstructionQueueStatus.PENDING - assert worker.current_instruction_queue_item == queue + scan_worker_mock.run() + get_worker.assert_called_once_with(queue) + delegated_worker.process_instructions.assert_called_once_with(queue) + queue.append_to_queue_history.assert_called_once() -@pytest.mark.parametrize( - "msg,method", - [ - ( - messages.DeviceInstructionMessage( - device=None, - action="open_scan", - parameter={"readout_priority": {"monitored": [], "baseline": [], "on_request": []}}, - metadata={"readout_priority": "monitored", "scan_id": "12345"}, - ), - "open_scan", - ), - ( - messages.DeviceInstructionMessage( - device=None, - action="close_scan", - parameter={}, - metadata={"readout_priority": "monitored", "scan_id": "12345"}, - ), - "close_scan", - ), - ( - messages.DeviceInstructionMessage( - device=None, - action="trigger", - parameter={"group": "trigger"}, - metadata={"readout_priority": "monitored", "point_id": 0}, - ), - "forward_instruction", - ), - ( - messages.DeviceInstructionMessage( - device="samx", - action="set", - parameter={"value": 1.3681828686580249}, - metadata={"readout_priority": "monitored"}, - ), - "forward_instruction", - ), - ( - messages.DeviceInstructionMessage( - device=None, - action="read", - parameter={"group": "monitored"}, - metadata={"readout_priority": "monitored", "point_id": 1}, - ), - "forward_instruction", - ), - ( - messages.DeviceInstructionMessage( - device=None, - action="stage", - parameter={}, - metadata={"readout_priority": "monitored"}, - ), - "forward_instruction", - ), - ( - messages.DeviceInstructionMessage( - device=None, - action="unstage", - parameter={}, - metadata={"readout_priority": "monitored"}, - ), - "forward_instruction", - ), - ( - messages.DeviceInstructionMessage( - device="samx", - action="rpc", - parameter={ - "device": "lsamy", - "func": "readback.get", - "rpc_id": "61a7376c-36cf-41af-94b1-76c1ba821d47", - "args": [], - "kwargs": {}, - }, - metadata={"readout_priority": "monitored"}, - ), - "forward_instruction", - ), - ( - messages.DeviceInstructionMessage( - device="samx", action="kickoff", parameter={}, metadata={} - ), - "forward_instruction", - ), - ( - messages.DeviceInstructionMessage( - device=None, - action="baseline_reading", - parameter={}, - metadata={"readout_priority": "baseline"}, - ), - "forward_instruction", - ), - ( - messages.DeviceInstructionMessage(device=None, action="close_scan_def", parameter={}), - "close_scan", - ), - ( - messages.DeviceInstructionMessage( - device=None, action="publish_data_as_read", parameter={} - ), - "publish_data_as_read", - ), - ( - messages.DeviceInstructionMessage( - device=None, action="scan_report_instruction", parameter={} - ), - "process_scan_report_instruction", - ), - ( - messages.DeviceInstructionMessage(device=None, action="pre_scan", parameter={}), - "forward_instruction", - ), - ( - messages.DeviceInstructionMessage(device=None, action="complete", parameter={}), - "forward_instruction", - ), - ], -) -def test_instruction_step(scan_worker_mock, msg, method): - worker = scan_worker_mock - with mock.patch( - f"bec_server.scan_server.scan_worker.ScanWorker.{method}" - ) as instruction_method: - with mock.patch.object(worker, "update_instr_with_scan_report") as update_mock: - worker._instruction_step(msg) - instruction_method.assert_called_once() - if method == "set": - update_mock.assert_called_once_with(msg) +def test_run_delegates_scan_abortion_handling_to_selected_worker(scan_worker_mock): + queue = mock.MagicMock() + delegated_worker = mock.MagicMock() + delegated_worker.process_instructions.side_effect = ScanAbortion() -def test_reset(scan_worker_mock): - worker = scan_worker_mock - worker.current_scan_id = 1 - worker.current_scan_info = 1 - worker.scan_id = 1 - worker.interception_msg = 1 + def _handle(_queue, _exc): + scan_worker_mock.signal_event.set() - worker.reset() + delegated_worker._handle_scan_abortion.side_effect = _handle - assert worker.current_scan_id == "" - assert worker.current_scan_info == {} - assert worker.scan_id == None - assert worker.interception_msg == None + with mock.patch.object(scan_worker_mock, "get_worker_for_queue", return_value=delegated_worker): + scan_worker_mock.parent.queue_manager.queues[scan_worker_mock.queue_name] = [queue] + scan_worker_mock.run() -def test_cleanup(scan_worker_mock): - worker = scan_worker_mock - with mock.patch.object(worker, "forward_instruction") as forward_mock: - worker.cleanup() - sent_message = forward_mock.mock_calls[0].args[0] - diid = sent_message.metadata["device_instr_id"] - devices = sent_message.device - msg = messages.DeviceInstructionMessage( - device=devices, action="unstage", parameter={}, metadata={"device_instr_id": diid} - ) - forward_mock.assert_called_once_with(msg) + delegated_worker._handle_scan_abortion.assert_called_once() def test_shutdown(scan_worker_mock): - worker = scan_worker_mock - with mock.patch.object(worker.signal_event, "set") as set_mock: - worker._started = mock.MagicMock() - worker._started.is_set.return_value = True - with mock.patch.object(worker, "join") as join_mock: - worker.shutdown() + with mock.patch.object(scan_worker_mock.signal_event, "set") as set_mock: + scan_worker_mock._started = mock.MagicMock() + scan_worker_mock._started.is_set.return_value = True + with mock.patch.object(scan_worker_mock, "join") as join_mock: + scan_worker_mock.shutdown() set_mock.assert_called_once() join_mock.assert_called_once() - - -@pytest.mark.parametrize( - "msg", - [ - messages.DeviceInstructionMessage( - device=["samx"], action="set", parameter={"value": 1}, metadata={"scan_id": "scan_id"} - ) - ], -) -def test_worker_update_instr_with_scan_report_no_update(msg, scan_worker_mock): - worker = scan_worker_mock - worker.current_instruction_queue_item = mock.MagicMock(spec=InstructionQueueItem) - arb = worker.current_instruction_queue_item.active_request_block = mock.MagicMock( - spec=RequestBlock - ) - arb.scan_report_instructions = [] - with mock.patch.object(worker, "forward_instruction") as forward_mock: - worker._instruction_step(msg) - worker.update_instr_with_scan_report(msg) - forward_mock.assert_called_once_with(msg) - - -@pytest.mark.parametrize( - "msg", - [ - messages.DeviceInstructionMessage( - device=["samx"], action="set", parameter={"value": 1}, metadata={"scan_id": "scan_id"} - ) - ], -) -def test_worker_update_instr_with_scan_report_no_update_with_report(msg, scan_worker_mock): - worker = scan_worker_mock - worker.current_instruction_queue_item = mock.MagicMock(spec=InstructionQueueItem) - arb = worker.current_instruction_queue_item.active_request_block = mock.MagicMock( - spec=RequestBlock - ) - arb.scan_report_instructions = [{"scan_progress": {"points": 10, "show_table": True}}] - with mock.patch.object(worker, "forward_instruction") as forward_mock: - worker._instruction_step(msg) - worker.update_instr_with_scan_report(msg) - forward_mock.assert_called_once_with(msg) - - -@pytest.mark.parametrize( - "msg", - [ - messages.DeviceInstructionMessage( - device=["samx"], action="set", parameter={"value": 1}, metadata={"scan_id": "scan_id"} - ) - ], -) -def test_worker_update_instr_with_scan_report_update(msg, scan_worker_mock): - worker = scan_worker_mock - worker.current_instruction_queue_item = mock.MagicMock(spec=InstructionQueueItem) - arb = worker.current_instruction_queue_item.active_request_block = mock.MagicMock( - spec=RequestBlock - ) - arb.scan_report_instructions = [ - {"readback": {"RID": "rid", "devices": ["samx"], "start": [0], "end": [1]}} - ] - with mock.patch.object(worker, "forward_instruction") as forward_mock: - worker._instruction_step(msg) - worker.update_instr_with_scan_report(msg) - forward_mock.assert_called_once_with(msg) - assert msg.metadata["response"] is True - - -@pytest.mark.parametrize( - "base_path, current_account_msg, expected_path, raises_error", - [ - ( - "/data/$account/raw", - messages.VariableMessage(value="test_account"), - "/data/test_account/raw", - False, - ), - ("/data/$account/raw", None, "/data/raw", False), - ( - "/data/raw", - messages.VariableMessage(value="test_account"), - "/data/raw/test_account", - False, - ), - ("/data/raw", None, "/data/raw", False), - ( - "/data/$account/$sub_dir/raw", - messages.VariableMessage(value="test_account"), - "/data/test_account/$sub_dir/raw", - True, - ), - ], -) -def test_worker_get_file_base_path( - scan_worker_mock, base_path, current_account_msg, expected_path, raises_error -): - worker = scan_worker_mock - file_writer_base_path_orig = worker.parent._service_config.config["file_writer"]["base_path"] - try: - worker.parent._service_config.config["file_writer"]["base_path"] = base_path - with mock.patch.object(worker.connector, "get_last", return_value=current_account_msg): - if raises_error: - with pytest.raises(ValueError) as exc_info: - worker._get_file_base_path() - else: - file_path = worker._get_file_base_path() - assert file_path == expected_path - worker.connector.get_last.assert_called_once_with( - MessageEndpoints.account(), "data" - ) - finally: - worker.parent._service_config.config["file_writer"][ - "base_path" - ] = file_writer_base_path_orig - - -@pytest.mark.parametrize( - "scan_info, out", - [ - (None, {}), - ({}, {}), - ({"scan_id": "12345"}, {"scan_id": "12345"}), - ({"scan_number": 1}, {"scan_number": 1}), - ({"scan_id": "12345", "scan_number": 1}, {"scan_id": "12345", "scan_number": 1}), - ], -) -def test_worker_get_metadata_for_alarm(scan_worker_mock, scan_info, out): - worker = scan_worker_mock - worker.current_scan_info = scan_info - metadata = worker._get_metadata_for_alarm() - assert metadata == out - - -def test_handle_scan_abortion(scan_worker_mock): - worker = scan_worker_mock - queue = mock.MagicMock(spec=InstructionQueueItem) - queue.exit_info = None - queue.queue_id = "id-12345" - with mock.patch.object(worker, "_send_scan_status") as send_status_mock: - abortion = ScanAbortion() - worker._handle_scan_abortion(queue, abortion) - send_status_mock.assert_called_once_with("aborted", reason="alarm") - - -def test_handle_scan_halt(scan_worker_mock): - worker = scan_worker_mock - queue = mock.MagicMock(spec=InstructionQueueItem) - queue.exit_info = None - queue.queue_id = "id-12345" - with mock.patch.object(worker, "_send_scan_status") as send_status_mock: - abortion = ScanAbortion() - queue.return_to_start = False - worker._handle_scan_abortion(queue, abortion) - send_status_mock.assert_called_once_with("halted", reason="alarm") - - -def test_handle_user_scan_interruption(scan_worker_mock): - worker = scan_worker_mock - queue = mock.MagicMock(spec=InstructionQueueItem) - queue.exit_info = None - queue.queue_id = "id-12345" - with mock.patch.object(worker, "_send_scan_status") as send_status_mock: - interruption = UserScanInterruption(exit_info=("user_completed", "user")) - worker._handle_scan_abortion(queue, interruption) - send_status_mock.assert_called_once_with("user_completed", reason="user") - - -def test_handle_user_scan_interruption_followed_by_abortion(scan_worker_mock): - worker = scan_worker_mock - queue = mock.MagicMock(spec=InstructionQueueItem) - queue.exit_info = "user_completed" - queue.queue_id = "id-12345" - with mock.patch.object(worker, "_send_scan_status") as send_status_mock: - interruption = UserScanInterruption(exit_info=("user_completed", "user")) - abortion = ScanAbortion() - worker._handle_scan_abortion(queue, interruption) - worker._handle_scan_abortion(queue, abortion) - send_status_mock.mock_calls[0].assert_called_with("user_completed", reason="user") - send_status_mock.mock_calls[1].assert_called_with("user_completed", reason="user")