Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,8 @@ htmlcov/
atheris-report.md
fuzz-hypothesis-results.xml
**/fuzz/findings/

# Coverage artifacts
.coverage
.coverage.*
coverage.xml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ async def motion_detected(self, zone: str, confidence: float):
periodic,
build_function_schema,
build_event_schema,
get_rpc_source_device,
)
from device_connect_edge.drivers.transport import DriverTransport
from device_connect_edge.drivers.capability_loader import (
Expand All @@ -56,6 +57,7 @@ async def motion_detected(self, zone: str, confidence: float):
"before_emit",
"periodic",
"on",
"get_rpc_source_device",
"build_function_schema",
"build_event_schema",
# Capability loading
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,24 @@ async def detection_loop(self):
# - "internal": Called from another method on the same device
_call_origin: contextvars.ContextVar[str] = contextvars.ContextVar('call_origin', default='external')

# Authenticated source device of the in-flight RPC, exposed so driver
# handlers can perform per-call authorization without changing their
# signatures. Set by the @rpc wrapper for the duration of the handler
# call; None when the call did not carry a source identity (e.g. local
# routine calls).
_rpc_source_device: contextvars.ContextVar[Optional[str]] = contextvars.ContextVar(
'rpc_source_device', default=None
)


def get_rpc_source_device() -> Optional[str]:
"""Return the authenticated source device id of the current RPC, if any.

Returns None when invoked outside an external RPC (e.g. local routine or
internal call) or when the transport did not supply a source identity.
"""
return _rpc_source_device.get()


class routine_context:
"""Context manager to mark calls as coming from a routine.
Expand Down Expand Up @@ -402,6 +420,10 @@ async def wrapper(self, *args, **kwargs):
# Extract source_device from kwargs (injected by DeviceRuntime)
source_device = kwargs.pop("source_device", None)

# Expose the authenticated caller to the handler via a contextvar so
# drivers can perform per-call authorization checks.
_src_token = _rpc_source_device.set(source_device)

# Build args summary (source_device already removed)
args_summary = _summarize_args(args, kwargs)

Expand Down Expand Up @@ -499,6 +521,7 @@ async def wrapper(self, *args, **kwargs):
span.set_status(StatusCode.ERROR, str(e))
raise
finally:
_rpc_source_device.reset(_src_token)
duration_ms = (time.monotonic() - t0) * 1000
metric_attrs = {"rpc.method": func_name, "device_connect.device.id": device_id, "status": status}
metrics.rpc_duration.record(duration_ms, metric_attrs)
Expand Down
70 changes: 70 additions & 0 deletions packages/device-connect-edge/tests/test_rpc_source_device.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Copyright (c) 2024-2026, Arm Limited and Contributors. All rights reserved.
#
# SPDX-License-Identifier: Apache-2.0

"""Tests for the @rpc caller-identity hook (get_rpc_source_device).

Security hardening: device drivers need to know the authenticated source
device of the in-flight RPC so they can perform per-call authorization. The
@rpc wrapper exposes it via a contextvar for the duration of the handler.
"""

import pytest

from device_connect_edge.drivers import DeviceDriver, rpc, get_rpc_source_device
from device_connect_edge.types import DeviceIdentity, DeviceStatus


class _CallerAwareDriver(DeviceDriver):
device_type = "caller_aware"

@property
def identity(self) -> DeviceIdentity:
return DeviceIdentity(device_type="caller_aware", manufacturer="Test", model="X")

@property
def status(self) -> DeviceStatus:
return DeviceStatus()

@rpc()
async def whoami(self) -> dict:
"""Return the authenticated caller as seen inside the handler."""
return {"caller": get_rpc_source_device()}

async def connect(self) -> None:
pass

async def disconnect(self) -> None:
pass


@pytest.mark.asyncio
async def test_source_device_visible_inside_handler():
d = _CallerAwareDriver()
res = await d.whoami(source_device="controller-1")
assert res["caller"] == "controller-1"


@pytest.mark.asyncio
async def test_source_device_none_when_absent():
d = _CallerAwareDriver()
res = await d.whoami()
assert res["caller"] is None


@pytest.mark.asyncio
async def test_source_device_reset_after_call():
d = _CallerAwareDriver()
await d.whoami(source_device="controller-1")
# Outside the handler the contextvar must be back to its default.
assert get_rpc_source_device() is None


@pytest.mark.asyncio
async def test_source_device_not_leaked_into_handler_kwargs():
# source_device must be consumed by the wrapper, not passed to the
# handler body (which does not declare it).
d = _CallerAwareDriver()
# whoami takes no params; passing source_device must not raise.
res = await d.whoami(source_device="x")
assert res["caller"] == "x"
Loading