diff --git a/appdaemon/plugins/hass/hassplugin.py b/appdaemon/plugins/hass/hassplugin.py index 1f6f89d1c..467a484ba 100644 --- a/appdaemon/plugins/hass/hassplugin.py +++ b/appdaemon/plugins/hass/hassplugin.py @@ -12,6 +12,7 @@ from datetime import datetime from time import perf_counter from typing import Any, Literal, Optional +from urllib.parse import urlencode import aiohttp from aiohttp import ClientResponseError, WebSocketError, WSMsgType @@ -373,9 +374,6 @@ async def websocket_send_json( Returns: A dict containing the response from Home Assistant. """ - request = utils.clean_kwargs(request) - request = utils.remove_literals(request, (None,)) - if not self.connect_event.is_set(): self.logger.debug("Not connected to websocket, skipping JSON send.") return @@ -387,7 +385,7 @@ async def websocket_send_json( if not silent: # include this in the "not auth" section so we don't accidentally put the token in the logs - req_json = json.dumps(request, indent=4) + req_json = utils.convert_json(request, indent=4) for i, line in enumerate(req_json.splitlines()): if i == 0: self.logger.debug(f"Sending JSON: {line}") @@ -396,7 +394,7 @@ async def websocket_send_json( send_time = perf_counter() try: - await self.ws.send_json(request) + await self.ws.send_json(request, dumps=utils.convert_json) # happens when the connection closes in the middle, which could be during shutdown except ConnectionResetError: if self.AD.stopping: @@ -405,7 +403,7 @@ async def websocket_send_json( else: raise # Something bad actually happened, so raise the exception - self.update_perf(bytes_sent=len(json.dumps(request)), requests_sent=1) + self.update_perf(bytes_sent=len(utils.convert_json(request)), requests_sent=1) match request: case {"type": "auth"}: @@ -454,25 +452,25 @@ async def http_method( **kwargs (optional): Zero or more keyword arguments. These get used as the data for the method, as appropriate. """ - kwargs = utils.clean_http_kwargs(kwargs) url = self.config.ha_url / endpoint.lstrip("/") try: - self.update_perf( - bytes_sent=len(str(url)) + len(json.dumps(kwargs).encode("utf-8")), - requests_sent=1, - ) - self.logger.debug(f"Hass {method.upper()} {endpoint}: {kwargs}") match method.lower(): case "get": - http_method = functools.partial(self.session.get, params=kwargs) + cleaned = utils.clean_http_params_for_urlencode(kwargs) + payload_size = len(urlencode(cleaned).encode("utf-8")) + http_method = functools.partial(self.session.get, params=cleaned) case "post": + payload_size = len(utils.convert_json(kwargs).encode("utf-8")) http_method = functools.partial(self.session.post, json=kwargs) case "delete": - http_method = functools.partial(self.session.delete, params=kwargs) + cleaned = utils.clean_http_params_for_urlencode(kwargs) + payload_size = len(urlencode(cleaned).encode("utf-8")) + http_method = functools.partial(self.session.delete, params=cleaned) case _: raise ValueError(f"Invalid method: {method}") + self.update_perf(bytes_sent=len(str(url)) + payload_size, requests_sent=1) timeout = utils.parse_timedelta(timeout) client_timeout = aiohttp.ClientTimeout(total=timeout.total_seconds()) diff --git a/appdaemon/utils.py b/appdaemon/utils.py index a05a17128..6d48aec11 100644 --- a/appdaemon/utils.py +++ b/appdaemon/utils.py @@ -1189,49 +1189,55 @@ def time_str(start: float, now: float | None = None) -> str: return format_timedelta((now or perf_counter()) - start) -def clean_kwargs(val: Any, *, http: bool = False) -> Any: - """Recursively clean a dict of kwargs. - - Conversions: - - datetime values are converted to ISO format strings - - Mapping values (like dicts) are converted to dicts of cleaned key-value pairs - - Iterable values (like lists and tuples) are converted to lists of cleaned values - - Other values are converted to strings +def remove_literals(val: Any, literal: Sequence[Any]) -> Any: + """Remove instances of literals from a nested data structure. + + Uses identity comparison (``is``) rather than equality (``==``) + to avoid ``0 == False`` and ``0.0 == False`` pitfalls. """ + def _is_literal(v: Any) -> bool: + return any(v is lit for lit in literal) match val: - case True if http: - return "true" - case str() | int() | float() | bool() | None: + case str(): return val - case datetime(): - return val.isoformat() case Mapping(): - return {k: clean_kwargs(v, http=http) for k, v in val.items()} + return {k: remove_literals(v, literal) for k, v in val.items() if not _is_literal(v)} case Iterable(): - return [clean_kwargs(v, http=http) for v in val] + return [remove_literals(v, literal) for v in val if not _is_literal(v)] case _: - return str(val) + return val -def remove_literals(val: Any, literal: Sequence[Any]) -> Any: - """Remove instances of literals from a nested data structure.""" +def clean_http_params_for_urlencode(val: Any) -> Any: + """Recursively cleans kwargs for use as URL query parameters. + + - None and False are excluded (HA treats param presence as enabled) + - True is converted to "true" + - datetime objects are converted to ISO format + - Other values are kept as-is + """ match val: - case str(): + case True: + return "true" + case str() | int() | float(): return val + case datetime(): + return val.isoformat() case Mapping(): - return {k: remove_literals(v, literal) for k, v in val.items() if v not in literal} + return { + k: clean_http_params_for_urlencode(v) + for k, v in val.items() + if v is not None and v is not False + } case Iterable(): - return [remove_literals(v, literal) for v in val if v not in literal] + return [ + clean_http_params_for_urlencode(v) + for v in val + if v is not None and v is not False + ] case _: - return val - - -def clean_http_kwargs(val: Any) -> Any: - """Recursively cleans the kwarg dict to prepare it for use in HTTP requests.""" - cleaned = clean_kwargs(val, http=True) - pruned = remove_literals(cleaned, (None, False)) - return pruned + return str(val) def unwrapped(func: Callable) -> Callable: diff --git a/tests/unit/test_kwarg_clean.py b/tests/unit/test_kwarg_clean.py index 7fbadeaa0..6eed1a551 100644 --- a/tests/unit/test_kwarg_clean.py +++ b/tests/unit/test_kwarg_clean.py @@ -1,9 +1,9 @@ -from copy import deepcopy +import json from datetime import datetime import pytest import pytz -from appdaemon.utils import clean_http_kwargs, clean_kwargs, remove_literals +from appdaemon.utils import clean_http_params_for_urlencode, convert_json, remove_literals pytestmark = [ pytest.mark.ci, @@ -22,34 +22,42 @@ } -def test_clean_kwargs(): - cleaned = clean_kwargs(BASE) - pruned = remove_literals(BASE, (None,)) - assert isinstance(cleaned["f"], str) - +def test_clean_http_params_for_urlencode(): + cleaned = clean_http_params_for_urlencode(BASE) assert cleaned["a"] == 1 assert cleaned["b"] == 2.0 assert cleaned["c"] == "three" - assert cleaned["d"] is True - assert cleaned["e"] is False - assert "g" not in pruned - - kwargs = deepcopy(BASE) - - kwargs["nested"] = deepcopy(BASE) - kwargs["nested"]["extra"] = deepcopy(BASE) - cleaned = clean_kwargs(kwargs) - assert isinstance(cleaned["nested"]["extra"]["f"], str) - - -def test_clean_http_kwargs(): - cleaned = clean_http_kwargs(BASE) - assert isinstance(cleaned["f"], str) assert cleaned["d"] == "true" assert "e" not in cleaned + assert isinstance(cleaned["f"], str) assert "g" not in cleaned +def test_clean_http_params_for_urlencode_preserves_zero(): + """0 and 0.0 must survive clean_http_params_for_urlencode (0 == False but 0 is not False).""" + data = {"offset": 0, "price": 0.0, "flag": False, "name": "test"} + cleaned = clean_http_params_for_urlencode(data) + assert cleaned["offset"] == 0 + assert cleaned["price"] == 0.0 + assert "flag" not in cleaned + assert cleaned["name"] == "test" + + +def test_clean_http_params_for_urlencode_nested(): + """Nested dicts and datetimes are cleaned recursively.""" + data = { + "outer": { + "inner": { + "dt": datetime(2025, 9, 22, 12, 0, 0, tzinfo=pytz.utc), + "gone": None, + } + } + } + cleaned = clean_http_params_for_urlencode(data) + assert cleaned["outer"]["inner"]["dt"] == "2025-09-22T12:00:00+00:00" + assert "gone" not in cleaned["outer"]["inner"] + + SERVICE_CALL = { 'type': 'call_service', 'domain': 'notify', @@ -68,8 +76,9 @@ def test_clean_http_kwargs(): } -def test_websocket_service_call_kwargs(): - cleaned = clean_kwargs(SERVICE_CALL) +def test_clean_http_params_for_urlencode_complex_nested(): + """Complex nested structure (like a service call) is cleaned correctly.""" + cleaned = clean_http_params_for_urlencode(SERVICE_CALL) match cleaned: case { "service_data": @@ -87,7 +96,119 @@ def test_websocket_service_call_kwargs(): case _: assert False, "Action format incorrect" case _: - assert False, "Action format incorrect" + assert False, "Structure format incorrect" + +def test_remove_literals_strips_none_from_service_call(): pruned = remove_literals(SERVICE_CALL, (None,)) assert "timeout" not in pruned["service_data"] + + +def test_remove_literals_preserves_zero(): + """remove_literals must use identity (is), not equality (==), to avoid 0 == False.""" + data = {"a": 0, "b": 0.0, "c": False, "d": None, "e": "hello"} + pruned = remove_literals(data, (None, False)) + assert pruned["a"] == 0 + assert pruned["b"] == 0.0 + assert "c" not in pruned + assert "d" not in pruned + assert pruned["e"] == "hello" + + +class TestConvertJson: + """convert_json is the JSON serializer used by the aiohttp session and websocket.""" + + def test_datetime_uses_isoformat(self): + dt = datetime(2025, 6, 15, 10, 0, 0, tzinfo=pytz.utc) + result = convert_json({"timestamp": dt}) + parsed = json.loads(result) + assert parsed["timestamp"] == "2025-06-15T10:00:00+00:00" + + def test_booleans_are_json_booleans(self): + result = convert_json({"flag": True, "other": False}) + parsed = json.loads(result) + assert parsed["flag"] is True + assert parsed["other"] is False + + def test_none_becomes_null(self): + result = convert_json({"value": None}) + parsed = json.loads(result) + assert parsed["value"] is None + + def test_zero_preserved(self): + result = convert_json({"rate": 0, "price": 0.0}) + parsed = json.loads(result) + assert parsed["rate"] == 0 + assert parsed["price"] == 0.0 + + def test_unknown_type_falls_back_to_str(self): + class Custom: + def __str__(self): + return "custom_value" + + result = convert_json({"obj": Custom()}) + parsed = json.loads(result) + assert parsed["obj"] == "custom_value" + + +class TestSetStateRegression: + """Regression tests for set_state scenarios from issues #2531, #2464, #2492. + + These simulate what happens when set_state kwargs pass through + session.post(json=kwargs) with convert_json as the serializer + (the transparent POST path). + """ + + def test_issue_2531_false_and_zero_attributes(self): + """Reproduces the exact scenario from issue #2531.""" + kwargs = { + "state": 1, + "attributes": { + "rate": 0, + "friendly_name": "Test Entity", + "unit_of_measurement": "GBP/kWh", + "plunge": False, + "plunge_start": False, + }, + } + result = json.loads(convert_json(kwargs)) + assert result["state"] == 1 + assert result["attributes"]["rate"] == 0 + assert result["attributes"]["plunge"] is False + assert result["attributes"]["plunge_start"] is False + assert result["attributes"]["friendly_name"] == "Test Entity" + + def test_issue_2492_zero_float_in_nested_dict(self): + """Reproduces the scenario from issue #2492 where 0.0 prices vanished.""" + kwargs = { + "state": "0.08", + "attributes": { + "prices": { + "2025-11-29T00:00:00+02:00": {"price": 0.0, "intervals": 4}, + "2025-11-29T11:00:00+02:00": {"price": 0.08, "intervals": 4}, + } + }, + } + result = json.loads(convert_json(kwargs)) + prices = result["attributes"]["prices"] + assert prices["2025-11-29T00:00:00+02:00"]["price"] == 0.0 + assert prices["2025-11-29T11:00:00+02:00"]["price"] == 0.08 + + def test_none_attribute_preserved_as_null(self): + """None values in attributes should become JSON null, not be dropped.""" + kwargs = { + "state": "on", + "attributes": { + "optional_field": None, + "name": "test", + }, + } + result = json.loads(convert_json(kwargs)) + assert "optional_field" in result["attributes"] + assert result["attributes"]["optional_field"] is None + + def test_state_zero_preserved(self): + """state=0 must not be dropped.""" + kwargs = {"state": 0, "attributes": {"icon": "mdi:radiator"}} + result = json.loads(convert_json(kwargs)) + assert result["state"] == 0