diff --git a/ucapi/api.py b/ucapi/api.py index 5e30941..d0239cd 100644 --- a/ucapi/api.py +++ b/ucapi/api.py @@ -6,6 +6,7 @@ """ import asyncio +import inspect import json import logging import os @@ -13,6 +14,7 @@ from asyncio import AbstractEventLoop from copy import deepcopy from dataclasses import asdict, dataclass +from functools import wraps from typing import Any, Callable import websockets @@ -212,7 +214,7 @@ async def _handle_ws(self, websocket) -> None: # authenticate on connection await self._authenticate(websocket, True) - self._events.emit(uc.Events.CLIENT_CONNECTED) + self._events.emit(uc.Events.CLIENT_CONNECTED, websocket=websocket) async for message in websocket: # Distinguish between text (str) and binary (bytes-like) messages @@ -264,7 +266,7 @@ async def _handle_ws(self, websocket) -> None: self._clients.remove(websocket) _LOG.info("[%s] WS: Client removed", websocket.remote_address) - self._events.emit(uc.Events.CLIENT_DISCONNECTED) + self._events.emit(uc.Events.CLIENT_DISCONNECTED, websocket=websocket) async def _send_ok_result( self, websocket, req_id: int, msg_data: dict[str, Any] | list | None = None @@ -411,7 +413,7 @@ async def _process_ws_message(self, websocket, message) -> None: else: await self._handle_ws_request_msg(websocket, msg, req_id, msg_data) elif kind == "event": - await self._handle_ws_event_msg(msg, msg_data) + await self._handle_ws_event_msg(websocket, msg, msg_data) async def _process_ws_binary_message(self, websocket, data: bytes) -> None: """Process a binary WebSocket message using protobuf IntegrationMessage. @@ -710,10 +712,10 @@ async def _handle_ws_request_msg( elif msg == uc.WsMessages.ENTITY_COMMAND: await self._entity_command(websocket, req_id, msg_data) elif msg == uc.WsMessages.SUBSCRIBE_EVENTS: - await self._subscribe_events(msg_data) + await self._subscribe_events(websocket, msg_data) await self._send_ok_result(websocket, req_id) elif msg == uc.WsMessages.UNSUBSCRIBE_EVENTS: - await self._unsubscribe_events(msg_data) + await self._unsubscribe_events(websocket, msg_data) await self._send_ok_result(websocket, req_id) elif msg == uc.WsMessages.GET_DRIVER_METADATA: await self._send_ws_response( @@ -730,16 +732,16 @@ async def _handle_ws_request_msg( await self.driver_setup_error(websocket) async def _handle_ws_event_msg( - self, msg: str, msg_data: dict[str, Any] | None + self, websocket: Any, msg: str, msg_data: dict[str, Any] | None ) -> None: if msg == uc.WsMsgEvents.CONNECT: - self._events.emit(uc.Events.CONNECT) + self._events.emit(uc.Events.CONNECT, websocket=websocket) elif msg == uc.WsMsgEvents.DISCONNECT: - self._events.emit(uc.Events.DISCONNECT) + self._events.emit(uc.Events.DISCONNECT, websocket=websocket) elif msg == uc.WsMsgEvents.ENTER_STANDBY: - self._events.emit(uc.Events.ENTER_STANDBY) + self._events.emit(uc.Events.ENTER_STANDBY, websocket=websocket) elif msg == uc.WsMsgEvents.EXIT_STANDBY: - self._events.emit(uc.Events.EXIT_STANDBY) + self._events.emit(uc.Events.EXIT_STANDBY, websocket=websocket) elif msg == uc.WsMsgEvents.ABORT_DRIVER_SETUP: if not self._setup_handler: _LOG.warning( @@ -792,7 +794,9 @@ async def set_device_state(self, state: uc.DeviceStates) -> None: uc.EventCategory.DEVICE, ) - async def _subscribe_events(self, msg_data: dict[str, Any] | None) -> None: + async def _subscribe_events( + self, websocket: Any, msg_data: dict[str, Any] | None + ) -> None: if msg_data is None: _LOG.warning("Ignoring _subscribe_events: called with empty msg_data") return @@ -806,9 +810,15 @@ async def _subscribe_events(self, msg_data: dict[str, Any] | None) -> None: entity_id, ) - self._events.emit(uc.Events.SUBSCRIBE_ENTITIES, msg_data["entity_ids"]) + self._events.emit( + uc.Events.SUBSCRIBE_ENTITIES, + entity_ids=msg_data["entity_ids"], + websocket=websocket, + ) - async def _unsubscribe_events(self, msg_data: dict[str, Any] | None) -> bool: + async def _unsubscribe_events( + self, websocket: Any, msg_data: dict[str, Any] | None + ) -> bool: if msg_data is None: _LOG.warning("Ignoring _unsubscribe_events: called with empty msg_data") return False @@ -819,7 +829,11 @@ async def _unsubscribe_events(self, msg_data: dict[str, Any] | None) -> bool: if self._configured_entities.remove(entity_id) is False: res = False - self._events.emit(uc.Events.UNSUBSCRIBE_ENTITIES, msg_data["entity_ids"]) + self._events.emit( + uc.Events.UNSUBSCRIBE_ENTITIES, + entity_ids=msg_data["entity_ids"], + websocket=websocket, + ) return res @@ -1114,6 +1128,57 @@ async def driver_setup_error(self, websocket, error="OTHER") -> None: websocket, uc.WsMsgEvents.DRIVER_SETUP_CHANGE, data, uc.EventCategory.DEVICE ) + @staticmethod + def _wrap_event_listener(listener: Callable) -> Callable: + """Event listener wrapper for backwards compatibility. + + Wrap an event listener so it remains compatible if the library starts emitting + additional event parameters later. + + Example: + - listener() keeps working even if emitter calls listener(websocket) + - listener(websocket) keeps working if emitter calls listener(websocket, x, y) + """ + try: + sig = inspect.signature(listener) + except (TypeError, ValueError): + # Builtins / callables without inspectable signature: fall back to raw call. + return listener + + params = list(sig.parameters.values()) + + accepts_varargs = any( + p.kind == inspect.Parameter.VAR_POSITIONAL for p in params + ) + accepts_varkw = any(p.kind == inspect.Parameter.VAR_KEYWORD for p in params) + + # How many positional args can the listener accept (excluding *args/**kwargs)? + positional_kinds = ( + inspect.Parameter.POSITIONAL_ONLY, + inspect.Parameter.POSITIONAL_OR_KEYWORD, + ) + max_positional = sum(1 for p in params if p.kind in positional_kinds) + + # Which named kwargs are accepted (if no **kwargs)? + accepted_kw = { + p.name + for p in params + if p.kind + in (inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY) + } + + @wraps(listener) + def wrapper(*args: Any, **kwargs: Any): + call_args = args if accepts_varargs else args[:max_positional] + call_kwargs = ( + kwargs + if accepts_varkw + else {k: v for k, v in kwargs.items() if k in accepted_kw} + ) + return listener(*call_args, **call_kwargs) + + return wrapper + def add_listener(self, event: uc.Events, f: Callable) -> None: """ Register a callback handler for the given event. @@ -1121,7 +1186,7 @@ def add_listener(self, event: uc.Events, f: Callable) -> None: :param event: the event :param f: callback handler """ - self._events.add_listener(event, f) + self._events.add_listener(event, self._wrap_event_listener(f)) def listens_to(self, event: uc.Events) -> Callable[[Callable], Callable]: """ @@ -1132,7 +1197,7 @@ def listens_to(self, event: uc.Events) -> Callable[[Callable], Callable]: """ def on(f: Callable) -> Callable: - self._events.add_listener(event, f) + self._events.add_listener(event, self._wrap_event_listener(f)) return f return on diff --git a/ucapi/api_definitions.py b/ucapi/api_definitions.py index 0b09da2..ad6cc77 100644 --- a/ucapi/api_definitions.py +++ b/ucapi/api_definitions.py @@ -81,25 +81,77 @@ class WsMsgEvents(str, Enum): class Events(str, Enum): - """Internal library events.""" + """Internal library events. + + All event parameters are named parameters and optional. + """ CLIENT_CONNECTED = "client_connected" - """WebSocket client connected.""" + """WebSocket client connected. + + Named parameters: + + - websocket: WebSocket client connection + """ CLIENT_DISCONNECTED = "client_disconnected" - """WebSocket client disconnected.""" + """WebSocket client disconnected. + + Named parameters: + + - websocket: WebSocket client connection + """ ENTITY_ATTRIBUTES_UPDATED = "entity_attributes_updated" + """Entity attributes updated. + + Named parameters: + + - entity_id: entity identifier + - entity_type: entity type + - attributes: updated attributes""" SUBSCRIBE_ENTITIES = "subscribe_entities" - """Integration API `subscribe_events` message.""" + """Integration API `subscribe_events` message. + + Named parameters: + + - entity_ids: list of entity IDs to subscribe to + - websocket: WebSocket client connection + """ UNSUBSCRIBE_ENTITIES = "unsubscribe_entities" - """Integration API `unsubscribe_events` message.""" + """Integration API `unsubscribe_events` message. + + Named parameters: + + - entity_ids: list of entity IDs to unsubscribe + - websocket: WebSocket client connection + """ CONNECT = "connect" - """Integration-API `connect` event message.""" + """Integration-API `connect` event message. + + Named parameters: + + - websocket: WebSocket client connection + """ DISCONNECT = "disconnect" - """Integration-API `disconnect` event message.""" + """Integration-API `disconnect` event message. + + Named parameters: + + - websocket: WebSocket client connection + """ ENTER_STANDBY = "enter_standby" - """Integration-API `enter_standby` event message.""" + """Integration-API `enter_standby` event message. + + Named parameters: + + - websocket: WebSocket client connection + """ EXIT_STANDBY = "exit_standby" - """Integration-API `exit_standby` event message.""" + """Integration-API `exit_standby` event message. + + Named parameters: + + - websocket: WebSocket client connection + """ # Does EventCategory need to be public?