Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/actions/bec_e2e_install/action.yml
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ runs:
cd ./_e2e_test_checkout_/bec
source ./bin/install_bec_dev.sh -t
pip install -e ../ophyd_devices
podman pod create --net host local_bec
podman pod create --network=host local_bec
python ./bec_ipython_client/tests/end-2-end/_ensure_requirements_container.py
pytest -v --files-path ./ --start-servers --random-order ./bec_ipython_client/tests/end-2-end/

31 changes: 22 additions & 9 deletions bec_ipython_client/tests/end-2-end/test_procedures_e2e.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import time
from dataclasses import dataclass
from importlib.metadata import version
from typing import TYPE_CHECKING, Callable, Generator
from unittest.mock import MagicMock, patch
Expand All @@ -11,7 +12,7 @@
from bec_lib import messages
from bec_lib.endpoints import MessageEndpoints
from bec_lib.logger import bec_logger
from bec_server.scan_server.procedures.constants import PROCEDURE
from bec_server.scan_server.procedures.constants import _CONTAINER, _WORKER
from bec_server.scan_server.procedures.container_utils import get_backend
from bec_server.scan_server.procedures.container_worker import ContainerProcedureWorker
from bec_server.scan_server.procedures.manager import ProcedureManager
Expand All @@ -28,6 +29,15 @@
pytestmark = pytest.mark.random_order(disabled=True)


@dataclass(frozen=True)
class PATCHED_CONSTANTS:
WORKER = _WORKER()
CONTAINER = _CONTAINER()
MANAGER_SHUTDOWN_TIMEOUT_S = 2
BEC_VERSION = version("bec_lib")
REDIS_HOST = "localhost"


