diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 6344262..c9853ea 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -33,7 +33,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ['3.9', '3.10', '3.11', '3.12', '3.13'] + python-version: ['3.13', '3.14'] steps: - uses: actions/checkout@v4 diff --git a/CHANGELOG.rst b/CHANGELOG.rst index cf620ea..5510b03 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -2,10 +2,31 @@ Changelog ========= -Version 6.2.0 (2025-12-17) +Version 7.0.0 (2025-12-17) ========================== -**BREAKING CHANGES**: Enumerations refactored for type safety and consistency +**BREAKING CHANGES**: +- Minimum Python version raised to 3.13 +- Enumerations refactored for type safety and consistency + +Removed +------- +- **Python 3.9-3.12 Support**: Minimum Python version is now 3.13 + + Home Assistant has deprecated Python 3.12 support, making Python 3.13 the de facto minimum for this ecosystem. + + Python 3.13 features and improvements: + + - **Experimental free-threaded mode** (PEP 703): Optional GIL removal for true parallelism + - **JIT compiler** (PEP 744): Just-in-time compilation for performance improvements + - **Better error messages**: Enhanced suggestions for NameError, AttributeError, and import errors + - **Type system enhancements**: TypeVars with defaults (PEP 696), @deprecated decorator (PEP 702), ReadOnly TypedDict (PEP 705) + - **Performance**: ~5-10% faster overall, optimized dictionary/set operations, better function calls + - PEP 695: New type parameter syntax for generics + - PEP 701: f-string improvements + - Built-in ``datetime.UTC`` constant + + If you need Python 3.12 support, use version 6.1.x of this library. - **CommandCode moved**: Import from ``nwp500.enums`` instead of ``nwp500.constants`` @@ -22,6 +43,13 @@ Version 6.2.0 (2025-12-17) Added ----- +- **Python 3.12+ Optimizations**: Leverage latest Python features + + - PEP 695: New type parameter syntax (``def func[T](...)`` instead of ``TypeVar``) + - Use ``datetime.UTC`` constant instead of ``datetime.timezone.utc`` + - Native union syntax (``X | Y`` instead of ``Union[X, Y]``) + - Cleaner generic type annotations throughout codebase + - **Enumerations Module (``src/nwp500/enums.py``)**: Comprehensive type-safe enums for device control and status - Status value enums: ``OnOffFlag``, ``Operation``, ``DhwOperationSetting``, ``CurrentOperationMode``, ``HeatSource``, ``DREvent``, ``WaterLevel``, ``FilterChange``, ``RecirculationMode`` diff --git a/README.rst b/README.rst index df122d1..e78744a 100644 --- a/README.rst +++ b/README.rst @@ -175,12 +175,10 @@ The library includes type-safe data models with automatic unit conversions: Requirements ============ -* Python 3.9+ +* Python 3.13+ * aiohttp >= 3.8.0 -* websockets >= 10.0 -* cryptography >= 3.4.0 * pydantic >= 2.0.0 -* awsiotsdk >= 1.21.0 +* awsiotsdk >= 1.27.0 License ======= diff --git a/docs/development/history.rst b/docs/development/history.rst index 12cc5a2..afd7e6e 100644 --- a/docs/development/history.rst +++ b/docs/development/history.rst @@ -18,7 +18,7 @@ providing: - Automatic reconnection with exponential backoff - Command queuing for reliable communication - Historical energy usage data (EMS API) -- Modern Python 3.9+ codebase with native type hints +- Modern Python 3.13+ codebase with native type hints Current Status -------------- @@ -38,7 +38,7 @@ The library is feature-complete with: - Comprehensive documentation - Working examples for all features - Unit tests with good coverage -- Python 3.9+ with modern type hints +- Python 3.13+ with modern type hints Implementation Milestones ------------------------- diff --git a/docs/installation.rst b/docs/installation.rst index b9ca36a..6e3a475 100644 --- a/docs/installation.rst +++ b/docs/installation.rst @@ -5,7 +5,7 @@ Installation Requirements ============ -* Python 3.9 or higher +* Python 3.13 or higher * pip (Python package installer) * Navien Smart Control account @@ -51,7 +51,7 @@ Core Dependencies The library requires: * ``aiohttp>=3.8.0`` - Async HTTP client for REST API -* ``awsiotsdk>=1.20.0`` - AWS IoT SDK for MQTT +* ``awsiotsdk>=1.27.0`` - AWS IoT SDK for MQTT * ``pydantic>=2.0.0`` - Data validation and models Optional Dependencies @@ -128,7 +128,7 @@ The MQTT client requires the AWS IoT SDK: .. code-block:: bash - pip install awsiotsdk>=1.20.0 + pip install awsiotsdk>=1.27.0 Upgrading ========= diff --git a/docs/python_api/models.rst b/docs/python_api/models.rst index 9432cd0..f719e9a 100644 --- a/docs/python_api/models.rst +++ b/docs/python_api/models.rst @@ -292,7 +292,7 @@ Complete real-time device status with 100+ fields. # Water usage detection if status.dhw_use: print("Water usage detected (short-term)") - if status.dhw_useSustained: + if status.dhw_use_sustained: print("Water usage detected (sustained)") # Errors diff --git a/docs/python_api/mqtt_client.rst b/docs/python_api/mqtt_client.rst index 0b1a14a..ec758a2 100644 --- a/docs/python_api/mqtt_client.rst +++ b/docs/python_api/mqtt_client.rst @@ -207,7 +207,7 @@ subscribe_device_status() print(f"Target: {status.dhw_temperature_setting}°F") print(f"Mode: {status.dhw_operation_setting.name}") print(f"Power: {status.current_inst_power}W") - print(f"Energy: {status.available_energy_capacity}%") + print(f"Energy: {status.dhw_charge_per}%") # Check if actively heating if status.operation_busy: @@ -566,11 +566,15 @@ subscribe_energy_usage() print(f"Electric: {energy.total.heat_element_percentage:.1f}%") print("\nDaily Breakdown:") - for day_data in energy.data: - print(f" Date: Day {len(energy.data)}") - print(f" Total: {day_data.total_usage} Wh") - print(f" HP: {day_data.hpUsage} Wh ({day_data.hpTime}h)") - print(f" HE: {day_data.heUsage} Wh ({day_data.heTime}h)") + for monthly_data in energy.usage: + print(f" Month: {monthly_data.year}-{monthly_data.month}") + for day_data in monthly_data.data: + # Skip empty days (all zeros) + if day_data.total_usage > 0: + print(f" Day {monthly_data.data.index(day_data) + 1}:") + print(f" Total: {day_data.total_usage} Wh") + print(f" HP: {day_data.heat_pump_usage} Wh ({day_data.heat_pump_time}h)") + print(f" HE: {day_data.heat_element_usage} Wh ({day_data.heat_element_time}h)") await mqtt.subscribe_energy_usage(device, on_energy) await mqtt.request_energy_usage(device, year=2024, months=[10]) diff --git a/docs/quickstart.rst b/docs/quickstart.rst index a60fb84..595e90e 100644 --- a/docs/quickstart.rst +++ b/docs/quickstart.rst @@ -8,7 +8,7 @@ in just a few minutes. Prerequisites ============= -* Python 3.9 or higher +* Python 3.13 or higher * Navien Smart Control account (via Navilink mobile app) * At least one Navien NWP500 device registered to your account * Valid email and password for your Navien account diff --git a/examples/mask.py b/examples/mask.py index 4d48a84..94628f5 100644 --- a/examples/mask.py +++ b/examples/mask.py @@ -7,15 +7,14 @@ from __future__ import annotations import re -from typing import Optional -def mask_mac(mac: Optional[str]) -> str: +def mask_mac(mac: str | None) -> str: """Always return fully redacted MAC address label, never expose partial values.""" return "[REDACTED_MAC]" -def mask_mac_in_topic(topic: str, mac_addr: Optional[str] = None) -> str: +def mask_mac_in_topic(topic: str, mac_addr: str | None = None) -> str: """Return topic with any MAC-like substrings replaced. Also ensures a direct literal match of mac_addr is redacted. @@ -33,7 +32,7 @@ def mask_mac_in_topic(topic: str, mac_addr: Optional[str] = None) -> str: __all__ = ["mask_mac", "mask_mac_in_topic"] -def mask_any(value: Optional[str]) -> str: +def mask_any(value: str | None) -> str: """Generic redaction for strings considered sensitive in examples. Always returns a short redaction tag; keep implementation simple so examples @@ -51,7 +50,7 @@ def mask_any(value: Optional[str]) -> str: return "[REDACTED]" -def mask_location(city: Optional[str], state: Optional[str]) -> str: +def mask_location(city: str | None, state: str | None) -> str: """Redact location fields for examples. Returns a single redaction tag if either city or state are present. diff --git a/pyproject.toml b/pyproject.toml index fac5689..dc5ec7c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,7 +11,7 @@ version_scheme = "no-guess-dev" [tool.ruff] # Ruff configuration for code formatting and linting line-length = 80 -target-version = "py39" +target-version = "py313" # Exclude directories exclude = [ @@ -99,7 +99,7 @@ line-ending = "auto" strict = true # Python version target -python_version = "3.9" +python_version = "3.13" # Module discovery files = ["src/nwp500", "tests"] diff --git a/scripts/diagnose_mqtt_connection.py b/scripts/diagnose_mqtt_connection.py new file mode 100755 index 0000000..33d5168 --- /dev/null +++ b/scripts/diagnose_mqtt_connection.py @@ -0,0 +1,174 @@ +#!/usr/bin/env python3 +""" +MQTT Connection Diagnostic Tool for Navien Smart Control. + +This script connects to the MQTT broker and monitors connection stability for a +specified duration. It outputs detailed diagnostics including: +- Connection drops and error reasons +- Reconnection attempts +- Session duration statistics +- Message throughput + +Usage: + python scripts/diagnose_mqtt_connection.py [--duration SECONDS] [--verbose] +""" + +import argparse +import asyncio +import contextlib +import logging +import os +import signal +import sys +from datetime import datetime + +# Add src to path to allow running from project root +sys.path.append(os.path.join(os.path.dirname(__file__), "..", "src")) + +try: + from nwp500 import NavienAuthClient, NavienMqttClient + from nwp500.mqtt_utils import MqttConnectionConfig +except ImportError: + print( + "Error: Could not import nwp500 library. " + "Run from project root with installed dependencies." + ) + sys.exit(1) + +# Configure logging +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", +) +logger = logging.getLogger("mqtt_diagnostics") + + +async def main(): + parser = argparse.ArgumentParser(description="MQTT Connection Diagnostics") + parser.add_argument( + "--duration", + type=int, + default=60, + help="Duration to run monitoring in seconds (0 for indefinite)", + ) + parser.add_argument( + "--verbose", action="store_true", help="Enable verbose logging" + ) + parser.add_argument( + "--email", help="Navien account email (or use NAVIEN_EMAIL env var)" + ) + parser.add_argument( + "--password", + help="Navien account password (or use NAVIEN_PASSWORD env var)", + ) + + args = parser.parse_args() + + # Get credentials + email = args.email or os.getenv("NAVIEN_EMAIL") + password = args.password or os.getenv("NAVIEN_PASSWORD") + + if not email or not password: + print("Error: Credentials required.") + print("Set NAVIEN_EMAIL and NAVIEN_PASSWORD environment variables") + print("OR provide --email and --password arguments.") + sys.exit(1) + + if args.verbose: + logging.getLogger("nwp500").setLevel(logging.DEBUG) + logging.getLogger("awscrt").setLevel(logging.DEBUG) + + print(f"Starting MQTT diagnostics for {email}") + print(f"Monitoring duration: {args.duration} seconds") + print("Press Ctrl+C to stop early and generate report") + print("-" * 60) + + try: + async with NavienAuthClient(email, password) as auth_client: + # Configure connection for investigation + config = MqttConnectionConfig( + auto_reconnect=True, + max_reconnect_attempts=5, + enable_command_queue=True, + ) + + mqtt_client = NavienMqttClient(auth_client, config=config) + + # Setup signal handler for graceful shutdown + stop_event = asyncio.Event() + + def signal_handler(): + print("\nStopping diagnostics...") + stop_event.set() + + loop = asyncio.get_running_loop() + for sig in (signal.SIGINT, signal.SIGTERM): + loop.add_signal_handler(sig, signal_handler) + + # Connect + logger.info("Connecting to MQTT broker...") + await mqtt_client.connect() + logger.info("Connected!") + + # Start monitoring loop + start_time = datetime.now() + conn_start = start_time + + while not stop_event.is_set(): + # Check duration + if ( + args.duration > 0 + and (datetime.now() - start_time).total_seconds() + > args.duration + ): + logger.info("Duration reached.") + break + + # Print periodic status + if (datetime.now() - conn_start).total_seconds() >= 10: + metrics = mqtt_client.diagnostics.get_metrics() + uptime = metrics.current_session_uptime_seconds + drops = metrics.total_connection_drops + reconnects = metrics.connection_recovered + + status = ( + "Connected" + if mqtt_client.is_connected + else "Disconnected" + ) + print( + f"Status: {status} | " + f"Uptime: {uptime:.1f}s | " + f"Drops: {drops} | " + f"Reconnects: {reconnects}" + ) + conn_start = datetime.now() + + await asyncio.sleep(1) + + # Final Summary + print("\n" + "=" * 60) + print("DIAGNOSTIC REPORT") + print("=" * 60) + mqtt_client.diagnostics.print_summary() + + # Export JSON + json_report = mqtt_client.diagnostics.export_json() + report_file = ( + f"mqtt_diag_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json" + ) + with open(report_file, "w") as f: + f.write(json_report) + print(f"\nDetailed JSON report saved to: {report_file}") + + # Disconnect + await mqtt_client.disconnect() + + except Exception as e: + logger.error(f"Diagnostic error: {e}", exc_info=True) + sys.exit(1) + + +if __name__ == "__main__": + with contextlib.suppress(KeyboardInterrupt): + asyncio.run(main()) diff --git a/scripts/extract_changelog.py b/scripts/extract_changelog.py index 0f5110e..22c6038 100755 --- a/scripts/extract_changelog.py +++ b/scripts/extract_changelog.py @@ -8,25 +8,25 @@ def extract_version_notes(changelog_path: str, version: str) -> str: """ Extract the changelog section for a specific version. - + Args: changelog_path: Path to CHANGELOG.rst version: Version string (e.g., "6.0.4") - + Returns: The changelog content for that version """ - with open(changelog_path, "r") as f: + with open(changelog_path) as f: content = f.read() - + # Match the version header and capture everything until the next version # Pattern: "Version X.Y.Z (DATE)" followed by "===" line - version_pattern = rf"Version {re.escape(version)} \([^)]+\)\n=+\n(.*?)(?=\nVersion \d+\.\d+\.\d+ \([^)]+\)\n=+|$)" - + version_pattern = rf"Version {re.escape(version)} \([^)]+\)\n=+\n(.*?)(?=\nVersion \d+\.\d+\.\d+ \([^)]+\)\n=+|$)" # noqa: E501 + match = re.search(version_pattern, content, re.DOTALL) if not match: return f"Release {version}" - + notes = match.group(1).strip() return notes @@ -35,9 +35,9 @@ def extract_version_notes(changelog_path: str, version: str) -> str: if len(sys.argv) != 3: print("Usage: extract_changelog.py ") sys.exit(1) - + changelog_path = sys.argv[1] version = sys.argv[2].lstrip("v") # Remove 'v' prefix if present - + notes = extract_version_notes(changelog_path, version) print(notes) diff --git a/scripts/lint.py b/scripts/lint.py index 89bbeb9..3ff7801 100644 --- a/scripts/lint.py +++ b/scripts/lint.py @@ -48,7 +48,7 @@ def main(): lint_commands = [ ( [ - "python3", + sys.executable, "-m", "ruff", "check", @@ -60,7 +60,7 @@ def main(): ), ( [ - "python3", + sys.executable, "-m", "ruff", "format", diff --git a/setup.cfg b/setup.cfg index 2f0e1a3..602876f 100644 --- a/setup.cfg +++ b/setup.cfg @@ -32,10 +32,6 @@ classifiers = Development Status :: 4 - Beta Programming Language :: Python Programming Language :: Python :: 3 - Programming Language :: Python :: 3.9 - Programming Language :: Python :: 3.10 - Programming Language :: Python :: 3.11 - Programming Language :: Python :: 3.12 Programming Language :: Python :: 3.13 Programming Language :: Python :: 3 :: Only @@ -48,7 +44,7 @@ package_dir = =src # Require a min/specific Python version (comma-separated conditions) -python_requires = >=3.9 +python_requires = >=3.13 # Add here dependencies of your project (line-separated), e.g. requests>=2.2,<3.0. # Version specifiers like >=2.2,<3.0 avoid problems due to API changes in @@ -56,7 +52,7 @@ python_requires = >=3.9 # For more information, check out https://semver.org/. install_requires = aiohttp>=3.8.0 - awsiotsdk>=1.26.0 + awsiotsdk>=1.27.0 pydantic>=2.0.0 diff --git a/src/nwp500/api_client.py b/src/nwp500/api_client.py index 93d58dc..170126f 100644 --- a/src/nwp500/api_client.py +++ b/src/nwp500/api_client.py @@ -5,7 +5,7 @@ """ import logging -from typing import Any, Optional +from typing import Any, Self import aiohttp @@ -46,7 +46,7 @@ def __init__( self, auth_client: NavienAuthClient, base_url: str = API_BASE_URL, - session: Optional[aiohttp.ClientSession] = None, + session: aiohttp.ClientSession | None = None, ): """ Initialize Navien API client. @@ -80,7 +80,7 @@ def __init__( ) self._owned_auth = False # Never own auth_client - async def __aenter__(self) -> "NavienAPIClient": + async def __aenter__(self) -> Self: """Enter async context manager.""" return self @@ -92,8 +92,8 @@ async def _make_request( self, method: str, endpoint: str, - json_data: Optional[dict[str, Any]] = None, - params: Optional[dict[str, Any]] = None, + json_data: dict[str, Any | None] = None, + params: dict[str, Any | None] = None, retry_on_auth_failure: bool = True, ) -> dict[str, Any]: """ @@ -392,7 +392,7 @@ async def update_push_token( # Convenience methods - async def get_first_device(self) -> Optional[Device]: + async def get_first_device(self) -> Device | None: """ Get the first device associated with the user. @@ -408,6 +408,6 @@ def is_authenticated(self) -> bool: return self._auth_client.is_authenticated @property - def user_email(self) -> Optional[str]: + def user_email(self) -> str | None: """Get current user email.""" return self._auth_client.user_email diff --git a/src/nwp500/auth.py b/src/nwp500/auth.py index 52e1772..8ce0774 100644 --- a/src/nwp500/auth.py +++ b/src/nwp500/auth.py @@ -14,7 +14,7 @@ import json import logging from datetime import datetime, timedelta -from typing import Any, Optional +from typing import Any, Self import aiohttp from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator @@ -70,16 +70,16 @@ class AuthTokens(NavienBaseModel): access_token: str = "" refresh_token: str = "" authentication_expires_in: int = 3600 - access_key_id: Optional[str] = None - secret_key: Optional[str] = None - session_token: Optional[str] = None - authorization_expires_in: Optional[int] = None + access_key_id: str | None = None + secret_key: str | None = None + session_token: str | None = None + authorization_expires_in: int | None = None # Calculated fields issued_at: datetime = Field(default_factory=datetime.now) _expires_at: datetime = PrivateAttr() - _aws_expires_at: Optional[datetime] = PrivateAttr(default=None) + _aws_expires_at: datetime | None = PrivateAttr(default=None) @model_validator(mode="before") @classmethod @@ -271,9 +271,9 @@ def __init__( user_id: str, password: str, base_url: str = API_BASE_URL, - session: Optional[aiohttp.ClientSession] = None, + session: aiohttp.ClientSession | None = None, timeout: int = 30, - stored_tokens: Optional[AuthTokens] = None, + stored_tokens: AuthTokens | None = None, ): """ Initialize the authentication client. @@ -302,8 +302,8 @@ def __init__( self._password = password # Current authentication state - self._auth_response: Optional[AuthenticationResponse] = None - self._user_email: Optional[str] = None + self._auth_response: AuthenticationResponse | None = None + self._user_email: str | None = None # Restore tokens if provided if stored_tokens: @@ -315,7 +315,7 @@ def __init__( ) self._user_email = user_id - async def __aenter__(self) -> "NavienAuthClient": + async def __aenter__(self) -> Self: """Async context manager entry.""" if self._owned_session: self._session = self._create_session() @@ -557,7 +557,7 @@ async def re_authenticate(self) -> AuthenticationResponse: _logger.info("Re-authenticating with stored credentials") return await self.sign_in(self._user_id, self._password) - async def ensure_valid_token(self) -> Optional[AuthTokens]: + async def ensure_valid_token(self) -> AuthTokens | None: """ Ensure we have a valid access token, refreshing if necessary. @@ -598,17 +598,17 @@ def is_authenticated(self) -> bool: return self._auth_response is not None @property - def current_user(self) -> Optional[UserInfo]: + def current_user(self) -> UserInfo | None: """Get current authenticated user info.""" return self._auth_response.user_info if self._auth_response else None @property - def current_tokens(self) -> Optional[AuthTokens]: + def current_tokens(self) -> AuthTokens | None: """Get current authentication tokens.""" return self._auth_response.tokens if self._auth_response else None @property - def user_email(self) -> Optional[str]: + def user_email(self) -> str | None: """Get the email address of the authenticated user.""" return self._user_email diff --git a/src/nwp500/cli/commands.py b/src/nwp500/cli/commands.py index 4a4ea3d..c82e321 100644 --- a/src/nwp500/cli/commands.py +++ b/src/nwp500/cli/commands.py @@ -3,7 +3,7 @@ import asyncio import json import logging -from typing import Any, Optional +from typing import Any from nwp500 import Device, DeviceFeature, DeviceStatus, NavienMqttClient from nwp500.exceptions import MqttError, Nwp500Error, ValidationError @@ -15,7 +15,7 @@ async def get_controller_serial_number( mqtt: NavienMqttClient, device: Device, timeout: float = 10.0 -) -> Optional[str]: +) -> str | None: """Retrieve controller serial number from device. Args: @@ -40,7 +40,7 @@ def on_feature(feature: DeviceFeature) -> None: serial_number = await asyncio.wait_for(future, timeout=timeout) _logger.info(f"Controller serial number retrieved: {serial_number}") return serial_number - except asyncio.TimeoutError: + except TimeoutError: _logger.error("Timed out waiting for controller serial number.") return None @@ -62,7 +62,7 @@ def on_status(status: DeviceStatus) -> None: try: await asyncio.wait_for(future, timeout=10) - except asyncio.TimeoutError: + except TimeoutError: _logger.error("Timed out waiting for device status response.") @@ -103,7 +103,7 @@ def raw_callback(topic: str, message: dict[str, Any]) -> None: try: await asyncio.wait_for(future, timeout=10) - except asyncio.TimeoutError: + except TimeoutError: _logger.error("Timed out waiting for device status response.") @@ -132,7 +132,7 @@ def on_device_info(info: Any) -> None: try: await asyncio.wait_for(future, timeout=10) - except asyncio.TimeoutError: + except TimeoutError: _logger.error("Timed out waiting for device info response.") @@ -230,7 +230,7 @@ def on_status_response(status: DeviceStatus) -> None: "Mode command sent but no status response received" ) - except asyncio.TimeoutError: + except TimeoutError: _logger.error("Timed out waiting for mode change confirmation") except ValidationError as e: @@ -301,7 +301,7 @@ def on_status_response(status: DeviceStatus) -> None: "Temperature command sent but no status response received" ) - except asyncio.TimeoutError: + except TimeoutError: _logger.error( "Timed out waiting for temperature change confirmation" ) @@ -374,7 +374,7 @@ def on_power_change_response(status: DeviceStatus) -> None: ) ) - except asyncio.TimeoutError: + except TimeoutError: _logger.error(f"Timed out waiting for power {action} confirmation") except MqttError as e: @@ -457,7 +457,7 @@ def raw_callback(topic: str, message: dict[str, Any]) -> None: try: await asyncio.wait_for(future, timeout=10) - except asyncio.TimeoutError: + except TimeoutError: _logger.error("Timed out waiting for reservation response.") @@ -499,7 +499,7 @@ def raw_callback(topic: str, message: dict[str, Any]) -> None: try: await asyncio.wait_for(future, timeout=10) - except asyncio.TimeoutError: + except TimeoutError: _logger.error("Timed out waiting for reservation update response.") @@ -591,7 +591,7 @@ def on_status_response(status: DeviceStatus) -> None: _logger.info(f"TOU {action} successful.") else: _logger.warning("TOU command sent but no response received") - except asyncio.TimeoutError: + except TimeoutError: _logger.error(f"Timed out waiting for TOU {action} confirmation") except MqttError as e: @@ -622,5 +622,5 @@ def raw_callback(topic: str, message: dict[str, Any]) -> None: try: await asyncio.wait_for(future, timeout=15) - except asyncio.TimeoutError: + except TimeoutError: _logger.error("Timed out waiting for energy usage response.") diff --git a/src/nwp500/cli/token_storage.py b/src/nwp500/cli/token_storage.py index 676a9f7..904e704 100644 --- a/src/nwp500/cli/token_storage.py +++ b/src/nwp500/cli/token_storage.py @@ -3,7 +3,6 @@ import json import logging from pathlib import Path -from typing import Optional from nwp500.auth import AuthTokens @@ -31,7 +30,7 @@ def save_tokens(tokens: AuthTokens, email: str) -> None: _logger.error(f"Failed to save tokens: {e}") -def load_tokens() -> tuple[Optional[AuthTokens], Optional[str]]: +def load_tokens() -> tuple[AuthTokens | None, str | None]: """ Load authentication tokens and user email from a file. diff --git a/src/nwp500/encoding.py b/src/nwp500/encoding.py index e447fe9..57078d0 100644 --- a/src/nwp500/encoding.py +++ b/src/nwp500/encoding.py @@ -8,7 +8,6 @@ from collections.abc import Iterable from numbers import Real -from typing import Union from .exceptions import ParameterValidationError, RangeValidationError @@ -35,7 +34,7 @@ # ============================================================================ -def encode_week_bitfield(days: Iterable[Union[str, int]]) -> int: +def encode_week_bitfield(days: Iterable[str | int]) -> int: """ Convert a collection of day names or indices into a reservation bitfield. @@ -323,8 +322,8 @@ def decode_reservation_hex(hex_string: str) -> list[dict[str, int]]: def build_reservation_entry( *, - enabled: Union[bool, int], - days: Iterable[Union[str, int]], + enabled: bool | int, + days: Iterable[str | int], hour: int, minute: int, mode_id: int, @@ -424,13 +423,13 @@ def build_reservation_entry( def build_tou_period( *, season_months: Iterable[int], - week_days: Iterable[Union[str, int]], + week_days: Iterable[str | int], start_hour: int, start_minute: int, end_hour: int, end_minute: int, - price_min: Union[int, Real], - price_max: Union[int, Real], + price_min: int | Real, + price_max: int | Real, decimal_point: int, ) -> dict[str, int]: """Build a TOU (Time of Use) period entry. diff --git a/src/nwp500/events.py b/src/nwp500/events.py index b2e1b3b..81140ca 100644 --- a/src/nwp500/events.py +++ b/src/nwp500/events.py @@ -10,8 +10,9 @@ import inspect import logging from collections import defaultdict +from collections.abc import Callable from dataclasses import dataclass -from typing import Any, Callable, Optional +from typing import Any __author__ = "Emmanuel Levijarvi" __copyright__ = "Emmanuel Levijarvi" @@ -147,7 +148,7 @@ def once( ) def off( - self, event: str, callback: Optional[Callable[..., Any]] = None + self, event: str, callback: Callable[..., Any | None] = None ) -> int: """ Remove event listener(s). @@ -326,7 +327,7 @@ def event_names(self) -> list[str]: """ return list(self._listeners.keys()) - def remove_all_listeners(self, event: Optional[str] = None) -> int: + def remove_all_listeners(self, event: str | None = None) -> int: """ Remove all listeners for an event, or all listeners for all events. @@ -360,7 +361,7 @@ def remove_all_listeners(self, event: Optional[str] = None) -> int: async def wait_for( self, event: str, - timeout: Optional[float] = None, + timeout: float | None = None, ) -> tuple[Any, ...]: """ Wait for an event to be emitted. @@ -406,7 +407,7 @@ def handler(*args: Any, **kwargs: Any) -> None: # Return just args for simplicity (most common case) return args_tuple - except asyncio.TimeoutError: + except TimeoutError: # Remove the listener on timeout self.off(event, handler) raise diff --git a/src/nwp500/exceptions.py b/src/nwp500/exceptions.py index 6958c3e..ab3f2cd 100644 --- a/src/nwp500/exceptions.py +++ b/src/nwp500/exceptions.py @@ -64,7 +64,7 @@ # handle other validation errors """ -from typing import Any, Optional +from typing import Any __author__ = "Emmanuel Levijarvi" __copyright__ = "Emmanuel Levijarvi" @@ -89,8 +89,8 @@ def __init__( self, message: str, *, - error_code: Optional[str] = None, - details: Optional[dict[str, Any]] = None, + error_code: str | None = None, + details: dict[str, Any | None] = None, retriable: bool = False, ): """Initialize base exception. @@ -151,8 +151,8 @@ class AuthenticationError(Nwp500Error): def __init__( self, message: str, - status_code: Optional[int] = None, - response: Optional[dict[str, Any]] = None, + status_code: int | None = None, + response: dict[str, Any | None] = None, **kwargs: Any, ): """Initialize authentication error. @@ -218,8 +218,8 @@ class APIError(Nwp500Error): def __init__( self, message: str, - code: Optional[int] = None, - response: Optional[dict[str, Any]] = None, + code: int | None = None, + response: dict[str, Any | None] = None, **kwargs: Any, ): """Initialize API error. @@ -337,7 +337,7 @@ class ParameterValidationError(ValidationError): def __init__( self, message: str, - parameter: Optional[str] = None, + parameter: str | None = None, value: Any = None, **kwargs: Any, ): @@ -376,7 +376,7 @@ class RangeValidationError(ValidationError): def __init__( self, message: str, - field: Optional[str] = None, + field: str | None = None, value: Any = None, min_value: Any = None, max_value: Any = None, diff --git a/src/nwp500/models.py b/src/nwp500/models.py index d6aacca..89f5510 100644 --- a/src/nwp500/models.py +++ b/src/nwp500/models.py @@ -7,7 +7,7 @@ """ import logging -from typing import Annotated, Any, Optional, Union +from typing import Annotated, Any from pydantic import BaseModel, BeforeValidator, ConfigDict, Field from pydantic.alias_generators import to_camel @@ -143,7 +143,7 @@ def model_dump(self, **kwargs: Any) -> dict[str, Any]: @staticmethod def _convert_enums_to_names( - data: Any, visited: Optional[set[int]] = None + data: Any, visited: set[int | None] = None ) -> Any: """Recursively convert Enum values to their names. @@ -194,21 +194,21 @@ class DeviceInfo(NavienBaseModel): home_seq: int = 0 mac_address: str = "" additional_value: str = "" - device_type: Union[DeviceType, int] = DeviceType.NPF700_WIFI + device_type: DeviceType | int = DeviceType.NPF700_WIFI device_name: str = "Unknown" connected: int = 0 - install_type: Optional[str] = None + install_type: str | None = None class Location(NavienBaseModel): """Location information for a device.""" - state: Optional[str] = None - city: Optional[str] = None - address: Optional[str] = None - latitude: Optional[float] = None - longitude: Optional[float] = None - altitude: Optional[float] = None + state: str | None = None + city: str | None = None + address: str | None = None + latitude: float | None = None + longitude: float | None = None + altitude: float | None = None class Device(NavienBaseModel): @@ -223,11 +223,11 @@ class FirmwareInfo(NavienBaseModel): mac_address: str = "" additional_value: str = "" - device_type: Union[DeviceType, int] = DeviceType.NPF700_WIFI + device_type: DeviceType | int = DeviceType.NPF700_WIFI cur_sw_code: int = 0 cur_version: int = 0 - downloaded_version: Optional[int] = None - device_group: Optional[str] = None + downloaded_version: int | None = None + device_group: str | None = None class TOUSchedule(NavienBaseModel): @@ -256,9 +256,9 @@ def model_validate( cls, obj: Any, *, - strict: Optional[bool] = None, - from_attributes: Optional[bool] = None, - context: Optional[dict[str, Any]] = None, + strict: bool | None = None, + from_attributes: bool | None = None, + context: dict[str, Any | None] = None, **kwargs: Any, ) -> "TOUInfo": # Handle nested structure where fields are in 'touInfo' @@ -914,7 +914,7 @@ class DeviceFeature(NavienBaseModel): "(1=USA, complies with FCC Part 15 Class B)" ) ) - model_type_code: Union[UnitType, int] = Field( + model_type_code: UnitType | int = Field( description="Model type identifier: NWP500 series model variant" ) control_type_code: int = Field( @@ -1148,14 +1148,14 @@ class MqttRequest(NavienBaseModel): """MQTT command request payload.""" command: int - device_type: Union[DeviceType, int] + device_type: DeviceType | int mac_address: str additional_value: str = "..." - mode: Optional[str] = None - param: list[Union[int, float]] = Field(default_factory=list) + mode: str | None = None + param: list[int | float] = Field(default_factory=list) param_str: str = "" - month: Optional[list[int]] = None - year: Optional[int] = None + month: list[int] | None = None + year: int | None = None class MqttCommand(NavienBaseModel): @@ -1165,7 +1165,7 @@ class MqttCommand(NavienBaseModel): session_id: str = Field(alias="sessionID") request_topic: str response_topic: str - request: Union[MqttRequest, dict[str, Any]] + request: MqttRequest | dict[str, Any] protocol_version: int = 2 @@ -1233,9 +1233,7 @@ class EnergyUsageResponse(NavienBaseModel): total: EnergyUsageTotal usage: list[MonthlyEnergyData] - def get_month_data( - self, year: int, month: int - ) -> Optional[MonthlyEnergyData]: + def get_month_data(self, year: int, month: int) -> MonthlyEnergyData | None: """Get energy usage data for a specific month. Args: diff --git a/src/nwp500/mqtt_client.py b/src/nwp500/mqtt_client.py index f4b6bab..0623935 100644 --- a/src/nwp500/mqtt_client.py +++ b/src/nwp500/mqtt_client.py @@ -15,8 +15,8 @@ import json import logging import uuid -from collections.abc import Sequence -from typing import TYPE_CHECKING, Any, Callable +from collections.abc import Callable, Sequence +from typing import TYPE_CHECKING, Any from awscrt import mqtt from awscrt.exceptions import AwsCrtError @@ -42,6 +42,7 @@ from .mqtt_command_queue import MqttCommandQueue from .mqtt_connection import MqttConnection from .mqtt_device_control import MqttDeviceController +from .mqtt_diagnostics import MqttDiagnosticsCollector from .mqtt_periodic import MqttPeriodicRequestManager from .mqtt_reconnection import MqttReconnectionHandler from .mqtt_subscriptions import MqttSubscriptionManager @@ -186,6 +187,9 @@ def __init__( self._reconnect_task: asyncio.Task[None] | None = None self._periodic_manager: MqttPeriodicRequestManager | None = None + # Diagnostics collector + self._diagnostics = MqttDiagnosticsCollector() + # Connection state (simpler than checking _connection_manager) self._connection: mqtt.Connection | None = None self._connected = False @@ -238,6 +242,29 @@ def _on_connection_interrupted_internal( if self._reconnection_handler and self.config.auto_reconnect: self._reconnection_handler.on_connection_interrupted(error) + # Record diagnostic event + active_subs = 0 + if self._subscription_manager: + # Access protected subscriber count for diagnostics + # pylint: disable=protected-access + active_subs = len(self._subscription_manager._subscriptions) + + # Record drop asynchronously + self._schedule_coroutine( + self._diagnostics.record_connection_drop( + error=error, + reconnect_attempt=( + self._reconnection_handler.attempt_count + if self._reconnection_handler + else 0 + ), + active_subscriptions=active_subs, + queued_commands=( + self._command_queue.count if self._command_queue else 0 + ), + ) + ) + def _on_connection_resumed_internal( self, return_code: Any, session_present: Any ) -> None: @@ -259,6 +286,16 @@ def _on_connection_resumed_internal( return_code, session_present ) + # Record diagnostic event + self._schedule_coroutine( + self._diagnostics.record_connection_success( + event_type="resumed", + session_present=session_present, + return_code=return_code, + attempt_number=0, # Reset on success + ) + ) + # Send any queued commands if self.config.enable_command_queue and self._command_queue: self._schedule_coroutine(self._send_queued_commands_internal()) @@ -521,6 +558,16 @@ async def connect(self) -> bool: ) _logger.info("All components initialized successfully") + + # Record diagnostic event + self._schedule_coroutine( + self._diagnostics.record_connection_success( + event_type="connected", + session_present=False, # Initial connect + attempt_number=0, + ) + ) + return True return False @@ -1486,3 +1533,8 @@ async def reset_reconnect(self) -> None: """ if self._reconnection_handler: self._reconnection_handler.reset() + + @property + def diagnostics(self) -> MqttDiagnosticsCollector: + """Get the diagnostics collector instance.""" + return self._diagnostics diff --git a/src/nwp500/mqtt_command_queue.py b/src/nwp500/mqtt_command_queue.py index 9729359..f15da7c 100644 --- a/src/nwp500/mqtt_command_queue.py +++ b/src/nwp500/mqtt_command_queue.py @@ -5,10 +5,13 @@ and automatically sends them when the connection is restored. """ +from __future__ import annotations + import asyncio import logging -from datetime import datetime -from typing import TYPE_CHECKING, Any, Callable +from collections.abc import Callable +from datetime import UTC, datetime +from typing import TYPE_CHECKING, Any from awscrt import mqtt @@ -37,7 +40,7 @@ class MqttCommandQueue: new commands (FIFO with overflow dropping). """ - def __init__(self, config: "MqttConnectionConfig"): + def __init__(self, config: MqttConnectionConfig): """ Initialize the command queue. @@ -45,7 +48,7 @@ def __init__(self, config: "MqttConnectionConfig"): config: MQTT connection configuration with queue settings """ self.config = config - # Use asyncio.Queue instead of deque for better async support + # Python 3.10+ handles asyncio.Queue initialization without running loop self._queue: asyncio.Queue[QueuedCommand] = asyncio.Queue( maxsize=config.max_queued_commands ) @@ -73,7 +76,10 @@ def enqueue( return command = QueuedCommand( - topic=topic, payload=payload, qos=qos, timestamp=datetime.utcnow() + topic=topic, + payload=payload, + qos=qos, + timestamp=datetime.now(UTC), ) # If queue is full, drop oldest command first diff --git a/src/nwp500/mqtt_connection.py b/src/nwp500/mqtt_connection.py index bafd51f..dea8255 100644 --- a/src/nwp500/mqtt_connection.py +++ b/src/nwp500/mqtt_connection.py @@ -9,7 +9,8 @@ import asyncio import json import logging -from typing import TYPE_CHECKING, Any, Callable, Optional, Union +from collections.abc import Callable +from typing import TYPE_CHECKING, Any from awscrt import mqtt from awscrt.exceptions import AwsCrtError @@ -47,10 +48,10 @@ def __init__( self, config: "MqttConnectionConfig", auth_client: "NavienAuthClient", - on_connection_interrupted: Optional[ - Callable[[mqtt.Connection, Exception], None] - ] = None, - on_connection_resumed: Optional[Callable[[Any, Any], None]] = None, + on_connection_interrupted: ( + Callable[[mqtt.Connection, Exception], None] | None + ) = None, + on_connection_resumed: Callable[[Any, Any | None], None] | None = None, ): """ Initialize connection manager. @@ -83,7 +84,7 @@ def __init__( self.config = config self._auth_client = auth_client - self._connection: Optional[mqtt.Connection] = None + self._connection: mqtt.Connection | None = None self._connected = False self._on_connection_interrupted = on_connection_interrupted self._on_connection_resumed = on_connection_resumed @@ -236,7 +237,7 @@ async def subscribe( self, topic: str, qos: mqtt.QoS, - callback: Optional[Callable[..., None]] = None, + callback: Callable[..., None] | None = None, ) -> tuple[Any, int]: """ Subscribe to an MQTT topic. @@ -320,7 +321,7 @@ async def unsubscribe(self, topic: str) -> int: async def publish( self, topic: str, - payload: Union[str, dict[str, Any], Any], + payload: str | dict[str, Any, Any], qos: mqtt.QoS = mqtt.QoS.AT_LEAST_ONCE, ) -> int: """ @@ -398,7 +399,7 @@ def is_connected(self) -> bool: return self._connected @property - def connection(self) -> Optional[mqtt.Connection]: + def connection(self) -> mqtt.Connection | None: """Get the underlying MQTT connection. Returns: diff --git a/src/nwp500/mqtt_device_control.py b/src/nwp500/mqtt_device_control.py index 0dd007c..fd9b7e5 100644 --- a/src/nwp500/mqtt_device_control.py +++ b/src/nwp500/mqtt_device_control.py @@ -14,9 +14,9 @@ """ import logging -from collections.abc import Awaitable, Sequence +from collections.abc import Awaitable, Callable, Sequence from datetime import datetime -from typing import Any, Callable, Optional +from typing import Any from .enums import CommandCode, DhwOperationSetting from .exceptions import ParameterValidationError, RangeValidationError @@ -187,7 +187,7 @@ async def set_dhw_mode( self, device: Device, mode_id: int, - vacation_days: Optional[int] = None, + vacation_days: int | None = None, ) -> int: """ Set DHW (Domestic Hot Water) operation mode. diff --git a/src/nwp500/mqtt_diagnostics.py b/src/nwp500/mqtt_diagnostics.py index 43cb58f..d5ca647 100644 --- a/src/nwp500/mqtt_diagnostics.py +++ b/src/nwp500/mqtt_diagnostics.py @@ -8,13 +8,16 @@ - Client-side configuration issues (insufficient keep-alive, poor backoff) """ +from __future__ import annotations + import json import logging import time from collections import defaultdict +from collections.abc import Callable from dataclasses import asdict, dataclass, field -from datetime import datetime -from typing import Any, Callable, Optional +from datetime import UTC, datetime +from typing import Any from awscrt.exceptions import AwsCrtError @@ -30,11 +33,11 @@ class ConnectionDropEvent: """Record of a single connection drop event.""" timestamp: str # ISO 8601 timestamp - error_name: Optional[str] = None - error_message: Optional[str] = None - error_code: Optional[int] = None + error_name: str | None = None + error_message: str | None = None + error_code: int | None = None reconnect_attempt: int = 0 - duration_connected_seconds: Optional[float] = None + duration_connected_seconds: float | None = None active_subscriptions: int = 0 queued_commands: int = 0 @@ -50,9 +53,9 @@ class ConnectionEvent: timestamp: str # ISO 8601 timestamp event_type: str # "connected", "resumed", "deep_reconnected" session_present: bool = False - return_code: Optional[int] = None + return_code: int | None = None attempt_number: int = 0 - time_to_reconnect_seconds: Optional[float] = None + time_to_reconnect_seconds: float | None = None def to_dict(self) -> dict[str, Any]: """Convert to dictionary.""" @@ -82,8 +85,8 @@ class MqttMetrics: ) # Bucketed by attempt count # Recent activity - last_drop_timestamp: Optional[str] = None - last_successful_connect_timestamp: Optional[str] = None + last_drop_timestamp: str | None = None + last_successful_connect_timestamp: str | None = None connection_recovered: int = 0 # Number of successful reconnections # QoS tracking @@ -135,10 +138,10 @@ def __init__( self._metrics = MqttMetrics() # Session tracking - self._session_start_time: Optional[float] = None + self._session_start_time: float | None = None self._session_duration_history: list[float] = [] - self._last_connection_timestamp: Optional[str] = None - self._last_drop_timestamp: Optional[float] = None + self._last_connection_timestamp: str | None = None + self._last_drop_timestamp: float | None = None # Error categorization self._aws_error_name_counts: dict[str, int] = defaultdict(int) @@ -162,7 +165,7 @@ def on_connection_drop( async def record_connection_drop( self, - error: Optional[Exception] = None, + error: Exception | None = None, reconnect_attempt: int = 0, active_subscriptions: int = 0, queued_commands: int = 0, @@ -176,7 +179,7 @@ async def record_connection_drop( active_subscriptions: Number of active subscriptions at time of drop queued_commands: Number of commands in the queue """ - now = datetime.utcnow().isoformat() + "Z" + now = datetime.now(UTC).isoformat() duration = None if self._session_start_time is not None: @@ -254,7 +257,7 @@ async def record_connection_success( self, event_type: str = "connected", session_present: bool = False, - return_code: Optional[int] = None, + return_code: int | None = None, attempt_number: int = 0, ) -> None: """ @@ -266,7 +269,7 @@ async def record_connection_success( return_code: MQTT return code attempt_number: Reconnection attempt number (0 = initial connect) """ - now = datetime.utcnow().isoformat() + "Z" + now = datetime.now(UTC).isoformat() time_to_reconnect = None # Update metrics @@ -367,7 +370,7 @@ def export_json(self) -> str: JSON string suitable for storing or sending to monitoring systems """ export_data = { - "timestamp": datetime.utcnow().isoformat() + "Z", + "timestamp": datetime.now(UTC).isoformat(), "metrics": self.get_metrics().to_dict(), "recent_drops": [ event.to_dict() for event in self.get_recent_drops(50) diff --git a/src/nwp500/mqtt_periodic.py b/src/nwp500/mqtt_periodic.py index 87b6c11..c43f8b9 100644 --- a/src/nwp500/mqtt_periodic.py +++ b/src/nwp500/mqtt_periodic.py @@ -9,10 +9,13 @@ - Per-device, per-type task management """ +from __future__ import annotations + import asyncio import contextlib import logging -from typing import Any, Callable, Optional +from collections.abc import Callable +from typing import Any from awscrt.exceptions import AwsCrtError @@ -237,7 +240,7 @@ async def periodic_request() -> None: async def stop_periodic_requests( self, device: Device, - request_type: Optional[PeriodicRequestType] = None, + request_type: PeriodicRequestType | None = None, ) -> None: """ Stop sending periodic requests for a device. @@ -288,9 +291,7 @@ async def stop_periodic_requests( + (f" (type={request_type.value})" if request_type else "") ) - async def stop_all_periodic_tasks( - self, reason: Optional[str] = None - ) -> None: + async def stop_all_periodic_tasks(self, reason: str | None = None) -> None: """ Stop all periodic request tasks. diff --git a/src/nwp500/mqtt_reconnection.py b/src/nwp500/mqtt_reconnection.py index efde312..bd6e316 100644 --- a/src/nwp500/mqtt_reconnection.py +++ b/src/nwp500/mqtt_reconnection.py @@ -5,11 +5,13 @@ the MQTT connection is interrupted. """ +from __future__ import annotations + import asyncio import contextlib import logging -from collections.abc import Awaitable -from typing import TYPE_CHECKING, Any, Callable, Optional +from collections.abc import Awaitable, Callable +from typing import TYPE_CHECKING, Any from awscrt.exceptions import AwsCrtError @@ -33,12 +35,12 @@ class MqttReconnectionHandler: def __init__( self, - config: "MqttConnectionConfig", + config: MqttConnectionConfig, is_connected_func: Callable[[], bool], schedule_coroutine_func: Callable[[Any], None], reconnect_func: Callable[[], Awaitable[None]], - deep_reconnect_func: Optional[Callable[[], Awaitable[None]]] = None, - emit_event_func: Optional[Callable[..., Awaitable[Any]]] = None, + deep_reconnect_func: Callable[[], Awaitable[None]] | None = None, + emit_event_func: Callable[..., Awaitable[Any]] | None = None, ): """ Initialize reconnection handler. @@ -62,7 +64,7 @@ def __init__( self._emit_event = emit_event_func self._reconnect_attempts = 0 - self._reconnect_task: Optional[asyncio.Task[None]] = None + self._reconnect_task: asyncio.Task[None] | None = None self._manual_disconnect = False self._enabled = False diff --git a/src/nwp500/mqtt_subscriptions.py b/src/nwp500/mqtt_subscriptions.py index 6de2b71..71006e5 100644 --- a/src/nwp500/mqtt_subscriptions.py +++ b/src/nwp500/mqtt_subscriptions.py @@ -9,10 +9,13 @@ - State change detection and event emission """ +from __future__ import annotations + import asyncio import json import logging -from typing import Any, Callable, Optional +from collections.abc import Callable +from typing import Any from awscrt import mqtt from awscrt.exceptions import AwsCrtError @@ -67,7 +70,7 @@ def __init__( ] = {} # Track previous state for change detection - self._previous_status: Optional[DeviceStatus] = None + self._previous_status: DeviceStatus | None = None @property def subscriptions(self) -> dict[str, mqtt.QoS]: diff --git a/src/nwp500/mqtt_utils.py b/src/nwp500/mqtt_utils.py index c6a4570..c77983a 100644 --- a/src/nwp500/mqtt_utils.py +++ b/src/nwp500/mqtt_utils.py @@ -5,12 +5,14 @@ configuration classes, and common data structures used across MQTT modules. """ +from __future__ import annotations + import re import uuid from dataclasses import dataclass from datetime import datetime from enum import Enum -from typing import Any, Optional +from typing import Any from awscrt import mqtt @@ -29,7 +31,7 @@ ] -def redact(obj: Any, keys_to_redact: Optional[set[str]] = None) -> Any: +def redact(obj: Any, keys_to_redact: set[str] | None = None) -> Any: """Return a redacted copy of obj with sensitive keys masked. This is a lightweight sanitizer for log messages to avoid emitting @@ -166,7 +168,7 @@ class MqttConnectionConfig: endpoint: str = AWS_IOT_ENDPOINT region: str = AWS_REGION - client_id: Optional[str] = None + client_id: str | None = None clean_session: bool = True keep_alive_secs: int = 1200 @@ -271,7 +273,9 @@ def topic_matches_pattern(topic: str, pattern: str) -> bool: if len(topic_parts) != len(pattern_parts): return False - for topic_part, pattern_part in zip(topic_parts, pattern_parts): + for topic_part, pattern_part in zip( + topic_parts, pattern_parts, strict=True + ): if pattern_part != "+" and topic_part != pattern_part: return False diff --git a/src/nwp500/utils.py b/src/nwp500/utils.py index 22ed9e0..5f1be09 100644 --- a/src/nwp500/utils.py +++ b/src/nwp500/utils.py @@ -5,11 +5,12 @@ including performance monitoring decorators and helper functions. """ -import asyncio import functools +import inspect import logging import time -from typing import Any, Callable, TypeVar, cast +from collections.abc import Callable +from typing import Any, cast __author__ = "Emmanuel Levijarvi" __copyright__ = "Emmanuel Levijarvi" @@ -17,10 +18,8 @@ _logger = logging.getLogger(__name__) -F = TypeVar("F", bound=Callable[..., Any]) - -def log_performance(func: F) -> F: +def log_performance[F: Callable[..., Any]](func: F) -> F: """Log execution time for async functions at DEBUG level. This decorator measures the execution time of async functions and logs @@ -48,7 +47,7 @@ async def fetch_device_status(device_id: str) -> dict: - Uses time.perf_counter() for high-resolution timing - Preserves function metadata (name, docstring, etc.) """ - if not asyncio.iscoroutinefunction(func): + if not inspect.iscoroutinefunction(func): raise TypeError( "@log_performance can only be applied to async " f"functions, got {func}" diff --git a/tests/test_cli_basic.py b/tests/test_cli_basic.py new file mode 100644 index 0000000..60c79f4 --- /dev/null +++ b/tests/test_cli_basic.py @@ -0,0 +1,24 @@ +"""Basic tests for CLI entry point.""" + +import sys +from unittest.mock import patch + +import pytest + +from nwp500.cli.__main__ import run + + +def test_cli_help(): + """Test that CLI help command works.""" + with patch.object(sys, "argv", ["nwp-cli", "--help"]): + with pytest.raises(SystemExit) as excinfo: + run() + assert excinfo.value.code == 0 + + +def test_cli_no_args(): + """Test that CLI without args shows help.""" + with patch.object(sys, "argv", ["nwp-cli"]): + with pytest.raises(SystemExit) as excinfo: + run() + assert excinfo.value.code != 0 diff --git a/tests/test_cli_commands.py b/tests/test_cli_commands.py new file mode 100644 index 0000000..31639a2 --- /dev/null +++ b/tests/test_cli_commands.py @@ -0,0 +1,151 @@ +"""Tests for CLI command handlers.""" + +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from nwp500.cli.commands import ( + get_controller_serial_number, + handle_set_dhw_temp_request, + handle_set_mode_request, + handle_status_request, +) +from nwp500.models import Device, DeviceFeature, DeviceStatus + + +@pytest.fixture +def mock_device(): + device = MagicMock(spec=Device) + device.device_info = MagicMock() + device.device_info.device_type = 123 + return device + + +@pytest.fixture +def mock_mqtt(): + mqtt = MagicMock() + # Async methods need to be AsyncMock + mqtt.subscribe_device_feature = AsyncMock() + mqtt.request_device_info = AsyncMock() + mqtt.subscribe_device_status = AsyncMock() + mqtt.request_device_status = AsyncMock() + mqtt.set_dhw_mode = AsyncMock() + mqtt.set_dhw_temperature = AsyncMock() + return mqtt + + +@pytest.mark.asyncio +async def test_get_controller_serial_number_success(mock_mqtt, mock_device): + """Test successful retrieval of controller serial number.""" + # Setup the feature that will be returned + feature = MagicMock(spec=DeviceFeature) + feature.controller_serial_number = "TEST_SERIAL_123" + + # When subscribe is called, capture the callback and call it immediately + async def side_effect_subscribe(device, callback): + callback(feature) + return None + + mock_mqtt.subscribe_device_feature.side_effect = side_effect_subscribe + + serial = await get_controller_serial_number( + mock_mqtt, mock_device, timeout=1.0 + ) + + assert serial == "TEST_SERIAL_123" + mock_mqtt.request_device_info.assert_called_once_with(mock_device) + + +@pytest.mark.asyncio +async def test_get_controller_serial_number_timeout(mock_mqtt, mock_device): + """Test timeout when retrieving controller serial number.""" + # Do nothing when subscribe is called, so future never completes + mock_mqtt.subscribe_device_feature.return_value = None + + # Reduce timeout for test speed + serial = await get_controller_serial_number( + mock_mqtt, mock_device, timeout=0.1 + ) + + assert serial is None + mock_mqtt.request_device_info.assert_called_once_with(mock_device) + + +@pytest.mark.asyncio +async def test_handle_status_request(mock_mqtt, mock_device, capsys): + """Test status request handler prints output.""" + status = MagicMock(spec=DeviceStatus) + status.model_dump.return_value = {"some": "data"} + + async def side_effect_subscribe(device, callback): + callback(status) + return None + + mock_mqtt.subscribe_device_status.side_effect = side_effect_subscribe + + await handle_status_request(mock_mqtt, mock_device) + + mock_mqtt.request_device_status.assert_called_once_with(mock_device) + captured = capsys.readouterr() + assert "some" in captured.out + assert "data" in captured.out + + +@pytest.mark.asyncio +async def test_handle_set_mode_request_success(mock_mqtt, mock_device): + """Test successful mode setting.""" + status = MagicMock(spec=DeviceStatus) + # Configure nested mock explicitly to avoid spec issues with Pydantic + operation_mode = MagicMock() + operation_mode.name = "HEAT_PUMP" + status.operation_mode = operation_mode + status.model_dump.return_value = {"mode": "HEAT_PUMP"} + + async def side_effect_subscribe(device, callback): + # Invoke callback immediately; handler waits on completed future + callback(status) + return None + + mock_mqtt.subscribe_device_status.side_effect = side_effect_subscribe + + await handle_set_mode_request(mock_mqtt, mock_device, "heat-pump") + + mock_mqtt.set_dhw_mode.assert_called_once_with( + mock_device, 1 + ) # 1 = Heat Pump + + +@pytest.mark.asyncio +async def test_handle_set_mode_request_invalid_mode(mock_mqtt, mock_device): + """Test setting an invalid mode.""" + await handle_set_mode_request(mock_mqtt, mock_device, "invalid-mode") + + mock_mqtt.set_dhw_mode.assert_not_called() + + +@pytest.mark.asyncio +async def test_handle_set_dhw_temp_request_success(mock_mqtt, mock_device): + """Test successful temperature setting.""" + status = MagicMock(spec=DeviceStatus) + status.dhw_target_temperature_setting = 120 + status.model_dump.return_value = {"temp": 120} + + async def side_effect_subscribe(device, callback): + callback(status) + return None + + mock_mqtt.subscribe_device_status.side_effect = side_effect_subscribe + + await handle_set_dhw_temp_request(mock_mqtt, mock_device, 120.0) + + mock_mqtt.set_dhw_temperature.assert_called_once_with(mock_device, 120.0) + + +@pytest.mark.asyncio +async def test_handle_set_dhw_temp_request_out_of_range(mock_mqtt, mock_device): + """Test setting temperature out of range.""" + await handle_set_dhw_temp_request(mock_mqtt, mock_device, 160.0) # > 150 + mock_mqtt.set_dhw_temperature.assert_not_called() + + await handle_set_dhw_temp_request(mock_mqtt, mock_device, 90.0) # < 95 + mock_mqtt.set_dhw_temperature.assert_not_called() diff --git a/tests/test_command_queue.py b/tests/test_command_queue.py index 0f5150e..67c44fe 100644 --- a/tests/test_command_queue.py +++ b/tests/test_command_queue.py @@ -1,7 +1,7 @@ """Tests for command queue functionality.""" from collections import deque -from datetime import datetime +from datetime import UTC, datetime from awscrt import mqtt @@ -14,7 +14,7 @@ def test_queued_command_dataclass(): topic = "test/topic" payload = {"key": "value"} qos = mqtt.QoS.AT_LEAST_ONCE - timestamp = datetime.utcnow() + timestamp = datetime.now(UTC) command = QueuedCommand( topic=topic, payload=payload, qos=qos, timestamp=timestamp @@ -69,7 +69,7 @@ def test_queued_command_fifo_order(): # Add commands for i in range(5): - timestamp = datetime.utcnow() + timestamp = datetime.now(UTC) timestamps.append(timestamp) command = QueuedCommand( topic=f"test/topic/{i}", diff --git a/tests/test_mqtt_client_init.py b/tests/test_mqtt_client_init.py index f3a4609..82fb307 100644 --- a/tests/test_mqtt_client_init.py +++ b/tests/test_mqtt_client_init.py @@ -731,18 +731,20 @@ def test_recover_connection_method_signature( assert "recover" in mqtt_client.recover_connection.__doc__.lower() assert "connection" in mqtt_client.recover_connection.__doc__.lower() - def test_recover_connection_error_handling_docstring( - self, auth_client_with_valid_tokens - ): + def test_recover_connection_error_handling_docstring(self): """Test that recover_connection docstring documents error handling.""" - mqtt_client = NavienMqttClient(auth_client_with_valid_tokens) - - doc = mqtt_client.recover_connection.__doc__ + # Check class attribute directly to avoid bound method issues + doc = NavienMqttClient.recover_connection.__doc__ assert doc is not None # Should mention it can raise exceptions - assert "TokenRefreshError" in doc or "raise" in doc.lower(), ( - "Docstring should document error handling" + # Case insensitive check for 'raises' or specific errors + doc_lower = doc.lower() + has_raises = "raises" in doc_lower + has_token_error = "tokenrefresherror" in doc_lower + + assert has_raises or has_token_error, ( + f"Docstring should document error handling. Got: {doc[:100]}..." ) @pytest.mark.asyncio