diff --git a/openaq/__init__.py b/openaq/__init__.py index 9cebf52d..61a08327 100644 --- a/openaq/__init__.py +++ b/openaq/__init__.py @@ -2,7 +2,7 @@ import logging -__version__ = "1.0.0rc3" +__version__ = "1.0.0rc4" logger = logging.getLogger(__name__) @@ -18,6 +18,7 @@ GatewayTimeoutError, HTTPRateLimitError, IdentifierOutOfBoundsError, + InvalidParameterError, NotAuthorizedError, NotFoundError, RateLimitError, @@ -34,6 +35,7 @@ "NotAuthorizedError", "NotFoundError", "TimeoutError", + "InvalidParameterError", "ValidationError", "GatewayTimeoutError", "HTTPRateLimitError", @@ -43,4 +45,5 @@ "ServerError", "ServiceUnavailableError", "BadGatewayError", + "__version__", ] diff --git a/openaq/client.py b/openaq/client.py index 4598eaeb..4c9931af 100644 --- a/openaq/client.py +++ b/openaq/client.py @@ -14,6 +14,7 @@ from pathlib import Path from types import TracebackType from typing import Mapping +from urllib.parse import urljoin, urlparse from openaq import __version__ from openaq.core.exceptions import ApiKeyMissingError, RateLimitError @@ -21,9 +22,7 @@ DEFAULT_LIMITS, DEFAULT_TIMEOUT, Headers, - Limits, Response, - Timeout, Transport, ) from openaq.models.countries import Countries @@ -83,17 +82,11 @@ class OpenAQ: auto_wait: Whether to automatically wait when rate limited. Defaults to True. base_url: The base URL for the API endpoint. - transport: The transport instance for making HTTP requests. For internal + _transport: The transport instance for making HTTP requests. For internal use. rate_limit_override: Override the default rate limit capacity of 60 requests per minute. Useful for accounts with a higher rate limit. Defaults to None. - timeout: Timeout configuration for HTTP requests. Defaults to 5 seconds - for connection and pool, and 8 seconds for read to account for the - API's 6 second processing limit. Pass None for no timeout. - limits: Connection pool limits for the HTTP transport. Defaults to 20 - maximum connections with 10 keepalive connections. Keepalive - connections expire after 30 seconds. Note: An API key can either be passed directly to the OpenAQ client class at @@ -131,9 +124,7 @@ def __init__( headers: Mapping[str, str] | None = None, auto_wait: bool = True, base_url: str = DEFAULT_BASE_URL, - transport: Transport | None = None, - timeout: float | Timeout | None = DEFAULT_TIMEOUT, - limits: Limits = DEFAULT_LIMITS, + _transport: Transport | None = None, # internal use only rate_limit_override: int | None = None, ) -> None: """Initializes the OpenAQ client. @@ -146,14 +137,8 @@ def __init__( auto_wait: Whether to automatically wait when rate limited. Defaults to True. base_url: The base URL for the API endpoint. - transport: The transport instance for making HTTP requests. For + _transport: The transport instance for making HTTP requests. For internal use. - timeout: Timeout configuration for HTTP requests. Defaults to 5 - seconds for connection and 8 seconds for read. Pass None for no - timeout. - limits: Connection pool limits for the HTTP transport. Defaults to - 20 maximum connections with 10 keepalive connections expiring - after 30 seconds. rate_limit_override: Initial rate limit capacity in requests per minute. Defaults to 60 and is corrected automatically from server response headers after the first request. @@ -163,12 +148,17 @@ def __init__( URL is used. """ self._api_key = _resolve_api_key(api_key) - self._base_url = base_url + parsed = urlparse(base_url) + if not parsed.scheme or not parsed.netloc: + raise ValueError( + f"Invalid base_url, must be a fully qualified URL: {base_url!r}" + ) + self._base_url = parsed.geturl() self._auto_wait = auto_wait self._transport = ( - transport - if transport is not None - else Transport(timeout=timeout, limits=limits) + _transport + if _transport is not None + else Transport(timeout=DEFAULT_TIMEOUT, limits=DEFAULT_LIMITS) ) self._headers = Headers(headers or {}) @@ -181,8 +171,8 @@ def __init__( ) self._user_agent = f"openaq-python-{__version__}-{platform.python_version()}" - assert self._api_key is not None - self._headers["X-API-Key"] = self._api_key + if self._api_key: + self._headers["X-API-Key"] = self._api_key self._headers["User-Agent"] = self._user_agent self._headers["Accept"] = ACCEPT_HEADER @@ -251,9 +241,13 @@ def _check_rate_limit(self) -> None: self._wait_for_rate_limit_reset() self._rate_limit_remaining = self._rate_limit_capacity else: - message = f"Rate limit exceeded. Limit resets in {self._rate_limit_reset_seconds} seconds" - logger.error(message) - raise RateLimitError(message) + logger.error( + "Rate limit exceeded. Limit resets in %s seconds", + self._rate_limit_reset_seconds, + ) + raise RateLimitError( + f"Rate limit exceeded. Limit resets in {self._rate_limit_reset_seconds} seconds" + ) self._rate_limit_remaining -= 1 def _set_rate_limit(self, headers: Headers) -> None: @@ -286,7 +280,7 @@ def _wait_for_rate_limit_reset(self) -> None: """ wait_seconds = self._rate_limit_reset_seconds if wait_seconds > 0: - logger.info(f"Rate limit hit. Waiting {wait_seconds} seconds for reset.") + logger.info("Rate limit hit. Waiting %s seconds for reset.", wait_seconds) time.sleep(wait_seconds) def _get_int_header(self, headers: Headers, key: str, default: int) -> int: @@ -349,7 +343,7 @@ def _do( """ self._check_rate_limit() request_headers = self._build_request_headers(headers) - url = self._base_url + path + url = urljoin(self._base_url, path.lstrip("/")) data = self._transport.send_request( method=method, url=url, params=params, headers=request_headers ) diff --git a/openaq/core/transport.py b/openaq/core/transport.py index 2009ef81..f455dab9 100644 --- a/openaq/core/transport.py +++ b/openaq/core/transport.py @@ -1,10 +1,11 @@ -"""Base class and utlity functions for working with client transport.""" +"""Base class and utility functions for working with client transport.""" from __future__ import annotations import http.client import json import logging +import ssl import threading import time import urllib.parse @@ -336,7 +337,7 @@ def release(self, pc: PooledConnection, *, discard: bool = False) -> None: def close_all(self) -> None: """Closes all idle connections in the pool and resets its state.""" - with self._lock: + with self._has_capacity: for q in self._idle.values(): for pc in q: try: @@ -345,6 +346,7 @@ def close_all(self) -> None: pass self._idle.clear() self._total = 0 + self._has_capacity.notify_all() def _encode_params( @@ -434,21 +436,31 @@ def _raw_request( for attempt in range(2): pc = self._pool.acquire(host, self._pool_timeout) try: - if self._read_timeout is not None: - if pc.conn.sock is not None: - pc.conn.sock.settimeout(self._read_timeout) - pc.conn.request(method, path, headers=dict(headers)) - raw = pc.conn.getresponse() - - # After connect, set socket timeout for the read. if self._read_timeout is not None and pc.conn.sock is not None: pc.conn.sock.settimeout(self._read_timeout) - + raw = pc.conn.getresponse() body = raw.read() resp = Response(raw.status, body, raw.msg) self._pool.release(pc) return resp + + except ssl.SSLCertVerificationError as exc: + self._pool.release(pc, discard=True) + logger.error( + "SSL certificate verification failed for %s: %s. " + "On macOS, run 'Install Certificates.command' in your Python " + "installation directory to fix this.", + host, + exc, + ) + raise + + except ssl.SSLError as exc: + self._pool.release(pc, discard=True) + logger.error("SSL error for %s: %s", host, exc) + raise + except (OSError, http.client.HTTPException) as exc: self._pool.release(pc, discard=True) if attempt == 1: @@ -495,7 +507,7 @@ def send_request( if parsed.query: path = f"{path}?{parsed.query}" - res = self._raw_request(method, host, path, headers) + res = self._raw_request(method.upper(), host, path, headers) logger.debug("Received response: %s from %s", res.status_code, url) return check_response(res) @@ -504,6 +516,20 @@ def close(self) -> None: self._pool.close_all() +_HTTP_SATUS_MAP = { + HTTPStatus.BAD_REQUEST: BadRequestError, + HTTPStatus.NOT_FOUND: NotFoundError, + HTTPStatus.REQUEST_TIMEOUT: TimeoutError, + HTTPStatus.FORBIDDEN: ForbiddenError, + HTTPStatus.UNPROCESSABLE_ENTITY: ValidationError, + HTTPStatus.TOO_MANY_REQUESTS: HTTPRateLimitError, + HTTPStatus.UNAUTHORIZED: NotAuthorizedError, + HTTPStatus.INTERNAL_SERVER_ERROR: ServerError, + HTTPStatus.BAD_GATEWAY: BadGatewayError, + HTTPStatus.SERVICE_UNAVAILABLE: ServiceUnavailableError, +} + + def check_response(res: Response) -> Response: """Checks the HTTP response of the request. @@ -528,42 +554,18 @@ def check_response(res: Response) -> Response: """ if res.status_code >= HTTPStatus.OK and res.status_code < HTTPStatus.BAD_REQUEST: return res - elif res.status_code == HTTPStatus.BAD_REQUEST: - logger.exception(f"HTTP {res.status_code} - {res.text}") - raise BadRequestError(res.text) - elif res.status_code == HTTPStatus.NOT_FOUND: - logger.exception(f"HTTP {res.status_code} - {res.text}") - raise NotFoundError(res.text) - elif res.status_code == HTTPStatus.REQUEST_TIMEOUT: - logger.exception(f"HTTP {res.status_code} - {res.text}") - raise TimeoutError(res.text) - elif res.status_code == HTTPStatus.FORBIDDEN: - logger.exception(f"HTTP {res.status_code} - {res.text}") - raise ForbiddenError(res.text) - elif res.status_code == HTTPStatus.UNPROCESSABLE_ENTITY: - logger.exception(f"HTTP {res.status_code} - {res.text}") - raise ValidationError(res.text) - elif res.status_code == HTTPStatus.TOO_MANY_REQUESTS: - logger.exception(f"HTTP {res.status_code} - {res.text}") - raise HTTPRateLimitError(res.text) - elif res.status_code == HTTPStatus.UNAUTHORIZED: - logger.exception(f"HTTP {res.status_code} - {res.text}") - raise NotAuthorizedError(res.text) - elif res.status_code == HTTPStatus.INTERNAL_SERVER_ERROR: - logger.exception(f"HTTP {res.status_code} - {res.text}") - raise ServerError(res.text) - elif res.status_code == HTTPStatus.BAD_GATEWAY: - logger.exception(f"HTTP {res.status_code} - {res.text}") - raise BadGatewayError(res.text) - elif res.status_code == HTTPStatus.SERVICE_UNAVAILABLE: - logger.exception(f"HTTP {res.status_code} - {res.text}") - raise ServiceUnavailableError(res.text) - elif res.status_code == HTTPStatus.GATEWAY_TIMEOUT: - logger.exception(f"HTTP {res.status_code} - {res.text}") + if res.status_code == HTTPStatus.GATEWAY_TIMEOUT: + logger.error("HTTP %s - %s", res.status_code, res.text) raise GatewayTimeoutError( "Your request timed out on the server. " "Consider reducing the complexity of your request." ) - else: - logger.exception(f"HTTP {res.status_code} - {res.text}") - raise Exception + try: + http_status = HTTPStatus(res.status_code) + except ValueError: + http_status = None + exc_class = ( + _HTTP_SATUS_MAP.get(http_status, ServerError) if http_status else ServerError + ) + logger.error("HTTP %s - %s", res.status_code, res.text) + raise exc_class(res.text) diff --git a/openaq/core/validators.py b/openaq/core/validators.py index e0caab32..bdc38d8c 100644 --- a/openaq/core/validators.py +++ b/openaq/core/validators.py @@ -861,7 +861,7 @@ def validate_datetime_params( if date_to: if not date_check(date_from) or not date_check(date_to): raise InvalidParameterError( - f"Invalid date_from or date_to, must be either date type or ISO-8601 formatted date string, got {type(date_from)} and {type(date_to)}" + f"Invalid date_from or date_to, must be either datetime.date type or ISO-8601 formatted date string, got {type(date_from)} and {type(date_to)}" ) date_from_date = to_date(date_from) date_to_date = to_date(date_to) @@ -888,7 +888,7 @@ def validate_datetime_params( if datetime_to is not None: if not datetime_check(datetime_from) or not datetime_check(datetime_to): raise InvalidParameterError( - f"Invalid datetime_from or datetime_to, must be either datetime type or ISO-8601 formatted string, got {type(datetime_from)} and {type(datetime_to)}" + f"Invalid datetime_from or datetime_to, must be either datetime.datetime type or ISO-8601 formatted string, got {type(datetime_from)} and {type(datetime_to)}" ) datetime_from_datetime = to_datetime(datetime_from) datetime_to_datetime = to_datetime(datetime_to) @@ -902,7 +902,7 @@ def validate_datetime_params( elif datetime_from is not None: if not datetime_check(datetime_from): raise InvalidParameterError( - f"Invalid datetime_from, must be either datetime type or ISO-8601 formatted string, got {type(datetime_from)}" + f"Invalid datetime_from, must be either datetime.datetime type or ISO-8601 formatted string, got {type(datetime_from)}" ) datetime_from_datetime = to_datetime(datetime_from) if not datetime_from_lesser_check(datetime_from_datetime): diff --git a/openaq/models/base.py b/openaq/models/base.py index 259a4aca..e252f5c5 100644 --- a/openaq/models/base.py +++ b/openaq/models/base.py @@ -1,3 +1,5 @@ +"""Base class for OpenAQ API resource models.""" + from __future__ import annotations from typing import TYPE_CHECKING diff --git a/openaq/models/countries.py b/openaq/models/countries.py index f7c7340e..a0e7b8c4 100644 --- a/openaq/models/countries.py +++ b/openaq/models/countries.py @@ -1,3 +1,5 @@ +"""Resource model for interacting with the countries endpoints of the OpenAQ API.""" + from openaq.core.models import build_query_params from openaq.core.responses import CountriesResponse from openaq.core.types import SortOrder diff --git a/openaq/models/instruments.py b/openaq/models/instruments.py index 8c95cbb1..a23a7328 100644 --- a/openaq/models/instruments.py +++ b/openaq/models/instruments.py @@ -1,3 +1,5 @@ +"""Resource model for interacting with the instruments endpoints of the OpenAQ API.""" + from openaq.core.models import build_query_params from openaq.core.responses import InstrumentsResponse from openaq.core.types import SortOrder diff --git a/openaq/models/licenses.py b/openaq/models/licenses.py index 78d1f678..65aab921 100644 --- a/openaq/models/licenses.py +++ b/openaq/models/licenses.py @@ -1,3 +1,5 @@ +"""Resource model for interacting with the licenses endpoints of the OpenAQ API.""" + from openaq.core.models import build_query_params from openaq.core.responses import LicensesResponse from openaq.core.types import SortOrder diff --git a/openaq/models/locations.py b/openaq/models/locations.py index 6558a666..e926d697 100644 --- a/openaq/models/locations.py +++ b/openaq/models/locations.py @@ -1,3 +1,5 @@ +"""Resource model for interacting with the locations endpoints of the OpenAQ API.""" + from openaq.core.models import build_query_params from openaq.core.responses import ( LatestResponse, diff --git a/openaq/models/manufacturers.py b/openaq/models/manufacturers.py index 86261002..e82ffbbe 100644 --- a/openaq/models/manufacturers.py +++ b/openaq/models/manufacturers.py @@ -1,3 +1,5 @@ +"""Resource model for interacting with the manufacturers endpoints of the OpenAQ API.""" + from openaq.core.models import build_query_params from openaq.core.responses import InstrumentsResponse, ManufacturersResponse from openaq.core.types import SortOrder diff --git a/openaq/models/measurements.py b/openaq/models/measurements.py index 751e731b..b1f986ba 100644 --- a/openaq/models/measurements.py +++ b/openaq/models/measurements.py @@ -1,3 +1,5 @@ +"""Resource model for interacting with the measurements endpoints of the OpenAQ API.""" + import datetime from typing import overload diff --git a/openaq/models/owners.py b/openaq/models/owners.py index be41ca53..f4476f79 100644 --- a/openaq/models/owners.py +++ b/openaq/models/owners.py @@ -1,3 +1,5 @@ +"""Resource model for interacting with the owners endpoints of the OpenAQ API.""" + from openaq.core.models import build_query_params from openaq.core.responses import OwnersResponse from openaq.core.types import SortOrder diff --git a/openaq/models/parameters.py b/openaq/models/parameters.py index 0d19613f..2c905444 100644 --- a/openaq/models/parameters.py +++ b/openaq/models/parameters.py @@ -1,3 +1,5 @@ +"""Resource model for interacting with the parameters endpoints of the OpenAQ API.""" + from openaq.core.models import build_query_params from openaq.core.responses import LatestResponse, ParametersResponse from openaq.core.types import ParameterType, SortOrder diff --git a/openaq/models/providers.py b/openaq/models/providers.py index ea3dc7e3..4aab5e75 100644 --- a/openaq/models/providers.py +++ b/openaq/models/providers.py @@ -1,3 +1,5 @@ +"""Resource model for interacting with the providers endpoints of the OpenAQ API.""" + from openaq.core.models import build_query_params from openaq.core.responses import ProvidersResponse from openaq.core.types import SortOrder diff --git a/openaq/models/sensors.py b/openaq/models/sensors.py index 7f59f435..52dfee0b 100644 --- a/openaq/models/sensors.py +++ b/openaq/models/sensors.py @@ -1,3 +1,5 @@ +"""Resource model for interacting with the sensors endpoints of the OpenAQ API.""" + from openaq.core.responses import SensorsResponse from openaq.core.validators import validate_integer_id diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 3bc171fc..a847992b 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -36,7 +36,7 @@ def mock_config_file(): class TestClient: @pytest.fixture() def setup(self): - self.client = OpenAQ(api_key="abc123-def456-ghi789", transport=MockTransport()) + self.client = OpenAQ(api_key="abc123-def456-ghi789", _transport=MockTransport()) @pytest.fixture() def mock_openaq_api_key_env_vars(self): @@ -46,7 +46,7 @@ def mock_openaq_api_key_env_vars(self): yield def test_transport_property(self, setup): - assert isinstance(self.client.transport, MockTransport) + assert isinstance(self.client._transport, MockTransport) with pytest.raises(AttributeError): self.client.transport = MockTransport() @@ -61,7 +61,7 @@ def test_custom_headers(self, setup): self.client = OpenAQ( api_key="abc123-def456-ghi789", base_url="https://mycustom.openaq.org", - transport=MockTransport(), + _transport=MockTransport(), ) assert self.client.headers["X-API-Key"] == "abc123-def456-ghi789" @@ -69,22 +69,22 @@ def test_client_params(self, setup): self.client = OpenAQ( api_key="abc123-def456-ghi789", base_url="https://mycustom.openaq.org", - transport=MockTransport(), + _transport=MockTransport(), ) assert self.client._base_url == "https://mycustom.openaq.org" def test_api_env_var(self, mock_openaq_api_key_env_vars): - client = OpenAQ(transport=MockTransport()) + client = OpenAQ(_transport=MockTransport()) assert client.api_key == "openaq-1a2b3c4d5e6f7g8h9i0j1k2l3m4n5o6p" @pytest.mark.usefixtures("mock_config_file") def test_api_key_from_config(self): if int(platform.python_version_tuple()[1]) >= 11: - client = OpenAQ(transport=MockTransport()) + client = OpenAQ(_transport=MockTransport()) assert client.api_key == "test_api_key" else: with pytest.raises(ApiKeyMissingError): - client = OpenAQ(transport=MockTransport()) + client = OpenAQ(_transport=MockTransport()) def test_api_key_arg_override_env_var(self, setup, mock_openaq_api_key_env_vars): assert self.client.api_key == "abc123-def456-ghi789" @@ -97,6 +97,10 @@ def test_api_key_arg_override_env_vars_config( ): assert self.client.api_key == "abc123-def456-ghi789" + def test_raises_api_key_missing_error_when_key_is_none(self): + with pytest.raises(ApiKeyMissingError): + OpenAQ(api_key=None, _transport=MockTransport()) + @patch('openaq.client.datetime') @patch('time.sleep') @patch('openaq.client.logger') @@ -111,7 +115,7 @@ def test_wait_for_rate_limit_reset_waits_when_positive( mock_sleep.assert_called_once_with(5) mock_logger.info.assert_called_once_with( - "Rate limit hit. Waiting 5 seconds for reset." + "Rate limit hit. Waiting %s seconds for reset.", 5 ) @patch('openaq.client.datetime') @@ -169,7 +173,7 @@ def test_context_manager_exit_closes_even_with_exception(self, setup): def test_blocks_after_custom_limit(self): client = OpenAQ( api_key="abc123-def456-ghi789", - transport=MockTransport(), + _transport=MockTransport(), auto_wait=False, rate_limit_override=5, ) @@ -184,7 +188,7 @@ def test_allows_exactly_override_requests(self): limit = 10 client = OpenAQ( api_key="abc123-def456-ghi789", - transport=MockTransport(), + _transport=MockTransport(), auto_wait=False, rate_limit_override=limit, ) @@ -328,27 +332,46 @@ def test_do_raises_before_sending_when_rate_limited(self, setup): def test_default_timeout_applied_to_transport(self): client = OpenAQ(api_key="abc123-def456-ghi789") - assert client.transport._connect_timeout == DEFAULT_TIMEOUT.connect - assert client.transport._read_timeout == DEFAULT_TIMEOUT.read - - def test_custom_timeout_passed_to_transport(self): - custom_timeout = Timeout(10.0, read=15.0) - client = OpenAQ(api_key="abc123-def456-ghi789", timeout=custom_timeout) - assert client.transport._connect_timeout == custom_timeout.connect - assert client.transport._read_timeout == custom_timeout.read + assert client._transport._connect_timeout == DEFAULT_TIMEOUT.connect + assert client._transport._read_timeout == DEFAULT_TIMEOUT.read def test_default_limits_applied_to_transport(self): client = OpenAQ(api_key="abc123-def456-ghi789") - assert client.transport._pool._max_total == DEFAULT_LIMITS.max_connections + assert client._transport._pool._max_total == DEFAULT_LIMITS.max_connections assert ( - client.transport._pool._max_idle == DEFAULT_LIMITS.max_keepalive_connections + client._transport._pool._max_idle + == DEFAULT_LIMITS.max_keepalive_connections ) - def test_custom_limits_passed_to_transport(self): - custom_limits = Limits(max_connections=5, max_keepalive_connections=2) - client = OpenAQ(api_key="abc123-def456-ghi789", limits=custom_limits) - assert client.transport._pool._max_total == 5 - assert client.transport._pool._max_idle == 2 + def test_do_does_not_produce_double_slash_in_url(self, setup): + """Leading slash on path should not create double slash when base_url has trailing slash.""" + mock_response = MagicMock() + mock_response.headers = Headers({}) + self.client._transport.send_request = Mock(return_value=mock_response) + + self.client._do("get", "/locations/1") + + call_kwargs = self.client._transport.send_request.call_args + url = call_kwargs.kwargs["url"] + assert "//" not in url.replace( + "https://", "" + ), f"URL contains double slash: {url}" + + def test_raises_value_error_for_base_url_without_scheme(self): + with pytest.raises(ValueError, match="Invalid base_url"): + OpenAQ( + api_key="abc123-def456-ghi789", + base_url="api.openaq.org/v3/", + _transport=MockTransport(), + ) + + def test_raises_value_error_for_base_url_without_netloc(self): + with pytest.raises(ValueError, match="Invalid base_url"): + OpenAQ( + api_key="abc123-def456-ghi789", + base_url="https://", + _transport=MockTransport(), + ) def test_tomllib_conditional_import(): diff --git a/tests/unit/test_transport.py b/tests/unit/test_transport.py index 3a29dee3..42548f98 100644 --- a/tests/unit/test_transport.py +++ b/tests/unit/test_transport.py @@ -1,4 +1,5 @@ import http.client +import ssl import threading import time from unittest import mock @@ -290,3 +291,57 @@ def test_raises_after_two_failures(self): with acquire, release: with pytest.raises(OSError): transport._raw_request("GET", "api.openaq.org", "/v3/locations", {}) + + def test_send_request_uppercases_method(self): + transport = Transport() + transport._raw_request = mock.Mock( + return_value=Response(200, b'{}', http.client.HTTPMessage()) + ) + + transport.send_request( + "get", "https://api.openaq.org/v3/locations/1", None, Headers() + ) + + method = transport._raw_request.call_args.args[0] + assert method == "GET" + + @pytest.mark.parametrize( + "exc", + [ + ssl.SSLCertVerificationError("cert verify failed"), + ssl.SSLError("ssl error"), + ], + ) + def test_ssl_error_discards_connection_and_raises(self, exc): + transport = self.make_transport() + raw = make_raw_response() + acquire, release, pc = self.patch_pool(transport, raw) + pc.conn.request.side_effect = exc + + with acquire, release as mock_release: + with pytest.raises(ssl.SSLError): + transport._raw_request("GET", "api.openaq.org", "/v3/locations", {}) + mock_release.assert_called_once_with(pc, discard=True) + + def test_ssl_error_logs(self): + transport = self.make_transport() + raw = make_raw_response() + acquire, release, pc = self.patch_pool(transport, raw) + pc.conn.request.side_effect = ssl.SSLCertVerificationError("cert verify failed") + + with acquire, release: + with pytest.raises(ssl.SSLCertVerificationError): + with mock.patch("openaq.core.transport.logger") as mock_logger: + transport._raw_request("GET", "api.openaq.org", "/v3/locations", {}) + mock_logger.error.assert_called_once() + + def test_ssl_error_does_not_retry(self): + transport = self.make_transport() + raw = make_raw_response() + acquire, release, pc = self.patch_pool(transport, raw) + pc.conn.request.side_effect = ssl.SSLError("SSL error") + + with acquire, release: + with pytest.raises(ssl.SSLError): + transport._raw_request("GET", "api.openaq.org", "/v3/locations", {}) + assert pc.conn.request.call_count == 1