@pytest.fixture
def client_logtool_and_manager(
bec_ipython_client_fixture_with_logtool: tuple[BECIPythonClient, "LogTestTool"],
Expand All @@ -52,7 +62,7 @@ def _wait_while(cond: Callable[[], bool], timeout_s):
def test_building_worker_image():
podman_utils = get_backend()
build = podman_utils.build_worker_image()
assert len(build._command_output.splitlines()[-1]) == 64
assert len(build._command_output.splitlines()[-1]) == 64 # type: ignore
assert podman_utils.image_exists(f"bec_procedure_worker:v{version('bec_lib')}")


Expand All @@ -62,7 +72,7 @@ def test_procedure_runner_spawns_worker(
client_logtool_and_manager: tuple[BECIPythonClient, "LogTestTool", ProcedureManager],
):
client, _, manager = client_logtool_and_manager
assert manager.active_workers == {}
assert manager._active_workers == {}
endpoint = MessageEndpoints.procedure_request()
msg = messages.ProcedureRequestMessage(
identifier="sleep", args_kwargs=((), {"time_s": 2}), queue="test"
Expand All @@ -77,34 +87,37 @@ def cb(worker: ContainerProcedureWorker):
manager.add_callback("test", cb)
client.connector.xadd(topic=endpoint, msg_dict=msg.model_dump())

_wait_while(lambda: manager.active_workers == {}, 5)
_wait_while(lambda: manager.active_workers != {}, 20)
_wait_while(lambda: manager._active_workers == {}, 5)
_wait_while(lambda: manager._active_workers != {}, 20)

assert logs != []


@pytest.mark.timeout(100)
@patch("bec_server.scan_server.procedures.manager.procedure_registry.is_registered", lambda _: True)
@patch("bec_server.scan_server.procedures.container_worker.PROCEDURE", PATCHED_CONSTANTS())
def test_happy_path_container_procedure_runner(
client_logtool_and_manager: tuple[BECIPythonClient, "LogTestTool", ProcedureManager],
):
test_args = (1, 2, 3)
test_kwargs = {"a": "b", "c": "d"}
client, logtool, manager = client_logtool_and_manager
assert manager.active_workers == {}
assert manager._active_workers == {}
conn = client.connector
endpoint = MessageEndpoints.procedure_request()
msg = messages.ProcedureRequestMessage(
identifier="log execution message args", args_kwargs=(test_args, test_kwargs)
)
conn.xadd(topic=endpoint, msg_dict=msg.model_dump())

_wait_while(lambda: manager.active_workers == {}, 5)
_wait_while(lambda: manager.active_workers != {}, 20)
_wait_while(lambda: manager._active_workers == {}, 5)
_wait_while(lambda: manager._active_workers != {}, 20)

logtool.fetch()
assert logtool.is_present_in_any_message("procedure accepted: True, message:")
assert logtool.is_present_in_any_message("ContainerWorker started container for queue primary")
assert logtool.is_present_in_any_message(
"ContainerWorker started container for queue primary"
), f"Log content relating to procedures: {manager}"
res, msg = logtool.are_present_in_order(
[
"Container worker 'primary' status update: IDLE",
Expand Down
85 changes: 80 additions & 5 deletions bec_lib/bec_lib/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,9 @@ class MessageOp(list[str], enum.Enum):
SET_PUBLISH = ["register", "set_and_publish", "delete", "get", "keys"]
SEND = ["send", "register"]
STREAM = ["xadd", "xrange", "xread", "register_stream", "keys", "get_last", "delete"]
LIST = ["lpush", "lrange", "rpush", "ltrim", "keys", "delete", "blocking_list_pop"]
LIST = ["lpush", "lrange", "lrem", "rpush", "ltrim", "keys", "delete", "blocking_list_pop"]
KEY_VALUE = ["set", "get", "delete", "keys"]
SET = ["remove_from_set", "get_set_members"]
SET = ["remove_from_set", "get_set_members", "delete"]


MessageType = TypeVar("MessageType", bound="type[messages.BECMessage]")
Expand Down Expand Up @@ -1439,8 +1439,8 @@ def available_procedures() -> EndpointInfo:
@staticmethod
def procedure_request() -> EndpointInfo:
"""
Endpoint for scan queue request. This endpoint is used to request the new scans.
The request is sent using a messages.ScanQueueMessage message.
Endpoint for requesting new procedures.
The request is sent using a messages.ProcedureRequestMessage message.

Returns:
EndpointInfo: Endpoint for scan queue request.
Expand Down Expand Up @@ -1479,7 +1479,24 @@ def procedure_execution(queue_id: str):
Returns:
EndpointInfo: Endpoint for scan queue request.
"""
endpoint = f"{EndpointType.INTERNAL.value}/procedures/procedure_execution/{queue_id}"
endpoint = f"{EndpointType.INFO.value}/procedures/procedure_execution/{queue_id}"
return EndpointInfo(
endpoint=endpoint,
message_type=messages.ProcedureExecutionMessage,
message_op=MessageOp.LIST,
)

@staticmethod
def unhandled_procedure_execution(queue_id: str):
"""
Endpoint for procedure executions which were pending when the manager was shutdown.
Messages from procedure_execution are moved here on manager startup.
The request is sent using a messages.ProcedureExecutionMessage message.

Returns:
EndpointInfo: Endpoint for scan queue request.
"""
endpoint = f"{EndpointType.INFO.value}/procedures/unhandled_procedure_execution/{queue_id}"
return EndpointInfo(
endpoint=endpoint,
message_type=messages.ProcedureExecutionMessage,
Expand All @@ -1502,6 +1519,36 @@ def active_procedure_executions():
message_op=MessageOp.SET,
)

@staticmethod
def procedure_abort():
"""
Endpoint to request aborting a running procedure

Returns:
EndpointInfo: Endpoint for set of active procedure executions.
"""
endpoint = f"{EndpointType.USER.value}/procedures/abort"
return EndpointInfo(
endpoint=endpoint,
message_type=messages.ProcedureAbortMessage,
message_op=MessageOp.STREAM,
)

@staticmethod
def procedure_clear_unhandled():
"""
Endpoint to request removing an aborted procedure

Returns:
EndpointInfo: Endpoint for set of active procedure executions.
"""
endpoint = f"{EndpointType.USER.value}/procedures/clear_unhandled"
return EndpointInfo(
endpoint=endpoint,
message_type=messages.ProcedureClearUnhandledMessage,
message_op=MessageOp.STREAM,
)

@staticmethod
def procedure_worker_status_update(queue_id: str):
"""
Expand All @@ -1517,6 +1564,34 @@ def procedure_worker_status_update(queue_id: str):
message_op=MessageOp.LIST,
)

@staticmethod
def procedure_queue_notif():
"""
PubSub channel for a consumer (e.g. BEC widgets) to be notified of changes to a procedure queue

Returns:
EndpointInfo: Endpoint for procedure queue updates for given queue.
"""
endpoint = f"{EndpointType.INFO.value}/procedures/queue_notif"
return EndpointInfo(
endpoint=endpoint,
message_type=messages.ProcedureQNotifMessage,
message_op=MessageOp.SEND,
)

@staticmethod
def procedure_logs(queue: str):
"""
Endpoint for logs for a given procedure queue

Returns:
EndpointInfo: Endpoint for procedure queue updates for given queue.
"""
endpoint = f"{EndpointType.INFO.value}/procedures/logs/{queue}"
return EndpointInfo(
endpoint=endpoint, message_type=messages.RawMessage, message_op=MessageOp.STREAM
)

@staticmethod
def gui_registry_state(gui_id: str):
"""
Expand Down
13 changes: 9 additions & 4 deletions bec_lib/bec_lib/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ class BECLogger:
"<green>{service_name} | {{time:YYYY-MM-DD HH:mm:ss.SSS}}</green> | <level>{{level}}</level> |"
" <level>{{thread.name}} ({{thread.id}})</level> | <cyan>{{extra[stack]}}</cyan> - <level>{{message}}</level>\n"
)
CONTAINER_FORMAT = "{{time:YYYY-MM-DD HH:mm:ss.SSS}} | {{level}} | {{message}}\n"
LOGLEVEL = LogLevel

_logger = None
Expand Down Expand Up @@ -224,18 +225,21 @@ def _logger_callback(self, msg):
# because it depends on the connector
pass

def get_format(self, level: LogLevel = None, is_stderr=False) -> str:
def get_format(self, level: LogLevel = None, is_stderr=False, is_container=False) -> str:
"""
Get the format for a specific log level.

Args:
level (LogLevel, optional): Log level. Defaults to None. If None, the current log level will be used.
is_stderr (bool, optional): Whether the log is for stderr. Defaults to False.
is_container (bool, optional): Simple logging for procedure container. Defaults to False.

Returns:
str: Log format.
"""
service_name = self.service_name if self.service_name else ""
if is_container:
return self.CONTAINER_FORMAT.format()
if level is None:
level = self.level
if level > self.LOGLEVEL.DEBUG:
Expand All @@ -246,15 +250,16 @@ def get_format(self, level: LogLevel = None, is_stderr=False) -> str:
return self.DEBUG_FORMAT.format(service_name=service_name)
return self.TRACE_FORMAT.format(service_name=service_name)

def formatting(self, is_stderr=False):
def formatting(self, is_stderr=False, is_container=False):
"""
Format the log message.

Args:
record (dict): Log record.
is_container (bool, optional): Simple logging for procedure container. Defaults to False.

Returns:
dict: Formatted log record.
str: Log format.
"""

def _update_record(record):
Expand All @@ -269,7 +274,7 @@ def _update_record(record):

def _format(record):
level = _update_record(record)
return self.get_format(level)
return self.get_format(level, is_container=is_container)

def _format_stderr(record):
level = _update_record(record)
Expand Down
64 changes: 62 additions & 2 deletions bec_lib/bec_lib/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
import warnings
from copy import deepcopy
from enum import Enum, auto
from typing import Any, ClassVar, Literal
from typing import Any, ClassVar, Literal, Self
from uuid import uuid4

import numpy as np
from pydantic import BaseModel, ConfigDict, Field, ValidationError, field_validator, model_validator
Expand All @@ -18,6 +19,7 @@ class ProcedureWorkerStatus(Enum):
IDLE = auto()
FINISHED = auto()
DEAD = auto() # worker lost communication with the container
NONE = auto() # worker doesn't exist in manager, caught during creation and cleanup


class BECStatus(Enum):
Expand Down Expand Up @@ -1110,18 +1112,67 @@ class ProcedureRequestMessage(BECMessage):
queue: str | None = None


class ProcedureQNotifMessage(BECMessage):
"""Message type for notifying watchers of changes to queues"""

msg_type: ClassVar[str] = "procedure_queue_notif_message"
queue_name: str
queue_type: Literal["execution", "unhandled"]


class ProcedureExecutionMessage(BECMessage):
"""Message type for sending procedure execution instructions to the scheduler

Sent by the user to the procedure_request topic. It will be consumed by the scan server.
Args:
identifier (str): name of the procedure registered with the server
queue (str): the procedure queue this execution belongs to
args_kwargs (tuple[tuple[Any, ...], dict[str, Any]]): arguments for the procedure function
"""

msg_type: ClassVar[str] = "procedure_execution_message"
identifier: str
queue: str
args_kwargs: tuple[tuple[Any, ...], dict[str, Any]] = (), {}
execution_id: str = Field(default_factory=lambda: str(uuid4()))


class ProcedureAbortMessage(BECMessage):
"""Message type to request aborting a procedure or procedure queue

One and only one of the args should be supplied.
Args:
queue (str | None): the procedure queue to abort
execution_id (str | None): the procedure execution to abort
abort_all (bool | None): abort all procedures if true
"""

msg_type: ClassVar[str] = "procedure_abort_message"
queue: str | None = None
execution_id: str | None = None
abort_all: bool | None = None

@model_validator(mode="after")
def mutually_exclusive(self) -> Self:
if (self.queue, self.execution_id, self.abort_all).count(None) != 2:
raise ValueError(
"Please only supply one argument! Supplied: \n"
f" {self.queue=}, {self.execution_id=}, {self.abort_all=}"
)
return self


class ProcedureClearUnhandledMessage(ProcedureAbortMessage):
"""Message type to request clearing an unhandled procedure or procedure queue

One and only one of the args should be supplied.
Args:
queue (str | None): the procedure queue to abort
execution_id (str | None): the procedure execution to abort
abort_all (bool | None): abort all procedures if true
"""

...


class ProcedureWorkerStatusMessage(BECMessage):
Expand All @@ -1130,12 +1181,21 @@ class ProcedureWorkerStatusMessage(BECMessage):
Args:
worker_queue (str): Worker queue ID
status (str): Worker status

current_execution_id (str | None): ID of the current job, only allowed for RUNNING
"""

msg_type: ClassVar[str] = "procedure_worker_status_message"
worker_queue: str
status: ProcedureWorkerStatus
current_execution_id: str | None = None

@model_validator(mode="after")
def check_id(self) -> Self:
if self.current_execution_id is not None and self.status != ProcedureWorkerStatus.RUNNING:
raise ValueError("Adding an execution ID is only valid for the RUNNING status")
if self.current_execution_id is None and self.status == ProcedureWorkerStatus.RUNNING:
raise ValueError("Adding an execution ID is mandatory for the RUNNING status")
return self


class LoginInfoMessage(BECMessage):
Expand Down
Loading
Loading