Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 100 additions & 14 deletions bec_server/bec_server/scan_bundler/bec_emitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import threading
import time
from queue import Queue
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any, cast

from bec_lib import messages
from bec_lib.endpoints import MessageEndpoints
Expand All @@ -14,6 +14,8 @@
logger = bec_logger.logger

if TYPE_CHECKING:
from bec_lib.redis_connector import MessageObject

from .scan_bundler import ScanBundler


Expand All @@ -22,6 +24,7 @@ def __init__(self, scan_bundler: ScanBundler) -> None:
super().__init__(scan_bundler.connector)
self._send_buffer = Queue()
self.scan_bundler = scan_bundler
self._device_progress_subscriptions: dict[str, dict[str, Any]] = {}
self._buffered_connector_thread = None
self._buffered_publisher_stop_event = threading.Event()
self._start_buffered_connector()
Expand Down Expand Up @@ -96,7 +99,8 @@ def _send_bec_scan_point(self, scan_id: str, point_id: int) -> None:
MessageEndpoints.scan_segment(),
MessageEndpoints.public_scan_segment(scan_id=scan_id, point_id=point_id),
)
self._update_scan_progress(scan_id, point_id)
if not self._has_device_progress_subscription(scan_id):
self._update_scan_progress(scan_id, point_id)

def _update_scan_progress(self, scan_id: str, point_id: int, done=False) -> None:
if scan_id not in self.scan_bundler.sync_storage:
Expand All @@ -107,18 +111,36 @@ def _update_scan_progress(self, scan_id: str, point_id: int, done=False) -> None
info = self.scan_bundler.sync_storage[scan_id]["info"]

num_monitored_readouts = info.get("num_monitored_readouts", info.get("num_points", 0))

value = point_id + 1
max_value = num_monitored_readouts or point_id + 1
self.send_scan_progress(scan_id, value=value, max_value=max_value, done=done)

def send_scan_progress(self, scan_id: str, value: float, max_value: float, done=False) -> None:
"""
Send a scan progress update.

Args:
scan_id (str): The ID of the scan.
value (float): The current progress value.
max_value (float): The maximum progress value.
done (bool): Whether the scan is done.
"""
storage = self.scan_bundler.sync_storage.get(scan_id)
if not storage:
return
info = storage["info"]
msg = messages.ProgressMessage(
value=point_id + 1,
max_value=num_monitored_readouts or point_id + 1,
value=value,
max_value=max_value,
done=done,
metadata={
"scan_id": scan_id,
"RID": info.get("RID", ""),
"queue_id": info.get("queue_id", ""),
"status": self.scan_bundler.sync_storage[scan_id]["status"],
"status": storage["status"],
},
)
storage["last_progress_sent"] = msg
self.scan_bundler.connector.set_and_publish(MessageEndpoints.scan_progress(), msg)

def _send_baseline(self, scan_id: str) -> None:
Expand All @@ -141,29 +163,93 @@ def _send_baseline(self, scan_id: str) -> None:
pipe.execute()

def on_scan_status_update(self, status_msg: messages.ScanStatusMessage):
sb = self.scan_bundler
if status_msg.scan_id not in sb.sync_storage:
logger.warning(
f"Cannot update scan progress: Scan {status_msg.scan_id} not found in sync storage."
)
return

if status_msg.status == "open":
# No need to update progress for an open scan. This is handled by the scan point emit.
# Update progress subscription:
# - If the scan report instruction contains "scan_progress", we simply emit
# progress updates as they come in.
# - If the scan report instruction contains "device_progress", we subscribe
# to the progress of the first device and use that as the progress for the whole scan.
self._update_device_progress_subscription(status_msg.scan_id)
return

num_points = max(status_msg.info.get("num_points", 0) - 1, 0)
num_monitored_readouts = status_msg.info.get("num_monitored_readouts", num_points)
if status_msg.status == "closed":
self._update_scan_progress(status_msg.scan_id, num_monitored_readouts, done=True)
return
if not self._has_device_progress_subscription(status_msg.scan_id):
self._update_scan_progress(status_msg.scan_id, num_monitored_readouts, done=True)
return

sb = self.scan_bundler
if status_msg.scan_id not in sb.sync_storage:
logger.warning(
f"Cannot update scan progress: Scan {status_msg.scan_id} not found in sync storage."
)
self._unregister_device_progress_subscription(status_msg.scan_id)
self._emit_last_progress(status_msg.scan_id)
return

# Scan is not open or closed but instead in ["aborted", "halted", "user_completed"]
storage = sb.sync_storage[status_msg.scan_id]
if self._has_device_progress_subscription(status_msg.scan_id):
self._unregister_device_progress_subscription(status_msg.scan_id)
self._emit_last_progress(status_msg.scan_id)
return
sent_vals = storage.get("sent", {0}) or {0}
max_point = max(sent_vals)
self._update_scan_progress(status_msg.scan_id, max_point, done=True)

def on_cleanup(self, scan_id: str):
self._unregister_device_progress_subscription(scan_id)

def shutdown(self):
if self._buffered_connector_thread:
self._buffered_publisher_stop_event.set()
self._buffered_connector_thread.join()
self._buffered_connector_thread = None

