diff --git a/.coverage_full b/.coverage_full new file mode 100644 index 0000000..6aa8830 Binary files /dev/null and b/.coverage_full differ diff --git a/.secrets.baseline b/.secrets.baseline index 1beb0ff..28b888f 100644 --- a/.secrets.baseline +++ b/.secrets.baseline @@ -268,28 +268,28 @@ "filename": "src/api/hive_auth_async.py", "hashed_secret": "5dc786e32e3a0a4611daaf397721c6ef64cd71b0", "is_verified": false, - "line_number": 48 + "line_number": 49 }, { "type": "Secret Keyword", "filename": "src/api/hive_auth_async.py", "hashed_secret": "ac9f290e69cee683ba3c63461f1f3fa02765032a", "is_verified": false, - "line_number": 49 + "line_number": 50 }, { "type": "Secret Keyword", "filename": "src/api/hive_auth_async.py", "hashed_secret": "351b174ccf89601f6f4bd3f3970a4aba7d17c98e", "is_verified": false, - "line_number": 52 + "line_number": 53 }, { "type": "Secret Keyword", "filename": "src/api/hive_auth_async.py", "hashed_secret": "576956b5291ac38d04ef5f82cc974286a857f0b2", "is_verified": false, - "line_number": 109 + "line_number": 110 } ], "src/api/srp_crypto.py": [ @@ -298,112 +298,112 @@ "filename": "src/api/srp_crypto.py", "hashed_secret": "3e619ee0820ecf213c2f38c634e416b53defe3b0", "is_verified": false, - "line_number": 11 + "line_number": 10 }, { "type": "Hex High Entropy String", "filename": "src/api/srp_crypto.py", "hashed_secret": "b8e0d506d969f09a9af89ce89fd9759b72c63262", "is_verified": false, - "line_number": 12 + "line_number": 11 }, { "type": "Hex High Entropy String", "filename": "src/api/srp_crypto.py", "hashed_secret": "e97a751edc71e9afbe0c0f63ec94873392833f9f", "is_verified": false, - "line_number": 13 + "line_number": 12 }, { "type": "Hex High Entropy String", "filename": "src/api/srp_crypto.py", "hashed_secret": "92488c021dd524a2f4e116666b3645308fa0e35c", "is_verified": false, - "line_number": 14 + "line_number": 13 }, { "type": "Hex High Entropy String", "filename": "src/api/srp_crypto.py", "hashed_secret": "d4571e2f026f458aecd2950b0eb6aec190276177", "is_verified": false, - "line_number": 15 + "line_number": 14 }, { "type": "Hex High Entropy String", "filename": "src/api/srp_crypto.py", "hashed_secret": "8109d3c2f659f13cb61fc9e71eed574efe8c8fd8", "is_verified": false, - "line_number": 16 + "line_number": 15 }, { "type": "Hex High Entropy String", "filename": "src/api/srp_crypto.py", "hashed_secret": "08cac7461d7b624b88c53ee47da09cbbb84ea290", "is_verified": false, - "line_number": 17 + "line_number": 16 }, { "type": "Hex High Entropy String", "filename": "src/api/srp_crypto.py", "hashed_secret": "95523fea7e6136c6148299dcc3077debfa2976b3", "is_verified": false, - "line_number": 18 + "line_number": 17 }, { "type": "Hex High Entropy String", "filename": "src/api/srp_crypto.py", "hashed_secret": "c978fb77621e86f5e9077653fe5345ac1616b466", "is_verified": false, - "line_number": 19 + "line_number": 18 }, { "type": "Hex High Entropy String", "filename": "src/api/srp_crypto.py", "hashed_secret": "fc02990268ecf8a35a4912d60dab3754e5f43846", "is_verified": false, - "line_number": 20 + "line_number": 19 }, { "type": "Hex High Entropy String", "filename": "src/api/srp_crypto.py", "hashed_secret": "2c2c0ca491a73e95c8965b6641731057b65f6462", "is_verified": false, - "line_number": 21 + "line_number": 20 }, { "type": "Hex High Entropy String", "filename": "src/api/srp_crypto.py", "hashed_secret": "672b25c6be065170206f3fc6346ebb8e84cbb9d3", "is_verified": false, - "line_number": 22 + "line_number": 21 }, { "type": "Hex High Entropy String", "filename": "src/api/srp_crypto.py", "hashed_secret": "99d02e268ea3ee849fb6e359c6c1b019e4d07efd", "is_verified": false, - "line_number": 23 + "line_number": 22 }, { "type": "Hex High Entropy String", "filename": "src/api/srp_crypto.py", "hashed_secret": "e677fc4cb09d99e1e0d30af31f2e209e541e380e", "is_verified": false, - "line_number": 24 + "line_number": 23 }, { "type": "Hex High Entropy String", "filename": "src/api/srp_crypto.py", "hashed_secret": "05b69b06f40cae0c910a15b1ac75b1f7a847eccb", "is_verified": false, - "line_number": 25 + "line_number": 24 }, { "type": "Hex High Entropy String", "filename": "src/api/srp_crypto.py", "hashed_secret": "c7f914bac2d66eb3f8ae3888fa47bf1ada6caaf5", "is_verified": false, - "line_number": 26 + "line_number": 25 } ], "tests/unit/test_device_registration.py": [ @@ -419,21 +419,21 @@ "filename": "tests/unit/test_device_registration.py", "hashed_secret": "d8bce9746547bb7743e5933fbf0fc4f2d2cbcad3", "is_verified": false, - "line_number": 710 + "line_number": 660 }, { "type": "Secret Keyword", "filename": "tests/unit/test_device_registration.py", "hashed_secret": "e4f50034475acff058e17b35679f8ef1e54f86c5", "is_verified": false, - "line_number": 783 + "line_number": 733 }, { "type": "Secret Keyword", "filename": "tests/unit/test_device_registration.py", "hashed_secret": "6ab013c213c685b1f1b1a452796bf22afbd44699", "is_verified": false, - "line_number": 794 + "line_number": 744 } ], "tests/unit/test_hive_auth.py": [ @@ -486,14 +486,14 @@ "filename": "tests/unit/test_hive_auth_async.py", "hashed_secret": "5c5a15a8b0b3e154d77746945e563ba40100681b", "is_verified": false, - "line_number": 150 + "line_number": 165 }, { "type": "Secret Keyword", "filename": "tests/unit/test_hive_auth_async.py", "hashed_secret": "d8bce9746547bb7743e5933fbf0fc4f2d2cbcad3", "is_verified": false, - "line_number": 206 + "line_number": 221 } ], "tests/unit/test_hive_auth_async_extended.py": [ @@ -502,42 +502,42 @@ "filename": "tests/unit/test_hive_auth_async_extended.py", "hashed_secret": "5c5a15a8b0b3e154d77746945e563ba40100681b", "is_verified": false, - "line_number": 259 + "line_number": 260 }, { "type": "Secret Keyword", "filename": "tests/unit/test_hive_auth_async_extended.py", "hashed_secret": "d8bce9746547bb7743e5933fbf0fc4f2d2cbcad3", "is_verified": false, - "line_number": 340 + "line_number": 341 }, { "type": "Secret Keyword", "filename": "tests/unit/test_hive_auth_async_extended.py", "hashed_secret": "76f6b6f16cb41692b330fc806029e8a31e20b69b", "is_verified": false, - "line_number": 815 + "line_number": 816 }, { "type": "Secret Keyword", "filename": "tests/unit/test_hive_auth_async_extended.py", "hashed_secret": "b3ed2cf313e7546085c3c50622143ff31e467d23", "is_verified": false, - "line_number": 834 + "line_number": 835 }, { "type": "Secret Keyword", "filename": "tests/unit/test_hive_auth_async_extended.py", "hashed_secret": "7476b69b5005e05d536361f960a9d18b736dfbfc", "is_verified": false, - "line_number": 848 + "line_number": 849 }, { "type": "Secret Keyword", "filename": "tests/unit/test_hive_auth_async_extended.py", "hashed_secret": "ff9f30d9ba5a4ec386edddeacc27f74ef412085e", "is_verified": false, - "line_number": 855 + "line_number": 856 }, { "type": "Secret Keyword", @@ -553,21 +553,21 @@ "filename": "tests/unit/test_hive_helper_extended.py", "hashed_secret": "701b389b848a2b1cfab867093101d8d5ac56addd", "is_verified": false, - "line_number": 134 + "line_number": 102 }, { "type": "Secret Keyword", "filename": "tests/unit/test_hive_helper_extended.py", "hashed_secret": "18960546905b75c869e7de63961dc185f9a0a7c9", "is_verified": false, - "line_number": 141 + "line_number": 109 }, { "type": "Secret Keyword", "filename": "tests/unit/test_hive_helper_extended.py", "hashed_secret": "fbf52ca8a72d8ecd77235d3b3e5d014e19ffbff2", "is_verified": false, - "line_number": 143 + "line_number": 111 } ], "tests/unit/test_session_discovery_extended.py": [ @@ -580,5 +580,5 @@ } ] }, - "generated_at": "2026-05-17T16:44:49Z" + "generated_at": "2026-05-24T17:39:03Z" } diff --git a/src/__init__.py b/src/__init__.py index bf6f4a0..13762b9 100644 --- a/src/__init__.py +++ b/src/__init__.py @@ -12,7 +12,10 @@ from .helper.const import SMS_REQUIRED from .helper.hive_exceptions import ( HiveApiError, + HiveAuthCredentialError, HiveAuthError, + HiveConfigurationError, + HiveError, HiveFailedToRefreshTokens, HiveInvalid2FACode, HiveInvalidDeviceAuthentication, diff --git a/src/api/device_registration.py b/src/api/device_registration.py index 5013ed8..c330fe2 100644 --- a/src/api/device_registration.py +++ b/src/api/device_registration.py @@ -183,14 +183,11 @@ async def confirm_device(self, device_name: str | None = None): ), ) except botocore.exceptions.ClientError as err: - if err.__class__.__name__ in ( - "NotAuthorizedException", - "CodeMismatchException", - ): + code = (err.response or {}).get("Error", {}).get("Code", "") + if code in ("NotAuthorizedException", "CodeMismatchException"): raise HiveInvalid2FACode from err except botocore.exceptions.EndpointConnectionError as err: - if err.__class__.__name__ == "EndpointConnectionError": - raise HiveApiError from err + raise HiveApiError from err return result @@ -210,8 +207,7 @@ async def update_device_status(self): ), ) except botocore.exceptions.EndpointConnectionError as err: - if err.__class__.__name__ == "EndpointConnectionError": - raise HiveApiError from err + raise HiveApiError from err return result @@ -335,10 +331,10 @@ async def forget_device(self, access_token, device_key): ), ) except botocore.exceptions.ClientError as err: - if err.__class__.__name__ == "NotAuthorizedException": + code = (err.response or {}).get("Error", {}).get("Code", "") + if code == "NotAuthorizedException": raise HiveInvalid2FACode from err except botocore.exceptions.EndpointConnectionError as err: - if err.__class__.__name__ == "ResourceNotFoundException": - raise HiveApiError from err + raise HiveApiError from err return result diff --git a/src/api/hive_async_api.py b/src/api/hive_async_api.py index bc91789..f88ff5e 100644 --- a/src/api/hive_async_api.py +++ b/src/api/hive_async_api.py @@ -6,17 +6,14 @@ import time import requests -import urllib3 from aiohttp import ClientResponse, ClientSession, ClientTimeout, web_exceptions from pyquery import PyQuery -from ..helper.const import HTTP_FORBIDDEN, HTTP_OK, HTTP_UNAUTHORIZED +from ..helper.const import HTTP_FORBIDDEN, HTTP_UNAUTHORIZED from ..helper.hive_exceptions import FileInUse, HiveApiError, HiveAuthError, NoApiToken _LOGGER = logging.getLogger(__name__) -urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) - class HiveApiAsync: """Hive API Code.""" @@ -97,21 +94,19 @@ async def request(self, method: str, url: str, **kwargs) -> ClientResponse: raise HiveAuthError( f"Token expired or forbidden calling {url} — HTTP {resp.status}" ) - if url is not None and resp.status is not None: - _LOGGER.error( - "Something has gone wrong calling %s - HTTP status is - %s — response: %s", - url, - resp.status, - resp_body[:200], - ) - + _LOGGER.error( + "Something has gone wrong calling %s - HTTP status is - %s — response: %s", + url, + resp.status, + resp_body[:200], + ) raise HiveApiError def get_login_info(self): """Get login properties to make the login request.""" url = "https://sso.hivehome.com/" - data = requests.get(url=url, verify=False, timeout=self.timeout) + data = requests.get(url=url, timeout=self.timeout) html = PyQuery(data.content) json_data = json.loads( '{"' @@ -128,33 +123,6 @@ def get_login_info(self): login_data.update({"REGION": json_data["HiveSSOPoolId"]}) return login_data - async def refresh_tokens(self): - """Refresh tokens - DEPRECATED NOW BY AWS TOKEN MANAGEMENT.""" - url = self.urls["refresh"] - if self.session is not None: - tokens = self.session.tokens.token_data - jsc = ( - "{" - + ",".join( - ('"' + str(i) + '": "' + str(t) + '" ' for i, t in tokens.items()) - ) - + "}" - ) - try: - await self.request("post", url, data=jsc) - - if self.json_return["original"] == HTTP_OK: - info = self.json_return["parsed"] - if "token" in info: - await self.session.update_tokens(info) - # pylint: disable-next=invalid-sequence-index - self.base_url = info["platform"]["endpoint"] - return True - except (ConnectionError, OSError, RuntimeError, ZeroDivisionError): - await self.error() - - return self.json_return - async def get_all(self): """Build and query all endpoint.""" json_return = {} @@ -214,8 +182,8 @@ async def motion_sensor(self, sensor, fromepoch, toepoch): """Call a way to get motion sensor info.""" json_return = {} url = ( - self.urls["base"] - + self.urls["products"] + self.base_url + + "/products" + "/" + sensor["type"] + "/" @@ -252,13 +220,7 @@ async def set_state(self, n_type, n_id, **kwargs): """Set the state of a Device.""" _LOGGER.debug("set_state - Setting state for %s/%s: %s", n_type, n_id, kwargs) json_return = {} - jsc = ( - "{" - + ",".join( - ('"' + str(i) + '": "' + str(t) + '" ' for i, t in kwargs.items()) - ) - + "}" - ) + jsc = json.dumps(kwargs) url = self.urls["nodes"].format(n_type, n_id) try: @@ -266,9 +228,9 @@ async def set_state(self, n_type, n_id, **kwargs): resp = await self.request("post", url, data=jsc) json_return["original"] = resp.status json_return["parsed"] = await resp.json(content_type=None) - except (FileInUse, OSError, RuntimeError, ConnectionError) as e: - if e.__class__.__name__ == "FileInUse": - return {"original": "file"} + except FileInUse: + return {"original": "file"} + except (OSError, RuntimeError, ConnectionError): await self.error() return json_return @@ -276,17 +238,20 @@ async def set_state(self, n_type, n_id, **kwargs): async def set_action(self, n_id, data): """Set the state of a Action.""" _LOGGER.debug("Setting action %s", n_id) + json_return = {} jsc = data url = self.urls["actions"] + "/" + n_id try: await self.is_file_being_used() - await self.request("put", url, data=jsc) - except (FileInUse, OSError, RuntimeError, ConnectionError) as e: - if e.__class__.__name__ == "FileInUse": - return {"original": "file"} + resp = await self.request("put", url, data=jsc) + json_return["original"] = resp.status + json_return["parsed"] = await resp.json(content_type=None) + except FileInUse: + return {"original": "file"} + except (OSError, RuntimeError, ConnectionError): await self.error() - return self.json_return + return json_return async def error(self): """An error has occurred interacting with the Hive API.""" diff --git a/src/api/hive_auth_async.py b/src/api/hive_auth_async.py index 4056690..420c47c 100644 --- a/src/api/hive_auth_async.py +++ b/src/api/hive_auth_async.py @@ -23,6 +23,7 @@ HiveInvalidPassword, HiveInvalidUsername, HiveRefreshTokenExpired, + HiveUnknownConfiguration, ) from .device_registration import DeviceRegistrationMixin from .hive_api import HiveApi @@ -58,17 +59,10 @@ def __init__( # pylint: disable=too-many-positional-arguments # noqa: PLR0913 device_group_key: str | None = None, device_key: str | None = None, device_password: str | None = None, - pool_region: str | None = None, client_secret: str | None = None, ): """Initialise async auth.""" - if pool_region is not None: - raise ValueError( - "pool_region and client should not both be specified " - "(region should be passed to the boto3 client instead)" - ) - - self.loop: asyncio.AbstractEventLoop = asyncio.get_event_loop() + self.loop: asyncio.AbstractEventLoop | None = None self.username = username self.password = password self.device_group_key: str | None = device_group_key @@ -95,10 +89,17 @@ def __init__( # pylint: disable=too-many-positional-arguments # noqa: PLR0913 async def async_init(self): """Initialise async variables.""" + self.loop = asyncio.get_running_loop() self.data = await self.loop.run_in_executor(None, self.api.get_login_info) self._pool_id = self.data.get("UPID") self._client_id = self.data.get("CLIID") - self._region = self.data.get("REGION").split("_")[0] + region_raw = self.data.get("REGION") + if not self._pool_id or not region_raw: + raise HiveUnknownConfiguration( + "SSO login page did not return required pool/region data" + ) + self._region = region_raw.split("_")[0] + # Cognito USER_SRP_AUTH does not use IAM credentials — boto3 requires non-None values. self.client = await self.loop.run_in_executor( None, functools.partial( @@ -156,6 +157,8 @@ def get_password_authentication_key(self, username, password, server_b_value, sa u_value = calculate_u(self.large_a_value, server_b_value) if u_value == 0: raise ValueError("U cannot be zero.") + if not self._pool_id or "_" not in self._pool_id: + raise HiveUnknownConfiguration(f"Invalid pool ID format: {self._pool_id!r}") pool_id = self._pool_id.split("_")[1] username_password = f"{pool_id}{username}:{password}" username_password_hash = hash_sha256(username_password.encode("utf-8")) @@ -279,13 +282,13 @@ async def login(self): # noqa: PLR0912 ), ) except botocore.exceptions.ClientError as err: - if err.__class__.__name__ == "UserNotFoundException": + code = (err.response or {}).get("Error", {}).get("Code", "") + if code == "UserNotFoundException": _LOGGER.error("Cognito auth failed: user not found.") raise HiveInvalidUsername from err except botocore.exceptions.EndpointConnectionError as err: - if err.__class__.__name__ == "EndpointConnectionError": - _LOGGER.error("Cognito auth failed: cannot reach endpoint.") - raise HiveApiError from err + _LOGGER.error("Cognito auth failed: cannot reach endpoint.") + raise HiveApiError from err if response["ChallengeName"] == self.PASSWORD_VERIFIER_CHALLENGE: _LOGGER.debug("login - Processing PASSWORD_VERIFIER challenge.") @@ -303,20 +306,18 @@ async def login(self): # noqa: PLR0912 ), ) except botocore.exceptions.ClientError as err: - if err.__class__.__name__ == "NotAuthorizedException": + code = (err.response or {}).get("Error", {}).get("Code", "") + if code == "NotAuthorizedException": _LOGGER.error("Cognito auth challenge failed: not authorised.") raise HiveInvalidPassword from err - if err.__class__.__name__ == "ResourceNotFoundException": + if code == "ResourceNotFoundException": _LOGGER.error( "Cognito auth challenge failed: device resource not found." ) raise HiveInvalidDeviceAuthentication from err except botocore.exceptions.EndpointConnectionError as err: - if err.__class__.__name__ == "EndpointConnectionError": - _LOGGER.error( - "Cognito auth challenge failed: cannot reach endpoint." - ) - raise HiveApiError from err + _LOGGER.error("Cognito auth challenge failed: cannot reach endpoint.") + raise HiveApiError from err _LOGGER.debug("login - SRP auth challenge completed successfully.") @@ -385,10 +386,8 @@ async def device_login(self): raise HiveInvalidDeviceAuthentication from err raise except botocore.exceptions.EndpointConnectionError as err: - if err.__class__.__name__ == "EndpointConnectionError": - _LOGGER.error("Device login failed: cannot reach endpoint.") - raise HiveApiError from err - raise HiveInvalidDeviceAuthentication from err + _LOGGER.error("Device login failed: cannot reach endpoint.") + raise HiveApiError from err _LOGGER.debug("device_login - Device authentication completed successfully.") return result @@ -413,26 +412,24 @@ async def sms_2fa(self, entered_code, challenge_parameters): }, ), ) - self.access_token = result["AuthenticationResult"]["AccessToken"] - self.token_created = datetime.datetime.now() - if "NewDeviceMetadata" in result["AuthenticationResult"]: - self.device_group_key = result["AuthenticationResult"][ - "NewDeviceMetadata" - ]["DeviceGroupKey"] - self.device_key = result["AuthenticationResult"]["NewDeviceMetadata"][ - "DeviceKey" - ] + if result and "AuthenticationResult" in result: + self.access_token = result["AuthenticationResult"]["AccessToken"] + self.token_created = datetime.datetime.now() + if "NewDeviceMetadata" in result["AuthenticationResult"]: + self.device_group_key = result["AuthenticationResult"][ + "NewDeviceMetadata" + ]["DeviceGroupKey"] + self.device_key = result["AuthenticationResult"][ + "NewDeviceMetadata" + ]["DeviceKey"] except botocore.exceptions.ClientError as err: - if err.__class__.__name__ in ( - "NotAuthorizedException", - "CodeMismatchException", - ): + code = (err.response or {}).get("Error", {}).get("Code", "") + if code in ("NotAuthorizedException", "CodeMismatchException"): _LOGGER.error("2FA code rejected by Cognito.") raise HiveInvalid2FACode from err except botocore.exceptions.EndpointConnectionError as err: - if err.__class__.__name__ == "EndpointConnectionError": - _LOGGER.error("2FA failed: cannot reach Cognito endpoint.") - raise HiveApiError from err + _LOGGER.error("2FA failed: cannot reach Cognito endpoint.") + raise HiveApiError from err _LOGGER.debug("sms_2fa - 2FA authentication completed successfully.") return result @@ -478,11 +475,10 @@ async def refresh_token(self, token): ) raise HiveFailedToRefreshTokens from err except botocore.exceptions.EndpointConnectionError as err: - if err.__class__.__name__ == "EndpointConnectionError": - _LOGGER.error( - "refresh_token - Token refresh failed: cannot reach Cognito endpoint." - ) - raise HiveApiError from err + _LOGGER.error( + "refresh_token - Token refresh failed: cannot reach Cognito endpoint." + ) + raise HiveApiError from err _LOGGER.debug("refresh_token - Cognito token refresh completed successfully.") return result diff --git a/src/api/srp_crypto.py b/src/api/srp_crypto.py index faf141b..74cea33 100644 --- a/src/api/srp_crypto.py +++ b/src/api/srp_crypto.py @@ -1,7 +1,6 @@ """Pure SRP/HKDF crypto helpers for AWS Cognito authentication.""" import binascii -import concurrent.futures import hashlib import hmac import os @@ -28,7 +27,6 @@ # https://github.com/aws/amazon-cognito-identity-js/blob/master/src/AuthenticationHelper.js#L49 G_HEX = "2" INFO_BITS = bytearray("Caldera Derived Key", "utf-8") -POOL = concurrent.futures.ThreadPoolExecutor() def hex_to_long(hex_string): diff --git a/src/devices/boost.py b/src/devices/boost.py index 2e3f1b7..b1e22e9 100644 --- a/src/devices/boost.py +++ b/src/devices/boost.py @@ -29,7 +29,7 @@ async def get_boost_status(self, device: Device): data = self.session.data.products[device.hive_id] return HIVETOHA["Boost"].get(data["state"].get("boost", False), "ON") except KeyError as e: - _LOGGER.error(e) + _LOGGER.error("get_boost_status - KeyError for %s: %s", device.ha_name, e) return None async def get_boost_time(self, device: Device): @@ -43,5 +43,5 @@ async def get_boost_time(self, device: Device): data = self.session.data.products[device.hive_id] return data["state"]["boost"] except KeyError as e: - _LOGGER.error(e) + _LOGGER.error("get_boost_time - KeyError for %s: %s", device.ha_name, e) return None diff --git a/src/devices/color.py b/src/devices/color.py index ade4cb2..2442ae9 100644 --- a/src/devices/color.py +++ b/src/devices/color.py @@ -33,7 +33,7 @@ async def get_min_color_temp(self, device: Device): data = self.session.data.products[device.hive_id] state = data["props"]["colourTemperature"]["max"] return round((1 / state) * 1000000) - except KeyError as e: + except (KeyError, ZeroDivisionError) as e: _LOGGER.error(e) return None @@ -50,7 +50,7 @@ async def get_max_color_temp(self, device: Device): data = self.session.data.products[device.hive_id] state = data["props"]["colourTemperature"]["min"] return round((1 / state) * 1000000) - except KeyError as e: + except (KeyError, ZeroDivisionError) as e: _LOGGER.error(e) return None @@ -67,7 +67,7 @@ async def get_color_temp(self, device: Device): data = self.session.data.products[device.hive_id] state = data["state"]["colourTemperature"] return round((1 / state) * 1000000) - except KeyError as e: + except (KeyError, ZeroDivisionError) as e: _LOGGER.error(e) return None diff --git a/src/devices/heating.py b/src/devices/heating.py index 2b81e1e..2b07d51 100644 --- a/src/devices/heating.py +++ b/src/devices/heating.py @@ -170,6 +170,7 @@ async def get_mode(self, device: Device): """ state = None final = None + device_name = device.ha_name try: data = self.session.data.products[device.hive_id] @@ -178,7 +179,7 @@ async def get_mode(self, device: Device): state = data["props"]["previous"]["mode"] final = HIVETOHA[self.heating_type].get(state, state) except KeyError as e: - _LOGGER.error(e) + _LOGGER.error("get_mode - KeyError getting mode for %s: %s", device_name, e) return final diff --git a/src/devices/hotwater.py b/src/devices/hotwater.py index b6985b2..2de73d9 100644 --- a/src/devices/hotwater.py +++ b/src/devices/hotwater.py @@ -33,6 +33,7 @@ async def get_mode(self, device: Device): """ state = None final = None + device_name = device.ha_name try: data = self.session.data.products[device.hive_id] @@ -41,7 +42,7 @@ async def get_mode(self, device: Device): state = data["props"]["previous"]["mode"] final = HIVETOHA[self.hotwater_type].get(state, state) except KeyError as e: - _LOGGER.error(e) + _LOGGER.error("get_mode - KeyError getting mode for %s: %s", device_name, e) return final diff --git a/src/devices/sensor.py b/src/devices/sensor.py index 9f0a431..bed1ebb 100644 --- a/src/devices/sensor.py +++ b/src/devices/sensor.py @@ -4,12 +4,36 @@ from typing import Any from ..helper.compat_aliases import SensorCompatMixin -from ..helper.const import HIVE_TYPES, HIVETOHA, sensor_commands +from ..helper.const import HIVE_TYPES, HIVETOHA from ..helper.device_handler_base import BaseDeviceHandler from ..helper.hivedataclasses import Device _LOGGER = logging.getLogger(__name__) +sensor_commands = { + "SMOKE_CO": lambda s, d: s.session.hub.get_smoke_status(d), + "DOG_BARK": lambda s, d: s.session.hub.get_dog_bark_status(d), + "GLASS_BREAK": lambda s, d: s.session.hub.get_glass_break_status(d), + "Current_Temperature": lambda s, d: s.session.heating.get_current_temperature(d), + "Heating_Current_Temperature": lambda s, d: ( + s.session.heating.get_current_temperature(d) + ), + "Heating_Target_Temperature": lambda s, d: s.session.heating.get_target_temperature( + d + ), + "Heating_State": lambda s, d: s.session.heating.get_state(d), + "Heating_Mode": lambda s, d: s.session.heating.get_mode(d), + "Heating_Boost": lambda s, d: s.session.heating.get_boost_status(d), + "Hotwater_State": lambda s, d: s.session.hotwater.get_state(d), + "Hotwater_Mode": lambda s, d: s.session.hotwater.get_mode(d), + "Hotwater_Boost": lambda s, d: s.session.hotwater.get_boost(d), + "Battery": lambda s, d: s.session.attr.get_battery(d.device_id), + "Mode": lambda s, d: s.session.attr.get_mode(d.hive_id), + "Availability": lambda s, d: s.online(d), + "Connectivity": lambda s, d: s.online(d), + "Power": lambda s, d: s.session.switch.get_power_usage(d), +} + class HiveSensor(BaseDeviceHandler): """Hive Sensor Code.""" @@ -133,7 +157,7 @@ async def get_sensor(self, device: Device): device.device_data = props device.parent_device = data.get("parent", None) elif device.hive_type in HIVE_TYPES["Sensor"]: - data = self.session.data.devices.get(device.hive_id, {}) + data = self.session.data.devices.get(device.device_id, {}) device.status = {"state": await self.get_state(device)} props = data.get("props") or {} props["online"] = online diff --git a/src/helper/compat_aliases.py b/src/helper/compat_aliases.py index b1fe79e..8cdfc0e 100644 --- a/src/helper/compat_aliases.py +++ b/src/helper/compat_aliases.py @@ -7,6 +7,7 @@ from __future__ import annotations +from datetime import timedelta from typing import Any from .hivedataclasses import Device @@ -138,6 +139,7 @@ async def updateData(self, device: Device): # pylint: disable=invalid-name """Backwards-compatible alias for update_data.""" return await self.update_data(device) # type: ignore[attr-defined] - async def updateInterval(self, new_interval: int): # pylint: disable=invalid-name,unused-argument + async def updateInterval(self, new_interval: int): # pylint: disable=invalid-name """Backwards-compatible alias for Home Assistant Scan Interval.""" + self.config.scan_interval = timedelta(seconds=new_interval) # type: ignore[attr-defined] return True diff --git a/src/helper/const.py b/src/helper/const.py index 7c27601..59d1c7d 100644 --- a/src/helper/const.py +++ b/src/helper/const.py @@ -60,29 +60,6 @@ "Sensor": ["motionsensor", "contactsensor"], "Switch": ["activeplug"], } -sensor_commands = { - "SMOKE_CO": lambda s, d: s.session.hub.get_smoke_status(d), - "DOG_BARK": lambda s, d: s.session.hub.get_dog_bark_status(d), - "GLASS_BREAK": lambda s, d: s.session.hub.get_glass_break_status(d), - "Current_Temperature": lambda s, d: s.session.heating.get_current_temperature(d), - "Heating_Current_Temperature": lambda s, d: ( - s.session.heating.get_current_temperature(d) - ), - "Heating_Target_Temperature": lambda s, d: s.session.heating.get_target_temperature( - d - ), - "Heating_State": lambda s, d: s.session.heating.get_state(d), - "Heating_Mode": lambda s, d: s.session.heating.get_mode(d), - "Heating_Boost": lambda s, d: s.session.heating.get_boost_status(d), - "Hotwater_State": lambda s, d: s.session.hotwater.get_state(d), - "Hotwater_Mode": lambda s, d: s.session.hotwater.get_mode(d), - "Hotwater_Boost": lambda s, d: s.session.hotwater.get_boost(d), - "Battery": lambda s, d: s.session.attr.get_battery(d.device_id), - "Mode": lambda s, d: s.session.attr.get_mode(d.hive_id), - "Availability": lambda s, d: s.online(d), - "Connectivity": lambda s, d: s.online(d), - "Power": lambda s, d: s.session.switch.get_power_usage(d), -} PRODUCTS = { "sense": [ diff --git a/src/helper/hive_exceptions.py b/src/helper/hive_exceptions.py index 4881d12..2a69441 100644 --- a/src/helper/hive_exceptions.py +++ b/src/helper/hive_exceptions.py @@ -19,14 +19,22 @@ class NoApiToken(Exception): """ -class HiveApiError(Exception): - """Api error. +class HiveError(Exception): + """Common base class for all Hive-specific exceptions. Args: Exception (object): Exception object to invoke """ +class HiveApiError(HiveError): + """Api error. + + Args: + HiveError (object): Parent Hive error class + """ + + class HiveAuthError(HiveApiError): """Auth error (401/403) — token may be expired or invalid. @@ -35,65 +43,81 @@ class HiveAuthError(HiveApiError): """ -class HiveRefreshTokenExpired(Exception): +class HiveRefreshTokenExpired(HiveApiError): """Refresh token expired. Args: - Exception (object): Exception object to invoke + HiveApiError (object): Parent API error class """ -class HiveReauthRequired(Exception): - """Re-Authentication is required. +class HiveFailedToRefreshTokens(HiveApiError): + """Raise invalid refresh tokens. Args: - Exception (object): Exception object to invoke + HiveApiError (object): Parent API error class + """ + + +class HiveConfigurationError(HiveError): + """Base class for configuration-related errors. + + Args: + HiveError (object): Parent Hive error class """ -class HiveUnknownConfiguration(Exception): +class HiveUnknownConfiguration(HiveConfigurationError): """Unknown Hive Configuration. Args: - Exception (object): Exception object to invoke + HiveConfigurationError (object): Parent configuration error class """ -class HiveInvalidUsername(Exception): - """Raise invalid Username. +class HiveInvalidDeviceAuthentication(HiveConfigurationError): + """Raise invalid device authentication. Args: - Exception (object): Exception object to invoke + HiveConfigurationError (object): Parent configuration error class """ -class HiveInvalidPassword(Exception): - """Raise invalid password. +class HiveAuthCredentialError(HiveError): + """Base class for authentication credential errors. Args: - Exception (object): Exception object to invoke + HiveError (object): Parent Hive error class """ -class HiveInvalid2FACode(Exception): - """Raise invalid 2FA code. +class HiveInvalidUsername(HiveAuthCredentialError): + """Raise invalid Username. Args: - Exception (object): Exception object to invoke + HiveAuthCredentialError (object): Parent credential error class """ -class HiveInvalidDeviceAuthentication(Exception): - """Raise invalid device authentication. +class HiveInvalidPassword(HiveAuthCredentialError): + """Raise invalid password. Args: - Exception (object): Exception object to invoke + HiveAuthCredentialError (object): Parent credential error class """ -class HiveFailedToRefreshTokens(Exception): - """Raise invalid refresh tokens. +class HiveInvalid2FACode(HiveAuthCredentialError): + """Raise invalid 2FA code. Args: - Exception (object): Exception object to invoke + HiveAuthCredentialError (object): Parent credential error class + """ + + +class HiveReauthRequired(HiveError): + """Re-Authentication is required. + + Args: + HiveError (object): Parent Hive error class """ diff --git a/src/helper/hive_helper.py b/src/helper/hive_helper.py index d052f73..150a027 100644 --- a/src/helper/hive_helper.py +++ b/src/helper/hive_helper.py @@ -8,7 +8,6 @@ from typing import Any from .const import HIVE_TYPES -from .hivedataclasses import Device _LOGGER = logging.getLogger(__name__) @@ -26,7 +25,6 @@ def epoch_time(date_time: Any, pattern: str, action: str) -> Any: Converted value, or ``None`` if *action* is unrecognised. """ if action == "to_epoch": - pattern = "%d.%m.%Y %H:%M:%S" return int(time.mktime(time.strptime(str(date_time), pattern))) if action == "from_epoch": return datetime.datetime.fromtimestamp(int(date_time)).strftime(pattern) @@ -297,19 +295,6 @@ def get_schedule_nnl(self, hive_api_schedule: dict): # pylint: disable=too-many return schedule_now_and_next - def get_heat_on_demand_device(self, device: Device): - """Use TRV device to get the linked thermostat device. - - Args: - device ([dictionary]): [The TRV device to lookup.] - - Returns: - [dictionary]: [Gets the thermostat device linked to TRV.] - """ - trv = self.session.data.products.get(device["HiveID"]) - thermostat = self.session.data.products.get(trv["state"]["zone"]) - return thermostat - def sanitize_payload(self, payload: dict[str, Any]) -> dict[str, Any]: """Return a copy of payload with sensitive values masked for logs.""" diff --git a/src/helper/map.py b/src/helper/map.py index b4bd8ea..3b38c9d 100644 --- a/src/helper/map.py +++ b/src/helper/map.py @@ -10,6 +10,11 @@ class Map(dict): dict (dict): dictionary to map. """ - __getattr__ = dict.get + def __getattr__(self, key): + try: + return self[key] + except KeyError: + raise AttributeError(f"Map has no key {key!r}") from None + __setattr__ = dict.__setitem__ __delattr__ = dict.__delitem__ diff --git a/src/session/polling.py b/src/session/polling.py index 27341fc..abe5859 100644 --- a/src/session/polling.py +++ b/src/session/polling.py @@ -144,7 +144,17 @@ async def get_devices(self, _n_id: str): # pylint: disable=too-many-locals,too- api_call_start = time.monotonic() try: api_resp_d = await self.api.get_all() + api_call_duration = time.monotonic() - api_call_start + if api_call_duration > self._slow_poll_threshold: + _LOGGER.debug( + "get_devices - Hive API response took %.1fs — marking poll as slow.", + api_call_duration, + ) + self._last_poll_slow = True + else: + self._last_poll_slow = False except HiveAuthError: + self._last_poll_slow = False _LOGGER.warning( "Auth error (401/403) after token refresh, " "falling back to full device re-login." @@ -154,15 +164,6 @@ async def get_devices(self, _n_id: str): # pylint: disable=too-many-locals,too- self.api.get_all, reraise_as=HiveReauthRequired, ) - api_call_duration = time.monotonic() - api_call_start - if api_call_duration > self._slow_poll_threshold: - _LOGGER.debug( - "get_devices - Hive API response took %.1fs — marking poll as slow.", - api_call_duration, - ) - self._last_poll_slow = True - else: - self._last_poll_slow = False if not str(api_resp_d["original"]).startswith("2"): raise HTTPException if api_resp_d["parsed"] is None: diff --git a/tests/unit/test_color_extended.py b/tests/unit/test_color_extended.py index 6795724..b605e0e 100644 --- a/tests/unit/test_color_extended.py +++ b/tests/unit/test_color_extended.py @@ -99,3 +99,49 @@ async def test_keyerror_on_missing_product_returns_none(self): result = await handler.get_max_color_temp(device) assert result is None + + +class TestZeroDivisionGuards: + """Colour-temperature methods must return None instead of raising ZeroDivisionError.""" + + async def test_get_min_color_temp_zero_returns_none(self): + """min colourTemperature == 0 must return None, not raise ZeroDivisionError. + + get_min_color_temp reads colourTemperature['max'] and divides by it, + so 'max' must be 0 to trigger ZeroDivisionError. + """ + session = _make_session( + products={ + "light-1": {"props": {"colourTemperature": {"max": 0, "min": 153}}} + } + ) + h = _make_handler(session) + device = _make_device() + result = await h.get_min_color_temp(device) + assert result is None + + async def test_get_max_color_temp_zero_returns_none(self): + """max colourTemperature == 0 must return None, not raise ZeroDivisionError. + + get_max_color_temp reads colourTemperature['min'] and divides by it, + so 'min' must be 0 to trigger ZeroDivisionError. + """ + session = _make_session( + products={ + "light-1": {"props": {"colourTemperature": {"max": 500, "min": 0}}} + } + ) + h = _make_handler(session) + device = _make_device() + result = await h.get_max_color_temp(device) + assert result is None + + async def test_get_color_temp_zero_returns_none(self): + """state colourTemperature == 0 must return None, not raise ZeroDivisionError.""" + session = _make_session( + products={"light-1": {"state": {"colourTemperature": 0}}} + ) + h = _make_handler(session) + device = _make_device() + result = await h.get_color_temp(device) + assert result is None diff --git a/tests/unit/test_compat_aliases.py b/tests/unit/test_compat_aliases.py index 1c066e7..ec6c44e 100644 --- a/tests/unit/test_compat_aliases.py +++ b/tests/unit/test_compat_aliases.py @@ -11,7 +11,7 @@ SwitchCompatMixin, WaterHeaterCompatMixin, ) -from apyhiveapi.helper.hivedataclasses import Device +from apyhiveapi.helper.hivedataclasses import Device, SessionConfig def _make_device(): @@ -319,13 +319,58 @@ class Stub(SessionCompatMixin): assert s.deviceList is s.device_list async def test_update_interval_returns_true(self): - """updateInterval always returns True (deprecated no-op).""" + """updateInterval returns True and updates config.scan_interval.""" class Stub(SessionCompatMixin): """Stub for updateInterval test.""" device_list = {} + def __init__(self): + self.config = SessionConfig() + s = Stub() result = await s.updateInterval(60) assert result is True + + +# --------------------------------------------------------------------------- +# SessionCompatMixin.updateInterval — bug fix tests +# --------------------------------------------------------------------------- + + +def _make_concrete_session(): + """Return a minimal SessionCompatMixin subclass with a real SessionConfig.""" + + class ConcreteSession(SessionCompatMixin): + """Minimal concrete SessionCompatMixin for updateInterval tests.""" + + def __init__(self): + self.config = SessionConfig() + self.device_list = {} + + async def start_session(self, config=None): # pylint: disable=unused-argument + """Stub.""" + + async def update_data(self, device): # pylint: disable=unused-argument + """Stub.""" + + return ConcreteSession() + + +class TestSessionCompatMixinUpdateInterval: + """updateInterval must actually update config.scan_interval.""" + + async def test_update_interval_sets_scan_interval(self): + """updateInterval(300) must set self.config.scan_interval to timedelta(seconds=300).""" + from datetime import timedelta + + session = _make_concrete_session() + await session.updateInterval(300) + assert session.config.scan_interval == timedelta(seconds=300) + + async def test_update_interval_returns_true(self): + """updateInterval must return True on success.""" + session = _make_concrete_session() + result = await session.updateInterval(60) + assert result is True diff --git a/tests/unit/test_device_registration.py b/tests/unit/test_device_registration.py index 372eada..43ad3bc 100644 --- a/tests/unit/test_device_registration.py +++ b/tests/unit/test_device_registration.py @@ -482,32 +482,9 @@ async def test_other_client_error_does_not_raise(self): result = await stub.forget_device("acc-token", "dev-key") assert result is None - async def test_endpoint_error_does_not_raise_api_error(self): - """EndpointConnectionError only raises HiveApiError if class name is - 'ResourceNotFoundException', which can never be true for an - EndpointConnectionError. The exception is therefore silently swallowed.""" + async def test_endpoint_error_raises_api_error(self): stub = await _make_stub() stub.loop.run_in_executor.side_effect = _endpoint_error() - # The guard condition is always False for a real EndpointConnectionError, - # so no exception propagates. - result = await stub.forget_device("acc-token", "dev-key") - assert result is None - - async def test_endpoint_error_named_resource_not_found_raises_api_error(self): - """A subclass of EndpointConnectionError named 'ResourceNotFoundException' - satisfies the guard at line 339 and raises HiveApiError (line 340).""" - stub = await _make_stub() - # Craft a class whose __class__.__name__ == "ResourceNotFoundException" - # but which IS an EndpointConnectionError (so it's caught by the except clause) - resource_cls = type( - "ResourceNotFoundException", - (botocore.exceptions.EndpointConnectionError,), - {}, - ) - resource_err = resource_cls( - endpoint_url="https://cognito.eu-west-1.amazonaws.com" - ) - stub.loop.run_in_executor.side_effect = resource_err with pytest.raises(HiveApiError): await stub.forget_device("acc-token", "dev-key") @@ -636,33 +613,6 @@ async def test_other_client_error_is_swallowed(self): result = await stub.confirm_device("name") assert result is None # no HiveInvalid2FACode raised - async def test_endpoint_error_wrong_name_is_swallowed(self): - """EndpointConnectionError subclass with wrong __name__ is swallowed (190->193).""" - stub = await _make_stub() - stub.generate_hash_device = AsyncMock( - return_value={"PasswordVerifier": "pv", "Salt": "s"} - ) - wrong_cls = type( - "WrongEndpoint", (botocore.exceptions.EndpointConnectionError,), {} - ) - wrong_err = wrong_cls(endpoint_url="https://cognito.eu-west-1.amazonaws.com") - stub.loop.run_in_executor.side_effect = wrong_err - result = await stub.confirm_device("name") - assert result is None # no HiveApiError raised - - -class TestUpdateDeviceStatusSwallowedEndpointError: - async def test_endpoint_error_wrong_name_is_swallowed(self): - """EndpointConnectionError with wrong name is caught but not re-raised (211->214).""" - stub = await _make_stub() - wrong_cls = type( - "WrongEndpoint", (botocore.exceptions.EndpointConnectionError,), {} - ) - wrong_err = wrong_cls(endpoint_url="https://cognito.eu-west-1.amazonaws.com") - stub.loop.run_in_executor.side_effect = wrong_err - result = await stub.update_device_status() - assert result is None # no HiveApiError raised - class TestDeviceRegistration: async def test_calls_confirm_and_update(self): diff --git a/tests/unit/test_helpers.py b/tests/unit/test_helpers.py index 1f84c3e..9f9578f 100644 --- a/tests/unit/test_helpers.py +++ b/tests/unit/test_helpers.py @@ -38,11 +38,7 @@ class TestEpochTime: """Tests for the top-level epoch_time() helper function.""" def test_to_epoch_returns_int(self): - """to_epoch converts a date string to an integer Unix timestamp. - - Note: epoch_time ignores the *pattern* argument for "to_epoch" — - it always applies "%d.%m.%Y %H:%M:%S" internally. - """ + """to_epoch converts a date string to an integer Unix timestamp.""" result = epoch_time("01.01.2024 12:00:00", "%d.%m.%Y %H:%M:%S", "to_epoch") assert isinstance(result, int) diff --git a/tests/unit/test_hive_api.py b/tests/unit/test_hive_api.py index 8b6a182..1cd1093 100644 --- a/tests/unit/test_hive_api.py +++ b/tests/unit/test_hive_api.py @@ -213,119 +213,6 @@ def test_key_error_calls_error_and_returns_none(self): assert result is None -# --------------------------------------------------------------------------- -# Tests: HiveApi.refresh_tokens -# --------------------------------------------------------------------------- - - -class TestRefreshTokens: - def test_successful_with_token_key_updates_session(self): - """When the response contains 'token', session.update_tokens is called.""" - api = _make_api() - refresh_data = { - "token": "new-token", - "platform": {"endpoint": "https://new.endpoint.com"}, - } - mock_resp = _make_mock_response( - 200, json_data=refresh_data, text=json.dumps(refresh_data) - ) - - with patch.object(api, "request", return_value=mock_resp): - result = api.refresh_tokens() - - api.session.update_tokens.assert_called_once_with(refresh_data) - assert result["original"] == 200 - - def test_no_token_in_response_no_session_update(self): - """When response lacks 'token' key, update_tokens is not called.""" - api = _make_api() - response_data = {"other_key": "value"} - mock_resp = _make_mock_response( - 200, json_data=response_data, text=json.dumps(response_data) - ) - - with patch.object(api, "request", return_value=mock_resp): - api.refresh_tokens() - - api.session.update_tokens.assert_not_called() - - def test_none_tokens_defaults_to_empty_dict(self): - """Calling refresh_tokens() without arguments uses session.token_data.""" - api = _make_api() - response_data = {"other": "val"} - mock_resp = _make_mock_response( - 200, json_data=response_data, text=json.dumps(response_data) - ) - - with patch.object(api, "request", return_value=mock_resp) as mock_req: - api.refresh_tokens() - # Should have been called (session provides the tokens dict) - mock_req.assert_called_once() - - def test_os_error_calls_error(self): - api = _make_api() - with patch.object(api, "request", side_effect=OSError("connection failed")): - api.refresh_tokens() - - assert api.json_return["original"] == "Error making API call" - - def test_runtime_error_calls_error(self): - api = _make_api() - with patch.object(api, "request", side_effect=RuntimeError("fail")): - api.refresh_tokens() - - assert api.json_return["original"] == "Error making API call" - - def test_json_decode_error_calls_error(self): - """Bad JSON in response text triggers error().""" - api = _make_api() - mock_resp = MagicMock() - mock_resp.status_code = 200 - mock_resp.text = "not-json" - - with patch.object(api, "request", return_value=mock_resp): - api.refresh_tokens() - - assert api.json_return["original"] == "Error making API call" - - def test_explicit_tokens_arg_skips_none_branch(self): - """Passing a non-None tokens arg covers the 80->82 False branch.""" - api = _make_api() - explicit_tokens = {"key": "val"} - response_data = {"other": "x"} - mock_resp = _make_mock_response(200, json_data=response_data) - - with patch.object(api, "request", return_value=mock_resp): - api.refresh_tokens(tokens=explicit_tokens) - # Session is not None so session tokens overwrite, but no crash - api.session.update_tokens.assert_not_called() - - def test_session_none_skips_token_overwrite(self): - """When session is None the 83->85 False branch is taken (no token overwrite).""" - api = _make_api_no_session(token="standalone-token") - response_data = {"other": "x"} - mock_resp = _make_mock_response(200, json_data=response_data) - - with patch.object(api, "request", return_value=mock_resp): - api.refresh_tokens(tokens={"key": "val"}) - - def test_urls_base_updated_on_token_refresh(self): - """After a successful refresh the base URL is updated from the response.""" - api = _make_api() - refresh_data = { - "token": "new-tok", - "platform": {"endpoint": "https://new-platform.com/1.0"}, - } - mock_resp = _make_mock_response( - 200, json_data=refresh_data, text=json.dumps(refresh_data) - ) - - with patch.object(api, "request", return_value=mock_resp): - api.refresh_tokens() - - assert api.urls["base"] == "https://new-platform.com/1.0" - - # --------------------------------------------------------------------------- # Tests: HiveApi.get_all # --------------------------------------------------------------------------- diff --git a/tests/unit/test_hive_async_api.py b/tests/unit/test_hive_async_api.py index 86c1cd7..c630182 100644 --- a/tests/unit/test_hive_async_api.py +++ b/tests/unit/test_hive_async_api.py @@ -65,50 +65,42 @@ def _make_api_no_token(_url_contains_sso=False): class TestHiveApiAsyncRequest: - @pytest.mark.asyncio async def test_successful_200_returns_response(self): api = _make_api(status=200, json_data={"ok": True}) resp = await api.request("get", "https://beekeeper.hivehome.com/1.0/nodes/all") assert resp.status == 200 - @pytest.mark.asyncio async def test_201_also_succeeds(self): api = _make_api(status=201) resp = await api.request("post", "https://beekeeper.hivehome.com/1.0/nodes/x/y") assert resp.status == 201 - @pytest.mark.asyncio async def test_sso_url_without_token_does_not_raise(self): api = _make_api_no_token() # Should not raise NoApiToken because "sso" is in the URL resp = await api.request("get", "https://sso.hivehome.com/") assert resp.status == 200 - @pytest.mark.asyncio async def test_non_sso_without_token_raises_no_api_token(self): api = _make_api_no_token() with pytest.raises(NoApiToken): await api.request("get", "https://beekeeper.hivehome.com/1.0/nodes/all") - @pytest.mark.asyncio async def test_401_raises_hive_auth_error(self): api = _make_api(status=401) with pytest.raises(HiveAuthError): await api.request("get", "https://beekeeper.hivehome.com/1.0/nodes/all") - @pytest.mark.asyncio async def test_403_raises_hive_auth_error(self): api = _make_api(status=403) with pytest.raises(HiveAuthError): await api.request("get", "https://beekeeper.hivehome.com/1.0/nodes/all") - @pytest.mark.asyncio async def test_500_raises_hive_api_error(self): api = _make_api(status=500) with pytest.raises(HiveApiError): await api.request("get", "https://beekeeper.hivehome.com/1.0/nodes/all") - @pytest.mark.asyncio async def test_404_raises_hive_api_error(self): api = _make_api(status=404) with pytest.raises(HiveApiError): @@ -121,7 +113,6 @@ async def test_404_raises_hive_api_error(self): class TestGetAll: - @pytest.mark.asyncio async def test_successful_get_all_returns_parsed_json(self): payload = {"products": [], "devices": []} api = _make_api(status=200, json_data=payload) @@ -129,21 +120,18 @@ async def test_successful_get_all_returns_parsed_json(self): assert result["original"] == 200 assert result["parsed"] == payload - @pytest.mark.asyncio async def test_timeout_error_propagates(self): api = _make_api(status=200) api.websession.request.side_effect = asyncio.TimeoutError with pytest.raises(asyncio.TimeoutError): await api.get_all() - @pytest.mark.asyncio async def test_os_error_calls_error_method(self): api = _make_api(status=200) api.websession.request.side_effect = OSError("network down") with pytest.raises(web_exceptions.HTTPError): await api.get_all() - @pytest.mark.asyncio async def test_runtime_error_calls_error_method(self): api = _make_api(status=200) api.websession.request.side_effect = RuntimeError("boom") @@ -157,7 +145,6 @@ async def test_runtime_error_calls_error_method(self): class TestGetEndpoints: - @pytest.mark.asyncio async def test_get_devices_returns_parsed_json(self): payload = [{"id": "dev1"}] api = _make_api(status=200, json_data=payload) @@ -165,7 +152,6 @@ async def test_get_devices_returns_parsed_json(self): assert result["original"] == 200 assert result["parsed"] == payload - @pytest.mark.asyncio async def test_get_products_returns_parsed_json(self): payload = [{"id": "prod1"}] api = _make_api(status=200, json_data=payload) @@ -173,7 +159,6 @@ async def test_get_products_returns_parsed_json(self): assert result["original"] == 200 assert result["parsed"] == payload - @pytest.mark.asyncio async def test_get_actions_returns_parsed_json(self): payload = [{"id": "act1"}] api = _make_api(status=200, json_data=payload) @@ -181,21 +166,18 @@ async def test_get_actions_returns_parsed_json(self): assert result["original"] == 200 assert result["parsed"] == payload - @pytest.mark.asyncio async def test_get_devices_os_error_raises_http_error(self): api = _make_api(status=200) api.websession.request.side_effect = OSError with pytest.raises(web_exceptions.HTTPError): await api.get_devices() - @pytest.mark.asyncio async def test_get_products_os_error_raises_http_error(self): api = _make_api(status=200) api.websession.request.side_effect = OSError with pytest.raises(web_exceptions.HTTPError): await api.get_products() - @pytest.mark.asyncio async def test_get_actions_os_error_raises_http_error(self): api = _make_api(status=200) api.websession.request.side_effect = OSError @@ -209,13 +191,11 @@ async def test_get_actions_os_error_raises_http_error(self): class TestSetState: - @pytest.mark.asyncio async def test_file_in_use_returns_file_response(self): api = _make_api(status=200, file_mode=True) result = await api.set_state("heating", "node-1", mode="MANUAL") assert result == {"original": "file"} - @pytest.mark.asyncio async def test_successful_set_state(self): payload = {"id": "node-1", "mode": "MANUAL"} api = _make_api(status=200, json_data=payload) @@ -223,14 +203,12 @@ async def test_successful_set_state(self): assert result["original"] == 200 assert result["parsed"] == payload - @pytest.mark.asyncio async def test_os_error_calls_error_method(self): api = _make_api(status=200) api.websession.request.side_effect = OSError("fail") with pytest.raises(web_exceptions.HTTPError): await api.set_state("heating", "node-1", mode="MANUAL") - @pytest.mark.asyncio async def test_runtime_error_calls_error_method(self): api = _make_api(status=200) api.websession.request.side_effect = RuntimeError("fail") @@ -244,19 +222,24 @@ async def test_runtime_error_calls_error_method(self): class TestSetAction: - @pytest.mark.asyncio async def test_file_in_use_returns_file_response(self): api = _make_api(status=200, file_mode=True) result = await api.set_action("action-1", '{"status": "on"}') assert result == {"original": "file"} - @pytest.mark.asyncio - async def test_successful_set_action_returns_json_return(self): - api = _make_api(status=200) + async def test_successful_set_action_returns_status_200(self): + payload = {"id": "action-1", "status": "on"} + api = _make_api(status=200, json_data=payload) result = await api.set_action("action-1", '{"status": "on"}') - assert result == api.json_return + assert result["original"] == 200 + assert result["parsed"] == payload + + async def test_runtime_error_calls_error_method(self): + api = _make_api(status=200) + api.websession.request.side_effect = RuntimeError("fail") + with pytest.raises(web_exceptions.HTTPError): + await api.set_action("action-1", "{}") - @pytest.mark.asyncio async def test_os_error_calls_error_method(self): api = _make_api(status=200) api.websession.request.side_effect = OSError @@ -264,13 +247,45 @@ async def test_os_error_calls_error_method(self): await api.set_action("action-1", "{}") +# --------------------------------------------------------------------------- +# Tests: HiveApiAsync.motion_sensor +# --------------------------------------------------------------------------- + + +class TestMotionSensor: + async def test_url_does_not_double_base_url(self): + payload = [{"timestamp": 12345}] + api = _make_api(status=200, json_data=payload) + captured = {} + original_request = api.request + + async def capture_request(method, url, **kwargs): + captured["url"] = url + return await original_request(method, url, **kwargs) + + api.request = capture_request + sensor = {"type": "motionsensor", "id": "ms-001"} + await api.motion_sensor(sensor, 1000000, 2000000) + url = captured["url"] + assert url.startswith(api.base_url + "/products/") + assert "motionsensor/ms-001" in url + assert url.count("https://beekeeper") == 1 + + async def test_motion_sensor_returns_parsed_json(self): + payload = [{"timestamp": 12345}] + api = _make_api(status=200, json_data=payload) + sensor = {"type": "motionsensor", "id": "ms-001"} + result = await api.motion_sensor(sensor, 1000000, 2000000) + assert result["original"] == 200 + assert result["parsed"] == payload + + # --------------------------------------------------------------------------- # Tests: HiveApiAsync.error # --------------------------------------------------------------------------- class TestError: - @pytest.mark.asyncio async def test_error_raises_http_error(self): api = _make_api() with pytest.raises(web_exceptions.HTTPError): @@ -283,13 +298,11 @@ async def test_error_raises_http_error(self): class TestIsFileBeingUsed: - @pytest.mark.asyncio async def test_file_mode_raises_file_in_use(self): api = _make_api(file_mode=True) with pytest.raises(FileInUse): await api.is_file_being_used() - @pytest.mark.asyncio async def test_not_file_mode_does_not_raise(self): api = _make_api(file_mode=False) await api.is_file_being_used() # Should not raise diff --git a/tests/unit/test_hive_async_api_extended.py b/tests/unit/test_hive_async_api_extended.py index 5a9848c..eabf190 100644 --- a/tests/unit/test_hive_async_api_extended.py +++ b/tests/unit/test_hive_async_api_extended.py @@ -111,7 +111,7 @@ def test_makes_request_to_sso_url(self): api.get_login_info() mock_get.assert_called_once_with( - url="https://sso.hivehome.com/", verify=False, timeout=api.timeout + url="https://sso.hivehome.com/", timeout=api.timeout ) def test_uses_first_script_tag(self): @@ -135,103 +135,6 @@ def test_uses_first_script_tag(self): assert result["UPID"] == "eu-west-1_first" -# --------------------------------------------------------------------------- -# Tests: refresh_tokens() — lines 131-156 -# --------------------------------------------------------------------------- - - -class TestRefreshTokens: - """Cover lines 133-156: refresh_tokens() success, no-token, and error paths.""" - - async def test_successful_request_with_non_ok_json_return_returns_json_return(self): - """When request succeeds but json_return["original"] != HTTP_OK, returns json_return.""" - api = _make_api(status=200) - # request() will succeed (200) but json_return is not updated by refresh_tokens - # so json_return["original"] stays as the default string, not HTTP_OK (200) - result = await api.refresh_tokens() - # Returns self.json_return (the default dict) - assert result == api.json_return - - async def test_session_tokens_read_before_request(self): - """tokens are read from session.tokens.token_data before constructing the request.""" - api = _make_api(status=200, token="my-session-token") - api.session.tokens.token_data = { - "token": "my-session-token", - "refreshToken": "r-tok", - } - result = await api.refresh_tokens() - # No exception raised — tokens were read without error - assert result is not None - - async def test_connection_error_raises_http_error(self): - """ConnectionError inside the try block causes error() → HTTPError.""" - api = _make_api(status=200) - api.websession.request.side_effect = ConnectionError("connection refused") - with pytest.raises(web_exceptions.HTTPError): - await api.refresh_tokens() - - async def test_os_error_raises_http_error(self): - """OSError inside the try block causes error() → HTTPError.""" - api = _make_api(status=200) - api.websession.request.side_effect = OSError("network error") - with pytest.raises(web_exceptions.HTTPError): - await api.refresh_tokens() - - async def test_runtime_error_raises_http_error(self): - """RuntimeError inside the try block causes error() → HTTPError.""" - api = _make_api(status=200) - api.websession.request.side_effect = RuntimeError("bad state") - with pytest.raises(web_exceptions.HTTPError): - await api.refresh_tokens() - - async def test_zero_division_raises_http_error(self): - """ZeroDivisionError inside the try block causes error() → HTTPError.""" - api = _make_api(status=200) - api.websession.request.side_effect = ZeroDivisionError("division by zero") - with pytest.raises(web_exceptions.HTTPError): - await api.refresh_tokens() - - async def test_json_return_true_when_ok_status_in_json_return(self): - """When json_return["original"] equals HTTP_OK (200) and token is present, - update_tokens is called and base_url is updated, returning True.""" - api = _make_api(status=200) - # Manually set json_return to simulate a successful response - api.json_return = { - "original": 200, - "parsed": { - "token": "new-token", - "platform": {"endpoint": "https://new.endpoint"}, - }, - } - api.session.update_tokens = AsyncMock() - - # Patch request to be a no-op (doesn't modify json_return) - with patch.object(api, "request", new_callable=AsyncMock) as mock_req: - mock_req.return_value = MagicMock() - result = await api.refresh_tokens() - - assert result is True - api.session.update_tokens.assert_called_once_with(api.json_return["parsed"]) - assert api.base_url == "https://new.endpoint" - - async def test_json_return_true_without_token_in_parsed(self): - """When json_return["original"] == HTTP_OK but no 'token' in parsed, - update_tokens is NOT called and returns True.""" - api = _make_api(status=200) - api.json_return = { - "original": 200, - "parsed": {"other_key": "value"}, - } - api.session.update_tokens = AsyncMock() - - with patch.object(api, "request", new_callable=AsyncMock) as mock_req: - mock_req.return_value = MagicMock() - result = await api.refresh_tokens() - - assert result is True - api.session.update_tokens.assert_not_called() - - # --------------------------------------------------------------------------- # Tests: motion_sensor() — lines 213-235 # --------------------------------------------------------------------------- @@ -386,42 +289,35 @@ async def test_connection_error_raises_http_error(self): # --------------------------------------------------------------------------- -# Tests: request() — url=None and resp.status=None skips the logging branch +# Tests: set_state() JSON encoding — Fix A # --------------------------------------------------------------------------- -class TestRequestUrlOrStatusNone: - """Lines 100->108: when url is None or resp.status is None, skip log → raise directly.""" +class TestSetStateJsonEncoding: + """set_state must produce valid JSON even when kwarg values contain special characters.""" - async def test_none_status_skips_log_and_raises_hive_api_error(self): - """resp.status=None causes branch 100->108 (skips the log lines) then raises.""" - api = _make_api(status=200) - # Replace the websession response with one having status=None - bad_resp = _make_mock_response(status=None) - bad_resp.text = AsyncMock(return_value="") - api.websession.request.return_value = bad_resp - with pytest.raises(HiveApiError): - await api.request("get", None) + async def test_set_state_escapes_quotes_in_value(self): + """A value containing double-quotes must produce valid, parseable JSON.""" + import json # noqa: PLC0415 + session = MagicMock() + session.tokens.token_data = {"token": "tok"} + session.config.file = False + api = HiveApiAsync(hive_session=session) + api.urls = {"nodes": "https://beekeeper.hivehome.com/1.0/nodes/{}/{}"} -# --------------------------------------------------------------------------- -# Tests: refresh_tokens() — session=None (134->136) -# --------------------------------------------------------------------------- + captured = {} + + async def fake_request(_method, _url, **kwargs): + captured["data"] = kwargs.get("data") + resp = MagicMock() + resp.status = 200 + resp.json = AsyncMock(return_value={}) + return resp + with patch.object(api, "request", side_effect=fake_request): + with patch.object(api, "is_file_being_used", new=AsyncMock()): + await api.set_state("heating", "node-1", mode='MANUAL"injected') -class TestRefreshTokensSessionNone: - """Line 134->136: when self.session is None, skip token_data read (line 135).""" - - async def test_session_none_skips_token_data_read(self): - """When session is None, tokens is not set from session → jsc uses undefined.""" - ws = MagicMock() - ws.request.return_value = _make_mock_response(status=200) - ws.closed = False - ws.close = AsyncMock() - api = HiveApiAsync(hive_session=None, websession=ws) - # tokens is not defined before jsc, so this will raise NameError or UnboundLocalError; - # what we need is that line 134's False branch (134->136) is traversed. - try: - await api.refresh_tokens() - except (NameError, UnboundLocalError, AttributeError): - pass # expected — tokens was never defined since session is None + parsed = json.loads(captured["data"]) + assert parsed["mode"] == 'MANUAL"injected' diff --git a/tests/unit/test_hive_auth_async.py b/tests/unit/test_hive_auth_async.py index fb541f8..fa6939b 100644 --- a/tests/unit/test_hive_auth_async.py +++ b/tests/unit/test_hive_auth_async.py @@ -79,12 +79,27 @@ async def _make_auth( class TestHiveAuthAsyncInit: - def test_pool_region_raises_value_error(self): + def test_pool_region_no_longer_accepted(self): from apyhiveapi.api.hive_auth_async import HiveAuthAsync - with pytest.raises(ValueError, match="pool_region"): + with pytest.raises(TypeError): HiveAuthAsync(username="u", password="p", pool_region="eu-west-1") + async def test_async_init_sets_running_loop(self): + from apyhiveapi.api.hive_auth_async import HiveAuthAsync + + auth = HiveAuthAsync(username="u@test.com", password="pass") + assert auth.loop is None # not set until async_init + mock_data = { + "UPID": "eu-west-1_Test", + "CLIID": "client-id", + "REGION": "eu-west-1_Test", + } + with patch.object(auth.api, "get_login_info", return_value=mock_data): + with patch("boto3.client", return_value=MagicMock()): + await auth.async_init() + assert auth.loop is not None + async def test_file_flag_set_for_magic_username(self): from apyhiveapi.api.hive_auth_async import HiveAuthAsync @@ -444,6 +459,14 @@ async def test_new_device_metadata_in_sms_stores_keys(self): assert auth.device_group_key == "sms-grp" assert auth.device_key == "sms-dev" + @pytest.mark.asyncio + async def test_no_authentication_result_key_does_not_raise(self): + auth = await _make_auth() + auth.loop.run_in_executor.return_value = {"ChallengeName": "SMS_MFA"} + result = await auth.sms_2fa("123456", {"Session": "sess-1"}) + assert auth.access_token is None + assert result == {"ChallengeName": "SMS_MFA"} + # --------------------------------------------------------------------------- # Tests: refresh_token diff --git a/tests/unit/test_hive_auth_async_extended.py b/tests/unit/test_hive_auth_async_extended.py index 54d408d..bc03922 100644 --- a/tests/unit/test_hive_auth_async_extended.py +++ b/tests/unit/test_hive_auth_async_extended.py @@ -90,13 +90,13 @@ async def test_async_init_sets_pool_id_and_client_id(self): auth.client = None # trigger async_init flow mock_boto_client = MagicMock() - - auth.loop = MagicMock() - auth.loop.run_in_executor = AsyncMock( + mock_loop = MagicMock() + mock_loop.run_in_executor = AsyncMock( side_effect=[_LOGIN_INFO, mock_boto_client] ) - await auth.async_init() + with patch("asyncio.get_running_loop", return_value=mock_loop): + await auth.async_init() assert auth._pool_id == "eu-west-1_TestPool" assert auth._client_id == "test-client-id" @@ -116,12 +116,13 @@ async def test_async_init_splits_region_correctly(self): "REGION": "ap-southeast-2_XyzPool", } mock_boto_client = MagicMock() - auth.loop = MagicMock() - auth.loop.run_in_executor = AsyncMock( + mock_loop = MagicMock() + mock_loop.run_in_executor = AsyncMock( side_effect=[login_info, mock_boto_client] ) - await auth.async_init() + with patch("asyncio.get_running_loop", return_value=mock_loop): + await auth.async_init() assert auth._region == "ap-southeast-2" @@ -712,10 +713,10 @@ async def test_other_client_error_in_initiate_auth_falls_through(self): class TestLoginInitiateAuthSwallowedEndpointError: - """Arc 284->288: EndpointConnectionError caught but class name is wrong.""" + """EndpointConnectionError in initiate_auth always raises HiveApiError.""" - async def test_wrong_name_endpoint_error_in_initiate_auth_falls_through(self): - """EndpointConnectionError with wrong name is swallowed; response stays None.""" + async def test_wrong_name_endpoint_error_in_initiate_auth_raises_api_error(self): + """Any EndpointConnectionError subclass in initiate_auth raises HiveApiError.""" auth = await _make_auth() wrong_cls = type( @@ -724,7 +725,7 @@ async def test_wrong_name_endpoint_error_in_initiate_auth_falls_through(self): wrong_err = wrong_cls(endpoint_url="https://cognito.eu-west-1.amazonaws.com") auth.loop.run_in_executor = AsyncMock(side_effect=wrong_err) - with pytest.raises((TypeError, KeyError)): + with pytest.raises(HiveApiError): await auth.login() @@ -771,10 +772,10 @@ async def test_other_client_error_in_challenge_falls_through(self): class TestLoginChallengeSwallowedEndpointError: - """Arc 313->319: EndpointConnectionError caught with wrong class name in challenge.""" + """EndpointConnectionError in respond_to_auth_challenge always raises HiveApiError.""" - async def test_wrong_name_endpoint_error_in_challenge_falls_through(self): - """EndpointConnectionError with wrong name is swallowed; result stays None.""" + async def test_wrong_name_endpoint_error_in_challenge_raises_api_error(self): + """Any EndpointConnectionError subclass in SRP challenge raises HiveApiError.""" auth = await _make_auth() challenge_response = { @@ -797,7 +798,7 @@ async def test_wrong_name_endpoint_error_in_challenge_falls_through(self): auth.loop.run_in_executor = AsyncMock( side_effect=[challenge_response, wrong_err] ) - with pytest.raises((TypeError, AttributeError)): + with pytest.raises(HiveApiError): await auth.login() @@ -884,11 +885,10 @@ async def test_device_login_calls_second_respond_to_auth_challenge(self): class TestDeviceLoginEndpointWrongName: - """Line 389: EndpointConnectionError with wrong __class__.__name__ raises - HiveInvalidDeviceAuthentication instead of HiveApiError.""" + """Any EndpointConnectionError in device_login always raises HiveApiError.""" - async def test_wrong_name_endpoint_error_raises_invalid_device_auth(self): - """A subclass of EndpointConnectionError with a different name hits line 389.""" + async def test_wrong_name_endpoint_error_raises_api_error(self): + """Any EndpointConnectionError subclass in device_login raises HiveApiError.""" auth = await _make_auth(device_key="dk-err", device_group_key="grp-err") auth.device_password = "dev-pass-err" @@ -898,7 +898,7 @@ async def test_wrong_name_endpoint_error_raises_invalid_device_auth(self): wrong_err = wrong_cls(endpoint_url="https://cognito.eu-west-1.amazonaws.com") auth.loop.run_in_executor = AsyncMock(side_effect=wrong_err) - with pytest.raises(HiveInvalidDeviceAuthentication): + with pytest.raises(HiveApiError): await auth.device_login() @@ -930,10 +930,10 @@ async def test_other_client_error_is_swallowed_returns_none(self): class TestSms2faSwallowedEndpointError: - """Arc 431->435: EndpointConnectionError caught with wrong class name in sms_2fa.""" + """Any EndpointConnectionError in sms_2fa raises HiveApiError.""" - async def test_wrong_name_endpoint_error_is_swallowed(self): - """EndpointConnectionError subclass with wrong name is swallowed; returns None.""" + async def test_wrong_name_endpoint_error_raises_api_error(self): + """Any EndpointConnectionError subclass in sms_2fa raises HiveApiError.""" auth = await _make_auth() wrong_cls = type( @@ -942,8 +942,8 @@ async def test_wrong_name_endpoint_error_is_swallowed(self): wrong_err = wrong_cls(endpoint_url="https://cognito.eu-west-1.amazonaws.com") auth.loop.run_in_executor = AsyncMock(side_effect=wrong_err) - result = await auth.sms_2fa("654321", {"Session": "sess-abc"}) - assert result is None + with pytest.raises(HiveApiError): + await auth.sms_2fa("654321", {"Session": "sess-abc"}) # --------------------------------------------------------------------------- @@ -952,10 +952,10 @@ async def test_wrong_name_endpoint_error_is_swallowed(self): class TestRefreshTokenSwallowedEndpointError: - """Arc 479->485: EndpointConnectionError caught with wrong class name in refresh_token.""" + """Any EndpointConnectionError in refresh_token raises HiveApiError.""" - async def test_wrong_name_endpoint_error_is_swallowed_returns_none(self): - """EndpointConnectionError subclass with wrong name is swallowed; result=None returned.""" + async def test_wrong_name_endpoint_error_raises_api_error(self): + """Any EndpointConnectionError subclass in refresh_token raises HiveApiError.""" auth = await _make_auth() wrong_cls = type( @@ -964,6 +964,58 @@ async def test_wrong_name_endpoint_error_is_swallowed_returns_none(self): wrong_err = wrong_cls(endpoint_url="https://cognito.eu-west-1.amazonaws.com") auth.loop.run_in_executor = AsyncMock(side_effect=wrong_err) - # result initialised to None; exception swallowed; line 485 reached; returns None - result = await auth.refresh_token("some-refresh-token") - assert result is None + with pytest.raises(HiveApiError): + await auth.refresh_token("some-refresh-token") + + +# --------------------------------------------------------------------------- +# Tests: async_init() — missing REGION or UPID keys +# --------------------------------------------------------------------------- + + +class TestAsyncInitMissingKeys: + """async_init must raise HiveUnknownConfiguration when login info keys are absent.""" + + async def test_async_init_missing_region_raises_configuration_error(self): + """If REGION is absent from login info, raise HiveUnknownConfiguration.""" + from apyhiveapi.api.hive_auth_async import HiveAuthAsync + from apyhiveapi.helper.hive_exceptions import HiveUnknownConfiguration + + auth = HiveAuthAsync(username="user@test.com", password="pass") + bad_login_info = {"UPID": "eu-west-1_TestPool", "CLIID": "test-client-id"} + mock_loop = MagicMock() + mock_loop.run_in_executor = AsyncMock(side_effect=[bad_login_info]) + with patch("asyncio.get_running_loop", return_value=mock_loop): + with pytest.raises(HiveUnknownConfiguration): + await auth.async_init() + + async def test_async_init_missing_upid_raises_configuration_error(self): + """If UPID is absent from login info, raise HiveUnknownConfiguration.""" + from apyhiveapi.api.hive_auth_async import HiveAuthAsync + from apyhiveapi.helper.hive_exceptions import HiveUnknownConfiguration + + auth = HiveAuthAsync(username="user@test.com", password="pass") + bad_login_info = {"CLIID": "test-client-id", "REGION": "eu-west-1_TestPool"} + mock_loop = MagicMock() + mock_loop.run_in_executor = AsyncMock(side_effect=[bad_login_info]) + with patch("asyncio.get_running_loop", return_value=mock_loop): + with pytest.raises(HiveUnknownConfiguration): + await auth.async_init() + + +# --------------------------------------------------------------------------- +# Tests: get_password_authentication_key() — None _pool_id +# --------------------------------------------------------------------------- + + +class TestGetPasswordAuthKeyNonePoolId: + """get_password_authentication_key must not crash with AttributeError when _pool_id is None.""" + + async def test_none_pool_id_raises_configuration_error(self): + """If _pool_id is None, raise HiveUnknownConfiguration (not AttributeError).""" + from apyhiveapi.helper.hive_exceptions import HiveUnknownConfiguration + + auth = await _make_auth() + auth._pool_id = None + with pytest.raises(HiveUnknownConfiguration): + auth.get_password_authentication_key("user", "pass", "DEADBEEF", "ABCDEF") diff --git a/tests/unit/test_hive_exceptions.py b/tests/unit/test_hive_exceptions.py new file mode 100644 index 0000000..33ffb5c --- /dev/null +++ b/tests/unit/test_hive_exceptions.py @@ -0,0 +1,93 @@ +"""Unit tests for the hive_exceptions hierarchy.""" + +import pytest +from apyhiveapi.helper.hive_exceptions import ( + FileInUse, + HiveApiError, + HiveAuthCredentialError, + HiveAuthError, + HiveConfigurationError, + HiveError, + HiveFailedToRefreshTokens, + HiveInvalid2FACode, + HiveInvalidDeviceAuthentication, + HiveInvalidPassword, + HiveInvalidUsername, + HiveReauthRequired, + HiveRefreshTokenExpired, + HiveUnknownConfiguration, + NoApiToken, +) + + +class TestHiveErrorBase: + def test_hive_api_error_is_hive_error(self): + assert issubclass(HiveApiError, HiveError) + + def test_hive_auth_error_is_hive_api_error(self): + assert issubclass(HiveAuthError, HiveApiError) + + def test_hive_auth_error_is_hive_error(self): + assert issubclass(HiveAuthError, HiveError) + + def test_hive_refresh_token_expired_is_hive_api_error(self): + assert issubclass(HiveRefreshTokenExpired, HiveApiError) + + def test_hive_failed_to_refresh_is_hive_api_error(self): + assert issubclass(HiveFailedToRefreshTokens, HiveApiError) + + def test_hive_reauth_required_is_hive_error(self): + assert issubclass(HiveReauthRequired, HiveError) + + +class TestCredentialErrors: + def test_invalid_username_is_hive_auth_credential_error(self): + assert issubclass(HiveInvalidUsername, HiveAuthCredentialError) + + def test_invalid_password_is_hive_auth_credential_error(self): + assert issubclass(HiveInvalidPassword, HiveAuthCredentialError) + + def test_invalid_2fa_is_hive_auth_credential_error(self): + assert issubclass(HiveInvalid2FACode, HiveAuthCredentialError) + + def test_auth_credential_error_is_hive_error(self): + assert issubclass(HiveAuthCredentialError, HiveError) + + +class TestConfigurationErrors: + def test_unknown_config_is_hive_configuration_error(self): + assert issubclass(HiveUnknownConfiguration, HiveConfigurationError) + + def test_invalid_device_auth_is_hive_configuration_error(self): + assert issubclass(HiveInvalidDeviceAuthentication, HiveConfigurationError) + + def test_configuration_error_is_hive_error(self): + assert issubclass(HiveConfigurationError, HiveError) + + +class TestStandaloneExceptions: + def test_file_in_use_is_not_hive_error(self): + assert not issubclass(FileInUse, HiveError) + + def test_no_api_token_is_not_hive_error(self): + assert not issubclass(NoApiToken, HiveError) + + def test_file_in_use_is_exception(self): + assert issubclass(FileInUse, Exception) + + def test_no_api_token_is_exception(self): + assert issubclass(NoApiToken, Exception) + + +class TestInstantiable: + def test_hive_error_is_raiseable(self): + with pytest.raises(HiveError): + raise HiveError("test") + + def test_hive_api_error_caught_as_hive_error(self): + with pytest.raises(HiveError): + raise HiveApiError("test") + + def test_invalid_username_caught_as_hive_error(self): + with pytest.raises(HiveError): + raise HiveInvalidUsername("test") diff --git a/tests/unit/test_hive_helper_extended.py b/tests/unit/test_hive_helper_extended.py index 8c13ec3..f215a8d 100644 --- a/tests/unit/test_hive_helper_extended.py +++ b/tests/unit/test_hive_helper_extended.py @@ -59,38 +59,6 @@ def test_returns_false_when_cache_is_empty(self): assert helper.get_device_from_id("any-id") is False -# --------------------------------------------------------------------------- -# get_heat_on_demand_device — lines 315-317 -# --------------------------------------------------------------------------- - - -class TestGetHeatOnDemandDevice: - """Covers HiveHelper.get_heat_on_demand_device (lines 315-317).""" - - def test_returns_linked_thermostat(self): - """Looks up TRV by HiveID, then fetches linked thermostat by zone.""" - trv_id = "trv-001" - thermostat_id = "zone-001" - - trv_data = {"state": {"zone": thermostat_id}, "type": "trvcontrol"} - thermostat_data = {"id": thermostat_id, "type": "heating"} - - products = { - trv_id: trv_data, - thermostat_id: thermostat_data, - } - helper = _make_helper(products=products) - - # Device accessed with dict-style key "HiveID" as used inside the method - device = MagicMock() - device.__getitem__ = MagicMock( - side_effect=lambda k: trv_id if k == "HiveID" else None - ) - - result = helper.get_heat_on_demand_device(device) - assert result == thermostat_data - - # --------------------------------------------------------------------------- # sanitize_payload — list masking (line 329) and non-str/dict/list fallthrough # --------------------------------------------------------------------------- @@ -141,3 +109,29 @@ def test_long_string_partially_masked(self): payload = {"password": "supersecretpassword"} result = helper.sanitize_payload(payload) assert result["password"] == "supe...word" + + +# --------------------------------------------------------------------------- +# epoch_time — to_epoch must honour the pattern argument +# --------------------------------------------------------------------------- + + +class TestEpochTimePattern: + """epoch_time to_epoch must honour the pattern argument.""" + + def test_to_epoch_uses_caller_pattern(self): + """Passing a custom pattern must parse the date string with that pattern.""" + from apyhiveapi.helper.hive_helper import epoch_time + + # ISO date — only parses if the custom pattern is respected + result = epoch_time("2024-06-15", "%Y-%m-%d", "to_epoch") + assert isinstance(result, int), "Expected int epoch timestamp" + assert result > 0 + + def test_to_epoch_standard_hive_format_still_works(self): + """The standard Hive date+time format must still parse correctly.""" + from apyhiveapi.helper.hive_helper import epoch_time + + result = epoch_time("15.06.2024 12:00:00", "%d.%m.%Y %H:%M:%S", "to_epoch") + assert isinstance(result, int) + assert result > 0 diff --git a/tests/unit/test_map.py b/tests/unit/test_map.py index f6209f7..0cd5350 100644 --- a/tests/unit/test_map.py +++ b/tests/unit/test_map.py @@ -1,5 +1,6 @@ """Unit tests for Map — dot-notation dict wrapper.""" +import pytest from apyhiveapi.helper.map import Map @@ -15,10 +16,18 @@ def test_dict_read(): assert m["key"] == "value" -def test_missing_key_returns_none_not_keyerror(): - """Test that missing keys return None instead of raising KeyError.""" +def test_missing_key_raises_attribute_error(): + """Missing attribute access raises AttributeError.""" m = Map({}) - assert m.missing is None + with pytest.raises(AttributeError): + _ = m.missing + + +def test_missing_bracket_key_raises_key_error(): + """Missing bracket access raises KeyError (standard dict behaviour).""" + m = Map({}) + with pytest.raises(KeyError): + _ = m["missing"] def test_nested_access(): diff --git a/tests/unit/test_polling.py b/tests/unit/test_polling.py index e518941..2a18501 100644 --- a/tests/unit/test_polling.py +++ b/tests/unit/test_polling.py @@ -207,3 +207,77 @@ async def test_poll_devices_propagates_false(self): p.get_devices = AsyncMock(return_value=False) result = await p._poll_devices() assert result is False + + +# --------------------------------------------------------------------------- +# TestGetDevicesSlowPoll +# --------------------------------------------------------------------------- + + +class TestGetDevicesSlowPoll: + async def test_auth_error_sets_last_poll_slow_false(self): + from unittest.mock import AsyncMock, MagicMock + + from apyhiveapi.helper.hive_exceptions import HiveAuthError + + p = _make_polling() + p.api = MagicMock() + p.api.get_all = AsyncMock(side_effect=HiveAuthError()) + p.config = MagicMock() + p.config.file = False + p.tokens = MagicMock() + p._last_poll_slow = True # pre-set to True to confirm it gets cleared + + retry_result = { + "original": 200, + "parsed": {"products": [], "devices": [], "actions": []}, + } + + async def fake_retry_login(): + pass + + async def fake_retry_with_backoff(_fn, _reraise_as=None): + return retry_result + + p._retry_login = fake_retry_login + p._retry_with_backoff = fake_retry_with_backoff + p.hive_refresh_tokens = AsyncMock() + p.data = MagicMock() + p.data.products = {} + p.data.devices = {} + p.data.actions = {} + p.config.last_update = MagicMock() + p.config.scan_interval = MagicMock() + + await p.get_devices("No_ID") + assert p._last_poll_slow is False + + async def test_slow_api_call_sets_last_poll_slow_true(self): + from unittest.mock import AsyncMock, MagicMock + + p = _make_polling() + p._slow_poll_threshold = 0 # any call will be "slow" + p.api = MagicMock() + + slow_result = { + "original": 200, + "parsed": {"products": [], "devices": [], "actions": []}, + } + + async def slow_get_all(): + return slow_result + + p.api.get_all = slow_get_all + p.config = MagicMock() + p.config.file = False + p.tokens = MagicMock() + p.hive_refresh_tokens = AsyncMock() + p.data = MagicMock() + p.data.products = {} + p.data.devices = {} + p.data.actions = {} + p.config.last_update = MagicMock() + p.config.scan_interval = MagicMock() + + await p.get_devices("No_ID") + assert p._last_poll_slow is True diff --git a/tests/unit/test_remaining_branches.py b/tests/unit/test_remaining_branches.py index cb5b8fb..2b652d8 100644 --- a/tests/unit/test_remaining_branches.py +++ b/tests/unit/test_remaining_branches.py @@ -525,7 +525,8 @@ class TestSensorGetSensorHiveTypesSensorPath: async def test_contactsensor_in_hive_types_sensor_takes_else_branch(self): """contactsensor is in HIVE_TYPES['Sensor'] and not in sensor_commands key set, so the elif branch is taken.""" - from apyhiveapi.helper.const import HIVE_TYPES, sensor_commands + from apyhiveapi.devices.sensor import sensor_commands + from apyhiveapi.helper.const import HIVE_TYPES # 'contactsensor' is in HIVE_TYPES['Sensor'] and NOT a key in sensor_commands assert "contactsensor" in HIVE_TYPES["Sensor"] @@ -546,7 +547,8 @@ async def test_contactsensor_in_hive_types_sensor_takes_else_branch(self): async def test_motionsensor_in_hive_types_sensor_sets_status(self): """motionsensor is in HIVE_TYPES['Sensor'] and not in sensor_commands key set.""" - from apyhiveapi.helper.const import HIVE_TYPES, sensor_commands + from apyhiveapi.devices.sensor import sensor_commands + from apyhiveapi.helper.const import HIVE_TYPES assert "motionsensor" in HIVE_TYPES["Sensor"] assert "motionsensor" not in sensor_commands diff --git a/tests/unit/test_sensor_extended.py b/tests/unit/test_sensor_extended.py index 68122f3..604c785 100644 --- a/tests/unit/test_sensor_extended.py +++ b/tests/unit/test_sensor_extended.py @@ -2,7 +2,7 @@ # pylint: disable=protected-access -from unittest.mock import AsyncMock, MagicMock +from unittest.mock import AsyncMock, MagicMock, patch from apyhiveapi.devices.sensor import Sensor from apyhiveapi.helper.hivedataclasses import Device, SessionConfig @@ -146,6 +146,39 @@ async def test_contact_sensor_in_hive_types_sets_status(self): assert "state" in result.status session.attr.state_attributes.assert_awaited_once() + async def test_contact_sensor_uses_device_id_not_hive_id_for_props(self): + """HIVE_TYPES['Sensor'] branch must look up data.devices by device_id. + + Before the fix, line 160 used hive_id; data was always {} so + device.parent_device was always None even when the device existed. + """ + hive_id = "prod-abc" + device_id = "dev-xyz" # deliberately different from hive_id + + products = {} # contactsensor is NOT in products + devices = { + device_id: { + "props": {"online": True, "signal": -70}, + "parent": "hub-parent-id", + } + } + session = _make_session(products=products, devices=devices) + session.attr.online_offline = AsyncMock(return_value=True) + + device = _make_device( + hive_id=hive_id, device_id=device_id, hive_type="contactsensor" + ) + device.device_data = {"online": True} + + sensor = Sensor(session) + with patch.object(sensor, "get_state", new=AsyncMock(return_value="CLOSED")): + result = await sensor.get_sensor(device) + + assert result is not None + assert device.parent_device == "hub-parent-id", ( + "parent_device must come from data.devices[device_id], not hive_id lookup" + ) + class TestGetState: """Tests for HiveSensor.get_state covering the motionsensor branch (lines 37-42)."""