diff --git a/bec_widgets/widgets/control/scan_control/scan_group_box.py b/bec_widgets/widgets/control/scan_control/scan_group_box.py index 0ce8d8643..aec9d96d5 100644 --- a/bec_widgets/widgets/control/scan_control/scan_group_box.py +++ b/bec_widgets/widgets/control/scan_control/scan_group_box.py @@ -162,6 +162,47 @@ def __init__( self.setChecked(default) +class ScanOptionalWidget(QGroupBox): + def __init__(self, widget, parent=None): + super().__init__(parent=parent) + self.inner_widget = widget + self.arg_name = getattr(widget, "arg_name", None) + self.setFlat(True) + + layout = QHBoxLayout(self) + layout.setContentsMargins(0, 0, 0, 0) + layout.addWidget(widget) + + self.none_checkbox = QCheckBox(self) + self.none_checkbox.setToolTip("Set this value to None.") + self.none_checkbox.toggled.connect(self._on_none_toggled) + layout.addWidget(self.none_checkbox) + + def _on_none_toggled(self, checked: bool) -> None: + self.inner_widget.setEnabled(not checked) + + def set_none(self, checked: bool) -> None: + self.none_checkbox.setChecked(checked) + + def is_none(self) -> bool: + return self.none_checkbox.isChecked() + + def setToolTip(self, text: str) -> None: # noqa: N802 + super().setToolTip(text) + self.inner_widget.setToolTip(text) + checkbox_tooltip = "Set this value to None." + if text: + checkbox_tooltip = f"{text}\n{checkbox_tooltip}" + self.none_checkbox.setToolTip(checkbox_tooltip) + + def toolTip(self) -> str: # noqa: N802 + return self.inner_widget.toolTip() + + def setSuffix(self, suffix: str) -> None: # noqa: N802 + if hasattr(self.inner_widget, "setSuffix"): + self.inner_widget.setSuffix(suffix) + + class ScanGroupBox(QGroupBox): WIDGET_HANDLER = { ScanArgType.DEVICE: DeviceComboBox, @@ -211,6 +252,7 @@ def __init__( self.labels = [] self.widgets = [] self._widget_configs = {} + self._wrapped_widgets = {} self._column_labels = {} self.selected_devices = {} @@ -297,10 +339,11 @@ def add_input_widgets(self, group_inputs: dict, row) -> None: ) if isinstance(widget, ScanLiteralsComboBox): widget.set_literals(item["type"].get("Literal", [])) - self._widget_configs[widget] = item - self._apply_unit_metadata(widget, item) - self.layout.addWidget(widget, row, column_index) - self.widgets.append(widget) + display_widget = self._wrap_optional_widget(widget, item, default) + self._widget_configs[display_widget] = item + self._apply_unit_metadata(display_widget, item) + self.layout.addWidget(display_widget, row, column_index) + self.widgets.append(display_widget) @Slot(str) def emit_device_selected(self, device_name): @@ -334,8 +377,10 @@ def remove_widget_bundle(self): return for widget in self.widgets[-len(self.inputs) :]: - if isinstance(widget, DeviceComboBox): - self.selected_devices[widget] = "" + inner_widget = self._inner_widget(widget) + if isinstance(inner_widget, DeviceComboBox): + self.selected_devices[inner_widget] = "" + self._wrapped_widgets.pop(inner_widget, None) self._widget_configs.pop(widget, None) widget.close() widget.deleteLater() @@ -347,8 +392,10 @@ def remove_widget_bundle(self): def remove_all_widget_bundles(self): """Remove every widget bundle from the scan control layout.""" for widget in list(self.widgets): - if isinstance(widget, DeviceComboBox): - self.selected_devices.pop(widget, None) + inner_widget = self._inner_widget(widget) + if isinstance(inner_widget, DeviceComboBox): + self.selected_devices.pop(inner_widget, None) + self._wrapped_widgets.pop(inner_widget, None) self._widget_configs.pop(widget, None) widget.close() widget.deleteLater() @@ -385,12 +432,7 @@ def _get_arg_parameters(self, device_object: bool = True): for j in range(self.layout.columnCount()): try: # In case that the bundle size changes widget = self.layout.itemAtPosition(i, j).widget() - if isinstance(widget, DeviceComboBox) and device_object: - value = widget.get_current_device() - elif isinstance(widget, DeviceComboBox): - value = widget.currentText() - else: - value = WidgetIO.get_value(widget) + value = self._widget_value(widget, device_object=device_object) args.append(value) except AttributeError: continue @@ -400,27 +442,23 @@ def _get_kwarg_parameters(self, device_object: bool = True): kwargs = {} for i in range(self.layout.columnCount()): widget = self.layout.itemAtPosition(1, i).widget() - if isinstance(widget, DeviceComboBox) and device_object: - value = widget.get_current_device().name - elif isinstance(widget, DeviceComboBox): - value = widget.currentText() - elif isinstance(widget, ScanLiteralsComboBox): - value = widget.get_value() - else: - value = WidgetIO.get_value(widget) + value = self._widget_value(widget, device_object=device_object) + inner_widget = self._inner_widget(widget) + if isinstance(inner_widget, DeviceComboBox) and value is not None and device_object: + value = value.name kwargs[widget.arg_name] = value return kwargs def count_arg_rows(self): widget_rows = 0 for row in range(self.layout.rowCount()): + if row == 0: + continue for col in range(self.layout.columnCount()): item = self.layout.itemAtPosition(row, col) - if item is not None: - widget = item.widget() - if widget is not None: - if isinstance(widget, DeviceComboBox): - widget_rows += 1 + if item is not None and item.widget() is not None: + widget_rows += 1 + break return widget_rows def set_parameters(self, parameters: list | dict): @@ -444,13 +482,13 @@ def _set_arg_parameters(self, parameters: list): self.add_input_widgets(self.inputs, row) for i, value in enumerate(parameters): - WidgetIO.set_value(self.widgets[i], value) + self._set_widget_value(self.widgets[i], value) def _set_kwarg_parameters(self, parameters: dict): for widget in self.widgets: for key, value in parameters.items(): if widget.arg_name == key: - WidgetIO.set_value(widget, value) + self._set_widget_value(widget, value) break @staticmethod @@ -505,6 +543,7 @@ def _device_units(device) -> str | None: return None def _widget_position(self, widget) -> tuple[int, int] | None: + widget = self._display_widget(widget) for row in range(self.layout.rowCount()): for column in range(self.layout.columnCount()): item = self.layout.itemAtPosition(row, column) @@ -608,3 +647,43 @@ def _apply_numeric_limits(widget: ScanDoubleSpinBox | ScanSpinBox, item: dict) - if item.get("lt") is not None: maximum = float(item["lt"]) - step widget.setRange(minimum, maximum) + + def _wrap_optional_widget(self, widget, item: dict, default): + if not item.get("optional", False): + return widget + + wrapped_widget = ScanOptionalWidget(widget, parent=self) + wrapped_widget.set_none(default is None) + self._wrapped_widgets[widget] = wrapped_widget + return wrapped_widget + + @staticmethod + def _inner_widget(widget): + if isinstance(widget, ScanOptionalWidget): + return widget.inner_widget + return widget + + def _display_widget(self, widget): + return self._wrapped_widgets.get(widget, widget) + + def _widget_value(self, widget, *, device_object: bool = True): + if isinstance(widget, ScanOptionalWidget) and widget.is_none(): + return None + + inner_widget = self._inner_widget(widget) + if isinstance(inner_widget, DeviceComboBox) and device_object: + return inner_widget.get_current_device() + if isinstance(inner_widget, DeviceComboBox): + return inner_widget.currentText() + if isinstance(inner_widget, ScanLiteralsComboBox): + return inner_widget.get_value() + return WidgetIO.get_value(inner_widget) + + def _set_widget_value(self, widget, value) -> None: + if isinstance(widget, ScanOptionalWidget): + widget.set_none(value is None) + if value is None: + return + WidgetIO.set_value(widget.inner_widget, value) + return + WidgetIO.set_value(widget, value) diff --git a/bec_widgets/widgets/control/scan_control/scan_info_adapter.py b/bec_widgets/widgets/control/scan_control/scan_info_adapter.py index ba6d7a529..902da6629 100644 --- a/bec_widgets/widgets/control/scan_control/scan_info_adapter.py +++ b/bec_widgets/widgets/control/scan_control/scan_info_adapter.py @@ -92,27 +92,34 @@ def resolve_tooltip(scan_argument: ScanArgumentMetadata) -> str | None: @staticmethod def parse_annotation( annotation: AnnotationValue, - ) -> tuple[AnnotationValue, ScanArgumentMetadata]: + ) -> tuple[AnnotationValue, ScanArgumentMetadata, bool]: """Extract the serialized base annotation and ``ScanArgument`` metadata. Args: annotation (AnnotationValue): Serialized annotation payload from BEC. Returns: - tuple[AnnotationValue, ScanArgumentMetadata]: The unwrapped annotation and parsed - ``ScanArgument`` metadata. + tuple[AnnotationValue, ScanArgumentMetadata, bool]: The unwrapped annotation, + parsed ``ScanArgument`` metadata, and whether ``None`` is an allowed value. """ scan_argument: ScanArgumentMetadata = {} + if isinstance(annotation, dict) and "Annotated" in annotation: + annotated = annotation["Annotated"] + annotation = annotated.get("type", "_empty") + scan_argument = annotated.get("metadata", {}).get("ScanArgument", {}) or {} + + allows_none = False if isinstance(annotation, list): + allows_none = "NoneType" in annotation annotation = next( (entry for entry in annotation if entry != "NoneType"), annotation[0] if annotation else "_empty", ) - if isinstance(annotation, dict) and "Annotated" in annotation: - annotated = annotation["Annotated"] - annotation = annotated.get("type", "_empty") - scan_argument = annotated.get("metadata", {}).get("ScanArgument", {}) or {} - return annotation, scan_argument + elif annotation == "NoneType": + allows_none = True + annotation = "_empty" + + return annotation, scan_argument, allows_none @staticmethod def scan_arg_type_from_annotation(annotation: AnnotationValue) -> AnnotationValue: @@ -142,13 +149,14 @@ def scan_input_from_signature( Returns: ScanInputConfig: Normalized input configuration for ``ScanControl``. """ - annotation, scan_argument = self.parse_annotation(param.get("annotation")) + annotation, scan_argument, allows_none = self.parse_annotation(param.get("annotation")) return self._build_scan_input( name=param["name"], annotation=annotation, scan_argument=scan_argument, arg=arg, default=None if arg else param.get("default", None), + optional=allows_none, ) def scan_input_from_arg_input( @@ -171,13 +179,14 @@ def scan_input_from_arg_input( self.parse_annotation(signature_by_name[name].get("annotation"))[0] ) else: - annotation, scan_argument = self.parse_annotation(item_type) + annotation, scan_argument, allows_none = self.parse_annotation(item_type) scan_input = self._build_scan_input( name=name, annotation=annotation, scan_argument=scan_argument, arg=True, default=None, + optional=allows_none, ) if scan_input["type"] in ("_empty", None): scan_input["type"] = item_type @@ -191,6 +200,7 @@ def _build_scan_input( *, arg: bool, default: Any, + optional: bool, ) -> ScanInputConfig: """Build one normalized ScanControl input configuration. @@ -211,6 +221,7 @@ def _build_scan_input( "display_name": scan_argument.get("display_name") or self.format_display_name(name), "tooltip": self.resolve_tooltip(scan_argument), "default": default, + "optional": optional, "expert": scan_argument.get("expert", False), "hidden": scan_argument.get("hidden", False), "precision": scan_argument.get("precision"), diff --git a/tests/unit_tests/test_scan_control.py b/tests/unit_tests/test_scan_control.py index b6da2f962..c09f25dd7 100644 --- a/tests/unit_tests/test_scan_control.py +++ b/tests/unit_tests/test_scan_control.py @@ -442,6 +442,44 @@ def test_scan_info_adapter_skips_duplicate_visible_kwargs(): } +def test_scan_info_adapter_supports_optional_annotated_types(): + scan_info = { + "class": "OptionalScan", + "base_class": "ScanBaseV4", + "arg_input": {}, + "arg_bundle_size": {"bundle": 0, "min": None, "max": None}, + "gui_visibility": {"Matching": ["atol"]}, + "signature": [ + { + "arg": False, + "name": "atol", + "annotation": { + "Annotated": { + "type": ["float", "NoneType"], + "metadata": { + "ScanArgument": { + "display_name": "Tolerance", + "tooltip": "Optional tolerance used for position matching", + } + }, + } + }, + "default": None, + "kind": "KEYWORD_ONLY", + } + ], + } + + gui_config = ScanInfoAdapter().build_scan_ui_config(scan_info) + input_spec = gui_config["kwarg_groups"][0]["inputs"][0] + + assert input_spec["name"] == "atol" + assert input_spec["type"] == "float" + assert input_spec["optional"] is True + assert input_spec["default"] is None + assert input_spec["display_name"] == "Tolerance" + + def test_scan_info_adapter_rejects_unsupported_visible_inputs(): scan_info = { "class": "UnsupportedScan", diff --git a/tests/unit_tests/test_scan_control_group_box.py b/tests/unit_tests/test_scan_control_group_box.py index e47fb3b09..4204b0d71 100644 --- a/tests/unit_tests/test_scan_control_group_box.py +++ b/tests/unit_tests/test_scan_control_group_box.py @@ -1,7 +1,7 @@ # pylint: disable = no-name-in-module,missing-class-docstring, missing-module-docstring from bec_widgets.utils.widget_io import WidgetIO -from bec_widgets.widgets.control.scan_control.scan_group_box import ScanGroupBox +from bec_widgets.widgets.control.scan_control.scan_group_box import ScanGroupBox, ScanOptionalWidget def test_kwarg_box(qtbot): @@ -235,3 +235,41 @@ def test_spinbox_limits_from_scan_info(qtbot): assert settling_time.maximum() == 3.5 assert steps.minimum() == 1 assert steps.maximum() == 10 + + +def test_optional_kwarg_widget_round_trips_none(qtbot): + group_input = { + "name": "Kwarg Test", + "inputs": [ + { + "arg": False, + "name": "atol", + "type": "float", + "display_name": "Tolerance", + "tooltip": "Optional tolerance used for position matching", + "default": None, + "optional": True, + "expert": False, + } + ], + } + + kwarg_box = ScanGroupBox(box_type="kwargs", config=group_input) + + assert isinstance(kwarg_box.widgets[0], ScanOptionalWidget) + assert kwarg_box.widgets[0].none_checkbox.text() == "" + assert kwarg_box.widgets[0].is_none() is True + assert kwarg_box.widgets[0].inner_widget.isEnabled() is False + assert kwarg_box.get_parameters() == {"atol": None} + + kwarg_box.set_parameters({"atol": 1.25}) + + assert kwarg_box.widgets[0].is_none() is False + assert kwarg_box.widgets[0].inner_widget.isEnabled() is True + assert WidgetIO.get_value(kwarg_box.widgets[0].inner_widget) == 1.25 + assert kwarg_box.get_parameters() == {"atol": 1.25} + + kwarg_box.set_parameters({"atol": None}) + + assert kwarg_box.widgets[0].is_none() is True + assert kwarg_box.get_parameters() == {"atol": None}