diff --git a/.gitignore b/.gitignore index 2ca8ddc..7bc9691 100644 --- a/.gitignore +++ b/.gitignore @@ -20,3 +20,8 @@ htmlcov/ atheris-report.md fuzz-hypothesis-results.xml **/fuzz/findings/ + +# Coverage artifacts +.coverage +.coverage.* +coverage.xml diff --git a/packages/device-connect-edge/device_connect_edge/drivers/__init__.py b/packages/device-connect-edge/device_connect_edge/drivers/__init__.py index 3abdf1d..8a95e7e 100644 --- a/packages/device-connect-edge/device_connect_edge/drivers/__init__.py +++ b/packages/device-connect-edge/device_connect_edge/drivers/__init__.py @@ -39,6 +39,7 @@ async def motion_detected(self, zone: str, confidence: float): periodic, build_function_schema, build_event_schema, + get_rpc_source_device, ) from device_connect_edge.drivers.transport import DriverTransport from device_connect_edge.drivers.capability_loader import ( @@ -56,6 +57,7 @@ async def motion_detected(self, zone: str, confidence: float): "before_emit", "periodic", "on", + "get_rpc_source_device", "build_function_schema", "build_event_schema", # Capability loading diff --git a/packages/device-connect-edge/device_connect_edge/drivers/decorators.py b/packages/device-connect-edge/device_connect_edge/drivers/decorators.py index 4237699..dfa6935 100644 --- a/packages/device-connect-edge/device_connect_edge/drivers/decorators.py +++ b/packages/device-connect-edge/device_connect_edge/drivers/decorators.py @@ -72,6 +72,24 @@ async def detection_loop(self): # - "internal": Called from another method on the same device _call_origin: contextvars.ContextVar[str] = contextvars.ContextVar('call_origin', default='external') +# Authenticated source device of the in-flight RPC, exposed so driver +# handlers can perform per-call authorization without changing their +# signatures. Set by the @rpc wrapper for the duration of the handler +# call; None when the call did not carry a source identity (e.g. local +# routine calls). +_rpc_source_device: contextvars.ContextVar[Optional[str]] = contextvars.ContextVar( + 'rpc_source_device', default=None +) + + +def get_rpc_source_device() -> Optional[str]: + """Return the authenticated source device id of the current RPC, if any. + + Returns None when invoked outside an external RPC (e.g. local routine or + internal call) or when the transport did not supply a source identity. + """ + return _rpc_source_device.get() + class routine_context: """Context manager to mark calls as coming from a routine. @@ -402,6 +420,10 @@ async def wrapper(self, *args, **kwargs): # Extract source_device from kwargs (injected by DeviceRuntime) source_device = kwargs.pop("source_device", None) + # Expose the authenticated caller to the handler via a contextvar so + # drivers can perform per-call authorization checks. + _src_token = _rpc_source_device.set(source_device) + # Build args summary (source_device already removed) args_summary = _summarize_args(args, kwargs) @@ -499,6 +521,7 @@ async def wrapper(self, *args, **kwargs): span.set_status(StatusCode.ERROR, str(e)) raise finally: + _rpc_source_device.reset(_src_token) duration_ms = (time.monotonic() - t0) * 1000 metric_attrs = {"rpc.method": func_name, "device_connect.device.id": device_id, "status": status} metrics.rpc_duration.record(duration_ms, metric_attrs) diff --git a/packages/device-connect-edge/tests/test_rpc_source_device.py b/packages/device-connect-edge/tests/test_rpc_source_device.py new file mode 100644 index 0000000..39681eb --- /dev/null +++ b/packages/device-connect-edge/tests/test_rpc_source_device.py @@ -0,0 +1,70 @@ +# Copyright (c) 2024-2026, Arm Limited and Contributors. All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +"""Tests for the @rpc caller-identity hook (get_rpc_source_device). + +Security hardening: device drivers need to know the authenticated source +device of the in-flight RPC so they can perform per-call authorization. The +@rpc wrapper exposes it via a contextvar for the duration of the handler. +""" + +import pytest + +from device_connect_edge.drivers import DeviceDriver, rpc, get_rpc_source_device +from device_connect_edge.types import DeviceIdentity, DeviceStatus + + +class _CallerAwareDriver(DeviceDriver): + device_type = "caller_aware" + + @property + def identity(self) -> DeviceIdentity: + return DeviceIdentity(device_type="caller_aware", manufacturer="Test", model="X") + + @property + def status(self) -> DeviceStatus: + return DeviceStatus() + + @rpc() + async def whoami(self) -> dict: + """Return the authenticated caller as seen inside the handler.""" + return {"caller": get_rpc_source_device()} + + async def connect(self) -> None: + pass + + async def disconnect(self) -> None: + pass + + +@pytest.mark.asyncio +async def test_source_device_visible_inside_handler(): + d = _CallerAwareDriver() + res = await d.whoami(source_device="controller-1") + assert res["caller"] == "controller-1" + + +@pytest.mark.asyncio +async def test_source_device_none_when_absent(): + d = _CallerAwareDriver() + res = await d.whoami() + assert res["caller"] is None + + +@pytest.mark.asyncio +async def test_source_device_reset_after_call(): + d = _CallerAwareDriver() + await d.whoami(source_device="controller-1") + # Outside the handler the contextvar must be back to its default. + assert get_rpc_source_device() is None + + +@pytest.mark.asyncio +async def test_source_device_not_leaked_into_handler_kwargs(): + # source_device must be consumed by the wrapper, not passed to the + # handler body (which does not declare it). + d = _CallerAwareDriver() + # whoami takes no params; passing source_device must not raise. + res = await d.whoami(source_device="x") + assert res["caller"] == "x"