#############################################################
################# Device Progress Helpers ###################
#############################################################

def _update_device_progress_subscription(self, scan_id: str):
sb = self.scan_bundler
instructions = sb.scan_report_instructions.get(scan_id, [])
if self._has_device_progress_subscription(scan_id):
return
for instruction in instructions:
if "device_progress" in instruction:
device = instruction["device_progress"][0]
sub = {
"topics": MessageEndpoints.device_progress(device=device),
"cb": lambda msg_obj, _scan_id=scan_id: self._on_device_progress(
msg_obj, _scan_id
),
}
self._device_progress_subscriptions[scan_id] = sub
self.connector.register(**sub)
return

def _emit_last_progress(self, scan_id: str):
storage = self.scan_bundler.sync_storage.get(scan_id, {})
msg = storage.get("last_progress_sent")
value = msg.value if msg else 0
max_value = msg.max_value if msg else 0
self.send_scan_progress(scan_id, value=value, max_value=max_value, done=True)

def _on_device_progress(self, msg_obj: MessageObject, scan_id: str):
msg = cast(messages.ProgressMessage, msg_obj.value)
if msg.metadata.get("scan_id") != scan_id:
return
if msg.done:
self._unregister_device_progress_subscription(scan_id)
self.send_scan_progress(scan_id, value=msg.value, max_value=msg.max_value, done=msg.done)

def _has_device_progress_subscription(self, scan_id: str) -> bool:
return scan_id in self._device_progress_subscriptions

def _unregister_device_progress_subscription(self, scan_id: str) -> None:
sub_info = self._device_progress_subscriptions.pop(scan_id, None)
if sub_info:
self.connector.unregister(**sub_info)
41 changes: 37 additions & 4 deletions bec_server/bec_server/scan_bundler/scan_bundler.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import traceback
from collections.abc import Callable
from concurrent.futures import ThreadPoolExecutor
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, cast

from bec_lib import messages
from bec_lib.bec_service import BECService
Expand All @@ -17,7 +17,7 @@
from .bec_emitter import BECEmitter

if TYPE_CHECKING:
from bec_lib.redis_connector import RedisConnector
from bec_lib.redis_connector import MessageObject, RedisConnector


logger = bec_logger.logger
Expand All @@ -35,19 +35,23 @@ def __init__(self, config, connector_cls: type[RedisConnector]) -> None:
name="device_read_register",
)
self.connector.register(MessageEndpoints.scan_status(), cb=self._scan_status_callback)

self.sync_storage = {}
self.monitored_devices = {}
self.baseline_devices = {}
self.device_storage = {}
self.readout_priority = {}
self.scan_queue: messages.ScanQueueStatusMessage | None = None
self.scan_report_instructions: dict[str, list] = {}
self.storage_initialized = set()
self.executor = ThreadPoolExecutor(max_workers=4)
self.executor_tasks = collections.deque(maxlen=100)
self.scan_id_history = collections.deque(maxlen=10)
self._lock = threading.Lock()
self._emitter = []
self._initialize_emitters()
self.connector.register(
MessageEndpoints.scan_queue_status(), cb=self.on_scan_queue_status_update
)
self.status = messages.BECStatus.RUNNING

def _initialize_emitters(self):
Expand Down Expand Up @@ -95,6 +99,35 @@ def handle_scan_status_message(self, msg: messages.ScanStatusMessage) -> None:
self._scan_status_modification(msg)
self.run_emitter("on_scan_status_update", msg)

def on_scan_queue_status_update(self, msg_obj: MessageObject):
"""
Update the scan_report_instructions based on the active request block
in the scan queue status message.

Args:
status_msg (messages.ScanQueueStatusMessage): The scan queue status message
containing the active request block.
"""
status_msg = cast(messages.ScanQueueStatusMessage, msg_obj.value)
for scan_queue_status in status_msg.queue.values():
if not scan_queue_status.info:
continue
info = scan_queue_status.info[0]
active_request_block = info.active_request_block
if not active_request_block:
continue
scan_id = active_request_block.scan_id
if scan_id is None:
continue
report_instructions = active_request_block.report_instructions
if not report_instructions:
continue

self.scan_report_instructions[scan_id] = report_instructions
logger.debug(
f"Updated report instructions for scan_id {scan_id}: {report_instructions}"
)

def _scan_status_modification(self, msg: messages.ScanStatusMessage):
status = msg.content.get("status")
if status not in ["closed", "aborted", "paused", "halted", "user_completed"]:
Expand Down Expand Up @@ -358,6 +391,7 @@ def cleanup_storage(self):
remove_scan_ids.append(scan_id)

for scan_id in remove_scan_ids:
self.run_emitter("on_cleanup", scan_id)
for storage in [
"sync_storage",
"monitored_devices",
Expand All @@ -368,7 +402,6 @@ def cleanup_storage(self):
getattr(self, storage).pop(scan_id)
except KeyError:
logger.warning(f"Failed to remove {scan_id} from {storage}.")
self.run_emitter("on_cleanup", scan_id)
self.storage_initialized.remove(scan_id)

def _send_scan_point(self, scan_id, point_id) -> None:
Expand Down
Loading
Loading