Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions src/hatch_rest_api/aws_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from aiohttp import ClientSession, ClientResponse
import json

from .types import AwsIotCredentialsResponse
from .util_http import request_with_logging

_LOGGER = logging.getLogger(__name__)
Expand All @@ -11,22 +12,22 @@


class AwsHttp:
def __init__(self, client_session: ClientSession = None):
def __init__(self, client_session: ClientSession | None = None):
if client_session is None:
self.api_session = ClientSession(raise_for_status=True)
else:
self.api_session = client_session

async def cleanup_client_session(self):
async def cleanup_client_session(self) -> None:
await self.api_session.close()

@request_with_logging
async def _post_request_with_logging_and_errors_raised(
self, url: str, json_body: dict, headers: dict = None
self, url: str, json_body: dict, headers: dict | None = None
) -> ClientResponse:
return await self.api_session.post(url=url, json=json_body, headers=headers)

async def aws_credentials(self, region: str, identityId: str, aws_token: str):
async def aws_credentials(self, region: str, identityId: str, aws_token: str) -> AwsIotCredentialsResponse:
url = f"https://cognito-identity.{region}.amazonaws.com"
json_body = {
"IdentityId": identityId,
Expand Down
9 changes: 5 additions & 4 deletions src/hatch_rest_api/callbacks.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
from collections.abc import Callable
import logging

_LOGGER = logging.getLogger(__name__)


class CallbacksMixin:
def _setup_callbacks(self):
self._callbacks = set()
def _setup_callbacks(self) -> None:
self._callbacks: set[Callable[[], None]] = set()

def register_callback(self, callback) -> None:
def register_callback(self, callback: Callable[[], None]) -> None:
if not hasattr(self, "_callbacks"):
self._setup_callbacks()
self._callbacks.add(callback)

def remove_callback(self, callback) -> None:
def remove_callback(self, callback: Callable[[], None]) -> None:
if not hasattr(self, "_callbacks"):
self._setup_callbacks()
self._callbacks.discard(callback)
Expand Down
8 changes: 4 additions & 4 deletions src/hatch_rest_api/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
NO_SOUND_ID = 19998


class RestMiniAudioTrack(Enum):
class RestMiniAudioTrack(int, Enum):
NONE = 0
Heartbeat = 10124
Water = 10125
Expand All @@ -24,7 +24,7 @@ class RestMiniAudioTrack(Enum):
Birds = 10131


class RestPlusAudioTrack(Enum):
class RestPlusAudioTrack(int, Enum):
NONE = 0
Stream = 2
PinkNoise = 3
Expand All @@ -39,7 +39,7 @@ class RestPlusAudioTrack(Enum):
RockABye = 14


class RIoTAudioTrack(Enum):
class RIoTAudioTrack(int, Enum):
NONE = NO_SOUND_ID
BrownNoise = 10200
WhiteNoise = 10137
Expand All @@ -62,7 +62,7 @@ class RIoTAudioTrack(Enum):
RockABye = 10194

@classmethod
def sound_url_map(cls):
def sound_url_map(cls) -> dict[int, str]:
"""
Hard-coded list, as some of these values are not returned by the 'sounds' API. These were found from manually browsing the app and playing each
song, collecting the necessary values (name, id, url) from the Home Assistant debug logs for the `ha_hatch` custom component integration.
Expand Down
7 changes: 4 additions & 3 deletions src/hatch_rest_api/contentful.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from aiohttp import ClientError, ClientResponse, ClientSession

from .errors import RateError
from .types import JsonType

_LOGGER = logging.getLogger(__name__)

Expand All @@ -15,13 +16,13 @@


class Contentful:
def __init__(self, client_session: ClientSession = None):
def __init__(self, client_session: ClientSession | None = None):
self.api_session = client_session or ClientSession()

async def cleanup_client_session(self):
async def cleanup_client_session(self) -> None:
await self.api_session.close()

async def graphql_query(self, query, auth_token=None, max_retries=3, **variables):
async def graphql_query(self, query: str, auth_token: str | None = None, max_retries: int = 3, **variables: JsonType) -> dict[str, JsonType]:
retry_count = 0
while True:
try:
Expand Down
72 changes: 56 additions & 16 deletions src/hatch_rest_api/hatch.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,39 @@
import asyncio
from collections.abc import Awaitable, Callable, Mapping, Sequence
import logging
from typing import Literal, TypedDict, overload

from aiohttp import ClientError, ClientResponse, ClientSession, __version__
from aiohttp.hdrs import USER_AGENT

from .errors import AuthError, RateError
from .types import (
IotDeviceInfo,
IotTokenResponse,
JsonType,
LoginResponse,
Product,
RestIotRoutine,
SimpleSoundContent,
)
from .util_http import request_with_logging

type ContentType = Literal["sound", "color", "windDown"]


class ContentResponse[T: SimpleSoundContent | Mapping[str, JsonType]](TypedDict):
contentItems: list[T]


_LOGGER = logging.getLogger(__name__)

API_URL: str = "https://data.hatchbaby.com/"


def request_with_logging_and_errors(func):
async def request_with_logging_wrapper(*args, **kwargs):
def request_with_logging_and_errors[**P, T: ClientResponse](
func: Callable[P, Awaitable[T]],
) -> Callable[P, Awaitable[T]]:
async def request_with_logging_wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
response = await func(*args, **kwargs)

if response.status == 429:
Expand Down Expand Up @@ -45,20 +65,20 @@ async def request_with_logging_wrapper(*args, **kwargs):


class Hatch:
def __init__(self, client_session: ClientSession = None):
def __init__(self, client_session: ClientSession | None = None):
if client_session is None:
self.api_session = ClientSession(raise_for_status=True)
else:
self.api_session = client_session
_LOGGER.debug(f"api_session_version: {__version__}")

async def cleanup_client_session(self):
async def cleanup_client_session(self) -> None:
await self.api_session.close()

@request_with_logging_and_errors
@request_with_logging
async def _post_request_with_logging_and_errors_raised(
self, url: str, json_body: dict, auth_token: str = None
self, url: str, json_body: dict, auth_token: str | None = None
) -> ClientResponse:
headers = {USER_AGENT: "hatch_rest_api"}
if auth_token is not None:
Expand All @@ -68,7 +88,7 @@ async def _post_request_with_logging_and_errors_raised(
@request_with_logging
@request_with_logging_and_errors
async def _get_request_with_logging_and_errors_raised(
self, url: str, auth_token: str = None, params: dict = None
self, url: str, auth_token: str | None = None, params: dict | None = None
) -> ClientResponse:
headers = {USER_AGENT: "hatch_rest_api"}
if auth_token is not None:
Expand All @@ -86,10 +106,10 @@ async def login(self, email: str, password: str) -> str:
url=url, json_body=json_body
)
)
response_json = await response.json()
response_json: LoginResponse = await response.json()
return response_json["token"]

async def member(self, auth_token: str):
async def member(self, auth_token: str) -> dict[str, JsonType]:
url = API_URL + "service/app/v2/member"
response: ClientResponse = (
await self._get_request_with_logging_and_errors_raised(
Expand All @@ -99,7 +119,7 @@ async def member(self, auth_token: str):
response_json = await response.json()
return response_json["payload"]

async def iot_devices(self, auth_token: str):
async def iot_devices(self, auth_token: str) -> list[IotDeviceInfo]:
url = API_URL + "service/app/iotDevice/v2/fetch"
params = {
"iotProducts": [
Expand All @@ -119,7 +139,7 @@ async def iot_devices(self, auth_token: str):
response_json = await response.json()
return response_json["payload"]

async def token(self, auth_token: str):
async def token(self, auth_token: str) -> IotTokenResponse:
url = API_URL + "service/app/restPlus/token/v1/fetch"
response: ClientResponse = (
await self._get_request_with_logging_and_errors_raised(
Expand All @@ -129,7 +149,7 @@ async def token(self, auth_token: str):
response_json = await response.json()
return response_json["payload"]

async def favorites(self, auth_token: str, mac: str):
async def favorites(self, auth_token: str, mac: str) -> list[RestIotRoutine]:
url = API_URL + "service/app/routine/v2/fetch"
params = {"macAddress": mac}
response: ClientResponse = (
Expand All @@ -138,11 +158,11 @@ async def favorites(self, auth_token: str, mac: str):
)
)
response_json = await response.json()
favorites = response_json["payload"]
favorites: list[RestIotRoutine] = response_json["payload"]
favorites.sort(key=lambda x: x.get("displayOrder", float("inf")))
return favorites

async def routines(self, auth_token: str, mac: str):
async def routines(self, auth_token: str, mac: str) -> list[RestIotRoutine]:
url = API_URL + "service/app/routine/v2/fetch"
params = {"macAddress": mac, "types": "routine"}
response: ClientResponse = (
Expand All @@ -151,13 +171,33 @@ async def routines(self, auth_token: str, mac: str):
)
)
response_json = await response.json()
routines = response_json["payload"]
routines: list[RestIotRoutine] = response_json["payload"]
routines.sort(key=lambda x: x.get("displayOrder", float("inf")))
return routines

@overload
async def content(
self,
auth_token: str,
product: Product,
content: Sequence[Literal["sound"]],
max_retries: int = 3,
) -> ContentResponse[SimpleSoundContent]: ...
@overload
async def content(
self,
auth_token: str,
product: Product,
content: Sequence[Literal["color", "windDown"]],
max_retries: int = 3,
) -> ContentResponse[Mapping[str, JsonType]]: ...
async def content(
self, auth_token: str, product: str, content: list, max_retries: int = 3
):
self,
auth_token: str,
product: Product,
content: Sequence[ContentType],
max_retries: int = 3,
) -> ContentResponse[Mapping[str, JsonType]] | ContentResponse[SimpleSoundContent]:
# content options are ["sound", "color", "windDown"]
url = API_URL + "service/app/content/v1/fetchByProduct"
params = {"product": product, "contentTypes": content}
Expand Down
Loading