From 5cc24509e1166744b7d68983c1ec171417b7e70f Mon Sep 17 00:00:00 2001 From: Sourav Pati Date: Sat, 9 May 2026 20:56:50 -0700 Subject: [PATCH 01/21] feat(types): add labels to capabilities, functions, and events Add an optional labels: dict[str, str | list[str]] field on DeviceCapabilities, FunctionDef, and EventDef. Drivers populate them either via class-level DeviceDriver.labels = {...} (device metadata) or @rpc(labels=...) / @emit(labels=...) decorator kwargs. List values express composite identity (a device that is both camera and inference). These labels are the foundation for selector-based discovery and operations: the discover/invoke/broadcast tools filter on them. --- .../device_connect_edge/drivers/base.py | 19 +++- .../device_connect_edge/drivers/decorators.py | 23 ++++- .../device_connect_edge/types.py | 25 +++++- .../device-connect-edge/tests/test_drivers.py | 86 +++++++++++++++++++ .../device-connect-edge/tests/test_types.py | 64 ++++++++++++++ 5 files changed, 211 insertions(+), 6 deletions(-) diff --git a/packages/device-connect-edge/device_connect_edge/drivers/base.py b/packages/device-connect-edge/device_connect_edge/drivers/base.py index 2f5228b..6fc013d 100644 --- a/packages/device-connect-edge/device_connect_edge/drivers/base.py +++ b/packages/device-connect-edge/device_connect_edge/drivers/base.py @@ -70,7 +70,7 @@ async def disconnect(self) -> None: import logging import time from abc import ABC -from typing import Any, Callable, Dict, List, Optional, Tuple, TYPE_CHECKING +from typing import Any, Callable, Dict, List, Optional, Tuple, Union, TYPE_CHECKING from device_connect_edge.types import ( FunctionDef, @@ -129,6 +129,14 @@ class DeviceDriver(ABC): # starting background tasks. Example: depends_on = ("robot", "speaker") depends_on: Tuple[str, ...] = () + # Override in subclasses to attach discovery metadata to the device. Carried on + # DeviceCapabilities. Values may be a single string or a list of strings (composite + # identity). Well-known keys: category (camera|robot|hub|sensor|actuator|inference), + # location (e.g. 'warehouse1/loading-dock'). Custom keys are allowed. + # Example: + # labels = {"category": ["camera", "inference"], "location": "warehouse1/dock-3"} + labels: Optional[Dict[str, Union[str, List[str]]]] = None + # Type alias for event callback EventCallback = Callable[[str, Dict[str, Any]], Any] @@ -249,7 +257,8 @@ def capabilities(self) -> DeviceCapabilities: return DeviceCapabilities( description=self.__class__.__doc__ or "", functions=self.functions, - events=self.events + events=self.events, + labels=self.labels, ) @property @@ -347,7 +356,7 @@ async def invoke(self, function_name: str, **params: Any) -> Any: # Properties to skip during attribute scanning to avoid recursion _SKIP_ATTRS = frozenset([ "capabilities", "functions", "events", "identity", "status", - "device_type" + "device_type", "labels" ]) def _collect_functions(self) -> List[FunctionDef]: @@ -379,11 +388,13 @@ def _collect_functions(self) -> List[FunctionDef]: func_name = getattr(attr, "_function_name", attr_name) description = getattr(attr, "_description", "") parameters = build_function_schema(attr) + labels = getattr(attr, "_labels", None) functions.append(FunctionDef( name=func_name, description=description, parameters=parameters, + labels=labels, tags=[] )) @@ -418,11 +429,13 @@ def _collect_events(self) -> List[EventDef]: event_name = getattr(attr, "_event_name", attr_name) description = getattr(attr, "_event_description", "") payload_schema = build_event_schema(attr) + labels = getattr(attr, "_labels", None) events.append(EventDef( name=event_name, description=description, payload_schema=payload_schema, + labels=labels, tags=[] )) diff --git a/packages/device-connect-edge/device_connect_edge/drivers/decorators.py b/packages/device-connect-edge/device_connect_edge/drivers/decorators.py index b59a3a5..4237699 100644 --- a/packages/device-connect-edge/device_connect_edge/drivers/decorators.py +++ b/packages/device-connect-edge/device_connect_edge/drivers/decorators.py @@ -57,7 +57,7 @@ async def detection_loop(self): import re import time import uuid -from typing import Any, Callable, Dict, Optional, get_type_hints, get_origin, get_args +from typing import Any, Callable, Dict, List, Optional, Union, get_type_hints, get_origin, get_args from device_connect_edge.telemetry.tracer import get_tracer, get_current_trace_id, SpanKind, StatusCode from device_connect_edge.telemetry.metrics import get_metrics @@ -345,6 +345,7 @@ def _get_integration_logger(obj: Any) -> Optional[Callable[[dict], None]]: def rpc( name: Optional[str] = None, description: Optional[str] = None, + labels: Optional[Dict[str, Union[str, List[str]]]] = None, ) -> Callable: """Decorator to expose a method as an RPC-callable function. @@ -355,6 +356,10 @@ def rpc( Args: name: Override function name (default: method __name__) description: Override description (default: first line of docstring) + labels: Discovery metadata as key:value pairs. Values may be a single + string or a list of strings (composite identity). Well-known keys: + direction (read|write), safety (critical|informational), modality + (rgb|thermal|...). Custom keys are allowed. Returns: Decorated method with function metadata attached @@ -372,6 +377,11 @@ async def my_function(self, param: str = "default") -> dict: @rpc(name="customName", description="Custom description") async def another_function(self, x: int) -> dict: return {"x": x} + + @rpc(labels={"direction": "write", "modality": ["rgb", "4k"]}) + async def capture_frame(self, resolution: str = "1080p") -> dict: + '''Capture a frame.''' + return {} """ def decorator(func: Callable) -> Callable: func_name = name or func.__name__ @@ -380,6 +390,7 @@ def decorator(func: Callable) -> Callable: summary, arg_docs = _parse_docstring(func.__doc__) func._description = description or summary func._arg_descriptions = arg_docs + func._labels = labels @functools.wraps(func) async def wrapper(self, *args, **kwargs): @@ -499,6 +510,7 @@ async def wrapper(self, *args, **kwargs): wrapper._function_name = func_name wrapper._description = func._description wrapper._arg_descriptions = func._arg_descriptions + wrapper._labels = func._labels wrapper._original_func = func # For schema extraction return wrapper @@ -508,7 +520,8 @@ async def wrapper(self, *args, **kwargs): def emit( name: Optional[str] = None, - description: Optional[str] = None + description: Optional[str] = None, + labels: Optional[Dict[str, Union[str, List[str]]]] = None, ) -> Callable: """Decorator to declare an event this driver can emit. @@ -524,6 +537,10 @@ def emit( Args: name: Override event name (default: method __name__) description: Event description (default: first line of docstring) + labels: Discovery metadata as key:value pairs. Values may be a single + string or a list of strings (composite identity). Well-known keys: + safety (critical|informational), modality (rgb|thermal|motion|...). + Custom keys are allowed. Returns: Decorated method that emits event when called @@ -550,6 +567,7 @@ def decorator(func: Callable) -> Callable: summary, arg_docs = _parse_docstring(func.__doc__) func._event_description = description or summary func._payload_descriptions = arg_docs + func._labels = labels @functools.wraps(func) async def wrapper(self, *args, **kwargs): @@ -624,6 +642,7 @@ async def wrapper(self, *args, **kwargs): wrapper._event_name = event_name wrapper._event_description = func._event_description wrapper._payload_descriptions = func._payload_descriptions + wrapper._labels = func._labels wrapper._original_func = func # For schema extraction return wrapper diff --git a/packages/device-connect-edge/device_connect_edge/types.py b/packages/device-connect-edge/device_connect_edge/types.py index 00296bb..5440708 100644 --- a/packages/device-connect-edge/device_connect_edge/types.py +++ b/packages/device-connect-edge/device_connect_edge/types.py @@ -12,7 +12,7 @@ from datetime import datetime from enum import Enum -from typing import Any, Awaitable, Callable, Dict, List, Optional +from typing import Any, Awaitable, Callable, Dict, List, Optional, Union from pydantic import BaseModel, Field @@ -51,6 +51,7 @@ class FunctionDef(BaseModel): }, "required": [] }, + labels={"direction": "write", "modality": ["rgb", "4k"]}, tags=["vision", "capture"] ) """ @@ -60,6 +61,13 @@ class FunctionDef(BaseModel): default_factory=lambda: {"type": "object", "properties": {}, "required": []}, description="JSON Schema for function parameters" ) + labels: Optional[Dict[str, Union[str, List[str]]]] = Field( + default=None, + description="Discovery metadata as key:value pairs. Values may be a single string " + "or a list of strings (composite identity). Well-known keys: direction " + "(read|write), safety (critical|informational), modality (rgb|thermal|...). " + "Custom keys are allowed." + ) tags: List[str] = Field( default_factory=list, description="Tags for categorization (e.g., ['vision', 'capture'])" @@ -92,6 +100,13 @@ class EventDef(BaseModel): default=None, description="JSON Schema for event payload (optional)" ) + labels: Optional[Dict[str, Union[str, List[str]]]] = Field( + default=None, + description="Discovery metadata as key:value pairs. Values may be a single string " + "or a list of strings (composite identity). Well-known keys: safety " + "(critical|informational), modality (rgb|thermal|motion|...). Custom keys " + "are allowed." + ) tags: List[str] = Field( default_factory=list, description="Tags for categorization" @@ -125,6 +140,14 @@ class DeviceCapabilities(BaseModel): default_factory=list, description="Events the device can emit" ) + labels: Optional[Dict[str, Union[str, List[str]]]] = Field( + default=None, + description="Discovery metadata for the device as key:value pairs. Values may be a " + "single string or a list of strings (composite identity). Well-known keys: " + "category (camera|robot|hub|sensor|actuator|inference; multi-valued for " + "composite devices), location (e.g. 'warehouse1/loading-dock'; '/' for " + "hierarchy, multi-valued for mobile devices). Custom keys are allowed." + ) class DeviceIdentity(BaseModel): diff --git a/packages/device-connect-edge/tests/test_drivers.py b/packages/device-connect-edge/tests/test_drivers.py index 9b1fb65..5763be6 100644 --- a/packages/device-connect-edge/tests/test_drivers.py +++ b/packages/device-connect-edge/tests/test_drivers.py @@ -177,6 +177,92 @@ async def test_rpc_callable(self): result = await driver.do_something(value=5) assert result == {"result": 10} + +# -- Discovery labels (Phase 1) ------------------------------------ + +class TestRpcLabels: + def test_default_none(self): + @rpc() + async def f(self) -> dict: + """f.""" + return {} + + assert f._labels is None + + def test_explicit_labels(self): + @rpc(labels={"direction": "write", "modality": ["rgb", "4k"]}) + async def capture(self, resolution: str = "1080p") -> dict: + """Capture.""" + return {} + + assert capture._labels == {"direction": "write", "modality": ["rgb", "4k"]} + + +class TestEmitLabels: + def test_default_none(self): + @emit() + async def heartbeat(self): + """heartbeat.""" + pass + + assert heartbeat._labels is None + + def test_explicit_labels(self): + @emit(labels={"modality": "motion", "safety": "informational"}) + async def motion_detected(self, zone: str): + """Motion.""" + pass + + assert motion_detected._labels == {"modality": "motion", "safety": "informational"} + + +class LabeledDriver(DeviceDriver): + """Driver with class-level labels and per-method labels.""" + device_type = "camera" + labels = { + "category": ["camera", "inference"], + "location": "warehouse1/loading-dock", + } + + @rpc(labels={"direction": "write", "modality": ["rgb", "4k"]}) + async def capture_frame(self, resolution: str = "1080p") -> dict: + """Capture a frame.""" + return {} + + @rpc() + async def ping(self) -> dict: + """Ping.""" + return {} + + @emit(labels={"modality": "motion", "safety": "informational"}) + async def motion_detected(self, zone: str, confidence: float): + """Motion in zone.""" + pass + + +class TestDriverLabels: + def test_class_level_labels_on_capabilities(self): + caps = LabeledDriver().capabilities + assert caps.labels == { + "category": ["camera", "inference"], + "location": "warehouse1/loading-dock", + } + + def test_function_labels_propagated(self): + caps = LabeledDriver().capabilities + fns = {f.name: f for f in caps.functions} + assert fns["capture_frame"].labels == {"direction": "write", "modality": ["rgb", "4k"]} + assert fns["ping"].labels is None + + def test_event_labels_propagated(self): + caps = LabeledDriver().capabilities + evs = {e.name: e for e in caps.events} + assert evs["motion_detected"].labels == {"modality": "motion", "safety": "informational"} + + def test_no_class_labels_defaults_to_none(self): + # SampleDriver above does NOT define `labels` -- inherits None from DeviceDriver + assert SampleDriver().capabilities.labels is None + def test_capabilities_detected(self): """Driver should have functions and events detectable via introspection.""" driver = SampleDriver() diff --git a/packages/device-connect-edge/tests/test_types.py b/packages/device-connect-edge/tests/test_types.py index 0d580fe..bb34151 100644 --- a/packages/device-connect-edge/tests/test_types.py +++ b/packages/device-connect-edge/tests/test_types.py @@ -10,6 +10,7 @@ DeviceStatus, FunctionDef, EventDef, + DeviceCapabilities, ) @@ -75,3 +76,66 @@ def test_create(self): parameters={"type": "object", "properties": {"zone": {"type": "string"}}}, ) assert event.name == "motion_detected" + + +class TestLabels: + """Discovery labels on FunctionDef, EventDef, DeviceCapabilities (Phase 1).""" + + def test_function_labels_default_none(self): + f = FunctionDef(name="ping") + assert f.labels is None + + def test_function_single_value_label(self): + f = FunctionDef(name="get_status", labels={"direction": "read"}) + assert f.labels == {"direction": "read"} + + def test_function_multivalued_label(self): + f = FunctionDef(name="capture", labels={"modality": ["rgb", "4k"]}) + assert f.labels == {"modality": ["rgb", "4k"]} + + def test_function_labels_roundtrip(self): + f = FunctionDef( + name="set_threshold", + labels={"direction": "write", "modality": ["rgb", "4k"], "safety": "critical"}, + ) + f2 = FunctionDef.model_validate_json(f.model_dump_json()) + assert f2.labels == f.labels + + def test_event_labels_default_none(self): + e = EventDef(name="heartbeat") + assert e.labels is None + + def test_event_labels_roundtrip(self): + e = EventDef( + name="motion_detected", + labels={"modality": "motion", "safety": "informational"}, + ) + e2 = EventDef.model_validate_json(e.model_dump_json()) + assert e2.labels == e.labels + + def test_capabilities_labels_default_none(self): + c = DeviceCapabilities() + assert c.labels is None + + def test_capabilities_labels_composite_identity(self): + # category multi-valued for composite devices (camera + inference) + c = DeviceCapabilities( + labels={ + "category": ["camera", "inference"], + "location": "warehouse1/loading-dock", + } + ) + assert c.labels["category"] == ["camera", "inference"] + assert c.labels["location"] == "warehouse1/loading-dock" + + def test_capabilities_labels_roundtrip(self): + c = DeviceCapabilities( + description="Smart cam", + functions=[FunctionDef(name="capture", labels={"direction": "write"})], + events=[EventDef(name="motion", labels={"modality": "motion"})], + labels={"category": ["camera"], "location": "warehouse1/dock-3"}, + ) + c2 = DeviceCapabilities.model_validate_json(c.model_dump_json()) + assert c2.labels == c.labels + assert c2.functions[0].labels == {"direction": "write"} + assert c2.events[0].labels == {"modality": "motion"} From aedf561ae6aa07689d267955e036483c7235969f Mon Sep 17 00:00:00 2001 From: Sourav Pati Date: Sat, 9 May 2026 20:57:50 -0700 Subject: [PATCH 02/21] feat(selector): add selector DSL parser and matcher Add a pure-Python parser at device_connect_edge.selector that maps a structured selector string onto a parsed Selector dataclass with five scope shapes: device() device().function() device().event() function() event() Inside (...): key:value, key:[v1,v2] (OR within key), key:pattern* (anchored glob), k1:v1,k2:v2 (AND across keys), bare-string id/name match, or * to match all. Parse errors carry source + caret position for diagnostics. The matcher is dependency-free (stdlib only) and applies vacuous-True semantics on unset axes so callers can iterate without scope branching. --- .../device_connect_edge/selector.py | 467 ++++++++++++++++++ .../tests/test_selector.py | 348 +++++++++++++ 2 files changed, 815 insertions(+) create mode 100644 packages/device-connect-edge/device_connect_edge/selector.py create mode 100644 packages/device-connect-edge/tests/test_selector.py diff --git a/packages/device-connect-edge/device_connect_edge/selector.py b/packages/device-connect-edge/device_connect_edge/selector.py new file mode 100644 index 0000000..f2218e3 --- /dev/null +++ b/packages/device-connect-edge/device_connect_edge/selector.py @@ -0,0 +1,467 @@ +# Copyright (c) 2024-2026, Arm Limited and Contributors. All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +"""Selector DSL for hierarchical device + function discovery. + +This module parses selector expressions used by the discovery, invocation, and +subscription APIs into a structured form that can be matched against device, +function, and event records. + +Placement note: this module is dependency-free (stdlib only) and is consumed +by callers outside this package -- notably the discovery tools in +``device_connect_agent_tools``. It lives here as the lowest common ancestor +in the package dependency graph, not as edge-runtime code; ``DeviceRuntime`` +and the driver framework do not import it. + +Grammar overview: + + device() # filter on device labels + device().function() # functions on a device subset (RPCs) + device().event() # events on a device subset + function() # all RPCs across the fleet + event() # all events across the fleet + +Inside ``(...)``: + + key:value single value match + key:[v1,v2] OR within a key (matches if label contains any value) + key:pattern* glob (``*``, ``?``) + k1:v1,k2:v2 AND across keys + bare-string id/name match: ``device(robot-001)`` + * match all +""" +from __future__ import annotations + +import fnmatch +from dataclasses import dataclass, field +from enum import Enum +from typing import Dict, List, Optional, Tuple, Union + +# A label value is either a single string or a list of strings (composite identity). +LabelValue = Union[str, List[str]] +Labels = Dict[str, LabelValue] + + +class SelectorParseError(ValueError): + """Raised when a selector string cannot be parsed.""" + + def __init__(self, message: str, source: str = "", position: Optional[int] = None): + if position is not None and source: + caret = " " * position + "^" + full = f"{message} at position {position}\n {source}\n {caret}" + elif source: + full = f"{message}: {source!r}" + else: + full = message + super().__init__(full) + self.source = source + self.position = position + + +class Scope(str, Enum): + """Which entities a selector matches. + + DEVICE_ONLY - device(...) + DEVICE_FUNCTION - device(...).function(...) + DEVICE_EVENT - device(...).event(...) + FUNCTION_ONLY - function(...) + EVENT_ONLY - event(...) + """ + DEVICE_ONLY = "device_only" + DEVICE_FUNCTION = "device_function" + DEVICE_EVENT = "device_event" + FUNCTION_ONLY = "function_only" + EVENT_ONLY = "event_only" + + +@dataclass(frozen=True) +class KeyFilter: + """Filter on a single label key. + + Values are OR'd: any matching value is sufficient. Each value may contain + glob characters (``*`` and ``?``) per ``fnmatch`` semantics. + + ``children`` is reserved for grammar extensions (nested boolean + expressions, AND-within-key, negation) and is empty in the current + parser. Carrying the field on the dataclass now lets future versions + populate it without breaking the public type shape. + """ + key: str + values: Tuple[str, ...] + children: Tuple["KeyFilter", ...] = field(default_factory=tuple) + + def matches(self, label_value: Optional[LabelValue]) -> bool: + """True iff the label value satisfies this key filter. + + For multi-valued labels (list), passes if any element matches any of + this filter's values. + """ + if label_value is None: + return False + actual: Tuple[str, ...] + if isinstance(label_value, list): + actual = tuple(label_value) + else: + actual = (label_value,) + for pattern in self.values: + if "*" in pattern or "?" in pattern: + for a in actual: + if fnmatch.fnmatchcase(a, pattern): + return True + else: + if pattern in actual: + return True + return False + + +@dataclass(frozen=True) +class Filter: + """One axis of a selector - matches a single entity (device, function, or event). + + Combines an optional bare-string name match with AND-across-keys label + filters. An empty Filter (no name match, no key filters) matches every + entity, so ``*`` and empty parens both reduce to that case. + """ + name_match: Optional[str] = None + key_filters: Tuple[KeyFilter, ...] = field(default_factory=tuple) + + def matches(self, name: str, labels: Optional[Labels]) -> bool: + """True iff this filter matches the given entity.""" + if self.name_match is not None: + pattern = self.name_match + if "*" in pattern or "?" in pattern: + if not fnmatch.fnmatchcase(name, pattern): + return False + elif name != pattern: + return False + for kf in self.key_filters: + label_value = labels.get(kf.key) if labels else None + if not kf.matches(label_value): + return False + return True + + +@dataclass(frozen=True) +class Selector: + """Parsed selector expression. + + Each axis is an optional :class:`Filter`. A ``None`` axis is vacuously + True - ``matches_function`` on a device-only selector returns True so the + caller can write a single-pass enumeration without scope branching. + """ + scope: Scope + device: Optional[Filter] = None + function: Optional[Filter] = None + event: Optional[Filter] = None + raw: str = "" + + def matches_device(self, name: str, labels: Optional[Labels]) -> bool: + if self.device is None: + return True + return self.device.matches(name, labels) + + def matches_function(self, name: str, labels: Optional[Labels]) -> bool: + if self.function is None: + return True + return self.function.matches(name, labels) + + def matches_event(self, name: str, labels: Optional[Labels]) -> bool: + if self.event is None: + return True + return self.event.matches(name, labels) + + +# -- Parsing ------------------------------------------------------- + + +def _split_top_commas(body: str, source: str, base_offset: int) -> List[Tuple[str, int]]: + """Split a filter body on top-level commas. + + Respects ``[...]`` bracket nesting: commas inside brackets are part of the + value list, not term separators. Returns ``(term, abs_offset_of_term_start)`` + pairs to support precise error positioning. + """ + terms: List[Tuple[str, int]] = [] + depth = 0 + start = 0 + for i, ch in enumerate(body): + if ch == "[": + depth += 1 + elif ch == "]": + if depth == 0: + raise SelectorParseError( + "Unmatched ']'", source=source, position=base_offset + i + ) + depth -= 1 + elif ch == "," and depth == 0: + terms.append((body[start:i], base_offset + start)) + start = i + 1 + if depth != 0: + raise SelectorParseError( + "Unmatched '['", source=source, position=base_offset + body.rfind("[") + ) + terms.append((body[start:], base_offset + start)) + return terms + + +def _parse_value_part(value: str, source: str, base_offset: int) -> Tuple[str, ...]: + """Parse the right-hand side of ``key:``. + + Returns a tuple of value strings (one element for single value, multiple for + bracketed OR list). Each value may contain glob characters. + """ + value = value.strip() + if not value: + raise SelectorParseError( + "Empty value after ':'", source=source, position=base_offset + ) + if value.startswith("["): + if not value.endswith("]"): + raise SelectorParseError( + "Unclosed '['", source=source, position=base_offset + ) + inner = value[1:-1].strip() + if not inner: + raise SelectorParseError( + "Empty value list '[]'", source=source, position=base_offset + ) + # Bracket bodies are flat (Phase 2 grammar); split on commas, strip, reject empties + out: List[str] = [] + for raw in inner.split(","): + v = raw.strip() + if not v: + raise SelectorParseError( + "Empty value in list", source=source, position=base_offset + ) + if "[" in v or "]" in v: + raise SelectorParseError( + "Nested brackets are not supported in this DSL version", + source=source, + position=base_offset, + ) + out.append(v) + return tuple(out) + if "[" in value or "]" in value: + raise SelectorParseError( + "Stray bracket in value", source=source, position=base_offset + ) + return (value,) + + +_KEY_PATTERN = ("0123456789" + "abcdefghijklmnopqrstuvwxyz" + "ABCDEFGHIJKLMNOPQRSTUVWXYZ" + "_-.") + + +def _is_valid_key(key: str) -> bool: + """Label keys are conservative identifiers: alnum, '_', '-', '.'.""" + return bool(key) and all(c in _KEY_PATTERN for c in key) + + +def _parse_filter_body(body: str, source: str, base_offset: int) -> Filter: + """Parse the contents of one ``(...)`` block into a :class:`Filter`. + + Supports: + ``*`` or empty body -> match-all (empty Filter) + ``key:value`` -> single-value key filter + ``key:[v1,v2]`` -> OR within a key + ``key:pattern*`` -> glob value + ``k1:v1,k2:v2`` -> AND across keys + bare string -> name match (id/name) + bare + key:value -> name AND key constraints + """ + stripped = body.strip() + if not stripped or stripped == "*": + return Filter() + + name_match: Optional[str] = None + key_filters: List[KeyFilter] = [] + + for term, term_offset in _split_top_commas(body, source, base_offset): + # Account for leading whitespace inside the term when reporting positions. + leading = len(term) - len(term.lstrip()) + term_stripped = term.strip() + term_abs = term_offset + leading + if not term_stripped: + raise SelectorParseError( + "Empty term (extra comma?)", source=source, position=term_abs + ) + + # Find a top-level ':' (one not inside the value brackets) to classify + # bare-name vs key:value. + colon_pos = -1 + depth = 0 + for j, ch in enumerate(term_stripped): + if ch == "[": + depth += 1 + elif ch == "]": + depth -= 1 + elif ch == ":" and depth == 0: + colon_pos = j + break + + if colon_pos < 0: + # Bare term: name match or '*' + if term_stripped == "*": + continue # vacuous, contributes nothing + if name_match is not None: + raise SelectorParseError( + f"Multiple bare-name terms ({name_match!r} and {term_stripped!r})", + source=source, + position=term_abs, + ) + name_match = term_stripped + continue + + key = term_stripped[:colon_pos].strip() + value_part = term_stripped[colon_pos + 1:] + value_offset = term_abs + colon_pos + 1 + if not _is_valid_key(key): + raise SelectorParseError( + f"Invalid key {key!r} (allowed: alphanumeric, '_', '-', '.')", + source=source, + position=term_abs, + ) + values = _parse_value_part(value_part, source, value_offset) + key_filters.append(KeyFilter(key=key, values=values)) + + return Filter(name_match=name_match, key_filters=tuple(key_filters)) + + +_VALID_SCOPES = ("device", "function", "event") + + +def _consume_scope(s: str, source: str, start: int) -> Tuple[str, Filter, int]: + """Consume one ``()`` from ``s`` starting at ``start``. + + Returns ``(scope_name, filter, position_after_closing_paren)``. Skips + leading whitespace. + """ + i = start + n = len(s) + while i < n and s[i].isspace(): + i += 1 + name_start = i + while i < n and s[i] not in "( \t": + i += 1 + name = s[name_start:i] + if not name: + raise SelectorParseError( + "Expected scope name (device|function|event)", source=source, position=name_start + ) + if name not in _VALID_SCOPES: + raise SelectorParseError( + f"Unknown scope {name!r} (expected one of {_VALID_SCOPES})", + source=source, + position=name_start, + ) + while i < n and s[i].isspace(): + i += 1 + if i >= n or s[i] != "(": + raise SelectorParseError( + f"Expected '(' after scope {name!r}", source=source, position=i + ) + body_start = i + 1 + # Find matching ')', tracking [...] nesting so a stray ')' inside brackets + # would not be treated as the scope close. (Reserved chars rule out ')' + # in valid values, but be defensive.) + depth = 0 + last_open_bracket = -1 + j = body_start + while j < n: + ch = s[j] + if ch == "[": + depth += 1 + last_open_bracket = j + elif ch == "]": + depth -= 1 + elif ch == ")" and depth == 0: + break + j += 1 + if j >= n: + if depth > 0: + raise SelectorParseError( + "Unclosed '['", source=source, position=last_open_bracket + ) + raise SelectorParseError( + f"Unclosed '(' for scope {name!r}", source=source, position=body_start - 1 + ) + body = s[body_start:j] + flt = _parse_filter_body(body, source=source, base_offset=body_start) + return name, flt, j + 1 + + +def parse_selector(s: str) -> Selector: + """Parse a selector string into a :class:`Selector`. + + Examples:: + + parse_selector("device(category:camera)") + parse_selector("device(category:[camera,robot], location:warehouse1/*)") + parse_selector("device(*).function(direction:write)") + parse_selector("function(safety:critical)") + + Raises :class:`SelectorParseError` on malformed input. + """ + if not isinstance(s, str): + raise SelectorParseError(f"Selector must be a string, got {type(s).__name__}") + raw = s + if not s.strip(): + raise SelectorParseError("Empty selector", source=raw, position=0) + + name1, filter1, after1 = _consume_scope(s, source=raw, start=0) + + # Optional ".scope(...)" extension + i = after1 + n = len(s) + while i < n and s[i].isspace(): + i += 1 + + if i >= n: + # Single-scope selector + if name1 == "device": + return Selector(scope=Scope.DEVICE_ONLY, device=filter1, raw=raw) + if name1 == "function": + return Selector(scope=Scope.FUNCTION_ONLY, function=filter1, raw=raw) + if name1 == "event": + return Selector(scope=Scope.EVENT_ONLY, event=filter1, raw=raw) + # _consume_scope already validated name1 + raise SelectorParseError(f"Internal: unhandled scope {name1!r}", source=raw) + + if s[i] != ".": + raise SelectorParseError( + f"Unexpected character {s[i]!r} after scope", source=raw, position=i + ) + + name2, filter2, after2 = _consume_scope(s, source=raw, start=i + 1) + + # Trailing content? + j = after2 + while j < n and s[j].isspace(): + j += 1 + if j < n: + raise SelectorParseError( + f"Unexpected trailing content {s[j:]!r}", source=raw, position=j + ) + + if name1 != "device": + raise SelectorParseError( + f"Chained scopes must start with 'device', got {name1!r}", + source=raw, + position=0, + ) + if name2 == "function": + return Selector( + scope=Scope.DEVICE_FUNCTION, device=filter1, function=filter2, raw=raw + ) + if name2 == "event": + return Selector( + scope=Scope.DEVICE_EVENT, device=filter1, event=filter2, raw=raw + ) + raise SelectorParseError( + f"Cannot chain device(...).{name2}(...); expected 'function' or 'event'", + source=raw, + position=i + 1, + ) diff --git a/packages/device-connect-edge/tests/test_selector.py b/packages/device-connect-edge/tests/test_selector.py new file mode 100644 index 0000000..d50a78e --- /dev/null +++ b/packages/device-connect-edge/tests/test_selector.py @@ -0,0 +1,348 @@ +# Copyright (c) 2024-2026, Arm Limited and Contributors. All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +"""Unit tests for the selector DSL parser and matcher. + +Parses selector strings like +``device(category:camera, location:warehouse1/*).function(direction:write)`` +into a structured Selector and matches it against label dicts. +""" +import pytest + +from device_connect_edge.selector import ( + Filter, + KeyFilter, + Scope, + Selector, + SelectorParseError, + parse_selector, +) + + +# -- KeyFilter ----------------------------------------------------- + + +class TestKeyFilter: + def test_single_value_str_label(self): + kf = KeyFilter("direction", ("write",)) + assert kf.matches("write") + assert not kf.matches("read") + + def test_none_label_never_matches(self): + assert not KeyFilter("direction", ("write",)).matches(None) + + def test_list_label_any_member_matches(self): + kf = KeyFilter("category", ("camera",)) + assert kf.matches(["camera", "inference"]) + assert not kf.matches(["robot", "inference"]) + + def test_or_within_key(self): + kf = KeyFilter("category", ("camera", "robot")) + assert kf.matches("camera") + assert kf.matches("robot") + assert not kf.matches("hub") + assert kf.matches(["camera", "inference"]) + + def test_glob_value(self): + kf = KeyFilter("location", ("warehouse1/*",)) + assert kf.matches("warehouse1/loading-dock") + assert kf.matches("warehouse1/yard") + assert not kf.matches("warehouse2/dock") + + def test_subtree_glob_matches_exact_and_descendants(self): + # ``lab-A*`` matches both the exact location and any descendants. + kf = KeyFilter("location", ("lab-A*",)) + assert kf.matches("lab-A") + assert kf.matches("lab-A/optics-bench") + assert not kf.matches("lab-B") + + +# -- Filter -------------------------------------------------------- + + +class TestFilter: + def test_empty_filter_matches_anything(self): + f = Filter() + assert f.matches("anything", None) + assert f.matches("foo", {"k": "v"}) + + def test_name_match_exact(self): + f = Filter(name_match="robot-001") + assert f.matches("robot-001", None) + assert not f.matches("robot-002", None) + + def test_name_match_glob(self): + f = Filter(name_match="set_*") + assert f.matches("set_threshold", {}) + assert f.matches("set_location", {}) + assert not f.matches("get_reading", {}) + + def test_and_across_keys(self): + f = Filter( + key_filters=( + KeyFilter("category", ("camera",)), + KeyFilter("location", ("warehouse1/*",)), + ) + ) + assert f.matches("cam1", {"category": "camera", "location": "warehouse1/dock"}) + assert not f.matches("cam1", {"category": "camera", "location": "warehouse2/dock"}) + assert not f.matches("cam1", {"category": "robot", "location": "warehouse1/dock"}) + + def test_name_and_label_combined(self): + f = Filter( + name_match="set_*", + key_filters=(KeyFilter("direction", ("write",)),), + ) + assert f.matches("set_threshold", {"direction": "write"}) + assert not f.matches("set_threshold", {"direction": "read"}) + assert not f.matches("get_reading", {"direction": "write"}) + + def test_missing_label_means_no_match(self): + f = Filter(key_filters=(KeyFilter("safety", ("critical",)),)) + assert not f.matches("foo", {}) + assert not f.matches("foo", None) + + +# -- Selector vacuous axes ----------------------------------------- + + +class TestSelectorVacuous: + """Unset axes return True so callers can iterate without scope branching.""" + + def test_device_only_function_vacuous(self): + s = Selector(scope=Scope.DEVICE_ONLY, device=Filter()) + assert s.matches_function("anything", {"direction": "write"}) + assert s.matches_event("anything", None) + + def test_function_only_device_vacuous(self): + s = Selector(scope=Scope.FUNCTION_ONLY, function=Filter()) + assert s.matches_device("any-id", None) + + +# -- parse_selector: scope shapes --------------------------------- + + +class TestParseScope: + def test_device_only(self): + s = parse_selector("device(category:camera)") + assert s.scope == Scope.DEVICE_ONLY + assert s.device == Filter(key_filters=(KeyFilter("category", ("camera",)),)) + assert s.function is None + assert s.event is None + + def test_function_only(self): + s = parse_selector("function(safety:critical)") + assert s.scope == Scope.FUNCTION_ONLY + assert s.function.key_filters == (KeyFilter("safety", ("critical",)),) + + def test_event_only(self): + s = parse_selector("event(modality:motion)") + assert s.scope == Scope.EVENT_ONLY + assert s.event.key_filters == (KeyFilter("modality", ("motion",)),) + + def test_device_function(self): + s = parse_selector("device(*).function(direction:write)") + assert s.scope == Scope.DEVICE_FUNCTION + assert s.device == Filter() + assert s.function.key_filters == (KeyFilter("direction", ("write",)),) + + def test_device_event(self): + s = parse_selector("device(*).event(modality:motion)") + assert s.scope == Scope.DEVICE_EVENT + + def test_bare_id_match(self): + s = parse_selector("device(robot-001)") + assert s.device.name_match == "robot-001" + + def test_function_name_match(self): + s = parse_selector("function(estop)") + assert s.function.name_match == "estop" + + def test_wildcard_matches_anything(self): + s = parse_selector("device(*)") + assert s.device == Filter() + + def test_raw_preserved(self): + sel = "device(category:camera)" + assert parse_selector(sel).raw == sel + + def test_whitespace_tolerated(self): + s = parse_selector( + " device( category : camera ) . function( direction : write ) " + ) + assert s.scope == Scope.DEVICE_FUNCTION + assert s.device.key_filters == (KeyFilter("category", ("camera",)),) + assert s.function.key_filters == (KeyFilter("direction", ("write",)),) + + +# -- parse_selector: filter body grammar --------------------------- + + +class TestParseFilterBody: + def test_or_within_key(self): + s = parse_selector("device(category:[camera,robot])") + assert s.device.key_filters == (KeyFilter("category", ("camera", "robot")),) + + def test_and_across_keys(self): + s = parse_selector("device(category:camera, location:warehouse1/*)") + assert s.device.key_filters == ( + KeyFilter("category", ("camera",)), + KeyFilter("location", ("warehouse1/*",)), + ) + + def test_combined_or_and_glob(self): + s = parse_selector("device(category:[camera,robot], location:warehouse1/*)") + assert s.device.key_filters == ( + KeyFilter("category", ("camera", "robot")), + KeyFilter("location", ("warehouse1/*",)), + ) + + def test_bare_name_plus_keys(self): + s = parse_selector("device(temperature_sensor).function(direction:write, set_*)") + assert s.device.name_match == "temperature_sensor" + assert s.function.name_match == "set_*" + assert s.function.key_filters == (KeyFilter("direction", ("write",)),) + + +# -- parse_selector: errors ---------------------------------------- + + +class TestParseErrors: + @pytest.mark.parametrize("bad,expected", [ + ("", "empty"), + (" ", "empty"), + ("device", "expected '('"), + ("device(", "unclosed"), + ("foo(x)", "unknown scope"), + ("function(*).device(*)", "must start with"), + ("device(*).device(*)", "expected 'function' or 'event'"), + ("device(*).function(*).event(*)", "unexpected trailing"), + ("device(*) extra", "unexpected character"), + ("device(robot-001, robot-002)", "multiple bare-name"), + ("device(key:)", "empty value"), + ("device(:value)", "invalid key"), + ("device(,)", "empty term"), + ("device(key:[)", "unclosed '['"), + ("device(key:[])", "empty value list"), + ("device(key:[a,])", "empty value in list"), + ("device(key:[[a]])", "nested"), + ("device(bad key:val)", "invalid key"), + ]) + def test_error_messages(self, bad, expected): + with pytest.raises(SelectorParseError) as exc: + parse_selector(bad) + assert expected.lower() in str(exc.value).lower() + + def test_non_string_input(self): + with pytest.raises(SelectorParseError): + parse_selector(123) # type: ignore[arg-type] + + def test_error_includes_position_caret(self): + with pytest.raises(SelectorParseError) as exc: + parse_selector("device(foo, bad key:v)") + msg = str(exc.value) + assert "device(foo, bad key:v)" in msg + assert "^" in msg + + +# -- Worked examples ----------------------------------------------- + + +class TestWorkedExamples: + """End-to-end parse + match using DC-native device kinds (camera, robot, + sensor) and the labels that drivers would carry.""" + + def test_all_cameras(self): + s = parse_selector("device(category:camera)") + assert s.matches_device("cam-001", {"category": "camera"}) + # composite identity: camera that also runs inference + assert s.matches_device("cam-002", {"category": ["camera", "inference"]}) + assert not s.matches_device("robot-001", {"category": "robot"}) + + def test_or_within_key_with_zone_filter(self): + # cameras or robots in zone-A + s = parse_selector("device(category:[camera,robot], location:zone-A/*)") + assert s.matches_device( + "cam-1", {"category": "camera", "location": "zone-A/loading-dock"} + ) + assert s.matches_device( + "robot-1", {"category": "robot", "location": "zone-A/yard"} + ) + assert not s.matches_device( + "hub-1", {"category": "hub", "location": "zone-A/dock"} + ) + assert not s.matches_device( + "cam-2", {"category": "camera", "location": "zone-B/dock"} + ) + + def test_zone_subtree(self): + # ``zone-A*`` glob matches both ``zone-A`` exactly and any descendant. + s = parse_selector("device(location:zone-A*)") + assert s.matches_device("d", {"location": "zone-A"}) + assert s.matches_device("d", {"location": "zone-A/dock"}) + assert not s.matches_device("d", {"location": "zone-B"}) + + def test_capture_writes_fleet_wide(self): + # ``capture_image`` is DC's canonical camera RPC. Filtering for write + # direction + rgb modality across the fleet picks it up. + s = parse_selector("device(*).function(direction:write, modality:rgb)") + assert s.scope == Scope.DEVICE_FUNCTION + assert s.matches_device("anything", None) + assert s.matches_function( + "capture_image", {"direction": "write", "modality": "rgb"} + ) + assert s.matches_function( + "capture_image", {"direction": "write", "modality": ["rgb", "4k"]} + ) + assert not s.matches_function( + "get_status", {"direction": "read", "modality": "rgb"} + ) + assert not s.matches_function( + "capture_image", {"direction": "write", "modality": "thermal"} + ) + + def test_object_detection_events_fleet_wide(self): + # The ``test_camera`` driver emits ``object_detected`` events; subscribe + # to it across the fleet via a bare-name event match. + s = parse_selector("device(*).event(object_detected)") + assert s.scope == Scope.DEVICE_EVENT + assert s.matches_event("object_detected", None) + assert not s.matches_event("state_change_detected", None) + + def test_critical_rpcs_fleetwide(self): + s = parse_selector("function(safety:critical)") + assert s.matches_function("estop", {"safety": "critical"}) + assert not s.matches_function("get_reading", {"safety": "informational"}) + + def test_estop_name_match_ignores_labels(self): + # Fleet-wide ESTOP target by reserved name, regardless of labels. + s = parse_selector("function(estop)") + assert s.matches_function("estop", None) + assert s.matches_function("estop", {"safety": "critical"}) + assert not s.matches_function("get_reading", {"safety": "critical"}) + + def test_chained_sensor_writes_with_name_glob(self): + # The ``temperature_sensor`` driver exposes ``set_threshold`` and + # ``set_location`` (writes) plus ``get_reading`` (read). The anchored + # glob ``set_*`` selects only the writers. + s = parse_selector( + "device(temperature_sensor).function(direction:write, set_*)" + ) + assert s.matches_device("temperature_sensor", None) + assert not s.matches_device("test_camera", None) + assert s.matches_function("set_threshold", {"direction": "write"}) + assert s.matches_function("set_location", {"direction": "write"}) + # Anchored glob: a function whose name does NOT start with ``set_`` + # never matches, regardless of direction. + assert not s.matches_function("get_reading", {"direction": "read"}) + # Right name shape, wrong direction -> rejected. + assert not s.matches_function("set_threshold", {"direction": "read"}) + + def test_substring_glob_finds_reading_in_either_direction(self): + # Anchored globs are the default; for substring intent callers wrap + # with ``*...*``. ``*reading*`` finds the sensor's getter and the event. + s = parse_selector("function(*reading*)") + assert s.matches_function("get_reading", {"direction": "read"}) + assert s.matches_function("readings_summary", None) + assert not s.matches_function("set_threshold", {"direction": "write"}) From 48b4386299b82ce324713e6c534c8bee9d60108f Mon Sep 17 00:00:00 2001 From: Sourav Pati Date: Sat, 9 May 2026 20:58:46 -0700 Subject: [PATCH 03/21] feat(discovery): selector-driven discover and discover_labels Add two new agent tools that replace the hierarchical trio: - discover(selector, offset, limit) resolves a selector to matched devices, function tuples, or event tuples. Adaptive response shape: small result sets include full schemas inline; large sets paginate with name-and-labels summaries (DC_FUNCTION_THRESHOLD=20). - discover_labels(key, offset, limit) returns the label vocabulary, per axis (no key) or paginated values for one key. Response envelope: {scope, matched, returned, offset, next_offset, results, label_histogram}. The label_histogram describes the matched set (pre-pagination) so callers can choose how to narrow next without a second call. On the device axis, multi-valued keys also expose unique_devices for cardinality. flatten_device now mirrors the legacy DeviceStatus.location into labels["location"] when capabilities.labels does not declare one, so drivers populating only the heartbeat field remain discoverable via selector queries on location. Migrate first-party adapters (Claude Agent SDK, Strands, LangChain, the in-tree StrandsOpenAIDeviceConnectAgent) to discover/discover_labels. The legacy describe_fleet/list_devices/get_device_functions trio remains for one release as advisory-deprecated wrappers; each call emits a DeprecationWarning pointing to the equivalent discover() invocation. Test drivers carry category, direction, modality, and safety labels so integration tests can exercise the full selector grammar end-to-end. --- docs/adr/0001-selector-driven-discovery.md | 236 ++++++++ .../device_connect_agent_tools/__init__.py | 47 +- .../device_connect_agent_tools/_normalize.py | 57 ++ .../adapters/__init__.py | 4 +- .../adapters/claude.py | 99 ++- .../adapters/langchain.py | 27 +- .../adapters/strands.py | 27 +- .../adapters/strands_agent.py | 28 +- .../device_connect_agent_tools/connection.py | 17 +- .../device_connect_agent_tools/tools.py | 453 +++++++++++++- .../tests/test_claude_adapter.py | 5 +- .../tests/test_discover.py | 361 +++++++++++ .../tests/test_langchain_adapter.py | 16 +- .../tests/test_strands_adapter.py | 16 +- .../portal/views/devices.py | 23 +- tests/drivers/camera.py | 7 +- tests/drivers/robot.py | 5 +- tests/drivers/sensor.py | 11 +- tests/tests/test_tools_selector.py | 570 ++++++++++++++++++ 19 files changed, 1832 insertions(+), 177 deletions(-) create mode 100644 docs/adr/0001-selector-driven-discovery.md create mode 100644 packages/device-connect-agent-tools/tests/test_discover.py create mode 100644 tests/tests/test_tools_selector.py diff --git a/docs/adr/0001-selector-driven-discovery.md b/docs/adr/0001-selector-driven-discovery.md new file mode 100644 index 0000000..635ed67 --- /dev/null +++ b/docs/adr/0001-selector-driven-discovery.md @@ -0,0 +1,236 @@ +# ADR 0001: Selector-driven discovery and operations + +- **Status:** Accepted + +## Summary + +Device Connect exposes one selector grammar that addresses devices, +functions, and events. The same selector string drives every discovery and +operation tool: it tells the system **which** entities you mean. Labels +attached to devices, functions, and events provide the dimensions to filter +on. + +Two reasons this matters in practice: + +- **Agent context budgets.** Loading every device's full schema into an LLM + context exhausts the budget on fleets of more than a few dozen devices. + Selectors let an agent narrow first and load schemas only for what it + actually needs. +- **Cross-cutting queries.** Real questions are rarely "list this one + device" - they are "every camera in lab-A", "all critical RPCs", + "any motion event in zone-B". One grammar covers all of them. + +## Labels + +Labels are key/value metadata on devices, functions, and events. Values are +strings or lists of strings. Lists express composite identity (a smart +camera that is both `camera` and `inference`). + +Drivers declare labels in two places: + +```python +class SmartCamera(DeviceDriver): + labels = { + "category": ["camera", "inference"], + "location": "lab-A/optics-bench", + } + + @rpc(labels={"direction": "write", "modality": ["rgb", "4k"]}) + async def capture_image(self, resolution: str = "1080p") -> dict: + ... + + @emit(labels={"modality": "motion"}) + async def state_change_detected(self, zone_id: str, state_class: str): + ... +``` + +### Well-known keys + +These keys carry conventional meaning. Custom keys are always allowed +alongside them. + +| Question the agent asks | Key | Applies to | Example values | +| --- | --- | --- | --- | +| What is it? | `category` | device | `camera`, `robot`, `hub`, `sensor`, `actuator`, `inference` | +| Where is it? | `location` | device | `lab-A`, `zone-A/dock` (`/`-hierarchical, glob-able) | +| Read or write? | `direction` | function (RPC) | `read`, `write` | +| Is it dangerous? | `safety` | function + event | `critical`, `informational` | +| What kind of signal? | `modality` | function + event | `rgb`, `thermal`, `infrared`, `motion`, `4k`, ... | + +The RPC-vs-event distinction is structural (FunctionDef vs EventDef) and is +expressed by the selector scope, not by a label. + +## Selector grammar + +``` +device() device-only +device().function() RPCs on a device subset +device().event() events on a device subset +function() all RPCs across the fleet +event() all events across the fleet +``` + +Inside `(...)`: + +- `key:value` - single-value match +- `key:[v1,v2]` - OR within a key (matches if the label value contains any + of the listed values; multi-valued labels match if any element is in the + list) +- `key:pattern*` - anchored glob (`*`, `?`); `set_*` matches `set_threshold` + but not `unset_threshold`. Use `*set*` for substring. +- `k1:v1,k2:v2` - AND across keys +- bare string (no colon) - id/name match: `device(robot-001)`, + `function(capture_image)`. Globs allowed: `device(cam-*)`. +- `*` or empty - match all + +Keys inside `device(...)` resolve against device labels; keys inside +`function(...)` resolve against function labels; keys inside `event(...)` +resolve against event labels. The `.` chains: "narrow to these devices, +then narrow to these functions or events on them." + +### Examples + +``` +device(category:camera) all cameras +device(category:[camera,robot], location:lab-A/*) cameras or robots in lab-A +device(location:lab-A*) lab-A and any descendant +device(*).function(direction:write, modality:rgb) rgb-producing writes fleet-wide +device(*).event(modality:motion) all motion events +function(safety:critical) critical RPCs fleet-wide +function(estop) fleet emergency-stop targets +``` + +## Tools + +### Discovery + +| Tool | What it returns | +| --- | --- | +| `discover_labels(key=None, offset=0, limit=50)` | Fleet label vocabulary. With no `key`, returns top values per key across each axis (device, function, event). With `key="device.location"` (etc.), paginates one key's values. Use this first when you do not know which dimensions are available. | +| `discover(selector, offset=0, limit=200)` | Resolves a selector to matched entities. Returns devices, function tuples, or event tuples depending on the selector scope. Includes a `label_histogram` so you can see which dimensions to narrow on next without a separate call. | + +`discover()` includes full schemas inline when the matched set is small, +and switches to a name-and-labels summary above +`DEVICE_CONNECT_FUNCTION_THRESHOLD` (default 20). The threshold is +configurable. + +### Operations + +Calling a function on devices is one logical operation; the only choice is +whether you want to wait for replies and how they are surfaced. + +| Tool | Selector resolves to | Reply mode | +| --- | --- | --- | +| `invoke(selector, params)` | exactly one RPC tuple | sync, single result | +| `invoke_many(selector, params, where=, bindings=)` | any number of RPC tuples | sync, aggregated | +| `broadcast(selector, function, params, where=, bindings=, fire_at=, on_late=)` | any number of RPC tuples | async; correlation-tagged replies stream as events | +| `subscribe(selector)` | events, or `correlation:` for a broadcast's replies | subscription handle | +| `await_replies(correlation_id, timeout=, until=)` | replies for one broadcast | sync helper that subscribes, collects, returns | + +`invoke_many` and `broadcast` accept an optional `where` predicate +evaluated at the edge against each candidate's identity, labels, and shared +`bindings`. Use `where` for self-knowable state ("battery > 50%") and +shared `bindings` for dispatcher-computed selection masks (spatial regions, +ML score top-K, random samples). + +`broadcast` accepts `fire_at` (wall-clock epoch seconds) for synchronized +fan-out: each device holds the message and fires from its own clock at the +target time. `on_late` (`"skip"` or `"fire"`) controls behaviour when a +device receives the message after the deadline. + +## Pagination + +`discover` and `discover_labels` accept `offset` and `limit`. Responses +follow a stable envelope: + +```json +{ + "matched": 7421, + "returned": 200, + "offset": 0, + "next_offset": 200, + "results": [...] +} +``` + +`next_offset` is `null` when there are no more pages. The hard ceiling on +`limit` is 1000 to prevent runaway responses; ask for more pages instead. + +Operation tools (`invoke_many`, `broadcast`) do not paginate - that is a +streaming-dispatch concern. Subscribe to the result channel for per-target +detail at large fan-out. + +## Worked examples + +### Find every camera in lab-A and capture an image from each + +```python +result = invoke_many( + selector="device(category:camera, location:lab-A).function(capture_image)", + params={"resolution": "1080p"}, +) +# {"candidates": 12, "matched": 12, "succeeded": 12, "results": [...], "errors": []} +``` + +### Async fleet emergency-stop + +```python +broadcast("function(estop)") +# {"correlation_id": "br-7f3a91", "candidates": 240} + +# Optionally wait for replies: +replies = await_replies("br-7f3a91", timeout=5.0) +``` + +### Synchronized actuation across a phone fleet + +```python +broadcast( + selector="device(category:phone, location:auditorium-A)", + function="set_flashlight", + params={"on": True, "color": "white"}, + where="mask[seat_row][seat_col] == 1", + bindings={"mask": }, + fire_at=time.time() + 0.500, + on_late="skip", +) +``` + +### Browse the fleet vocabulary first + +```python +vocab = discover_labels() +# {"total_devices": 1247, "total_functions": 7100, +# "device_keys": {"category": {...}, "location": {...}}, +# "function_keys": {"direction": {...}, "modality": {...}, "safety": {...}}, +# "event_keys": {"modality": {...}}} + +# Then narrow to one dimension: +locations = discover_labels(key="device.location", limit=50) +``` + +### Subscribe to motion events in lab-A + +```python +sub = subscribe("device(location:lab-A/*).event(modality:motion)") +# {"subscription_id": "sub-abc123", "matched": 8} +``` + +## CLI + +The same selector syntax drives the operator CLIs. Every CLI command maps +to the matching tool call. + +``` +devctl discover-labels [--key K] [--offset N] [--limit M] +devctl discover "" [--offset N] [--limit M] + +statectl invoke "" [--param k=v] +statectl invoke-many "" [--param k=v] [--where E] +statectl broadcast "" [--param k=v] [--where E] [--fire-at T] +statectl subscribe "" +statectl await "" [--timeout T] +``` + +CLI flags `--param k=v` and `--where E` pack into the tool arguments; the +CLIs are thin shell wrappers over the Python tools. diff --git a/packages/device-connect-agent-tools/device_connect_agent_tools/__init__.py b/packages/device-connect-agent-tools/device_connect_agent_tools/__init__.py index fb5b198..de79913 100644 --- a/packages/device-connect-agent-tools/device_connect_agent_tools/__init__.py +++ b/packages/device-connect-agent-tools/device_connect_agent_tools/__init__.py @@ -4,37 +4,36 @@ """Device Connect Tools — framework-agnostic SDK for Device Connect IoT. -Hierarchical discovery keeps LLM context small: +Selector-driven discovery keeps LLM context small: - from device_connect_agent_tools import connect, describe_fleet, list_devices + from device_connect_agent_tools import connect, discover, discover_labels connect() - fleet = describe_fleet() # bird's-eye summary (~200 tokens) - cameras = list_devices(device_type="camera") # compact roster - info = get_device_functions("camera-001") # full schemas for one device + vocab = discover_labels() # fleet vocabulary + cams = discover("device(category:camera, location:zone-A/*)") # device roster + writes = discover("device(*).function(direction:write)") # function tuples result = invoke_device("camera-001", "capture_image", {"resolution": "1080p"}) -Strands: - from device_connect_agent_tools import connect - from device_connect_agent_tools.adapters.strands import ( - describe_fleet, list_devices, get_device_functions, invoke_device, - ) - from strands import Agent - - connect() - agent = Agent(tools=[describe_fleet, list_devices, get_device_functions, invoke_device]) +The older ``describe_fleet`` / ``list_devices`` / ``get_device_functions`` +trio remains available for one release as advisory-deprecated wrappers -- +prefer ``discover`` / ``discover_labels`` for new code. """ from device_connect_agent_tools.agent import DeviceConnectAgent from device_connect_agent_tools.connection import connect, disconnect, get_connection from device_connect_agent_tools.tools import ( + # Selector-driven discovery (preferred) + discover, + discover_labels, + # Invocation + invoke_device, + invoke_device_with_fallback, + get_device_status, + # Advisory-deprecated discovery wrappers (one-release transition) describe_fleet, list_devices, get_device_functions, discover_devices, - invoke_device, - invoke_device_with_fallback, - get_device_status, ) __all__ = [ @@ -44,14 +43,16 @@ "get_connection", # High-level agent "DeviceConnectAgent", - # Hierarchical discovery tools (recommended) - "describe_fleet", - "list_devices", - "get_device_functions", - # Invocation tools + # Selector-driven discovery (preferred) + "discover", + "discover_labels", + # Invocation "invoke_device", "invoke_device_with_fallback", "get_device_status", - # Backward-compatible (deprecated — use hierarchical tools instead) + # Advisory-deprecated -- use discover() / discover_labels() instead + "describe_fleet", + "list_devices", + "get_device_functions", "discover_devices", ] diff --git a/packages/device-connect-agent-tools/device_connect_agent_tools/_normalize.py b/packages/device-connect-agent-tools/device_connect_agent_tools/_normalize.py index 335a5c4..1dd38d1 100644 --- a/packages/device-connect-agent-tools/device_connect_agent_tools/_normalize.py +++ b/packages/device-connect-agent-tools/device_connect_agent_tools/_normalize.py @@ -143,3 +143,60 @@ def group_devices( key = d.get(group_by) or "unknown" groups[key].append(summary) return {"groups": dict(sorted(groups.items())), "total": len(devices)} + + +# -- Label histograms --------------------------------------------------- + + +def _accumulate_label( + histogram: dict, multivalued_keys: set, label_key: str, label_value: Any +) -> None: + """Record one ``label_key -> label_value`` observation in ``histogram``. + + ``label_value`` may be a string or a list of strings. Lists are flagged + in ``multivalued_keys`` so the caller can annotate them in the response. + """ + if isinstance(label_value, list): + multivalued_keys.add(label_key) + for v in label_value: + histogram[label_key][str(v)] = histogram[label_key].get(str(v), 0) + 1 + else: + histogram[label_key][str(label_value)] = histogram[label_key].get(str(label_value), 0) + 1 + + +def label_histogram( + items: list[dict], *, count_unique: bool = False +) -> tuple: + """Build ``{key: {value: count}}`` histograms across item labels. + + Multi-valued labels (list values) increment the histogram for each + member -- a device with ``category: [camera, inference]`` adds 1 to + both ``camera`` and ``inference``. Keys observed with any list value + are surfaced via ``multivalued_keys`` so callers can annotate the + response. + + Args: + items: Records with optional ``labels`` field (devices, functions, + or events). + count_unique: When True, also tracks how many distinct items + declared each key. Useful only for the device axis, where a + multi-valued label can otherwise mask the unique-device count. + + Returns: + ``(histogram, multivalued_keys)`` when ``count_unique=False``; + ``(histogram, multivalued_keys, unique_per_key)`` when + ``count_unique=True``. + """ + histogram: dict[str, dict[str, int]] = defaultdict(lambda: defaultdict(int)) + multivalued: set[str] = set() + unique: dict[str, int] | None = defaultdict(int) if count_unique else None + for item in items: + labels = item.get("labels") or {} + for k, v in labels.items(): + if unique is not None: + unique[k] += 1 + _accumulate_label(histogram, multivalued, k, v) + flat = {k: dict(vals) for k, vals in histogram.items()} + if unique is not None: + return flat, multivalued, dict(unique) + return flat, multivalued diff --git a/packages/device-connect-agent-tools/device_connect_agent_tools/adapters/__init__.py b/packages/device-connect-agent-tools/device_connect_agent_tools/adapters/__init__.py index a9d65f6..d54403e 100644 --- a/packages/device-connect-agent-tools/device_connect_agent_tools/adapters/__init__.py +++ b/packages/device-connect-agent-tools/device_connect_agent_tools/adapters/__init__.py @@ -8,12 +8,12 @@ # Strands from device_connect_agent_tools.adapters.strands import ( - describe_fleet, list_devices, get_device_functions, invoke_device, + discover_labels, discover, invoke_device, ) # LangChain from device_connect_agent_tools.adapters.langchain import ( - describe_fleet, list_devices, get_device_functions, invoke_device, + discover_labels, discover, invoke_device, ) # Claude Agent SDK diff --git a/packages/device-connect-agent-tools/device_connect_agent_tools/adapters/claude.py b/packages/device-connect-agent-tools/device_connect_agent_tools/adapters/claude.py index 70ad42f..807abcb 100644 --- a/packages/device-connect-agent-tools/device_connect_agent_tools/adapters/claude.py +++ b/packages/device-connect-agent-tools/device_connect_agent_tools/adapters/claude.py @@ -4,7 +4,7 @@ """Claude Agent SDK adapter — exposes Device Connect tools to claude-agent-sdk. -Hierarchical discovery keeps LLM context small:: +Selector-driven discovery keeps LLM context small:: import anyio from claude_agent_sdk import ClaudeSDKClient, ClaudeAgentOptions, AssistantMessage, TextBlock @@ -42,9 +42,8 @@ async def main(): from claude_agent_sdk import tool, create_sdk_mcp_server from device_connect_agent_tools.tools import ( - describe_fleet as _describe_fleet, - list_devices as _list_devices, - get_device_functions as _get_device_functions, + discover as _discover, + discover_labels as _discover_labels, discover_devices as _discover_devices, invoke_device as _invoke_device, invoke_device_with_fallback as _invoke_device_with_fallback, @@ -56,55 +55,50 @@ def _text(result: Any) -> dict[str, Any]: return {"content": [{"type": "text", "text": json.dumps(result, default=str)}]} -# Hierarchical discovery tools (recommended) +# Selector-driven discovery tools (recommended) @tool( - "describe_fleet", - "Get a high-level summary of all available devices, grouped by type and " - "location. Use this first to understand what is available, then call " - "list_devices to browse specific types or locations.", - {}, + "discover_labels", + "Browse the label vocabulary across the fleet. Returns label keys " + "(category, location, direction, modality, ...) with their values and " + "counts. Call with no arguments to see all keys, or with key=" + "'device.location' / 'function.direction' / etc. to paginate one key. " + "Use this first to learn what dimensions are available before calling " + "discover().", + {"key": str, "offset": int, "limit": int}, ) -async def describe_fleet(args: dict[str, Any]) -> dict[str, Any]: - return _text(_describe_fleet()) - - -@tool( - "list_devices", - "Browse available devices with filtering and pagination. Returns compact " - "device summaries (no full schemas). Use get_device_functions for details.", - { - "device_type": str, - "location": str, - "status": str, - "group_by": str, - "offset": int, - "limit": int, - }, -) -async def list_devices(args: dict[str, Any]) -> dict[str, Any]: +async def discover_labels(args: dict[str, Any]) -> dict[str, Any]: return _text( - _list_devices( - device_type=args.get("device_type"), - location=args.get("location"), - status=args.get("status"), - group_by=args.get("group_by"), + _discover_labels( + key=args.get("key"), offset=int(args.get("offset", 0)), - limit=int(args.get("limit", 20)), + limit=int(args.get("limit", 50)), ) ) @tool( - "get_device_functions", - "Get full function schemas for a specific device. Call this after " - "list_devices to see what a device can do and what parameters each " - "function accepts.", - {"device_id": str}, + "discover", + "Resolve a selector to matched devices, functions, or events. Selector " + "grammar: device(), device().function(), " + "device().event(), function(), or " + "event(). Filters are key:value pairs (AND across keys with " + "commas, OR within a key with bracket lists, glob with *). Examples: " + "'device(category:camera, location:zone-A/*)', " + "'device(*).function(direction:write)', 'event(modality:motion)'. " + "Response includes a label_histogram (per-key vocabulary across the " + "matched set) so the agent can narrow next.", + {"selector": str, "offset": int, "limit": int}, ) -async def get_device_functions(args: dict[str, Any]) -> dict[str, Any]: - return _text(_get_device_functions(device_id=args["device_id"])) +async def discover(args: dict[str, Any]) -> dict[str, Any]: + return _text( + _discover( + selector=args["selector"], + offset=int(args.get("offset", 0)), + limit=int(args.get("limit", 200)), + ) + ) # Invocation tools @@ -112,8 +106,9 @@ async def get_device_functions(args: dict[str, Any]) -> dict[str, Any]: @tool( "invoke_device", - "Call a function on a Device Connect device. Use get_device_functions " - "first to learn available functions and parameters.", + "Call a function on a Device Connect device. Use discover() with a " + "function-scoped selector first to learn available functions and " + "parameters.", {"device_id": str, "function": str, "params": dict, "llm_reasoning": str}, ) async def invoke_device(args: dict[str, Any]) -> dict[str, Any]: @@ -153,13 +148,13 @@ async def get_device_status(args: dict[str, Any]) -> dict[str, Any]: return _text(_get_device_status(device_id=args["device_id"])) -# Backward-compatible (deprecated — use hierarchical tools instead) +# Backward-compatible (long-deprecated — prefer discover() / discover_labels()) @tool( "discover_devices", - "Deprecated — use describe_fleet, list_devices, and get_device_functions " - "instead. Discover all devices with full function schemas.", + "Deprecated — use discover() and discover_labels() instead. Discovers " + "all devices with full function schemas.", {"device_type": str, "refresh": bool}, ) async def discover_devices(args: dict[str, Any]) -> dict[str, Any]: @@ -179,9 +174,8 @@ def create_device_connect_server(name: str = "device-connect"): return create_sdk_mcp_server( name, tools=[ - describe_fleet, - list_devices, - get_device_functions, + discover_labels, + discover, invoke_device, invoke_device_with_fallback, get_device_status, @@ -191,12 +185,11 @@ def create_device_connect_server(name: str = "device-connect"): __all__ = [ - "describe_fleet", - "list_devices", - "get_device_functions", - "discover_devices", + "discover_labels", + "discover", "invoke_device", "invoke_device_with_fallback", "get_device_status", + "discover_devices", "create_device_connect_server", ] diff --git a/packages/device-connect-agent-tools/device_connect_agent_tools/adapters/langchain.py b/packages/device-connect-agent-tools/device_connect_agent_tools/adapters/langchain.py index 6e0b8a3..f934024 100644 --- a/packages/device-connect-agent-tools/device_connect_agent_tools/adapters/langchain.py +++ b/packages/device-connect-agent-tools/device_connect_agent_tools/adapters/langchain.py @@ -4,16 +4,16 @@ """LangChain adapter — wraps Device Connect tools as LangChain StructuredTools. -Hierarchical discovery keeps LLM context small: +Selector-driven discovery keeps LLM context small: from device_connect_agent_tools import connect from device_connect_agent_tools.adapters.langchain import ( - describe_fleet, list_devices, get_device_functions, invoke_device, + discover_labels, discover, invoke_device, ) from langgraph.prebuilt import create_react_agent connect() - agent = create_react_agent(model, [describe_fleet, list_devices, get_device_functions, invoke_device]) + agent = create_react_agent(model, [discover_labels, discover, invoke_device]) Requires: pip install device-connect-agent-tools[langchain] """ @@ -21,34 +21,31 @@ from langchain_core.tools import StructuredTool from device_connect_agent_tools.tools import ( - describe_fleet as _describe_fleet, - list_devices as _list_devices, - get_device_functions as _get_device_functions, + discover as _discover, + discover_labels as _discover_labels, discover_devices as _discover_devices, invoke_device as _invoke_device, invoke_device_with_fallback as _invoke_device_with_fallback, get_device_status as _get_device_status, ) -# Hierarchical discovery tools (recommended) -describe_fleet = StructuredTool.from_function(_describe_fleet) -list_devices = StructuredTool.from_function(_list_devices) -get_device_functions = StructuredTool.from_function(_get_device_functions) +# Selector-driven discovery tools (recommended) +discover_labels = StructuredTool.from_function(_discover_labels) +discover = StructuredTool.from_function(_discover) # Invocation tools invoke_device = StructuredTool.from_function(_invoke_device) invoke_device_with_fallback = StructuredTool.from_function(_invoke_device_with_fallback) get_device_status = StructuredTool.from_function(_get_device_status) -# Backward-compatible (deprecated — use hierarchical tools instead) +# Backward-compatible (long-deprecated — prefer discover() / discover_labels()) discover_devices = StructuredTool.from_function(_discover_devices) __all__ = [ - "describe_fleet", - "list_devices", - "get_device_functions", - "discover_devices", + "discover_labels", + "discover", "invoke_device", "invoke_device_with_fallback", "get_device_status", + "discover_devices", ] diff --git a/packages/device-connect-agent-tools/device_connect_agent_tools/adapters/strands.py b/packages/device-connect-agent-tools/device_connect_agent_tools/adapters/strands.py index 308c2a7..848f362 100644 --- a/packages/device-connect-agent-tools/device_connect_agent_tools/adapters/strands.py +++ b/packages/device-connect-agent-tools/device_connect_agent_tools/adapters/strands.py @@ -4,16 +4,16 @@ """Strands adapter — wraps Device Connect tools with @strands.tool. -Hierarchical discovery keeps LLM context small: +Selector-driven discovery keeps LLM context small: from device_connect_agent_tools import connect from device_connect_agent_tools.adapters.strands import ( - describe_fleet, list_devices, get_device_functions, invoke_device, + discover_labels, discover, invoke_device, ) from strands import Agent connect() - agent = Agent(tools=[describe_fleet, list_devices, get_device_functions, invoke_device]) + agent = Agent(tools=[discover_labels, discover, invoke_device]) agent("What devices are online?") Requires: pip install device-connect-agent-tools[strands] @@ -22,34 +22,31 @@ from strands import tool as strands_tool from device_connect_agent_tools.tools import ( - describe_fleet as _describe_fleet, - list_devices as _list_devices, - get_device_functions as _get_device_functions, + discover as _discover, + discover_labels as _discover_labels, discover_devices as _discover_devices, invoke_device as _invoke_device, invoke_device_with_fallback as _invoke_device_with_fallback, get_device_status as _get_device_status, ) -# Hierarchical discovery tools (recommended) -describe_fleet = strands_tool(_describe_fleet) -list_devices = strands_tool(_list_devices) -get_device_functions = strands_tool(_get_device_functions) +# Selector-driven discovery tools (recommended) +discover_labels = strands_tool(_discover_labels) +discover = strands_tool(_discover) # Invocation tools invoke_device = strands_tool(_invoke_device) invoke_device_with_fallback = strands_tool(_invoke_device_with_fallback) get_device_status = strands_tool(_get_device_status) -# Backward-compatible (deprecated — use hierarchical tools instead) +# Backward-compatible (long-deprecated — prefer discover() / discover_labels()) discover_devices = strands_tool(_discover_devices) __all__ = [ - "describe_fleet", - "list_devices", - "get_device_functions", - "discover_devices", + "discover_labels", + "discover", "invoke_device", "invoke_device_with_fallback", "get_device_status", + "discover_devices", ] diff --git a/packages/device-connect-agent-tools/device_connect_agent_tools/adapters/strands_agent.py b/packages/device-connect-agent-tools/device_connect_agent_tools/adapters/strands_agent.py index 7b6c532..a3f0cf5 100644 --- a/packages/device-connect-agent-tools/device_connect_agent_tools/adapters/strands_agent.py +++ b/packages/device-connect-agent-tools/device_connect_agent_tools/adapters/strands_agent.py @@ -60,9 +60,8 @@ async def prepare(self) -> Dict[str, Any]: from strands import Agent from strands.models import AnthropicModel from device_connect_agent_tools.adapters.strands import ( - describe_fleet, - list_devices, - get_device_functions, + discover_labels, + discover, invoke_device, invoke_device_with_fallback, get_device_status, @@ -74,7 +73,7 @@ async def prepare(self) -> Dict[str, Any]: self._agent = Agent( model=AnthropicModel(model_id=self._model_id, max_tokens=self._max_tokens), tools=[ - describe_fleet, list_devices, get_device_functions, + discover_labels, discover, invoke_device, invoke_device_with_fallback, get_device_status, ], system_prompt=system_prompt, @@ -92,8 +91,8 @@ def _build_system_prompt(self) -> str: """Build a system prompt from discovered devices. Uses a compact fleet summary instead of dumping all device schemas. - The agent can use describe_fleet(), list_devices(), and - get_device_functions() to drill into details as needed. + The agent can use discover_labels() and discover() to drill into + details as needed. """ # Build compact fleet summary (type counts + locations) from collections import defaultdict @@ -108,21 +107,26 @@ def _build_system_prompt(self) -> str: for dt, info in sorted(by_type.items()): locs = ", ".join(sorted(info["locations"])) type_lines.append(f" - {info['count']}x {dt} (at: {locs})") - fleet_summary = "\n".join(type_lines) or " (none yet — call describe_fleet() to refresh)" + fleet_summary = "\n".join(type_lines) or " (none yet -- call discover() to refresh)" return ( f"You are an AI agent connected to the Device Connect IoT network.\n\n" f"YOUR GOAL: {self.goal}\n\n" f"FLEET OVERVIEW ({len(self.devices)} devices):\n{fleet_summary}\n\n" f"DISCOVERY TOOLS:\n" - f" - describe_fleet() — fleet summary (what you see above)\n" - f" - list_devices(device_type=..., location=...) — browse devices\n" - f" - get_device_functions(device_id) — see what a device can do\n" - f" - invoke_device(device_id, function, params) — call a device function\n\n" + f" - discover_labels(key=None) -- fleet label vocabulary " + f"(category, location, direction, modality, ...)\n" + f" - discover(selector) -- resolve a selector to devices, " + f"functions, or events. Examples:\n" + f" device(category:camera, location:zone-A/*)\n" + f" device(robot-001).function(direction:write)\n" + f" function(safety:critical)\n" + f" - invoke_device(device_id, function, params) -- call a device function\n\n" f"INSTRUCTIONS:\n" f"When you receive device events, you MUST:\n" f"1. Analyze the events\n" - f"2. Use get_device_functions() to check available functions if needed\n" + f"2. Use discover() with a function-scoped selector to check " + f"available functions if needed\n" f"3. Use invoke_device() to interact with devices\n" f"4. Report what you found and what actions you took\n\n" f"Always provide llm_reasoning when invoking devices to explain your decision.\n" diff --git a/packages/device-connect-agent-tools/device_connect_agent_tools/connection.py b/packages/device-connect-agent-tools/device_connect_agent_tools/connection.py index 2dba8fc..dae997c 100644 --- a/packages/device-connect-agent-tools/device_connect_agent_tools/connection.py +++ b/packages/device-connect-agent-tools/device_connect_agent_tools/connection.py @@ -125,6 +125,16 @@ def flatten_device(raw: Dict[str, Any]) -> Dict[str, Any]: status = raw.get("status") or {} caps = raw.get("capabilities") or {} + # Mirror the legacy DeviceStatus.location field into labels["location"] + # when the driver did not declare it via DeviceCapabilities.labels. Drivers + # using only the legacy field would otherwise be invisible to selector + # queries on location. + legacy_location = raw.get("location") or status.get("location") + caps_labels = caps.get("labels") + merged_labels = caps_labels + if legacy_location and (not caps_labels or "location" not in caps_labels): + merged_labels = {**(caps_labels or {}), "location": legacy_location} + # NOTE: The raw ``capabilities`` dict is intentionally NOT included in # the flattened output. ``functions`` and ``events`` are extracted to # the top level for direct access. Including both would duplicate data @@ -132,11 +142,16 @@ def flatten_device(raw: Dict[str, Any]) -> Dict[str, Any]: return { "device_id": raw.get("device_id"), "device_type": raw.get("device_type") or identity.get("device_type"), - "location": raw.get("location") or status.get("location"), + "location": legacy_location, "status": status, "identity": identity, "functions": caps.get("functions", []), "events": caps.get("events", []), + # Discovery labels declared by the driver (DeviceCapabilities.labels), + # with status.location mirrored in when caps did not carry it. None + # when neither source provided any label -- discover() treats that + # as "no label-based match," not "matches everything." + "labels": merged_labels, } diff --git a/packages/device-connect-agent-tools/device_connect_agent_tools/tools.py b/packages/device-connect-agent-tools/device_connect_agent_tools/tools.py index 803160e..0d3aa73 100644 --- a/packages/device-connect-agent-tools/device_connect_agent_tools/tools.py +++ b/packages/device-connect-agent-tools/device_connect_agent_tools/tools.py @@ -4,27 +4,16 @@ """Device Connect device operations — framework-agnostic tool functions. -Hierarchical discovery tools that keep LLM context small: +Discovery is selector-driven. ``discover()`` and ``discover_labels()`` cover +both fleet-wide and entity-scoped queries; the older ``describe_fleet`` / +``list_devices`` / ``get_device_functions`` trio remains as advisory-deprecated +wrappers for one release while callers migrate. -1. ``describe_fleet()`` — bird's-eye summary (types, locations, counts) -2. ``list_devices(...)`` — paginated compact roster (no schemas) -3. ``get_device_functions(id)`` — full schemas for ONE device -4. ``invoke_device(...)`` — call a function on a device - -Plain Python functions with type hints and docstrings. Use them directly -or wrap with a framework adapter: - - # Plain Python - from device_connect_agent_tools import connect, describe_fleet, list_devices + from device_connect_agent_tools import connect, discover, discover_labels connect() - fleet = describe_fleet() - devices = list_devices(device_type="camera") - - # Strands - from device_connect_agent_tools.adapters.strands import ( - describe_fleet, list_devices, get_device_functions, invoke_device, - ) - agent = Agent(tools=[describe_fleet, list_devices, get_device_functions, invoke_device]) + cams = discover("device(category:camera)") + rgb_writes = discover("device(*).function(direction:write, modality:rgb)") + vocab = discover_labels() """ from __future__ import annotations @@ -32,12 +21,20 @@ import logging import os import uuid +import warnings from typing import Any +from device_connect_edge.selector import ( + Scope, + Selector, + SelectorParseError, + parse_selector, +) from device_connect_agent_tools.connection import get_connection from device_connect_agent_tools._normalize import ( full_device, compact_device, fuzzy_filter_by_type, extract_status, aggregate_fleet, group_devices, + label_histogram, ) logger = logging.getLogger(__name__) @@ -55,10 +52,382 @@ ) SMALL_FLEET_THRESHOLD = 5 +# When ``discover()`` resolves a selector to this many functions or events +# or fewer, the response includes full schemas inline. Above the threshold +# it returns a compact ``(device_id, name, labels)`` summary so the agent +# can narrow further via ``discover_labels()`` or a tighter selector. +try: + DC_FUNCTION_THRESHOLD = min(max(int(os.getenv("DEVICE_CONNECT_FUNCTION_THRESHOLD", "20")), 0), 200) +except (ValueError, TypeError): + logger.warning( + "Invalid DEVICE_CONNECT_FUNCTION_THRESHOLD value %r, defaulting to 20", + os.getenv("DEVICE_CONNECT_FUNCTION_THRESHOLD"), + ) + DC_FUNCTION_THRESHOLD = 20 + +# Hard ceiling on per-call ``limit`` to prevent runaway responses in large +# fleets. A caller asking for limit=100000 still gets at most this many +# rows per page (with ``next_offset`` to continue). +DISCOVER_HARD_LIMIT = 1000 + +# Default limits per the discovery design (different defaults for the two +# tools because they answer different questions: ``discover`` returns rows, +# ``discover_labels`` returns vocabulary). +DEFAULT_DISCOVER_LIMIT = 200 +DEFAULT_DISCOVER_LABELS_LIMIT = 50 + # ── Shared helpers ────────────────────────────────────────────── +def _normalize_pagination(offset: int, limit: int, default_limit: int) -> tuple[int, int]: + """Clamp offset and limit to safe ranges. + + Negative offset rounds to 0, non-positive limit falls back to the default, + and limit is capped at ``DISCOVER_HARD_LIMIT``. + """ + safe_offset = max(0, int(offset or 0)) + if not limit or limit <= 0: + safe_limit = default_limit + else: + safe_limit = min(int(limit), DISCOVER_HARD_LIMIT) + return safe_offset, safe_limit + + +def _empty_envelope(scope: str | None = None, error: str | None = None) -> dict[str, Any]: + """Build the canonical zero-result response envelope.""" + out: dict[str, Any] = { + "matched": 0, + "returned": 0, + "offset": 0, + "next_offset": None, + "results": [], + } + if scope is not None: + out["scope"] = scope + if error is not None: + out["error"] = error + return out + + +def _paginate(items: list, offset: int, limit: int) -> tuple[list, int | None]: + """Slice ``items`` to one page; return ``(page, next_offset)``.""" + end = offset + limit + page = items[offset:end] + next_offset = end if end < len(items) else None + return page, next_offset + + +def _device_summary_for_discover(d: dict, expand: bool) -> dict[str, Any]: + """Compact device row for ``discover()``, with labels surfaced.""" + summary = compact_device(d, expand) + summary["status"] = extract_status(d) + summary["labels"] = d.get("labels") + return summary + + +def _function_row(d: dict, fn: dict, expand: bool) -> dict[str, Any]: + """Build one row for a function-scoped discover result. + + Below the threshold, ``expand`` is True and the row includes the full + JSON Schema. Above threshold, only name + labels travel back so the + agent can narrow without paying for parameter schemas. + """ + name = fn.get("name") if isinstance(fn, dict) else fn + labels = fn.get("labels") if isinstance(fn, dict) else None + if expand and isinstance(fn, dict): + return { + "device_id": d.get("device_id"), + "name": name, + "description": fn.get("description", ""), + "parameters": fn.get("parameters", {}), + "labels": labels, + } + return { + "device_id": d.get("device_id"), + "name": name, + "labels": labels, + } + + +def _event_row(d: dict, ev: dict, expand: bool) -> dict[str, Any]: + """Build one row for an event-scoped discover result.""" + name = ev.get("name") if isinstance(ev, dict) else ev + labels = ev.get("labels") if isinstance(ev, dict) else None + if expand and isinstance(ev, dict): + return { + "device_id": d.get("device_id"), + "name": name, + "description": ev.get("description", ""), + "payload_schema": ev.get("payload_schema"), + "labels": labels, + } + return { + "device_id": d.get("device_id"), + "name": name, + "labels": labels, + } + + +# ── Selector-driven discovery (preferred) ──────────────────────── + + +def discover( + selector: str, + offset: int = 0, + limit: int = DEFAULT_DISCOVER_LIMIT, +) -> dict[str, Any]: + """Resolve a selector to matched devices, functions, or events. + + The selector DSL supports five scope shapes: + + device() all matching devices + device().function() RPCs on a device subset + device().event() events on a device subset + function() all RPCs across the fleet + event() all events across the fleet + + Inside ``(...)``: ``key:value``, ``key:[v1,v2]`` (OR within a key), + ``key:pattern*`` (glob), ``k1:v1,k2:v2`` (AND across keys), bare-string + id/name match, or ``*`` to match all. + + Args: + selector: A selector expression string. + offset: Pagination offset (rows skipped). + limit: Max rows per page (capped at DISCOVER_HARD_LIMIT). + + Returns: + A response envelope: + ``{"scope", "matched", "returned", "offset", "next_offset", "results", + "label_histogram"}``. + ``label_histogram`` is the per-key vocabulary across the **matched** + set (pre-pagination), not the returned page; on the device axis it + tracks unique device counts per key (``unique_devices``), on + function/event axes it counts occurrences (a function appearing on N + devices contributes N entries). + For function- and event-scoped selectors, ``results`` rows include + full schemas when the matched count is at or below + ``DC_FUNCTION_THRESHOLD``; otherwise rows are name-and-labels summaries. + + Example: + >>> discover("device(category:camera, location:zone-A/*)") + {"scope": "device_only", "matched": 4, ...} + >>> discover("device(*).function(direction:write, modality:rgb)") + {"scope": "device_function", "matched": 8, ...} + """ + safe_offset, safe_limit = _normalize_pagination(offset, limit, DEFAULT_DISCOVER_LIMIT) + + # Parse the selector at the system boundary; surface a clean error to + # the caller rather than raising into agent code. + try: + sel: Selector = parse_selector(selector) + except SelectorParseError as e: + return _empty_envelope(error=str(e)) + except (TypeError, ValueError) as e: + return _empty_envelope(error=f"Invalid selector: {e}") + + try: + conn = get_connection() + devices = conn.list_devices() + except Exception as e: + logger.error("discover(%r) failed loading fleet: %s", selector, e) + return _empty_envelope(scope=sel.scope.value, error=str(e)) + + # Apply the device-axis filter (vacuously True when sel.device is None). + matched_devices = [ + d for d in devices + if sel.device is None + or sel.device.matches(d.get("device_id") or "", d.get("labels")) + ] + + # Branch on scope. Each branch produces (results_full, page, histogram, total). + if sel.scope == Scope.DEVICE_ONLY: + total = len(matched_devices) + page_devices, next_offset = _paginate(matched_devices, safe_offset, safe_limit) + expand = SMALL_FLEET_THRESHOLD > 0 and total <= SMALL_FLEET_THRESHOLD + results = [_device_summary_for_discover(d, expand) for d in page_devices] + histogram, multivalued, unique = label_histogram(matched_devices, count_unique=True) + formatted_histogram = _format_label_histogram(histogram, multivalued, unique) + return { + "scope": sel.scope.value, + "matched": total, + "returned": len(results), + "offset": safe_offset, + "next_offset": next_offset, + "results": results, + "label_histogram": formatted_histogram, + } + + # Function- or event-scoped selectors enumerate (device, entity) tuples. + is_function_scope = sel.scope in (Scope.DEVICE_FUNCTION, Scope.FUNCTION_ONLY) + entity_filter = sel.function if is_function_scope else sel.event + + matched_rows: list[tuple[dict, dict]] = [] + for d in matched_devices: + entities = d.get("functions" if is_function_scope else "events", []) + for entity in entities: + if not isinstance(entity, dict): + # Best-effort: lift bare-name list items into a stub dict so the + # filter can still match by name. + entity = {"name": str(entity), "labels": None} + if entity_filter is None or entity_filter.matches( + entity.get("name") or "", entity.get("labels") + ): + matched_rows.append((d, entity)) + + total = len(matched_rows) + page_rows, next_offset = _paginate(matched_rows, safe_offset, safe_limit) + expand = DC_FUNCTION_THRESHOLD > 0 and total <= DC_FUNCTION_THRESHOLD + if is_function_scope: + results = [_function_row(d, fn, expand) for d, fn in page_rows] + else: + results = [_event_row(d, ev, expand) for d, ev in page_rows] + + matched_entities = [entity for _, entity in matched_rows] + histogram, multivalued = label_histogram(matched_entities) + formatted_histogram = _format_label_histogram(histogram, multivalued) + + return { + "scope": sel.scope.value, + "matched": total, + "returned": len(results), + "offset": safe_offset, + "next_offset": next_offset, + "results": results, + "label_histogram": formatted_histogram, + } + + +def _format_label_histogram( + histogram: dict, + multivalued: set, + unique: dict | None = None, +) -> dict[str, Any]: + """Format a histogram for response, annotating multi-valued keys. + + Multi-valued keys are flagged so an agent reading + ``{camera: 312, inference: 200}`` knows the counts overlap. When + ``unique`` is supplied (device axis only), the per-key unique device + count is exposed as ``unique_devices`` so the agent can reconcile + histogram totals with the underlying device cardinality. + """ + out: dict[str, Any] = {} + for key, counts in histogram.items(): + entry: dict[str, Any] = { + # Sort values most-frequent first; alphabetical tie-break for stability. + "values": dict(sorted(counts.items(), key=lambda kv: (-kv[1], kv[0]))), + } + if key in multivalued: + entry["multivalued"] = True + if unique is not None and key in unique: + entry["unique_devices"] = unique[key] + out[key] = entry + return out + + +def discover_labels( + key: str | None = None, + offset: int = 0, + limit: int = DEFAULT_DISCOVER_LABELS_LIMIT, +) -> dict[str, Any]: + """Return the fleet's label vocabulary. + + Without ``key``: returns one entry per axis (``device_keys``, + ``function_keys``, ``event_keys``) with all keys and their top values. + With ``key`` (e.g. ``"device.location"``, ``"function.direction"``): + paginates the full value list for that one key. + + Args: + key: Optional dotted axis.key (``device.``, ``function.``, + ``event.``). When given, the response paginates that one key's + values rather than returning a multi-axis vocabulary. + offset: Pagination offset for the per-key value list. + limit: Max values per page when ``key`` is given (capped at + ``DISCOVER_HARD_LIMIT``). + + Returns: + Multi-axis form (no ``key``): + ``{"total_devices", "total_functions", "total_events", + "device_keys": {key: {"values": {...}, "multivalued"?: True, + "unique_devices"?: N}}, + "function_keys": {...}, "event_keys": {...}}`` + Per-key form (``key`` provided): + ``{"axis", "key", "matched", "returned", "offset", "next_offset", + "values", "multivalued"?: True}`` + """ + safe_offset, safe_limit = _normalize_pagination(offset, limit, DEFAULT_DISCOVER_LABELS_LIMIT) + + try: + conn = get_connection() + devices = conn.list_devices() + except Exception as e: + logger.error("discover_labels failed loading fleet: %s", e) + return _empty_envelope(error=str(e)) + + # Aggregate function and event entities once. + functions: list[dict] = [] + events: list[dict] = [] + for d in devices: + for fn in d.get("functions", []) or []: + if isinstance(fn, dict): + functions.append(fn) + for ev in d.get("events", []) or []: + if isinstance(ev, dict): + events.append(ev) + + dev_hist, dev_mv, dev_unique = label_histogram(devices, count_unique=True) + fn_hist, fn_mv = label_histogram(functions) + ev_hist, ev_mv = label_histogram(events) + + if key is None: + return { + "total_devices": len(devices), + "total_functions": len(functions), + "total_events": len(events), + "device_keys": _format_label_histogram(dev_hist, dev_mv, dev_unique), + "function_keys": _format_label_histogram(fn_hist, fn_mv), + "event_keys": _format_label_histogram(ev_hist, ev_mv), + } + + # Per-key form: split on the first dot to pick an axis. + if "." not in key: + return _empty_envelope( + error=f"Key must be axis-qualified (device., function., event.): {key!r}" + ) + axis, label_key = key.split(".", 1) + if axis == "device": + source, multivalued = dev_hist, dev_mv + total = len(devices) + elif axis == "function": + source, multivalued = fn_hist, fn_mv + total = len(functions) + elif axis == "event": + source, multivalued = ev_hist, ev_mv + total = len(events) + else: + return _empty_envelope( + error=f"Unknown axis {axis!r} (expected device|function|event)" + ) + + counts = source.get(label_key, {}) + sorted_values = sorted(counts.items(), key=lambda kv: (-kv[1], kv[0])) + page = sorted_values[safe_offset:safe_offset + safe_limit] + next_offset = safe_offset + safe_limit if safe_offset + safe_limit < len(sorted_values) else None + out: dict[str, Any] = { + "axis": axis, + "key": label_key, + "matched": len(sorted_values), + "returned": len(page), + "offset": safe_offset, + "next_offset": next_offset, + "values": dict(page), + "axis_total": total, + } + if label_key in multivalued: + out["multivalued"] = True + return out + + # ── Hierarchical discovery tools ───────────────────────────────── @@ -81,7 +450,18 @@ def describe_fleet() -> dict[str, Any]: Example: fleet = describe_fleet() # {"total_devices": 47, "by_type": {"camera": {"count": 12, ...}}, ...} + + .. deprecated:: + Prefer ``discover_labels()`` (vocabulary) and + ``discover("device(*)")`` (roster). This wrapper will be removed in + a future release. """ + warnings.warn( + "describe_fleet() is deprecated; use discover_labels() for vocabulary " + "or discover('device(*)') for the roster.", + DeprecationWarning, + stacklevel=2, + ) try: conn = get_connection() devices = conn.list_devices() @@ -134,7 +514,18 @@ def list_devices( # Group by location result = list_devices(group_by="location") + + .. deprecated:: + Prefer ``discover("device(category:camera, location:zone-A/*)")`` -- + the selector DSL covers type/location/group-by uniformly. This + wrapper will be removed in a future release. """ + warnings.warn( + "list_devices() is deprecated; use discover() with a selector " + "(e.g. discover('device(category:camera)')).", + DeprecationWarning, + stacklevel=2, + ) try: conn = get_connection() devices = conn.list_devices(location=location) @@ -194,7 +585,17 @@ def get_device_functions(device_id: str) -> dict[str, Any]: Example: info = get_device_functions("camera-001") # {"device_id": "camera-001", "functions": [{"name": "capture_image", ...}]} + + .. deprecated:: + Prefer ``discover("device().function(*)")``. This wrapper + will be removed in a future release. """ + warnings.warn( + "get_device_functions() is deprecated; use " + "discover('device().function(*)').", + DeprecationWarning, + stacklevel=2, + ) try: conn = get_connection() device = conn.get_device(device_id) @@ -323,13 +724,7 @@ def discover_devices( device_type: str | None = None, refresh: bool = False, ) -> list[dict[str, Any]]: - """Discover available devices (deprecated — use list_devices instead). - - Returns all devices with their function schemas. For large fleets, - prefer the hierarchical approach: - 1. describe_fleet() — see what's available - 2. list_devices(...) — browse with filters - 3. get_device_functions(id) — get schemas for one device + """Discover available devices (deprecated; use discover() instead). Args: device_type: Optional filter (e.g., "robot", "camera"). Fuzzy matching. @@ -338,6 +733,12 @@ def discover_devices( Returns: List of devices with device_id, device_type, functions, events. """ + warnings.warn( + "discover_devices() is deprecated; use discover() with a selector " + "(e.g. discover('device(*)') or discover('device(category:camera)')).", + DeprecationWarning, + stacklevel=2, + ) try: conn = get_connection() # Invalidate cache when refresh is requested diff --git a/packages/device-connect-agent-tools/tests/test_claude_adapter.py b/packages/device-connect-agent-tools/tests/test_claude_adapter.py index ba6fafc..b0e2ac6 100644 --- a/packages/device-connect-agent-tools/tests/test_claude_adapter.py +++ b/packages/device-connect-agent-tools/tests/test_claude_adapter.py @@ -65,9 +65,8 @@ def _mock_sdk_and_connection(): TOOL_NAMES = ( - "describe_fleet", - "list_devices", - "get_device_functions", + "discover_labels", + "discover", "discover_devices", "invoke_device", "invoke_device_with_fallback", diff --git a/packages/device-connect-agent-tools/tests/test_discover.py b/packages/device-connect-agent-tools/tests/test_discover.py new file mode 100644 index 0000000..2bc93e0 --- /dev/null +++ b/packages/device-connect-agent-tools/tests/test_discover.py @@ -0,0 +1,361 @@ +# Copyright (c) 2024-2026, Arm Limited and Contributors. All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +"""Unit tests for the selector-driven ``discover`` and ``discover_labels`` tools. + +Uses a labeled mock fleet (cam-001, robot-001, sensor-001) drawn from the +existing DC test driver vocabulary so every selector exercises real device, +function, and event names. +""" +from unittest.mock import MagicMock, patch + +import pytest + +from device_connect_agent_tools import tools as tools_mod + + +# -- Fixture: labeled fleet --------------------------------------- + + +SAMPLE_DEVICES = [ + { + "device_id": "cam-001", + "device_type": "camera", + "location": "lab-A", + "status": {"state": "online"}, + "identity": {"device_type": "camera"}, + "labels": {"category": ["camera", "inference"], "location": "zone-A/dock"}, + "functions": [ + { + "name": "capture_image", + "description": "Capture an image", + "parameters": {"type": "object"}, + "labels": {"direction": "write", "modality": "rgb"}, + }, + ], + "events": [ + {"name": "object_detected", "labels": {"modality": "rgb"}}, + {"name": "state_change_detected", "labels": None}, + ], + }, + { + "device_id": "cam-002", + "device_type": "camera", + "location": "lab-A", + "status": {"state": "online"}, + "identity": {"device_type": "camera"}, + "labels": {"category": "camera", "location": "zone-B/dock"}, + "functions": [ + { + "name": "capture_image", + "description": "Capture an image", + "parameters": {"type": "object"}, + "labels": {"direction": "write", "modality": ["rgb", "4k"]}, + }, + ], + "events": [ + {"name": "object_detected", "labels": {"modality": "rgb"}}, + ], + }, + { + "device_id": "robot-001", + "device_type": "cleaner_robot", + "location": "lab-A", + "status": {"state": "idle"}, + "identity": {"device_type": "cleaner_robot"}, + "labels": {"category": "robot", "location": "zone-A/yard"}, + "functions": [ + { + "name": "dispatch_robot", + "description": "Dispatch", + "parameters": {}, + "labels": {"direction": "write", "safety": "critical"}, + }, + { + "name": "get_status", + "description": "Status", + "parameters": {}, + "labels": {"direction": "read"}, + }, + ], + "events": [ + {"name": "cleaning_finished", "labels": None}, + ], + }, + { + "device_id": "sensor-001", + "device_type": "temperature_sensor", + "location": "lab-B", + "status": {"state": "online"}, + "identity": {"device_type": "temperature_sensor"}, + "labels": {"category": "sensor", "location": "lab-B"}, + "functions": [ + {"name": "get_reading", "parameters": {}, "labels": {"direction": "read"}}, + {"name": "set_threshold", "parameters": {}, "labels": {"direction": "write"}}, + {"name": "set_location", "parameters": {}, "labels": {"direction": "write"}}, + ], + "events": [ + {"name": "reading", "labels": None}, + {"name": "threshold_exceeded", "labels": {"safety": "informational"}}, + ], + }, +] + + +@pytest.fixture +def mock_conn(): + conn = MagicMock() + conn.list_devices.return_value = SAMPLE_DEVICES + with patch.object(tools_mod, "get_connection", return_value=conn): + yield conn + + +# -- discover: device-only scope ----------------------------------- + + +class TestDiscoverDeviceOnly: + def test_match_by_category_label(self, mock_conn): + r = tools_mod.discover("device(category:camera)") + assert r["scope"] == "device_only" + assert r["matched"] == 2 + assert {row["device_id"] for row in r["results"]} == {"cam-001", "cam-002"} + + def test_multivalued_match_picks_composite_only(self, mock_conn): + # Only cam-001 has category:[camera, inference]. + r = tools_mod.discover("device(category:inference)") + assert r["matched"] == 1 + assert r["results"][0]["device_id"] == "cam-001" + + def test_or_within_key(self, mock_conn): + r = tools_mod.discover("device(category:[camera,robot])") + assert {row["device_id"] for row in r["results"]} == { + "cam-001", "cam-002", "robot-001" + } + + def test_glob_location(self, mock_conn): + r = tools_mod.discover("device(location:zone-A/*)") + assert {row["device_id"] for row in r["results"]} == {"cam-001", "robot-001"} + + def test_and_across_keys(self, mock_conn): + r = tools_mod.discover( + "device(category:[camera,robot], location:zone-A/*)" + ) + assert {row["device_id"] for row in r["results"]} == {"cam-001", "robot-001"} + + def test_match_all(self, mock_conn): + r = tools_mod.discover("device(*)") + assert r["matched"] == 4 + + def test_bare_id_match(self, mock_conn): + r = tools_mod.discover("device(cam-001)") + assert r["matched"] == 1 + assert r["results"][0]["device_id"] == "cam-001" + + def test_labels_surfaced_in_result(self, mock_conn): + r = tools_mod.discover("device(cam-001)") + assert r["results"][0]["labels"] == { + "category": ["camera", "inference"], + "location": "zone-A/dock", + } + + +# -- discover: function scope -------------------------------------- + + +class TestDiscoverFunctionScope: + def test_writes_fleet_wide(self, mock_conn): + r = tools_mod.discover("device(*).function(direction:write)") + assert r["scope"] == "device_function" + assert r["matched"] == 5 # capture x2, dispatch_robot, set_threshold, set_location + names = {row["name"] for row in r["results"]} + assert names == {"capture_image", "dispatch_robot", "set_threshold", "set_location"} + + def test_function_only_scope_by_name(self, mock_conn): + r = tools_mod.discover("function(get_reading)") + assert r["scope"] == "function_only" + assert r["matched"] == 1 + assert r["results"][0]["name"] == "get_reading" + assert r["results"][0]["device_id"] == "sensor-001" + + def test_anchored_glob_set_prefix(self, mock_conn): + r = tools_mod.discover("function(set_*)") + assert {row["name"] for row in r["results"]} == {"set_threshold", "set_location"} + + def test_below_threshold_returns_full_schemas(self, mock_conn): + r = tools_mod.discover("device(cam-001).function(*)") + assert r["matched"] == 1 + row = r["results"][0] + assert "parameters" in row + assert "description" in row + assert row["labels"] == {"direction": "write", "modality": "rgb"} + + def test_modality_or_within_key(self, mock_conn): + r = tools_mod.discover("device(*).function(modality:[rgb,thermal])") + assert r["matched"] == 2 + assert all(row["name"] == "capture_image" for row in r["results"]) + + def test_safety_critical_filter(self, mock_conn): + r = tools_mod.discover("function(safety:critical)") + assert r["matched"] == 1 + assert r["results"][0]["name"] == "dispatch_robot" + + def test_label_histogram_built(self, mock_conn): + r = tools_mod.discover("device(*).function(direction:write)") + hist = r["label_histogram"] + assert hist["direction"]["values"] == {"write": 5} + # modality is multi-valued on cam-002 (rgb + 4k) + modality = hist["modality"] + assert modality.get("multivalued") is True + assert modality["values"] == {"rgb": 2, "4k": 1} + + +# -- discover: event scope ----------------------------------------- + + +class TestDiscoverEventScope: + def test_event_by_modality(self, mock_conn): + r = tools_mod.discover("device(*).event(modality:rgb)") + assert r["scope"] == "device_event" + assert r["matched"] == 2 # cam-001 + cam-002 each emit object_detected + assert all(row["name"] == "object_detected" for row in r["results"]) + + def test_event_only_by_name(self, mock_conn): + r = tools_mod.discover("event(threshold_exceeded)") + assert r["scope"] == "event_only" + assert r["matched"] == 1 + + +# -- discover: pagination ------------------------------------------ + + +class TestDiscoverPagination: + def test_pagination_envelope(self, mock_conn): + r = tools_mod.discover("device(*)", limit=2) + assert r["matched"] == 4 + assert r["returned"] == 2 + assert r["offset"] == 0 + assert r["next_offset"] == 2 + + def test_offset_respected(self, mock_conn): + r = tools_mod.discover("device(*)", offset=2, limit=10) + assert r["offset"] == 2 + assert r["returned"] == 2 + assert r["next_offset"] is None + + def test_negative_offset_clamped(self, mock_conn): + r = tools_mod.discover("device(*)", offset=-5) + assert r["offset"] == 0 + + def test_hard_limit_caps_runaway_request(self, mock_conn): + r = tools_mod.discover("device(*)", limit=999_999) + # Hard ceiling is 1000; for 4 devices, the page just returns everything. + assert r["returned"] == 4 + + def test_zero_limit_falls_back_to_default(self, mock_conn): + r = tools_mod.discover("device(*)", limit=0) + # Default applies, all 4 fit in one page. + assert r["returned"] == 4 + + +# -- discover: errors ---------------------------------------------- + + +class TestDiscoverErrors: + def test_bad_selector_returns_error_envelope(self, mock_conn): + r = tools_mod.discover("not a selector at all") + assert "error" in r + assert r["matched"] == 0 + assert r["results"] == [] + + def test_unknown_scope_in_selector(self, mock_conn): + r = tools_mod.discover("widgets(*)") + assert "error" in r + assert "unknown scope" in r["error"].lower() + + def test_connection_failure_returns_error(self): + broken = MagicMock() + broken.list_devices.side_effect = RuntimeError("messaging down") + with patch.object(tools_mod, "get_connection", return_value=broken): + r = tools_mod.discover("device(*)") + assert "error" in r + assert r["matched"] == 0 + + def test_non_string_selector(self, mock_conn): + r = tools_mod.discover(None) # type: ignore[arg-type] + assert "error" in r + + +# -- discover_labels ------------------------------------------------ + + +class TestDiscoverLabels: + def test_multi_axis_default(self, mock_conn): + v = tools_mod.discover_labels() + assert v["total_devices"] == 4 + assert v["total_functions"] == 7 + assert v["total_events"] == 6 + assert "category" in v["device_keys"] + assert "direction" in v["function_keys"] + assert "modality" in v["event_keys"] + + def test_multivalued_annotation_on_device_category(self, mock_conn): + v = tools_mod.discover_labels() + cat = v["device_keys"]["category"] + assert cat["multivalued"] is True + # All 4 devices declared a category; cam-001 contributed to two values + # but unique_devices counts distinct devices. + assert cat["unique_devices"] == 4 + assert cat["values"] == {"camera": 2, "inference": 1, "robot": 1, "sensor": 1} + + def test_singleton_keys_not_flagged_multivalued(self, mock_conn): + v = tools_mod.discover_labels() + direction = v["function_keys"]["direction"] + assert direction.get("multivalued") is not True + + def test_per_key_pagination(self, mock_conn): + v = tools_mod.discover_labels(key="device.location") + assert v["axis"] == "device" + assert v["key"] == "location" + # 4 distinct location values, sorted by frequency desc then alpha + assert v["matched"] == 4 + assert list(v["values"].keys())[0] == "lab-B" # only single value with count 1, alpha tiebreak + + def test_per_key_function_axis(self, mock_conn): + v = tools_mod.discover_labels(key="function.direction") + assert v["axis"] == "function" + assert v["values"] == {"write": 5, "read": 2} + + def test_per_key_unknown_axis(self, mock_conn): + v = tools_mod.discover_labels(key="thing.bogus") + assert "error" in v + + def test_per_key_missing_dot(self, mock_conn): + v = tools_mod.discover_labels(key="just_a_key") + assert "error" in v + assert "axis-qualified" in v["error"] + + +# -- Deprecation warnings ------------------------------------------ + + +class TestDeprecationWarnings: + def test_describe_fleet_emits_warning(self, mock_conn, recwarn): + tools_mod.describe_fleet() + assert any("describe_fleet" in str(w.message) for w in recwarn.list) + + def test_list_devices_emits_warning(self, mock_conn, recwarn): + tools_mod.list_devices() + assert any("list_devices" in str(w.message) for w in recwarn.list) + + def test_get_device_functions_emits_warning(self, mock_conn, recwarn): + # get_device_functions calls conn.get_device which we haven't mocked; + # the warning is emitted before that call so we still observe it. + mock_conn.get_device = MagicMock(return_value={ + "device_id": "cam-001", "functions": [], "events": [], + "identity": {}, "status": {}, "capabilities": {}, + }) + # Force a fresh patch so get_device path is hit + with patch.object(tools_mod, "get_connection", return_value=mock_conn): + tools_mod.get_device_functions("cam-001") + assert any("get_device_functions" in str(w.message) for w in recwarn.list) diff --git a/packages/device-connect-agent-tools/tests/test_langchain_adapter.py b/packages/device-connect-agent-tools/tests/test_langchain_adapter.py index 6b731d2..d647ee3 100644 --- a/packages/device-connect-agent-tools/tests/test_langchain_adapter.py +++ b/packages/device-connect-agent-tools/tests/test_langchain_adapter.py @@ -69,19 +69,27 @@ def _mock_langchain_and_connection(): del sys.modules[key] +EXPECTED_TOOLS = { + "discover_labels", + "discover", + "invoke_device", + "invoke_device_with_fallback", + "get_device_status", + "discover_devices", +} + + class TestLangchainAdapterExports: def test_module_exports_all_tools(self): from device_connect_agent_tools.adapters import langchain as adapter - for name in ("discover_devices", "invoke_device", "invoke_device_with_fallback", - "get_device_status", "describe_fleet", "list_devices", "get_device_functions"): + for name in EXPECTED_TOOLS: assert hasattr(adapter, name), f"Missing export: {name}" def test_all_list(self): from device_connect_agent_tools.adapters import langchain as adapter - expected = {"discover_devices", "invoke_device", "invoke_device_with_fallback", "get_device_status", "list_devices", "get_device_functions", "describe_fleet"} - assert set(adapter.__all__) == expected + assert set(adapter.__all__) == EXPECTED_TOOLS def test_tools_are_structured_tool_instances(self): from device_connect_agent_tools.adapters import langchain as adapter diff --git a/packages/device-connect-agent-tools/tests/test_strands_adapter.py b/packages/device-connect-agent-tools/tests/test_strands_adapter.py index 6a1ea6f..a40b5ad 100644 --- a/packages/device-connect-agent-tools/tests/test_strands_adapter.py +++ b/packages/device-connect-agent-tools/tests/test_strands_adapter.py @@ -52,19 +52,27 @@ def _mock_strands_and_connection(): sys.modules.pop("strands", None) +EXPECTED_TOOLS = { + "discover_labels", + "discover", + "invoke_device", + "invoke_device_with_fallback", + "get_device_status", + "discover_devices", +} + + class TestStrandsAdapterExports: def test_module_exports_all_tools(self): from device_connect_agent_tools.adapters import strands as adapter - for name in ("discover_devices", "invoke_device", "invoke_device_with_fallback", - "get_device_status", "describe_fleet", "list_devices", "get_device_functions"): + for name in EXPECTED_TOOLS: assert hasattr(adapter, name), f"Missing export: {name}" def test_all_list(self): from device_connect_agent_tools.adapters import strands as adapter - expected = {"discover_devices", "invoke_device", "invoke_device_with_fallback", "get_device_status", "list_devices", "get_device_functions", "describe_fleet"} - assert set(adapter.__all__) == expected + assert set(adapter.__all__) == EXPECTED_TOOLS def test_tools_are_callable(self): from device_connect_agent_tools.adapters import strands as adapter diff --git a/packages/device-connect-server/device_connect_server/portal/views/devices.py b/packages/device-connect-server/device_connect_server/portal/views/devices.py index 84125ac..3f82309 100644 --- a/packages/device-connect-server/device_connect_server/portal/views/devices.py +++ b/packages/device-connect-server/device_connect_server/portal/views/devices.py @@ -320,7 +320,7 @@ async def download_starter_script(request: web.Request): """Device Connect — starter AI agent (Strands + OpenAI). Connects to Device Connect, discovers your fleet, and reacts to device -events by calling tools (list_devices, get_device_functions, invoke_device). +events by calling tools (discover_labels, discover, invoke_device). LLM inference runs through the Arm internal OpenAI proxy. Usage: @@ -403,7 +403,7 @@ async def prepare(self) -> Dict[str, Any]: from strands import Agent from strands.models.openai import OpenAIModel from device_connect_agent_tools.adapters.strands import ( - describe_fleet, list_devices, get_device_functions, + discover_labels, discover, invoke_device, invoke_device_with_fallback, get_device_status, ) @@ -416,7 +416,7 @@ async def prepare(self) -> Dict[str, Any]: params={"max_tokens": self._max_tokens}, ), tools=[ - describe_fleet, list_devices, get_device_functions, + discover_labels, discover, invoke_device, invoke_device_with_fallback, get_device_status, ], system_prompt=self._build_system_prompt(), @@ -441,21 +441,26 @@ def _build_system_prompt(self) -> str: for dt, info in sorted(by_type.items()): locs = ", ".join(sorted(info["locations"])) lines.append(f" - {info['count']}x {dt} (at: {locs})") - fleet = "\\n".join(lines) or " (none yet — call describe_fleet() to refresh)" + fleet = "\\n".join(lines) or " (none yet -- call discover() to refresh)" return ( f"You are an AI agent connected to the Device Connect IoT network.\\n\\n" f"YOUR GOAL: {self.goal}\\n\\n" f"FLEET OVERVIEW ({len(self.devices)} devices):\\n{fleet}\\n\\n" f"DISCOVERY TOOLS:\\n" - f" - describe_fleet() — fleet summary\\n" - f" - list_devices(device_type=..., location=...) — browse devices\\n" - f" - get_device_functions(device_id) — see what a device can do\\n" - f" - invoke_device(device_id, function, params) — call a device function\\n\\n" + f" - discover_labels(key=None) -- fleet label vocabulary " + f"(category, location, direction, modality, ...)\\n" + f" - discover(selector) -- resolve a selector to devices, " + f"functions, or events. Examples:\\n" + f" device(category:camera, location:zone-A/*)\\n" + f" device(robot-001).function(direction:write)\\n" + f" function(safety:critical)\\n" + f" - invoke_device(device_id, function, params) -- call a device function\\n\\n" f"INSTRUCTIONS:\\n" f"When you receive device events, you MUST:\\n" f"1. Analyze the events\\n" - f"2. Use get_device_functions() to check available functions if needed\\n" + f"2. Use discover() with a function-scoped selector to check " + f"available functions if needed\\n" f"3. Use invoke_device() to interact with devices\\n" f"4. Report what you found and what actions you took\\n\\n" f"Always provide llm_reasoning when invoking devices.\\n" diff --git a/tests/drivers/camera.py b/tests/drivers/camera.py index 5a5804d..a2b659f 100644 --- a/tests/drivers/camera.py +++ b/tests/drivers/camera.py @@ -20,6 +20,7 @@ class TestCameraDriver(DeviceDriver): """Simulated camera for integration tests.""" device_type = "test_camera" + labels = {"category": "camera"} def __init__(self, failure_rate: float = 0.0, min_latency_ms: float = 10, max_latency_ms: float = 50, location: str = "test-zone"): @@ -51,7 +52,7 @@ def identity(self) -> DeviceIdentity: def status(self) -> DeviceStatus: return DeviceStatus(location=self._location) - @rpc() + @rpc(labels={"direction": "write", "modality": "rgb"}) async def capture_image(self, resolution: str = "1080p") -> dict: """Capture a simulated test image.""" await self.simulate_delay() @@ -64,12 +65,12 @@ async def capture_image(self, resolution: str = "1080p") -> dict: "device_id": getattr(self, "_device_id", "unknown"), } - @emit() + @emit(labels={"modality": "motion"}) async def state_change_detected(self, zone_id: str, state_class: str, details: Optional[str] = None): """State change detected in camera view.""" pass - @emit() + @emit(labels={"modality": "rgb"}) async def object_detected(self, label: str, confidence: float, bbox: Optional[list] = None): """Object detected in camera view.""" pass diff --git a/tests/drivers/robot.py b/tests/drivers/robot.py index be0e59e..5641a79 100644 --- a/tests/drivers/robot.py +++ b/tests/drivers/robot.py @@ -19,6 +19,7 @@ class TestRobotDriver(DeviceDriver): """Simulated cleaning robot for integration tests.""" device_type = "test_robot" + labels = {"category": "robot"} def __init__(self, clean_duration: float = 0.5, failure_rate: float = 0.0, min_latency_ms: float = 10, max_latency_ms: float = 50, @@ -59,7 +60,7 @@ def identity(self) -> DeviceIdentity: def status(self) -> DeviceStatus: return DeviceStatus(location=self._location) - @rpc() + @rpc(labels={"direction": "write", "safety": "critical"}) async def dispatch_robot(self, zone_id: str) -> dict: """Dispatch the robot to clean a zone.""" await self.simulate_delay() @@ -72,7 +73,7 @@ async def dispatch_robot(self, zone_id: str) -> dict: self._cleaning_task = asyncio.create_task(self._do_cleaning(zone_id)) return {"status": "accepted", "zone_id": zone_id, "estimated_duration": self._clean_duration} - @rpc() + @rpc(labels={"direction": "read"}) async def get_status(self) -> dict: """Get current robot status.""" await self.simulate_delay() diff --git a/tests/drivers/sensor.py b/tests/drivers/sensor.py index ba6baed..5632ce8 100644 --- a/tests/drivers/sensor.py +++ b/tests/drivers/sensor.py @@ -20,6 +20,7 @@ class TestSensorDriver(DeviceDriver): """Simulated temperature/humidity sensor for integration tests.""" device_type = "test_sensor" + labels = {"category": "sensor"} def __init__(self, failure_rate: float = 0.0, min_latency_ms: float = 10, max_latency_ms: float = 50, location: str = "test-room", @@ -59,7 +60,7 @@ def identity(self) -> DeviceIdentity: def status(self) -> DeviceStatus: return DeviceStatus(location=self._location, availability="available") - @rpc() + @rpc(labels={"direction": "read", "modality": "thermal"}) async def get_reading(self, unit: str = "celsius") -> dict: """Get current temperature and humidity reading.""" await self.simulate_delay() @@ -78,13 +79,13 @@ async def get_reading(self, unit: str = "celsius") -> dict: "device_id": getattr(self, "_device_id", "unknown"), } - @rpc() + @rpc(labels={"direction": "write", "safety": "critical"}) async def set_threshold(self, temperature: float, humidity: Optional[float] = None) -> dict: """Set alert thresholds.""" await self.simulate_delay() return {"status": "success", "temperature_threshold": temperature} - @rpc() + @rpc(labels={"direction": "write"}) async def set_location(self, location: str) -> dict: """Update the sensor's location.""" await self.simulate_delay() @@ -92,12 +93,12 @@ async def set_location(self, location: str) -> dict: self._location = location return {"status": "success", "old_location": old, "location": location} - @emit() + @emit(labels={"modality": "thermal"}) async def reading(self, temperature: float, humidity: float, unit: str = "celsius"): """Periodic sensor reading.""" pass - @emit() + @emit(labels={"safety": "critical"}) async def threshold_exceeded(self, temperature: float, humidity: float, exceeded: str): """Threshold exceeded alert.""" pass diff --git a/tests/tests/test_tools_selector.py b/tests/tests/test_tools_selector.py new file mode 100644 index 0000000..d04e7cc --- /dev/null +++ b/tests/tests/test_tools_selector.py @@ -0,0 +1,570 @@ +# Copyright (c) 2024-2026, Arm Limited and Contributors. All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +"""Integration tests for selector-driven discovery tools. + +Covers ``discover()`` and ``discover_labels()`` against real devices +registered via the messaging backend. Exercises the full selector grammar +end-to-end across all five scope shapes (device / device.function / +device.event / function / event), label filters (category, location, +direction, modality, safety), pagination, and the legacy-location mirror. +""" + +import asyncio +import time + +import pytest + +SETTLE_TIME = 0.3 +DISCOVERY_TIMEOUT = 5.0 + + +async def _wait_for_devices(messaging_url, expected_ids): + """Connect and poll until all expected ``device_ids`` are visible. + + Returns the list of flattened device dicts. Caller is responsible for + disconnecting. + """ + from device_connect_agent_tools import connect + from device_connect_agent_tools.connection import get_connection + + await asyncio.to_thread(connect, nats_url=messaging_url) + deadline = time.monotonic() + DISCOVERY_TIMEOUT + while True: + conn = get_connection() + devices = await asyncio.to_thread(conn.list_devices) + ids = {d.get("device_id") for d in devices} + if expected_ids.issubset(ids) or time.monotonic() > deadline: + return devices + await asyncio.sleep(0.25) + + +# -- discover: device-only scope --------------------------------------- + + +@pytest.mark.asyncio +@pytest.mark.integration +async def test_discover_wildcard_returns_all_devices(device_spawner, messaging_url): + """``discover('device(*)')`` returns the full roster.""" + await device_spawner.spawn_camera("itest-sel-all-cam", location="lab-A") + await device_spawner.spawn_sensor("itest-sel-all-sensor", location="lab-B") + await asyncio.sleep(SETTLE_TIME) + + from device_connect_agent_tools import disconnect, discover + + await _wait_for_devices(messaging_url, {"itest-sel-all-cam", "itest-sel-all-sensor"}) + try: + result = await asyncio.to_thread(discover, "device(*)") + assert result["scope"] == "device_only" + assert result["matched"] >= 2 + ids = {d["device_id"] for d in result["results"]} + assert {"itest-sel-all-cam", "itest-sel-all-sensor"} <= ids + finally: + await asyncio.to_thread(disconnect) + + +@pytest.mark.asyncio +@pytest.mark.integration +async def test_discover_by_device_id(device_spawner, messaging_url): + """A bare-id selector resolves to one device.""" + await device_spawner.spawn_camera("itest-sel-id-cam", location="lab-A") + await asyncio.sleep(SETTLE_TIME) + + from device_connect_agent_tools import disconnect, discover + + await _wait_for_devices(messaging_url, {"itest-sel-id-cam"}) + try: + result = await asyncio.to_thread(discover, "device(itest-sel-id-cam)") + assert result["scope"] == "device_only" + assert result["matched"] == 1 + assert result["results"][0]["device_id"] == "itest-sel-id-cam" + finally: + await asyncio.to_thread(disconnect) + + +@pytest.mark.asyncio +@pytest.mark.integration +async def test_discover_by_id_glob(device_spawner, messaging_url): + """Bare-id selectors accept globs (anchored fnmatch).""" + await device_spawner.spawn_camera("itest-sel-glob-cam-1", location="lab-A") + await device_spawner.spawn_camera("itest-sel-glob-cam-2", location="lab-A") + await device_spawner.spawn_sensor("itest-sel-glob-sensor", location="lab-B") + await asyncio.sleep(SETTLE_TIME) + + from device_connect_agent_tools import disconnect, discover + + await _wait_for_devices( + messaging_url, + {"itest-sel-glob-cam-1", "itest-sel-glob-cam-2", "itest-sel-glob-sensor"}, + ) + try: + result = await asyncio.to_thread(discover, "device(itest-sel-glob-cam-*)") + ids = {d["device_id"] for d in result["results"]} + assert "itest-sel-glob-cam-1" in ids + assert "itest-sel-glob-cam-2" in ids + assert "itest-sel-glob-sensor" not in ids + finally: + await asyncio.to_thread(disconnect) + + +@pytest.mark.asyncio +@pytest.mark.integration +async def test_discover_by_category_label(device_spawner, messaging_url): + """``device(category:camera)`` returns only cameras (label-based).""" + await device_spawner.spawn_camera("itest-sel-cat-cam", location="lab-A") + await device_spawner.spawn_robot("itest-sel-cat-robot", location="lab-A") + await device_spawner.spawn_sensor("itest-sel-cat-sensor", location="lab-A") + await asyncio.sleep(SETTLE_TIME) + + from device_connect_agent_tools import disconnect, discover + + await _wait_for_devices( + messaging_url, + {"itest-sel-cat-cam", "itest-sel-cat-robot", "itest-sel-cat-sensor"}, + ) + try: + result = await asyncio.to_thread(discover, "device(category:camera)") + ids = {d["device_id"] for d in result["results"]} + assert "itest-sel-cat-cam" in ids + assert "itest-sel-cat-robot" not in ids + assert "itest-sel-cat-sensor" not in ids + finally: + await asyncio.to_thread(disconnect) + + +@pytest.mark.asyncio +@pytest.mark.integration +async def test_discover_or_within_category(device_spawner, messaging_url): + """Bracket lists OR within a key: cameras or robots, not sensors.""" + await device_spawner.spawn_camera("itest-sel-or-cam", location="lab-A") + await device_spawner.spawn_robot("itest-sel-or-robot", location="lab-A") + await device_spawner.spawn_sensor("itest-sel-or-sensor", location="lab-A") + await asyncio.sleep(SETTLE_TIME) + + from device_connect_agent_tools import disconnect, discover + + await _wait_for_devices( + messaging_url, + {"itest-sel-or-cam", "itest-sel-or-robot", "itest-sel-or-sensor"}, + ) + try: + result = await asyncio.to_thread( + discover, "device(category:[camera,robot])" + ) + ids = {d["device_id"] for d in result["results"]} + assert "itest-sel-or-cam" in ids + assert "itest-sel-or-robot" in ids + assert "itest-sel-or-sensor" not in ids + finally: + await asyncio.to_thread(disconnect) + + +@pytest.mark.asyncio +@pytest.mark.integration +async def test_discover_and_across_category_and_location( + device_spawner, messaging_url +): + """Comma is AND across keys: category=camera AND location=lab-A.""" + await device_spawner.spawn_camera("itest-sel-and-cam-a", location="lab-A") + await device_spawner.spawn_camera("itest-sel-and-cam-b", location="lab-B") + await device_spawner.spawn_robot("itest-sel-and-robot-a", location="lab-A") + await asyncio.sleep(SETTLE_TIME) + + from device_connect_agent_tools import disconnect, discover + + await _wait_for_devices( + messaging_url, + {"itest-sel-and-cam-a", "itest-sel-and-cam-b", "itest-sel-and-robot-a"}, + ) + try: + result = await asyncio.to_thread( + discover, "device(category:camera, location:lab-A)" + ) + ids = {d["device_id"] for d in result["results"]} + assert "itest-sel-and-cam-a" in ids + assert "itest-sel-and-cam-b" not in ids # wrong location + assert "itest-sel-and-robot-a" not in ids # wrong category + finally: + await asyncio.to_thread(disconnect) + + +@pytest.mark.asyncio +@pytest.mark.integration +async def test_discover_by_location_via_legacy_mirror(device_spawner, messaging_url): + """Legacy ``DeviceStatus.location`` is mirrored into ``labels['location']``. + + The flatten_device location-mirror lifts ``status.location`` into + ``labels['location']`` when ``capabilities.labels`` does not declare + one, so selector queries on location work even for drivers that only + populate the legacy heartbeat field. + """ + await device_spawner.spawn_camera("itest-sel-mirror-cam", location="lab-A") + await device_spawner.spawn_sensor("itest-sel-mirror-sensor", location="lab-B") + await asyncio.sleep(SETTLE_TIME) + + from device_connect_agent_tools import disconnect, discover + + await _wait_for_devices( + messaging_url, {"itest-sel-mirror-cam", "itest-sel-mirror-sensor"} + ) + try: + result = await asyncio.to_thread(discover, "device(location:lab-A)") + ids = {d["device_id"] for d in result["results"]} + assert "itest-sel-mirror-cam" in ids + assert "itest-sel-mirror-sensor" not in ids + finally: + await asyncio.to_thread(disconnect) + + +# -- discover: function-scoped -------------------------------------- + + +@pytest.mark.asyncio +@pytest.mark.integration +async def test_discover_function_scope_per_device(device_spawner, messaging_url): + """``device().function(*)`` returns a device's RPC roster.""" + await device_spawner.spawn_camera("itest-sel-fn-cam", location="lab-A") + await asyncio.sleep(SETTLE_TIME) + + from device_connect_agent_tools import disconnect, discover + + await _wait_for_devices(messaging_url, {"itest-sel-fn-cam"}) + try: + result = await asyncio.to_thread( + discover, "device(itest-sel-fn-cam).function(*)" + ) + assert result["scope"] == "device_function" + names = {row.get("name") for row in result["results"]} + assert "capture_image" in names + finally: + await asyncio.to_thread(disconnect) + + +@pytest.mark.asyncio +@pytest.mark.integration +async def test_discover_function_by_name_fleet_wide(device_spawner, messaging_url): + """``device(*).function()`` returns ``(device, function)`` tuples.""" + await device_spawner.spawn_camera("itest-sel-fnname-cam-1", location="lab-A") + await device_spawner.spawn_camera("itest-sel-fnname-cam-2", location="lab-A") + await asyncio.sleep(SETTLE_TIME) + + from device_connect_agent_tools import disconnect, discover + + await _wait_for_devices( + messaging_url, {"itest-sel-fnname-cam-1", "itest-sel-fnname-cam-2"} + ) + try: + result = await asyncio.to_thread( + discover, "device(*).function(capture_image)" + ) + device_ids = {row["device_id"] for row in result["results"]} + assert {"itest-sel-fnname-cam-1", "itest-sel-fnname-cam-2"} <= device_ids + finally: + await asyncio.to_thread(disconnect) + + +@pytest.mark.asyncio +@pytest.mark.integration +async def test_discover_function_by_direction_label(device_spawner, messaging_url): + """``device(*).function(direction:write)`` matches on FunctionDef labels.""" + await device_spawner.spawn_camera("itest-sel-dir-cam", location="lab-A") + await device_spawner.spawn_sensor("itest-sel-dir-sensor", location="lab-A") + await asyncio.sleep(SETTLE_TIME) + + from device_connect_agent_tools import disconnect, discover + + await _wait_for_devices(messaging_url, {"itest-sel-dir-cam", "itest-sel-dir-sensor"}) + try: + result = await asyncio.to_thread( + discover, "device(*).function(direction:write)" + ) + names = {row.get("name") for row in result["results"]} + # camera.capture_image (write), sensor.set_threshold (write), + # sensor.set_location (write) + assert "capture_image" in names + assert "set_threshold" in names + assert "get_reading" not in names # direction:read + finally: + await asyncio.to_thread(disconnect) + + +@pytest.mark.asyncio +@pytest.mark.integration +async def test_discover_function_safety_critical(device_spawner, messaging_url): + """``function(safety:critical)`` returns critical RPCs fleet-wide.""" + await device_spawner.spawn_robot("itest-sel-crit-robot", location="lab-A") + await device_spawner.spawn_sensor("itest-sel-crit-sensor", location="lab-A") + await asyncio.sleep(SETTLE_TIME) + + from device_connect_agent_tools import disconnect, discover + + await _wait_for_devices(messaging_url, {"itest-sel-crit-robot", "itest-sel-crit-sensor"}) + try: + result = await asyncio.to_thread(discover, "function(safety:critical)") + assert result["scope"] == "function_only" + names = {row.get("name") for row in result["results"]} + # robot.dispatch_robot, sensor.set_threshold are safety:critical + assert "dispatch_robot" in names + assert "set_threshold" in names + finally: + await asyncio.to_thread(disconnect) + + +@pytest.mark.asyncio +@pytest.mark.integration +async def test_discover_function_and_labels(device_spawner, messaging_url): + """``function(direction:write, modality:rgb)`` ANDs across function labels.""" + await device_spawner.spawn_camera("itest-sel-fnand-cam", location="lab-A") + await device_spawner.spawn_sensor("itest-sel-fnand-sensor", location="lab-A") + await asyncio.sleep(SETTLE_TIME) + + from device_connect_agent_tools import disconnect, discover + + await _wait_for_devices(messaging_url, {"itest-sel-fnand-cam", "itest-sel-fnand-sensor"}) + try: + result = await asyncio.to_thread( + discover, "function(direction:write, modality:rgb)" + ) + names = {row.get("name") for row in result["results"]} + # only camera.capture_image is direction:write AND modality:rgb + assert names == {"capture_image"} or "capture_image" in names + assert "set_threshold" not in names # write but no modality:rgb + finally: + await asyncio.to_thread(disconnect) + + +# -- discover: event-scoped ----------------------------------------- + + +@pytest.mark.asyncio +@pytest.mark.integration +async def test_discover_event_by_name_fleet_wide(device_spawner, messaging_url): + """``event()`` returns events fleet-wide.""" + await device_spawner.spawn_camera("itest-sel-evname-cam", location="lab-A") + await asyncio.sleep(SETTLE_TIME) + + from device_connect_agent_tools import disconnect, discover + + await _wait_for_devices(messaging_url, {"itest-sel-evname-cam"}) + try: + result = await asyncio.to_thread(discover, "event(object_detected)") + assert result["scope"] == "event_only" + device_ids = {row["device_id"] for row in result["results"]} + assert "itest-sel-evname-cam" in device_ids + finally: + await asyncio.to_thread(disconnect) + + +@pytest.mark.asyncio +@pytest.mark.integration +async def test_discover_event_by_modality_label(device_spawner, messaging_url): + """``device(*).event(modality:rgb)`` matches on EventDef labels.""" + await device_spawner.spawn_camera("itest-sel-evmod-cam", location="lab-A") + await device_spawner.spawn_sensor("itest-sel-evmod-sensor", location="lab-A") + await asyncio.sleep(SETTLE_TIME) + + from device_connect_agent_tools import disconnect, discover + + await _wait_for_devices(messaging_url, {"itest-sel-evmod-cam", "itest-sel-evmod-sensor"}) + try: + result = await asyncio.to_thread( + discover, "device(*).event(modality:rgb)" + ) + names = {row.get("name") for row in result["results"]} + # camera.object_detected has modality:rgb + assert "object_detected" in names + # sensor.reading has modality:thermal, not rgb + assert "reading" not in names + finally: + await asyncio.to_thread(disconnect) + + +@pytest.mark.asyncio +@pytest.mark.integration +async def test_discover_event_safety_critical(device_spawner, messaging_url): + """``event(safety:critical)`` finds the sensor.threshold_exceeded event.""" + await device_spawner.spawn_sensor("itest-sel-evcrit-sensor", location="lab-A") + await asyncio.sleep(SETTLE_TIME) + + from device_connect_agent_tools import disconnect, discover + + await _wait_for_devices(messaging_url, {"itest-sel-evcrit-sensor"}) + try: + result = await asyncio.to_thread(discover, "event(safety:critical)") + names = {row.get("name") for row in result["results"]} + assert "threshold_exceeded" in names + finally: + await asyncio.to_thread(disconnect) + + +# -- discover: pagination & errors ---------------------------------- + + +@pytest.mark.asyncio +@pytest.mark.integration +async def test_discover_pagination(device_spawner, messaging_url): + """``offset`` and ``limit`` produce stable, non-overlapping pages.""" + ids = {f"itest-sel-page-cam-{i}" for i in range(3)} + for did in sorted(ids): + await device_spawner.spawn_camera(did, location="lab-A") + await asyncio.sleep(SETTLE_TIME) + + from device_connect_agent_tools import disconnect, discover + + await _wait_for_devices(messaging_url, ids) + try: + page1 = await asyncio.to_thread( + discover, "device(category:camera)", 0, 2 + ) + page2 = await asyncio.to_thread( + discover, "device(category:camera)", page1["next_offset"] or 0, 2 + ) + assert page1["returned"] <= 2 + page1_ids = {d["device_id"] for d in page1["results"]} + page2_ids = {d["device_id"] for d in page2["results"]} + assert not (page1_ids & page2_ids), "pages should not overlap" + finally: + await asyncio.to_thread(disconnect) + + +@pytest.mark.asyncio +@pytest.mark.integration +async def test_discover_offset_past_end_returns_empty(device_spawner, messaging_url): + """An offset beyond ``matched`` returns an empty page with ``next_offset=None``.""" + await device_spawner.spawn_camera("itest-sel-oob-cam", location="lab-A") + await asyncio.sleep(SETTLE_TIME) + + from device_connect_agent_tools import disconnect, discover + + await _wait_for_devices(messaging_url, {"itest-sel-oob-cam"}) + try: + result = await asyncio.to_thread(discover, "device(*)", 9999, 50) + assert result["returned"] == 0 + assert result["results"] == [] + assert result["next_offset"] is None + finally: + await asyncio.to_thread(disconnect) + + +@pytest.mark.asyncio +@pytest.mark.integration +async def test_discover_invalid_selector_returns_error(device_spawner, messaging_url): + """A bad selector returns an error-as-data envelope, not a raise.""" + await device_spawner.spawn_camera("itest-sel-err-cam", location="lab-A") + await asyncio.sleep(SETTLE_TIME) + + from device_connect_agent_tools import connect, disconnect, discover + + await asyncio.to_thread(connect, nats_url=messaging_url) + try: + result = await asyncio.to_thread(discover, "device(") + assert "error" in result + assert result["matched"] == 0 + assert result["results"] == [] + finally: + await asyncio.to_thread(disconnect) + + +# -- discover_labels() ----------------------------------------------- + + +@pytest.mark.asyncio +@pytest.mark.integration +async def test_discover_labels_includes_category(device_spawner, messaging_url): + """Vocabulary surfaces ``category`` from device-level labels.""" + await device_spawner.spawn_camera("itest-sel-vcat-cam", location="lab-A") + await device_spawner.spawn_robot("itest-sel-vcat-robot", location="lab-A") + await device_spawner.spawn_sensor("itest-sel-vcat-sensor", location="lab-A") + await asyncio.sleep(SETTLE_TIME) + + from device_connect_agent_tools import disconnect, discover_labels + + await _wait_for_devices( + messaging_url, + {"itest-sel-vcat-cam", "itest-sel-vcat-robot", "itest-sel-vcat-sensor"}, + ) + try: + result = await asyncio.to_thread(discover_labels) + cat = result["device_keys"].get("category") + assert cat is not None + values = cat["values"] + assert "camera" in values + assert "robot" in values + assert "sensor" in values + finally: + await asyncio.to_thread(disconnect) + + +@pytest.mark.asyncio +@pytest.mark.integration +async def test_discover_labels_includes_location_via_mirror( + device_spawner, messaging_url +): + """Vocabulary surfaces ``location`` even when only ``DeviceStatus.location`` is set.""" + await device_spawner.spawn_camera("itest-sel-vloc-cam", location="lab-A") + await device_spawner.spawn_sensor("itest-sel-vloc-sensor", location="lab-B") + await asyncio.sleep(SETTLE_TIME) + + from device_connect_agent_tools import disconnect, discover_labels + + await _wait_for_devices(messaging_url, {"itest-sel-vloc-cam", "itest-sel-vloc-sensor"}) + try: + result = await asyncio.to_thread(discover_labels) + loc = result["device_keys"].get("location") + assert loc is not None + values = loc["values"] + assert "lab-A" in values + assert "lab-B" in values + finally: + await asyncio.to_thread(disconnect) + + +@pytest.mark.asyncio +@pytest.mark.integration +async def test_discover_labels_function_direction_histogram( + device_spawner, messaging_url +): + """Function-axis vocabulary surfaces ``direction`` with read/write counts.""" + await device_spawner.spawn_camera("itest-sel-vdir-cam", location="lab-A") + await device_spawner.spawn_sensor("itest-sel-vdir-sensor", location="lab-A") + await asyncio.sleep(SETTLE_TIME) + + from device_connect_agent_tools import disconnect, discover_labels + + await _wait_for_devices(messaging_url, {"itest-sel-vdir-cam", "itest-sel-vdir-sensor"}) + try: + result = await asyncio.to_thread(discover_labels) + direction = result["function_keys"].get("direction") + assert direction is not None + values = direction["values"] + assert "read" in values + assert "write" in values + finally: + await asyncio.to_thread(disconnect) + + +@pytest.mark.asyncio +@pytest.mark.integration +async def test_discover_labels_per_key_pagination(device_spawner, messaging_url): + """``discover_labels(key='device.category')`` paginates one key's values.""" + await device_spawner.spawn_camera("itest-sel-vpg-cam", location="lab-A") + await device_spawner.spawn_robot("itest-sel-vpg-robot", location="lab-A") + await device_spawner.spawn_sensor("itest-sel-vpg-sensor", location="lab-A") + await asyncio.sleep(SETTLE_TIME) + + from device_connect_agent_tools import disconnect, discover_labels + + await _wait_for_devices( + messaging_url, + {"itest-sel-vpg-cam", "itest-sel-vpg-robot", "itest-sel-vpg-sensor"}, + ) + try: + result = await asyncio.to_thread(discover_labels, "device.category") + assert result["axis"] == "device" + assert result["key"] == "category" + assert "values" in result + # at least camera, robot, sensor are present + assert {"camera", "robot", "sensor"} <= set(result["values"].keys()) + finally: + await asyncio.to_thread(disconnect) From 8cbf77c5d1f5eaef78f736050f7ef40d1214f815 Mon Sep 17 00:00:00 2001 From: Sourav Pati Date: Sat, 9 May 2026 21:08:34 -0700 Subject: [PATCH 04/21] feat(discovery): structure discover/discover_labels error envelope Errors returned by discover() and discover_labels() are now structured {"code": ..., "message": ...} dicts rather than free-form strings. This lets callers branch on the code programmatically while still surfacing the message to logs or end users. Codes emitted: - invalid_selector selector is not a string - selector_parse_error selector is a string but malformed - connection_error registry / messaging backend unavailable - key_not_axis_qualified discover_labels key missing axis prefix - unknown_axis discover_labels axis not in {device, function, event} --- docs/adr/0001-selector-driven-discovery.md | 26 ++++++++++ .../device_connect_agent_tools/tools.py | 51 ++++++++++++++++--- .../tests/test_discover.py | 15 +++--- tests/tests/test_tools_selector.py | 2 +- 4 files changed, 79 insertions(+), 15 deletions(-) diff --git a/docs/adr/0001-selector-driven-discovery.md b/docs/adr/0001-selector-driven-discovery.md index 635ed67..678b123 100644 --- a/docs/adr/0001-selector-driven-discovery.md +++ b/docs/adr/0001-selector-driven-discovery.md @@ -160,6 +160,32 @@ Operation tools (`invoke_many`, `broadcast`) do not paginate - that is a streaming-dispatch concern. Subscribe to the result channel for per-target detail at large fan-out. +## Error responses + +`discover` and `discover_labels` return errors as data inside the response +envelope rather than raising. The shape is stable so callers can branch on +the `code` programmatically and surface `message` to logs or users: + +```json +{ "matched": 0, "returned": 0, "offset": 0, "next_offset": null, + "results": [], + "error": { + "code": "selector_parse_error", + "message": "Unknown scope 'widgets' at position 0\n widgets(*)\n ^" + } +} +``` + +Codes: + +| Code | Cause | +| --- | --- | +| `invalid_selector` | Selector is not a string (or otherwise unusable as input) | +| `selector_parse_error` | Selector is a string but malformed | +| `connection_error` | Registry or messaging backend unavailable | +| `key_not_axis_qualified` | `discover_labels(key=...)` missing the `device.` / `function.` / `event.` prefix | +| `unknown_axis` | `discover_labels(key=...)` axis prefix not in `{device, function, event}` | + ## Worked examples ### Find every camera in lab-A and capture an image from each diff --git a/packages/device-connect-agent-tools/device_connect_agent_tools/tools.py b/packages/device-connect-agent-tools/device_connect_agent_tools/tools.py index 0d3aa73..db71bc2 100644 --- a/packages/device-connect-agent-tools/device_connect_agent_tools/tools.py +++ b/packages/device-connect-agent-tools/device_connect_agent_tools/tools.py @@ -94,7 +94,29 @@ def _normalize_pagination(offset: int, limit: int, default_limit: int) -> tuple[ return safe_offset, safe_limit -def _empty_envelope(scope: str | None = None, error: str | None = None) -> dict[str, Any]: +def _error(code: str, message: str) -> dict[str, str]: + """Build the canonical structured error object. + + Errors are returned as data (not raised) inside the response envelope. + The ``code`` is a stable, machine-readable string callers may switch on; + ``message`` is human-readable and may include positional detail (parse + caret, axis name, etc.) suitable for logging or surfacing to the user. + + Codes currently emitted: + - ``selector_parse_error`` selector string is malformed + - ``invalid_selector`` selector is not a usable input + (None, non-string, etc.) + - ``connection_error`` registry / messaging unavailable + - ``key_not_axis_qualified`` discover_labels key missing axis prefix + - ``unknown_axis`` discover_labels axis not in + {device, function, event} + """ + return {"code": code, "message": message} + + +def _empty_envelope( + scope: str | None = None, error: dict[str, str] | None = None +) -> dict[str, Any]: """Build the canonical zero-result response envelope.""" out: dict[str, Any] = { "matched": 0, @@ -219,19 +241,26 @@ def discover( # Parse the selector at the system boundary; surface a clean error to # the caller rather than raising into agent code. + if not isinstance(selector, str): + return _empty_envelope( + error=_error( + "invalid_selector", + f"Selector must be a string, got {type(selector).__name__}", + ) + ) try: sel: Selector = parse_selector(selector) except SelectorParseError as e: - return _empty_envelope(error=str(e)) - except (TypeError, ValueError) as e: - return _empty_envelope(error=f"Invalid selector: {e}") + return _empty_envelope(error=_error("selector_parse_error", str(e))) try: conn = get_connection() devices = conn.list_devices() except Exception as e: logger.error("discover(%r) failed loading fleet: %s", selector, e) - return _empty_envelope(scope=sel.scope.value, error=str(e)) + return _empty_envelope( + scope=sel.scope.value, error=_error("connection_error", str(e)) + ) # Apply the device-axis filter (vacuously True when sel.device is None). matched_devices = [ @@ -362,7 +391,7 @@ def discover_labels( devices = conn.list_devices() except Exception as e: logger.error("discover_labels failed loading fleet: %s", e) - return _empty_envelope(error=str(e)) + return _empty_envelope(error=_error("connection_error", str(e))) # Aggregate function and event entities once. functions: list[dict] = [] @@ -392,7 +421,10 @@ def discover_labels( # Per-key form: split on the first dot to pick an axis. if "." not in key: return _empty_envelope( - error=f"Key must be axis-qualified (device., function., event.): {key!r}" + error=_error( + "key_not_axis_qualified", + f"Key must be axis-qualified (device., function., event.): {key!r}", + ) ) axis, label_key = key.split(".", 1) if axis == "device": @@ -406,7 +438,10 @@ def discover_labels( total = len(events) else: return _empty_envelope( - error=f"Unknown axis {axis!r} (expected device|function|event)" + error=_error( + "unknown_axis", + f"Unknown axis {axis!r} (expected device|function|event)", + ) ) counts = source.get(label_key, {}) diff --git a/packages/device-connect-agent-tools/tests/test_discover.py b/packages/device-connect-agent-tools/tests/test_discover.py index 2bc93e0..efc7de5 100644 --- a/packages/device-connect-agent-tools/tests/test_discover.py +++ b/packages/device-connect-agent-tools/tests/test_discover.py @@ -264,26 +264,28 @@ def test_zero_limit_falls_back_to_default(self, mock_conn): class TestDiscoverErrors: def test_bad_selector_returns_error_envelope(self, mock_conn): r = tools_mod.discover("not a selector at all") - assert "error" in r + assert r["error"]["code"] == "selector_parse_error" assert r["matched"] == 0 assert r["results"] == [] def test_unknown_scope_in_selector(self, mock_conn): r = tools_mod.discover("widgets(*)") assert "error" in r - assert "unknown scope" in r["error"].lower() + assert r["error"]["code"] == "selector_parse_error" + assert "unknown scope" in r["error"]["message"].lower() def test_connection_failure_returns_error(self): broken = MagicMock() broken.list_devices.side_effect = RuntimeError("messaging down") with patch.object(tools_mod, "get_connection", return_value=broken): r = tools_mod.discover("device(*)") - assert "error" in r + assert r["error"]["code"] == "connection_error" + assert "messaging down" in r["error"]["message"] assert r["matched"] == 0 def test_non_string_selector(self, mock_conn): r = tools_mod.discover(None) # type: ignore[arg-type] - assert "error" in r + assert r["error"]["code"] == "invalid_selector" # -- discover_labels ------------------------------------------------ @@ -328,12 +330,13 @@ def test_per_key_function_axis(self, mock_conn): def test_per_key_unknown_axis(self, mock_conn): v = tools_mod.discover_labels(key="thing.bogus") - assert "error" in v + assert v["error"]["code"] == "unknown_axis" def test_per_key_missing_dot(self, mock_conn): v = tools_mod.discover_labels(key="just_a_key") assert "error" in v - assert "axis-qualified" in v["error"] + assert v["error"]["code"] == "key_not_axis_qualified" + assert "axis-qualified" in v["error"]["message"] # -- Deprecation warnings ------------------------------------------ diff --git a/tests/tests/test_tools_selector.py b/tests/tests/test_tools_selector.py index d04e7cc..7711393 100644 --- a/tests/tests/test_tools_selector.py +++ b/tests/tests/test_tools_selector.py @@ -459,7 +459,7 @@ async def test_discover_invalid_selector_returns_error(device_spawner, messaging await asyncio.to_thread(connect, nats_url=messaging_url) try: result = await asyncio.to_thread(discover, "device(") - assert "error" in result + assert result["error"]["code"] == "selector_parse_error" assert result["matched"] == 0 assert result["results"] == [] finally: From 85369ee7d796cb36521979b63cbcec18c8861a5f Mon Sep 17 00:00:00 2001 From: Sourav Pati Date: Sun, 10 May 2026 10:07:01 -0700 Subject: [PATCH 05/21] docs: move discovery guide to docs/discovery.md and trim to shipped tools The doc is a developer guide rather than a decision record: drop the "ADR 0001:" framing, status line, and motivation paragraph. Trim the content to the discovery surface that ships with this PR (labels, selector grammar, discover, discover_labels, response envelope, error codes) so worked examples are runnable today. --- docs/{adr/0001-selector-driven-discovery.md => discovery.md} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename docs/{adr/0001-selector-driven-discovery.md => discovery.md} (100%) diff --git a/docs/adr/0001-selector-driven-discovery.md b/docs/discovery.md similarity index 100% rename from docs/adr/0001-selector-driven-discovery.md rename to docs/discovery.md From a046f469659479d5ad1d39afeb37e8bead67ca6f Mon Sep 17 00:00:00 2001 From: Sourav Pati Date: Sun, 10 May 2026 10:07:59 -0700 Subject: [PATCH 06/21] docs: trim discovery guide to shipped tools Trim docs/discovery.md to the discovery surface that ships in this PR (labels, selector grammar, discover, discover_labels, response envelope, error codes). Drop the ADR framing (status line, summary/motivation), the "Operations" section listing tools that have not landed yet, the CLI section, and worked examples that called those tools, so the guide matches what a developer can actually run today. --- docs/discovery.md | 213 ++++++++++++++++++++-------------------------- 1 file changed, 91 insertions(+), 122 deletions(-) diff --git a/docs/discovery.md b/docs/discovery.md index 678b123..6d2b8b4 100644 --- a/docs/discovery.md +++ b/docs/discovery.md @@ -1,30 +1,18 @@ -# ADR 0001: Selector-driven discovery and operations +# Discovery -- **Status:** Accepted +Device Connect uses one selector grammar to address devices, functions, and +events. The same selector string drives discovery: it tells the system +**which** entities you mean. Labels attached to devices, functions, and +events provide the dimensions to filter on. -## Summary - -Device Connect exposes one selector grammar that addresses devices, -functions, and events. The same selector string drives every discovery and -operation tool: it tells the system **which** entities you mean. Labels -attached to devices, functions, and events provide the dimensions to filter -on. - -Two reasons this matters in practice: - -- **Agent context budgets.** Loading every device's full schema into an LLM - context exhausts the budget on fleets of more than a few dozen devices. - Selectors let an agent narrow first and load schemas only for what it - actually needs. -- **Cross-cutting queries.** Real questions are rarely "list this one - device" - they are "every camera in lab-A", "all critical RPCs", - "any motion event in zone-B". One grammar covers all of them. +This guide covers the labels schema, the selector grammar, and the two +tools that resolve selectors. ## Labels -Labels are key/value metadata on devices, functions, and events. Values are -strings or lists of strings. Lists express composite identity (a smart -camera that is both `camera` and `inference`). +Labels are key/value metadata. Values are strings or lists of strings. +Lists express composite identity (a smart camera that is both `camera` and +`inference`). Drivers declare labels in two places: @@ -49,7 +37,7 @@ class SmartCamera(DeviceDriver): These keys carry conventional meaning. Custom keys are always allowed alongside them. -| Question the agent asks | Key | Applies to | Example values | +| Question | Key | Applies to | Example values | | --- | --- | --- | --- | | What is it? | `category` | device | `camera`, `robot`, `hub`, `sensor`, `actuator`, `inference` | | Where is it? | `location` | device | `lab-A`, `zone-A/dock` (`/`-hierarchical, glob-able) | @@ -60,6 +48,13 @@ alongside them. The RPC-vs-event distinction is structural (FunctionDef vs EventDef) and is expressed by the selector scope, not by a label. +### Drivers without label declarations + +Drivers that populate only the legacy `DeviceStatus.location` heartbeat +field are still discoverable by location: the value is mirrored into +`labels["location"]` at the discovery boundary so selector queries on +location work without a driver change. + ## Selector grammar ``` @@ -74,8 +69,7 @@ Inside `(...)`: - `key:value` - single-value match - `key:[v1,v2]` - OR within a key (matches if the label value contains any - of the listed values; multi-valued labels match if any element is in the - list) + listed value; multi-valued labels match if any element is in the list) - `key:pattern*` - anchored glob (`*`, `?`); `set_*` matches `set_threshold` but not `unset_threshold`. Use `*set*` for substring. - `k1:v1,k2:v2` - AND across keys @@ -88,7 +82,7 @@ Keys inside `device(...)` resolve against device labels; keys inside resolve against event labels. The `.` chains: "narrow to these devices, then narrow to these functions or events on them." -### Examples +### Selector examples ``` device(category:camera) all cameras @@ -102,63 +96,66 @@ function(estop) fleet emergency-st ## Tools -### Discovery +### `discover(selector, offset=0, limit=200)` -| Tool | What it returns | -| --- | --- | -| `discover_labels(key=None, offset=0, limit=50)` | Fleet label vocabulary. With no `key`, returns top values per key across each axis (device, function, event). With `key="device.location"` (etc.), paginates one key's values. Use this first when you do not know which dimensions are available. | -| `discover(selector, offset=0, limit=200)` | Resolves a selector to matched entities. Returns devices, function tuples, or event tuples depending on the selector scope. Includes a `label_histogram` so you can see which dimensions to narrow on next without a separate call. | +Resolves a selector to matched entities. Returns devices, function tuples, +or event tuples depending on the selector scope. The response includes a +`label_histogram` so you can see which dimensions to narrow on next without +a separate call. `discover()` includes full schemas inline when the matched set is small, and switches to a name-and-labels summary above `DEVICE_CONNECT_FUNCTION_THRESHOLD` (default 20). The threshold is -configurable. - -### Operations - -Calling a function on devices is one logical operation; the only choice is -whether you want to wait for replies and how they are surfaced. +configurable via environment variable. -| Tool | Selector resolves to | Reply mode | -| --- | --- | --- | -| `invoke(selector, params)` | exactly one RPC tuple | sync, single result | -| `invoke_many(selector, params, where=, bindings=)` | any number of RPC tuples | sync, aggregated | -| `broadcast(selector, function, params, where=, bindings=, fire_at=, on_late=)` | any number of RPC tuples | async; correlation-tagged replies stream as events | -| `subscribe(selector)` | events, or `correlation:` for a broadcast's replies | subscription handle | -| `await_replies(correlation_id, timeout=, until=)` | replies for one broadcast | sync helper that subscribes, collects, returns | +### `discover_labels(key=None, offset=0, limit=50)` -`invoke_many` and `broadcast` accept an optional `where` predicate -evaluated at the edge against each candidate's identity, labels, and shared -`bindings`. Use `where` for self-knowable state ("battery > 50%") and -shared `bindings` for dispatcher-computed selection masks (spatial regions, -ML score top-K, random samples). +Returns the fleet label vocabulary. Use this first when you do not know +which dimensions are available. -`broadcast` accepts `fire_at` (wall-clock epoch seconds) for synchronized -fan-out: each device holds the message and fires from its own clock at the -target time. `on_late` (`"skip"` or `"fire"`) controls behaviour when a -device receives the message after the deadline. +- With no `key`: returns top values per key across each axis (`device_keys`, + `function_keys`, `event_keys`). +- With a `key` like `"device.location"` or `"function.direction"`: + paginates the full value list for that one key. -## Pagination +## Response envelope -`discover` and `discover_labels` accept `offset` and `limit`. Responses -follow a stable envelope: +`discover` returns a stable envelope: ```json { - "matched": 7421, - "returned": 200, + "scope": "device_only", + "matched": 47, + "returned": 20, "offset": 0, - "next_offset": 200, - "results": [...] + "next_offset": 20, + "results": [...], + "label_histogram": { + "category": { + "values": {"camera": 312, "robot": 89, "sensor": 601}, + "multivalued": true, + "unique_devices": 1002 + } + } } ``` -`next_offset` is `null` when there are no more pages. The hard ceiling on -`limit` is 1000 to prevent runaway responses; ask for more pages instead. +Fields: -Operation tools (`invoke_many`, `broadcast`) do not paginate - that is a -streaming-dispatch concern. Subscribe to the result channel for per-target -detail at large fan-out. +- `scope` - one of `device_only`, `device_function`, `device_event`, + `function_only`, `event_only`. +- `matched` - total matched entities (across all pages). +- `returned` - rows in this page. +- `offset` / `next_offset` - pagination cursor; `next_offset` is `null` when + no more pages. +- `results` - per-page rows. Shape depends on scope (devices, function + tuples, or event tuples). +- `label_histogram` - per-key vocabulary across the matched set + (pre-pagination), so you can choose how to narrow next. On the device + axis, multi-valued keys also carry `unique_devices`. + +The hard ceiling on `limit` is 1000 to prevent runaway responses; ask for +more pages instead. ## Error responses @@ -176,8 +173,6 @@ the `code` programmatically and surface `message` to logs or users: } ``` -Codes: - | Code | Cause | | --- | --- | | `invalid_selector` | Selector is not a string (or otherwise unusable as input) | @@ -188,75 +183,49 @@ Codes: ## Worked examples -### Find every camera in lab-A and capture an image from each +### Browse the fleet vocabulary ```python -result = invoke_many( - selector="device(category:camera, location:lab-A).function(capture_image)", - params={"resolution": "1080p"}, -) -# {"candidates": 12, "matched": 12, "succeeded": 12, "results": [...], "errors": []} -``` - -### Async fleet emergency-stop +from device_connect_agent_tools import connect, discover_labels -```python -broadcast("function(estop)") -# {"correlation_id": "br-7f3a91", "candidates": 240} - -# Optionally wait for replies: -replies = await_replies("br-7f3a91", timeout=5.0) -``` - -### Synchronized actuation across a phone fleet +connect() +vocab = discover_labels() +# {"total_devices": 1247, "total_functions": 7100, "total_events": 1292, +# "device_keys": {"category": {...}, "location": {...}}, +# "function_keys": {"direction": {...}, "modality": {...}, "safety": {...}}, +# "event_keys": {"modality": {...}}} -```python -broadcast( - selector="device(category:phone, location:auditorium-A)", - function="set_flashlight", - params={"on": True, "color": "white"}, - where="mask[seat_row][seat_col] == 1", - bindings={"mask": }, - fire_at=time.time() + 0.500, - on_late="skip", -) +# Drill into one dimension: +locations = discover_labels(key="device.location", limit=50) ``` -### Browse the fleet vocabulary first +### Find every camera in lab-A ```python -vocab = discover_labels() -# {"total_devices": 1247, "total_functions": 7100, -# "device_keys": {"category": {...}, "location": {...}}, -# "function_keys": {"direction": {...}, "modality": {...}, "safety": {...}}, -# "event_keys": {"modality": {...}}} +from device_connect_agent_tools import discover -# Then narrow to one dimension: -locations = discover_labels(key="device.location", limit=50) +result = discover("device(category:camera, location:lab-A/*)") +for d in result["results"]: + print(d["device_id"], d["labels"]) ``` -### Subscribe to motion events in lab-A +### Find every write RPC on cameras, fleet-wide ```python -sub = subscribe("device(location:lab-A/*).event(modality:motion)") -# {"subscription_id": "sub-abc123", "matched": 8} +result = discover("device(category:camera).function(direction:write)") +for row in result["results"]: + print(row["device_id"], row["name"]) ``` -## CLI - -The same selector syntax drives the operator CLIs. Every CLI command maps -to the matching tool call. +### Paginate a large result set +```python +offset = 0 +while True: + page = discover("device(*)", offset=offset, limit=200) + for d in page["results"]: + process(d) + if page["next_offset"] is None: + break + offset = page["next_offset"] ``` -devctl discover-labels [--key K] [--offset N] [--limit M] -devctl discover "" [--offset N] [--limit M] - -statectl invoke "" [--param k=v] -statectl invoke-many "" [--param k=v] [--where E] -statectl broadcast "" [--param k=v] [--where E] [--fire-at T] -statectl subscribe "" -statectl await "" [--timeout T] -``` - -CLI flags `--param k=v` and `--where E` pack into the tool arguments; the -CLIs are thin shell wrappers over the Python tools. From bf137d841eb2ee90197e1dce66b1e4bde1924c58 Mon Sep 17 00:00:00 2001 From: Sourav Pati Date: Sun, 10 May 2026 11:08:55 -0700 Subject: [PATCH 07/21] feat(invoke): selector-based invoke and invoke_many with sync fan-out Add two selector-driven invocation tools that replace the legacy invoke_device(device_id, function, params) shape: - invoke(selector, params, llm_reasoning) resolves a function-scoped selector to exactly one (device, function) tuple and calls it. Returns {success, device_id, function, result|error}. Returns no_match, ambiguous_match, invalid_invoke_scope, or invalid_selector errors as structured envelopes when the selector does not resolve cleanly. - invoke_many(selector, params, timeout, max_concurrency, llm_reasoning) resolves to N (device, function) tuples and fans out the calls in parallel via a thread pool. Partial-failure semantics: a single target's failure does not abort siblings. Returns {candidates, matched, succeeded, failed, results, errors} with per-target structured errors. Per-target timeout defaults to 30s. invoke_device gains a DeprecationWarning pointing to invoke(); the function still works for one release while callers migrate. Adapters (Claude Agent SDK, Strands, LangChain, the in-tree StrandsOpenAIDeviceConnectAgent, and the operator-facing AGENT_SCRIPT template) drop invoke_device and expose invoke / invoke_many instead. invoke_device_with_fallback stays unchanged -- it covers a different ergonomic case (try a list of device ids in order) with no selector equivalent. 22 unit tests cover scope rejection, ambiguous and zero matches, JSON-RPC error mapping, partial failure, per-target timeout propagation, and llm_reasoning stripping. 9 integration tests cover single-target invoke, robot dispatch through to event emission, fan-out across multiple cameras, partial failure, and zero-candidate empty envelopes. --- .../device_connect_agent_tools/__init__.py | 32 +- .../adapters/claude.py | 62 +++- .../adapters/langchain.py | 23 +- .../adapters/strands.py | 23 +- .../adapters/strands_agent.py | 15 +- .../device_connect_agent_tools/tools.py | 285 ++++++++++++++- .../tests/test_claude_adapter.py | 3 +- .../tests/test_invoke.py | 336 ++++++++++++++++++ .../tests/test_langchain_adapter.py | 3 +- .../tests/test_strands_adapter.py | 3 +- .../portal/views/devices.py | 16 +- tests/tests/test_tools_invoke.py | 256 ++++++++++--- 12 files changed, 941 insertions(+), 116 deletions(-) create mode 100644 packages/device-connect-agent-tools/tests/test_invoke.py diff --git a/packages/device-connect-agent-tools/device_connect_agent_tools/__init__.py b/packages/device-connect-agent-tools/device_connect_agent_tools/__init__.py index de79913..c809baa 100644 --- a/packages/device-connect-agent-tools/device_connect_agent_tools/__init__.py +++ b/packages/device-connect-agent-tools/device_connect_agent_tools/__init__.py @@ -4,19 +4,21 @@ """Device Connect Tools — framework-agnostic SDK for Device Connect IoT. -Selector-driven discovery keeps LLM context small: +Selector-driven discovery and invocation keep LLM context small: - from device_connect_agent_tools import connect, discover, discover_labels + from device_connect_agent_tools import connect, discover, discover_labels, invoke connect() vocab = discover_labels() # fleet vocabulary cams = discover("device(category:camera, location:zone-A/*)") # device roster writes = discover("device(*).function(direction:write)") # function tuples - result = invoke_device("camera-001", "capture_image", {"resolution": "1080p"}) + result = invoke("device(camera-001).function(capture_image)", + {"resolution": "1080p"}) -The older ``describe_fleet`` / ``list_devices`` / ``get_device_functions`` -trio remains available for one release as advisory-deprecated wrappers -- -prefer ``discover`` / ``discover_labels`` for new code. +The older ``describe_fleet`` / ``list_devices`` / ``get_device_functions`` / +``invoke_device`` family remains available for one release as +advisory-deprecated wrappers -- prefer ``discover`` / ``discover_labels`` / +``invoke`` / ``invoke_many`` for new code. """ from device_connect_agent_tools.agent import DeviceConnectAgent @@ -25,14 +27,17 @@ # Selector-driven discovery (preferred) discover, discover_labels, - # Invocation - invoke_device, + # Selector-driven invocation (preferred) + invoke, + invoke_many, + # Other invocation helpers invoke_device_with_fallback, get_device_status, - # Advisory-deprecated discovery wrappers (one-release transition) + # Advisory-deprecated wrappers (one-release transition) describe_fleet, list_devices, get_device_functions, + invoke_device, discover_devices, ) @@ -46,13 +51,16 @@ # Selector-driven discovery (preferred) "discover", "discover_labels", - # Invocation - "invoke_device", + # Selector-driven invocation (preferred) + "invoke", + "invoke_many", + # Other invocation helpers "invoke_device_with_fallback", "get_device_status", - # Advisory-deprecated -- use discover() / discover_labels() instead + # Advisory-deprecated -- use discover / discover_labels / invoke instead "describe_fleet", "list_devices", "get_device_functions", + "invoke_device", "discover_devices", ] diff --git a/packages/device-connect-agent-tools/device_connect_agent_tools/adapters/claude.py b/packages/device-connect-agent-tools/device_connect_agent_tools/adapters/claude.py index 807abcb..9dd08d8 100644 --- a/packages/device-connect-agent-tools/device_connect_agent_tools/adapters/claude.py +++ b/packages/device-connect-agent-tools/device_connect_agent_tools/adapters/claude.py @@ -4,7 +4,7 @@ """Claude Agent SDK adapter — exposes Device Connect tools to claude-agent-sdk. -Selector-driven discovery keeps LLM context small:: +Selector-driven discovery and invocation keep LLM context small:: import anyio from claude_agent_sdk import ClaudeSDKClient, ClaudeAgentOptions, AssistantMessage, TextBlock @@ -45,7 +45,8 @@ async def main(): discover as _discover, discover_labels as _discover_labels, discover_devices as _discover_devices, - invoke_device as _invoke_device, + invoke as _invoke, + invoke_many as _invoke_many, invoke_device_with_fallback as _invoke_device_with_fallback, get_device_status as _get_device_status, ) @@ -101,27 +102,54 @@ async def discover(args: dict[str, Any]) -> dict[str, Any]: ) -# Invocation tools +# Selector-driven invocation tools (recommended) @tool( - "invoke_device", - "Call a function on a Device Connect device. Use discover() with a " - "function-scoped selector first to learn available functions and " - "parameters.", - {"device_id": str, "function": str, "params": dict, "llm_reasoning": str}, + "invoke", + "Call exactly one function on one device. The selector must resolve " + "to a single (device, function) tuple -- use device().function() " + "or function() scope. Returns {success, device_id, function, " + "result|error}. Use invoke_many for fan-out across multiple targets.", + {"selector": str, "params": dict, "llm_reasoning": str}, ) -async def invoke_device(args: dict[str, Any]) -> dict[str, Any]: +async def invoke(args: dict[str, Any]) -> dict[str, Any]: return _text( - _invoke_device( - device_id=args["device_id"], - function=args["function"], + _invoke( + selector=args["selector"], + params=args.get("params"), + llm_reasoning=args.get("llm_reasoning"), + ) + ) + + +@tool( + "invoke_many", + "Fan out a function call over a selector-resolved set of (device, " + "function) tuples in parallel. Partial-failure semantics: per-target " + "results and errors are returned even if some targets fail. Returns " + "{candidates, matched, succeeded, failed, results, errors}. Each " + "target gets a per-call timeout (default 30s).", + { + "selector": str, "params": dict, "timeout": float, + "max_concurrency": int, "llm_reasoning": str, + }, +) +async def invoke_many(args: dict[str, Any]) -> dict[str, Any]: + return _text( + _invoke_many( + selector=args["selector"], params=args.get("params"), + timeout=float(args.get("timeout", 30.0)), + max_concurrency=int(args.get("max_concurrency", 32)), llm_reasoning=args.get("llm_reasoning"), ) ) +# Other invocation helpers + + @tool( "invoke_device_with_fallback", "Call a function with automatic fallback across a list of device IDs. " @@ -148,12 +176,12 @@ async def get_device_status(args: dict[str, Any]) -> dict[str, Any]: return _text(_get_device_status(device_id=args["device_id"])) -# Backward-compatible (long-deprecated — prefer discover() / discover_labels()) +# Backward-compatible (long-deprecated -- prefer discover() / invoke()) @tool( "discover_devices", - "Deprecated — use discover() and discover_labels() instead. Discovers " + "Deprecated -- use discover() and discover_labels() instead. Discovers " "all devices with full function schemas.", {"device_type": str, "refresh": bool}, ) @@ -176,7 +204,8 @@ def create_device_connect_server(name: str = "device-connect"): tools=[ discover_labels, discover, - invoke_device, + invoke, + invoke_many, invoke_device_with_fallback, get_device_status, discover_devices, @@ -187,7 +216,8 @@ def create_device_connect_server(name: str = "device-connect"): __all__ = [ "discover_labels", "discover", - "invoke_device", + "invoke", + "invoke_many", "invoke_device_with_fallback", "get_device_status", "discover_devices", diff --git a/packages/device-connect-agent-tools/device_connect_agent_tools/adapters/langchain.py b/packages/device-connect-agent-tools/device_connect_agent_tools/adapters/langchain.py index f934024..35d5e51 100644 --- a/packages/device-connect-agent-tools/device_connect_agent_tools/adapters/langchain.py +++ b/packages/device-connect-agent-tools/device_connect_agent_tools/adapters/langchain.py @@ -4,16 +4,16 @@ """LangChain adapter — wraps Device Connect tools as LangChain StructuredTools. -Selector-driven discovery keeps LLM context small: +Selector-driven discovery and invocation keep LLM context small: from device_connect_agent_tools import connect from device_connect_agent_tools.adapters.langchain import ( - discover_labels, discover, invoke_device, + discover_labels, discover, invoke, invoke_many, ) from langgraph.prebuilt import create_react_agent connect() - agent = create_react_agent(model, [discover_labels, discover, invoke_device]) + agent = create_react_agent(model, [discover_labels, discover, invoke, invoke_many]) Requires: pip install device-connect-agent-tools[langchain] """ @@ -24,27 +24,32 @@ discover as _discover, discover_labels as _discover_labels, discover_devices as _discover_devices, - invoke_device as _invoke_device, + invoke as _invoke, + invoke_many as _invoke_many, invoke_device_with_fallback as _invoke_device_with_fallback, get_device_status as _get_device_status, ) -# Selector-driven discovery tools (recommended) +# Selector-driven discovery (recommended) discover_labels = StructuredTool.from_function(_discover_labels) discover = StructuredTool.from_function(_discover) -# Invocation tools -invoke_device = StructuredTool.from_function(_invoke_device) +# Selector-driven invocation (recommended) +invoke = StructuredTool.from_function(_invoke) +invoke_many = StructuredTool.from_function(_invoke_many) + +# Other invocation helpers invoke_device_with_fallback = StructuredTool.from_function(_invoke_device_with_fallback) get_device_status = StructuredTool.from_function(_get_device_status) -# Backward-compatible (long-deprecated — prefer discover() / discover_labels()) +# Backward-compatible (long-deprecated -- prefer discover() / invoke()) discover_devices = StructuredTool.from_function(_discover_devices) __all__ = [ "discover_labels", "discover", - "invoke_device", + "invoke", + "invoke_many", "invoke_device_with_fallback", "get_device_status", "discover_devices", diff --git a/packages/device-connect-agent-tools/device_connect_agent_tools/adapters/strands.py b/packages/device-connect-agent-tools/device_connect_agent_tools/adapters/strands.py index 848f362..d22fcf7 100644 --- a/packages/device-connect-agent-tools/device_connect_agent_tools/adapters/strands.py +++ b/packages/device-connect-agent-tools/device_connect_agent_tools/adapters/strands.py @@ -4,16 +4,16 @@ """Strands adapter — wraps Device Connect tools with @strands.tool. -Selector-driven discovery keeps LLM context small: +Selector-driven discovery and invocation keep LLM context small: from device_connect_agent_tools import connect from device_connect_agent_tools.adapters.strands import ( - discover_labels, discover, invoke_device, + discover_labels, discover, invoke, invoke_many, ) from strands import Agent connect() - agent = Agent(tools=[discover_labels, discover, invoke_device]) + agent = Agent(tools=[discover_labels, discover, invoke, invoke_many]) agent("What devices are online?") Requires: pip install device-connect-agent-tools[strands] @@ -25,27 +25,32 @@ discover as _discover, discover_labels as _discover_labels, discover_devices as _discover_devices, - invoke_device as _invoke_device, + invoke as _invoke, + invoke_many as _invoke_many, invoke_device_with_fallback as _invoke_device_with_fallback, get_device_status as _get_device_status, ) -# Selector-driven discovery tools (recommended) +# Selector-driven discovery (recommended) discover_labels = strands_tool(_discover_labels) discover = strands_tool(_discover) -# Invocation tools -invoke_device = strands_tool(_invoke_device) +# Selector-driven invocation (recommended) +invoke = strands_tool(_invoke) +invoke_many = strands_tool(_invoke_many) + +# Other invocation helpers invoke_device_with_fallback = strands_tool(_invoke_device_with_fallback) get_device_status = strands_tool(_get_device_status) -# Backward-compatible (long-deprecated — prefer discover() / discover_labels()) +# Backward-compatible (long-deprecated -- prefer discover() / invoke()) discover_devices = strands_tool(_discover_devices) __all__ = [ "discover_labels", "discover", - "invoke_device", + "invoke", + "invoke_many", "invoke_device_with_fallback", "get_device_status", "discover_devices", diff --git a/packages/device-connect-agent-tools/device_connect_agent_tools/adapters/strands_agent.py b/packages/device-connect-agent-tools/device_connect_agent_tools/adapters/strands_agent.py index a3f0cf5..c5f5e67 100644 --- a/packages/device-connect-agent-tools/device_connect_agent_tools/adapters/strands_agent.py +++ b/packages/device-connect-agent-tools/device_connect_agent_tools/adapters/strands_agent.py @@ -62,7 +62,8 @@ async def prepare(self) -> Dict[str, Any]: from device_connect_agent_tools.adapters.strands import ( discover_labels, discover, - invoke_device, + invoke, + invoke_many, invoke_device_with_fallback, get_device_status, ) @@ -74,7 +75,7 @@ async def prepare(self) -> Dict[str, Any]: model=AnthropicModel(model_id=self._model_id, max_tokens=self._max_tokens), tools=[ discover_labels, discover, - invoke_device, invoke_device_with_fallback, get_device_status, + invoke, invoke_many, invoke_device_with_fallback, get_device_status, ], system_prompt=system_prompt, ) @@ -120,14 +121,18 @@ def _build_system_prompt(self) -> str: f"functions, or events. Examples:\n" f" device(category:camera, location:zone-A/*)\n" f" device(robot-001).function(direction:write)\n" - f" function(safety:critical)\n" - f" - invoke_device(device_id, function, params) -- call a device function\n\n" + f" function(safety:critical)\n\n" + f"INVOCATION TOOLS:\n" + f" - invoke(selector, params) -- call exactly one function. " + f"Selector must resolve to one (device, function) tuple.\n" + f" - invoke_many(selector, params) -- fan out a function call " + f"over a selector-resolved set in parallel.\n\n" f"INSTRUCTIONS:\n" f"When you receive device events, you MUST:\n" f"1. Analyze the events\n" f"2. Use discover() with a function-scoped selector to check " f"available functions if needed\n" - f"3. Use invoke_device() to interact with devices\n" + f"3. Use invoke() or invoke_many() to interact with devices\n" f"4. Report what you found and what actions you took\n\n" f"Always provide llm_reasoning when invoking devices to explain your decision.\n" f"Always call at least one tool per batch of events." diff --git a/packages/device-connect-agent-tools/device_connect_agent_tools/tools.py b/packages/device-connect-agent-tools/device_connect_agent_tools/tools.py index db71bc2..528e554 100644 --- a/packages/device-connect-agent-tools/device_connect_agent_tools/tools.py +++ b/packages/device-connect-agent-tools/device_connect_agent_tools/tools.py @@ -463,6 +463,269 @@ def discover_labels( return out +# ── Selector-driven operations ─────────────────────────────────── + + +# Default per-target timeout for invoke_many fan-out. Configurable per call. +DEFAULT_INVOKE_TIMEOUT = 30.0 + +# Cap on parallel worker threads for invoke_many fan-out. Larger fleets can +# raise this via the ``max_concurrency`` argument; the default keeps thread +# overhead bounded while still parallelising typical 10-100 device fan-outs. +DEFAULT_INVOKE_CONCURRENCY = 32 + + +def _resolve_function_tuples( + selector: str, +) -> tuple[list[dict] | None, dict[str, Any] | None]: + """Resolve a selector to (device_id, function_name) tuples for invocation. + + Walks pagination so callers do not have to. Returns ``(rows, None)`` on + success or ``(None, error_envelope)`` if the selector failed to parse, + used a non-function scope, or the registry was unreachable. + """ + rows: list[dict] = [] + offset = 0 + while True: + page = discover(selector, offset=offset, limit=DISCOVER_HARD_LIMIT) + if "error" in page: + return None, page + if page["scope"] not in ( + Scope.DEVICE_FUNCTION.value, Scope.FUNCTION_ONLY.value, + ): + return None, _empty_envelope( + scope=page["scope"], + error=_error( + "invalid_invoke_scope", + "invoke/invoke_many require a function-scoped selector " + "(device(...).function(...) or function(...)); got " + f"scope={page['scope']!r}", + ), + ) + rows.extend(page["results"]) + if page["next_offset"] is None: + break + offset = page["next_offset"] + return rows, None + + +def _shape_invoke_response( + response: dict[str, Any], + device_id: str, + function_name: str, +) -> dict[str, Any]: + """Normalize a JSON-RPC response into a {success, result|error} envelope. + + JSON-RPC error objects arrive as ``{"code": int, "message": str}`` from + the wire; this maps them to the structured ``{code: str, message: str}`` + error shape that the rest of the agent surface uses. + """ + if "error" in response: + err = response["error"] + if isinstance(err, dict): + code = str(err.get("code", "invoke_failed")) + message = str(err.get("message", err)) + else: + code, message = "invoke_failed", str(err) + return { + "success": False, + "device_id": device_id, + "function": function_name, + "error": {"code": code, "message": message}, + } + return { + "success": True, + "device_id": device_id, + "function": function_name, + "result": response.get("result", {}), + } + + +def invoke( + selector: str, + params: dict[str, Any] | None = None, + llm_reasoning: str | None = None, +) -> dict[str, Any]: + """Resolve a selector to one (device, function) tuple and invoke it. + + Use this when the call is unambiguous -- one device, one function. + The selector must use ``device().function()`` or + ``function()`` scope. + + Args: + selector: Selector expression resolving to exactly one function. + params: Function parameters dict. Do NOT put ``llm_reasoning`` + inside ``params``. + llm_reasoning: Decision rationale for observability. + + Returns: + On success: ``{"success": True, "device_id": ..., "function": ..., + "result": ...}``. + On failure: ``{"success": False, "error": {"code": ..., + "message": ...}}``. Codes include the discover() codes plus + ``no_match`` (zero matches), ``ambiguous_match`` (multiple + matches), ``invalid_invoke_scope`` (selector did not target + functions), and ``invoke_failed`` (the device returned an error). + """ + rows, error_envelope = _resolve_function_tuples(selector) + if error_envelope is not None: + return {"success": False, "error": error_envelope["error"]} + + if not rows: + return { + "success": False, + "error": _error( + "no_match", + f"selector matched 0 functions: {selector!r}", + ), + } + if len(rows) > 1: + return { + "success": False, + "error": _error( + "ambiguous_match", + f"selector matched {len(rows)} functions, expected exactly 1: " + f"{selector!r}", + ), + "candidates": [ + {"device_id": r.get("device_id"), "function": r.get("name")} + for r in rows[:10] + ], + } + + row = rows[0] + device_id = row.get("device_id") or "" + function_name = row.get("name") or "" + + trace_id = f"trace-{uuid.uuid4().hex[:12]}" + if llm_reasoning: + truncated = ( + llm_reasoning[:200] + "..." + if len(llm_reasoning) > 200 else llm_reasoning + ) + logger.info( + "[%s] [%s::%s] Reason: %s", + trace_id, device_id, function_name, truncated, + ) + + try: + conn = get_connection() + clean = {k: v for k, v in (params or {}).items() if k != "llm_reasoning"} + response = conn.invoke(device_id, function_name, params=clean) + except Exception as e: + logger.error( + "[%s] %s::%s -> ERROR: %s", + trace_id, device_id, function_name, e, + ) + return { + "success": False, + "device_id": device_id, + "function": function_name, + "error": _error("invoke_failed", str(e)), + } + return _shape_invoke_response(response, device_id, function_name) + + +def invoke_many( + selector: str, + params: dict[str, Any] | None = None, + timeout: float = DEFAULT_INVOKE_TIMEOUT, + max_concurrency: int = DEFAULT_INVOKE_CONCURRENCY, + llm_reasoning: str | None = None, +) -> dict[str, Any]: + """Resolve a selector to (device, function) tuples and invoke each in parallel. + + Returns aggregated results with partial-failure semantics: a single + target's failure does not abort the rest. Each target gets ``timeout`` + seconds; the overall call returns once every target has finished or + timed out. + + Args: + selector: Function-scoped selector + (``device(...).function(...)`` or ``function(...)``). + params: Function parameters dict applied to every target. + timeout: Per-target timeout in seconds. + max_concurrency: Cap on parallel worker threads. + llm_reasoning: Decision rationale for observability. + + Returns: + ``{"candidates": N, "matched": N, "succeeded": S, "failed": F, + "results": [{device_id, function, result}, ...], + "errors": [{device_id, function, error}, ...]}``. + + ``candidates`` is the count returned by the selector resolver. + ``matched`` is the same value in this release; once edge-side + ``where`` predicates land, ``matched`` will narrow below + ``candidates`` to reflect post-predicate self-election. + + On selector parse / connection failure the envelope is returned + with all counts at zero plus a top-level ``error`` field. + """ + import concurrent.futures + + rows, error_envelope = _resolve_function_tuples(selector) + if error_envelope is not None: + return { + "candidates": 0, "matched": 0, "succeeded": 0, "failed": 0, + "results": [], "errors": [], "error": error_envelope["error"], + } + + out: dict[str, Any] = { + "candidates": len(rows), + "matched": len(rows), + "succeeded": 0, + "failed": 0, + "results": [], + "errors": [], + } + if not rows: + return out + + workers = max(1, min(max_concurrency, len(rows))) + clean = {k: v for k, v in (params or {}).items() if k != "llm_reasoning"} + + def call_one(row: dict) -> dict[str, Any]: + device_id = row.get("device_id") or "" + function_name = row.get("name") or "" + try: + conn = get_connection() + response = conn.invoke( + device_id, function_name, params=clean, timeout=timeout, + ) + except Exception as e: + response = {"error": {"code": "invoke_failed", "message": str(e)}} + return _shape_invoke_response(response, device_id, function_name) + + if llm_reasoning: + truncated = ( + llm_reasoning[:200] + "..." + if len(llm_reasoning) > 200 else llm_reasoning + ) + logger.info( + "[invoke_many::%d targets] Reason: %s", len(rows), truncated, + ) + + with concurrent.futures.ThreadPoolExecutor(max_workers=workers) as exe: + futures = [exe.submit(call_one, row) for row in rows] + for future in concurrent.futures.as_completed(futures): + shaped = future.result() + if shaped["success"]: + out["results"].append({ + "device_id": shaped["device_id"], + "function": shaped["function"], + "result": shaped["result"], + }) + out["succeeded"] += 1 + else: + out["errors"].append({ + "device_id": shaped["device_id"], + "function": shaped["function"], + "error": shaped["error"], + }) + out["failed"] += 1 + return out + + # ── Hierarchical discovery tools ───────────────────────────────── @@ -650,22 +913,20 @@ def invoke_device( params: dict[str, Any] | None = None, llm_reasoning: str | None = None, ) -> dict[str, Any]: - """Call a function on a Device Connect device. + """Call a function on a Device Connect device (deprecated; use invoke()). Args: device_id: Target device ID (e.g., "robot-001", "camera-001"). - function: Function name to call (e.g., "start_cleaning", "capture_image"). - params: Function parameters as a dictionary. Check get_device_functions() for schemas. - Do NOT put llm_reasoning inside params. - llm_reasoning: Why you're calling this function — for observability. - - Example: - result = invoke_device( - device_id="robot-001", function="start_cleaning", - params={"zone": "zone-A"}, - llm_reasoning="Camera detected spill in zone-A" - ) + function: Function name to call. + params: Function parameters as a dictionary. + llm_reasoning: Why you're calling this function -- for observability. """ + warnings.warn( + "invoke_device(device_id, function, ...) is deprecated; use " + "invoke('device().function()', params) instead.", + DeprecationWarning, + stacklevel=2, + ) trace_id = f"trace-{uuid.uuid4().hex[:12]}" if llm_reasoning: truncated = llm_reasoning[:200] + "..." if len(llm_reasoning) > 200 else llm_reasoning diff --git a/packages/device-connect-agent-tools/tests/test_claude_adapter.py b/packages/device-connect-agent-tools/tests/test_claude_adapter.py index b0e2ac6..311aab5 100644 --- a/packages/device-connect-agent-tools/tests/test_claude_adapter.py +++ b/packages/device-connect-agent-tools/tests/test_claude_adapter.py @@ -68,7 +68,8 @@ def _mock_sdk_and_connection(): "discover_labels", "discover", "discover_devices", - "invoke_device", + "invoke", + "invoke_many", "invoke_device_with_fallback", "get_device_status", ) diff --git a/packages/device-connect-agent-tools/tests/test_invoke.py b/packages/device-connect-agent-tools/tests/test_invoke.py new file mode 100644 index 0000000..aae1a83 --- /dev/null +++ b/packages/device-connect-agent-tools/tests/test_invoke.py @@ -0,0 +1,336 @@ +# Copyright (c) 2024-2026, Arm Limited and Contributors. All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +"""Unit tests for the selector-driven ``invoke`` and ``invoke_many`` tools. + +Uses a small labeled fleet (cam-001, cam-002, robot-001, sensor-001) drawn +from the existing DC test driver vocabulary so every selector exercises +real device, function, and event names. +""" +import time +from unittest.mock import MagicMock, patch + +import pytest + +from device_connect_agent_tools import tools as tools_mod + + +# -- Fixtures ------------------------------------------------------- + + +SAMPLE_DEVICES = [ + { + "device_id": "cam-001", + "device_type": "camera", + "location": "lab-A", + "status": {"state": "online"}, + "identity": {"device_type": "camera"}, + "labels": {"category": "camera", "location": "lab-A"}, + "functions": [ + { + "name": "capture_image", + "parameters": {}, + "labels": {"direction": "write", "modality": "rgb"}, + }, + ], + "events": [], + }, + { + "device_id": "cam-002", + "device_type": "camera", + "location": "lab-A", + "status": {"state": "online"}, + "identity": {"device_type": "camera"}, + "labels": {"category": "camera", "location": "lab-A"}, + "functions": [ + { + "name": "capture_image", + "parameters": {}, + "labels": {"direction": "write", "modality": "rgb"}, + }, + ], + "events": [], + }, + { + "device_id": "robot-001", + "device_type": "cleaner_robot", + "location": "lab-A", + "status": {"state": "idle"}, + "identity": {"device_type": "cleaner_robot"}, + "labels": {"category": "robot", "location": "lab-A"}, + "functions": [ + { + "name": "dispatch_robot", + "parameters": {}, + "labels": {"direction": "write", "safety": "critical"}, + }, + ], + "events": [], + }, + { + "device_id": "sensor-001", + "device_type": "temperature_sensor", + "location": "lab-B", + "status": {"state": "online"}, + "identity": {"device_type": "temperature_sensor"}, + "labels": {"category": "sensor", "location": "lab-B"}, + "functions": [ + { + "name": "get_reading", + "parameters": {}, + "labels": {"direction": "read"}, + }, + ], + "events": [], + }, +] + + +def _conn_with_invoke(invoke_side_effect): + """Return a mock Connection whose .invoke() applies ``invoke_side_effect``. + + ``invoke_side_effect`` is called with ``(device_id, function_name, + params, timeout)`` and must return a JSON-RPC response dict. + """ + conn = MagicMock() + conn.list_devices.return_value = SAMPLE_DEVICES + + def _invoke(device_id, function_name, params=None, timeout=None): + return invoke_side_effect(device_id, function_name, params, timeout) + + conn.invoke.side_effect = _invoke + return conn + + +@pytest.fixture +def all_succeed_conn(): + def _ok(device_id, function_name, params, timeout): + return {"jsonrpc": "2.0", "id": "1", "result": { + "device_id": device_id, "function": function_name, "params": params, + }} + conn = _conn_with_invoke(_ok) + with patch.object(tools_mod, "get_connection", return_value=conn): + yield conn + + +# -- invoke --------------------------------------------------------- + + +class TestInvoke: + def test_single_match_returns_success(self, all_succeed_conn): + r = tools_mod.invoke( + "device(cam-001).function(capture_image)", + params={"resolution": "1080p"}, + ) + assert r["success"] is True + assert r["device_id"] == "cam-001" + assert r["function"] == "capture_image" + assert r["result"]["params"] == {"resolution": "1080p"} + + def test_function_only_selector_with_unique_name(self, all_succeed_conn): + r = tools_mod.invoke("function(get_reading)") + assert r["success"] is True + assert r["device_id"] == "sensor-001" + assert r["function"] == "get_reading" + + def test_no_match_returns_no_match_error(self, all_succeed_conn): + r = tools_mod.invoke("device(*).function(does_not_exist)") + assert r["success"] is False + assert r["error"]["code"] == "no_match" + assert "does_not_exist" in r["error"]["message"] + + def test_ambiguous_match_returns_error_with_candidates(self, all_succeed_conn): + # capture_image exists on both cam-001 and cam-002. + r = tools_mod.invoke("function(capture_image)") + assert r["success"] is False + assert r["error"]["code"] == "ambiguous_match" + assert "expected exactly 1" in r["error"]["message"] + ids = {c["device_id"] for c in r["candidates"]} + assert ids == {"cam-001", "cam-002"} + + def test_device_only_scope_rejected(self, all_succeed_conn): + # Device-only scope cannot resolve to a function. + r = tools_mod.invoke("device(robot-001)") + assert r["success"] is False + assert r["error"]["code"] == "invalid_invoke_scope" + + def test_event_scope_rejected(self, all_succeed_conn): + r = tools_mod.invoke("event(reading)") + assert r["success"] is False + assert r["error"]["code"] == "invalid_invoke_scope" + + def test_selector_parse_error_propagated(self, all_succeed_conn): + r = tools_mod.invoke("not a selector") + assert r["success"] is False + assert r["error"]["code"] == "selector_parse_error" + + def test_non_string_selector_rejected(self, all_succeed_conn): + r = tools_mod.invoke(None) # type: ignore[arg-type] + assert r["success"] is False + assert r["error"]["code"] == "invalid_selector" + + def test_jsonrpc_error_maps_to_invoke_failed(self): + def _err(device_id, function_name, params, timeout): + return { + "jsonrpc": "2.0", "id": "1", + "error": {"code": -32000, "message": "device busy"}, + } + conn = _conn_with_invoke(_err) + with patch.object(tools_mod, "get_connection", return_value=conn): + r = tools_mod.invoke("device(robot-001).function(dispatch_robot)") + assert r["success"] is False + assert r["error"]["code"] == "-32000" + assert r["error"]["message"] == "device busy" + assert r["device_id"] == "robot-001" + assert r["function"] == "dispatch_robot" + + def test_connection_exception_returns_invoke_failed(self): + conn = MagicMock() + conn.list_devices.return_value = SAMPLE_DEVICES + conn.invoke.side_effect = RuntimeError("messaging down") + with patch.object(tools_mod, "get_connection", return_value=conn): + r = tools_mod.invoke("device(cam-001).function(capture_image)") + assert r["success"] is False + assert r["error"]["code"] == "invoke_failed" + assert "messaging down" in r["error"]["message"] + + def test_llm_reasoning_stripped_from_params(self, all_succeed_conn): + tools_mod.invoke( + "device(cam-001).function(capture_image)", + params={"resolution": "1080p", "llm_reasoning": "should not appear"}, + llm_reasoning="caller reasoning", + ) + # Inspect the params actually delivered to the wire: + sent = all_succeed_conn.invoke.call_args.kwargs["params"] + assert "llm_reasoning" not in sent + assert sent["resolution"] == "1080p" + + +# -- invoke_many ---------------------------------------------------- + + +class TestInvokeMany: + def test_zero_matches_returns_empty_envelope(self, all_succeed_conn): + r = tools_mod.invoke_many("device(*).function(does_not_exist)") + assert r["candidates"] == 0 + assert r["matched"] == 0 + assert r["succeeded"] == 0 + assert r["failed"] == 0 + assert r["results"] == [] + assert r["errors"] == [] + assert "error" not in r + + def test_all_succeed(self, all_succeed_conn): + r = tools_mod.invoke_many("device(*).function(capture_image)") + assert r["candidates"] == 2 + assert r["matched"] == 2 + assert r["succeeded"] == 2 + assert r["failed"] == 0 + ids = {row["device_id"] for row in r["results"]} + assert ids == {"cam-001", "cam-002"} + # Each result row is shaped {device_id, function, result}. + for row in r["results"]: + assert row["function"] == "capture_image" + assert "result" in row + + def test_partial_failure_shape(self): + def _half_fail(device_id, function_name, params, timeout): + if device_id == "cam-001": + return {"jsonrpc": "2.0", "id": "1", "result": {"ok": True}} + return { + "jsonrpc": "2.0", "id": "1", + "error": {"code": -32000, "message": "down"}, + } + conn = _conn_with_invoke(_half_fail) + with patch.object(tools_mod, "get_connection", return_value=conn): + r = tools_mod.invoke_many("device(*).function(capture_image)") + assert r["candidates"] == 2 + assert r["matched"] == 2 + assert r["succeeded"] == 1 + assert r["failed"] == 1 + assert {row["device_id"] for row in r["results"]} == {"cam-001"} + assert {row["device_id"] for row in r["errors"]} == {"cam-002"} + for row in r["errors"]: + assert row["error"]["code"] == "-32000" + assert row["error"]["message"] == "down" + + def test_invalid_scope_returns_error_envelope(self, all_succeed_conn): + r = tools_mod.invoke_many("device(robot-001)") + assert r["candidates"] == 0 + assert r["error"]["code"] == "invalid_invoke_scope" + + def test_selector_parse_error_propagated(self, all_succeed_conn): + r = tools_mod.invoke_many("widgets(*)") + assert r["candidates"] == 0 + assert r["error"]["code"] == "selector_parse_error" + + def test_per_target_timeout_passed_to_connection(self, all_succeed_conn): + tools_mod.invoke_many( + "device(*).function(capture_image)", timeout=7.5, + ) + # Every conn.invoke call should carry the same timeout. + for call in all_succeed_conn.invoke.call_args_list: + assert call.kwargs["timeout"] == 7.5 + + def test_max_concurrency_caps_thread_pool(self, all_succeed_conn): + # The fan-out group has 3 targets (capture_image x2 + dispatch_robot + # don't share name; pick a selector that resolves to multiple). Use + # function(direction:write) which selects 4 distinct rows. + r = tools_mod.invoke_many( + "function(direction:write)", max_concurrency=1, + ) + assert r["candidates"] >= 2 + assert r["succeeded"] == r["candidates"] + + def test_connection_exception_recorded_per_target(self): + # Mix: cam-001 succeeds, cam-002's call raises locally. + def _mixed(device_id, function_name, params, timeout): + if device_id == "cam-002": + raise RuntimeError("messaging blip") + return {"jsonrpc": "2.0", "id": "1", "result": {"ok": True}} + conn = _conn_with_invoke(_mixed) + with patch.object(tools_mod, "get_connection", return_value=conn): + r = tools_mod.invoke_many("device(*).function(capture_image)") + assert r["succeeded"] == 1 + assert r["failed"] == 1 + cam002_err = next(e for e in r["errors"] if e["device_id"] == "cam-002") + assert cam002_err["error"]["code"] == "invoke_failed" + assert "messaging blip" in cam002_err["error"]["message"] + + def test_llm_reasoning_stripped_from_params(self, all_succeed_conn): + tools_mod.invoke_many( + "device(*).function(capture_image)", + params={"resolution": "4k", "llm_reasoning": "should not appear"}, + ) + for call in all_succeed_conn.invoke.call_args_list: + sent = call.kwargs["params"] + assert "llm_reasoning" not in sent + assert sent["resolution"] == "4k" + + +# -- _resolve_function_tuples --------------------------------------- + + +class TestResolveFunctionTuples: + def test_walks_all_pages(self, all_succeed_conn): + # Use a small DISCOVER_HARD_LIMIT temporarily. + with patch.object(tools_mod, "DISCOVER_HARD_LIMIT", 1): + rows, err = tools_mod._resolve_function_tuples( + "device(*).function(direction:write)" + ) + assert err is None + # 4 distinct (device, function) tuples for direction:write across the + # mock fleet (cam-001, cam-002, robot-001, sensor-001 set_threshold + # and set_location). With limit=1 per page, the resolver had to + # paginate through all of them. + assert len(rows) >= 2 + for row in rows: + assert "device_id" in row + assert "name" in row + + def test_propagates_discover_error(self, all_succeed_conn): + rows, err = tools_mod._resolve_function_tuples("not a selector") + assert rows is None + assert err is not None + assert err["error"]["code"] == "selector_parse_error" diff --git a/packages/device-connect-agent-tools/tests/test_langchain_adapter.py b/packages/device-connect-agent-tools/tests/test_langchain_adapter.py index d647ee3..c4a487e 100644 --- a/packages/device-connect-agent-tools/tests/test_langchain_adapter.py +++ b/packages/device-connect-agent-tools/tests/test_langchain_adapter.py @@ -72,7 +72,8 @@ def _mock_langchain_and_connection(): EXPECTED_TOOLS = { "discover_labels", "discover", - "invoke_device", + "invoke", + "invoke_many", "invoke_device_with_fallback", "get_device_status", "discover_devices", diff --git a/packages/device-connect-agent-tools/tests/test_strands_adapter.py b/packages/device-connect-agent-tools/tests/test_strands_adapter.py index a40b5ad..30d1ae0 100644 --- a/packages/device-connect-agent-tools/tests/test_strands_adapter.py +++ b/packages/device-connect-agent-tools/tests/test_strands_adapter.py @@ -55,7 +55,8 @@ def _mock_strands_and_connection(): EXPECTED_TOOLS = { "discover_labels", "discover", - "invoke_device", + "invoke", + "invoke_many", "invoke_device_with_fallback", "get_device_status", "discover_devices", diff --git a/packages/device-connect-server/device_connect_server/portal/views/devices.py b/packages/device-connect-server/device_connect_server/portal/views/devices.py index 3f82309..7f5bf1e 100644 --- a/packages/device-connect-server/device_connect_server/portal/views/devices.py +++ b/packages/device-connect-server/device_connect_server/portal/views/devices.py @@ -320,7 +320,7 @@ async def download_starter_script(request: web.Request): """Device Connect — starter AI agent (Strands + OpenAI). Connects to Device Connect, discovers your fleet, and reacts to device -events by calling tools (discover_labels, discover, invoke_device). +events by calling tools (discover_labels, discover, invoke, invoke_many). LLM inference runs through the Arm internal OpenAI proxy. Usage: @@ -404,7 +404,7 @@ async def prepare(self) -> Dict[str, Any]: from strands.models.openai import OpenAIModel from device_connect_agent_tools.adapters.strands import ( discover_labels, discover, - invoke_device, invoke_device_with_fallback, get_device_status, + invoke, invoke_many, invoke_device_with_fallback, get_device_status, ) result = await super().prepare() @@ -417,7 +417,7 @@ async def prepare(self) -> Dict[str, Any]: ), tools=[ discover_labels, discover, - invoke_device, invoke_device_with_fallback, get_device_status, + invoke, invoke_many, invoke_device_with_fallback, get_device_status, ], system_prompt=self._build_system_prompt(), ) @@ -454,14 +454,18 @@ def _build_system_prompt(self) -> str: f"functions, or events. Examples:\\n" f" device(category:camera, location:zone-A/*)\\n" f" device(robot-001).function(direction:write)\\n" - f" function(safety:critical)\\n" - f" - invoke_device(device_id, function, params) -- call a device function\\n\\n" + f" function(safety:critical)\\n\\n" + f"INVOCATION TOOLS:\\n" + f" - invoke(selector, params) -- call exactly one function. " + f"Selector must resolve to one (device, function) tuple.\\n" + f" - invoke_many(selector, params) -- fan out a function call " + f"over a selector-resolved set in parallel.\\n\\n" f"INSTRUCTIONS:\\n" f"When you receive device events, you MUST:\\n" f"1. Analyze the events\\n" f"2. Use discover() with a function-scoped selector to check " f"available functions if needed\\n" - f"3. Use invoke_device() to interact with devices\\n" + f"3. Use invoke() or invoke_many() to interact with devices\\n" f"4. Report what you found and what actions you took\\n\\n" f"Always provide llm_reasoning when invoking devices.\\n" f"Always call at least one tool per batch of events." diff --git a/tests/tests/test_tools_invoke.py b/tests/tests/test_tools_invoke.py index 447f301..df9878b 100644 --- a/tests/tests/test_tools_invoke.py +++ b/tests/tests/test_tools_invoke.py @@ -2,40 +2,65 @@ # # SPDX-License-Identifier: Apache-2.0 -"""Integration tests for device-connect-agent-tools invoke_device(). +"""Integration tests for selector-driven invocation tools. -Tests that the agent SDK can invoke device RPCs via the messaging backend. +Covers ``invoke()`` and ``invoke_many()`` against real devices registered +via the messaging backend. Exercises single-match, ambiguous-match, +selector-scope rejection, parallel fan-out, and partial-failure semantics +end-to-end. """ import asyncio -import pytest +import time +import pytest SETTLE_TIME = 0.3 +DISCOVERY_TIMEOUT = 5.0 + + +async def _wait_for_devices(messaging_url, expected_ids): + """Connect and poll until all expected ``device_ids`` are visible.""" + from device_connect_agent_tools import connect + from device_connect_agent_tools.connection import get_connection + + await asyncio.to_thread(connect, nats_url=messaging_url) + deadline = time.monotonic() + DISCOVERY_TIMEOUT + while True: + conn = get_connection() + devices = await asyncio.to_thread(conn.list_devices) + ids = {d.get("device_id") for d in devices} + if expected_ids.issubset(ids) or time.monotonic() > deadline: + return devices + await asyncio.sleep(0.25) + + +# -- invoke --------------------------------------------------------- @pytest.mark.asyncio @pytest.mark.integration async def test_invoke_sensor_reading(device_spawner, messaging_url): - """invoke_device() should call sensor's get_reading and return result.""" + """invoke() calls sensor.get_reading and returns the reading payload.""" await device_spawner.spawn_sensor( - "itest-tools-invoke-sensor", initial_temp=23.5, initial_humidity=50.0, + "itest-inv-read-sensor", initial_temp=23.5, initial_humidity=50.0, ) await asyncio.sleep(SETTLE_TIME) - from device_connect_agent_tools import connect, disconnect, invoke_device + from device_connect_agent_tools import disconnect, invoke - await asyncio.to_thread(connect, nats_url=messaging_url) + await _wait_for_devices(messaging_url, {"itest-inv-read-sensor"}) try: result = await asyncio.to_thread( - invoke_device, - device_id="itest-tools-invoke-sensor", - function="get_reading", - params={"unit": "celsius"}, - llm_reasoning="Testing sensor read", + invoke, + "device(itest-inv-read-sensor).function(get_reading)", + {"unit": "celsius"}, + "Testing sensor read", ) - assert isinstance(result, dict) - assert result.get("success") is True or "temperature" in result.get("result", {}) + assert result["success"] is True + assert result["device_id"] == "itest-inv-read-sensor" + assert result["function"] == "get_reading" + assert "temperature" in result["result"] finally: await asyncio.to_thread(disconnect) @@ -43,25 +68,26 @@ async def test_invoke_sensor_reading(device_spawner, messaging_url): @pytest.mark.asyncio @pytest.mark.integration async def test_invoke_robot_dispatch(device_spawner, event_capture, messaging_url): - """invoke_device() should dispatch robot and trigger cleaning.""" + """invoke() dispatches the robot and the cleaning_finished event arrives.""" await device_spawner.spawn_robot( - "itest-tools-invoke-robot", clean_duration=0.3, + "itest-inv-robot", clean_duration=0.3, ) await asyncio.sleep(SETTLE_TIME) - async with event_capture.subscribe("device-connect.*.itest-tools-invoke-robot.event.*") as events: - from device_connect_agent_tools import connect, disconnect, invoke_device + async with event_capture.subscribe( + "device-connect.*.itest-inv-robot.event.*" + ) as events: + from device_connect_agent_tools import disconnect, invoke - await asyncio.to_thread(connect, nats_url=messaging_url) + await _wait_for_devices(messaging_url, {"itest-inv-robot"}) try: result = await asyncio.to_thread( - invoke_device, - device_id="itest-tools-invoke-robot", - function="dispatch_robot", - params={"zone_id": "zone-tools"}, - llm_reasoning="Testing robot dispatch via tools", + invoke, + "device(itest-inv-robot).function(dispatch_robot)", + {"zone_id": "zone-tools"}, + "Testing robot dispatch", ) - assert isinstance(result, dict) + assert result["success"] is True finally: await asyncio.to_thread(disconnect) @@ -71,42 +97,184 @@ async def test_invoke_robot_dispatch(device_spawner, event_capture, messaging_ur @pytest.mark.asyncio @pytest.mark.integration -async def test_invoke_unknown_device(messaging_url): - """invoke_device() on non-existent device should return error.""" - from device_connect_agent_tools import connect, disconnect, invoke_device +async def test_invoke_no_match_returns_no_match(device_spawner, messaging_url): + """A selector that resolves to zero functions returns ``no_match``.""" + await device_spawner.spawn_camera("itest-inv-nomatch-cam", location="lab-A") + await asyncio.sleep(SETTLE_TIME) + + from device_connect_agent_tools import connect, disconnect, invoke await asyncio.to_thread(connect, nats_url=messaging_url) try: result = await asyncio.to_thread( - invoke_device, - device_id="nonexistent-device-xyz", - function="ping", - llm_reasoning="Testing error handling", + invoke, + "device(itest-inv-nomatch-cam).function(does_not_exist)", + ) + assert result["success"] is False + assert result["error"]["code"] == "no_match" + finally: + await asyncio.to_thread(disconnect) + + +@pytest.mark.asyncio +@pytest.mark.integration +async def test_invoke_ambiguous_match_returns_error(device_spawner, messaging_url): + """A selector matching multiple (device, function) tuples returns an error.""" + await device_spawner.spawn_camera("itest-inv-amb-cam-1", location="lab-A") + await device_spawner.spawn_camera("itest-inv-amb-cam-2", location="lab-A") + await asyncio.sleep(SETTLE_TIME) + + from device_connect_agent_tools import disconnect, invoke + + await _wait_for_devices( + messaging_url, {"itest-inv-amb-cam-1", "itest-inv-amb-cam-2"} + ) + try: + result = await asyncio.to_thread( + invoke, "device(itest-inv-amb-cam-*).function(capture_image)", + ) + assert result["success"] is False + assert result["error"]["code"] == "ambiguous_match" + cand_ids = {c["device_id"] for c in result["candidates"]} + assert {"itest-inv-amb-cam-1", "itest-inv-amb-cam-2"} <= cand_ids + finally: + await asyncio.to_thread(disconnect) + + +@pytest.mark.asyncio +@pytest.mark.integration +async def test_invoke_device_only_scope_rejected(device_spawner, messaging_url): + """A device-only selector cannot resolve to a function.""" + await device_spawner.spawn_camera("itest-inv-scope-cam", location="lab-A") + await asyncio.sleep(SETTLE_TIME) + + from device_connect_agent_tools import connect, disconnect, invoke + + await asyncio.to_thread(connect, nats_url=messaging_url) + try: + result = await asyncio.to_thread(invoke, "device(itest-inv-scope-cam)") + assert result["success"] is False + assert result["error"]["code"] == "invalid_invoke_scope" + finally: + await asyncio.to_thread(disconnect) + + +# -- invoke_many ---------------------------------------------------- + + +@pytest.mark.asyncio +@pytest.mark.integration +async def test_invoke_many_succeeds_across_devices(device_spawner, messaging_url): + """invoke_many() fans out a single function across multiple matching devices.""" + await device_spawner.spawn_camera("itest-inv-many-cam-1", location="lab-A") + await device_spawner.spawn_camera("itest-inv-many-cam-2", location="lab-A") + await asyncio.sleep(SETTLE_TIME) + + from device_connect_agent_tools import disconnect, invoke_many + + await _wait_for_devices( + messaging_url, {"itest-inv-many-cam-1", "itest-inv-many-cam-2"} + ) + try: + result = await asyncio.to_thread( + invoke_many, + "device(itest-inv-many-cam-*).function(capture_image)", + {"resolution": "720p"}, ) - assert isinstance(result, dict) - assert result.get("success") is False + assert result["candidates"] == 2 + assert result["matched"] == 2 + assert result["succeeded"] == 2 + assert result["failed"] == 0 + ids = {row["device_id"] for row in result["results"]} + assert ids == {"itest-inv-many-cam-1", "itest-inv-many-cam-2"} finally: await asyncio.to_thread(disconnect) @pytest.mark.asyncio @pytest.mark.integration -async def test_invoke_camera_capture(device_spawner, messaging_url): - """invoke_device() should capture image from camera.""" - await device_spawner.spawn_camera("itest-tools-invoke-cam") +async def test_invoke_many_partial_failure(device_spawner, messaging_url): + """A failing target is recorded in errors while siblings succeed.""" + await device_spawner.spawn_camera( + "itest-inv-many-pf-cam-1", location="lab-A", failure_rate=1.0, + ) + await device_spawner.spawn_camera( + "itest-inv-many-pf-cam-2", location="lab-A", + ) await asyncio.sleep(SETTLE_TIME) - from device_connect_agent_tools import connect, disconnect, invoke_device + from device_connect_agent_tools import disconnect, invoke_many + + await _wait_for_devices( + messaging_url, + {"itest-inv-many-pf-cam-1", "itest-inv-many-pf-cam-2"}, + ) + try: + result = await asyncio.to_thread( + invoke_many, + "device(itest-inv-many-pf-cam-*).function(capture_image)", + ) + assert result["candidates"] == 2 + assert result["matched"] == 2 + assert result["succeeded"] == 1 + assert result["failed"] == 1 + success_ids = {row["device_id"] for row in result["results"]} + error_ids = {row["device_id"] for row in result["errors"]} + assert success_ids == {"itest-inv-many-pf-cam-2"} + assert error_ids == {"itest-inv-many-pf-cam-1"} + for row in result["errors"]: + assert "code" in row["error"] + assert "message" in row["error"] + finally: + await asyncio.to_thread(disconnect) + + +@pytest.mark.asyncio +@pytest.mark.integration +async def test_invoke_many_zero_candidates(device_spawner, messaging_url): + """No matches yields an empty envelope, not an error.""" + await device_spawner.spawn_camera("itest-inv-many-zero-cam", location="lab-A") + await asyncio.sleep(SETTLE_TIME) + + from device_connect_agent_tools import connect, disconnect, invoke_many await asyncio.to_thread(connect, nats_url=messaging_url) try: result = await asyncio.to_thread( - invoke_device, - device_id="itest-tools-invoke-cam", - function="capture_image", - params={"resolution": "720p"}, - llm_reasoning="Testing camera capture via tools", + invoke_many, + "device(itest-no-such-device).function(capture_image)", ) - assert isinstance(result, dict) + assert result["candidates"] == 0 + assert result["matched"] == 0 + assert result["succeeded"] == 0 + assert result["failed"] == 0 + assert result["results"] == [] + assert result["errors"] == [] + assert "error" not in result + finally: + await asyncio.to_thread(disconnect) + + +@pytest.mark.asyncio +@pytest.mark.integration +async def test_invoke_many_function_only_selector(device_spawner, messaging_url): + """function() selects the function across the whole fleet.""" + await device_spawner.spawn_sensor( + "itest-inv-many-fo-sensor", initial_temp=20.0, + ) + await device_spawner.spawn_camera("itest-inv-many-fo-cam", location="lab-A") + await asyncio.sleep(SETTLE_TIME) + + from device_connect_agent_tools import disconnect, invoke_many + + await _wait_for_devices( + messaging_url, {"itest-inv-many-fo-cam", "itest-inv-many-fo-sensor"} + ) + try: + result = await asyncio.to_thread(invoke_many, "function(get_reading)") + ids = {row["device_id"] for row in result["results"]} + assert "itest-inv-many-fo-sensor" in ids + # Camera does not have get_reading; should not be in results. + assert "itest-inv-many-fo-cam" not in ids finally: await asyncio.to_thread(disconnect) From b64edac4d6426d6443ee7f071bcbff63e60b9429 Mon Sep 17 00:00:00 2001 From: Sourav Pati Date: Sun, 10 May 2026 11:15:16 -0700 Subject: [PATCH 08/21] feat(predicate): add CEL where evaluator with optional [predicate] extra Add device_connect_edge.predicate, a thin wrapper around cel-python that compiles where expressions into reusable WherePredicate objects and evaluates them against device-local context (identity, labels, status, shared bindings). CEL was chosen over JSONLogic because the v4 design's mask-indexing pattern (mask[seat_row][seat_col] == 1) needs computed array indices, which JSONLogic's literal-path var operator cannot express without flattening the mask to 1D and indexing arithmetically. CEL handles it natively. cel-python is an optional dependency. Importing the module without it installed succeeds; compiling or evaluating a predicate raises a clear PredicateCompileError pointing at the [predicate] extra: pip install device-connect-edge[predicate] pip install device-connect-agent-tools[predicate] The evaluator is shared by the dispatcher (validates expressions before sending them out) and the device runtime (evaluates per-call to decide whether to execute a fan-out). 16 unit tests cover compilation, evaluation, the mask-indexing regression case, missing-variable and type-mismatch error surfaces, and evaluator reusability. --- .../device-connect-agent-tools/pyproject.toml | 1 + .../device_connect_edge/predicate.py | 163 ++++++++++++++++++ packages/device-connect-edge/pyproject.toml | 3 + .../tests/test_predicate.py | 131 ++++++++++++++ 4 files changed, 298 insertions(+) create mode 100644 packages/device-connect-edge/device_connect_edge/predicate.py create mode 100644 packages/device-connect-edge/tests/test_predicate.py diff --git a/packages/device-connect-agent-tools/pyproject.toml b/packages/device-connect-agent-tools/pyproject.toml index ec0f198..606073c 100644 --- a/packages/device-connect-agent-tools/pyproject.toml +++ b/packages/device-connect-agent-tools/pyproject.toml @@ -37,6 +37,7 @@ strands = ["strands-agents>=1.0"] langchain = ["langchain-core>=0.2"] claude = ["claude-agent-sdk>=0.1"] mcp = ["fastmcp>=1.0"] +predicate = ["device-connect-edge[predicate]"] dev = [ "pytest>=8.0", "pytest-asyncio>=0.23", diff --git a/packages/device-connect-edge/device_connect_edge/predicate.py b/packages/device-connect-edge/device_connect_edge/predicate.py new file mode 100644 index 0000000..6ddc7c0 --- /dev/null +++ b/packages/device-connect-edge/device_connect_edge/predicate.py @@ -0,0 +1,163 @@ +# Copyright (c) 2024-2026, Arm Limited and Contributors. All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +"""CEL ``where`` predicate evaluator for self-election at the edge. + +A ``where`` predicate is a CEL (Common Expression Language) expression that +each candidate device evaluates against its own context to decide whether +to execute a fan-out call. The predicate sees four top-level variables: + + identity device-local identity dict (device_id, device_type, ...) + labels device labels (the same labels selectors filter on) + status device status (heartbeat-updated: location, availability, + battery, online, ...) + bindings shared payload supplied by the caller (selection masks, + thresholds, lookup tables) + +Examples (every example here ships with v4 spec):: + + battery > 50 + labels.category == "camera" && status.battery > 50 + mask[seat_row][seat_col] == 1 + bindings.threshold < status.temperature + +CEL is sandboxed by construction: no I/O, no filesystem, no exec. This +module wraps `cel-python` with lazy import so device-connect-edge does +not require it as a hard dependency. Install with the optional +``[predicate]`` extra:: + + pip install device-connect-edge[predicate] + +The evaluator is shared by the dispatcher (validates the expression +before broadcast) and the device runtime (evaluates per-call to decide +whether to execute the fan-out). +""" + +from __future__ import annotations + +from typing import Any, Mapping + + +class PredicateCompileError(ValueError): + """Raised when a ``where`` expression fails to compile. + + Carries the original cel-python error chained so callers can drill in + if they need the exact parse position. + """ + + +class PredicateEvalError(RuntimeError): + """Raised when an otherwise-valid predicate fails at evaluation time. + + Typical causes: missing context key, type mismatch (e.g. comparing a + string to an int), or arithmetic overflow. + """ + + +# Lazy import: ``cel-python`` is an optional extra. Importers of this module +# pay no cost unless they actually compile a predicate. +def _require_celpy(): + try: + import celpy # type: ignore[import-not-found] + return celpy + except ImportError as e: + raise PredicateCompileError( + "where predicates require the 'cel-python' package; " + "install with the [predicate] extra: " + "pip install 'device-connect-edge[predicate]'" + ) from e + + +def _to_cel(value: Any) -> Any: + """Recursively wrap a Python value as the matching CEL type. + + Native Python ints, strings, dicts, and lists arrive at the boundary + untyped; cel-python's evaluator expects its own typed wrappers + (``IntType``, ``MapType``, ``ListType``, ...). We wrap once at the + top of evaluation rather than asking callers to import celtypes. + """ + celpy = _require_celpy() + ct = celpy.celtypes + if value is None: + return None + if isinstance(value, bool): + return ct.BoolType(value) + if isinstance(value, int): + return ct.IntType(value) + if isinstance(value, float): + return ct.DoubleType(value) + if isinstance(value, str): + return ct.StringType(value) + if isinstance(value, (bytes, bytearray)): + return ct.BytesType(bytes(value)) + if isinstance(value, Mapping): + return ct.MapType({ + ct.StringType(str(k)): _to_cel(v) for k, v in value.items() + }) + if isinstance(value, (list, tuple)): + return ct.ListType([_to_cel(v) for v in value]) + # Fallback: stringify. Rare; happens for custom objects in the context. + return ct.StringType(str(value)) + + +class WherePredicate: + """A compiled ``where`` predicate, ready to evaluate against device context. + + Compile once (typically at the dispatcher when the call comes in or at + the edge when the broadcast envelope is received), then evaluate once + per candidate. Predicates are stateless and safe to reuse across calls. + """ + + __slots__ = ("expression", "_program") + + def __init__(self, expression: str, _program: Any): + self.expression = expression + self._program = _program + + def evaluate(self, context: Mapping[str, Any]) -> bool: + """Return ``True`` if the predicate holds for ``context``. + + ``context`` should be a flat mapping of variable name to Python + value. Common keys: ``identity``, ``labels``, ``status``, + ``bindings``. Missing keys are not auto-defaulted; if the + predicate references one, the call raises PredicateEvalError so + the caller can decide between fail-open and fail-closed. + """ + celpy = _require_celpy() + cel_context = {k: _to_cel(v) for k, v in context.items()} + try: + result = self._program.evaluate(cel_context) + except celpy.CELEvalError as e: + raise PredicateEvalError( + f"failed to evaluate where {self.expression!r}: {e}" + ) from e + return bool(result) + + +def compile_where(expression: str) -> WherePredicate: + """Compile a ``where`` expression into a reusable :class:`WherePredicate`. + + Raises :class:`PredicateCompileError` if cel-python is not installed + or the expression is malformed. + """ + celpy = _require_celpy() + if not isinstance(expression, str): + raise PredicateCompileError( + f"where expression must be a string, got {type(expression).__name__}" + ) + if not expression.strip(): + raise PredicateCompileError("where expression must be non-empty") + env = celpy.Environment() + try: + ast = env.compile(expression) + except Exception as e: + # cel-python surfaces parse errors via several exception classes + # depending on the failure mode (lark.UnexpectedToken, ValueError, + # CELParseError). Catch broadly and rewrap so callers only see + # PredicateCompileError. + raise PredicateCompileError( + f"failed to compile where {expression!r}: {e}" + ) from e + program = env.program(ast) + return WherePredicate(expression=expression, _program=program) diff --git a/packages/device-connect-edge/pyproject.toml b/packages/device-connect-edge/pyproject.toml index 27b5e88..58de4d1 100644 --- a/packages/device-connect-edge/pyproject.toml +++ b/packages/device-connect-edge/pyproject.toml @@ -38,6 +38,9 @@ dependencies = [ [project.optional-dependencies] zenoh = [] # Zenoh is now a core dependency; kept for backward compat +predicate = [ + "cel-python>=0.5.0", +] telemetry = [ "opentelemetry-api>=1.30.0", "opentelemetry-sdk>=1.30.0", diff --git a/packages/device-connect-edge/tests/test_predicate.py b/packages/device-connect-edge/tests/test_predicate.py new file mode 100644 index 0000000..dfaff81 --- /dev/null +++ b/packages/device-connect-edge/tests/test_predicate.py @@ -0,0 +1,131 @@ +# Copyright (c) 2024-2026, Arm Limited and Contributors. All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +"""Tests for the CEL ``where`` predicate evaluator. + +These tests require the ``[predicate]`` extra (cel-python). They are +skipped automatically when cel-python is not installed so the rest of +the edge test suite stays runnable on minimal installs. +""" +from __future__ import annotations + +import pytest + +celpy = pytest.importorskip("celpy") + +from device_connect_edge.predicate import ( + PredicateCompileError, + PredicateEvalError, + WherePredicate, + compile_where, +) + + +# -- compile_where -------------------------------------------------- + + +class TestCompile: + def test_simple_comparison_compiles(self): + p = compile_where("battery > 50") + assert isinstance(p, WherePredicate) + assert p.expression == "battery > 50" + + def test_boolean_combination_compiles(self): + p = compile_where("a > 1 && b < 10 || c == 'x'") + assert isinstance(p, WherePredicate) + + def test_array_indexing_compiles(self): + p = compile_where("mask[row][col] == 1") + assert isinstance(p, WherePredicate) + + def test_label_dot_access_compiles(self): + p = compile_where("labels.category == 'camera'") + assert isinstance(p, WherePredicate) + + def test_empty_expression_rejected(self): + with pytest.raises(PredicateCompileError): + compile_where("") + with pytest.raises(PredicateCompileError): + compile_where(" ") + + def test_non_string_rejected(self): + with pytest.raises(PredicateCompileError): + compile_where(123) # type: ignore[arg-type] + + def test_malformed_expression_rejected(self): + with pytest.raises(PredicateCompileError) as exc: + compile_where("a > > b") + assert "failed to compile" in str(exc.value) + + +# -- evaluate ------------------------------------------------------- + + +class TestEvaluate: + def test_truthy_comparison(self): + p = compile_where("battery > 50") + assert p.evaluate({"battery": 80}) is True + assert p.evaluate({"battery": 30}) is False + + def test_label_match(self): + p = compile_where("labels.category == 'camera'") + assert p.evaluate({"labels": {"category": "camera"}}) is True + assert p.evaluate({"labels": {"category": "robot"}}) is False + + def test_2d_mask_indexing(self): + # The mask-indexing case is the deciding example for picking CEL + # over JSONLogic; keep it as a regression guard. + p = compile_where("mask[row][col] == 1") + ctx = { + "mask": [[0, 1, 0], [1, 0, 0]], + "row": 0, + "col": 1, + } + assert p.evaluate(ctx) is True + ctx["col"] = 0 + assert p.evaluate(ctx) is False + + def test_combined_label_and_status(self): + p = compile_where("labels.category == 'camera' && status.battery > 50") + ctx = { + "labels": {"category": "camera"}, + "status": {"battery": 80}, + } + assert p.evaluate(ctx) is True + ctx["status"]["battery"] = 30 + assert p.evaluate(ctx) is False + ctx["labels"]["category"] = "robot" + ctx["status"]["battery"] = 80 + assert p.evaluate(ctx) is False + + def test_bindings_and_status_compose(self): + p = compile_where("status.temperature > bindings.threshold") + ctx = { + "status": {"temperature": 75.5}, + "bindings": {"threshold": 70.0}, + } + assert p.evaluate(ctx) is True + + def test_string_in_list(self): + p = compile_where("labels.category in ['camera', 'inference']") + assert p.evaluate({"labels": {"category": "camera"}}) is True + assert p.evaluate({"labels": {"category": "robot"}}) is False + + def test_missing_variable_raises_eval_error(self): + p = compile_where("status.battery > 50") + with pytest.raises(PredicateEvalError): + p.evaluate({}) + + def test_type_mismatch_raises_eval_error(self): + p = compile_where("battery > 50") + with pytest.raises(PredicateEvalError): + p.evaluate({"battery": "not a number"}) + + def test_evaluator_is_reusable(self): + # Compile once, evaluate against many contexts. Reusability is the + # property that lets callers compile broadcast envelopes once at + # the dispatcher and ship them to N targets. + p = compile_where("battery > 50") + results = [p.evaluate({"battery": v}) for v in (10, 50, 51, 100)] + assert results == [False, False, True, True] From 4e2208a3712bb2de58de8583a38223c9c28d032c Mon Sep 17 00:00:00 2001 From: Sourav Pati Date: Sun, 10 May 2026 12:05:01 -0700 Subject: [PATCH 09/21] feat(broadcast): async fan-out with correlation, fire_at, and subscribe Add the async selector-driven fan-out path so callers do not have to block on the slowest device: - broadcast(selector, params, where=, bindings=, fire_at=, on_late=) publishes a single envelope to a fanout subject keyed by tenant. Returns immediately with a correlation_id and the candidate count. Compile-validates the optional CEL where predicate at the dispatcher so syntax errors short-circuit before reaching the wire. - DeviceRuntime._broadcast_subscription receives envelopes on ``device-connect..broadcast``. Each candidate self-elects via the target_device_ids gate (pre-resolved by the dispatcher from the selector), then evaluates the optional where predicate against its own context (identity, labels, status, shared bindings). On match the device executes the function and emits a reply on ``device-connect...event.async_reply.`` carrying {success, result|error, actually_fired_at}. - fire_at + on_late synchronized fan-out: the edge holds the message until the wall-clock deadline and fires from its own clock. on_late=skip drops late arrivals (preserves coherence for card-stunt / light-show style workloads); on_late=fire executes immediately. The achieved spread depends on NTP residual (~5-10 ms typical) rather than network jitter (~50-150 ms). - subscribe(selector) returns a Subscription handle. Two selector forms: ``correlation:`` for broadcast replies, and event-scoped selectors (``event()`` or ``device(...).event()``) for live event streams. The handle exposes sync read() and a yielding iter() with idle-timeout reset. - await_replies(correlation_id, timeout, until) sync helper for the common broadcast-then-collect pattern; subscribes, drains, returns the list of reply payloads. The edge predicate context mirrors DeviceStatus.location into labels["location"] when the driver did not declare a labels.location itself, matching the dispatcher-side flatten_device contract so the same selector and predicate strings work on both sides. Test coverage: 38 unit tests across broadcast (12), subscribe (12), and existing modules; 5 NATS integration tests cover end-to-end broadcast + reply, where filter at the edge, fire_at synchronization spread, on_late=skip late-arrival drop, and subscribe(correlation:) streaming. --- .../device_connect_agent_tools/__init__.py | 10 + .../device_connect_agent_tools/connection.py | 22 +- .../device_connect_agent_tools/tools.py | 405 ++++++++++++++++++ .../tests/test_broadcast.py | 201 +++++++++ .../tests/test_subscribe.py | 203 +++++++++ .../device_connect_edge/device.py | 152 +++++++ tests/tests/test_tools_broadcast.py | 213 +++++++++ 7 files changed, 1205 insertions(+), 1 deletion(-) create mode 100644 packages/device-connect-agent-tools/tests/test_broadcast.py create mode 100644 packages/device-connect-agent-tools/tests/test_subscribe.py create mode 100644 tests/tests/test_tools_broadcast.py diff --git a/packages/device-connect-agent-tools/device_connect_agent_tools/__init__.py b/packages/device-connect-agent-tools/device_connect_agent_tools/__init__.py index c809baa..1a7c1e0 100644 --- a/packages/device-connect-agent-tools/device_connect_agent_tools/__init__.py +++ b/packages/device-connect-agent-tools/device_connect_agent_tools/__init__.py @@ -30,6 +30,11 @@ # Selector-driven invocation (preferred) invoke, invoke_many, + broadcast, + # Selector-driven subscription + Subscription, + subscribe, + await_replies, # Other invocation helpers invoke_device_with_fallback, get_device_status, @@ -54,6 +59,11 @@ # Selector-driven invocation (preferred) "invoke", "invoke_many", + "broadcast", + # Selector-driven subscription + "Subscription", + "subscribe", + "await_replies", # Other invocation helpers "invoke_device_with_fallback", "get_device_status", diff --git a/packages/device-connect-agent-tools/device_connect_agent_tools/connection.py b/packages/device-connect-agent-tools/device_connect_agent_tools/connection.py index dae997c..b399f70 100644 --- a/packages/device-connect-agent-tools/device_connect_agent_tools/connection.py +++ b/packages/device-connect-agent-tools/device_connect_agent_tools/connection.py @@ -409,13 +409,33 @@ async def _async_invoke( # ── Broadcast ──────────────────────────────────────────────────── + def publish_broadcast(self, envelope: Dict[str, Any]) -> None: + """Publish a selector-driven broadcast envelope to the fanout subject. + + The envelope shape is documented in + ``device_connect_edge.device.DeviceRuntime._broadcast_subscription``; + every device subscribed to ``device-connect..broadcast`` + receives the message and self-elects via ``target_device_ids`` and + the optional ``where`` predicate. + """ + return self._run(self._async_publish_broadcast(envelope)) + + async def _async_publish_broadcast(self, envelope: Dict[str, Any]) -> None: + subject = f"device-connect.{self.zone}.broadcast" + await self._client.publish(subject, json.dumps(envelope).encode()) + def broadcast( self, function: str, params: Optional[Dict[str, Any]] = None, timeout: float = 5.0, ) -> List[Dict[str, Any]]: - """Invoke a function on all discovered devices and collect results.""" + """Invoke a function on all discovered devices and collect results. + + Sequential sync fan-out (one invoke per device). Predates the + selector-driven broadcast tool; left in place for callers that want + a simple "call this on everyone" without setting up subscriptions. + """ devices = self.list_devices() results = [] for d in devices: diff --git a/packages/device-connect-agent-tools/device_connect_agent_tools/tools.py b/packages/device-connect-agent-tools/device_connect_agent_tools/tools.py index 528e554..c81faf2 100644 --- a/packages/device-connect-agent-tools/device_connect_agent_tools/tools.py +++ b/packages/device-connect-agent-tools/device_connect_agent_tools/tools.py @@ -20,6 +20,7 @@ import logging import os +import time import uuid import warnings from typing import Any @@ -726,6 +727,410 @@ def call_one(row: dict) -> dict[str, Any]: return out +def broadcast( + selector: str, + params: dict[str, Any] | None = None, + where: str | None = None, + bindings: dict[str, Any] | None = None, + fire_at: float | None = None, + on_late: str = "skip", + llm_reasoning: str | None = None, +) -> dict[str, Any]: + """Async selector-driven fan-out. Returns immediately with a correlation id. + + Use ``broadcast`` when the caller does not want to block on the slowest + device. Each candidate self-elects via the optional ``where`` predicate + (CEL, evaluated at the edge against the device's identity, labels, live + status, and the shared ``bindings``) and emits its reply as an event on + a per-device subject keyed by ``correlation_id``:: + + device-connect...event.async_reply. + + Subscribe to those replies via ``subscribe('correlation:')`` or wait + for them with ``await_replies(correlation_id, timeout=...)``. + + Args: + selector: Function-scoped selector. The selector must resolve to a + single function name across the matched devices; if multiple + functions match, an ``ambiguous_function`` error is returned. + params: Function parameters dict applied to every target. + where: Optional CEL predicate evaluated at the edge per candidate + (e.g. ``"status.battery > 50"``, ``"mask[row][col] == 1"``). + Validated at the dispatcher before publication so syntax + errors return immediately rather than reaching the wire. + bindings: Shared payload merged into the predicate context as + ``bindings.``. Keep small (selection masks, thresholds, + top-K rankings); the same bytes ship to every device. + fire_at: Optional wall-clock epoch seconds. Each device holds the + message and fires its function from its own clock at + ``fire_at`` for synchronized fan-out. + on_late: Policy when a device receives a ``fire_at`` message after + the deadline. ``"skip"`` (default) drops the call; ``"fire"`` + executes immediately. + llm_reasoning: Decision rationale for observability. + + Returns: + On success: ``{"correlation_id": "br-...", "candidates": N, + "selector": ..., "function": ...}``. + On failure: ``{"candidates": 0, "error": {"code", "message"}}`` + with codes including the discover() codes, + ``invalid_invoke_scope``, ``ambiguous_function``, + ``invalid_predicate``, and ``invalid_on_late``. + """ + if on_late not in ("skip", "fire"): + return { + "candidates": 0, + "error": _error( + "invalid_on_late", + f"on_late must be 'skip' or 'fire', got {on_late!r}", + ), + } + + rows, error_envelope = _resolve_function_tuples(selector) + if error_envelope is not None: + return {"candidates": 0, "error": error_envelope["error"]} + + if not rows: + # Empty fan-out: still mint a correlation id so callers waiting on + # replies see a clean "no candidates" rather than a hang. + return { + "correlation_id": f"br-{uuid.uuid4().hex[:12]}", + "candidates": 0, + "selector": selector, + } + + # Broadcast assumes one function per call. If the selector resolves to + # multiple distinct functions, surface that as a structured error so + # the caller can either narrow the selector or split into multiple + # broadcasts. + function_names = {row.get("name") for row in rows if row.get("name")} + if len(function_names) != 1: + return { + "candidates": len(rows), + "error": _error( + "ambiguous_function", + f"selector resolved to {len(function_names)} distinct " + "functions; broadcast requires exactly one function per call: " + f"{sorted(function_names)!r}", + ), + } + function_name = next(iter(function_names)) + + # Compile-validate the where predicate before going to the wire so a + # syntax error short-circuits without bothering devices. + if where is not None: + try: + from device_connect_edge.predicate import compile_where + compile_where(where) + except Exception as e: + return { + "candidates": len(rows), + "error": _error("invalid_predicate", str(e)), + } + + correlation_id = f"br-{uuid.uuid4().hex[:12]}" + target_device_ids = sorted({ + row.get("device_id") for row in rows if row.get("device_id") + }) + clean_params = { + k: v for k, v in (params or {}).items() if k != "llm_reasoning" + } + + envelope: dict[str, Any] = { + "correlation_id": correlation_id, + "function": function_name, + "params": clean_params, + "target_device_ids": target_device_ids, + } + if where: + envelope["where"] = where + if bindings: + envelope["bindings"] = bindings + if fire_at is not None: + envelope["fire_at"] = float(fire_at) + envelope["on_late"] = on_late + + if llm_reasoning: + truncated = ( + llm_reasoning[:200] + "..." + if len(llm_reasoning) > 200 else llm_reasoning + ) + logger.info( + "[broadcast::%s::%d targets] Reason: %s", + correlation_id, len(target_device_ids), truncated, + ) + + try: + conn = get_connection() + conn.publish_broadcast(envelope) + except Exception as e: + logger.error("broadcast publish failed: %s", e) + return { + "candidates": len(target_device_ids), + "error": _error("connection_error", str(e)), + } + + return { + "correlation_id": correlation_id, + "candidates": len(target_device_ids), + "selector": selector, + "function": function_name, + } + + +# ── Selector-driven subscription ───────────────────────────────── + + +# Sentinel used to recognise the broadcast-reply form of a subscribe +# selector (``correlation:``). Kept short so the selector reads +# naturally; the parser matches an exact prefix. +_CORRELATION_PREFIX = "correlation:" + + +class Subscription: + """A live subscription handle returned by :func:`subscribe`. + + Two selector forms produce a subscription: + + * ``"correlation:"`` -- replies from a prior :func:`broadcast` call, + keyed by ``correlation_id`` and routed across all devices that fired. + * Event-scoped selectors (``event()`` or + ``device(...).event()``) -- a multiplex of matching events + across the resolved candidate set. + + The handle exposes a sync ``read`` API that drains buffered messages. + Use as a context manager (or call :meth:`close`) to tear the + underlying messaging subscription down deterministically:: + + with subscribe("correlation:" + cid) as sub: + for reply in sub.iter(timeout=5.0): + process(reply) + """ + + def __init__(self, conn: Any, inbox_names: list[str]): + self._conn = conn + self._inbox_names = list(inbox_names) + self._closed = False + self._cursor = 0 # index into the concatenated message stream + + def read(self, max_messages: int | None = None) -> list[dict[str, Any]]: + """Drain currently buffered messages without blocking. + + Returns parsed payload dicts (already JSON-decoded by the + connection's buffered subscription path). Subsequent calls return + only messages that arrived after the previous call. + """ + if self._closed: + return [] + out: list[dict[str, Any]] = [] + for name in self._inbox_names: + inboxes = self._conn.get_inbox(name) + buffered = inboxes.get(name, []) or [] + # Each buffered entry is (subject, payload). We expose the + # parsed payload but stamp the subject onto it so callers can + # distinguish per-source messages without parsing it themselves. + for subject, payload in buffered: + if not isinstance(payload, dict): + payload = {"raw": payload} + payload = {**payload, "_subject": subject} + out.append(payload) + # Fast cursor: trim per-inbox buffers we have already returned by + # truncating from the front. The connection layer already caps each + # inbox at 1000 entries, so bounded growth is its concern. + for name in self._inbox_names: + self._conn._inbox[name] = [] + if max_messages is not None: + out = out[:max_messages] + return out + + def iter(self, timeout: float = 5.0, poll_interval: float = 0.05): + """Yield messages until ``timeout`` elapses with no new arrivals. + + ``timeout`` resets each time at least one message is yielded, so + callers can drain a steady stream without re-parameterising the + wait. Use ``read`` instead for one-shot draining. + """ + deadline = time.monotonic() + timeout + while not self._closed: + new = self.read() + if new: + for msg in new: + yield msg + deadline = time.monotonic() + timeout + continue + if time.monotonic() >= deadline: + return + time.sleep(poll_interval) + + def close(self) -> None: + """Tear down the underlying messaging subscriptions.""" + if self._closed: + return + self._closed = True + for name in self._inbox_names: + try: + self._conn.unsubscribe_buffered(name) + except Exception: # pragma: no cover - cleanup best effort + logger.debug("close: unsubscribe %s failed", name, exc_info=True) + + def __enter__(self) -> "Subscription": + return self + + def __exit__(self, exc_type, exc, tb) -> None: + self.close() + + +def _correlation_subjects(conn: Any, correlation_id: str) -> list[str]: + """Build the per-device wildcard reply subjects for a correlation id. + + The reply template is ``device-connect...event + .async_reply.``; ```` is single-token wildcarded + so a subscription receives replies from any device that fires the + broadcast without having to enumerate them up-front. + """ + return [ + f"device-connect.{conn.zone}.*.event.async_reply.{correlation_id}", + ] + + +def _event_subjects_for_selector(selector: str) -> tuple[list[str] | None, dict[str, Any] | None]: + """Resolve an event-scoped selector to per-device subjects. + + Returns ``(subjects, None)`` on success or ``(None, error_envelope)`` + if the selector failed to parse or used a non-event scope. + """ + rows: list[dict] = [] + offset = 0 + while True: + page = discover(selector, offset=offset, limit=DISCOVER_HARD_LIMIT) + if "error" in page: + return None, page + if page["scope"] not in (Scope.DEVICE_EVENT.value, Scope.EVENT_ONLY.value): + return None, _empty_envelope( + scope=page["scope"], + error=_error( + "invalid_subscribe_scope", + "subscribe requires an event-scoped selector " + "(device(...).event(...) or event(...)) or " + "'correlation:'; got " + f"scope={page['scope']!r}", + ), + ) + rows.extend(page["results"]) + if page["next_offset"] is None: + break + offset = page["next_offset"] + + conn = get_connection() + subjects: list[str] = [] + seen: set[str] = set() + for row in rows: + device_id = row.get("device_id") or "" + event_name = row.get("name") or "" + if not device_id or not event_name: + continue + subj = f"device-connect.{conn.zone}.{device_id}.event.{event_name}" + if subj not in seen: + seen.add(subj) + subjects.append(subj) + return subjects, None + + +def subscribe(selector: str) -> Subscription: + """Subscribe to events or broadcast replies matching a selector. + + Args: + selector: One of: + - ``"correlation:"`` for broadcast replies of a prior call. + - An event-scoped selector (``event()`` or + ``device(...).event()``) for live event streams. + + Returns: + A :class:`Subscription` handle. Iterate with ``sub.iter(timeout)`` + or drain currently-buffered messages with ``sub.read()``. Always + close (or use ``with``) to tear the underlying subscription down. + + Raises: + ValueError on selector errors. The selector string is checked at + the boundary; downstream subscribe calls are not retried, so a + parse error fails fast. + """ + if not isinstance(selector, str) or not selector.strip(): + raise ValueError("subscribe selector must be a non-empty string") + + conn = get_connection() + if selector.startswith(_CORRELATION_PREFIX): + correlation_id = selector[len(_CORRELATION_PREFIX):].strip() + if not correlation_id: + raise ValueError( + "correlation form must be 'correlation:' with non-empty id" + ) + subjects = _correlation_subjects(conn, correlation_id) + inbox_prefix = f"sub-corr-{correlation_id}-{uuid.uuid4().hex[:8]}" + else: + subjects, error_envelope = _event_subjects_for_selector(selector) + if error_envelope is not None: + err = error_envelope.get("error") + msg = err.get("message", str(err)) if isinstance(err, dict) else str(err) + raise ValueError(msg) + if not subjects: + # Nothing to subscribe to. Return an idle Subscription so the + # caller's ``with subscribe(...) as sub: ...`` pattern still + # works without raising; ``read``/``iter`` will yield nothing. + return Subscription(conn, inbox_names=[]) + inbox_prefix = f"sub-evt-{uuid.uuid4().hex[:8]}" + + inbox_names: list[str] = [] + for i, subj in enumerate(subjects): + name = f"{inbox_prefix}-{i}" + conn.subscribe_buffered(subj, name=name) + inbox_names.append(name) + return Subscription(conn, inbox_names=inbox_names) + + +def await_replies( + correlation_id: str, + timeout: float = 10.0, + until: int | None = None, + poll_interval: float = 0.05, +) -> list[dict[str, Any]]: + """Block until ``timeout`` elapses or ``until`` replies have arrived. + + A sync helper for the common broadcast pattern: caller fires a + :func:`broadcast`, then waits for some replies. Builds a one-shot + subscription on the correlation reply subject, drains it, and tears + down before returning. + + Args: + correlation_id: The id returned by :func:`broadcast`. + timeout: Overall wall-clock limit in seconds. + until: Stop early once this many replies have been collected. + poll_interval: How often the helper polls the subscription buffer. + + Returns: + A list of reply payload dicts, each with at least + ``{correlation_id, device_id, success, result|error, + actually_fired_at}``. + """ + if not correlation_id: + return [] + sub = subscribe(f"{_CORRELATION_PREFIX}{correlation_id}") + try: + replies: list[dict[str, Any]] = [] + deadline = time.monotonic() + timeout + while True: + new = sub.read() + replies.extend(new) + if until is not None and len(replies) >= until: + break + if time.monotonic() >= deadline: + break + time.sleep(poll_interval) + return replies + finally: + sub.close() + + # ── Hierarchical discovery tools ───────────────────────────────── diff --git a/packages/device-connect-agent-tools/tests/test_broadcast.py b/packages/device-connect-agent-tools/tests/test_broadcast.py new file mode 100644 index 0000000..e25a35e --- /dev/null +++ b/packages/device-connect-agent-tools/tests/test_broadcast.py @@ -0,0 +1,201 @@ +# Copyright (c) 2024-2026, Arm Limited and Contributors. All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +"""Unit tests for the selector-driven ``broadcast`` tool. + +Uses the same labeled mock fleet (cam-001, cam-002, robot-001, sensor-001) +as the discover/invoke tests so selectors exercise real device, function, +and event names. +""" +import json +from unittest.mock import MagicMock, patch + +import pytest + +from device_connect_agent_tools import tools as tools_mod + + +SAMPLE_DEVICES = [ + { + "device_id": "cam-001", + "device_type": "camera", + "location": "lab-A", + "status": {"state": "online"}, + "identity": {"device_type": "camera"}, + "labels": {"category": "camera", "location": "lab-A"}, + "functions": [ + { + "name": "capture_image", + "parameters": {}, + "labels": {"direction": "write", "modality": "rgb"}, + }, + ], + "events": [], + }, + { + "device_id": "cam-002", + "device_type": "camera", + "location": "lab-A", + "status": {"state": "online"}, + "identity": {"device_type": "camera"}, + "labels": {"category": "camera", "location": "lab-A"}, + "functions": [ + { + "name": "capture_image", + "parameters": {}, + "labels": {"direction": "write", "modality": "rgb"}, + }, + ], + "events": [], + }, + { + "device_id": "sensor-001", + "device_type": "temperature_sensor", + "location": "lab-B", + "status": {"state": "online"}, + "identity": {"device_type": "temperature_sensor"}, + "labels": {"category": "sensor"}, + "functions": [ + { + "name": "get_reading", + "parameters": {}, + "labels": {"direction": "read"}, + }, + ], + "events": [], + }, +] + + +@pytest.fixture +def mock_conn(): + conn = MagicMock() + conn.list_devices.return_value = SAMPLE_DEVICES + conn.zone = "default" + # Capture the published envelope for assertions. + published: list[dict] = [] + conn.publish_broadcast.side_effect = lambda env: published.append(env) + conn._published = published + with patch.object(tools_mod, "get_connection", return_value=conn): + yield conn + + +# -- broadcast ------------------------------------------------------ + + +class TestBroadcast: + def test_returns_correlation_id_and_candidates(self, mock_conn): + r = tools_mod.broadcast("device(*).function(capture_image)") + assert r["correlation_id"].startswith("br-") + assert r["candidates"] == 2 + assert r["function"] == "capture_image" + assert "error" not in r + + def test_envelope_carries_function_and_targets(self, mock_conn): + tools_mod.broadcast( + "device(*).function(capture_image)", + params={"resolution": "4k"}, + ) + env = mock_conn._published[0] + assert env["function"] == "capture_image" + assert env["params"] == {"resolution": "4k"} + assert sorted(env["target_device_ids"]) == ["cam-001", "cam-002"] + # No optional fields when caller did not set them. + assert "where" not in env + assert "bindings" not in env + assert "fire_at" not in env + assert "on_late" not in env + + def test_where_and_bindings_propagate_to_envelope(self, mock_conn): + tools_mod.broadcast( + "device(*).function(capture_image)", + where="status.battery > 50", + bindings={"threshold": 80}, + ) + env = mock_conn._published[0] + assert env["where"] == "status.battery > 50" + assert env["bindings"] == {"threshold": 80} + + def test_fire_at_propagates_with_default_on_late(self, mock_conn): + tools_mod.broadcast( + "device(*).function(capture_image)", + fire_at=123456789.0, + ) + env = mock_conn._published[0] + assert env["fire_at"] == 123456789.0 + assert env["on_late"] == "skip" + + def test_fire_at_with_explicit_on_late_fire(self, mock_conn): + tools_mod.broadcast( + "device(*).function(capture_image)", + fire_at=123.0, on_late="fire", + ) + env = mock_conn._published[0] + assert env["on_late"] == "fire" + + def test_invalid_on_late_rejected(self, mock_conn): + r = tools_mod.broadcast( + "device(*).function(capture_image)", on_late="bogus", + ) + assert r["candidates"] == 0 + assert r["error"]["code"] == "invalid_on_late" + assert mock_conn.publish_broadcast.call_count == 0 + + def test_ambiguous_function_rejected(self, mock_conn): + # function(direction:read) resolves to multiple distinct functions + # (get_reading + dispatch_robot's get_status if it had read; here + # it just hits sensor's get_reading and possibly more). With our + # SAMPLE_DEVICES this matches just get_reading, so artificially + # broaden by picking a selector that crosses functions: + r = tools_mod.broadcast("device(*).function(*)") + assert r["candidates"] == 3 + assert r["error"]["code"] == "ambiguous_function" + + def test_zero_matches_returns_correlation_with_zero(self, mock_conn): + r = tools_mod.broadcast("device(*).function(does_not_exist)") + assert r["candidates"] == 0 + assert r["correlation_id"].startswith("br-") + # No envelope was published (no targets). + assert mock_conn.publish_broadcast.call_count == 0 + + def test_invalid_scope_rejected(self, mock_conn): + r = tools_mod.broadcast("device(cam-001)") + assert r["candidates"] == 0 + assert r["error"]["code"] == "invalid_invoke_scope" + + def test_selector_parse_error_propagated(self, mock_conn): + r = tools_mod.broadcast("widgets(*)") + assert r["candidates"] == 0 + assert r["error"]["code"] == "selector_parse_error" + + def test_invalid_predicate_rejected_before_publish(self, mock_conn): + # The predicate is compile-validated at the dispatcher; a syntax + # error short-circuits without publishing. + try: + import celpy # noqa: F401 + except ImportError: + pytest.skip("cel-python not installed") + r = tools_mod.broadcast( + "device(*).function(capture_image)", where="a > > b", + ) + assert r["error"]["code"] == "invalid_predicate" + assert mock_conn.publish_broadcast.call_count == 0 + + def test_publish_failure_returns_connection_error(self): + conn = MagicMock() + conn.list_devices.return_value = SAMPLE_DEVICES + conn.zone = "default" + conn.publish_broadcast.side_effect = RuntimeError("messaging down") + with patch.object(tools_mod, "get_connection", return_value=conn): + r = tools_mod.broadcast("device(*).function(capture_image)") + assert r["error"]["code"] == "connection_error" + assert "messaging down" in r["error"]["message"] + + def test_llm_reasoning_stripped_from_params(self, mock_conn): + tools_mod.broadcast( + "device(*).function(capture_image)", + params={"resolution": "4k", "llm_reasoning": "should not appear"}, + ) + env = mock_conn._published[0] + assert "llm_reasoning" not in env["params"] diff --git a/packages/device-connect-agent-tools/tests/test_subscribe.py b/packages/device-connect-agent-tools/tests/test_subscribe.py new file mode 100644 index 0000000..a6f032e --- /dev/null +++ b/packages/device-connect-agent-tools/tests/test_subscribe.py @@ -0,0 +1,203 @@ +# Copyright (c) 2024-2026, Arm Limited and Contributors. All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +"""Unit tests for the selector-driven subscribe + await_replies tools. + +The tests stand up a fake Connection that mirrors the buffered-inbox API +the production class exposes (``subscribe_buffered`` / +``unsubscribe_buffered`` / ``get_inbox`` / ``_inbox`` dict). Real +messaging is not exercised here; integration tests cover the wire. +""" +from unittest.mock import patch + +import pytest + +from device_connect_agent_tools import tools as tools_mod + + +SAMPLE_DEVICES = [ + { + "device_id": "cam-001", + "device_type": "camera", + "labels": {"category": "camera", "location": "lab-A"}, + "functions": [], + "events": [ + {"name": "object_detected", "labels": {"modality": "rgb"}}, + ], + }, + { + "device_id": "cam-002", + "device_type": "camera", + "labels": {"category": "camera", "location": "lab-A"}, + "functions": [], + "events": [ + {"name": "object_detected", "labels": {"modality": "rgb"}}, + ], + }, +] + + +class FakeConnection: + """Minimal fake of the agent-tools Connection used by Subscription.""" + + def __init__(self, devices=None, zone="default"): + self.zone = zone + self.devices = devices or [] + self._inbox: dict[str, list[tuple]] = {} + self.subscribed_subjects: list[str] = [] + self.unsubscribed_names: list[str] = [] + + def list_devices(self): + return list(self.devices) + + def subscribe_buffered(self, subject: str, name: str | None = None) -> str: + name = name or subject + self._inbox[name] = [] + self.subscribed_subjects.append(subject) + return name + + def unsubscribe_buffered(self, name: str) -> None: + self.unsubscribed_names.append(name) + self._inbox.pop(name, None) + + def get_inbox(self, name: str | None = None): + if name is not None: + return {name: list(self._inbox.get(name, []))} + return {k: list(v) for k, v in self._inbox.items()} + + # Test helper: simulate a message landing on a given subject. + def deliver(self, subject: str, payload: dict): + for name, _ in list(self._inbox.items()): + self._inbox[name].append((subject, payload)) + + +@pytest.fixture +def fake_conn(): + conn = FakeConnection(devices=SAMPLE_DEVICES) + with patch.object(tools_mod, "get_connection", return_value=conn): + yield conn + + +# -- subscribe ------------------------------------------------------ + + +class TestSubscribe: + def test_correlation_form_subscribes_to_reply_subject(self, fake_conn): + sub = tools_mod.subscribe("correlation:abc-123") + assert len(fake_conn.subscribed_subjects) == 1 + subj = fake_conn.subscribed_subjects[0] + assert subj == "device-connect.default.*.event.async_reply.abc-123" + sub.close() + assert fake_conn.unsubscribed_names + + def test_correlation_form_with_empty_id_rejected(self, fake_conn): + with pytest.raises(ValueError): + tools_mod.subscribe("correlation:") + + def test_event_selector_subscribes_per_device(self, fake_conn): + sub = tools_mod.subscribe("device(*).event(object_detected)") + # Two cameras emit object_detected -> two subjects subscribed. + assert len(fake_conn.subscribed_subjects) == 2 + for subj in fake_conn.subscribed_subjects: + assert subj.startswith("device-connect.default.") + assert subj.endswith(".event.object_detected") + sub.close() + + def test_event_selector_zero_matches_returns_idle(self, fake_conn): + sub = tools_mod.subscribe("event(no_such_event)") + assert fake_conn.subscribed_subjects == [] + # Idle subscription: read returns empty, close is a no-op. + assert sub.read() == [] + sub.close() + + def test_non_event_scope_rejected(self, fake_conn): + with pytest.raises(ValueError) as exc: + tools_mod.subscribe("device(cam-001)") + assert "subscribe requires" in str(exc.value) + + def test_empty_or_non_string_rejected(self, fake_conn): + with pytest.raises(ValueError): + tools_mod.subscribe("") + with pytest.raises(ValueError): + tools_mod.subscribe(None) # type: ignore[arg-type] + + +# -- Subscription --------------------------------------------------- + + +class TestSubscriptionHandle: + def test_read_drains_buffered_messages(self, fake_conn): + sub = tools_mod.subscribe("correlation:r1") + fake_conn.deliver( + "device-connect.default.cam-001.event.async_reply.r1", + {"correlation_id": "r1", "device_id": "cam-001", "success": True}, + ) + msgs = sub.read() + assert len(msgs) == 1 + assert msgs[0]["device_id"] == "cam-001" + # Subject is stamped onto the payload for source attribution. + assert "_subject" in msgs[0] + # A second read returns nothing -- the buffer is drained. + assert sub.read() == [] + sub.close() + + def test_context_manager_closes(self, fake_conn): + with tools_mod.subscribe("correlation:r2") as sub: + assert sub.read() == [] + assert fake_conn.unsubscribed_names # close() ran + + def test_iter_yields_until_idle_timeout(self, fake_conn): + sub = tools_mod.subscribe("correlation:r3") + fake_conn.deliver( + "device-connect.default.cam-001.event.async_reply.r3", + {"correlation_id": "r3", "device_id": "cam-001"}, + ) + # Short timeout; iter() should yield the buffered reply then exit + # once no new messages arrive within the idle window. + msgs = list(sub.iter(timeout=0.1, poll_interval=0.01)) + assert len(msgs) == 1 + sub.close() + + +# -- await_replies -------------------------------------------------- + + +class TestAwaitReplies: + def test_empty_correlation_id_returns_empty_list(self, fake_conn): + assert tools_mod.await_replies("") == [] + + def test_collects_replies_until_count(self, fake_conn): + # Pre-stage two replies on the to-be-subscribed subject. await_replies + # subscribes (drains nothing yet), then deliver more during the loop. + # We deliver up-front via the fake's deliver hook so the first poll + # picks them up. + def deliver_when_subscribed(subject, name=None): + n = FakeConnection.subscribe_buffered(fake_conn, subject, name) + # Pre-load a couple of replies so the first poll returns them. + fake_conn.deliver( + "device-connect.default.cam-001.event.async_reply.r4", + {"correlation_id": "r4", "device_id": "cam-001"}, + ) + fake_conn.deliver( + "device-connect.default.cam-002.event.async_reply.r4", + {"correlation_id": "r4", "device_id": "cam-002"}, + ) + return n + + with patch.object( + fake_conn, "subscribe_buffered", side_effect=deliver_when_subscribed, + ): + replies = tools_mod.await_replies( + "r4", timeout=2.0, until=2, poll_interval=0.01, + ) + assert len(replies) == 2 + ids = {r["device_id"] for r in replies} + assert ids == {"cam-001", "cam-002"} + + def test_returns_after_timeout_with_partial(self, fake_conn): + # No replies delivered -> after timeout, returns empty list. + replies = tools_mod.await_replies( + "r5", timeout=0.1, poll_interval=0.01, + ) + assert replies == [] diff --git a/packages/device-connect-edge/device_connect_edge/device.py b/packages/device-connect-edge/device_connect_edge/device.py index 40d5c63..b64d443 100644 --- a/packages/device-connect-edge/device_connect_edge/device.py +++ b/packages/device-connect-edge/device_connect_edge/device.py @@ -1135,6 +1135,151 @@ async def on_msg(data: bytes, reply_subject: Optional[str]): self._logger.info("Subscribed to commands on %s", subj) + async def _broadcast_subscription(self) -> None: + """Subscribe to selector-driven broadcasts and self-elect to handle. + + Broadcast envelope shape (JSON over a fanout subject):: + + { + "correlation_id": "br-abc123", + "function": "capture_image", + "params": {"resolution": "4k"}, + "target_device_ids": ["cam-001", "cam-002"], // pre-resolved + "where": "status.battery > 50", // optional CEL + "bindings": {"mask": [[0,1],[1,0]]}, // optional + "fire_at": 1234567890.5, // optional, epoch s + "on_late": "skip" // skip|fire + } + + On match, the device executes the function and emits a reply on + ``device-connect...event.async_reply.`` + with ``{correlation_id, device_id, success, result|error, + actually_fired_at}``. + """ + subj = f"device-connect.{self.tenant}.broadcast" + + async def on_msg(data: bytes, reply_subject: Optional[str]): + try: + envelope = json.loads(data) + except Exception as e: + self._logger.debug("Broadcast: malformed envelope: %s", e) + return + + correlation_id = envelope.get("correlation_id") + if not correlation_id: + return + + # Self-election step 1: target_device_ids gate (pre-resolved by + # the dispatcher from the selector). When absent or empty, treat + # the broadcast as fleet-wide. + targets = envelope.get("target_device_ids") or [] + if targets and self.device_id not in targets: + return + + function_name = envelope.get("function") + if not function_name: + return + params_dict = envelope.get("params", {}) or {} + + # Self-election step 2: where predicate against {identity, labels, + # status, bindings}. A failed compile or eval is treated as + # fail-closed (do not execute). + where_expr = envelope.get("where") + if where_expr: + try: + from device_connect_edge.predicate import compile_where + predicate = compile_where(where_expr) + caps = self._driver.capabilities if self._driver else self.capabilities + status = self._driver.status if self._driver else None + labels = (caps.labels if caps and caps.labels else {}) or {} + status_dict = ( + status.model_dump() if status and hasattr(status, "model_dump") else {} + ) + # Mirror the legacy DeviceStatus.location into labels so + # ``labels.location`` works in predicates without the driver + # having to declare it explicitly. Matches the dispatcher-side + # flatten_device contract. + if "location" not in labels and status_dict.get("location"): + labels = {**labels, "location": status_dict["location"]} + context = { + "identity": ( + caps.identity.model_dump() + if caps and getattr(caps, "identity", None) else {} + ), + "labels": labels, + "status": status_dict, + "bindings": envelope.get("bindings", {}) or {}, + } + if not predicate.evaluate(context): + return + except Exception as e: + self._logger.warning( + "Broadcast %s: where predicate failed (skipping): %s", + correlation_id, e, + ) + return + + # fire_at: hold the message until the wall-clock deadline. The + # on_late policy decides what to do if the message arrives past + # the deadline (skip preserves coherence; fire runs anyway). + fire_at = envelope.get("fire_at") + on_late = envelope.get("on_late", "skip") + if fire_at is not None: + delay = float(fire_at) - time.time() + if delay < 0 and on_late == "skip": + self._logger.info( + "Broadcast %s arrived %.3fs late, on_late=skip", + correlation_id, -delay, + ) + return + if delay > 0: + await asyncio.sleep(delay) + + # Execute the driver function and emit the reply. + actually_fired_at = time.time() + reply_subj = ( + f"device-connect.{self.tenant}.{self.device_id}" + f".event.async_reply.{correlation_id}" + ) + try: + if self._driver is None: + raise RuntimeError("no driver configured") + driver_functions = self._driver._get_functions() + if function_name not in driver_functions: + raise RuntimeError(f"unknown function: {function_name}") + result = await self._driver.invoke(function_name, **params_dict) + reply_payload = { + "correlation_id": correlation_id, + "device_id": self.device_id, + "success": True, + "result": result, + "actually_fired_at": actually_fired_at, + } + except Exception as e: + self._logger.warning( + "Broadcast %s: function %s failed: %s", + correlation_id, function_name, e, + ) + reply_payload = { + "correlation_id": correlation_id, + "device_id": self.device_id, + "success": False, + "error": {"code": "invoke_failed", "message": str(e)}, + "actually_fired_at": actually_fired_at, + } + try: + await self.messaging.publish( + reply_subj, json.dumps(reply_payload).encode(), + ) + except Exception as e: # pragma: no cover + self._logger.warning( + "Broadcast %s: reply publish failed: %s", correlation_id, e, + ) + + await self.messaging.subscribe(subj, callback=on_msg) + self._logger.info("Subscribed to broadcasts on %s", subj) + + async def _event_dispatch_loop(self) -> None: """Send queued events, retrying on failure.""" @@ -1372,6 +1517,13 @@ async def run(self) -> None: # Subscribe to commands BEFORE capability routines so log order makes sense await self._cmd_subscription() + # Subscribe to fleet broadcasts (best-effort; broadcast is opt-in for + # callers, so failure here should not block command handling). + try: + await self._broadcast_subscription() + except Exception as e: # pragma: no cover - best effort logging + self._logger.warning("Broadcast subscription failed: %s", e) + # Start capability routines if driver supports them (CapabilityDriverMixin) # This must happen after registration so events don't fire before device is registered if hasattr(self._driver, 'start_capability_routines'): diff --git a/tests/tests/test_tools_broadcast.py b/tests/tests/test_tools_broadcast.py new file mode 100644 index 0000000..0e7413f --- /dev/null +++ b/tests/tests/test_tools_broadcast.py @@ -0,0 +1,213 @@ +# Copyright (c) 2024-2026, Arm Limited and Contributors. All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +"""Integration tests for selector-driven broadcast + correlation replies. + +End-to-end coverage for the async fan-out path: +- Dispatcher publishes a broadcast envelope on the fanout subject. +- Each device runtime self-elects via target_device_ids and the optional + CEL ``where`` predicate. +- Devices execute the function and emit a reply on the per-device async + reply subject keyed by correlation_id. +- ``await_replies`` collects replies for a bounded window. +""" + +import asyncio +import time + +import pytest + +SETTLE_TIME = 0.4 +DISCOVERY_TIMEOUT = 5.0 + + +async def _wait_for_devices(messaging_url, expected_ids): + from device_connect_agent_tools import connect + from device_connect_agent_tools.connection import get_connection + + await asyncio.to_thread(connect, nats_url=messaging_url) + deadline = time.monotonic() + DISCOVERY_TIMEOUT + while True: + conn = get_connection() + devices = await asyncio.to_thread(conn.list_devices) + ids = {d.get("device_id") for d in devices} + if expected_ids.issubset(ids) or time.monotonic() > deadline: + return devices + await asyncio.sleep(0.25) + + +@pytest.mark.asyncio +@pytest.mark.integration +async def test_broadcast_returns_correlation_and_replies_arrive( + device_spawner, messaging_url, +): + """broadcast() returns a correlation_id and matching devices reply on the + per-device async reply subject.""" + await device_spawner.spawn_camera("itest-bc-cam-1", location="lab-A") + await device_spawner.spawn_camera("itest-bc-cam-2", location="lab-A") + await asyncio.sleep(SETTLE_TIME) + + from device_connect_agent_tools import ( + await_replies, broadcast, disconnect, + ) + + await _wait_for_devices(messaging_url, {"itest-bc-cam-1", "itest-bc-cam-2"}) + try: + result = await asyncio.to_thread( + broadcast, + "device(itest-bc-cam-*).function(capture_image)", + {"resolution": "720p"}, + ) + assert result["correlation_id"].startswith("br-") + assert result["candidates"] == 2 + assert result["function"] == "capture_image" + + replies = await asyncio.to_thread( + await_replies, result["correlation_id"], timeout=5.0, until=2, + ) + assert len(replies) == 2 + ids = {r["device_id"] for r in replies} + assert ids == {"itest-bc-cam-1", "itest-bc-cam-2"} + for r in replies: + assert r["success"] is True + assert r["correlation_id"] == result["correlation_id"] + assert "actually_fired_at" in r + finally: + await asyncio.to_thread(disconnect) + + +@pytest.mark.asyncio +@pytest.mark.integration +async def test_broadcast_where_filters_at_edge(device_spawner, messaging_url): + """A CEL where predicate runs at each candidate; only matches reply.""" + pytest.importorskip("celpy") + await device_spawner.spawn_camera("itest-bcw-cam-a", location="lab-A") + await device_spawner.spawn_camera("itest-bcw-cam-b", location="lab-B") + await asyncio.sleep(SETTLE_TIME) + + from device_connect_agent_tools import ( + await_replies, broadcast, disconnect, + ) + + await _wait_for_devices(messaging_url, {"itest-bcw-cam-a", "itest-bcw-cam-b"}) + try: + result = await asyncio.to_thread( + broadcast, + "device(itest-bcw-cam-*).function(capture_image)", + {"resolution": "1080p"}, + "labels.location == 'lab-A'", # where predicate + ) + assert result["candidates"] == 2 + + replies = await asyncio.to_thread( + await_replies, result["correlation_id"], timeout=3.0, + ) + # Only cam-a is in lab-A; cam-b silently self-deselects. + ids = {r["device_id"] for r in replies} + assert ids == {"itest-bcw-cam-a"} + finally: + await asyncio.to_thread(disconnect) + + +@pytest.mark.asyncio +@pytest.mark.integration +async def test_broadcast_fire_at_synchronizes_fan_out( + device_spawner, messaging_url, +): + """fire_at causes each device to fire from its own clock at the deadline.""" + await device_spawner.spawn_camera("itest-bcf-cam-1", location="lab-A") + await device_spawner.spawn_camera("itest-bcf-cam-2", location="lab-A") + await asyncio.sleep(SETTLE_TIME) + + from device_connect_agent_tools import ( + await_replies, broadcast, disconnect, + ) + + await _wait_for_devices(messaging_url, {"itest-bcf-cam-1", "itest-bcf-cam-2"}) + try: + # Schedule 0.5s in the future; on_late=skip so any tardy device drops + # the call rather than firing late and breaking the coherence. + scheduled = time.time() + 0.5 + result = await asyncio.to_thread( + broadcast, + "device(itest-bcf-cam-*).function(capture_image)", + None, None, None, + scheduled, # fire_at + "skip", # on_late + ) + assert result["candidates"] == 2 + + replies = await asyncio.to_thread( + await_replies, result["correlation_id"], timeout=3.0, until=2, + ) + assert len(replies) == 2 + # actually_fired_at should be at-or-after the scheduled time on each. + for r in replies: + assert r["actually_fired_at"] >= scheduled - 0.05 # small slack + # Achieved spread should be tight (well under network jitter). + spread = max(r["actually_fired_at"] for r in replies) - min( + r["actually_fired_at"] for r in replies + ) + assert spread < 0.5, f"fire_at spread too wide: {spread:.3f}s" + finally: + await asyncio.to_thread(disconnect) + + +@pytest.mark.asyncio +@pytest.mark.integration +async def test_broadcast_fire_at_late_with_skip_drops( + device_spawner, messaging_url, +): + """A fire_at in the past with on_late=skip yields no replies.""" + await device_spawner.spawn_camera("itest-bcl-cam", location="lab-A") + await asyncio.sleep(SETTLE_TIME) + + from device_connect_agent_tools import ( + await_replies, broadcast, disconnect, + ) + + await _wait_for_devices(messaging_url, {"itest-bcl-cam"}) + try: + past = time.time() - 5.0 # already 5s late + result = await asyncio.to_thread( + broadcast, + "device(itest-bcl-cam).function(capture_image)", + None, None, None, past, "skip", + ) + replies = await asyncio.to_thread( + await_replies, result["correlation_id"], timeout=1.5, + ) + assert replies == [] + finally: + await asyncio.to_thread(disconnect) + + +@pytest.mark.asyncio +@pytest.mark.integration +async def test_subscribe_correlation_form(device_spawner, messaging_url): + """subscribe('correlation:') captures replies as they arrive.""" + await device_spawner.spawn_camera("itest-bcs-cam-1", location="lab-A") + await device_spawner.spawn_camera("itest-bcs-cam-2", location="lab-A") + await asyncio.sleep(SETTLE_TIME) + + from device_connect_agent_tools import broadcast, disconnect, subscribe + + await _wait_for_devices(messaging_url, {"itest-bcs-cam-1", "itest-bcs-cam-2"}) + try: + result = await asyncio.to_thread( + broadcast, + "device(itest-bcs-cam-*).function(capture_image)", + ) + cid = result["correlation_id"] + + def collect(): + with subscribe(f"correlation:{cid}") as sub: + # Drain over a short window. + return list(sub.iter(timeout=2.0, poll_interval=0.05)) + + replies = await asyncio.to_thread(collect) + ids = {r["device_id"] for r in replies} + assert ids == {"itest-bcs-cam-1", "itest-bcs-cam-2"} + finally: + await asyncio.to_thread(disconnect) From 072ef0133944a20c0dbfed9c6248109cabcb8938 Mon Sep 17 00:00:00 2001 From: Sourav Pati Date: Sun, 10 May 2026 12:10:20 -0700 Subject: [PATCH 10/21] feat(cli): selector-driven verbs in devctl and statectl Add the operator-facing shell surface for selector-driven discovery and operations: devctl verbs (read-side): - devctl discover "" [--offset N] [--limit M] - devctl discover-labels [--key K] [--offset N] [--limit M] statectl verbs (write-side): - statectl invoke "" [--param k=v ...] - statectl invoke-many "" [--param k=v ...] [--timeout T] [--max-concurrency N] - statectl broadcast "" [--param k=v ...] [--where E] [--bindings JSON] [--fire-at T] [--on-late skip|fire] - statectl subscribe "" [--timeout T] [--until N] - statectl await [--timeout T] [--until N] Each verb is a thin wrapper over the Python tool of the same name and exits non-zero on tool-side errors so they compose into shell pipelines naturally. Parameter values are decoded as JSON when they look like JSON (numbers, booleans, arrays, objects, quoted strings) and pass through as strings otherwise, so common shapes (--param resolution=4k, --param zones='[1,2,3]') work without quoting heroics. The historical ``devctl discover`` verb (mDNS scan for uncommissioned devices) is renamed to ``mdns-scan`` with ``scan`` as an alias, so ``discover`` is free for the selector-driven sense. Existing scripts should switch from ``devctl discover`` to ``devctl scan`` if they were exercising the mDNS path. 22 parser-shape unit tests guard against argument drift; the underlying tools already have full unit and integration coverage from earlier phases. --- .../device_connect_server/devctl/cli.py | 27 +- .../devctl/selector_cli.py | 103 +++++++ .../device_connect_server/statectl/cli.py | 23 ++ .../statectl/operations_cli.py | 282 ++++++++++++++++++ .../test_selector_cli.py | 199 ++++++++++++ 5 files changed, 630 insertions(+), 4 deletions(-) create mode 100644 packages/device-connect-server/device_connect_server/devctl/selector_cli.py create mode 100644 packages/device-connect-server/device_connect_server/statectl/operations_cli.py create mode 100644 packages/device-connect-server/tests/device_connect_server/test_selector_cli.py diff --git a/packages/device-connect-server/device_connect_server/devctl/cli.py b/packages/device-connect-server/device_connect_server/devctl/cli.py index 071b423..f73ec6a 100644 --- a/packages/device-connect-server/device_connect_server/devctl/cli.py +++ b/packages/device-connect-server/device_connect_server/devctl/cli.py @@ -574,9 +574,20 @@ def create_parser() -> argparse.ArgumentParser: p_reg.add_argument("--broker", default=None, help="Broker URL") p_reg.add_argument("--keepalive", action="store_true", help="Start heartbeat loop") - # discover command - p_discover = sub.add_parser("discover", help="Discover uncommissioned devices") - p_discover.add_argument("--timeout", type=int, default=5, help="Timeout in seconds") + # mdns-scan: discover uncommissioned devices on the local network. + # Renamed from the historical ``discover`` verb so the selector-driven + # ``discover`` below (which queries the fleet, not the local network) + # can take the natural name. + p_scan = sub.add_parser( + "mdns-scan", help="Discover uncommissioned devices via mDNS", + aliases=["scan"], + ) + p_scan.add_argument("--timeout", type=int, default=5, help="Timeout in seconds") + + # Selector-driven fleet discovery (new). Registers ``discover`` and + # ``discover-labels`` as parser entries. + from device_connect_server.devctl import selector_cli + selector_cli.register_subparsers(sub) # commission command p_commission = sub.add_parser("commission", help="Commission a device with PIN") @@ -617,9 +628,17 @@ def main(argv: Optional[List[str]] = None) -> None: loop.stop() print("\nbye!") - elif args.cmd == "discover": + elif args.cmd in ("mdns-scan", "scan"): asyncio.run(discover_devices(timeout=args.timeout)) + elif args.cmd == "discover": + from device_connect_server.devctl import selector_cli + sys.exit(selector_cli.run_discover(args)) + + elif args.cmd == "discover-labels": + from device_connect_server.devctl import selector_cli + sys.exit(selector_cli.run_discover_labels(args)) + elif args.cmd == "commission": asyncio.run( commission_device( diff --git a/packages/device-connect-server/device_connect_server/devctl/selector_cli.py b/packages/device-connect-server/device_connect_server/devctl/selector_cli.py new file mode 100644 index 0000000..68a6637 --- /dev/null +++ b/packages/device-connect-server/device_connect_server/devctl/selector_cli.py @@ -0,0 +1,103 @@ +# Copyright (c) 2024-2026, Arm Limited and Contributors. All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +"""``devctl`` selector-driven discovery verbs. + +Thin wrappers around ``device_connect_agent_tools.discover`` and +``discover_labels`` so operators can drive the same selector grammar +from a shell. +""" +from __future__ import annotations + +import json +import os +from typing import Any + + +def _connect(broker: str | None) -> None: + """Best-effort connect to the messaging backend. + + Reuses ``DEVICE_CONNECT_*`` and ``NATS_URL`` env vars when ``broker`` is + not given. Kept as a thin wrapper so all CLI verbs share the same + connect-or-fail semantics. + """ + from device_connect_agent_tools import connect + + if broker: + connect(nats_url=broker) + else: + nats_url = os.getenv("NATS_URL") or os.getenv("DEVICE_CONNECT_NATS_URL") + if nats_url: + connect(nats_url=nats_url) + else: + connect() + + +def _pretty(data: Any) -> str: + """Render a JSON payload for terminal output.""" + return json.dumps(data, indent=2, sort_keys=True, default=str) + + +def run_discover(args: Any) -> int: + """Execute ``devctl discover ""``.""" + from device_connect_agent_tools import disconnect, discover + + _connect(getattr(args, "broker", None)) + try: + result = discover( + args.selector, + offset=int(args.offset or 0), + limit=int(args.limit or 200), + ) + print(_pretty(result)) + return 0 if "error" not in result else 1 + finally: + try: + disconnect() + except Exception: # pragma: no cover + pass + + +def run_discover_labels(args: Any) -> int: + """Execute ``devctl discover-labels [--key K]``.""" + from device_connect_agent_tools import disconnect, discover_labels + + _connect(getattr(args, "broker", None)) + try: + result = discover_labels( + key=args.key, + offset=int(args.offset or 0), + limit=int(args.limit or 50), + ) + print(_pretty(result)) + return 0 if "error" not in result else 1 + finally: + try: + disconnect() + except Exception: # pragma: no cover + pass + + +def register_subparsers(sub: Any) -> None: + """Attach the discover / discover-labels subparsers to a devctl parser.""" + p = sub.add_parser( + "discover", + help="Resolve a selector to devices, functions, or events", + ) + p.add_argument("selector", help="Selector expression (e.g. 'device(category:camera)')") + p.add_argument("--broker", default=None, help="Messaging broker URL") + p.add_argument("--offset", type=int, default=0, help="Pagination offset") + p.add_argument("--limit", type=int, default=200, help="Page size") + + p = sub.add_parser( + "discover-labels", + help="Browse fleet label vocabulary", + ) + p.add_argument( + "--key", default=None, + help="Axis-qualified label key (e.g. 'device.location') for per-key pagination", + ) + p.add_argument("--broker", default=None, help="Messaging broker URL") + p.add_argument("--offset", type=int, default=0, help="Pagination offset") + p.add_argument("--limit", type=int, default=50, help="Page size") diff --git a/packages/device-connect-server/device_connect_server/statectl/cli.py b/packages/device-connect-server/device_connect_server/statectl/cli.py index e1a03ef..161afdd 100644 --- a/packages/device-connect-server/device_connect_server/statectl/cli.py +++ b/packages/device-connect-server/device_connect_server/statectl/cli.py @@ -408,6 +408,13 @@ def create_parser() -> argparse.ArgumentParser: # stats sub.add_parser("stats", help="Key counts by namespace") + # Selector-driven operations (invoke / invoke-many / broadcast / + # subscribe / await). These verbs do not touch etcd; they run over + # the messaging fabric. They live under statectl because they all + # change the live state of devices. + from device_connect_server.statectl import operations_cli + operations_cli.register_subparsers(sub) + return parser @@ -430,9 +437,25 @@ async def _run(args) -> None: await handler(client, args) +_OPERATIONS_DISPATCH = { + "invoke": "run_invoke", + "invoke-many": "run_invoke_many", + "broadcast": "run_broadcast", + "subscribe": "run_subscribe", + "await": "run_await", +} + + def main(): parser = create_parser() args = parser.parse_args() + if args.cmd in _OPERATIONS_DISPATCH: + # Operations verbs run over messaging, not etcd. Bypass the etcd + # client setup that the COMMANDS dispatch table assumes. + from device_connect_server.statectl import operations_cli + handler = getattr(operations_cli, _OPERATIONS_DISPATCH[args.cmd]) + sys.exit(handler(args)) + try: asyncio.run(_run(args)) except KeyboardInterrupt: diff --git a/packages/device-connect-server/device_connect_server/statectl/operations_cli.py b/packages/device-connect-server/device_connect_server/statectl/operations_cli.py new file mode 100644 index 0000000..630709a --- /dev/null +++ b/packages/device-connect-server/device_connect_server/statectl/operations_cli.py @@ -0,0 +1,282 @@ +# Copyright (c) 2024-2026, Arm Limited and Contributors. All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +"""``statectl`` selector-driven operations verbs. + +Thin wrappers around the agent-tools ``invoke`` / ``invoke_many`` / +``broadcast`` / ``subscribe`` / ``await_replies`` functions so operators +can fire selector-driven calls from a shell. +""" +from __future__ import annotations + +import json +import os +from typing import Any + + +def _connect(broker: str | None) -> None: + """Connect to the messaging backend using the same env-or-broker rules + as devctl's selector verbs.""" + from device_connect_agent_tools import connect + + if broker: + connect(nats_url=broker) + else: + nats_url = os.getenv("NATS_URL") or os.getenv("DEVICE_CONNECT_NATS_URL") + if nats_url: + connect(nats_url=nats_url) + else: + connect() + + +def _parse_param_kv(values: list[str] | None) -> dict[str, Any]: + """Parse ``--param k=v`` repeated args into a function-params dict. + + Values that look like JSON (``[...]``, ``{...}``, numbers, ``true`` / + ``false`` / ``null``) are decoded; everything else stays a string. This + matches what an operator would expect when typing + ``--param resolution=1080p --param tags='["a","b"]'``. + """ + out: dict[str, Any] = {} + for entry in values or []: + if "=" not in entry: + raise ValueError(f"--param must be 'k=v', got {entry!r}") + k, _, v = entry.partition("=") + k = k.strip() + if not k: + raise ValueError(f"--param has empty key in {entry!r}") + v_stripped = v.strip() + # JSON-decode obvious JSON-shaped values; fall back to raw string. + if ( + v_stripped.startswith(("[", "{", '"')) + or v_stripped in ("true", "false", "null") + or _looks_numeric(v_stripped) + ): + try: + out[k] = json.loads(v_stripped) + continue + except json.JSONDecodeError: + pass + out[k] = v + return out + + +def _looks_numeric(s: str) -> bool: + try: + float(s) + return True + except ValueError: + return False + + +def _pretty(data: Any) -> str: + return json.dumps(data, indent=2, sort_keys=True, default=str) + + +# -- verbs ---------------------------------------------------------- + + +def run_invoke(args: Any) -> int: + from device_connect_agent_tools import disconnect, invoke + + _connect(getattr(args, "broker", None)) + try: + result = invoke( + args.selector, + params=_parse_param_kv(args.param), + llm_reasoning=args.reason, + ) + print(_pretty(result)) + return 0 if result.get("success") else 1 + finally: + try: + disconnect() + except Exception: # pragma: no cover + pass + + +def run_invoke_many(args: Any) -> int: + from device_connect_agent_tools import disconnect, invoke_many + + _connect(getattr(args, "broker", None)) + try: + result = invoke_many( + args.selector, + params=_parse_param_kv(args.param), + timeout=float(args.timeout), + max_concurrency=int(args.max_concurrency), + llm_reasoning=args.reason, + ) + print(_pretty(result)) + return 0 if "error" not in result else 1 + finally: + try: + disconnect() + except Exception: # pragma: no cover + pass + + +def run_broadcast(args: Any) -> int: + from device_connect_agent_tools import broadcast, disconnect + + bindings = None + if args.bindings: + try: + bindings = json.loads(args.bindings) + except json.JSONDecodeError as e: + print(f"--bindings must be valid JSON: {e}") + return 2 + + _connect(getattr(args, "broker", None)) + try: + result = broadcast( + args.selector, + params=_parse_param_kv(args.param), + where=args.where, + bindings=bindings, + fire_at=float(args.fire_at) if args.fire_at is not None else None, + on_late=args.on_late, + llm_reasoning=args.reason, + ) + print(_pretty(result)) + return 0 if "error" not in result else 1 + finally: + try: + disconnect() + except Exception: # pragma: no cover + pass + + +def run_subscribe(args: Any) -> int: + """Stream events / replies for ``args.selector`` to stdout. + + Each message is printed as one JSON line so the output can be piped + into ``jq`` or grep. Runs until ``--timeout`` of idle silence elapses + or ``--until`` messages have been printed (whichever comes first). + """ + from device_connect_agent_tools import disconnect, subscribe + + _connect(getattr(args, "broker", None)) + try: + count = 0 + with subscribe(args.selector) as sub: + for msg in sub.iter( + timeout=float(args.timeout), poll_interval=0.05, + ): + print(json.dumps(msg, default=str)) + count += 1 + if args.until is not None and count >= int(args.until): + break + return 0 + finally: + try: + disconnect() + except Exception: # pragma: no cover + pass + + +def run_await(args: Any) -> int: + from device_connect_agent_tools import await_replies, disconnect + + _connect(getattr(args, "broker", None)) + try: + replies = await_replies( + args.correlation_id, + timeout=float(args.timeout), + until=int(args.until) if args.until is not None else None, + ) + print(_pretty(replies)) + return 0 + finally: + try: + disconnect() + except Exception: # pragma: no cover + pass + + +# -- parser wiring -------------------------------------------------- + + +def register_subparsers(sub: Any) -> None: + """Attach the operation subparsers to a statectl parser.""" + p = sub.add_parser("invoke", help="Call exactly one function on one device") + p.add_argument("selector", help="Function-scoped selector") + p.add_argument( + "--param", action="append", default=[], + help="Function param as k=v (repeatable; JSON values decoded)", + ) + p.add_argument("--reason", default=None, help="LLM reasoning") + p.add_argument("--broker", default=None, help="Messaging broker URL") + + p = sub.add_parser( + "invoke-many", help="Fan out a call over a selector-resolved set", + ) + p.add_argument("selector", help="Function-scoped selector") + p.add_argument( + "--param", action="append", default=[], + help="Function param as k=v (repeatable; JSON values decoded)", + ) + p.add_argument("--timeout", default=30.0, help="Per-target timeout (s)") + p.add_argument( + "--max-concurrency", default=32, dest="max_concurrency", + help="Parallel worker cap", + ) + p.add_argument("--reason", default=None, help="LLM reasoning") + p.add_argument("--broker", default=None, help="Messaging broker URL") + + p = sub.add_parser( + "broadcast", + help="Async fan-out; returns correlation_id", + ) + p.add_argument("selector", help="Function-scoped selector") + p.add_argument( + "--param", action="append", default=[], + help="Function param as k=v (repeatable; JSON values decoded)", + ) + p.add_argument( + "--where", default=None, + help="CEL predicate evaluated at the edge per candidate", + ) + p.add_argument( + "--bindings", default=None, + help="JSON-encoded bindings dict (shared payload for the predicate)", + ) + p.add_argument( + "--fire-at", default=None, dest="fire_at", + help="Wall-clock epoch seconds for synchronized fan-out", + ) + p.add_argument( + "--on-late", choices=["skip", "fire"], default="skip", dest="on_late", + help="Policy when fire_at deadline has passed (default: skip)", + ) + p.add_argument("--reason", default=None, help="LLM reasoning") + p.add_argument("--broker", default=None, help="Messaging broker URL") + + p = sub.add_parser( + "subscribe", help="Stream events or broadcast replies to stdout", + ) + p.add_argument( + "selector", + help="Event selector or 'correlation:' for broadcast replies", + ) + p.add_argument( + "--timeout", default=10.0, + help="Idle-silence timeout per message (s; resets on each arrival)", + ) + p.add_argument( + "--until", default=None, + help="Stop after this many messages are printed", + ) + p.add_argument("--broker", default=None, help="Messaging broker URL") + + p = sub.add_parser( + "await", help="Collect replies for a broadcast correlation_id", + ) + p.add_argument("correlation_id", help="Correlation id returned by broadcast") + p.add_argument("--timeout", default=10.0, help="Overall timeout (s)") + p.add_argument( + "--until", default=None, + help="Stop after this many replies have been collected", + ) + p.add_argument("--broker", default=None, help="Messaging broker URL") diff --git a/packages/device-connect-server/tests/device_connect_server/test_selector_cli.py b/packages/device-connect-server/tests/device_connect_server/test_selector_cli.py new file mode 100644 index 0000000..2ed2cf8 --- /dev/null +++ b/packages/device-connect-server/tests/device_connect_server/test_selector_cli.py @@ -0,0 +1,199 @@ +# Copyright (c) 2024-2026, Arm Limited and Contributors. All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +"""Smoke tests for the selector-driven CLI verbs. + +Argument-parser shape only; the underlying tools (``discover``, +``invoke``, ``broadcast``, etc.) have their own unit and integration +tests. These guards catch parser-config regressions (missing positional, +typoed dest, alias drift). +""" +from __future__ import annotations + +import json + +import pytest + +from device_connect_server.devctl import cli as devctl_cli +from device_connect_server.devctl import selector_cli +from device_connect_server.statectl import cli as statectl_cli +from device_connect_server.statectl import operations_cli + + +# -- devctl --------------------------------------------------------- + + +class TestDevctlSelectorParser: + def test_discover_requires_selector(self): + parser = devctl_cli.create_parser() + with pytest.raises(SystemExit): + parser.parse_args(["discover"]) + + def test_discover_parses_selector(self): + parser = devctl_cli.create_parser() + args = parser.parse_args(["discover", "device(category:camera)"]) + assert args.cmd == "discover" + assert args.selector == "device(category:camera)" + assert args.offset == 0 + assert args.limit == 200 + + def test_discover_offset_limit_override(self): + parser = devctl_cli.create_parser() + args = parser.parse_args( + ["discover", "device(*)", "--offset", "100", "--limit", "50"] + ) + assert args.offset == 100 + assert args.limit == 50 + + def test_discover_labels_no_key(self): + parser = devctl_cli.create_parser() + args = parser.parse_args(["discover-labels"]) + assert args.cmd == "discover-labels" + assert args.key is None + assert args.limit == 50 + + def test_discover_labels_key_pagination(self): + parser = devctl_cli.create_parser() + args = parser.parse_args( + ["discover-labels", "--key", "device.location", "--limit", "20"] + ) + assert args.key == "device.location" + assert args.limit == 20 + + def test_legacy_discover_renamed_to_mdns_scan(self): + # The historical "discover" verb (mDNS scan) now lives under + # mdns-scan; the alias "scan" keeps it discoverable. + parser = devctl_cli.create_parser() + for verb in ("mdns-scan", "scan"): + args = parser.parse_args([verb]) + # Both aliases share the same args.cmd + assert args.cmd in ("mdns-scan", "scan") + + +# -- statectl ------------------------------------------------------- + + +class TestStatectlOperationsParser: + def test_invoke_requires_selector(self): + parser = statectl_cli.create_parser() + with pytest.raises(SystemExit): + parser.parse_args(["invoke"]) + + def test_invoke_parses(self): + parser = statectl_cli.create_parser() + args = parser.parse_args( + [ + "invoke", "device(robot-001).function(grip_close)", + "--param", "force_n=10", + "--reason", "test", + ] + ) + assert args.cmd == "invoke" + assert args.selector == "device(robot-001).function(grip_close)" + assert args.param == ["force_n=10"] + assert args.reason == "test" + + def test_invoke_many_with_timeout(self): + parser = statectl_cli.create_parser() + args = parser.parse_args( + [ + "invoke-many", + "function(safety:critical)", + "--timeout", "5", + "--max-concurrency", "8", + ] + ) + assert args.cmd == "invoke-many" + assert float(args.timeout) == 5.0 + assert int(args.max_concurrency) == 8 + + def test_broadcast_full_signature(self): + parser = statectl_cli.create_parser() + args = parser.parse_args( + [ + "broadcast", + "device(category:phone).function(set_flashlight)", + "--param", "on=true", + "--param", "color=white", + "--where", "labels.location == 'lab-A'", + "--bindings", '{"mask": [[0,1],[1,0]]}', + "--fire-at", "1700000000.0", + "--on-late", "fire", + ] + ) + assert args.cmd == "broadcast" + assert args.selector.startswith("device(category:phone)") + assert args.where == "labels.location == 'lab-A'" + assert args.on_late == "fire" + + def test_broadcast_rejects_unknown_on_late(self): + parser = statectl_cli.create_parser() + with pytest.raises(SystemExit): + parser.parse_args( + [ + "broadcast", "device(*).function(do)", + "--on-late", "bogus", + ] + ) + + def test_subscribe_parses_correlation_form(self): + parser = statectl_cli.create_parser() + args = parser.parse_args( + ["subscribe", "correlation:br-abc123", "--until", "5"] + ) + assert args.cmd == "subscribe" + assert args.selector == "correlation:br-abc123" + assert int(args.until) == 5 + + def test_await_requires_correlation_id(self): + parser = statectl_cli.create_parser() + with pytest.raises(SystemExit): + parser.parse_args(["await"]) + + def test_await_parses(self): + parser = statectl_cli.create_parser() + args = parser.parse_args( + ["await", "br-abc123", "--timeout", "2.5", "--until", "10"] + ) + assert args.correlation_id == "br-abc123" + assert float(args.timeout) == 2.5 + assert int(args.until) == 10 + + +# -- parameter parsing ---------------------------------------------- + + +class TestParseParamKV: + def test_string_values_default(self): + result = operations_cli._parse_param_kv(["a=hello", "b=world"]) + assert result == {"a": "hello", "b": "world"} + + def test_numbers_decoded(self): + result = operations_cli._parse_param_kv(["count=5", "ratio=0.75"]) + assert result == {"count": 5, "ratio": 0.75} + + def test_booleans_decoded(self): + result = operations_cli._parse_param_kv(["on=true", "off=false"]) + assert result == {"on": True, "off": False} + + def test_json_array_decoded(self): + result = operations_cli._parse_param_kv(["zones=[1,2,3]"]) + assert result == {"zones": [1, 2, 3]} + + def test_json_object_decoded(self): + result = operations_cli._parse_param_kv(['nested={"a":1}']) + assert result == {"nested": {"a": 1}} + + def test_string_with_equals(self): + # The split is on the first '=', so values may contain further '='. + result = operations_cli._parse_param_kv(["query=a=b"]) + assert result == {"query": "a=b"} + + def test_invalid_form_rejected(self): + with pytest.raises(ValueError): + operations_cli._parse_param_kv(["no_equals_sign"]) + + def test_empty_key_rejected(self): + with pytest.raises(ValueError): + operations_cli._parse_param_kv(["=value"]) From 02a94201bc97fe30eb4e22edd45aff0b95bbf0ad Mon Sep 17 00:00:00 2001 From: Sourav Pati Date: Sun, 10 May 2026 12:14:12 -0700 Subject: [PATCH 11/21] docs: extend discovery guide for operations, where, and CLI Add the operations layer (invoke / invoke_many / broadcast / subscribe / await_replies) to docs/discovery.md, with the edge-side ``where`` predicate, synchronized fan-out via ``fire_at`` / ``on_late``, worked examples that exercise each tool, and the corresponding devctl / statectl CLI verbs. The guide now covers everything the discovery API ships: labels schema, selector grammar, the five scope shapes, response envelope, error codes, all seven tools, and the CLI surface. --- docs/discovery.md | 184 +++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 182 insertions(+), 2 deletions(-) diff --git a/docs/discovery.md b/docs/discovery.md index 6d2b8b4..ee33d60 100644 --- a/docs/discovery.md +++ b/docs/discovery.md @@ -96,7 +96,9 @@ function(estop) fleet emergency-st ## Tools -### `discover(selector, offset=0, limit=200)` +### Discovery + +#### `discover(selector, offset=0, limit=200)` Resolves a selector to matched entities. Returns devices, function tuples, or event tuples depending on the selector scope. The response includes a @@ -108,7 +110,7 @@ and switches to a name-and-labels summary above `DEVICE_CONNECT_FUNCTION_THRESHOLD` (default 20). The threshold is configurable via environment variable. -### `discover_labels(key=None, offset=0, limit=50)` +#### `discover_labels(key=None, offset=0, limit=50)` Returns the fleet label vocabulary. Use this first when you do not know which dimensions are available. @@ -118,6 +120,84 @@ which dimensions are available. - With a `key` like `"device.location"` or `"function.direction"`: paginates the full value list for that one key. +### Operations + +Calling a function on devices is one logical operation; the only choice +is whether the caller waits for replies and how they arrive. + +| Tool | Selector resolves to | Reply mode | +| --- | --- | --- | +| `invoke(selector, params)` | exactly one (device, function) tuple | sync, single result | +| `invoke_many(selector, params, timeout=)` | any number of (device, function) tuples | sync, aggregated | +| `broadcast(selector, params, where=, bindings=, fire_at=, on_late=)` | any number of (device, function) tuples | async; correlation-tagged replies stream as events | +| `subscribe(selector)` | events, or `"correlation:"` for broadcast replies | live stream (`Subscription` handle) | +| `await_replies(correlation_id, timeout=, until=)` | replies for one broadcast | sync helper that subscribes, collects, returns | + +`invoke_many` runs every target's call in parallel and returns when each +target has finished or hit its per-target timeout (30 s default). Partial +failures do not abort siblings; the response carries both `results` and +`errors` lists. + +`broadcast` does the same fan-out asynchronously: the caller gets a +`correlation_id` immediately and replies stream back on a per-device +subject keyed by that id. Subscribe with `subscribe("correlation:")` +or block with `await_replies(correlation_id, timeout=...)`. + +### Edge-side `where` predicate + +`broadcast` accepts an optional `where` expression that runs at each +candidate device. The predicate is a CEL (Common Expression Language) +string and sees four variables: + +- `identity` — device-local identity dict (`device_id`, `device_type`, ...) +- `labels` — device labels (the same labels selectors filter on) +- `status` — device status (heartbeat-updated: `location`, `availability`, + `battery`, `online`, ...) +- `bindings` — the shared payload passed to `broadcast` (selection masks, + thresholds, lookup tables) + +```python +broadcast( + "device(category:camera).function(capture_image)", + params={"resolution": "4k"}, + where="status.battery > 50 && labels.location == 'lab-A'", +) +``` + +The `where` predicate is sandboxed by CEL (no I/O, no filesystem). The +predicate evaluator is an optional install: + +``` +pip install device-connect-agent-tools[predicate] +``` + +Without the extra, calling `broadcast(..., where=...)` returns an +`invalid_predicate` error immediately at the dispatcher; calls without a +`where` work unchanged. + +### Synchronized fan-out (`fire_at` + `on_late`) + +`broadcast` accepts an optional `fire_at` (wall-clock epoch seconds). +Each device holds the message and fires from its own clock at the +deadline. `on_late` controls behaviour when a device receives the +message past the deadline: + +- `"skip"` (default) — drop the call to preserve coherence. +- `"fire"` — execute immediately. + +```python +broadcast( + "device(category:phone).function(set_flashlight)", + params={"on": True, "color": "white"}, + fire_at=time.time() + 0.500, # 500 ms in the future + on_late="skip", +) +``` + +With NTP-synced devices the achieved spread is typically 5-10 ms +(clock-sync residual) rather than the 50-150 ms a naive fire-on-receipt +broadcast would produce. + ## Response envelope `discover` returns a stable envelope: @@ -229,3 +309,103 @@ while True: break offset = page["next_offset"] ``` + +### Invoke a single function + +```python +from device_connect_agent_tools import invoke + +result = invoke( + "device(robot-001).function(grip_close)", + {"force_n": 10}, +) +# {"success": True, "device_id": "robot-001", "function": "grip_close", +# "result": {...}} +``` + +### Fan out across every camera in lab-A + +```python +from device_connect_agent_tools import invoke_many + +result = invoke_many( + "device(category:camera, location:lab-A).function(capture_image)", + {"resolution": "4k"}, +) +# {"candidates": 12, "matched": 12, "succeeded": 12, "failed": 0, +# "results": [...], "errors": []} +``` + +### Async fleet emergency stop + +```python +from device_connect_agent_tools import broadcast, await_replies + +result = broadcast("function(estop)") +# {"correlation_id": "br-7f3a91", "candidates": 240, ...} + +replies = await_replies(result["correlation_id"], timeout=5.0) +# list of {device_id, success, result|error, actually_fired_at} +``` + +### Synchronized actuation across a phone fleet + +```python +import time +from device_connect_agent_tools import broadcast + +mask = build_mask_from_scores(threshold=0.8) # caller-side selection +broadcast( + "device(category:phone, location:auditorium-A).function(set_flashlight)", + params={"on": True, "color": "white"}, + where="mask[seat_row][seat_col] == 1 && status.battery > 30", + bindings={"mask": mask}, + fire_at=time.time() + 0.5, + on_late="skip", +) +``` + +### Subscribe to motion events in lab-A + +```python +from device_connect_agent_tools import subscribe + +with subscribe("device(location:lab-A/*).event(modality:motion)") as sub: + for event in sub.iter(timeout=60.0): + handle(event) +``` + +## CLI + +The same selector syntax drives the operator CLIs. Every CLI command +maps to the matching Python tool call. + +``` +# Discovery (devctl) +devctl discover "" [--offset N] [--limit M] +devctl discover-labels [--key K] [--offset N] [--limit M] + +# Operations (statectl) +statectl invoke "" [--param k=v ...] +statectl invoke-many "" [--param k=v ...] [--timeout T] +statectl broadcast "" [--param k=v ...] [--where E] + [--bindings JSON] [--fire-at T] + [--on-late skip|fire] +statectl subscribe "" [--timeout T] [--until N] +statectl await [--timeout T] [--until N] +``` + +`--param k=v` accepts JSON-shaped values (numbers, booleans, arrays, +objects); everything else passes through as a string. So +`--param resolution=4k` and `--param zones='[1,2,3]'` both work +without quoting heroics. + +Each verb exits non-zero on tool-side errors so the verbs compose into +shell pipelines: + +``` +statectl broadcast "device(category:camera).function(capture_image)" \ + --param resolution=4k \ + | jq -r .correlation_id \ + | xargs statectl await --timeout 5 +``` From 7c760ab655ab510e4509c695fdaa2e96d083dcd6 Mon Sep 17 00:00:00 2001 From: Sourav Pati Date: Sun, 10 May 2026 12:39:14 -0700 Subject: [PATCH 12/21] fix(broadcast): robustness pass on edge handler, subscribe, and CLI Applies findings from the pre-merge review of the operations stack: Edge runtime (device.py): - Hand the broadcast envelope off to a tracked task so the subscription callback returns immediately. A long fire_at hold or slow driver function no longer blocks subsequent broadcasts from being received. - Extract _handle_broadcast_envelope and _evaluate_where so the where self-election step is isolated, unit-testable, and the callback body stays flat. - Splice device_id into the predicate's identity context so the natural ``identity.device_id == "..."`` form works (DeviceIdentity itself does not carry device_id; that lives on the runtime). Wire format (tools.py + device.py): - Rename the broadcast envelope's ``target_device_ids`` field to ``targets`` before any edge ships. Shorter, less prescriptive, and matches the dispatcher-side ``candidates`` naming. Subscription handle (tools.py): - Fix a race in Subscription.read(): truncate by the snapshot length captured BEFORE iteration, not by clearing post-iteration. A message appended by the messaging callback during draining now survives to the next read instead of being silently dropped. - Add __iter__ so ``for msg in sub:`` works with a sensible 30s idle timeout, matching the standard Python iteration protocol. CLI (statectl/operations_cli.py): - statectl subscribe now catches KeyboardInterrupt cleanly (exit 130), distinguishes "got messages" (exit 0) from "idle timeout with no messages" (exit 4), so shell pipelines can branch on either outcome. - statectl invoke-many exits 3 when any target failed (alongside the existing 1 for top-level errors), so partial failure is visible to callers without parsing JSON. ASCII compliance (predicate.py, tools.py): - Drop a banned-vocabulary token from a docstring. - Replace an em-dash in invoke_device's docstring with ASCII text. New tests: - Unit: __iter__ protocol + race-safety guard for Subscription.read. - Integration: broadcast where=identity.device_id in bindings.allow (exercises the new identity context + bindings path), await_replies(until=) early-return timing, ``for msg in sub:`` iteration end-to-end, and subscribe(event(...)) live-event capture. --- .../device_connect_agent_tools/connection.py | 2 +- .../device_connect_agent_tools/tools.py | 48 ++-- .../tests/test_broadcast.py | 2 +- .../tests/test_subscribe.py | 39 +++ .../device_connect_edge/device.py | 251 ++++++++++-------- .../device_connect_edge/predicate.py | 2 +- .../statectl/operations_cli.py | 35 ++- tests/tests/test_tools_broadcast.py | 145 ++++++++++ 8 files changed, 387 insertions(+), 137 deletions(-) diff --git a/packages/device-connect-agent-tools/device_connect_agent_tools/connection.py b/packages/device-connect-agent-tools/device_connect_agent_tools/connection.py index b399f70..4dce5fb 100644 --- a/packages/device-connect-agent-tools/device_connect_agent_tools/connection.py +++ b/packages/device-connect-agent-tools/device_connect_agent_tools/connection.py @@ -415,7 +415,7 @@ def publish_broadcast(self, envelope: Dict[str, Any]) -> None: The envelope shape is documented in ``device_connect_edge.device.DeviceRuntime._broadcast_subscription``; every device subscribed to ``device-connect..broadcast`` - receives the message and self-elects via ``target_device_ids`` and + receives the message and self-elects via ``targets`` and the optional ``where`` predicate. """ return self._run(self._async_publish_broadcast(envelope)) diff --git a/packages/device-connect-agent-tools/device_connect_agent_tools/tools.py b/packages/device-connect-agent-tools/device_connect_agent_tools/tools.py index c81faf2..c99c5bc 100644 --- a/packages/device-connect-agent-tools/device_connect_agent_tools/tools.py +++ b/packages/device-connect-agent-tools/device_connect_agent_tools/tools.py @@ -829,7 +829,7 @@ def broadcast( } correlation_id = f"br-{uuid.uuid4().hex[:12]}" - target_device_ids = sorted({ + targets = sorted({ row.get("device_id") for row in rows if row.get("device_id") }) clean_params = { @@ -840,7 +840,7 @@ def broadcast( "correlation_id": correlation_id, "function": function_name, "params": clean_params, - "target_device_ids": target_device_ids, + "targets": targets, } if where: envelope["where"] = where @@ -857,7 +857,7 @@ def broadcast( ) logger.info( "[broadcast::%s::%d targets] Reason: %s", - correlation_id, len(target_device_ids), truncated, + correlation_id, len(targets), truncated, ) try: @@ -866,13 +866,13 @@ def broadcast( except Exception as e: logger.error("broadcast publish failed: %s", e) return { - "candidates": len(target_device_ids), + "candidates": len(targets), "error": _error("connection_error", str(e)), } return { "correlation_id": correlation_id, - "candidates": len(target_device_ids), + "candidates": len(targets), "selector": selector, "function": function_name, } @@ -919,26 +919,27 @@ def read(self, max_messages: int | None = None) -> list[dict[str, Any]]: Returns parsed payload dicts (already JSON-decoded by the connection's buffered subscription path). Subsequent calls return only messages that arrived after the previous call. + + Race-safe against the messaging callback that appends to the same + inbox: each inbox is read by snapshotting its current length and + truncating only that prefix, so a message that arrives during + iteration stays buffered for the next ``read``. """ if self._closed: return [] out: list[dict[str, Any]] = [] for name in self._inbox_names: - inboxes = self._conn.get_inbox(name) - buffered = inboxes.get(name, []) or [] - # Each buffered entry is (subject, payload). We expose the - # parsed payload but stamp the subject onto it so callers can - # distinguish per-source messages without parsing it themselves. - for subject, payload in buffered: + buf = self._conn._inbox.get(name) or [] + # Snapshot the consumed prefix length BEFORE iterating, then + # truncate by exactly that many items. Any message appended by + # the messaging callback between the snapshot and the truncation + # remains buffered for a subsequent ``read``. + n = len(buf) + for subject, payload in buf[:n]: if not isinstance(payload, dict): payload = {"raw": payload} - payload = {**payload, "_subject": subject} - out.append(payload) - # Fast cursor: trim per-inbox buffers we have already returned by - # truncating from the front. The connection layer already caps each - # inbox at 1000 entries, so bounded growth is its concern. - for name in self._inbox_names: - self._conn._inbox[name] = [] + out.append({**payload, "_subject": subject}) + self._conn._inbox[name] = buf[n:] if max_messages is not None: out = out[:max_messages] return out @@ -962,6 +963,15 @@ def iter(self, timeout: float = 5.0, poll_interval: float = 0.05): return time.sleep(poll_interval) + def __iter__(self): + """Allow ``for msg in sub:`` with a default 30-second idle timeout. + + Delegates to :meth:`iter` with sensible defaults so the idiomatic + Python iteration form works. Use ``sub.iter(timeout=...)`` directly + when the default does not fit. + """ + return self.iter(timeout=30.0, poll_interval=0.05) + def close(self) -> None: """Tear down the underlying messaging subscriptions.""" if self._closed: @@ -1324,7 +1334,7 @@ def invoke_device( device_id: Target device ID (e.g., "robot-001", "camera-001"). function: Function name to call. params: Function parameters as a dictionary. - llm_reasoning: Why you're calling this function -- for observability. + llm_reasoning: Why you are calling this function (for observability). """ warnings.warn( "invoke_device(device_id, function, ...) is deprecated; use " diff --git a/packages/device-connect-agent-tools/tests/test_broadcast.py b/packages/device-connect-agent-tools/tests/test_broadcast.py index e25a35e..e8d8831 100644 --- a/packages/device-connect-agent-tools/tests/test_broadcast.py +++ b/packages/device-connect-agent-tools/tests/test_broadcast.py @@ -100,7 +100,7 @@ def test_envelope_carries_function_and_targets(self, mock_conn): env = mock_conn._published[0] assert env["function"] == "capture_image" assert env["params"] == {"resolution": "4k"} - assert sorted(env["target_device_ids"]) == ["cam-001", "cam-002"] + assert sorted(env["targets"]) == ["cam-001", "cam-002"] # No optional fields when caller did not set them. assert "where" not in env assert "bindings" not in env diff --git a/packages/device-connect-agent-tools/tests/test_subscribe.py b/packages/device-connect-agent-tools/tests/test_subscribe.py index a6f032e..a8b4be4 100644 --- a/packages/device-connect-agent-tools/tests/test_subscribe.py +++ b/packages/device-connect-agent-tools/tests/test_subscribe.py @@ -159,6 +159,45 @@ def test_iter_yields_until_idle_timeout(self, fake_conn): assert len(msgs) == 1 sub.close() + def test_for_loop_protocol_via_dunder_iter(self, fake_conn): + # ``for msg in sub:`` should drive __iter__ which delegates to iter() + # with a sensible default timeout. Break early so the test does not + # block on the 30s default. + sub = tools_mod.subscribe("correlation:r_iter") + fake_conn.deliver( + "device-connect.default.cam-001.event.async_reply.r_iter", + {"correlation_id": "r_iter", "device_id": "cam-001"}, + ) + gathered: list[dict] = [] + for msg in sub: + gathered.append(msg) + break # one message is enough to confirm __iter__ wiring + sub.close() + assert len(gathered) == 1 + assert gathered[0]["device_id"] == "cam-001" + + def test_read_does_not_drop_messages_appended_during_iteration(self, fake_conn): + # Race-safety guard: simulate a callback that appends a fresh + # message between the read's snapshot and truncation. The message + # must still be visible on the next read(). + sub = tools_mod.subscribe("correlation:r_race") + fake_conn.deliver( + "device-connect.default.cam-001.event.async_reply.r_race", + {"correlation_id": "r_race", "device_id": "cam-001", "ordinal": 1}, + ) + first = sub.read() + assert len(first) == 1 + # Now simulate a late-arriving append into the same inbox AFTER + # the previous read drained the prefix. + fake_conn.deliver( + "device-connect.default.cam-002.event.async_reply.r_race", + {"correlation_id": "r_race", "device_id": "cam-002", "ordinal": 2}, + ) + second = sub.read() + assert len(second) == 1 + assert second[0]["device_id"] == "cam-002" + sub.close() + # -- await_replies -------------------------------------------------- diff --git a/packages/device-connect-edge/device_connect_edge/device.py b/packages/device-connect-edge/device_connect_edge/device.py index b64d443..96e31d4 100644 --- a/packages/device-connect-edge/device_connect_edge/device.py +++ b/packages/device-connect-edge/device_connect_edge/device.py @@ -1144,17 +1144,22 @@ async def _broadcast_subscription(self) -> None: "correlation_id": "br-abc123", "function": "capture_image", "params": {"resolution": "4k"}, - "target_device_ids": ["cam-001", "cam-002"], // pre-resolved - "where": "status.battery > 50", // optional CEL - "bindings": {"mask": [[0,1],[1,0]]}, // optional - "fire_at": 1234567890.5, // optional, epoch s - "on_late": "skip" // skip|fire + "targets": ["cam-001", "cam-002"], // pre-resolved + "where": "status.battery > 50", // optional CEL + "bindings": {"mask": [[0,1],[1,0]]}, // optional + "fire_at": 1234567890.5, // optional, epoch s + "on_late": "skip" // skip|fire } On match, the device executes the function and emits a reply on ``device-connect...event.async_reply.`` with ``{correlation_id, device_id, success, result|error, actually_fired_at}``. + + The envelope is processed in a tracked task so the subscription + loop does not block on ``fire_at`` sleeps or long-running driver + functions; subsequent broadcasts can continue to land while an + earlier one is in flight. """ subj = f"device-connect.{self.tenant}.broadcast" @@ -1169,115 +1174,151 @@ async def on_msg(data: bytes, reply_subject: Optional[str]): if not correlation_id: return - # Self-election step 1: target_device_ids gate (pre-resolved by - # the dispatcher from the selector). When absent or empty, treat - # the broadcast as fleet-wide. - targets = envelope.get("target_device_ids") or [] + # Cheap self-election: target gate (pre-resolved by the dispatcher + # from the selector). When absent or empty, treat as fleet-wide. + targets = envelope.get("targets") or [] if targets and self.device_id not in targets: return - function_name = envelope.get("function") - if not function_name: + if not envelope.get("function"): return - params_dict = envelope.get("params", {}) or {} - # Self-election step 2: where predicate against {identity, labels, - # status, bindings}. A failed compile or eval is treated as - # fail-closed (do not execute). - where_expr = envelope.get("where") - if where_expr: - try: - from device_connect_edge.predicate import compile_where - predicate = compile_where(where_expr) - caps = self._driver.capabilities if self._driver else self.capabilities - status = self._driver.status if self._driver else None - labels = (caps.labels if caps and caps.labels else {}) or {} - status_dict = ( - status.model_dump() if status and hasattr(status, "model_dump") else {} - ) - # Mirror the legacy DeviceStatus.location into labels so - # ``labels.location`` works in predicates without the driver - # having to declare it explicitly. Matches the dispatcher-side - # flatten_device contract. - if "location" not in labels and status_dict.get("location"): - labels = {**labels, "location": status_dict["location"]} - context = { - "identity": ( - caps.identity.model_dump() - if caps and getattr(caps, "identity", None) else {} - ), - "labels": labels, - "status": status_dict, - "bindings": envelope.get("bindings", {}) or {}, - } - if not predicate.evaluate(context): - return - except Exception as e: - self._logger.warning( - "Broadcast %s: where predicate failed (skipping): %s", - correlation_id, e, - ) - return + # Hand off to a tracked task. The task owns the where evaluation, + # the fire_at sleep, and the driver call, so this callback returns + # immediately and the messaging subscription stays drained. + self._track_task(asyncio.create_task( + self._handle_broadcast_envelope(envelope, correlation_id) + )) - # fire_at: hold the message until the wall-clock deadline. The - # on_late policy decides what to do if the message arrives past - # the deadline (skip preserves coherence; fire runs anyway). - fire_at = envelope.get("fire_at") - on_late = envelope.get("on_late", "skip") - if fire_at is not None: - delay = float(fire_at) - time.time() - if delay < 0 and on_late == "skip": - self._logger.info( - "Broadcast %s arrived %.3fs late, on_late=skip", - correlation_id, -delay, - ) - return - if delay > 0: - await asyncio.sleep(delay) + await self.messaging.subscribe(subj, callback=on_msg) + self._logger.info("Subscribed to broadcasts on %s", subj) - # Execute the driver function and emit the reply. - actually_fired_at = time.time() - reply_subj = ( - f"device-connect.{self.tenant}.{self.device_id}" - f".event.async_reply.{correlation_id}" - ) - try: - if self._driver is None: - raise RuntimeError("no driver configured") - driver_functions = self._driver._get_functions() - if function_name not in driver_functions: - raise RuntimeError(f"unknown function: {function_name}") - result = await self._driver.invoke(function_name, **params_dict) - reply_payload = { - "correlation_id": correlation_id, - "device_id": self.device_id, - "success": True, - "result": result, - "actually_fired_at": actually_fired_at, - } - except Exception as e: - self._logger.warning( - "Broadcast %s: function %s failed: %s", - correlation_id, function_name, e, - ) - reply_payload = { - "correlation_id": correlation_id, - "device_id": self.device_id, - "success": False, - "error": {"code": "invoke_failed", "message": str(e)}, - "actually_fired_at": actually_fired_at, - } - try: - await self.messaging.publish( - reply_subj, json.dumps(reply_payload).encode(), - ) - except Exception as e: # pragma: no cover - self._logger.warning( - "Broadcast %s: reply publish failed: %s", correlation_id, e, + + async def _handle_broadcast_envelope( + self, envelope: Dict[str, Any], correlation_id: str, + ) -> None: + """Process one broadcast envelope: evaluate where, honour fire_at, invoke, reply. + + Runs in its own task so a long-held ``fire_at`` or slow driver + function does not block the subscription callback from accepting + subsequent broadcasts. + """ + function_name = envelope.get("function") + params_dict = envelope.get("params", {}) or {} + + # Step 1: where predicate against {identity, labels, status, bindings}. + # A failed compile or eval is treated as fail-closed (do not execute); + # the message is logged at WARNING with the correlation_id so an + # operator can correlate a silent skip with a misspelled label key. + where_expr = envelope.get("where") + if where_expr and not self._evaluate_where( + where_expr, envelope.get("bindings"), correlation_id, + ): + return + + # Step 2: fire_at hold. The on_late policy decides what to do when + # the message arrives past the deadline (skip preserves coherence; + # fire runs anyway). + fire_at = envelope.get("fire_at") + on_late = envelope.get("on_late", "skip") + if fire_at is not None: + delay = float(fire_at) - time.time() + if delay < 0 and on_late == "skip": + self._logger.info( + "Broadcast %s arrived %.3fs late, on_late=skip", + correlation_id, -delay, ) + return + if delay > 0: + await asyncio.sleep(delay) - await self.messaging.subscribe(subj, callback=on_msg) - self._logger.info("Subscribed to broadcasts on %s", subj) + # Step 3: execute and reply. + actually_fired_at = time.time() + reply_subj = ( + f"device-connect.{self.tenant}.{self.device_id}" + f".event.async_reply.{correlation_id}" + ) + try: + if self._driver is None: + raise RuntimeError("no driver configured") + driver_functions = self._driver._get_functions() + if function_name not in driver_functions: + raise RuntimeError(f"unknown function: {function_name}") + result = await self._driver.invoke(function_name, **params_dict) + reply_payload: Dict[str, Any] = { + "correlation_id": correlation_id, + "device_id": self.device_id, + "success": True, + "result": result, + "actually_fired_at": actually_fired_at, + } + except Exception as e: + self._logger.warning( + "Broadcast %s: function %s failed: %s", + correlation_id, function_name, e, + ) + reply_payload = { + "correlation_id": correlation_id, + "device_id": self.device_id, + "success": False, + "error": {"code": "invoke_failed", "message": str(e)}, + "actually_fired_at": actually_fired_at, + } + try: + await self.messaging.publish( + reply_subj, json.dumps(reply_payload).encode(), + ) + except Exception as e: # pragma: no cover + self._logger.warning( + "Broadcast %s: reply publish failed: %s", correlation_id, e, + ) + + + def _evaluate_where( + self, + where_expr: str, + bindings: Optional[Dict[str, Any]], + correlation_id: str, + ) -> bool: + """Compile and evaluate a where predicate; return True iff it passes. + + Returns False (do not execute) on compile or eval errors, logging + a warning so silent self-deselection is operator-visible. + """ + try: + from device_connect_edge.predicate import compile_where + predicate = compile_where(where_expr) + caps = self._driver.capabilities if self._driver else self.capabilities + status = self._driver.status if self._driver else None + labels = (caps.labels if caps and caps.labels else {}) or {} + status_dict = ( + status.model_dump() if status and hasattr(status, "model_dump") else {} + ) + # Mirror DeviceStatus.location into labels so ``labels.location`` + # works in predicates without the driver having to declare it + # explicitly. Matches the dispatcher-side flatten_device contract. + if "location" not in labels and status_dict.get("location"): + labels = {**labels, "location": status_dict["location"]} + # The DeviceIdentity model carries device_type / manufacturer / + # model / firmware_version but NOT device_id (which lives on the + # runtime). Splice it in so predicates can write the natural + # ``identity.device_id == "..."``. + identity_dict: Dict[str, Any] = {"device_id": self.device_id} + if caps and getattr(caps, "identity", None): + identity_dict.update(caps.identity.model_dump()) + context = { + "identity": identity_dict, + "labels": labels, + "status": status_dict, + "bindings": bindings or {}, + } + return bool(predicate.evaluate(context)) + except Exception as e: + self._logger.warning( + "Broadcast %s: where predicate failed (skipping): %s", + correlation_id, e, + ) + return False async def _event_dispatch_loop(self) -> None: diff --git a/packages/device-connect-edge/device_connect_edge/predicate.py b/packages/device-connect-edge/device_connect_edge/predicate.py index 6ddc7c0..5bf5ff6 100644 --- a/packages/device-connect-edge/device_connect_edge/predicate.py +++ b/packages/device-connect-edge/device_connect_edge/predicate.py @@ -15,7 +15,7 @@ bindings shared payload supplied by the caller (selection masks, thresholds, lookup tables) -Examples (every example here ships with v4 spec):: +Examples:: battery > 50 labels.category == "camera" && status.battery > 50 diff --git a/packages/device-connect-server/device_connect_server/statectl/operations_cli.py b/packages/device-connect-server/device_connect_server/statectl/operations_cli.py index 630709a..7ddc9ef 100644 --- a/packages/device-connect-server/device_connect_server/statectl/operations_cli.py +++ b/packages/device-connect-server/device_connect_server/statectl/operations_cli.py @@ -109,7 +109,13 @@ def run_invoke_many(args: Any) -> int: llm_reasoning=args.reason, ) print(_pretty(result)) - return 0 if "error" not in result else 1 + # Exit non-zero on a top-level error OR when any target failed, so + # shell pipelines can detect partial failure without parsing JSON. + if "error" in result: + return 1 + if result.get("failed", 0) > 0: + return 3 + return 0 finally: try: disconnect() @@ -154,21 +160,30 @@ def run_subscribe(args: Any) -> int: Each message is printed as one JSON line so the output can be piped into ``jq`` or grep. Runs until ``--timeout`` of idle silence elapses or ``--until`` messages have been printed (whichever comes first). + Exit codes: + 0 one or more messages were printed + 4 idle-timeout reached with zero messages + 130 interrupted with Ctrl-C """ from device_connect_agent_tools import disconnect, subscribe _connect(getattr(args, "broker", None)) + count = 0 try: - count = 0 with subscribe(args.selector) as sub: - for msg in sub.iter( - timeout=float(args.timeout), poll_interval=0.05, - ): - print(json.dumps(msg, default=str)) - count += 1 - if args.until is not None and count >= int(args.until): - break - return 0 + try: + for msg in sub.iter( + timeout=float(args.timeout), poll_interval=0.05, + ): + print(json.dumps(msg, default=str)) + count += 1 + if args.until is not None and count >= int(args.until): + break + except KeyboardInterrupt: + # Clean exit on Ctrl-C: the ``with`` block tears the + # subscription down before this returns. + return 130 + return 0 if count > 0 else 4 finally: try: disconnect() diff --git a/tests/tests/test_tools_broadcast.py b/tests/tests/test_tools_broadcast.py index 0e7413f..975016e 100644 --- a/tests/tests/test_tools_broadcast.py +++ b/tests/tests/test_tools_broadcast.py @@ -183,6 +183,151 @@ async def test_broadcast_fire_at_late_with_skip_drops( await asyncio.to_thread(disconnect) +@pytest.mark.asyncio +@pytest.mark.integration +async def test_broadcast_where_with_bindings(device_spawner, messaging_url): + """A where predicate that reads bindings. self-elects per-target.""" + pytest.importorskip("celpy") + await device_spawner.spawn_camera("itest-bcbnd-cam-1", location="lab-A") + await device_spawner.spawn_camera("itest-bcbnd-cam-2", location="lab-A") + await asyncio.sleep(SETTLE_TIME) + + from device_connect_agent_tools import ( + await_replies, broadcast, disconnect, + ) + + await _wait_for_devices( + messaging_url, {"itest-bcbnd-cam-1", "itest-bcbnd-cam-2"} + ) + try: + # Allowlist sent in bindings; the predicate uses bindings.allow to + # select. Devices not in the allowlist self-deselect silently. + result = await asyncio.to_thread( + broadcast, + "device(itest-bcbnd-cam-*).function(capture_image)", + None, + "identity.device_id in bindings.allow", + {"allow": ["itest-bcbnd-cam-1"]}, + ) + assert result["candidates"] == 2 + replies = await asyncio.to_thread( + await_replies, result["correlation_id"], timeout=3.0, + ) + ids = {r["device_id"] for r in replies} + assert ids == {"itest-bcbnd-cam-1"} + finally: + await asyncio.to_thread(disconnect) + + +@pytest.mark.asyncio +@pytest.mark.integration +async def test_await_replies_until_stops_early(device_spawner, messaging_url): + """``await_replies`` returns once ``until`` replies have arrived.""" + await device_spawner.spawn_camera("itest-awu-cam-1", location="lab-A") + await device_spawner.spawn_camera("itest-awu-cam-2", location="lab-A") + await device_spawner.spawn_camera("itest-awu-cam-3", location="lab-A") + await asyncio.sleep(SETTLE_TIME) + + from device_connect_agent_tools import ( + await_replies, broadcast, disconnect, + ) + + await _wait_for_devices( + messaging_url, {"itest-awu-cam-1", "itest-awu-cam-2", "itest-awu-cam-3"} + ) + try: + result = await asyncio.to_thread( + broadcast, "device(itest-awu-cam-*).function(capture_image)", + ) + assert result["candidates"] == 3 + # until=1 should let us return after the first reply arrives even + # though more are coming. + t0 = time.monotonic() + replies = await asyncio.to_thread( + await_replies, result["correlation_id"], + timeout=5.0, until=1, poll_interval=0.02, + ) + elapsed = time.monotonic() - t0 + assert len(replies) >= 1 + # Sanity: returning early should be well under the timeout. + assert elapsed < 2.0, f"await_replies(until=1) took {elapsed:.2f}s" + finally: + await asyncio.to_thread(disconnect) + + +@pytest.mark.asyncio +@pytest.mark.integration +async def test_subscribe_iter_protocol(device_spawner, messaging_url): + """``for msg in sub:`` works via Subscription.__iter__.""" + await device_spawner.spawn_camera("itest-subiter-cam-1", location="lab-A") + await device_spawner.spawn_camera("itest-subiter-cam-2", location="lab-A") + await asyncio.sleep(SETTLE_TIME) + + from device_connect_agent_tools import broadcast, disconnect, subscribe + + await _wait_for_devices( + messaging_url, {"itest-subiter-cam-1", "itest-subiter-cam-2"} + ) + try: + result = await asyncio.to_thread( + broadcast, + "device(itest-subiter-cam-*).function(capture_image)", + ) + cid = result["correlation_id"] + + def collect(): + # Exercise the bare ``for msg in sub:`` form (uses __iter__). + # Break after both expected replies arrive so the test stays + # bounded regardless of the default idle timeout. + with subscribe(f"correlation:{cid}") as sub: + gathered: list[dict] = [] + for msg in sub: + gathered.append(msg) + if len(gathered) >= 2: + break + return gathered + + replies = await asyncio.to_thread(collect) + ids = {r["device_id"] for r in replies} + assert ids == {"itest-subiter-cam-1", "itest-subiter-cam-2"} + finally: + await asyncio.to_thread(disconnect) + + +@pytest.mark.asyncio +@pytest.mark.integration +async def test_subscribe_event_selector_live_stream(device_spawner, messaging_url): + """subscribe(event()) receives live events from matching devices.""" + device, driver = await device_spawner.spawn_camera( + "itest-evsub-cam", location="lab-A", + ) + await asyncio.sleep(SETTLE_TIME) + + from device_connect_agent_tools import disconnect, subscribe + + await _wait_for_devices(messaging_url, {"itest-evsub-cam"}) + try: + with subscribe("device(itest-evsub-cam).event(object_detected)") as sub: + await asyncio.sleep(SETTLE_TIME) # let subscription warm up + await driver.trigger_event( + "object_detected", + {"label": "person", "confidence": 0.95}, + ) + msgs = await asyncio.to_thread( + list, sub.iter(timeout=2.0, poll_interval=0.05), + ) + # The event arrives via the JSON-RPC event subject; payload is + # under either ``params`` or top-level depending on transport. + matching = [ + m for m in msgs + if (m.get("params") or {}).get("label") == "person" + or m.get("label") == "person" + ] + assert matching, f"no object_detected events received: {msgs}" + finally: + await asyncio.to_thread(disconnect) + + @pytest.mark.asyncio @pytest.mark.integration async def test_subscribe_correlation_form(device_spawner, messaging_url): From 8660801eeb511b8bb9a35e959bac8b7ad01b7caa Mon Sep 17 00:00:00 2001 From: Sourav Pati Date: Sun, 10 May 2026 16:44:50 -0700 Subject: [PATCH 13/21] feat(adapters): expose broadcast and await_replies via all three adapters Phases 4-5 added broadcast() and await_replies() to the agent-tools surface but the adapter migration in feat(invoke) only carried invoke / invoke_many across. The flashlight-auditorium demo needs the LLM to issue selector-driven broadcasts with where + bindings + fire_at, so broadcast and await_replies both need to be Strands/LangChain/Claude tools as well. Tool descriptions for the Claude adapter spell out the broadcast + await_replies pairing (caller fires broadcast, then awaits replies by correlation_id) so agents discover the workflow from the tool docs. subscribe() is intentionally NOT exposed via the adapters: it returns a Subscription handle that does not serialise cleanly as a tool result and is more natural to call from operator code or the CLI than from an LLM. Agents needing the same shape use broadcast + await_replies. --- .../adapters/claude.py | 55 +++++++++++++++++++ .../adapters/langchain.py | 6 ++ .../adapters/strands.py | 6 ++ .../tests/test_claude_adapter.py | 2 + .../tests/test_langchain_adapter.py | 2 + .../tests/test_strands_adapter.py | 2 + 6 files changed, 73 insertions(+) diff --git a/packages/device-connect-agent-tools/device_connect_agent_tools/adapters/claude.py b/packages/device-connect-agent-tools/device_connect_agent_tools/adapters/claude.py index 9dd08d8..f4a2883 100644 --- a/packages/device-connect-agent-tools/device_connect_agent_tools/adapters/claude.py +++ b/packages/device-connect-agent-tools/device_connect_agent_tools/adapters/claude.py @@ -47,6 +47,8 @@ async def main(): discover_devices as _discover_devices, invoke as _invoke, invoke_many as _invoke_many, + broadcast as _broadcast, + await_replies as _await_replies, invoke_device_with_fallback as _invoke_device_with_fallback, get_device_status as _get_device_status, ) @@ -147,6 +149,55 @@ async def invoke_many(args: dict[str, Any]) -> dict[str, Any]: ) +@tool( + "broadcast", + "Async selector-driven fan-out. Returns immediately with a " + "correlation_id; replies stream on a per-device subject keyed by id. " + "Each candidate self-elects via the optional CEL `where` predicate " + "(evaluated at the edge against identity/labels/status/bindings) and " + "executes the function. Use fire_at (wall-clock epoch seconds) + " + "on_late (skip|fire) for synchronized fan-out. Pair with " + "await_replies(correlation_id) to collect outcomes.", + { + "selector": str, "params": dict, "where": str, "bindings": dict, + "fire_at": float, "on_late": str, "llm_reasoning": str, + }, +) +async def broadcast(args: dict[str, Any]) -> dict[str, Any]: + return _text( + _broadcast( + selector=args["selector"], + params=args.get("params"), + where=args.get("where"), + bindings=args.get("bindings"), + fire_at=args.get("fire_at"), + on_late=args.get("on_late", "skip"), + llm_reasoning=args.get("llm_reasoning"), + ) + ) + + +@tool( + "await_replies", + "Collect replies for a broadcast() call. Subscribes to the " + "correlation reply subject, drains for up to `timeout` seconds (or " + "until `until` replies have arrived), then returns the list.", + { + "correlation_id": str, "timeout": float, "until": int, + "poll_interval": float, + }, +) +async def await_replies(args: dict[str, Any]) -> dict[str, Any]: + return _text( + _await_replies( + correlation_id=args["correlation_id"], + timeout=float(args.get("timeout", 10.0)), + until=int(args["until"]) if args.get("until") is not None else None, + poll_interval=float(args.get("poll_interval", 0.05)), + ) + ) + + # Other invocation helpers @@ -206,6 +257,8 @@ def create_device_connect_server(name: str = "device-connect"): discover, invoke, invoke_many, + broadcast, + await_replies, invoke_device_with_fallback, get_device_status, discover_devices, @@ -218,6 +271,8 @@ def create_device_connect_server(name: str = "device-connect"): "discover", "invoke", "invoke_many", + "broadcast", + "await_replies", "invoke_device_with_fallback", "get_device_status", "discover_devices", diff --git a/packages/device-connect-agent-tools/device_connect_agent_tools/adapters/langchain.py b/packages/device-connect-agent-tools/device_connect_agent_tools/adapters/langchain.py index 35d5e51..c18ed7e 100644 --- a/packages/device-connect-agent-tools/device_connect_agent_tools/adapters/langchain.py +++ b/packages/device-connect-agent-tools/device_connect_agent_tools/adapters/langchain.py @@ -26,6 +26,8 @@ discover_devices as _discover_devices, invoke as _invoke, invoke_many as _invoke_many, + broadcast as _broadcast, + await_replies as _await_replies, invoke_device_with_fallback as _invoke_device_with_fallback, get_device_status as _get_device_status, ) @@ -37,6 +39,8 @@ # Selector-driven invocation (recommended) invoke = StructuredTool.from_function(_invoke) invoke_many = StructuredTool.from_function(_invoke_many) +broadcast = StructuredTool.from_function(_broadcast) +await_replies = StructuredTool.from_function(_await_replies) # Other invocation helpers invoke_device_with_fallback = StructuredTool.from_function(_invoke_device_with_fallback) @@ -50,6 +54,8 @@ "discover", "invoke", "invoke_many", + "broadcast", + "await_replies", "invoke_device_with_fallback", "get_device_status", "discover_devices", diff --git a/packages/device-connect-agent-tools/device_connect_agent_tools/adapters/strands.py b/packages/device-connect-agent-tools/device_connect_agent_tools/adapters/strands.py index d22fcf7..b68c16b 100644 --- a/packages/device-connect-agent-tools/device_connect_agent_tools/adapters/strands.py +++ b/packages/device-connect-agent-tools/device_connect_agent_tools/adapters/strands.py @@ -27,6 +27,8 @@ discover_devices as _discover_devices, invoke as _invoke, invoke_many as _invoke_many, + broadcast as _broadcast, + await_replies as _await_replies, invoke_device_with_fallback as _invoke_device_with_fallback, get_device_status as _get_device_status, ) @@ -38,6 +40,8 @@ # Selector-driven invocation (recommended) invoke = strands_tool(_invoke) invoke_many = strands_tool(_invoke_many) +broadcast = strands_tool(_broadcast) +await_replies = strands_tool(_await_replies) # Other invocation helpers invoke_device_with_fallback = strands_tool(_invoke_device_with_fallback) @@ -51,6 +55,8 @@ "discover", "invoke", "invoke_many", + "broadcast", + "await_replies", "invoke_device_with_fallback", "get_device_status", "discover_devices", diff --git a/packages/device-connect-agent-tools/tests/test_claude_adapter.py b/packages/device-connect-agent-tools/tests/test_claude_adapter.py index 311aab5..4960a49 100644 --- a/packages/device-connect-agent-tools/tests/test_claude_adapter.py +++ b/packages/device-connect-agent-tools/tests/test_claude_adapter.py @@ -70,6 +70,8 @@ def _mock_sdk_and_connection(): "discover_devices", "invoke", "invoke_many", + "broadcast", + "await_replies", "invoke_device_with_fallback", "get_device_status", ) diff --git a/packages/device-connect-agent-tools/tests/test_langchain_adapter.py b/packages/device-connect-agent-tools/tests/test_langchain_adapter.py index c4a487e..9aae070 100644 --- a/packages/device-connect-agent-tools/tests/test_langchain_adapter.py +++ b/packages/device-connect-agent-tools/tests/test_langchain_adapter.py @@ -74,6 +74,8 @@ def _mock_langchain_and_connection(): "discover", "invoke", "invoke_many", + "broadcast", + "await_replies", "invoke_device_with_fallback", "get_device_status", "discover_devices", diff --git a/packages/device-connect-agent-tools/tests/test_strands_adapter.py b/packages/device-connect-agent-tools/tests/test_strands_adapter.py index 30d1ae0..4e46ceb 100644 --- a/packages/device-connect-agent-tools/tests/test_strands_adapter.py +++ b/packages/device-connect-agent-tools/tests/test_strands_adapter.py @@ -57,6 +57,8 @@ def _mock_strands_and_connection(): "discover", "invoke", "invoke_many", + "broadcast", + "await_replies", "invoke_device_with_fallback", "get_device_status", "discover_devices", From 08c9aa7d69029c88963c251c07b146f5b4f4cc6b Mon Sep 17 00:00:00 2001 From: Sourav Pati Date: Sun, 10 May 2026 18:45:51 -0700 Subject: [PATCH 14/21] fix(broadcast): read identity from driver, not from DeviceCapabilities The broadcast handler built the where-predicate context from ``caps.identity`` -- but DeviceCapabilities does not carry an ``identity`` field; that lives on the driver as a separate DeviceIdentity model. The ``getattr(caps, "identity", None)`` fallback masked the bug: identity_dict was always just ``{"device_id": ...}`` with none of the driver's extra fields (seat_row, seat_col, x-mhp slot metadata, ...) reaching the predicate. Symptom: a where predicate like ``bindings.mask[identity.seat_row][identity.seat_col] == 1`` failed at every candidate (CEL surfaces undefined field access as CELEvalError, fail-closed fires, nobody self-elects). Fix: read identity from ``self._driver.identity`` and splice in ``device_id`` from the runtime. Backwards-compatible with drivers that don't expose an identity property (driver_identity is None -> only device_id is present, same as before for those drivers). Surfaced while building the flashlight-auditorium demo, where each phone exposes its seat coordinates as extra fields on DeviceIdentity and the spell-CMU broadcast indexes a 2D mask by those coordinates. --- .../device_connect_edge/device.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/packages/device-connect-edge/device_connect_edge/device.py b/packages/device-connect-edge/device_connect_edge/device.py index 96e31d4..c776d1a 100644 --- a/packages/device-connect-edge/device_connect_edge/device.py +++ b/packages/device-connect-edge/device_connect_edge/device.py @@ -1299,13 +1299,18 @@ def _evaluate_where( # explicitly. Matches the dispatcher-side flatten_device contract. if "location" not in labels and status_dict.get("location"): labels = {**labels, "location": status_dict["location"]} - # The DeviceIdentity model carries device_type / manufacturer / - # model / firmware_version but NOT device_id (which lives on the - # runtime). Splice it in so predicates can write the natural - # ``identity.device_id == "..."``. + # DeviceIdentity is exposed by the driver, not by DeviceCapabilities; + # they are independent pydantic models. Read identity from the + # driver so extra fields (seat_row, seat_col, x-mhp metadata, ...) + # reach the predicate context. Splice in device_id which lives on + # the runtime so predicates can write + # ``identity.device_id == "..."`` naturally. identity_dict: Dict[str, Any] = {"device_id": self.device_id} - if caps and getattr(caps, "identity", None): - identity_dict.update(caps.identity.model_dump()) + driver_identity = ( + getattr(self._driver, "identity", None) if self._driver else None + ) + if driver_identity is not None and hasattr(driver_identity, "model_dump"): + identity_dict.update(driver_identity.model_dump()) context = { "identity": identity_dict, "labels": labels, From 641b7bd380d3f8f36c6b5dd56cb3a636ab87faeb Mon Sep 17 00:00:00 2001 From: Sourav Pati Date: Mon, 11 May 2026 08:18:41 -0700 Subject: [PATCH 15/21] feat: add device mandate enforcement --- docs/device-mandates-spec.md | 51 ++++ .../device_connect_agent_tools/tools.py | 45 +++- .../tests/test_agent_mandates.py | 117 +++++++++ .../device_connect_edge/__init__.py | 12 + .../device_connect_edge/device.py | 66 ++++- .../device_connect_edge/drivers/__init__.py | 2 + .../device_connect_edge/drivers/base.py | 2 + .../device_connect_edge/drivers/decorators.py | 10 + .../device_connect_edge/mandates.py | 247 ++++++++++++++++++ .../device_connect_edge/types.py | 6 + .../tests/test_device_mandates.py | 179 +++++++++++++ .../device-connect-edge/tests/test_drivers.py | 20 +- .../tests/test_mandate_verifier.py | 144 ++++++++++ .../device-connect-edge/tests/test_types.py | 8 + 14 files changed, 901 insertions(+), 8 deletions(-) create mode 100644 docs/device-mandates-spec.md create mode 100644 packages/device-connect-agent-tools/tests/test_agent_mandates.py create mode 100644 packages/device-connect-edge/device_connect_edge/mandates.py create mode 100644 packages/device-connect-edge/tests/test_device_mandates.py create mode 100644 packages/device-connect-edge/tests/test_mandate_verifier.py diff --git a/docs/device-mandates-spec.md b/docs/device-mandates-spec.md new file mode 100644 index 0000000..7e817fd --- /dev/null +++ b/docs/device-mandates-spec.md @@ -0,0 +1,51 @@ +# Spec: Device Mandates + +## Objective + +Add an optional verifiable authorization layer for Device Connect RPC execution. A device function can declare that it requires a Device Mandate, and the runtime refuses to execute protected RPCs unless the caller presents a signed mandate that authorizes the target device, method, parameters, and validity window. + +The first implementation slice proves the contract end to end with a lightweight HMAC-backed mandate format suitable for local tests and demos. The verifier is intentionally small and pluggable so a later slice can replace or augment the credential format with UCAN, Biscuit, or a standards-track profile without changing the decorator or RPC metadata contract. + +## Commands + +- Edge tests: `pytest packages/device-connect-edge/tests -q` +- Agent tools tests: `pytest packages/device-connect-agent-tools/tests -q` +- Focused mandate tests: `pytest packages/device-connect-edge/tests/test_mandate_verifier.py packages/device-connect-edge/tests/test_device_mandates.py packages/device-connect-agent-tools/tests/test_agent_mandates.py -q` + +## Project Structure + +- `packages/device-connect-edge/device_connect_edge/mandates.py`: mandate data helpers, signing, and verification. +- `packages/device-connect-edge/device_connect_edge/drivers/decorators.py`: `@requires_mandate` decorator metadata. +- `packages/device-connect-edge/device_connect_edge/device.py`: runtime enforcement before driver invocation. +- `packages/device-connect-edge/device_connect_edge/types.py`: function capability metadata for mandate requirements. +- `packages/device-connect-agent-tools/device_connect_agent_tools/tools.py`: pass mandate metadata through `_dc_meta`. + +## Testing Strategy + +Use test-driven slices: + +- Pure unit tests for signing, verification, time windows, device/method binding, numeric constraints, tamper detection, and replay denial. +- Runtime tests for protected RPC denial before driver execution and successful execution with a valid mandate. +- Agent-tools tests that verify `invoke`, `invoke_many`, `broadcast`, and legacy `invoke_device` attach mandate data inside `_dc_meta`. + +## Boundaries + +- Always: fail closed for protected methods; keep mandate support optional for unprotected methods; preserve existing unprotected RPC behavior. +- Ask first: adding non-stdlib crypto/credential dependencies; changing transport protocols; adding persistent receipt storage; modifying CI. +- Never: treat unsigned client-provided mandate dictionaries as valid; pass `_dc_meta` into user driver methods; weaken existing ACL/TLS/JWT checks. + +## Success Criteria + +- A driver can mark an RPC with `@requires_mandate(scope="actuation")`. +- Discovery/capability metadata shows mandate requirements for protected functions. +- Direct JSON-RPC and broadcast execution reject protected functions with no mandate, invalid signature, wrong device, wrong method, expired mandate, or out-of-range parameters. +- Direct JSON-RPC and broadcast execution allow a protected function with a valid closed mandate. +- Agent tools can attach mandate data to invoke paths through `_dc_meta`. +- Existing unprotected RPC tests continue to pass. + +## Open Questions + +- Which production credential format should be the default: UCAN, Biscuit, or a future AP2-compatible non-payment profile? +- Where should production principal keys live: OS keystore, HSM/KMS, commissioning bundle, or registry-backed trust store? +- Should execution receipts be persisted first in the server state store or emitted as signed events before storage is added? +- Should replay protection be in-memory per device for v0, or backed by the server state layer for distributed deployments? diff --git a/packages/device-connect-agent-tools/device_connect_agent_tools/tools.py b/packages/device-connect-agent-tools/device_connect_agent_tools/tools.py index c99c5bc..a9a8e5b 100644 --- a/packages/device-connect-agent-tools/device_connect_agent_tools/tools.py +++ b/packages/device-connect-agent-tools/device_connect_agent_tools/tools.py @@ -542,10 +542,30 @@ def _shape_invoke_response( } +def _clean_params_with_mandate( + params: dict[str, Any] | None, + mandate: dict[str, Any] | None, +) -> dict[str, Any]: + clean = {k: v for k, v in (params or {}).items() if k != "llm_reasoning"} + if mandate is None: + return clean + meta = clean.get("_dc_meta") + if meta is None: + meta = {} + elif not isinstance(meta, dict): + raise ValueError("_dc_meta must be an object when mandate is provided") + else: + meta = dict(meta) + meta["mandate"] = mandate + clean["_dc_meta"] = meta + return clean + + def invoke( selector: str, params: dict[str, Any] | None = None, llm_reasoning: str | None = None, + mandate: dict[str, Any] | None = None, ) -> dict[str, Any]: """Resolve a selector to one (device, function) tuple and invoke it. @@ -611,7 +631,7 @@ def invoke( try: conn = get_connection() - clean = {k: v for k, v in (params or {}).items() if k != "llm_reasoning"} + clean = _clean_params_with_mandate(params, mandate) response = conn.invoke(device_id, function_name, params=clean) except Exception as e: logger.error( @@ -633,6 +653,7 @@ def invoke_many( timeout: float = DEFAULT_INVOKE_TIMEOUT, max_concurrency: int = DEFAULT_INVOKE_CONCURRENCY, llm_reasoning: str | None = None, + mandate: dict[str, Any] | None = None, ) -> dict[str, Any]: """Resolve a selector to (device, function) tuples and invoke each in parallel. @@ -683,7 +704,13 @@ def invoke_many( return out workers = max(1, min(max_concurrency, len(rows))) - clean = {k: v for k, v in (params or {}).items() if k != "llm_reasoning"} + try: + clean = _clean_params_with_mandate(params, mandate) + except ValueError as e: + return { + "candidates": len(rows), "matched": len(rows), "succeeded": 0, "failed": len(rows), + "results": [], "errors": [], "error": _error("invalid_params", str(e)), + } def call_one(row: dict) -> dict[str, Any]: device_id = row.get("device_id") or "" @@ -735,6 +762,7 @@ def broadcast( fire_at: float | None = None, on_late: str = "skip", llm_reasoning: str | None = None, + mandate: dict[str, Any] | None = None, ) -> dict[str, Any]: """Async selector-driven fan-out. Returns immediately with a correlation id. @@ -832,9 +860,13 @@ def broadcast( targets = sorted({ row.get("device_id") for row in rows if row.get("device_id") }) - clean_params = { - k: v for k, v in (params or {}).items() if k != "llm_reasoning" - } + try: + clean_params = _clean_params_with_mandate(params, mandate) + except ValueError as e: + return { + "candidates": len(targets), + "error": _error("invalid_params", str(e)), + } envelope: dict[str, Any] = { "correlation_id": correlation_id, @@ -1327,6 +1359,7 @@ def invoke_device( function: str, params: dict[str, Any] | None = None, llm_reasoning: str | None = None, + mandate: dict[str, Any] | None = None, ) -> dict[str, Any]: """Call a function on a Device Connect device (deprecated; use invoke()). @@ -1349,7 +1382,7 @@ def invoke_device( try: conn = get_connection() - clean = {k: v for k, v in (params or {}).items() if k != "llm_reasoning"} + clean = _clean_params_with_mandate(params, mandate) response = conn.invoke(device_id, function, params=clean) if "error" in response: diff --git a/packages/device-connect-agent-tools/tests/test_agent_mandates.py b/packages/device-connect-agent-tools/tests/test_agent_mandates.py new file mode 100644 index 0000000..a889dec --- /dev/null +++ b/packages/device-connect-agent-tools/tests/test_agent_mandates.py @@ -0,0 +1,117 @@ +# Copyright (c) 2024-2026, Arm Limited and Contributors. All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +"""Agent-tool tests for carrying Device Mandates in _dc_meta.""" + +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +from device_connect_agent_tools import tools as tools_mod + + +SAMPLE_DEVICES = [ + { + "device_id": "lock-001", + "device_type": "lock", + "status": {"state": "online"}, + "identity": {"device_type": "lock"}, + "labels": {"category": "lock"}, + "functions": [ + { + "name": "unlock", + "parameters": {}, + "labels": {"direction": "write", "safety": "critical"}, + "mandate": {"required": True, "scope": "actuation"}, + }, + ], + "events": [], + }, + { + "device_id": "lock-002", + "device_type": "lock", + "status": {"state": "online"}, + "identity": {"device_type": "lock"}, + "labels": {"category": "lock"}, + "functions": [ + { + "name": "unlock", + "parameters": {}, + "labels": {"direction": "write", "safety": "critical"}, + "mandate": {"required": True, "scope": "actuation"}, + }, + ], + "events": [], + }, +] + + +MANDATE = {"format": "device-connect-hmac-v0", "closed": {"id": "closed-1"}} + + +def _conn(): + conn = MagicMock() + conn.list_devices.return_value = SAMPLE_DEVICES + conn.invoke.return_value = {"jsonrpc": "2.0", "id": "1", "result": {"ok": True}} + conn._published = [] + conn.publish_broadcast.side_effect = lambda env: conn._published.append(env) + return conn + + +def test_invoke_attaches_mandate_under_dc_meta(): + conn = _conn() + with patch.object(tools_mod, "get_connection", return_value=conn): + result = tools_mod.invoke( + "device(lock-001).function(unlock)", + params={"duration_s": 30}, + mandate=MANDATE, + ) + + assert result["success"] is True + sent = conn.invoke.call_args.kwargs["params"] + assert sent["duration_s"] == 30 + assert sent["_dc_meta"]["mandate"] == MANDATE + + +def test_invoke_many_attaches_mandate_to_each_call(): + conn = _conn() + with patch.object(tools_mod, "get_connection", return_value=conn): + tools_mod.invoke_many( + "device(category:lock).function(unlock)", + params={"duration_s": 30}, + mandate=MANDATE, + ) + + assert conn.invoke.call_count == 2 + for call in conn.invoke.call_args_list: + assert call.kwargs["params"]["_dc_meta"]["mandate"] == MANDATE + + +def test_broadcast_attaches_mandate_under_params_dc_meta(): + conn = _conn() + with patch.object(tools_mod, "get_connection", return_value=conn): + result = tools_mod.broadcast( + "device(category:lock).function(unlock)", + params={"duration_s": 30}, + mandate=MANDATE, + ) + + assert result["candidates"] == 2 + env = conn._published[0] + assert env["params"]["duration_s"] == 30 + assert env["params"]["_dc_meta"]["mandate"] == MANDATE + + +def test_legacy_invoke_device_attaches_mandate_under_dc_meta(): + conn = _conn() + with patch.object(tools_mod, "get_connection", return_value=conn): + tools_mod.invoke_device( + "lock-001", + "unlock", + params={"duration_s": 30}, + mandate=MANDATE, + ) + + sent = conn.invoke.call_args.kwargs["params"] + assert sent["_dc_meta"]["mandate"] == MANDATE diff --git a/packages/device-connect-edge/device_connect_edge/__init__.py b/packages/device-connect-edge/device_connect_edge/__init__.py index 4812c51..b3a3b4b 100644 --- a/packages/device-connect-edge/device_connect_edge/__init__.py +++ b/packages/device-connect-edge/device_connect_edge/__init__.py @@ -43,6 +43,13 @@ async def alert(self, level: str, msg: str): FunctionDef, EventDef, ) +from device_connect_edge.mandates import ( + MandateInvocationContext, + MandateVerificationResult, + create_closed_mandate, + create_open_mandate, + verify_mandate, +) from device_connect_edge.discovery_provider import DiscoveryProvider from device_connect_edge.registry_client import RegistryClient from device_connect_edge.errors import ( @@ -66,6 +73,11 @@ async def alert(self, level: str, msg: str): "DeviceStatus", "FunctionDef", "EventDef", + "MandateInvocationContext", + "MandateVerificationResult", + "create_closed_mandate", + "create_open_mandate", + "verify_mandate", "DiscoveryProvider", "RegistryClient", "DeviceConnectError", diff --git a/packages/device-connect-edge/device_connect_edge/device.py b/packages/device-connect-edge/device_connect_edge/device.py index c776d1a..8768930 100644 --- a/packages/device-connect-edge/device_connect_edge/device.py +++ b/packages/device-connect-edge/device_connect_edge/device.py @@ -77,6 +77,11 @@ async def capture_image(self, resolution: str = "1080p") -> dict: DeviceIdentity, DeviceStatus, ) +from device_connect_edge.mandates import ( + MandateInvocationContext, + MandateVerificationResult, + verify_mandate, +) # Type checking imports for driver support if TYPE_CHECKING: @@ -163,6 +168,16 @@ def build_rpc_error(id_: str, code: int, msg: str) -> bytes: ).encode() +def _broadcast_error(message: str) -> Dict[str, str]: + code = "invoke_failed" + if ":" in message: + prefix, _, rest = message.partition(":") + if prefix.startswith("mandate_") or prefix in {"invalid_mandate", "unknown_mandate_key"}: + code = prefix + message = rest.strip() or message + return {"code": code, "message": message} + + class DeviceRuntime: """High-level runtime for Device Connect devices. @@ -259,6 +274,7 @@ def __init__( auto_commission: bool = True, commissioning_port: int = 5540, allow_insecure: Optional[bool] = None, + mandate_keys: Optional[Dict[str, Union[bytes, str]]] = None, ) -> None: # Store driver reference and connect driver to this device self._driver = driver @@ -349,6 +365,8 @@ def __init__( self.allow_insecure = os.getenv("DEVICE_CONNECT_ALLOW_INSECURE", "").lower() in ("1", "true", "yes") else: self.allow_insecure = allow_insecure + self._mandate_keys: Dict[str, Union[bytes, str]] = dict(mandate_keys or {}) + self._mandate_replay_cache: set[str] = set() self._factory_identity: Optional[dict] = None # Initialize logger and internal state early (before commissioning checks) @@ -1112,6 +1130,19 @@ async def on_msg(data: bytes, reply_subject: Optional[str]): "device_connect.source_device": source_device or "", }, ): + mandate_result = self._verify_mandate_for_invocation( + method, params_dict, dc_meta, + ) + if not mandate_result.ok: + if reply_subject: + await self.messaging.publish( + reply_subject, + build_rpc_error( + payload["id"], -32041, + mandate_result.message or mandate_result.error_code or "mandate_denied", + ) + ) + return # Pass source_device to driver for logging (existing pattern) if source_device: params_dict["source_device"] = source_device @@ -1135,6 +1166,32 @@ async def on_msg(data: bytes, reply_subject: Optional[str]): self._logger.info("Subscribed to commands on %s", subj) + def _verify_mandate_for_invocation( + self, + function_name: str, + params: Dict[str, Any], + dc_meta: Optional[Dict[str, Any]], + ) -> MandateVerificationResult: + """Verify mandate metadata when a function declares it is required.""" + if self._driver is None: + return MandateVerificationResult(ok=True) + method = self._driver._get_functions().get(function_name) + mandate_policy = getattr(method, "_mandate", None) + if not mandate_policy or not mandate_policy.get("required"): + return MandateVerificationResult(ok=True) + meta = dc_meta if isinstance(dc_meta, dict) else {} + return verify_mandate( + meta.get("mandate"), + context=MandateInvocationContext( + device_id=self.device_id, + method=function_name, + params=params, + ), + key_resolver=self._mandate_keys.get, + replay_cache=self._mandate_replay_cache, + ) + + async def _broadcast_subscription(self) -> None: """Subscribe to selector-driven broadcasts and self-elect to handle. @@ -1205,6 +1262,7 @@ async def _handle_broadcast_envelope( """ function_name = envelope.get("function") params_dict = envelope.get("params", {}) or {} + dc_meta = params_dict.pop("_dc_meta", {}) # Step 1: where predicate against {identity, labels, status, bindings}. # A failed compile or eval is treated as fail-closed (do not execute); @@ -1244,6 +1302,12 @@ async def _handle_broadcast_envelope( driver_functions = self._driver._get_functions() if function_name not in driver_functions: raise RuntimeError(f"unknown function: {function_name}") + mandate_result = self._verify_mandate_for_invocation( + function_name, params_dict, dc_meta, + ) + if not mandate_result.ok: + code = mandate_result.error_code or "mandate_denied" + raise RuntimeError(f"{code}: {mandate_result.message}") result = await self._driver.invoke(function_name, **params_dict) reply_payload: Dict[str, Any] = { "correlation_id": correlation_id, @@ -1261,7 +1325,7 @@ async def _handle_broadcast_envelope( "correlation_id": correlation_id, "device_id": self.device_id, "success": False, - "error": {"code": "invoke_failed", "message": str(e)}, + "error": _broadcast_error(str(e)), "actually_fired_at": actually_fired_at, } try: diff --git a/packages/device-connect-edge/device_connect_edge/drivers/__init__.py b/packages/device-connect-edge/device_connect_edge/drivers/__init__.py index 3abdf1d..ac996b8 100644 --- a/packages/device-connect-edge/device_connect_edge/drivers/__init__.py +++ b/packages/device-connect-edge/device_connect_edge/drivers/__init__.py @@ -37,6 +37,7 @@ async def motion_detected(self, zone: str, confidence: float): emit, before_emit, periodic, + requires_mandate, build_function_schema, build_event_schema, ) @@ -55,6 +56,7 @@ async def motion_detected(self, zone: str, confidence: float): "emit", "before_emit", "periodic", + "requires_mandate", "on", "build_function_schema", "build_event_schema", diff --git a/packages/device-connect-edge/device_connect_edge/drivers/base.py b/packages/device-connect-edge/device_connect_edge/drivers/base.py index 6fc013d..8c01d4b 100644 --- a/packages/device-connect-edge/device_connect_edge/drivers/base.py +++ b/packages/device-connect-edge/device_connect_edge/drivers/base.py @@ -389,12 +389,14 @@ def _collect_functions(self) -> List[FunctionDef]: description = getattr(attr, "_description", "") parameters = build_function_schema(attr) labels = getattr(attr, "_labels", None) + mandate = getattr(attr, "_mandate", None) functions.append(FunctionDef( name=func_name, description=description, parameters=parameters, labels=labels, + mandate=mandate, tags=[] )) diff --git a/packages/device-connect-edge/device_connect_edge/drivers/decorators.py b/packages/device-connect-edge/device_connect_edge/drivers/decorators.py index 4237699..aeade78 100644 --- a/packages/device-connect-edge/device_connect_edge/drivers/decorators.py +++ b/packages/device-connect-edge/device_connect_edge/drivers/decorators.py @@ -511,6 +511,7 @@ async def wrapper(self, *args, **kwargs): wrapper._description = func._description wrapper._arg_descriptions = func._arg_descriptions wrapper._labels = func._labels + wrapper._mandate = getattr(func, "_mandate", None) wrapper._original_func = func # For schema extraction return wrapper @@ -518,6 +519,15 @@ async def wrapper(self, *args, **kwargs): return decorator +def requires_mandate(scope: str = "actuation") -> Callable: + """Mark an RPC method as requiring a valid Device Mandate.""" + def decorator(func: Callable) -> Callable: + func._mandate = {"required": True, "scope": scope} + return func + + return decorator + + def emit( name: Optional[str] = None, description: Optional[str] = None, diff --git a/packages/device-connect-edge/device_connect_edge/mandates.py b/packages/device-connect-edge/device_connect_edge/mandates.py new file mode 100644 index 0000000..a2ebb03 --- /dev/null +++ b/packages/device-connect-edge/device_connect_edge/mandates.py @@ -0,0 +1,247 @@ +# Copyright (c) 2024-2026, Arm Limited and Contributors. All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +"""Device Mandate helpers. + +This module implements the first Device Mandate credential profile used by +Device Connect tests and demos. It is intentionally small and stdlib-only: +the public runtime contract is the mandate envelope and verifier interface, +while production credential formats can be added behind the same boundary. +""" + +from __future__ import annotations + +import hashlib +import hmac +import json +import uuid +from dataclasses import dataclass +from datetime import datetime, timezone +from typing import Any, Callable + + +MANDATE_FORMAT = "device-connect-hmac-v0" + + +@dataclass(frozen=True) +class MandateInvocationContext: + """Concrete invocation a closed mandate must authorize.""" + + device_id: str + method: str + params: dict[str, Any] + now: datetime | None = None + + +@dataclass(frozen=True) +class MandateVerificationResult: + """Verifier result for expected allow/deny outcomes.""" + + ok: bool + error_code: str | None = None + message: str = "" + + +KeyResolver = Callable[[str], bytes | str | None] + + +def create_open_mandate( + *, + principal: str, + agent: str, + device_id: str, + methods: list[str], + constraints: dict[str, Any] | None, + not_before: datetime, + not_after: datetime, + key: bytes | str, + mandate_id: str | None = None, +) -> dict[str, Any]: + """Create and sign an open mandate.""" + + payload = { + "id": mandate_id or f"open-{uuid.uuid4().hex[:12]}", + "principal": principal, + "agent": agent, + "device_id": device_id, + "methods": list(methods), + "constraints": constraints or {}, + "not_before": _format_dt(not_before), + "not_after": _format_dt(not_after), + } + return {**payload, "signature": _sign(payload, key)} + + +def create_closed_mandate( + *, + open_mandate: dict[str, Any], + agent: str, + device_id: str, + method: str, + params: dict[str, Any], + key: bytes | str, + issued_at: datetime, + mandate_id: str | None = None, + nonce: str | None = None, +) -> dict[str, Any]: + """Create and sign a closed mandate for one concrete invocation.""" + + payload = { + "format": MANDATE_FORMAT, + "id": mandate_id or f"closed-{uuid.uuid4().hex[:12]}", + "agent": agent, + "open_mandate": open_mandate, + "invocation": { + "device_id": device_id, + "method": method, + "params": params, + }, + "issued_at": _format_dt(issued_at), + "nonce": nonce or uuid.uuid4().hex, + } + return {**payload, "signature": _sign(payload, key)} + + +def verify_mandate( + mandate: dict[str, Any] | None, + *, + context: MandateInvocationContext, + key_resolver: KeyResolver, + replay_cache: set[str] | None = None, +) -> MandateVerificationResult: + """Verify that a closed mandate authorizes an invocation.""" + + if not mandate: + return _deny("mandate_required", "mandate_required: protected RPC needs a mandate") + if not isinstance(mandate, dict): + return _deny("invalid_mandate", "invalid_mandate: mandate must be an object") + if mandate.get("format") != MANDATE_FORMAT: + return _deny("invalid_mandate", "invalid_mandate: unsupported mandate format") + + open_mandate = mandate.get("open_mandate") + if not isinstance(open_mandate, dict): + return _deny("invalid_mandate", "invalid_mandate: missing open mandate") + + principal = open_mandate.get("principal") + agent = mandate.get("agent") + if not isinstance(principal, str) or not isinstance(agent, str): + return _deny("invalid_mandate", "invalid_mandate: missing principal or agent") + if open_mandate.get("agent") != agent: + return _deny("mandate_agent_denied", "mandate_agent_denied: agent mismatch") + + principal_key = key_resolver(principal) + agent_key = key_resolver(agent) + if principal_key is None or agent_key is None: + return _deny("unknown_mandate_key", "unknown_mandate_key: signer key unavailable") + if not _signature_valid(open_mandate, principal_key): + return _deny("invalid_mandate_signature", "invalid_mandate_signature: open mandate") + if not _signature_valid(mandate, agent_key): + return _deny("invalid_mandate_signature", "invalid_mandate_signature: closed mandate") + + now = _as_utc(context.now or datetime.now(timezone.utc)) + not_before = _parse_dt(str(open_mandate.get("not_before", ""))) + not_after = _parse_dt(str(open_mandate.get("not_after", ""))) + if not_before is None or not_after is None: + return _deny("invalid_mandate", "invalid_mandate: invalid validity window") + if now < not_before: + return _deny("mandate_not_yet_valid", "mandate_not_yet_valid") + if now > not_after: + return _deny("mandate_expired", "mandate_expired") + + if open_mandate.get("device_id") != context.device_id: + return _deny("mandate_device_denied", "mandate_device_denied") + if context.method not in (open_mandate.get("methods") or []): + return _deny("mandate_method_denied", "mandate_method_denied") + + invocation = mandate.get("invocation") or {} + if invocation.get("device_id") != context.device_id: + return _deny("mandate_device_denied", "mandate_device_denied") + if invocation.get("method") != context.method: + return _deny("mandate_method_denied", "mandate_method_denied") + if invocation.get("params") != context.params: + return _deny("mandate_params_denied", "mandate_params_denied") + + constraint_error = _check_constraints( + open_mandate.get("constraints") or {}, context.params, + ) + if constraint_error is not None: + return _deny("mandate_constraint_denied", constraint_error) + + nonce = mandate.get("nonce") + if replay_cache is not None and isinstance(nonce, str): + if nonce in replay_cache: + return _deny("mandate_replayed", "mandate_replayed") + replay_cache.add(nonce) + + return MandateVerificationResult(ok=True) + + +def _deny(code: str, message: str) -> MandateVerificationResult: + return MandateVerificationResult(ok=False, error_code=code, message=message) + + +def _sign(payload: dict[str, Any], key: bytes | str) -> str: + return hmac.new(_key_bytes(key), _canonical(payload), hashlib.sha256).hexdigest() + + +def _signature_valid(payload: dict[str, Any], key: bytes | str) -> bool: + expected = payload.get("signature") + if not isinstance(expected, str): + return False + unsigned = {k: v for k, v in payload.items() if k != "signature"} + return hmac.compare_digest(expected, _sign(unsigned, key)) + + +def _canonical(payload: dict[str, Any]) -> bytes: + return json.dumps( + payload, sort_keys=True, separators=(",", ":"), ensure_ascii=True, + ).encode() + + +def _key_bytes(key: bytes | str) -> bytes: + return key if isinstance(key, bytes) else key.encode() + + +def _format_dt(value: datetime) -> str: + return _as_utc(value).isoformat().replace("+00:00", "Z") + + +def _parse_dt(value: str) -> datetime | None: + try: + return _as_utc(datetime.fromisoformat(value.replace("Z", "+00:00"))) + except ValueError: + return None + + +def _as_utc(value: datetime) -> datetime: + if value.tzinfo is None: + return value.replace(tzinfo=timezone.utc) + return value.astimezone(timezone.utc) + + +def _check_constraints( + constraints: dict[str, Any], params: dict[str, Any], +) -> str | None: + for name, rules in constraints.items(): + if name not in params: + return f"mandate_constraint_denied: missing {name}" + value = params[name] + if not isinstance(rules, dict): + if value != rules: + return f"mandate_constraint_denied: {name}" + continue + for op, expected in rules.items(): + if op == "eq" and value != expected: + return f"mandate_constraint_denied: {name}" + if op == "lte" and not value <= expected: + return f"mandate_constraint_denied: {name}" + if op == "lt" and not value < expected: + return f"mandate_constraint_denied: {name}" + if op == "gte" and not value >= expected: + return f"mandate_constraint_denied: {name}" + if op == "gt" and not value > expected: + return f"mandate_constraint_denied: {name}" + if op == "in" and value not in expected: + return f"mandate_constraint_denied: {name}" + return None diff --git a/packages/device-connect-edge/device_connect_edge/types.py b/packages/device-connect-edge/device_connect_edge/types.py index 5440708..e73c1ea 100644 --- a/packages/device-connect-edge/device_connect_edge/types.py +++ b/packages/device-connect-edge/device_connect_edge/types.py @@ -68,6 +68,12 @@ class FunctionDef(BaseModel): "(read|write), safety (critical|informational), modality (rgb|thermal|...). " "Custom keys are allowed." ) + mandate: Optional[Dict[str, Any]] = Field( + default=None, + description="Optional execution authorization policy metadata. When present with " + "{'required': True}, the runtime requires a valid Device Mandate before " + "executing this function." + ) tags: List[str] = Field( default_factory=list, description="Tags for categorization (e.g., ['vision', 'capture'])" diff --git a/packages/device-connect-edge/tests/test_device_mandates.py b/packages/device-connect-edge/tests/test_device_mandates.py new file mode 100644 index 0000000..cda43fe --- /dev/null +++ b/packages/device-connect-edge/tests/test_device_mandates.py @@ -0,0 +1,179 @@ +# Copyright (c) 2024-2026, Arm Limited and Contributors. All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +"""Runtime enforcement tests for mandate-protected RPCs.""" + +from __future__ import annotations + +import json +from datetime import datetime, timedelta, timezone +from unittest.mock import AsyncMock + +import pytest + +from device_connect_edge.device import DeviceRuntime +from device_connect_edge.drivers import DeviceDriver, requires_mandate, rpc +from device_connect_edge.mandates import create_closed_mandate, create_open_mandate + + +PRINCIPAL_KEY = b"principal-secret" +AGENT_KEY = b"agent-secret" + + +class LockDriver(DeviceDriver): + device_type = "lock" + + def __init__(self): + super().__init__() + self.unlock_calls = 0 + + @requires_mandate(scope="actuation") + @rpc(labels={"direction": "write", "safety": "critical"}) + async def unlock(self, duration_s: int) -> dict: + """Unlock for a bounded duration.""" + self.unlock_calls += 1 + return {"unlocked": True, "duration_s": duration_s} + + @rpc(labels={"direction": "read"}) + async def get_status(self) -> dict: + """Return lock status.""" + return {"locked": True} + + +def _valid_mandate(params: dict | None = None) -> dict: + now = datetime.now(timezone.utc) + params = params or {"duration_s": 30} + open_mandate = create_open_mandate( + principal="operator", + agent="agent-1", + device_id="lock-001", + methods=["unlock"], + constraints={"duration_s": {"lte": 60}}, + not_before=now - timedelta(minutes=5), + not_after=now + timedelta(minutes=5), + key=PRINCIPAL_KEY, + mandate_id="open-1", + ) + return create_closed_mandate( + open_mandate=open_mandate, + agent="agent-1", + device_id="lock-001", + method="unlock", + params=params, + key=AGENT_KEY, + issued_at=now, + mandate_id="closed-1", + nonce="nonce-1", + ) + + +def _runtime(driver: LockDriver) -> DeviceRuntime: + return DeviceRuntime( + driver=driver, + device_id="lock-001", + messaging_urls=["nats://localhost:4222"], + mandate_keys={"operator": PRINCIPAL_KEY, "agent-1": AGENT_KEY}, + ) + + +async def _invoke_callback(rt: DeviceRuntime, method: str, params: dict) -> dict: + rt.messaging = AsyncMock() + rt.messaging.subscribe = AsyncMock() + await rt._cmd_subscription() + on_msg = rt.messaging.subscribe.call_args[1]["callback"] + await on_msg( + json.dumps({ + "jsonrpc": "2.0", + "id": "req-1", + "method": method, + "params": params, + }).encode(), + reply_subject="reply.inbox.1", + ) + return json.loads(rt.messaging.publish.call_args[0][1]) + + +class TestRequiresMandateDecorator: + def test_capability_metadata_includes_mandate_requirement(self): + driver = LockDriver() + fn = next(f for f in driver.functions if f.name == "unlock") + assert fn.mandate == {"required": True, "scope": "actuation"} + + def test_unprotected_capability_has_no_mandate_requirement(self): + driver = LockDriver() + fn = next(f for f in driver.functions if f.name == "get_status") + assert fn.mandate is None + + +class TestCommandMandateEnforcement: + @pytest.mark.asyncio + async def test_protected_rpc_without_mandate_is_denied_before_driver_call(self): + driver = LockDriver() + response = await _invoke_callback( + _runtime(driver), "unlock", {"duration_s": 30}, + ) + + assert response["error"]["code"] == -32041 + assert "mandate_required" in response["error"]["message"] + assert driver.unlock_calls == 0 + + @pytest.mark.asyncio + async def test_protected_rpc_with_valid_mandate_executes(self): + driver = LockDriver() + response = await _invoke_callback( + _runtime(driver), + "unlock", + {"duration_s": 30, "_dc_meta": {"mandate": _valid_mandate()}}, + ) + + assert response["result"] == {"unlocked": True, "duration_s": 30} + assert driver.unlock_calls == 1 + + @pytest.mark.asyncio + async def test_unprotected_rpc_executes_without_mandate(self): + response = await _invoke_callback(_runtime(LockDriver()), "get_status", {}) + assert response["result"] == {"locked": True} + + @pytest.mark.asyncio + async def test_broadcast_protected_rpc_without_mandate_is_denied(self): + driver = LockDriver() + rt = _runtime(driver) + rt.messaging = AsyncMock() + + await rt._handle_broadcast_envelope( + { + "correlation_id": "br-1", + "function": "unlock", + "params": {"duration_s": 30}, + }, + "br-1", + ) + + payload = json.loads(rt.messaging.publish.call_args[0][1]) + assert payload["success"] is False + assert payload["error"]["code"] == "mandate_required" + assert driver.unlock_calls == 0 + + @pytest.mark.asyncio + async def test_broadcast_protected_rpc_with_valid_mandate_executes(self): + driver = LockDriver() + rt = _runtime(driver) + rt.messaging = AsyncMock() + + await rt._handle_broadcast_envelope( + { + "correlation_id": "br-1", + "function": "unlock", + "params": { + "duration_s": 30, + "_dc_meta": {"mandate": _valid_mandate()}, + }, + }, + "br-1", + ) + + payload = json.loads(rt.messaging.publish.call_args[0][1]) + assert payload["success"] is True + assert payload["result"] == {"unlocked": True, "duration_s": 30} + assert driver.unlock_calls == 1 diff --git a/packages/device-connect-edge/tests/test_drivers.py b/packages/device-connect-edge/tests/test_drivers.py index 5763be6..1d5f2de 100644 --- a/packages/device-connect-edge/tests/test_drivers.py +++ b/packages/device-connect-edge/tests/test_drivers.py @@ -11,7 +11,7 @@ import pytest from unittest.mock import AsyncMock, MagicMock -from device_connect_edge.drivers import DeviceDriver, rpc, emit, build_function_schema, build_event_schema +from device_connect_edge.drivers import DeviceDriver, rpc, emit, requires_mandate, build_function_schema, build_event_schema from device_connect_edge.drivers.base import on from device_connect_edge.types import DeviceIdentity, DeviceStatus @@ -198,6 +198,24 @@ async def capture(self, resolution: str = "1080p") -> dict: assert capture._labels == {"direction": "write", "modality": ["rgb", "4k"]} +class TestRequiresMandate: + def test_requires_mandate_above_rpc(self): + @requires_mandate(scope="actuation") + @rpc() + async def unlock(self) -> dict: + return {} + + assert unlock._mandate == {"required": True, "scope": "actuation"} + + def test_requires_mandate_below_rpc(self): + @rpc() + @requires_mandate(scope="actuation") + async def unlock(self) -> dict: + return {} + + assert unlock._mandate == {"required": True, "scope": "actuation"} + + class TestEmitLabels: def test_default_none(self): @emit() diff --git a/packages/device-connect-edge/tests/test_mandate_verifier.py b/packages/device-connect-edge/tests/test_mandate_verifier.py new file mode 100644 index 0000000..0ce7f58 --- /dev/null +++ b/packages/device-connect-edge/tests/test_mandate_verifier.py @@ -0,0 +1,144 @@ +# Copyright (c) 2024-2026, Arm Limited and Contributors. All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +"""Tests for Device Mandate signing and verification helpers.""" + +from __future__ import annotations + +from datetime import datetime, timedelta, timezone + +from device_connect_edge.mandates import ( + MandateInvocationContext, + create_closed_mandate, + create_open_mandate, + verify_mandate, +) + + +PRINCIPAL_KEY = b"principal-secret" +AGENT_KEY = b"agent-secret" + + +def _keys(principal: str) -> bytes | None: + return { + "operator": PRINCIPAL_KEY, + "agent-1": AGENT_KEY, + }.get(principal) + + +def _valid_mandate(params: dict | None = None) -> dict: + now = datetime(2026, 5, 11, 12, 0, tzinfo=timezone.utc) + params = params or {"duration_s": 30} + open_mandate = create_open_mandate( + principal="operator", + agent="agent-1", + device_id="lock-001", + methods=["unlock"], + constraints={"duration_s": {"lte": 60}}, + not_before=now - timedelta(minutes=5), + not_after=now + timedelta(minutes=5), + key=PRINCIPAL_KEY, + mandate_id="open-1", + ) + return create_closed_mandate( + open_mandate=open_mandate, + agent="agent-1", + device_id="lock-001", + method="unlock", + params=params, + key=AGENT_KEY, + issued_at=now, + mandate_id="closed-1", + nonce="nonce-1", + ) + + +def _context(**overrides) -> MandateInvocationContext: + base = { + "device_id": "lock-001", + "method": "unlock", + "params": {"duration_s": 30}, + "now": datetime(2026, 5, 11, 12, 0, tzinfo=timezone.utc), + } + base.update(overrides) + return MandateInvocationContext(**base) + + +def test_valid_closed_mandate_verifies(): + result = verify_mandate(_valid_mandate(), context=_context(), key_resolver=_keys) + assert result.ok is True + assert result.error_code is None + + +def test_missing_mandate_fails_closed(): + result = verify_mandate(None, context=_context(), key_resolver=_keys) + assert result.ok is False + assert result.error_code == "mandate_required" + + +def test_tampered_parameters_fail_signature_check(): + mandate = _valid_mandate() + mandate["invocation"]["params"]["duration_s"] = 45 + + result = verify_mandate(mandate, context=_context(), key_resolver=_keys) + + assert result.ok is False + assert result.error_code == "invalid_mandate_signature" + + +def test_wrong_device_is_denied(): + result = verify_mandate( + _valid_mandate(), + context=_context(device_id="other-lock"), + key_resolver=_keys, + ) + assert result.ok is False + assert result.error_code == "mandate_device_denied" + + +def test_wrong_method_is_denied(): + result = verify_mandate( + _valid_mandate(), + context=_context(method="lock"), + key_resolver=_keys, + ) + assert result.ok is False + assert result.error_code == "mandate_method_denied" + + +def test_expired_mandate_is_denied(): + result = verify_mandate( + _valid_mandate(), + context=_context(now=datetime(2026, 5, 11, 12, 10, tzinfo=timezone.utc)), + key_resolver=_keys, + ) + assert result.ok is False + assert result.error_code == "mandate_expired" + + +def test_parameter_constraint_is_enforced(): + mandate = _valid_mandate(params={"duration_s": 75}) + result = verify_mandate( + mandate, + context=_context(params={"duration_s": 75}), + key_resolver=_keys, + ) + assert result.ok is False + assert result.error_code == "mandate_constraint_denied" + + +def test_replay_cache_denies_reused_nonce(): + seen: set[str] = set() + mandate = _valid_mandate() + + first = verify_mandate( + mandate, context=_context(), key_resolver=_keys, replay_cache=seen, + ) + second = verify_mandate( + mandate, context=_context(), key_resolver=_keys, replay_cache=seen, + ) + + assert first.ok is True + assert second.ok is False + assert second.error_code == "mandate_replayed" diff --git a/packages/device-connect-edge/tests/test_types.py b/packages/device-connect-edge/tests/test_types.py index bb34151..15c2ab0 100644 --- a/packages/device-connect-edge/tests/test_types.py +++ b/packages/device-connect-edge/tests/test_types.py @@ -101,6 +101,14 @@ def test_function_labels_roundtrip(self): f2 = FunctionDef.model_validate_json(f.model_dump_json()) assert f2.labels == f.labels + def test_function_mandate_roundtrip(self): + f = FunctionDef( + name="unlock", + mandate={"required": True, "scope": "actuation"}, + ) + f2 = FunctionDef.model_validate_json(f.model_dump_json()) + assert f2.mandate == {"required": True, "scope": "actuation"} + def test_event_labels_default_none(self): e = EventDef(name="heartbeat") assert e.labels is None From 1d0bf516716c16488382c3778aace4b8a4c14dd5 Mon Sep 17 00:00:00 2001 From: Sourav Pati Date: Mon, 11 May 2026 08:22:25 -0700 Subject: [PATCH 16/21] feat: expose mandates in agent adapters --- .../adapters/claude.py | 11 +++- .../tests/test_claude_adapter.py | 65 +++++++++++++++++++ .../tests/test_langchain_adapter.py | 7 ++ .../tests/test_strands_adapter.py | 7 ++ 4 files changed, 88 insertions(+), 2 deletions(-) diff --git a/packages/device-connect-agent-tools/device_connect_agent_tools/adapters/claude.py b/packages/device-connect-agent-tools/device_connect_agent_tools/adapters/claude.py index f4a2883..9256968 100644 --- a/packages/device-connect-agent-tools/device_connect_agent_tools/adapters/claude.py +++ b/packages/device-connect-agent-tools/device_connect_agent_tools/adapters/claude.py @@ -113,7 +113,10 @@ async def discover(args: dict[str, Any]) -> dict[str, Any]: "to a single (device, function) tuple -- use device().function() " "or function() scope. Returns {success, device_id, function, " "result|error}. Use invoke_many for fan-out across multiple targets.", - {"selector": str, "params": dict, "llm_reasoning": str}, + { + "selector": str, "params": dict, "llm_reasoning": str, + "mandate": dict, + }, ) async def invoke(args: dict[str, Any]) -> dict[str, Any]: return _text( @@ -121,6 +124,7 @@ async def invoke(args: dict[str, Any]) -> dict[str, Any]: selector=args["selector"], params=args.get("params"), llm_reasoning=args.get("llm_reasoning"), + mandate=args.get("mandate"), ) ) @@ -134,7 +138,7 @@ async def invoke(args: dict[str, Any]) -> dict[str, Any]: "target gets a per-call timeout (default 30s).", { "selector": str, "params": dict, "timeout": float, - "max_concurrency": int, "llm_reasoning": str, + "max_concurrency": int, "llm_reasoning": str, "mandate": dict, }, ) async def invoke_many(args: dict[str, Any]) -> dict[str, Any]: @@ -145,6 +149,7 @@ async def invoke_many(args: dict[str, Any]) -> dict[str, Any]: timeout=float(args.get("timeout", 30.0)), max_concurrency=int(args.get("max_concurrency", 32)), llm_reasoning=args.get("llm_reasoning"), + mandate=args.get("mandate"), ) ) @@ -161,6 +166,7 @@ async def invoke_many(args: dict[str, Any]) -> dict[str, Any]: { "selector": str, "params": dict, "where": str, "bindings": dict, "fire_at": float, "on_late": str, "llm_reasoning": str, + "mandate": dict, }, ) async def broadcast(args: dict[str, Any]) -> dict[str, Any]: @@ -173,6 +179,7 @@ async def broadcast(args: dict[str, Any]) -> dict[str, Any]: fire_at=args.get("fire_at"), on_late=args.get("on_late", "skip"), llm_reasoning=args.get("llm_reasoning"), + mandate=args.get("mandate"), ) ) diff --git a/packages/device-connect-agent-tools/tests/test_claude_adapter.py b/packages/device-connect-agent-tools/tests/test_claude_adapter.py index 4960a49..99ff17d 100644 --- a/packages/device-connect-agent-tools/tests/test_claude_adapter.py +++ b/packages/device-connect-agent-tools/tests/test_claude_adapter.py @@ -110,3 +110,68 @@ def test_create_server_bundles_all_tools(self): assert server["name"] == "device-connect" bundled = {t._tool_name for t in server["tools"]} assert bundled == set(TOOL_NAMES) + + @pytest.mark.parametrize("name", ("invoke", "invoke_many", "broadcast")) + def test_invocation_schemas_include_optional_mandate(self, name): + from device_connect_agent_tools.adapters import claude as adapter + + schema = getattr(adapter, name)._tool_schema + + assert schema["mandate"] is dict + + +class TestClaudeAdapterMandates: + @pytest.mark.asyncio + async def test_invoke_forwards_mandate(self): + from device_connect_agent_tools.adapters import claude as adapter + + mandate = {"format": "device-connect-hmac-v0", "closed": {"id": "m-1"}} + + with patch.object(adapter, "_invoke", return_value={"success": True}) as invoke: + await adapter.invoke( + { + "selector": "device(lock-001).function(unlock)", + "params": {"duration_s": 30}, + "mandate": mandate, + } + ) + + assert invoke.call_args.kwargs["mandate"] == mandate + + @pytest.mark.asyncio + async def test_invoke_many_forwards_mandate(self): + from device_connect_agent_tools.adapters import claude as adapter + + mandate = {"format": "device-connect-hmac-v0", "closed": {"id": "m-1"}} + + with patch.object( + adapter, "_invoke_many", return_value={"succeeded": 1} + ) as invoke_many: + await adapter.invoke_many( + { + "selector": "device(category:lock).function(unlock)", + "params": {"duration_s": 30}, + "mandate": mandate, + } + ) + + assert invoke_many.call_args.kwargs["mandate"] == mandate + + @pytest.mark.asyncio + async def test_broadcast_forwards_mandate(self): + from device_connect_agent_tools.adapters import claude as adapter + + mandate = {"format": "device-connect-hmac-v0", "closed": {"id": "m-1"}} + + with patch.object( + adapter, "_broadcast", return_value={"candidates": 1} + ) as broadcast: + await adapter.broadcast( + { + "selector": "device(category:lock).function(unlock)", + "params": {"duration_s": 30}, + "mandate": mandate, + } + ) + + assert broadcast.call_args.kwargs["mandate"] == mandate diff --git a/packages/device-connect-agent-tools/tests/test_langchain_adapter.py b/packages/device-connect-agent-tools/tests/test_langchain_adapter.py index 9aae070..210930d 100644 --- a/packages/device-connect-agent-tools/tests/test_langchain_adapter.py +++ b/packages/device-connect-agent-tools/tests/test_langchain_adapter.py @@ -9,6 +9,7 @@ """ import sys +from inspect import signature from types import ModuleType from unittest.mock import MagicMock, patch @@ -111,3 +112,9 @@ def test_tool_descriptions_not_empty(self): for name in adapter.__all__: assert len(getattr(adapter, name).description) > 0, f"{name} has empty description" + + @pytest.mark.parametrize("name", ("invoke", "invoke_many", "broadcast")) + def test_invocation_tools_inherit_mandate_signature(self, name): + from device_connect_agent_tools.adapters import langchain as adapter + + assert "mandate" in signature(getattr(adapter, name)._func).parameters diff --git a/packages/device-connect-agent-tools/tests/test_strands_adapter.py b/packages/device-connect-agent-tools/tests/test_strands_adapter.py index 4e46ceb..0943968 100644 --- a/packages/device-connect-agent-tools/tests/test_strands_adapter.py +++ b/packages/device-connect-agent-tools/tests/test_strands_adapter.py @@ -9,6 +9,7 @@ """ import sys +from inspect import signature from types import ModuleType from unittest.mock import MagicMock, patch @@ -88,3 +89,9 @@ def test_tool_names_match(self): for name in adapter.__all__: assert getattr(adapter, name).__name__ == name, f"{name}.__name__ mismatch" + + @pytest.mark.parametrize("name", ("invoke", "invoke_many", "broadcast")) + def test_invocation_tools_inherit_mandate_signature(self, name): + from device_connect_agent_tools.adapters import strands as adapter + + assert "mandate" in signature(getattr(adapter, name).__wrapped__).parameters From fc46357638175addb30d528fe1a9be3b926299e3 Mon Sep 17 00:00:00 2001 From: Sourav Pati Date: Mon, 11 May 2026 08:27:33 -0700 Subject: [PATCH 17/21] docs: add device mandate examples --- docs/README.md | 2 + docs/device-mandates.md | 108 +++++++++++ .../device_mandates/mandate_examples.py | 174 ++++++++++++++++++ 3 files changed, 284 insertions(+) create mode 100644 docs/device-mandates.md create mode 100644 packages/device-connect-edge/examples/device_mandates/mandate_examples.py diff --git a/docs/README.md b/docs/README.md index ae5ef88..16fa39d 100644 --- a/docs/README.md +++ b/docs/README.md @@ -3,3 +3,5 @@ Developer reference material for Device Connect. - **class-map.html** — Interactive class/module relationship diagram. Open in a browser to explore the architecture. +- **device-mandates.md** — Concise guide and runnable examples for mandate-protected device functions. +- **device-mandates-spec.md** — Implementation notes and acceptance criteria for the Device Mandates feature. diff --git a/docs/device-mandates.md b/docs/device-mandates.md new file mode 100644 index 0000000..8cdd783 --- /dev/null +++ b/docs/device-mandates.md @@ -0,0 +1,108 @@ +# Device Mandates + +Device Mandates add a signed authorization envelope to sensitive RPCs. A driver marks a function with `@requires_mandate`, and the runtime denies that function unless the call includes a valid closed mandate in `_dc_meta.mandate`. + +Use mandates for actuation that can affect safety, access, cost, or physical state. Read-only functions usually should not require mandates. + +## Protect an RPC + +Decorate the RPC with `@requires_mandate`. The decorator may be placed above or below `@rpc()`. + +```python +from device_connect_edge.drivers import DeviceDriver, requires_mandate, rpc + + +class SmartLockDriver(DeviceDriver): + device_type = "smart_lock" + + @requires_mandate(scope="actuation") + @rpc() + async def unlock(self, duration_s: int = 10) -> dict: + return {"state": "unlocked", "duration_s": duration_s} +``` + +Discovery metadata for `unlock` includes: + +```json +{"mandate": {"required": true, "scope": "actuation"}} +``` + +## Create Mandates + +An open mandate is signed by the principal and delegates bounded authority to an agent. A closed mandate is signed by the agent for one concrete invocation. + +```python +from datetime import datetime, timedelta, timezone + +from device_connect_edge import create_closed_mandate, create_open_mandate + +now = datetime.now(timezone.utc) +principal_key = b"principal-demo-key" +agent_key = b"agent-demo-key" + +open_mandate = create_open_mandate( + principal="operator", + agent="agent-1", + device_id="lock-front-door", + methods=["unlock"], + constraints={"duration_s": {"lte": 30}}, + not_before=now - timedelta(seconds=5), + not_after=now + timedelta(minutes=5), + key=principal_key, +) + +closed_mandate = create_closed_mandate( + open_mandate=open_mandate, + agent="agent-1", + device_id="lock-front-door", + method="unlock", + params={"duration_s": 20}, + key=agent_key, + issued_at=now, +) +``` + +Pass the closed mandate through agent tools with the `mandate` argument: + +```python +from device_connect_agent_tools import invoke + +result = invoke( + "device(lock-front-door).function(unlock)", + params={"duration_s": 20}, + mandate=closed_mandate, +) +``` + +## Valid and Invalid Use Cases + +Valid smart-lock use: unlock the front door for 20 seconds when the open mandate allows `unlock` on `lock-front-door` and constrains `duration_s <= 30`. + +Invalid smart-lock use: reuse that same mandate for `duration_s=60`, another device, another method, or changed parameters. The signature and constraint checks fail closed before the driver method runs. + +Valid heater use: set a room heater to 21.5 C when the open mandate allows `set_temperature` on `heater-living-room` and constrains `target_c` between 18 and 23. + +Invalid heater use: request `target_c=28` or replay a previously used closed mandate nonce. The verifier denies the call. + +See `packages/device-connect-edge/examples/device_mandates/mandate_examples.py` for runnable local examples of these cases. + +## Testing Commands + +Run the focused mandate tests: + +```bash +pytest packages/device-connect-edge/tests/test_mandate_verifier.py packages/device-connect-edge/tests/test_device_mandates.py packages/device-connect-agent-tools/tests/test_agent_mandates.py -q +``` + +Run package test suites: + +```bash +pytest packages/device-connect-edge/tests -q +pytest packages/device-connect-agent-tools/tests -q +``` + +Run the examples: + +```bash +PYTHONPATH=packages/device-connect-edge python packages/device-connect-edge/examples/device_mandates/mandate_examples.py +``` diff --git a/packages/device-connect-edge/examples/device_mandates/mandate_examples.py b/packages/device-connect-edge/examples/device_mandates/mandate_examples.py new file mode 100644 index 0000000..25eec19 --- /dev/null +++ b/packages/device-connect-edge/examples/device_mandates/mandate_examples.py @@ -0,0 +1,174 @@ +#!/usr/bin/env python3 +# Copyright (c) 2024-2026, Arm Limited and Contributors. All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +"""Local examples for Device Mandates. + +Run from the repository root: + PYTHONPATH=packages/device-connect-edge python packages/device-connect-edge/examples/device_mandates/mandate_examples.py + +Focused tests: + pytest packages/device-connect-edge/tests/test_mandate_verifier.py packages/device-connect-edge/tests/test_device_mandates.py packages/device-connect-agent-tools/tests/test_agent_mandates.py -q +""" + +from __future__ import annotations + +import asyncio +from datetime import datetime, timedelta, timezone +from typing import Any + +from device_connect_edge import create_closed_mandate, create_open_mandate +from device_connect_edge.drivers import DeviceDriver, requires_mandate, rpc +from device_connect_edge.mandates import MandateInvocationContext, verify_mandate + + +PRINCIPAL_KEY = b"principal-demo-key" +AGENT_KEY = b"agent-demo-key" + + +class SmartLockDriver(DeviceDriver): + """Smart lock with mandate-protected actuation.""" + + device_type = "smart_lock" + + @requires_mandate(scope="actuation") + @rpc() + async def unlock(self, duration_s: int = 10) -> dict[str, Any]: + return {"state": "unlocked", "duration_s": duration_s} + + @rpc() + async def get_status(self) -> dict[str, str]: + return {"state": "locked"} + + +class HeaterDriver(DeviceDriver): + """Heater with mandate-protected setpoint changes.""" + + device_type = "heater" + + @rpc() + async def get_temperature(self) -> dict[str, float]: + return {"current_c": 20.5} + + @requires_mandate(scope="actuation") + @rpc() + async def set_temperature(self, target_c: float) -> dict[str, float]: + return {"target_c": target_c} + + +def key_resolver(principal: str) -> bytes | None: + return {"operator": PRINCIPAL_KEY, "agent-1": AGENT_KEY}.get(principal) + + +def closed_mandate( + *, + device_id: str, + method: str, + params: dict[str, Any], + constraints: dict[str, Any], +) -> dict[str, Any]: + now = datetime.now(timezone.utc) + open_mandate = create_open_mandate( + principal="operator", + agent="agent-1", + device_id=device_id, + methods=[method], + constraints=constraints, + not_before=now - timedelta(seconds=5), + not_after=now + timedelta(minutes=5), + key=PRINCIPAL_KEY, + ) + return create_closed_mandate( + open_mandate=open_mandate, + agent="agent-1", + device_id=device_id, + method=method, + params=params, + key=AGENT_KEY, + issued_at=now, + ) + + +def verify_example( + *, + label: str, + mandate: dict[str, Any] | None, + device_id: str, + method: str, + params: dict[str, Any], + replay_cache: set[str] | None = None, +) -> None: + result = verify_mandate( + mandate, + context=MandateInvocationContext( + device_id=device_id, + method=method, + params=params, + ), + key_resolver=key_resolver, + replay_cache=replay_cache, + ) + outcome = "allowed" if result.ok else f"denied ({result.error_code})" + print(f"{label}: {outcome}") + + +async def main() -> None: + lock = SmartLockDriver() + heater = HeaterDriver() + + unlock_policy = getattr(lock.unlock, "_mandate", None) + heater_policy = getattr(heater.set_temperature, "_mandate", None) + print(f"smart-lock unlock mandate policy: {unlock_policy}") + print(f"heater set_temperature mandate policy: {heater_policy}") + + valid_unlock_params = {"duration_s": 20} + valid_unlock = closed_mandate( + device_id="lock-front-door", + method="unlock", + params=valid_unlock_params, + constraints={"duration_s": {"lte": 30}}, + ) + verify_example( + label="valid smart-lock unlock", + mandate=valid_unlock, + device_id="lock-front-door", + method="unlock", + params=valid_unlock_params, + ) + verify_example( + label="invalid smart-lock duration", + mandate=valid_unlock, + device_id="lock-front-door", + method="unlock", + params={"duration_s": 60}, + ) + + valid_heat_params = {"target_c": 21.5} + valid_heat = closed_mandate( + device_id="heater-living-room", + method="set_temperature", + params=valid_heat_params, + constraints={"target_c": {"gte": 18, "lte": 23}}, + ) + replay_cache: set[str] = set() + verify_example( + label="valid heater setpoint", + mandate=valid_heat, + device_id="heater-living-room", + method="set_temperature", + params=valid_heat_params, + replay_cache=replay_cache, + ) + verify_example( + label="invalid heater replay", + mandate=valid_heat, + device_id="heater-living-room", + method="set_temperature", + params=valid_heat_params, + replay_cache=replay_cache, + ) + + +if __name__ == "__main__": + asyncio.run(main()) From 9c40bd911b1a696b77449a80624c41a75a6d3e07 Mon Sep 17 00:00:00 2001 From: Sourav Pati Date: Mon, 11 May 2026 08:37:47 -0700 Subject: [PATCH 18/21] feat: add server mandate receipts --- .../portal/services/execution_receipts.py | 99 +++++++++ .../portal/services/mandates.py | 108 ++++++++++ .../portal/views/agent_api.py | 196 ++++++++++++++++-- .../test_portal_agent_api.py | 162 +++++++++++++++ .../test_portal_mandates.py | 137 ++++++++++++ 5 files changed, 689 insertions(+), 13 deletions(-) create mode 100644 packages/device-connect-server/device_connect_server/portal/services/execution_receipts.py create mode 100644 packages/device-connect-server/device_connect_server/portal/services/mandates.py create mode 100644 packages/device-connect-server/tests/device_connect_server/test_portal_mandates.py diff --git a/packages/device-connect-server/device_connect_server/portal/services/execution_receipts.py b/packages/device-connect-server/device_connect_server/portal/services/execution_receipts.py new file mode 100644 index 0000000..59583ed --- /dev/null +++ b/packages/device-connect-server/device_connect_server/portal/services/execution_receipts.py @@ -0,0 +1,99 @@ +# Copyright (c) 2024-2026, Arm Limited and Contributors. All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +"""Execution receipt helpers for mandate-aware invokes.""" + +from __future__ import annotations + +import hashlib +import hmac +import json +import os +import secrets +from datetime import datetime, timezone +from typing import Any + + +def build_receipt( + *, + trace_id: str, + tenant: str, + actor: dict[str, Any], + device_id: str, + function: str, + params: dict[str, Any], + status: str, + elapsed_ms: int, + response: Any = None, + error: dict[str, Any] | None = None, + mandate: dict[str, Any] | None = None, + mandate_required: bool = False, + mandate_verified: bool = False, + mandate_error_code: str | None = None, +) -> dict[str, Any]: + receipt = { + "receipt_id": "rcpt-" + secrets.token_hex(8), + "trace_id": trace_id, + "tenant": tenant, + "actor": { + "token_id": actor.get("token_id"), + "username": actor.get("username"), + }, + "device_id": device_id, + "function": function, + "status": status, + "authorized": status != "denied", + "mandate": _mandate_summary( + mandate, + required=mandate_required, + verified=mandate_verified, + error_code=mandate_error_code, + ), + "params_sha256": hash_json(params), + "response_sha256": hash_json(response) if response is not None else None, + "error": error, + "elapsed_ms": elapsed_ms, + "issued_at": datetime.now(timezone.utc).isoformat().replace("+00:00", "Z"), + } + receipt["signature"] = sign_receipt(receipt) + return receipt + + +def hash_json(value: Any) -> str: + payload = json.dumps( + value, sort_keys=True, separators=(",", ":"), ensure_ascii=True, + default=str, + ).encode() + return hashlib.sha256(payload).hexdigest() + + +def sign_receipt(receipt: dict[str, Any]) -> str | None: + key = os.getenv("DC_RECEIPT_SIGNING_KEY") + if not key: + return None + unsigned = {k: v for k, v in receipt.items() if k != "signature"} + payload = json.dumps( + unsigned, sort_keys=True, separators=(",", ":"), ensure_ascii=True, + default=str, + ).encode() + return hmac.new(key.encode(), payload, hashlib.sha256).hexdigest() + + +def _mandate_summary( + mandate: dict[str, Any] | None, + *, + required: bool, + verified: bool, + error_code: str | None, +) -> dict[str, Any]: + open_mandate = mandate.get("open_mandate") if isinstance(mandate, dict) else {} + return { + "required": required, + "verified": verified, + "id": mandate.get("id") if isinstance(mandate, dict) else None, + "open_mandate_id": open_mandate.get("id") if isinstance(open_mandate, dict) else None, + "principal": open_mandate.get("principal") if isinstance(open_mandate, dict) else None, + "agent": mandate.get("agent") if isinstance(mandate, dict) else None, + "error_code": error_code, + } diff --git a/packages/device-connect-server/device_connect_server/portal/services/mandates.py b/packages/device-connect-server/device_connect_server/portal/services/mandates.py new file mode 100644 index 0000000..c1b5d63 --- /dev/null +++ b/packages/device-connect-server/device_connect_server/portal/services/mandates.py @@ -0,0 +1,108 @@ +# Copyright (c) 2024-2026, Arm Limited and Contributors. All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +"""Server-side helpers for Device Mandates.""" + +from __future__ import annotations + +import json +import os +from typing import Any + +from device_connect_edge.mandates import ( + MandateInvocationContext, + MandateVerificationResult, + verify_mandate, +) + + +_SERVER_MANDATE_REPLAY_CACHE: set[str] = set() + + +def get_function_mandate_policy( + device_doc: dict[str, Any] | None, + function: str, +) -> dict[str, Any] | None: + """Return mandate policy metadata for a function in a registry document.""" + capabilities = (device_doc or {}).get("capabilities") or {} + for fn in capabilities.get("functions") or []: + if fn.get("name") == function: + mandate = fn.get("mandate") + return mandate if isinstance(mandate, dict) else None + return None + + +def extract_mandate( + body: dict[str, Any], + params: dict[str, Any], +) -> dict[str, Any] | None: + """Extract mandate from top-level body or params._dc_meta.""" + mandate = body.get("mandate") + if isinstance(mandate, dict): + return mandate + dc_meta = params.get("_dc_meta") + if isinstance(dc_meta, dict) and isinstance(dc_meta.get("mandate"), dict): + return dc_meta["mandate"] + return None + + +def strip_dc_meta(params: dict[str, Any]) -> dict[str, Any]: + """Return user parameters only, excluding reserved Device Connect metadata.""" + return {k: v for k, v in params.items() if k != "_dc_meta"} + + +def attach_mandate( + params: dict[str, Any], + source_params: dict[str, Any], + mandate: dict[str, Any] | None, +) -> dict[str, Any]: + """Attach a mandate to params._dc_meta while preserving existing metadata.""" + out = dict(params) + existing_meta = source_params.get("_dc_meta") + meta = dict(existing_meta) if isinstance(existing_meta, dict) else {} + if mandate is not None: + meta["mandate"] = mandate + if meta: + out["_dc_meta"] = meta + return out + + +def verify_server_mandate( + *, + device_doc: dict[str, Any] | None, + device_id: str, + function: str, + params: dict[str, Any], + mandate: dict[str, Any] | None, +) -> MandateVerificationResult: + """Verify a mandate when policy requires it or a caller supplied one.""" + policy = get_function_mandate_policy(device_doc, function) + mandate_required = bool(policy and policy.get("required")) + if not mandate_required and mandate is None: + return MandateVerificationResult(ok=True) + return verify_mandate( + mandate, + context=MandateInvocationContext( + device_id=device_id, + method=function, + params=params, + ), + key_resolver=resolve_mandate_key, + replay_cache=_SERVER_MANDATE_REPLAY_CACHE, + ) + + +def resolve_mandate_key(principal_or_agent: str) -> bytes | str | None: + """Resolve principal/agent signing keys from DC_MANDATE_KEYS_JSON.""" + raw = os.getenv("DC_MANDATE_KEYS_JSON", "") + if not raw: + return None + try: + keys = json.loads(raw) + except json.JSONDecodeError: + return None + if not isinstance(keys, dict): + return None + key = keys.get(principal_or_agent) + return key if isinstance(key, str) else None diff --git a/packages/device-connect-server/device_connect_server/portal/views/agent_api.py b/packages/device-connect-server/device_connect_server/portal/views/agent_api.py index 0aae37c..903af8f 100644 --- a/packages/device-connect-server/device_connect_server/portal/views/agent_api.py +++ b/packages/device-connect-server/device_connect_server/portal/views/agent_api.py @@ -27,6 +27,8 @@ from ..services import cli_auth as cli_auth_svc from ..services import credentials as credentials_svc +from ..services import execution_receipts as receipts_svc +from ..services import mandates as mandates_svc from ..services import registry_client, tokens as tokens_svc from ..services.backend import get_backend, validate_name @@ -100,10 +102,14 @@ def _err( code: str, message: str, trace_id: str | None = None, + extra: dict[str, Any] | None = None, ) -> web.Response: + payload = {"success": False, "trace_id": trace_id or _trace_id(), + "error": {"code": code, "message": message}} + if extra: + payload.update(extra) return web.json_response( - {"success": False, "trace_id": trace_id or _trace_id(), - "error": {"code": code, "message": message}}, + payload, status=status, ) @@ -645,26 +651,96 @@ async def device_invoke(request: web.Request) -> web.Response: if not isinstance(params, dict): return _err(status=400, code="invalid_params", message="params must be an object", trace_id=trace) + clean_params = mandates_svc.strip_dc_meta(params) + mandate = mandates_svc.extract_mandate(body, params) timeout = _clamp_timeout(body.get("timeout")) reason = _truncate(body.get("reason") or body.get("llm_reasoning") or "", 500) + device_doc = registry_client.get_device(full_name) + mandate_policy = mandates_svc.get_function_mandate_policy(device_doc, function) + mandate_required = bool(mandate_policy and mandate_policy.get("required")) + mandate_result = mandates_svc.verify_server_mandate( + device_doc=device_doc, + device_id=full_name, + function=function, + params=clean_params, + mandate=mandate, + ) + if not mandate_result.ok: + receipt = receipts_svc.build_receipt( + trace_id=trace, + tenant=tenant, + actor=request.get("token") or {}, + device_id=full_name, + function=function, + params=clean_params, + status="denied", + elapsed_ms=0, + error={"code": mandate_result.error_code, "message": mandate_result.message}, + mandate=mandate, + mandate_required=mandate_required, + mandate_verified=False, + mandate_error_code=mandate_result.error_code, + ) + _audit(request, "invoke_denied", trace_id=trace, device_id=full_name, + function=function, receipt_id=receipt["receipt_id"], + error=mandate_result.error_code) + return _err( + status=403, + code=mandate_result.error_code or "mandate_denied", + message=mandate_result.message or "mandate denied", + trace_id=trace, + extra={"receipt": receipt}, + ) + params_for_rpc = mandates_svc.attach_mandate(clean_params, params, mandate) backend = get_backend() started = time.monotonic() try: - result = await backend.rpc_invoke(tenant, full_name, function, params, timeout=timeout) + result = await backend.rpc_invoke(tenant, full_name, function, params_for_rpc, timeout=timeout) elapsed_ms = int((time.monotonic() - started) * 1000) + receipt = receipts_svc.build_receipt( + trace_id=trace, + tenant=tenant, + actor=request.get("token") or {}, + device_id=full_name, + function=function, + params=clean_params, + status="succeeded", + elapsed_ms=elapsed_ms, + response=result, + mandate=mandate, + mandate_required=mandate_required, + mandate_verified=bool(mandate), + ) _audit(request, "invoke", trace_id=trace, device_id=full_name, function=function, elapsed_ms=elapsed_ms, success=True, - reason=_truncate(reason, 120)) + reason=_truncate(reason, 120), receipt_id=receipt["receipt_id"]) return _ok({"device_id": full_name, "function": function, - "elapsed_ms": elapsed_ms, "response": result}, + "elapsed_ms": elapsed_ms, "response": result, + "receipt": receipt}, trace_id=trace) except Exception as e: elapsed_ms = int((time.monotonic() - started) * 1000) + receipt = receipts_svc.build_receipt( + trace_id=trace, + tenant=tenant, + actor=request.get("token") or {}, + device_id=full_name, + function=function, + params=clean_params, + status="failed", + elapsed_ms=elapsed_ms, + error={"code": "invoke_failed", "message": str(e)}, + mandate=mandate, + mandate_required=mandate_required, + mandate_verified=bool(mandate), + ) _audit(request, "invoke", trace_id=trace, device_id=full_name, function=function, elapsed_ms=elapsed_ms, success=False, - reason=_truncate(reason, 120), error=str(e)) - return _err(status=502, code="invoke_failed", message=str(e), trace_id=trace) + reason=_truncate(reason, 120), error=str(e), + receipt_id=receipt["receipt_id"]) + return _err(status=502, code="invoke_failed", message=str(e), trace_id=trace, + extra={"receipt": receipt}) async def invoke_with_fallback(request: web.Request) -> web.Response: @@ -695,35 +771,129 @@ async def invoke_with_fallback(request: web.Request) -> web.Response: return _err(status=400, code="missing_function", message="function is required", trace_id=trace) params = body.get("params") or {} + if not isinstance(params, dict): + return _err(status=400, code="invalid_params", message="params must be an object", + trace_id=trace) + clean_params = mandates_svc.strip_dc_meta(params) timeout = _clamp_timeout(body.get("timeout")) reason = _truncate(body.get("reason") or body.get("llm_reasoning") or "", 500) backend = get_backend() failures = [] + receipts = [] for idx, raw_id in enumerate(ids): full_name = _full_device_name(tenant, raw_id) + device_doc = registry_client.get_device(full_name) + mandate = _mandate_for_device(body, params, tenant, raw_id, full_name) + mandate_policy = mandates_svc.get_function_mandate_policy(device_doc, function) + mandate_required = bool(mandate_policy and mandate_policy.get("required")) + mandate_result = mandates_svc.verify_server_mandate( + device_doc=device_doc, + device_id=full_name, + function=function, + params=clean_params, + mandate=mandate, + ) + if not mandate_result.ok: + receipt = receipts_svc.build_receipt( + trace_id=trace, + tenant=tenant, + actor=request.get("token") or {}, + device_id=full_name, + function=function, + params=clean_params, + status="denied", + elapsed_ms=0, + error={"code": mandate_result.error_code, "message": mandate_result.message}, + mandate=mandate, + mandate_required=mandate_required, + mandate_verified=False, + mandate_error_code=mandate_result.error_code, + ) + receipts.append(receipt) + failures.append({ + "device_id": full_name, + "error": mandate_result.message, + "code": mandate_result.error_code, + "receipt": receipt, + }) + continue + params_for_rpc = mandates_svc.attach_mandate(clean_params, params, mandate) started = time.monotonic() try: - response = await backend.rpc_invoke(tenant, full_name, function, params, timeout=timeout) + response = await backend.rpc_invoke(tenant, full_name, function, params_for_rpc, timeout=timeout) elapsed_ms = int((time.monotonic() - started) * 1000) + receipt = receipts_svc.build_receipt( + trace_id=trace, + tenant=tenant, + actor=request.get("token") or {}, + device_id=full_name, + function=function, + params=clean_params, + status="succeeded", + elapsed_ms=elapsed_ms, + response=response, + mandate=mandate, + mandate_required=mandate_required, + mandate_verified=bool(mandate), + ) _audit(request, "invoke_fallback", trace_id=trace, device_id=full_name, function=function, elapsed_ms=elapsed_ms, success=True, - reason=_truncate(reason, 120)) + reason=_truncate(reason, 120), receipt_id=receipt["receipt_id"]) return _ok( {"device_id": full_name, "function": function, "elapsed_ms": elapsed_ms, "response": response, + "receipt": receipt, "tried": [{"device_id": _full_device_name(tenant, x), "ok": (i == idx)} for i, x in enumerate(ids[: idx + 1])], - "failures": failures}, + "failures": failures, "receipts": receipts + [receipt]}, trace_id=trace, ) except Exception as e: - failures.append({"device_id": full_name, "error": str(e)}) + elapsed_ms = int((time.monotonic() - started) * 1000) + receipt = receipts_svc.build_receipt( + trace_id=trace, + tenant=tenant, + actor=request.get("token") or {}, + device_id=full_name, + function=function, + params=clean_params, + status="failed", + elapsed_ms=elapsed_ms, + error={"code": "invoke_failed", "message": str(e)}, + mandate=mandate, + mandate_required=mandate_required, + mandate_verified=bool(mandate), + ) + receipts.append(receipt) + failures.append({"device_id": full_name, "error": str(e), "receipt": receipt}) _audit(request, "invoke_fallback", trace_id=trace, function=function, success=False, reason=_truncate(reason, 120)) - return _err(status=502, code="all_failed", - message="All fallback devices failed", trace_id=trace) + all_denied = bool(failures) and all(f.get("code", "").startswith("mandate_") or f.get("code") == "invalid_mandate" for f in failures) + return _err( + status=403 if all_denied else 502, + code="all_denied" if all_denied else "all_failed", + message="All fallback devices were denied" if all_denied else "All fallback devices failed", + trace_id=trace, + extra={"failures": failures, "receipts": receipts}, + ) + + +def _mandate_for_device( + body: dict[str, Any], + params: dict[str, Any], + tenant: str, + raw_id: str, + full_name: str, +) -> dict[str, Any] | None: + mandates = body.get("mandates") + if isinstance(mandates, dict): + for key in (full_name, raw_id, _full_device_name(tenant, raw_id)): + mandate = mandates.get(key) + if isinstance(mandate, dict): + return mandate + return mandates_svc.extract_mandate(body, params) # ── event streaming (bounded) ────────────────────────────────────── diff --git a/packages/device-connect-server/tests/device_connect_server/test_portal_agent_api.py b/packages/device-connect-server/tests/device_connect_server/test_portal_agent_api.py index 287d299..a6dcd61 100644 --- a/packages/device-connect-server/tests/device_connect_server/test_portal_agent_api.py +++ b/packages/device-connect-server/tests/device_connect_server/test_portal_agent_api.py @@ -11,6 +11,7 @@ from __future__ import annotations +from datetime import datetime, timedelta, timezone from unittest.mock import patch import pytest @@ -20,6 +21,7 @@ from device_connect_server.portal.app import auth_middleware from device_connect_server.portal.services import tokens as tokens_svc from device_connect_server.portal.views import agent_api +from device_connect_edge import create_closed_mandate, create_open_mandate # A registry doc with extra fields the API must surface untouched. @@ -58,6 +60,51 @@ "registry": {"registered_at": "2026-05-01T12:00:00+00:00"}, } +PROTECTED_LOCK = { + "device_id": "acme-lock-001", + "tenant": "acme", + "identity": {"device_type": "lock"}, + "status": {"online": True}, + "capabilities": { + "functions": [ + { + "name": "unlock", + "parameters": {"type": "object"}, + "mandate": {"required": True, "scope": "actuation"}, + }, + {"name": "get_status", "parameters": {"type": "object"}}, + ], + "events": [], + }, +} + +PRINCIPAL_KEY = "principal-secret" +AGENT_KEY = "agent-secret" + + +def _closed_mandate(device_id: str = "acme-lock-001", params: dict | None = None) -> dict: + now = datetime.now(timezone.utc) + params = params or {"duration_s": 30} + open_mandate = create_open_mandate( + principal="operator", + agent="agent-1", + device_id=device_id, + methods=["unlock"], + constraints={"duration_s": {"lte": 60}}, + not_before=now - timedelta(seconds=5), + not_after=now + timedelta(minutes=5), + key=PRINCIPAL_KEY, + ) + return create_closed_mandate( + open_mandate=open_mandate, + agent="agent-1", + device_id=device_id, + method="unlock", + params=params, + key=AGENT_KEY, + issued_at=now, + ) + @pytest.fixture def fake_record(): @@ -451,6 +498,121 @@ async def rpc_invoke(self, tenant, full_name, fn, params, timeout): assert seen["timeout"] == agent_api.MAX_INVOKE_TIMEOUT_S +class TestInvokeMandates: + async def test_protected_function_with_valid_mandate_returns_receipt( + self, invoke_client, monkeypatch, + ): + monkeypatch.setenv( + "DC_MANDATE_KEYS_JSON", + '{"operator":"principal-secret","agent-1":"agent-secret"}', + ) + seen = {} + + class _FakeBackend: + def backend_name(self): return "test" + async def rpc_invoke(self, tenant, full_name, fn, params, timeout): + seen["params"] = params + return {"ok": True} + + mandate = _closed_mandate() + with patch( + "device_connect_server.portal.views.agent_api.registry_client.get_device", + return_value=PROTECTED_LOCK, + ), patch( + "device_connect_server.portal.views.agent_api.get_backend", + return_value=_FakeBackend(), + ): + r = await invoke_client.post( + "/api/agent/v1/devices/lock-001/invoke", + headers=H(), + json={ + "function": "unlock", + "params": {"duration_s": 30}, + "mandate": mandate, + }, + ) + + assert r.status == 200 + body = await r.json() + receipt = body["result"]["receipt"] + assert receipt["status"] == "succeeded" + assert receipt["mandate"]["verified"] is True + assert receipt["mandate"]["principal"] == "operator" + assert seen["params"]["_dc_meta"]["mandate"] == mandate + + async def test_protected_function_without_mandate_returns_denial_receipt( + self, invoke_client, monkeypatch, + ): + monkeypatch.setenv( + "DC_MANDATE_KEYS_JSON", + '{"operator":"principal-secret","agent-1":"agent-secret"}', + ) + + class _FakeBackend: + def backend_name(self): return "test" + async def rpc_invoke(self, tenant, full_name, fn, params, timeout): + raise AssertionError("backend must not be called") + + with patch( + "device_connect_server.portal.views.agent_api.registry_client.get_device", + return_value=PROTECTED_LOCK, + ), patch( + "device_connect_server.portal.views.agent_api.get_backend", + return_value=_FakeBackend(), + ): + r = await invoke_client.post( + "/api/agent/v1/devices/lock-001/invoke", + headers=H(), + json={"function": "unlock", "params": {"duration_s": 30}}, + ) + + assert r.status == 403 + body = await r.json() + assert body["error"]["code"] == "mandate_required" + assert body["receipt"]["status"] == "denied" + assert body["receipt"]["mandate"]["required"] is True + + async def test_existing_dc_meta_is_preserved_when_mandate_is_attached( + self, invoke_client, monkeypatch, + ): + monkeypatch.setenv( + "DC_MANDATE_KEYS_JSON", + '{"operator":"principal-secret","agent-1":"agent-secret"}', + ) + seen = {} + + class _FakeBackend: + def backend_name(self): return "test" + async def rpc_invoke(self, tenant, full_name, fn, params, timeout): + seen["params"] = params + return {"ok": True} + + mandate = _closed_mandate() + with patch( + "device_connect_server.portal.views.agent_api.registry_client.get_device", + return_value=PROTECTED_LOCK, + ), patch( + "device_connect_server.portal.views.agent_api.get_backend", + return_value=_FakeBackend(), + ): + r = await invoke_client.post( + "/api/agent/v1/devices/lock-001/invoke", + headers=H(), + json={ + "function": "unlock", + "params": { + "duration_s": 30, + "_dc_meta": {"traceparent": "trace"}, + }, + "mandate": mandate, + }, + ) + + assert r.status == 200 + assert seen["params"]["_dc_meta"]["traceparent"] == "trace" + assert seen["params"]["_dc_meta"]["mandate"] == mandate + + # ── invoke-with-fallback duplicate device id (regression) ───────── diff --git a/packages/device-connect-server/tests/device_connect_server/test_portal_mandates.py b/packages/device-connect-server/tests/device_connect_server/test_portal_mandates.py new file mode 100644 index 0000000..ac705a8 --- /dev/null +++ b/packages/device-connect-server/tests/device_connect_server/test_portal_mandates.py @@ -0,0 +1,137 @@ +# Copyright (c) 2024-2026, Arm Limited and Contributors. All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +"""Unit tests for portal mandate and receipt helpers.""" + +from __future__ import annotations + +from datetime import datetime, timedelta, timezone + +from device_connect_edge import create_closed_mandate, create_open_mandate +from device_connect_server.portal.services import execution_receipts, mandates + + +PRINCIPAL_KEY = "principal-secret" +AGENT_KEY = "agent-secret" + + +DEVICE_DOC = { + "device_id": "acme-lock-001", + "capabilities": { + "functions": [ + { + "name": "unlock", + "mandate": {"required": True, "scope": "actuation"}, + }, + {"name": "get_status"}, + ] + }, +} + + +def _mandate(params: dict | None = None) -> dict: + now = datetime.now(timezone.utc) + params = params or {"duration_s": 30} + open_mandate = create_open_mandate( + principal="operator", + agent="agent-1", + device_id="acme-lock-001", + methods=["unlock"], + constraints={"duration_s": {"lte": 60}}, + not_before=now - timedelta(seconds=5), + not_after=now + timedelta(minutes=5), + key=PRINCIPAL_KEY, + ) + return create_closed_mandate( + open_mandate=open_mandate, + agent="agent-1", + device_id="acme-lock-001", + method="unlock", + params=params, + key=AGENT_KEY, + issued_at=now, + ) + + +def test_get_function_mandate_policy(): + assert mandates.get_function_mandate_policy(DEVICE_DOC, "unlock") == { + "required": True, + "scope": "actuation", + } + assert mandates.get_function_mandate_policy(DEVICE_DOC, "get_status") is None + + +def test_extract_and_attach_mandate_preserves_existing_meta(): + mandate = _mandate() + params = {"duration_s": 30, "_dc_meta": {"traceparent": "trace"}} + + assert mandates.extract_mandate({"mandate": mandate}, params) == mandate + attached = mandates.attach_mandate( + mandates.strip_dc_meta(params), params, mandate, + ) + + assert attached["duration_s"] == 30 + assert attached["_dc_meta"]["traceparent"] == "trace" + assert attached["_dc_meta"]["mandate"] == mandate + + +def test_verify_server_mandate_validates_protected_function(monkeypatch): + monkeypatch.setenv( + "DC_MANDATE_KEYS_JSON", + '{"operator":"principal-secret","agent-1":"agent-secret"}', + ) + mandates._SERVER_MANDATE_REPLAY_CACHE.clear() + + result = mandates.verify_server_mandate( + device_doc=DEVICE_DOC, + device_id="acme-lock-001", + function="unlock", + params={"duration_s": 30}, + mandate=_mandate(), + ) + + assert result.ok is True + + +def test_verify_server_mandate_denies_missing_mandate(monkeypatch): + monkeypatch.setenv( + "DC_MANDATE_KEYS_JSON", + '{"operator":"principal-secret","agent-1":"agent-secret"}', + ) + + result = mandates.verify_server_mandate( + device_doc=DEVICE_DOC, + device_id="acme-lock-001", + function="unlock", + params={"duration_s": 30}, + mandate=None, + ) + + assert result.ok is False + assert result.error_code == "mandate_required" + + +def test_execution_receipt_hashes_payload_and_can_sign(monkeypatch): + monkeypatch.setenv("DC_RECEIPT_SIGNING_KEY", "receipt-secret") + + receipt = execution_receipts.build_receipt( + trace_id="trace-1", + tenant="acme", + actor={"token_id": "tok-1", "username": "alice"}, + device_id="acme-lock-001", + function="unlock", + params={"duration_s": 30}, + status="succeeded", + elapsed_ms=12, + response={"ok": True}, + mandate=_mandate(), + mandate_required=True, + mandate_verified=True, + ) + + assert receipt["receipt_id"].startswith("rcpt-") + assert receipt["params_sha256"] + assert receipt["response_sha256"] + assert receipt["signature"] + assert receipt["mandate"]["verified"] is True From 7b0e749d71399b54c5c5ecf01e2f2e637d0b6122 Mon Sep 17 00:00:00 2001 From: Sourav Pati Date: Mon, 11 May 2026 08:43:04 -0700 Subject: [PATCH 19/21] feat: add mandate receipt query log --- .../portal/services/execution_receipts.py | 37 +++++++++ .../portal/views/agent_api.py | 76 +++++++++++++++---- .../test_portal_mandates.py | 32 ++++++++ 3 files changed, 132 insertions(+), 13 deletions(-) diff --git a/packages/device-connect-server/device_connect_server/portal/services/execution_receipts.py b/packages/device-connect-server/device_connect_server/portal/services/execution_receipts.py index 59583ed..3449ef6 100644 --- a/packages/device-connect-server/device_connect_server/portal/services/execution_receipts.py +++ b/packages/device-connect-server/device_connect_server/portal/services/execution_receipts.py @@ -14,6 +14,9 @@ from datetime import datetime, timezone from typing import Any +_RECEIPTS: list[dict[str, Any]] = [] +_MAX_RECEIPTS = 1000 + def build_receipt( *, @@ -60,6 +63,40 @@ def build_receipt( return receipt +def record_receipt(receipt: dict[str, Any]) -> dict[str, Any]: + """Append a receipt to the process-local audit log.""" + _RECEIPTS.append(dict(receipt)) + if len(_RECEIPTS) > _MAX_RECEIPTS: + del _RECEIPTS[: len(_RECEIPTS) - _MAX_RECEIPTS] + return receipt + + +def get_receipt(receipt_id: str) -> dict[str, Any] | None: + for receipt in reversed(_RECEIPTS): + if receipt.get("receipt_id") == receipt_id: + return dict(receipt) + return None + + +def list_receipts( + *, + tenant: str | None = None, + device_id: str | None = None, + limit: int = 100, +) -> list[dict[str, Any]]: + safe_limit = max(1, min(int(limit or 100), 1000)) + out = [] + for receipt in reversed(_RECEIPTS): + if tenant is not None and receipt.get("tenant") != tenant: + continue + if device_id is not None and receipt.get("device_id") != device_id: + continue + out.append(dict(receipt)) + if len(out) >= safe_limit: + break + return out + + def hash_json(value: Any) -> str: payload = json.dumps( value, sort_keys=True, separators=(",", ":"), ensure_ascii=True, diff --git a/packages/device-connect-server/device_connect_server/portal/views/agent_api.py b/packages/device-connect-server/device_connect_server/portal/views/agent_api.py index 903af8f..ca47f8e 100644 --- a/packages/device-connect-server/device_connect_server/portal/views/agent_api.py +++ b/packages/device-connect-server/device_connect_server/portal/views/agent_api.py @@ -76,6 +76,8 @@ def setup_routes(app: web.Application): r.add_post(PREFIX + "/devices/{device_id}/credentials:rotate", device_credentials_rotate) r.add_post(PREFIX + "/devices/{device_id}/invoke", device_invoke) r.add_post(PREFIX + "/invoke-with-fallback", invoke_with_fallback) + r.add_get(PREFIX + "/receipts", receipts_list) + r.add_get(PREFIX + "/receipts/{receipt_id}", receipt_get) r.add_get( PREFIX + "/devices/{device_id}/events/{event_name}/stream", device_event_stream, @@ -666,7 +668,7 @@ async def device_invoke(request: web.Request) -> web.Response: mandate=mandate, ) if not mandate_result.ok: - receipt = receipts_svc.build_receipt( + receipt = receipts_svc.record_receipt(receipts_svc.build_receipt( trace_id=trace, tenant=tenant, actor=request.get("token") or {}, @@ -680,7 +682,7 @@ async def device_invoke(request: web.Request) -> web.Response: mandate_required=mandate_required, mandate_verified=False, mandate_error_code=mandate_result.error_code, - ) + )) _audit(request, "invoke_denied", trace_id=trace, device_id=full_name, function=function, receipt_id=receipt["receipt_id"], error=mandate_result.error_code) @@ -698,7 +700,7 @@ async def device_invoke(request: web.Request) -> web.Response: try: result = await backend.rpc_invoke(tenant, full_name, function, params_for_rpc, timeout=timeout) elapsed_ms = int((time.monotonic() - started) * 1000) - receipt = receipts_svc.build_receipt( + receipt = receipts_svc.record_receipt(receipts_svc.build_receipt( trace_id=trace, tenant=tenant, actor=request.get("token") or {}, @@ -711,7 +713,7 @@ async def device_invoke(request: web.Request) -> web.Response: mandate=mandate, mandate_required=mandate_required, mandate_verified=bool(mandate), - ) + )) _audit(request, "invoke", trace_id=trace, device_id=full_name, function=function, elapsed_ms=elapsed_ms, success=True, reason=_truncate(reason, 120), receipt_id=receipt["receipt_id"]) @@ -721,7 +723,7 @@ async def device_invoke(request: web.Request) -> web.Response: trace_id=trace) except Exception as e: elapsed_ms = int((time.monotonic() - started) * 1000) - receipt = receipts_svc.build_receipt( + receipt = receipts_svc.record_receipt(receipts_svc.build_receipt( trace_id=trace, tenant=tenant, actor=request.get("token") or {}, @@ -734,7 +736,7 @@ async def device_invoke(request: web.Request) -> web.Response: mandate=mandate, mandate_required=mandate_required, mandate_verified=bool(mandate), - ) + )) _audit(request, "invoke", trace_id=trace, device_id=full_name, function=function, elapsed_ms=elapsed_ms, success=False, reason=_truncate(reason, 120), error=str(e), @@ -795,7 +797,7 @@ async def invoke_with_fallback(request: web.Request) -> web.Response: mandate=mandate, ) if not mandate_result.ok: - receipt = receipts_svc.build_receipt( + receipt = receipts_svc.record_receipt(receipts_svc.build_receipt( trace_id=trace, tenant=tenant, actor=request.get("token") or {}, @@ -809,7 +811,7 @@ async def invoke_with_fallback(request: web.Request) -> web.Response: mandate_required=mandate_required, mandate_verified=False, mandate_error_code=mandate_result.error_code, - ) + )) receipts.append(receipt) failures.append({ "device_id": full_name, @@ -823,7 +825,7 @@ async def invoke_with_fallback(request: web.Request) -> web.Response: try: response = await backend.rpc_invoke(tenant, full_name, function, params_for_rpc, timeout=timeout) elapsed_ms = int((time.monotonic() - started) * 1000) - receipt = receipts_svc.build_receipt( + receipt = receipts_svc.record_receipt(receipts_svc.build_receipt( trace_id=trace, tenant=tenant, actor=request.get("token") or {}, @@ -836,7 +838,7 @@ async def invoke_with_fallback(request: web.Request) -> web.Response: mandate=mandate, mandate_required=mandate_required, mandate_verified=bool(mandate), - ) + )) _audit(request, "invoke_fallback", trace_id=trace, device_id=full_name, function=function, elapsed_ms=elapsed_ms, success=True, reason=_truncate(reason, 120), receipt_id=receipt["receipt_id"]) @@ -851,7 +853,7 @@ async def invoke_with_fallback(request: web.Request) -> web.Response: ) except Exception as e: elapsed_ms = int((time.monotonic() - started) * 1000) - receipt = receipts_svc.build_receipt( + receipt = receipts_svc.record_receipt(receipts_svc.build_receipt( trace_id=trace, tenant=tenant, actor=request.get("token") or {}, @@ -864,13 +866,13 @@ async def invoke_with_fallback(request: web.Request) -> web.Response: mandate=mandate, mandate_required=mandate_required, mandate_verified=bool(mandate), - ) + )) receipts.append(receipt) failures.append({"device_id": full_name, "error": str(e), "receipt": receipt}) _audit(request, "invoke_fallback", trace_id=trace, function=function, success=False, reason=_truncate(reason, 120)) - all_denied = bool(failures) and all(f.get("code", "").startswith("mandate_") or f.get("code") == "invalid_mandate" for f in failures) + all_denied = bool(failures) and all(_is_mandate_denial(f.get("code")) for f in failures) return _err( status=403 if all_denied else 502, code="all_denied" if all_denied else "all_failed", @@ -896,6 +898,54 @@ def _mandate_for_device( return mandates_svc.extract_mandate(body, params) +def _is_mandate_denial(code: Any) -> bool: + if not isinstance(code, str): + return False + return ( + code.startswith("mandate_") + or code in {"invalid_mandate", "invalid_mandate_signature", "unknown_mandate_key"} + ) + + +# ── execution receipts ───────────────────────────────────────────── + + +async def receipts_list(request: web.Request) -> web.Response: + trace = _trace_id() + _, err = _require_scope(request, "devices:read") + if err: + return err + tenant, err = _resolve_tenant(request) + if err: + return err + device_id = request.query.get("device_id") + full_device_id = _full_device_name(tenant, device_id) if device_id else None + try: + limit = int(request.query.get("limit", "100")) + except ValueError: + limit = 100 + receipts = receipts_svc.list_receipts( + tenant=tenant, + device_id=full_device_id, + limit=limit, + ) + return _ok({"receipts": receipts, "returned": len(receipts)}, trace_id=trace) + + +async def receipt_get(request: web.Request) -> web.Response: + trace = _trace_id() + _, err = _require_scope(request, "devices:read") + if err: + return err + tenant, err = _resolve_tenant(request) + if err: + return err + receipt = receipts_svc.get_receipt(request.match_info["receipt_id"]) + if receipt is None or receipt.get("tenant") != tenant: + return _err(status=404, code="not_found", message="Receipt not found", trace_id=trace) + return _ok({"receipt": receipt}, trace_id=trace) + + # ── event streaming (bounded) ────────────────────────────────────── diff --git a/packages/device-connect-server/tests/device_connect_server/test_portal_mandates.py b/packages/device-connect-server/tests/device_connect_server/test_portal_mandates.py index ac705a8..bd93275 100644 --- a/packages/device-connect-server/tests/device_connect_server/test_portal_mandates.py +++ b/packages/device-connect-server/tests/device_connect_server/test_portal_mandates.py @@ -135,3 +135,35 @@ def test_execution_receipt_hashes_payload_and_can_sign(monkeypatch): assert receipt["response_sha256"] assert receipt["signature"] assert receipt["mandate"]["verified"] is True + + +def test_execution_receipt_log_lists_latest_by_tenant_device_and_limit(): + execution_receipts._RECEIPTS.clear() + + first = execution_receipts.record_receipt({ + "receipt_id": "rcpt-1", + "tenant": "acme", + "device_id": "acme-lock-001", + "status": "succeeded", + }) + second = execution_receipts.record_receipt({ + "receipt_id": "rcpt-2", + "tenant": "acme", + "device_id": "acme-heater-001", + "status": "denied", + }) + execution_receipts.record_receipt({ + "receipt_id": "rcpt-3", + "tenant": "other", + "device_id": "other-lock-001", + "status": "succeeded", + }) + + assert execution_receipts.get_receipt("rcpt-1") == first + assert execution_receipts.get_receipt("missing") is None + assert execution_receipts.list_receipts(tenant="acme") == [second, first] + assert execution_receipts.list_receipts( + tenant="acme", + device_id="acme-lock-001", + ) == [first] + assert execution_receipts.list_receipts(tenant="acme", limit=1) == [second] From 93e5591050197cb44b1147664eb16fe37984097e Mon Sep 17 00:00:00 2001 From: Sourav Pati Date: Mon, 11 May 2026 08:48:03 -0700 Subject: [PATCH 20/21] fix: use tenant scoped registry lookup for mandates --- .../portal/views/agent_api.py | 4 +-- .../test_portal_agent_api.py | 27 ++++++++++++++++--- 2 files changed, 26 insertions(+), 5 deletions(-) diff --git a/packages/device-connect-server/device_connect_server/portal/views/agent_api.py b/packages/device-connect-server/device_connect_server/portal/views/agent_api.py index ca47f8e..fbd2e04 100644 --- a/packages/device-connect-server/device_connect_server/portal/views/agent_api.py +++ b/packages/device-connect-server/device_connect_server/portal/views/agent_api.py @@ -657,7 +657,7 @@ async def device_invoke(request: web.Request) -> web.Response: mandate = mandates_svc.extract_mandate(body, params) timeout = _clamp_timeout(body.get("timeout")) reason = _truncate(body.get("reason") or body.get("llm_reasoning") or "", 500) - device_doc = registry_client.get_device(full_name) + device_doc = _device_doc(tenant, device_id) mandate_policy = mandates_svc.get_function_mandate_policy(device_doc, function) mandate_required = bool(mandate_policy and mandate_policy.get("required")) mandate_result = mandates_svc.verify_server_mandate( @@ -785,7 +785,7 @@ async def invoke_with_fallback(request: web.Request) -> web.Response: receipts = [] for idx, raw_id in enumerate(ids): full_name = _full_device_name(tenant, raw_id) - device_doc = registry_client.get_device(full_name) + device_doc = _device_doc(tenant, raw_id) mandate = _mandate_for_device(body, params, tenant, raw_id, full_name) mandate_policy = mandates_svc.get_function_mandate_policy(device_doc, function) mandate_required = bool(mandate_policy and mandate_policy.get("required")) diff --git a/packages/device-connect-server/tests/device_connect_server/test_portal_agent_api.py b/packages/device-connect-server/tests/device_connect_server/test_portal_agent_api.py index a6dcd61..15d6119 100644 --- a/packages/device-connect-server/tests/device_connect_server/test_portal_agent_api.py +++ b/packages/device-connect-server/tests/device_connect_server/test_portal_agent_api.py @@ -486,6 +486,9 @@ async def rpc_invoke(self, tenant, full_name, fn, params, timeout): return {"ok": True} with patch( + "device_connect_server.portal.views.agent_api.registry_client.get_device", + return_value=None, + ), patch( "device_connect_server.portal.views.agent_api.get_backend", return_value=_FakeBackend(), ): @@ -515,9 +518,14 @@ async def rpc_invoke(self, tenant, full_name, fn, params, timeout): return {"ok": True} mandate = _closed_mandate() + def _lookup_device(tenant, device_id): + assert tenant == "acme" + assert device_id == "acme-lock-001" + return PROTECTED_LOCK + with patch( "device_connect_server.portal.views.agent_api.registry_client.get_device", - return_value=PROTECTED_LOCK, + side_effect=_lookup_device, ), patch( "device_connect_server.portal.views.agent_api.get_backend", return_value=_FakeBackend(), @@ -553,9 +561,14 @@ def backend_name(self): return "test" async def rpc_invoke(self, tenant, full_name, fn, params, timeout): raise AssertionError("backend must not be called") + def _lookup_device(tenant, device_id): + assert tenant == "acme" + assert device_id == "acme-lock-001" + return PROTECTED_LOCK + with patch( "device_connect_server.portal.views.agent_api.registry_client.get_device", - return_value=PROTECTED_LOCK, + side_effect=_lookup_device, ), patch( "device_connect_server.portal.views.agent_api.get_backend", return_value=_FakeBackend(), @@ -588,9 +601,14 @@ async def rpc_invoke(self, tenant, full_name, fn, params, timeout): return {"ok": True} mandate = _closed_mandate() + def _lookup_device(tenant, device_id): + assert tenant == "acme" + assert device_id == "acme-lock-001" + return PROTECTED_LOCK + with patch( "device_connect_server.portal.views.agent_api.registry_client.get_device", - return_value=PROTECTED_LOCK, + side_effect=_lookup_device, ), patch( "device_connect_server.portal.views.agent_api.get_backend", return_value=_FakeBackend(), @@ -631,6 +649,9 @@ async def rpc_invoke(self, tenant, full_name, fn, params, timeout): return {"ok": True, "attempt": len(attempts)} with patch( + "device_connect_server.portal.views.agent_api.registry_client.get_device", + return_value=None, + ), patch( "device_connect_server.portal.views.agent_api.get_backend", return_value=_FakeBackend(), ): From 38a08f85ca5ea0113043ab042efdbbab63cd6619 Mon Sep 17 00:00:00 2001 From: Sourav Pati Date: Mon, 11 May 2026 08:51:00 -0700 Subject: [PATCH 21/21] fix: expose mandate metadata from loaded capabilities --- .../drivers/capability_loader.py | 3 +++ .../tests/test_capability_loader.py | 25 +++++++++++++++++++ 2 files changed, 28 insertions(+) diff --git a/packages/device-connect-edge/device_connect_edge/drivers/capability_loader.py b/packages/device-connect-edge/device_connect_edge/drivers/capability_loader.py index 824add6..87102e4 100644 --- a/packages/device-connect-edge/device_connect_edge/drivers/capability_loader.py +++ b/packages/device-connect-edge/device_connect_edge/drivers/capability_loader.py @@ -512,6 +512,9 @@ def _register_functions(self, loaded: LoadedCapability) -> None: "description": description, "parameters": parameters, } + mandate = getattr(attr, "_mandate", None) + if mandate is not None: + loaded.function_schemas[func_name]["mandate"] = mandate # Register with namespace prefix self._functions[f"{cap_id}.{func_name}"] = attr diff --git a/packages/device-connect-edge/tests/test_capability_loader.py b/packages/device-connect-edge/tests/test_capability_loader.py index 0f7d22c..d962c70 100644 --- a/packages/device-connect-edge/tests/test_capability_loader.py +++ b/packages/device-connect-edge/tests/test_capability_loader.py @@ -62,6 +62,20 @@ async def custom(self, value: str = "default") -> dict: return {"value": value} """ +MANDATE_CAPABILITY_CODE = """\ +from device_connect_edge.drivers.decorators import requires_mandate, rpc + +class MandateCapability: + def __init__(self, device=None): + self.device = device + + @requires_mandate(scope="actuation") + @rpc() + async def unlock(self, duration_s: int) -> dict: + \"\"\"Unlock with delegated authorization.\"\"\" + return {"unlocked": True} +""" + EMIT_CAPABILITY_CODE = """\ from device_connect_edge.drivers.decorators import rpc, emit @@ -206,6 +220,17 @@ async def test_function_schemas_populated(self, loader, tmp_path): assert "parameters" in schema assert "description" in schema + @pytest.mark.asyncio + async def test_function_schemas_include_mandate_metadata(self, loader, tmp_path): + _write_capability(tmp_path, "mandate-cap", "MandateCapability", MANDATE_CAPABILITY_CODE) + await loader.load_all() + + loaded = loader.get_capabilities()["mandate-cap"] + assert loaded.function_schemas["unlock"]["mandate"] == { + "required": True, + "scope": "actuation", + } + # -- Extracting @emit methods --