From f9c54620684e048c29bde29903c420a7794ba30f Mon Sep 17 00:00:00 2001 From: wakonig_k Date: Sat, 16 May 2026 13:32:27 +0200 Subject: [PATCH] feat(bec_emitter, scan_bundler): forward device progress as scan progress --- .../bec_server/scan_bundler/bec_emitter.py | 114 ++++++++-- .../bec_server/scan_bundler/scan_bundler.py | 41 +++- .../tests_scan_bundler/test_bec_emitter.py | 207 +++++++++++++++++- 3 files changed, 335 insertions(+), 27 deletions(-) diff --git a/bec_server/bec_server/scan_bundler/bec_emitter.py b/bec_server/bec_server/scan_bundler/bec_emitter.py index 69d7ddb23..a2f8fd148 100644 --- a/bec_server/bec_server/scan_bundler/bec_emitter.py +++ b/bec_server/bec_server/scan_bundler/bec_emitter.py @@ -3,7 +3,7 @@ import threading import time from queue import Queue -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any, cast from bec_lib import messages from bec_lib.endpoints import MessageEndpoints @@ -14,6 +14,8 @@ logger = bec_logger.logger if TYPE_CHECKING: + from bec_lib.redis_connector import MessageObject + from .scan_bundler import ScanBundler @@ -22,6 +24,7 @@ def __init__(self, scan_bundler: ScanBundler) -> None: super().__init__(scan_bundler.connector) self._send_buffer = Queue() self.scan_bundler = scan_bundler + self._device_progress_subscriptions: dict[str, dict[str, Any]] = {} self._buffered_connector_thread = None self._buffered_publisher_stop_event = threading.Event() self._start_buffered_connector() @@ -96,7 +99,8 @@ def _send_bec_scan_point(self, scan_id: str, point_id: int) -> None: MessageEndpoints.scan_segment(), MessageEndpoints.public_scan_segment(scan_id=scan_id, point_id=point_id), ) - self._update_scan_progress(scan_id, point_id) + if not self._has_device_progress_subscription(scan_id): + self._update_scan_progress(scan_id, point_id) def _update_scan_progress(self, scan_id: str, point_id: int, done=False) -> None: if scan_id not in self.scan_bundler.sync_storage: @@ -107,18 +111,36 @@ def _update_scan_progress(self, scan_id: str, point_id: int, done=False) -> None info = self.scan_bundler.sync_storage[scan_id]["info"] num_monitored_readouts = info.get("num_monitored_readouts", info.get("num_points", 0)) - + value = point_id + 1 + max_value = num_monitored_readouts or point_id + 1 + self.send_scan_progress(scan_id, value=value, max_value=max_value, done=done) + + def send_scan_progress(self, scan_id: str, value: float, max_value: float, done=False) -> None: + """ + Send a scan progress update. + + Args: + scan_id (str): The ID of the scan. + value (float): The current progress value. + max_value (float): The maximum progress value. + done (bool): Whether the scan is done. + """ + storage = self.scan_bundler.sync_storage.get(scan_id) + if not storage: + return + info = storage["info"] msg = messages.ProgressMessage( - value=point_id + 1, - max_value=num_monitored_readouts or point_id + 1, + value=value, + max_value=max_value, done=done, metadata={ "scan_id": scan_id, "RID": info.get("RID", ""), "queue_id": info.get("queue_id", ""), - "status": self.scan_bundler.sync_storage[scan_id]["status"], + "status": storage["status"], }, ) + storage["last_progress_sent"] = msg self.scan_bundler.connector.set_and_publish(MessageEndpoints.scan_progress(), msg) def _send_baseline(self, scan_id: str) -> None: @@ -141,29 +163,93 @@ def _send_baseline(self, scan_id: str) -> None: pipe.execute() def on_scan_status_update(self, status_msg: messages.ScanStatusMessage): + sb = self.scan_bundler + if status_msg.scan_id not in sb.sync_storage: + logger.warning( + f"Cannot update scan progress: Scan {status_msg.scan_id} not found in sync storage." + ) + return + if status_msg.status == "open": - # No need to update progress for an open scan. This is handled by the scan point emit. + # Update progress subscription: + # - If the scan report instruction contains "scan_progress", we simply emit + # progress updates as they come in. + # - If the scan report instruction contains "device_progress", we subscribe + # to the progress of the first device and use that as the progress for the whole scan. + self._update_device_progress_subscription(status_msg.scan_id) return num_points = max(status_msg.info.get("num_points", 0) - 1, 0) num_monitored_readouts = status_msg.info.get("num_monitored_readouts", num_points) if status_msg.status == "closed": - self._update_scan_progress(status_msg.scan_id, num_monitored_readouts, done=True) - return + if not self._has_device_progress_subscription(status_msg.scan_id): + self._update_scan_progress(status_msg.scan_id, num_monitored_readouts, done=True) + return - sb = self.scan_bundler - if status_msg.scan_id not in sb.sync_storage: - logger.warning( - f"Cannot update scan progress: Scan {status_msg.scan_id} not found in sync storage." - ) + self._unregister_device_progress_subscription(status_msg.scan_id) + self._emit_last_progress(status_msg.scan_id) return + + # Scan is not open or closed but instead in ["aborted", "halted", "user_completed"] storage = sb.sync_storage[status_msg.scan_id] + if self._has_device_progress_subscription(status_msg.scan_id): + self._unregister_device_progress_subscription(status_msg.scan_id) + self._emit_last_progress(status_msg.scan_id) + return sent_vals = storage.get("sent", {0}) or {0} max_point = max(sent_vals) self._update_scan_progress(status_msg.scan_id, max_point, done=True) + def on_cleanup(self, scan_id: str): + self._unregister_device_progress_subscription(scan_id) + def shutdown(self): if self._buffered_connector_thread: self._buffered_publisher_stop_event.set() self._buffered_connector_thread.join() self._buffered_connector_thread = None + + ############################################################# + ################# Device Progress Helpers ################### + ############################################################# + + def _update_device_progress_subscription(self, scan_id: str): + sb = self.scan_bundler + instructions = sb.scan_report_instructions.get(scan_id, []) + if self._has_device_progress_subscription(scan_id): + return + for instruction in instructions: + if "device_progress" in instruction: + device = instruction["device_progress"][0] + sub = { + "topics": MessageEndpoints.device_progress(device=device), + "cb": lambda msg_obj, _scan_id=scan_id: self._on_device_progress( + msg_obj, _scan_id + ), + } + self._device_progress_subscriptions[scan_id] = sub + self.connector.register(**sub) + return + + def _emit_last_progress(self, scan_id: str): + storage = self.scan_bundler.sync_storage.get(scan_id, {}) + msg = storage.get("last_progress_sent") + value = msg.value if msg else 0 + max_value = msg.max_value if msg else 0 + self.send_scan_progress(scan_id, value=value, max_value=max_value, done=True) + + def _on_device_progress(self, msg_obj: MessageObject, scan_id: str): + msg = cast(messages.ProgressMessage, msg_obj.value) + if msg.metadata.get("scan_id") != scan_id: + return + if msg.done: + self._unregister_device_progress_subscription(scan_id) + self.send_scan_progress(scan_id, value=msg.value, max_value=msg.max_value, done=msg.done) + + def _has_device_progress_subscription(self, scan_id: str) -> bool: + return scan_id in self._device_progress_subscriptions + + def _unregister_device_progress_subscription(self, scan_id: str) -> None: + sub_info = self._device_progress_subscriptions.pop(scan_id, None) + if sub_info: + self.connector.unregister(**sub_info) diff --git a/bec_server/bec_server/scan_bundler/scan_bundler.py b/bec_server/bec_server/scan_bundler/scan_bundler.py index 78fd7a0be..27ecd75d4 100644 --- a/bec_server/bec_server/scan_bundler/scan_bundler.py +++ b/bec_server/bec_server/scan_bundler/scan_bundler.py @@ -6,7 +6,7 @@ import traceback from collections.abc import Callable from concurrent.futures import ThreadPoolExecutor -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, cast from bec_lib import messages from bec_lib.bec_service import BECService @@ -17,7 +17,7 @@ from .bec_emitter import BECEmitter if TYPE_CHECKING: - from bec_lib.redis_connector import RedisConnector + from bec_lib.redis_connector import MessageObject, RedisConnector logger = bec_logger.logger @@ -35,12 +35,13 @@ def __init__(self, config, connector_cls: type[RedisConnector]) -> None: name="device_read_register", ) self.connector.register(MessageEndpoints.scan_status(), cb=self._scan_status_callback) - self.sync_storage = {} self.monitored_devices = {} self.baseline_devices = {} self.device_storage = {} self.readout_priority = {} + self.scan_queue: messages.ScanQueueStatusMessage | None = None + self.scan_report_instructions: dict[str, list] = {} self.storage_initialized = set() self.executor = ThreadPoolExecutor(max_workers=4) self.executor_tasks = collections.deque(maxlen=100) @@ -48,6 +49,9 @@ def __init__(self, config, connector_cls: type[RedisConnector]) -> None: self._lock = threading.Lock() self._emitter = [] self._initialize_emitters() + self.connector.register( + MessageEndpoints.scan_queue_status(), cb=self.on_scan_queue_status_update + ) self.status = messages.BECStatus.RUNNING def _initialize_emitters(self): @@ -95,6 +99,35 @@ def handle_scan_status_message(self, msg: messages.ScanStatusMessage) -> None: self._scan_status_modification(msg) self.run_emitter("on_scan_status_update", msg) + def on_scan_queue_status_update(self, msg_obj: MessageObject): + """ + Update the scan_report_instructions based on the active request block + in the scan queue status message. + + Args: + status_msg (messages.ScanQueueStatusMessage): The scan queue status message + containing the active request block. + """ + status_msg = cast(messages.ScanQueueStatusMessage, msg_obj.value) + for scan_queue_status in status_msg.queue.values(): + if not scan_queue_status.info: + continue + info = scan_queue_status.info[0] + active_request_block = info.active_request_block + if not active_request_block: + continue + scan_id = active_request_block.scan_id + if scan_id is None: + continue + report_instructions = active_request_block.report_instructions + if not report_instructions: + continue + + self.scan_report_instructions[scan_id] = report_instructions + logger.debug( + f"Updated report instructions for scan_id {scan_id}: {report_instructions}" + ) + def _scan_status_modification(self, msg: messages.ScanStatusMessage): status = msg.content.get("status") if status not in ["closed", "aborted", "paused", "halted", "user_completed"]: @@ -358,6 +391,7 @@ def cleanup_storage(self): remove_scan_ids.append(scan_id) for scan_id in remove_scan_ids: + self.run_emitter("on_cleanup", scan_id) for storage in [ "sync_storage", "monitored_devices", @@ -368,7 +402,6 @@ def cleanup_storage(self): getattr(self, storage).pop(scan_id) except KeyError: logger.warning(f"Failed to remove {scan_id} from {storage}.") - self.run_emitter("on_cleanup", scan_id) self.storage_initialized.remove(scan_id) def _send_scan_point(self, scan_id, point_id) -> None: diff --git a/bec_server/tests/tests_scan_bundler/test_bec_emitter.py b/bec_server/tests/tests_scan_bundler/test_bec_emitter.py index 5c21852ce..6d5877361 100644 --- a/bec_server/tests/tests_scan_bundler/test_bec_emitter.py +++ b/bec_server/tests/tests_scan_bundler/test_bec_emitter.py @@ -3,6 +3,7 @@ import pytest from bec_lib import messages +from bec_lib.connector import MessageObject from bec_lib.endpoints import MessageEndpoints from bec_server.scan_bundler.bec_emitter import BECEmitter @@ -52,6 +53,26 @@ def test_send_bec_scan_point(bec_emitter_mock): ) +def test_send_bec_scan_point_skips_point_progress_with_device_progress_sub(bec_emitter_mock): + sb = bec_emitter_mock.scan_bundler + scan_id = "lkajsdlkj" + point_id = 2 + sb.sync_storage[scan_id] = {"info": {}, "status": "open", "sent": set()} + sb.sync_storage[scan_id][point_id] = {} + bec_emitter_mock._device_progress_subscriptions[scan_id] = { + "topics": MessageEndpoints.device_progress("samx"), + "cb": mock.Mock(), + } + + with ( + mock.patch.object(bec_emitter_mock, "add_message") as send, + mock.patch.object(bec_emitter_mock, "_update_scan_progress") as update_progress, + ): + bec_emitter_mock._send_bec_scan_point(scan_id, point_id) + send.assert_called_once() + update_progress.assert_not_called() + + def test_send_baseline_BEC(bec_emitter_mock): sb = bec_emitter_mock.scan_bundler scan_id = "lkajsdlkj" @@ -161,6 +182,19 @@ def test_add_message(msg, endpoint, public): emitter.shutdown() +def test_bec_emitter_scan_status_update_open_updates_subscription(bec_emitter_mock): + bec_emitter_mock.scan_bundler.sync_storage["lkajsdlkj"] = { + "info": {}, + "status": "open", + "sent": set(), + "baseline": {}, + } + msg = messages.ScanStatusMessage(scan_id="lkajsdlkj", status="open", info={"num_points": 10}) + with mock.patch.object(bec_emitter_mock, "_update_device_progress_subscription") as update_sub: + bec_emitter_mock.on_scan_status_update(msg) + update_sub.assert_called_once_with("lkajsdlkj") + + @pytest.mark.parametrize( "msg, sent, progress, ref_scan_id", [ @@ -175,7 +209,7 @@ def test_add_message(msg, endpoint, public): scan_id="lkajsdlkj", status="closed", info={"num_points": 10} ), {0, 1}, - 9, # 10 points, but sent 0 and 1, so progress is 9 + 9, "lkajsdlkj", ), ( @@ -192,7 +226,7 @@ def test_add_message(msg, endpoint, public): ), {0, 1}, 1, - "lkajsdlkj", # This is a different scan_id, should not update progress + "lkajsdlkj", ), ( messages.ScanStatusMessage( @@ -200,19 +234,174 @@ def test_add_message(msg, endpoint, public): ), {}, 0, - "lkajsdlkj", # This is a different scan_id, should not update progress + "lkajsdlkj", ), ], ) -def test_bec_emitter_scan_status_update(bec_emitter_mock, msg, sent, progress, ref_scan_id): - +def test_bec_emitter_scan_status_update_point_progress_path( + bec_emitter_mock, msg, sent, progress, ref_scan_id +): sb = bec_emitter_mock.scan_bundler - sb.sync_storage[ref_scan_id] = {"info": {}, "status": msg.status, "sent": sent} - sb.sync_storage[ref_scan_id]["baseline"] = {} + sb.sync_storage[ref_scan_id] = {"info": {}, "status": msg.status, "sent": sent, "baseline": {}} - with mock.patch.object(bec_emitter_mock, "_update_scan_progress") as update: + with ( + mock.patch.object(bec_emitter_mock, "_update_scan_progress") as update, + mock.patch.object(bec_emitter_mock, "_update_device_progress_subscription") as update_sub, + ): bec_emitter_mock.on_scan_status_update(msg) - if msg.status == "open" or msg.scan_id != ref_scan_id: + if msg.status == "open": + update.assert_not_called() + update_sub.assert_called_once_with(msg.scan_id) + elif msg.scan_id != ref_scan_id: update.assert_not_called() + update_sub.assert_not_called() else: update.assert_called_once_with(msg.scan_id, progress, done=True) + update_sub.assert_not_called() + + +def test_bec_emitter_scan_status_update_missing_scan_id_does_not_update(bec_emitter_mock): + msg = messages.ScanStatusMessage( + scan_id="wrong_scan_id", status="aborted", info={"num_points": 10} + ) + with mock.patch.object(bec_emitter_mock, "_update_scan_progress") as update: + bec_emitter_mock.on_scan_status_update(msg) + update.assert_not_called() + + +@pytest.mark.parametrize("status", ["closed", "aborted"]) +def test_bec_emitter_scan_status_update_wrong_scan_id_does_not_emit_progress( + bec_emitter_mock, status +): + msg = messages.ScanStatusMessage( + scan_id="wrong_scan_id", status=status, info={"num_points": 10} + ) + with ( + mock.patch.object(bec_emitter_mock, "_update_scan_progress") as update, + mock.patch.object(bec_emitter_mock, "send_scan_progress") as send_scan_progress, + ): + bec_emitter_mock.on_scan_status_update(msg) + update.assert_not_called() + send_scan_progress.assert_not_called() + + +def test_update_device_progress_subscription_registers_device_progress(bec_emitter_mock): + sb = bec_emitter_mock.scan_bundler + scan_id = "scan_id" + sb.sync_storage[scan_id] = {"info": {}, "status": "open", "sent": set()} + sb.scan_report_instructions[scan_id] = [{"device_progress": ["samx"]}] + + with mock.patch.object(bec_emitter_mock.connector, "register") as register: + bec_emitter_mock._update_device_progress_subscription(scan_id) + + registered_sub = bec_emitter_mock._device_progress_subscriptions[scan_id] + assert registered_sub["topics"] == MessageEndpoints.device_progress(device="samx") + assert callable(registered_sub["cb"]) + register.assert_called_once_with(**registered_sub) + + +def test_on_device_progress_done_unregisters_and_emits_progress(bec_emitter_mock): + scan_id = "scan_id" + sub = {"topics": MessageEndpoints.device_progress(device="samx"), "cb": mock.Mock()} + bec_emitter_mock.scan_bundler.sync_storage[scan_id] = { + "info": {}, + "status": "open", + "sent": set(), + } + bec_emitter_mock._device_progress_subscriptions[scan_id] = sub + progress_msg = messages.ProgressMessage( + value=3, max_value=7, done=True, metadata={"scan_id": scan_id} + ) + msg_obj = MessageObject(MessageEndpoints.device_progress("samx").endpoint, progress_msg) + + with ( + mock.patch.object(bec_emitter_mock.connector, "unregister") as unregister, + mock.patch.object(bec_emitter_mock, "send_scan_progress") as send_scan_progress, + ): + bec_emitter_mock._on_device_progress(msg_obj, scan_id) + + unregister.assert_called_once_with(**sub) + send_scan_progress.assert_called_once_with(scan_id, value=3, max_value=7, done=True) + + +def test_on_device_progress_ignores_other_scan_progress(bec_emitter_mock): + scan_id = "scan_id" + sub = {"topics": MessageEndpoints.device_progress(device="samx"), "cb": mock.Mock()} + bec_emitter_mock.scan_bundler.sync_storage[scan_id] = { + "info": {}, + "status": "open", + "sent": set(), + } + bec_emitter_mock._device_progress_subscriptions[scan_id] = sub + progress_msg = messages.ProgressMessage( + value=3, max_value=7, done=True, metadata={"scan_id": "other_scan_id"} + ) + msg_obj = MessageObject(MessageEndpoints.device_progress("samx").endpoint, progress_msg) + + with ( + mock.patch.object(bec_emitter_mock.connector, "unregister") as unregister, + mock.patch.object(bec_emitter_mock, "send_scan_progress") as send_scan_progress, + ): + bec_emitter_mock._on_device_progress(msg_obj, scan_id) + + unregister.assert_not_called() + send_scan_progress.assert_not_called() + assert bec_emitter_mock._device_progress_subscriptions[scan_id] == sub + + +def test_scan_status_update_closed_with_device_progress_unsubscribes_and_emits_last_progress( + bec_emitter_mock, +): + sb = bec_emitter_mock.scan_bundler + scan_id = "scan_id" + sub = {"topics": MessageEndpoints.device_progress(device="samx"), "cb": mock.Mock()} + sb.sync_storage[scan_id] = { + "info": {}, + "status": "closed", + "sent": {0, 1}, + "baseline": {}, + "last_progress_sent": messages.ProgressMessage(value=4, max_value=9, done=False), + } + bec_emitter_mock._device_progress_subscriptions[scan_id] = sub + msg = messages.ScanStatusMessage(scan_id=scan_id, status="closed", info={"num_points": 10}) + + with ( + mock.patch.object(bec_emitter_mock.connector, "unregister") as unregister, + mock.patch.object(bec_emitter_mock, "send_scan_progress") as send_scan_progress, + ): + bec_emitter_mock.on_scan_status_update(msg) + + unregister.assert_called_once_with(**sub) + send_scan_progress.assert_called_once_with(scan_id, value=4, max_value=9, done=True) + + +@pytest.mark.parametrize("status", ["closed", "aborted"]) +def test_scan_status_update_device_progress_without_last_progress_emits_done_message( + bec_emitter_mock, status +): + sb = bec_emitter_mock.scan_bundler + scan_id = "scan_id" + sub = {"topics": MessageEndpoints.device_progress(device="samx"), "cb": mock.Mock()} + sb.sync_storage[scan_id] = {"info": {}, "status": status, "sent": {0, 1}, "baseline": {}} + bec_emitter_mock._device_progress_subscriptions[scan_id] = sub + msg = messages.ScanStatusMessage(scan_id=scan_id, status=status, info={"num_points": 10}) + + with ( + mock.patch.object(bec_emitter_mock.connector, "unregister") as unregister, + mock.patch.object(bec_emitter_mock, "send_scan_progress") as send_scan_progress, + ): + bec_emitter_mock.on_scan_status_update(msg) + + unregister.assert_called_once_with(**sub) + send_scan_progress.assert_called_once_with(scan_id, value=0, max_value=0, done=True) + + +def test_on_cleanup_unregisters_device_progress_subscription(bec_emitter_mock): + scan_id = "scan_id" + sub = {"topics": MessageEndpoints.device_progress(device="samx"), "cb": mock.Mock()} + bec_emitter_mock._device_progress_subscriptions[scan_id] = sub + + with mock.patch.object(bec_emitter_mock.connector, "unregister") as unregister: + bec_emitter_mock.on_cleanup(scan_id) + + unregister.assert_called_once_with(**sub)