Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 108 additions & 29 deletions bec_widgets/widgets/control/scan_control/scan_group_box.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,47 @@ def __init__(
self.setChecked(default)


class ScanOptionalWidget(QGroupBox):
def __init__(self, widget, parent=None):
super().__init__(parent=parent)
self.inner_widget = widget
self.arg_name = getattr(widget, "arg_name", None)
self.setFlat(True)

layout = QHBoxLayout(self)
layout.setContentsMargins(0, 0, 0, 0)
layout.addWidget(widget)

self.none_checkbox = QCheckBox(self)
self.none_checkbox.setToolTip("Set this value to None.")
self.none_checkbox.toggled.connect(self._on_none_toggled)
layout.addWidget(self.none_checkbox)

def _on_none_toggled(self, checked: bool) -> None:
self.inner_widget.setEnabled(not checked)

def set_none(self, checked: bool) -> None:
self.none_checkbox.setChecked(checked)

def is_none(self) -> bool:
return self.none_checkbox.isChecked()

def setToolTip(self, text: str) -> None: # noqa: N802
super().setToolTip(text)
self.inner_widget.setToolTip(text)
checkbox_tooltip = "Set this value to None."
if text:
checkbox_tooltip = f"{text}\n{checkbox_tooltip}"
self.none_checkbox.setToolTip(checkbox_tooltip)

def toolTip(self) -> str: # noqa: N802
return self.inner_widget.toolTip()

def setSuffix(self, suffix: str) -> None: # noqa: N802
if hasattr(self.inner_widget, "setSuffix"):
self.inner_widget.setSuffix(suffix)


class ScanGroupBox(QGroupBox):
WIDGET_HANDLER = {
ScanArgType.DEVICE: DeviceComboBox,
Expand Down Expand Up @@ -211,6 +252,7 @@ def __init__(
self.labels = []
self.widgets = []
self._widget_configs = {}
self._wrapped_widgets = {}
self._column_labels = {}
self.selected_devices = {}

Expand Down Expand Up @@ -297,10 +339,11 @@ def add_input_widgets(self, group_inputs: dict, row) -> None:
)
if isinstance(widget, ScanLiteralsComboBox):
widget.set_literals(item["type"].get("Literal", []))
self._widget_configs[widget] = item
self._apply_unit_metadata(widget, item)
self.layout.addWidget(widget, row, column_index)
self.widgets.append(widget)
display_widget = self._wrap_optional_widget(widget, item, default)
self._widget_configs[display_widget] = item
self._apply_unit_metadata(display_widget, item)
self.layout.addWidget(display_widget, row, column_index)
self.widgets.append(display_widget)

@Slot(str)
def emit_device_selected(self, device_name):
Expand Down Expand Up @@ -334,8 +377,10 @@ def remove_widget_bundle(self):
return

for widget in self.widgets[-len(self.inputs) :]:
if isinstance(widget, DeviceComboBox):
self.selected_devices[widget] = ""
inner_widget = self._inner_widget(widget)
if isinstance(inner_widget, DeviceComboBox):
self.selected_devices[inner_widget] = ""
self._wrapped_widgets.pop(inner_widget, None)
self._widget_configs.pop(widget, None)
widget.close()
widget.deleteLater()
Expand All @@ -347,8 +392,10 @@ def remove_widget_bundle(self):
def remove_all_widget_bundles(self):
"""Remove every widget bundle from the scan control layout."""
for widget in list(self.widgets):
if isinstance(widget, DeviceComboBox):
self.selected_devices.pop(widget, None)
inner_widget = self._inner_widget(widget)
if isinstance(inner_widget, DeviceComboBox):
self.selected_devices.pop(inner_widget, None)
self._wrapped_widgets.pop(inner_widget, None)
self._widget_configs.pop(widget, None)
widget.close()
widget.deleteLater()
Expand Down Expand Up @@ -385,12 +432,7 @@ def _get_arg_parameters(self, device_object: bool = True):
for j in range(self.layout.columnCount()):
try: # In case that the bundle size changes
widget = self.layout.itemAtPosition(i, j).widget()
if isinstance(widget, DeviceComboBox) and device_object:
value = widget.get_current_device()
elif isinstance(widget, DeviceComboBox):
value = widget.currentText()
else:
value = WidgetIO.get_value(widget)
value = self._widget_value(widget, device_object=device_object)
args.append(value)
except AttributeError:
continue
Expand All @@ -400,27 +442,23 @@ def _get_kwarg_parameters(self, device_object: bool = True):
kwargs = {}
for i in range(self.layout.columnCount()):
widget = self.layout.itemAtPosition(1, i).widget()
if isinstance(widget, DeviceComboBox) and device_object:
value = widget.get_current_device().name
elif isinstance(widget, DeviceComboBox):
value = widget.currentText()
elif isinstance(widget, ScanLiteralsComboBox):
value = widget.get_value()
else:
value = WidgetIO.get_value(widget)
value = self._widget_value(widget, device_object=device_object)
inner_widget = self._inner_widget(widget)
if isinstance(inner_widget, DeviceComboBox) and value is not None and device_object:
value = value.name
kwargs[widget.arg_name] = value
return kwargs

def count_arg_rows(self):
widget_rows = 0
for row in range(self.layout.rowCount()):
if row == 0:
continue
for col in range(self.layout.columnCount()):
item = self.layout.itemAtPosition(row, col)
if item is not None:
widget = item.widget()
if widget is not None:
if isinstance(widget, DeviceComboBox):
widget_rows += 1
if item is not None and item.widget() is not None:
widget_rows += 1
break
return widget_rows

def set_parameters(self, parameters: list | dict):
Expand All @@ -444,13 +482,13 @@ def _set_arg_parameters(self, parameters: list):
self.add_input_widgets(self.inputs, row)

for i, value in enumerate(parameters):
WidgetIO.set_value(self.widgets[i], value)
self._set_widget_value(self.widgets[i], value)

def _set_kwarg_parameters(self, parameters: dict):
for widget in self.widgets:
for key, value in parameters.items():
if widget.arg_name == key:
WidgetIO.set_value(widget, value)
self._set_widget_value(widget, value)
break

@staticmethod
Expand Down Expand Up @@ -505,6 +543,7 @@ def _device_units(device) -> str | None:
return None

def _widget_position(self, widget) -> tuple[int, int] | None:
widget = self._display_widget(widget)
for row in range(self.layout.rowCount()):
for column in range(self.layout.columnCount()):
item = self.layout.itemAtPosition(row, column)
Expand Down Expand Up @@ -608,3 +647,43 @@ def _apply_numeric_limits(widget: ScanDoubleSpinBox | ScanSpinBox, item: dict) -
if item.get("lt") is not None:
maximum = float(item["lt"]) - step
widget.setRange(minimum, maximum)

def _wrap_optional_widget(self, widget, item: dict, default):
if not item.get("optional", False):
return widget

wrapped_widget = ScanOptionalWidget(widget, parent=self)
wrapped_widget.set_none(default is None)
self._wrapped_widgets[widget] = wrapped_widget
return wrapped_widget

@staticmethod
def _inner_widget(widget):
if isinstance(widget, ScanOptionalWidget):
return widget.inner_widget
return widget

def _display_widget(self, widget):
return self._wrapped_widgets.get(widget, widget)

def _widget_value(self, widget, *, device_object: bool = True):
if isinstance(widget, ScanOptionalWidget) and widget.is_none():
return None

inner_widget = self._inner_widget(widget)
if isinstance(inner_widget, DeviceComboBox) and device_object:
return inner_widget.get_current_device()
if isinstance(inner_widget, DeviceComboBox):
return inner_widget.currentText()
if isinstance(inner_widget, ScanLiteralsComboBox):
return inner_widget.get_value()
return WidgetIO.get_value(inner_widget)

def _set_widget_value(self, widget, value) -> None:
if isinstance(widget, ScanOptionalWidget):
widget.set_none(value is None)
if value is None:
return
WidgetIO.set_value(widget.inner_widget, value)
return
WidgetIO.set_value(widget, value)
31 changes: 21 additions & 10 deletions bec_widgets/widgets/control/scan_control/scan_info_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,27 +92,34 @@ def resolve_tooltip(scan_argument: ScanArgumentMetadata) -> str | None:
@staticmethod
def parse_annotation(
annotation: AnnotationValue,
) -> tuple[AnnotationValue, ScanArgumentMetadata]:
) -> tuple[AnnotationValue, ScanArgumentMetadata, bool]:
"""Extract the serialized base annotation and ``ScanArgument`` metadata.

Args:
annotation (AnnotationValue): Serialized annotation payload from BEC.

Returns:
tuple[AnnotationValue, ScanArgumentMetadata]: The unwrapped annotation and parsed
``ScanArgument`` metadata.
tuple[AnnotationValue, ScanArgumentMetadata, bool]: The unwrapped annotation,
parsed ``ScanArgument`` metadata, and whether ``None`` is an allowed value.
"""
scan_argument: ScanArgumentMetadata = {}
if isinstance(annotation, dict) and "Annotated" in annotation:
annotated = annotation["Annotated"]
annotation = annotated.get("type", "_empty")
scan_argument = annotated.get("metadata", {}).get("ScanArgument", {}) or {}

allows_none = False
if isinstance(annotation, list):
allows_none = "NoneType" in annotation
annotation = next(
(entry for entry in annotation if entry != "NoneType"),
annotation[0] if annotation else "_empty",
)
if isinstance(annotation, dict) and "Annotated" in annotation:
annotated = annotation["Annotated"]
annotation = annotated.get("type", "_empty")
scan_argument = annotated.get("metadata", {}).get("ScanArgument", {}) or {}
return annotation, scan_argument
elif annotation == "NoneType":
allows_none = True
annotation = "_empty"

return annotation, scan_argument, allows_none

@staticmethod
def scan_arg_type_from_annotation(annotation: AnnotationValue) -> AnnotationValue:
Expand Down Expand Up @@ -142,13 +149,14 @@ def scan_input_from_signature(
Returns:
ScanInputConfig: Normalized input configuration for ``ScanControl``.
"""
annotation, scan_argument = self.parse_annotation(param.get("annotation"))
annotation, scan_argument, allows_none = self.parse_annotation(param.get("annotation"))
return self._build_scan_input(
name=param["name"],
annotation=annotation,
scan_argument=scan_argument,
arg=arg,
default=None if arg else param.get("default", None),
optional=allows_none,
)

def scan_input_from_arg_input(
Expand All @@ -171,13 +179,14 @@ def scan_input_from_arg_input(
self.parse_annotation(signature_by_name[name].get("annotation"))[0]
)
else:
annotation, scan_argument = self.parse_annotation(item_type)
annotation, scan_argument, allows_none = self.parse_annotation(item_type)
scan_input = self._build_scan_input(
name=name,
annotation=annotation,
scan_argument=scan_argument,
arg=True,
default=None,
optional=allows_none,
)
if scan_input["type"] in ("_empty", None):
scan_input["type"] = item_type
Expand All @@ -191,6 +200,7 @@ def _build_scan_input(
*,
arg: bool,
default: Any,
optional: bool,
) -> ScanInputConfig:
"""Build one normalized ScanControl input configuration.

Expand All @@ -211,6 +221,7 @@ def _build_scan_input(
"display_name": scan_argument.get("display_name") or self.format_display_name(name),
"tooltip": self.resolve_tooltip(scan_argument),
"default": default,
"optional": optional,
"expert": scan_argument.get("expert", False),
"hidden": scan_argument.get("hidden", False),
"precision": scan_argument.get("precision"),
Expand Down
38 changes: 38 additions & 0 deletions tests/unit_tests/test_scan_control.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,6 +442,44 @@ def test_scan_info_adapter_skips_duplicate_visible_kwargs():
}


def test_scan_info_adapter_supports_optional_annotated_types():
scan_info = {
"class": "OptionalScan",
"base_class": "ScanBaseV4",
"arg_input": {},
"arg_bundle_size": {"bundle": 0, "min": None, "max": None},
"gui_visibility": {"Matching": ["atol"]},
"signature": [
{
"arg": False,
"name": "atol",
"annotation": {
"Annotated": {
"type": ["float", "NoneType"],
"metadata": {
"ScanArgument": {
"display_name": "Tolerance",
"tooltip": "Optional tolerance used for position matching",
}
},
}
},
"default": None,
"kind": "KEYWORD_ONLY",
}
],
}

gui_config = ScanInfoAdapter().build_scan_ui_config(scan_info)
input_spec = gui_config["kwarg_groups"][0]["inputs"][0]

assert input_spec["name"] == "atol"
assert input_spec["type"] == "float"
assert input_spec["optional"] is True
assert input_spec["default"] is None
assert input_spec["display_name"] == "Tolerance"


def test_scan_info_adapter_rejects_unsupported_visible_inputs():
scan_info = {
"class": "UnsupportedScan",
Expand Down
Loading
Loading