Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion bec_server/bec_server/scan_server/scans/scan_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import enum
import threading
from collections.abc import Sequence
from typing import TYPE_CHECKING, Annotated, Type
from typing import Annotated, Callable, Type

import numpy as np
import pint
Expand Down Expand Up @@ -201,6 +201,7 @@ def __init__(
self._premove_motor_status = None
self.positions = np.array([])
self.start_positions = []
self._scan_original_hooks = self._collect_original_scan_hooks()
self._scan_modifier_hooks = (
get_scan_hooks_impl(scan_modifier) if scan_modifier is not None else {}
)
Expand Down Expand Up @@ -263,3 +264,20 @@ def update_scan_info(
setattr(self.scan_info, key, value)
else:
self.scan_info.additional_scan_parameters[key] = value

def _collect_original_scan_hooks(self) -> dict[str, Callable]:
"""
Bind the undecorated scan hook implementations to this scan instance.

Returns:
dict[str, Callable]: Mapping from hook name to the original bound method.
"""
original_hooks = {}
for attr_name in dir(type(self)):
attr = getattr(type(self), attr_name)
hook_info = getattr(attr, "_scan_hook_info", None)
original_func = getattr(attr, "_scan_hook_original", None)
if hook_info is None or original_func is None:
continue
original_hooks[hook_info["method_name"]] = original_func.__get__(self, type(self))
return original_hooks
123 changes: 107 additions & 16 deletions bec_server/bec_server/scan_server/scans/scan_modifier.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from fnmatch import fnmatchcase
from functools import wraps
from typing import TYPE_CHECKING, Annotated, Any, Literal, TypeAlias, get_args

Expand Down Expand Up @@ -27,6 +28,44 @@
VALID_SCAN_HOOKS = set(get_args(ScanHookName))


def _matches_scan_name(scan_name: str | None, patterns: list[str] | None) -> bool:
if not patterns:
return True
if scan_name is None:
return False
return any(fnmatchcase(scan_name, pattern) for pattern in patterns)


def _get_hook_method_name(
hook_name: str,
hook_info: dict[str, str | dict[str, str | list[str]] | list[str | dict[str, str | list[str]]]],
hook_type: str,
scan_name: str | None,
) -> str | None:
hook_config = hook_info.get(hook_type)
if hook_config is None:
return None
if isinstance(hook_config, list):
matched_method_names = []
for config in hook_config:
if isinstance(config, str):
matched_method_names.append(config)
continue
if _matches_scan_name(scan_name, config.get("scan_names")):
matched_method_names.append(config["method_name"])
if len(matched_method_names) > 1:
raise ValueError(
f"Multiple scan modifier implementations matched hook '{hook_name}' "
f"for lifecycle '{hook_type}' and scan '{scan_name}'"
)
return matched_method_names[0] if matched_method_names else None
if isinstance(hook_config, str):
return hook_config
if not _matches_scan_name(scan_name, hook_config.get("scan_names")):
return None
return hook_config["method_name"]


def scan_hook(func):
"""
Decorator for scan hooks. It registers the decorated method as a scan hook and thus allows
Expand All @@ -46,62 +85,82 @@ def wrapper(self, *args, **kwargs):
return func(self, *args, **kwargs)

hook_info = self._scan_modifier_hooks[func.__name__]
if "before" in hook_info:
before_method = getattr(self._scan_modifier, hook_info["before"])
scan_name = getattr(
getattr(self, "scan_info", None), "scan_name", getattr(self, "scan_name", None)
)

before_method_name = _get_hook_method_name(func.__name__, hook_info, "before", scan_name)
if before_method_name is not None:
before_method = getattr(self._scan_modifier, before_method_name)
before_method(*args, **kwargs)

if "replace" in hook_info:
replace_method = getattr(self._scan_modifier, hook_info["replace"])
replace_method_name = _get_hook_method_name(func.__name__, hook_info, "replace", scan_name)
if replace_method_name is not None:
replace_method = getattr(self._scan_modifier, replace_method_name)
replace_method(*args, **kwargs)
else:
func(self, *args, **kwargs)

if "after" in hook_info:
after_method = getattr(self._scan_modifier, hook_info["after"])
after_method_name = _get_hook_method_name(func.__name__, hook_info, "after", scan_name)
if after_method_name is not None:
after_method = getattr(self._scan_modifier, after_method_name)
after_method(*args, **kwargs)

return

# pylint: disable=protected-access
wrapper._scan_hook_info = {"method_name": func.__name__} # type: ignore
wrapper._scan_hook_original = func # type: ignore[attr-defined]

return wrapper


def scan_hook_impl(
hook_name: ScanHookName, hook_type: Literal["before", "after", "replace"] = "before"
hook_name: ScanHookName,
hook_type: Literal["before", "after", "replace"] = "before",
scan_names: list[str] | None = None,
):
"""
Decorator for scan hook implementations. It registers the decorated method as an implementation of the specified scan hook.
The hook_name must refer to an existing scan hook.
The hook_type should be one of the following: "before", "after" or "replace".
The optional scan_names list can be used to restrict the implementation to matching scan names.
Wildcards are supported using shell-style patterns such as ``*_line_scan``.
This allows the scan modifier to specify whether the decorated method should be executed before, after or instead of the original scan hook method.
"""
if hook_name not in VALID_SCAN_HOOKS:
raise ValueError(f"Invalid scan hook: {hook_name}")
if hook_type not in {"before", "after", "replace"}:
raise ValueError(f"Invalid scan hook type: {hook_type}")
if scan_names is not None and not isinstance(scan_names, list):
raise ValueError("scan_names must be a list of scan name patterns")

def decorator(func):
@wraps(func)
def wrapper(self, *args, **kwargs):
return func(self, *args, **kwargs)

# pylint: disable=protected-access
wrapper._scan_hook_impl_info = {"hook_name": hook_name, "hook_type": hook_type} # type: ignore
wrapper._scan_hook_impl_info = {
"hook_name": hook_name,
"hook_type": hook_type,
"scan_names": scan_names,
} # type: ignore

return wrapper

return decorator


def get_scan_hooks_impl(cls) -> dict[str, dict[str, str]]:
def get_scan_hooks_impl(
cls,
) -> dict[
str, dict[str, str | dict[str, str | list[str]] | list[str | dict[str, str | list[str]]]]
]:
"""
Get the scan hooks implemented by the given class. It returns
a dictionary mapping the original scan hook names to the corresponding method names and hook types in the scan modifier.

Raises:
ValueError: If the class implements multiple hooks for the same hook_type (before, after, replace) for the same scan hook.
"""
hooks = {}
for attr_name in dir(cls):
Expand All @@ -112,11 +171,19 @@ def get_scan_hooks_impl(cls) -> dict[str, dict[str, str]]:
hook_type = info["hook_type"]
if hook_name not in hooks:
hooks[hook_name] = {}
if hook_type in hooks[hook_name]:
raise ValueError(
f"Multiple implementations for the same hook type '{hook_type}' for the scan hook '{hook_name}' in class '{cls.__name__}'"
)
hooks[hook_name][hook_type] = attr_name
scan_names = info.get("scan_names")
hook_config: str | dict[str, str | list[str]]
if scan_names is None:
hook_config = attr_name
else:
hook_config = {"method_name": attr_name, "scan_names": scan_names}
existing_hook_config = hooks[hook_name].get(hook_type)
if existing_hook_config is None:
hooks[hook_name][hook_type] = hook_config
elif isinstance(existing_hook_config, list):
existing_hook_config.append(hook_config)
else:
hooks[hook_name][hook_type] = [existing_hook_config, hook_config]
return hooks


Expand Down Expand Up @@ -225,3 +292,27 @@ def device_is_available(self, device: list[str] | str, check_enabled: bool = Tru
if check_enabled and not self.dev[dev_name].enabled:
return False
return True

def call_original(self, hook_name: ScanHookName, *args, **kwargs):
"""
Call the scan's original hook implementation directly, bypassing scan modifier dispatch.

Args:
hook_name (ScanHookName): Name of the original scan hook to call.
*args: Positional arguments forwarded to the original hook.
**kwargs: Keyword arguments forwarded to the original hook.

Returns:
Any: The return value of the original hook implementation.

Raises:
AttributeError: If the scan does not expose an original implementation for the hook.
"""
original_hooks = getattr(self.scan, "_scan_original_hooks", {})
try:
original_hook = original_hooks[hook_name]
except KeyError as exc:
raise AttributeError(
f"Scan {type(self.scan).__name__!r} does not expose an original hook for {hook_name!r}"
) from exc
return original_hook(*args, **kwargs)
Loading
Loading