From c18d753601e0970c4383b66ba6ff64b37c5b1079 Mon Sep 17 00:00:00 2001 From: LifeJiggy Date: Sat, 11 Oct 2025 17:37:01 +0100 Subject: [PATCH 1/7] Fix known issues: kubernetes.get_kubeconfig and invoices.get_pdf_by_uuid --- src/pydo/operations/_patch.py | 26 ++++++++++++++++++++++++++ tests/mocked/test_billing.py | 5 ++--- tests/mocked/test_kubernetes.py | 3 --- 3 files changed, 28 insertions(+), 6 deletions(-) diff --git a/src/pydo/operations/_patch.py b/src/pydo/operations/_patch.py index 8a843f7c..2e1b836b 100644 --- a/src/pydo/operations/_patch.py +++ b/src/pydo/operations/_patch.py @@ -9,6 +9,8 @@ from typing import TYPE_CHECKING from ._operations import DropletsOperations as Droplets +from ._operations import KubernetesOperations as Kubernetes +from ._operations import InvoicesOperations as Invoices if TYPE_CHECKING: # pylint: disable=unused-import,ungrouped-imports @@ -25,3 +27,27 @@ def patch_sdk(): you can't accomplish using the techniques described in https://aka.ms/azsdk/python/dpcodegen/python/customize """ + + # Fix kubernetes.get_kubeconfig to return raw YAML content instead of trying to parse as JSON + def _get_kubeconfig(self, cluster_id, **kwargs): + """Get a Kubernetes config file for the specified cluster.""" + # Call the original method but with raw response + response = self._client.get( + f"/v2/kubernetes/clusters/{cluster_id}/kubeconfig", + **kwargs + ) + return response.content + + Kubernetes.get_kubeconfig = _get_kubeconfig + + # Fix invoices.get_pdf_by_uuid to return raw PDF content instead of trying to parse as JSON + def _get_pdf_by_uuid(self, invoice_uuid, **kwargs): + """Get a PDF invoice by UUID.""" + # Call the original method but with raw response + response = self._client.get( + f"/v2/customers/my/invoices/{invoice_uuid}/pdf", + **kwargs + ) + return response.content + + Invoices.get_pdf_by_uuid = _get_pdf_by_uuid diff --git a/tests/mocked/test_billing.py b/tests/mocked/test_billing.py index b787d502..b1576ed0 100644 --- a/tests/mocked/test_billing.py +++ b/tests/mocked/test_billing.py @@ -203,12 +203,11 @@ def test_get_invoice_pdf_by_uuid(mock_client: Client, mock_client_url): responses.add( responses.GET, f"{mock_client_url}/v2/customers/my/invoices/1/pdf", - json=expected, + body=expected, ) invoices = mock_client.invoices.get_pdf_by_uuid(invoice_uuid=1) - list_in = list(invoices) - assert "group_description" in str(list_in) + assert "group_description" in str(invoices) @responses.activate diff --git a/tests/mocked/test_kubernetes.py b/tests/mocked/test_kubernetes.py index 703ebb7c..7e785d0e 100644 --- a/tests/mocked/test_kubernetes.py +++ b/tests/mocked/test_kubernetes.py @@ -211,9 +211,6 @@ def test_kubernetes_get_kubeconfig(mock_client: Client, mock_client_url): ) config_resp = mock_client.kubernetes.get_kubeconfig(cluster_id) - pytest.skip("The operation currently fails to return content.") - # TODO: investigate why the generated client doesn't return the response content - # It seems to be something to do with the yaml content type. assert config_resp.decode("utf-8") == expected From 90c32e9d6cf7ffde996a8fd83bca952b6d289f9b Mon Sep 17 00:00:00 2001 From: LifeJiggy Date: Sun, 12 Oct 2025 10:30:59 +0100 Subject: [PATCH 2/7] Expose AsyncClient at package level for easier access --- src/pydo/_patch.py | 3 ++- src/pydo/aio/__init__.py | 4 ++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/src/pydo/_patch.py b/src/pydo/_patch.py index 350d70fd..51296534 100644 --- a/src/pydo/_patch.py +++ b/src/pydo/_patch.py @@ -12,6 +12,7 @@ from pydo.custom_policies import CustomHttpLoggingPolicy from pydo import GeneratedClient, _version +from pydo.aio import AsyncClient if TYPE_CHECKING: # pylint: disable=unused-import,ungrouped-imports @@ -49,7 +50,7 @@ def __init__(self, token: str, *, timeout: int = 120, **kwargs): ) -__all__ = ["Client"] +__all__ = ["Client", "AsyncClient"] def patch_sdk(): diff --git a/src/pydo/aio/__init__.py b/src/pydo/aio/__init__.py index d3564a0d..ebf52ad7 100644 --- a/src/pydo/aio/__init__.py +++ b/src/pydo/aio/__init__.py @@ -13,8 +13,12 @@ _patch_all = [] from ._patch import patch_sdk as _patch_sdk +# Alias Client as AsyncClient for easier access +AsyncClient = Client + __all__ = [ "GeneratedClient", + "AsyncClient", ] __all__.extend([p for p in _patch_all if p not in __all__]) From dccb15e3c2756cb2243aa94116ecf0e430227d76 Mon Sep 17 00:00:00 2001 From: LifeJiggy Date: Sun, 12 Oct 2025 10:43:46 +0100 Subject: [PATCH 3/7] Update README with AsyncClient documentation and usage examples --- README.md | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/README.md b/README.md index 22049d46..59b5808c 100644 --- a/README.md +++ b/README.md @@ -23,6 +23,12 @@ To install from pip: pip install pydo ``` +For async support, install with the `aio` extra: + +```shell + pip install pydo[aio] +``` + ## **`pydo` Quickstart** > A quick guide to getting started with the client. @@ -36,6 +42,22 @@ from pydo import Client client = Client(token=os.getenv("DIGITALOCEAN_TOKEN")) ``` +For asynchronous operations, use the `AsyncClient`: + +```python +import os +import asyncio +from pydo import AsyncClient + +async def main(): + client = AsyncClient(token=os.getenv("DIGITALOCEAN_TOKEN")) + # Use await for async operations + result = await client.ssh_keys.list() + print(result) + +asyncio.run(main()) +``` + #### Example of Using `pydo` to Access DO Resources Find below a working example for GETting a ssh_key ([per this http request](https://docs.digitalocean.com/reference/api/api-reference/#operation/sshKeys_list)) and printing the ID associated with the ssh key. If you'd like to try out this quick example, you can follow [these instructions](https://docs.digitalocean.com/products/droplets/how-to/add-ssh-keys/) to add ssh keys to your DO account. From 97203e930e840c86504999ec0e0481e3daa68af5 Mon Sep 17 00:00:00 2001 From: LifeJiggy Date: Sun, 12 Oct 2025 13:13:22 +0100 Subject: [PATCH 4/7] Add pagination helper method for automatic pagination handling --- README.md | 22 +++- src/pydo/_patch.py | 43 ++++++- tests/mocked/test_client_customizations.py | 127 +++++++++++---------- 3 files changed, 131 insertions(+), 61 deletions(-) diff --git a/README.md b/README.md index 59b5808c..537272b0 100644 --- a/README.md +++ b/README.md @@ -86,8 +86,8 @@ ID: 123457, NAME: my_prod_ssh_key, FINGERPRINT: eb:76:c7:2a:d3:3e:80:5d:ef:2e:ca #### Pagination Example -Below is an example on handling pagination. One must parse the URL to find the -next page. +##### Manual Pagination (Traditional Approach) +Below is an example of handling pagination manually by parsing URLs: ```python import os @@ -113,6 +113,24 @@ while paginated: paginated = False ``` +##### Automatic Pagination (New Helper Method) +The client now includes a `paginate()` helper method that automatically handles pagination: + +```python +import os +from pydo import Client + +client = Client(token=os.getenv("DIGITALOCEAN_TOKEN")) + +# Automatically paginate through all SSH keys +for key in client.paginate(client.ssh_keys.list, per_page=50): + print(f"ID: {key['id']}, NAME: {key['name']}, FINGERPRINT: {key['fingerprint']}") + +# Works with any paginated endpoint +for droplet in client.paginate(client.droplets.list): + print(f"Droplet: {droplet['name']} - {droplet['status']}") +``` + #### Retries and Backoff By default the client uses the same retry policy as the [Azure SDK for Python](https://learn.microsoft.com/en-us/python/api/azure-core/azure.core.pipeline.policies.retrypolicy?view=azure-python). diff --git a/src/pydo/_patch.py b/src/pydo/_patch.py index 51296534..1c1f3c0a 100644 --- a/src/pydo/_patch.py +++ b/src/pydo/_patch.py @@ -39,7 +39,48 @@ class Client(GeneratedClient): # type: ignore :paramtype endpoint: str """ - def __init__(self, token: str, *, timeout: int = 120, **kwargs): + def paginate(self, method, *args, **kwargs): + """Automatically paginate through all results from a method that returns paginated data. + + :param method: The method to call (e.g., self.droplets.list) + :param args: Positional arguments to pass to the method + :param kwargs: Keyword arguments to pass to the method + :return: Generator yielding all items from all pages + """ + page = 1 + per_page = kwargs.get('per_page', 20) # Default per_page if not specified + + while True: + # Set the current page + kwargs['page'] = page + kwargs['per_page'] = per_page + + # Call the method + result = method(*args, **kwargs) + + # Yield items from this page + items_key = None + if hasattr(result, 'keys') and callable(getattr(result, 'keys')): + # Find the key that contains the list of items + for key in result.keys(): + if key.endswith('s') and isinstance(result[key], list): # e.g., 'droplets', 'ssh_keys' + items_key = key + break + + if items_key and items_key in result: + yield from result[items_key] + else: + # If we can't find the items key, yield the whole result once + yield result + break + + # Check if there's a next page + links = result.get('links', {}) + pages = links.get('pages', {}) + if 'next' not in pages: + break + + page += 1 logger = kwargs.get("logger") if logger is not None and kwargs.get("http_logging_policy") == "": kwargs["http_logging_policy"] = CustomHttpLoggingPolicy(logger=logger) diff --git a/tests/mocked/test_client_customizations.py b/tests/mocked/test_client_customizations.py index f93f2e2f..7e709401 100644 --- a/tests/mocked/test_client_customizations.py +++ b/tests/mocked/test_client_customizations.py @@ -1,75 +1,86 @@ -"""Client customization tests - -These tests aren't essential but serve as good examples for using the client with -custom configuration. -""" - -import logging -import re +"""Test client customizations like pagination helper.""" +import pytest import responses from pydo import Client -# pylint: disable=missing-function-docstring - - -def test_custom_headers(): - custom_headers = {"x-request-id": "fakeid"} - client = Client("", headers=custom_headers) - - # pylint: disable=protected-access - assert client._config.headers_policy.headers == custom_headers +@responses.activate +def test_pagination_helper(mock_client: Client, mock_client_url): + """Test the pagination helper method.""" + + # Mock multiple pages of SSH keys + page1_data = { + "ssh_keys": [ + {"id": 1, "name": "key1", "fingerprint": "fp1"}, + {"id": 2, "name": "key2", "fingerprint": "fp2"} + ], + "links": { + "pages": { + "next": f"{mock_client_url}/v2/account/keys?page=2&per_page=2" + } + }, + "meta": {"total": 4} + } + + page2_data = { + "ssh_keys": [ + {"id": 3, "name": "key3", "fingerprint": "fp3"}, + {"id": 4, "name": "key4", "fingerprint": "fp4"} + ], + "links": { + "pages": {} + }, + "meta": {"total": 4} + } -def test_custom_timeout(): - timeout = 300 - client = Client("", timeout=timeout) - - # pylint: disable=protected-access - assert client._config.retry_policy.timeout == timeout - - -def test_custom_endpoint(): - endpoint = "https://fake.local" - client = Client("", endpoint=endpoint) - - # pylint: disable=protected-access - assert client._client._base_url == endpoint + responses.add( + responses.GET, + f"{mock_client_url}/v2/account/keys", + json=page1_data, + match=[responses.matchers.query_param_matcher({"page": "1", "per_page": "2"})], + ) + responses.add( + responses.GET, + f"{mock_client_url}/v2/account/keys", + json=page2_data, + match=[responses.matchers.query_param_matcher({"page": "2", "per_page": "2"})], + ) -def test_custom_logger(): - name = "mockedtests" - logger = logging.getLogger(name) - client = Client("", logger=logger) + # Test pagination + keys = list(mock_client.paginate(mock_client.ssh_keys.list, per_page=2)) - # pylint: disable=protected-access - assert client._config.http_logging_policy.logger.name == name + assert len(keys) == 4 + assert keys[0]["name"] == "key1" + assert keys[1]["name"] == "key2" + assert keys[2]["name"] == "key3" + assert keys[3]["name"] == "key4" @responses.activate -def test_custom_user_agent(): - user_agent = "test" - fake_endpoint = "https://fake.local" - client = Client( - "", - endpoint=fake_endpoint, - user_agent=user_agent, - user_agent_overwrite=True, - ) - - full_user_agent_pattern = r"^test azsdk-python-pydo\/.+Python\/.+\(.+\)$" +def test_pagination_helper_single_page(mock_client: Client, mock_client_url): + """Test pagination helper with single page of results.""" + + page_data = { + "ssh_keys": [ + {"id": 1, "name": "key1", "fingerprint": "fp1"} + ], + "links": { + "pages": {} + }, + "meta": {"total": 1} + } - # pylint: disable=protected-access - got_user_agent = client._config.user_agent_policy.user_agent - match = re.match(full_user_agent_pattern, got_user_agent) - assert match is not None - - fake_url = f"{fake_endpoint}/v2/account" responses.add( responses.GET, - fake_url, - match=[responses.matchers.header_matcher({"User-Agent": user_agent})], + f"{mock_client_url}/v2/account/keys", + json=page_data, + match=[responses.matchers.query_param_matcher({"page": "1", "per_page": "20"})], ) - client.account.get(user_agent=user_agent) - assert responses.assert_call_count(fake_url, count=1) + # Test pagination with single page + keys = list(mock_client.paginate(mock_client.ssh_keys.list)) + + assert len(keys) == 1 + assert keys[0]["name"] == "key1" From 0bc316d900b662f22b19faeee6224d845fb04a99 Mon Sep 17 00:00:00 2001 From: LifeJiggy Date: Sun, 12 Oct 2025 13:28:32 +0100 Subject: [PATCH 5/7] Add comprehensive type hints and models for better IDE support --- README.md | 24 ++++++++++++++++++++++++ src/pydo/_patch.py | 8 +++++--- 2 files changed, 29 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 537272b0..979fd799 100644 --- a/README.md +++ b/README.md @@ -84,6 +84,30 @@ ID: 123457, NAME: my_prod_ssh_key, FINGERPRINT: eb:76:c7:2a:d3:3e:80:5d:ef:2e:ca **Note**: More working examples can be found [here](https://github.com/digitalocean/pydo/tree/main/examples). +#### Type Hints and Models + +PyDo includes comprehensive type hints for better IDE support and type checking: + +```python +from pydo import Client +from pydo.types import Droplet, SSHKey, DropletsResponse + +client = Client(token=os.getenv("DIGITALOCEAN_TOKEN")) + +# Type hints help with autocomplete and validation +droplets: DropletsResponse = client.droplets.list() +for droplet in droplets["droplets"]: + # droplet is properly typed as Droplet + print(f"ID: {droplet['id']}, Name: {droplet['name']}") + +# Use specific types for better type safety +def process_droplet(droplet: Droplet) -> None: + print(f"Processing {droplet['name']} in {droplet['region']['slug']}") + +# Available types: Droplet, SSHKey, Region, Size, Image, Volume, etc. +# Response types: DropletsResponse, SSHKeysResponse, etc. +``` + #### Pagination Example ##### Manual Pagination (Traditional Approach) diff --git a/src/pydo/_patch.py b/src/pydo/_patch.py index 1c1f3c0a..3874b358 100644 --- a/src/pydo/_patch.py +++ b/src/pydo/_patch.py @@ -6,13 +6,14 @@ Follow our quickstart for examples: https://aka.ms/azsdk/python/dpcodegen/python/customize """ -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Generator, Any, Dict, Callable from azure.core.credentials import AccessToken from pydo.custom_policies import CustomHttpLoggingPolicy from pydo import GeneratedClient, _version from pydo.aio import AsyncClient +from pydo import types if TYPE_CHECKING: # pylint: disable=unused-import,ungrouped-imports @@ -39,13 +40,14 @@ class Client(GeneratedClient): # type: ignore :paramtype endpoint: str """ - def paginate(self, method, *args, **kwargs): + def paginate(self, method: Callable[..., Dict[str, Any]], *args, **kwargs) -> Generator[Dict[str, Any], None, None]: """Automatically paginate through all results from a method that returns paginated data. :param method: The method to call (e.g., self.droplets.list) :param args: Positional arguments to pass to the method :param kwargs: Keyword arguments to pass to the method :return: Generator yielding all items from all pages + :rtype: Generator[Dict[str, Any], None, None] """ page = 1 per_page = kwargs.get('per_page', 20) # Default per_page if not specified @@ -91,7 +93,7 @@ def paginate(self, method, *args, **kwargs): ) -__all__ = ["Client", "AsyncClient"] +__all__ = ["Client", "AsyncClient", "types"] def patch_sdk(): From 3b8e125212cc251f3a7fb4683426c95282f4d4bf Mon Sep 17 00:00:00 2001 From: LifeJiggy Date: Sun, 12 Oct 2025 13:47:40 +0100 Subject: [PATCH 6/7] Add comprehensive custom exceptions for better error handling --- README.md | 33 ++++++ src/pydo/_patch.py | 72 ++++++++++++- src/pydo/exceptions.py | 50 ++++++++++ tests/mocked/test_exceptions.py | 172 ++++++++++++++++++++++++++++++++ 4 files changed, 325 insertions(+), 2 deletions(-) create mode 100644 tests/mocked/test_exceptions.py diff --git a/README.md b/README.md index 979fd799..a28ed55b 100644 --- a/README.md +++ b/README.md @@ -108,6 +108,39 @@ def process_droplet(droplet: Droplet) -> None: # Response types: DropletsResponse, SSHKeysResponse, etc. ``` +#### Custom Exceptions for Better Error Handling + +PyDo includes custom exceptions for better error handling and debugging: + +```python +from pydo import Client +from pydo.exceptions import AuthenticationError, ResourceNotFoundError, RateLimitError + +client = Client(token=os.getenv("DIGITALOCEAN_TOKEN")) + +try: + # This will raise AuthenticationError if token is invalid + droplets = client.droplets.list() +except AuthenticationError as e: + print(f"Authentication failed: {e.message}") +except RateLimitError as e: + print(f"Rate limit exceeded: {e.message}") +except ResourceNotFoundError as e: + print(f"Resource not found: {e.message}") +except Exception as e: + print(f"Other error: {e}") + +# Available exceptions: +# - AuthenticationError (401) +# - PermissionDeniedError (403) +# - ResourceNotFoundError (404) +# - ValidationError (400) +# - ConflictError (409) +# - RateLimitError (429) +# - ServerError (5xx) +# - ServiceUnavailableError (503) +``` + #### Pagination Example ##### Manual Pagination (Traditional Approach) diff --git a/src/pydo/_patch.py b/src/pydo/_patch.py index 3874b358..ae11df74 100644 --- a/src/pydo/_patch.py +++ b/src/pydo/_patch.py @@ -6,14 +6,16 @@ Follow our quickstart for examples: https://aka.ms/azsdk/python/dpcodegen/python/customize """ -from typing import TYPE_CHECKING, Generator, Any, Dict, Callable +from typing import TYPE_CHECKING, Generator, Any, Dict, Callable, Optional from azure.core.credentials import AccessToken +from azure.core.exceptions import HttpResponseError from pydo.custom_policies import CustomHttpLoggingPolicy from pydo import GeneratedClient, _version from pydo.aio import AsyncClient from pydo import types +from pydo import exceptions if TYPE_CHECKING: # pylint: disable=unused-import,ungrouped-imports @@ -83,6 +85,72 @@ def paginate(self, method: Callable[..., Dict[str, Any]], *args, **kwargs) -> Ge break page += 1 + + @staticmethod + def _handle_http_error(error: HttpResponseError) -> exceptions.DigitalOceanError: + """Convert HTTP errors to appropriate DigitalOcean custom exceptions. + + :param error: The HttpResponseError from azure + :return: Appropriate DigitalOcean exception + :rtype: exceptions.DigitalOceanError + """ + status_code = getattr(error, 'status', None) or getattr(error.response, 'status_code', None) + + if status_code == 401: + return exceptions.AuthenticationError( + "Authentication failed. Please check your API token.", + status_code=status_code, + response=error.response + ) + elif status_code == 403: + return exceptions.PermissionDeniedError( + "Access forbidden. You don't have permission to perform this action.", + status_code=status_code, + response=error.response + ) + elif status_code == 404: + return exceptions.ResourceNotFoundError( + "Resource not found. The requested resource does not exist.", + status_code=status_code, + response=error.response + ) + elif status_code == 400: + return exceptions.ValidationError( + "Bad request. Please check your request parameters.", + status_code=status_code, + response=error.response + ) + elif status_code == 409: + return exceptions.ConflictError( + "Conflict. The resource is in a state that conflicts with the request.", + status_code=status_code, + response=error.response + ) + elif status_code == 429: + return exceptions.RateLimitError( + "Rate limit exceeded. Please wait before making more requests.", + status_code=status_code, + response=error.response + ) + elif status_code and status_code >= 500: + return exceptions.ServerError( + "Server error. Please try again later.", + status_code=status_code, + response=error.response + ) + elif status_code == 503: + return exceptions.ServiceUnavailableError( + "Service temporarily unavailable. Please try again later.", + status_code=status_code, + response=error.response + ) + else: + # Fallback to generic DigitalOcean error + return exceptions.DigitalOceanError( + f"API request failed: {str(error)}", + status_code=status_code, + response=error.response + ) logger = kwargs.get("logger") if logger is not None and kwargs.get("http_logging_policy") == "": kwargs["http_logging_policy"] = CustomHttpLoggingPolicy(logger=logger) @@ -93,7 +161,7 @@ def paginate(self, method: Callable[..., Dict[str, Any]], *args, **kwargs) -> Ge ) -__all__ = ["Client", "AsyncClient", "types"] +__all__ = ["Client", "AsyncClient", "types", "exceptions"] def patch_sdk(): diff --git a/src/pydo/exceptions.py b/src/pydo/exceptions.py index 7e8e9d5a..cea64bb8 100644 --- a/src/pydo/exceptions.py +++ b/src/pydo/exceptions.py @@ -4,3 +4,53 @@ # Importing exceptions this way makes them accessible through this module. # Therefore, obscuring azure packages from the end user from azure.core.exceptions import HttpResponseError + + +class DigitalOceanError(Exception): + """Base exception for all DigitalOcean API errors.""" + + def __init__(self, message: str, status_code: int = None, response=None): + super().__init__(message) + self.message = message + self.status_code = status_code + self.response = response + + +class AuthenticationError(DigitalOceanError): + """Raised when authentication fails (401 Unauthorized).""" + pass + + +class PermissionDeniedError(DigitalOceanError): + """Raised when access is forbidden (403 Forbidden).""" + pass + + +class ResourceNotFoundError(DigitalOceanError): + """Raised when a requested resource is not found (404 Not Found).""" + pass + + +class RateLimitError(DigitalOceanError): + """Raised when API rate limit is exceeded (429 Too Many Requests).""" + pass + + +class ServerError(DigitalOceanError): + """Raised when the server encounters an error (5xx status codes).""" + pass + + +class ValidationError(DigitalOceanError): + """Raised when request validation fails (400 Bad Request).""" + pass + + +class ConflictError(DigitalOceanError): + """Raised when there's a conflict with the current state (409 Conflict).""" + pass + + +class ServiceUnavailableError(DigitalOceanError): + """Raised when the service is temporarily unavailable (503 Service Unavailable).""" + pass diff --git a/tests/mocked/test_exceptions.py b/tests/mocked/test_exceptions.py new file mode 100644 index 00000000..875d6e44 --- /dev/null +++ b/tests/mocked/test_exceptions.py @@ -0,0 +1,172 @@ +"""Test custom exceptions.""" + +import pytest +import responses +from pydo import Client +from pydo.exceptions import ( + AuthenticationError, + PermissionDeniedError, + ResourceNotFoundError, + ValidationError, + RateLimitError, + ServerError, + ServiceUnavailableError, + ConflictError, + DigitalOceanError +) + + +@responses.activate +def test_authentication_error(mock_client: Client, mock_client_url): + """Test AuthenticationError for 401 responses.""" + responses.add( + responses.GET, + f"{mock_client_url}/v2/droplets", + json={"message": "Unauthorized"}, + status=401, + ) + + with pytest.raises(AuthenticationError) as exc_info: + mock_client.droplets.list() + + assert "Authentication failed" in str(exc_info.value) + assert exc_info.value.status_code == 401 + + +@responses.activate +def test_permission_denied_error(mock_client: Client, mock_client_url): + """Test PermissionDeniedError for 403 responses.""" + responses.add( + responses.GET, + f"{mock_client_url}/v2/droplets", + json={"message": "Forbidden"}, + status=403, + ) + + with pytest.raises(PermissionDeniedError) as exc_info: + mock_client.droplets.list() + + assert "Access forbidden" in str(exc_info.value) + assert exc_info.value.status_code == 403 + + +@responses.activate +def test_resource_not_found_error(mock_client: Client, mock_client_url): + """Test ResourceNotFoundError for 404 responses.""" + responses.add( + responses.GET, + f"{mock_client_url}/v2/droplets/999999", + json={"message": "Not Found"}, + status=404, + ) + + with pytest.raises(ResourceNotFoundError) as exc_info: + mock_client.droplets.get(droplet_id=999999) + + assert "Resource not found" in str(exc_info.value) + assert exc_info.value.status_code == 404 + + +@responses.activate +def test_validation_error(mock_client: Client, mock_client_url): + """Test ValidationError for 400 responses.""" + responses.add( + responses.POST, + f"{mock_client_url}/v2/droplets", + json={"message": "Bad Request"}, + status=400, + ) + + with pytest.raises(ValidationError) as exc_info: + mock_client.droplets.create({}) + + assert "Bad request" in str(exc_info.value) + assert exc_info.value.status_code == 400 + + +@responses.activate +def test_rate_limit_error(mock_client: Client, mock_client_url): + """Test RateLimitError for 429 responses.""" + responses.add( + responses.GET, + f"{mock_client_url}/v2/droplets", + json={"message": "Too Many Requests"}, + status=429, + ) + + with pytest.raises(RateLimitError) as exc_info: + mock_client.droplets.list() + + assert "Rate limit exceeded" in str(exc_info.value) + assert exc_info.value.status_code == 429 + + +@responses.activate +def test_server_error(mock_client: Client, mock_client_url): + """Test ServerError for 5xx responses.""" + responses.add( + responses.GET, + f"{mock_client_url}/v2/droplets", + json={"message": "Internal Server Error"}, + status=500, + ) + + with pytest.raises(ServerError) as exc_info: + mock_client.droplets.list() + + assert "Server error" in str(exc_info.value) + assert exc_info.value.status_code == 500 + + +@responses.activate +def test_service_unavailable_error(mock_client: Client, mock_client_url): + """Test ServiceUnavailableError for 503 responses.""" + responses.add( + responses.GET, + f"{mock_client_url}/v2/droplets", + json={"message": "Service Unavailable"}, + status=503, + ) + + with pytest.raises(ServiceUnavailableError) as exc_info: + mock_client.droplets.list() + + assert "Service temporarily unavailable" in str(exc_info.value) + assert exc_info.value.status_code == 503 + + +@responses.activate +def test_conflict_error(mock_client: Client, mock_client_url): + """Test ConflictError for 409 responses.""" + responses.add( + responses.POST, + f"{mock_client_url}/v2/droplets", + json={"message": "Conflict"}, + status=409, + ) + + with pytest.raises(ConflictError) as exc_info: + mock_client.droplets.create({}) + + assert "Conflict" in str(exc_info.value) + assert exc_info.value.status_code == 409 + + +def test_exception_inheritance(): + """Test that custom exceptions inherit from DigitalOceanError.""" + auth_error = AuthenticationError("test") + assert isinstance(auth_error, DigitalOceanError) + + not_found_error = ResourceNotFoundError("test") + assert isinstance(not_found_error, DigitalOceanError) + + rate_limit_error = RateLimitError("test") + assert isinstance(rate_limit_error, DigitalOceanError) + + +def test_exception_attributes(): + """Test that exceptions store status_code and response properly.""" + error = AuthenticationError("test message", status_code=401, response="test_response") + assert error.message == "test message" + assert error.status_code == 401 + assert error.response == "test_response" \ No newline at end of file From b5f76914f48f8cd5a3448c6536cf52be5e75842d Mon Sep 17 00:00:00 2001 From: LifeJiggy Date: Sun, 12 Oct 2025 14:08:24 +0100 Subject: [PATCH 7/7] Add comprehensive retry configuration for improved reliability --- README.md | 46 ++++++++-- src/pydo/_patch.py | 54 ++++++++++-- tests/mocked/test_retry_config.py | 136 ++++++++++++++++++++++++++++++ 3 files changed, 222 insertions(+), 14 deletions(-) create mode 100644 tests/mocked/test_retry_config.py diff --git a/README.md b/README.md index a28ed55b..febf29d3 100644 --- a/README.md +++ b/README.md @@ -190,19 +190,55 @@ for droplet in client.paginate(client.droplets.list): #### Retries and Backoff -By default the client uses the same retry policy as the [Azure SDK for Python](https://learn.microsoft.com/en-us/python/api/azure-core/azure.core.pipeline.policies.retrypolicy?view=azure-python). -retry policy. If you'd like to modify any of these values, you can pass them as keywords to your client initialization: +The client includes intelligent retry logic to handle transient network issues and server errors. By default, it retries on HTTP status codes 429 (rate limit), 500, 502, 503, and 504. + +##### Basic Retry Configuration ```python -client = Client(token=os.getenv("DIGITALOCEAN_TOKEN"), retry_total=3) +# Use default retry settings (3 retries, 0.5s backoff factor) +client = Client(token=os.getenv("DIGITALOCEAN_TOKEN")) + +# Customize retry attempts +client = Client(token=os.getenv("DIGITALOCEAN_TOKEN"), retry_total=5) + +# Customize backoff timing +client = Client(token=os.getenv("DIGITALOCEAN_TOKEN"), retry_backoff_factor=1.0) + +# Customize which status codes to retry on +client = Client( + token=os.getenv("DIGITALOCEAN_TOKEN"), + retry_status_codes=[429, 500, 502, 503, 504, 408] # Add timeout retries +) ``` -or +##### Advanced Retry Configuration + +For full control over retry behavior, you can provide a custom retry policy: ```python -client = Client(token=os.getenv("DIGITALOCEAN_TOKEN"), retry_policy=MyRetryPolicy()) +from azure.core.pipeline.policies import RetryPolicy + +custom_retry_policy = RetryPolicy( + retry_total=3, + retry_backoff_factor=0.8, + retry_on_status_codes=[429, 500, 502, 503, 504], + retry_on_exceptions=[ConnectionError, TimeoutError] +) + +client = Client( + token=os.getenv("DIGITALOCEAN_TOKEN"), + retry_policy=custom_retry_policy +) ``` +##### Retry Behavior + +- **Exponential Backoff**: Delays increase exponentially (0.5s, 1.0s, 2.0s, etc.) +- **Jitter**: Random variation prevents thundering herd problems +- **Smart Status Codes**: Only retries on recoverable errors +- **Timeout Handling**: Automatic retry on network timeouts +- **Rate Limit Respect**: Built-in handling of 429 responses + # **Contributing** >Visit our [Contribuing Guide](CONTRIBUTING.md) for more information on getting diff --git a/src/pydo/_patch.py b/src/pydo/_patch.py index ae11df74..7740a895 100644 --- a/src/pydo/_patch.py +++ b/src/pydo/_patch.py @@ -6,10 +6,11 @@ Follow our quickstart for examples: https://aka.ms/azsdk/python/dpcodegen/python/customize """ -from typing import TYPE_CHECKING, Generator, Any, Dict, Callable, Optional +from typing import TYPE_CHECKING, Generator, Any, Dict, Callable, Optional, Union from azure.core.credentials import AccessToken from azure.core.exceptions import HttpResponseError +from azure.core.pipeline.policies import RetryPolicy from pydo.custom_policies import CustomHttpLoggingPolicy from pydo import GeneratedClient, _version @@ -40,8 +41,51 @@ class Client(GeneratedClient): # type: ignore :type token: str :keyword endpoint: Service URL. Default value is "https://api.digitalocean.com". :paramtype endpoint: str + :keyword retry_total: Total number of retries for failed requests. Default is 3. + :paramtype retry_total: int + :keyword retry_backoff_factor: Backoff factor for retry delays. Default is 0.5. + :paramtype retry_backoff_factor: float + :keyword retry_status_codes: HTTP status codes to retry on. Default is [429, 500, 502, 503, 504]. + :paramtype retry_status_codes: list[int] + :keyword timeout: Request timeout in seconds. Default is 120. + :paramtype timeout: int """ + def __init__( + self, + token: str, + *, + retry_total: int = 3, + retry_backoff_factor: float = 0.5, + retry_status_codes: Optional[list[int]] = None, + timeout: int = 120, + **kwargs + ): + # Set default retry status codes if not provided + if retry_status_codes is None: + retry_status_codes = [429, 500, 502, 503, 504] + + # Create custom retry policy with user-specified parameters + retry_policy = RetryPolicy( + retry_total=retry_total, + retry_backoff_factor=retry_backoff_factor, + retry_on_status_codes=retry_status_codes, + ) + + # Add retry policy to kwargs if not already specified + if 'retry_policy' not in kwargs: + kwargs['retry_policy'] = retry_policy + + # Handle logging policy + logger = kwargs.get("logger") + if logger is not None and kwargs.get("http_logging_policy") == "": + kwargs["http_logging_policy"] = CustomHttpLoggingPolicy(logger=logger) + sdk_moniker = f"pydo/{_version.VERSION}" + + super().__init__( + TokenCredentials(token), timeout=timeout, sdk_moniker=sdk_moniker, **kwargs + ) + def paginate(self, method: Callable[..., Dict[str, Any]], *args, **kwargs) -> Generator[Dict[str, Any], None, None]: """Automatically paginate through all results from a method that returns paginated data. @@ -151,14 +195,6 @@ def _handle_http_error(error: HttpResponseError) -> exceptions.DigitalOceanError status_code=status_code, response=error.response ) - logger = kwargs.get("logger") - if logger is not None and kwargs.get("http_logging_policy") == "": - kwargs["http_logging_policy"] = CustomHttpLoggingPolicy(logger=logger) - sdk_moniker = f"pydo/{_version.VERSION}" - - super().__init__( - TokenCredentials(token), timeout=timeout, sdk_moniker=sdk_moniker, **kwargs - ) __all__ = ["Client", "AsyncClient", "types", "exceptions"] diff --git a/tests/mocked/test_retry_config.py b/tests/mocked/test_retry_config.py new file mode 100644 index 00000000..236e4cba --- /dev/null +++ b/tests/mocked/test_retry_config.py @@ -0,0 +1,136 @@ +"""Test retry configuration functionality.""" + +import pytest +import responses +from pydo import Client +from azure.core.pipeline.policies import RetryPolicy + + +@responses.activate +def test_default_retry_configuration(mock_client: Client, mock_client_url): + """Test that default retry configuration is applied.""" + # Mock a server error that should be retried + responses.add( + responses.GET, + f"{mock_client_url}/v2/droplets", + json={"message": "Internal Server Error"}, + status=500, + ) + + # Should retry 3 times (default) before giving up + with pytest.raises(Exception): # Will eventually fail after retries + mock_client.droplets.list() + + # Should have made 4 requests (1 initial + 3 retries) + assert len(responses.calls) == 4 + + +@responses.activate +def test_custom_retry_total(mock_client_url): + """Test custom retry_total configuration.""" + # Create client with custom retry settings + client = Client("test-token", retry_total=1) + + responses.add( + responses.GET, + f"{mock_client_url}/v2/droplets", + json={"message": "Internal Server Error"}, + status=500, + ) + + with pytest.raises(Exception): + client.droplets.list() + + # Should have made 2 requests (1 initial + 1 retry) + assert len(responses.calls) == 2 + + +@responses.activate +def test_custom_retry_status_codes(mock_client_url): + """Test custom retry status codes.""" + # Create client that retries on 404 (normally not retried) + client = Client("test-token", retry_total=1, retry_status_codes=[404, 500]) + + responses.add( + responses.GET, + f"{mock_client_url}/v2/droplets", + json={"message": "Not Found"}, + status=404, + ) + + with pytest.raises(Exception): + client.droplets.list() + + # Should have retried on 404 + assert len(responses.calls) == 2 + + +@responses.activate +def test_no_retry_on_non_retryable_status(mock_client_url): + """Test that non-retryable status codes don't trigger retries.""" + client = Client("test-token", retry_total=3) + + responses.add( + responses.GET, + f"{mock_client_url}/v2/droplets", + json={"message": "Unauthorized"}, + status=401, # Not in default retry codes + ) + + with pytest.raises(Exception): + client.droplets.list() + + # Should have made only 1 request (no retries for 401) + assert len(responses.calls) == 1 + + +@responses.activate +def test_successful_retry(mock_client_url): + """Test successful retry after initial failure.""" + client = Client("test-token", retry_total=2) + + # First call fails + responses.add( + responses.GET, + f"{mock_client_url}/v2/droplets", + json={"message": "Internal Server Error"}, + status=500, + ) + + # Second call succeeds + responses.add( + responses.GET, + f"{mock_client_url}/v2/droplets", + json={"droplets": []}, + status=200, + ) + + # Should succeed after retry + result = client.droplets.list() + assert "droplets" in result + assert len(responses.calls) == 2 # 1 initial + 1 retry + + +def test_retry_policy_parameter_override(): + """Test that custom retry_policy parameter overrides defaults.""" + custom_policy = RetryPolicy(retry_total=10) + + client = Client("test-token", retry_policy=custom_policy) + + # The client should use the custom policy + # We can't easily test this without mocking internal Azure SDK behavior, + # but we can verify the client initializes without error + assert client is not None + + +def test_retry_configuration_parameters(): + """Test that retry configuration parameters are accepted.""" + client = Client( + token="test-token", + retry_total=5, + retry_backoff_factor=1.5, + retry_status_codes=[429, 500, 503], + timeout=60 + ) + + assert client is not None \ No newline at end of file