diff --git a/client_gen_config.md b/client_gen_config.md index 6b931fc..5f689cc 100644 --- a/client_gen_config.md +++ b/client_gen_config.md @@ -73,4 +73,17 @@ directive: where: '$.paths."/v2/apps".post' transform: > $["parameters"] = []; + + # Strip path-level and operation-level "servers" overrides from all paths. + # Without this, Autorest generates a duplicate `endpoint` parameter in + # GeneratedClient.__init__ (one per unique server URL), which is a Python + # syntax error. Runtime multi-base-URL routing is handled by + # _BaseURLProxy in custom_extensions.py instead. + - from: openapi-document + where: '$.paths.*' + transform: > + delete $["servers"]; + for (const m of ["get","post","put","patch","delete","head","options"]) { + if ($[m] && $[m].servers) { delete $[m].servers; } + } ``` diff --git a/src/pydo/_patch.py b/src/pydo/_patch.py index 350d70f..b76607e 100644 --- a/src/pydo/_patch.py +++ b/src/pydo/_patch.py @@ -11,6 +11,7 @@ from azure.core.credentials import AccessToken from pydo.custom_policies import CustomHttpLoggingPolicy +from pydo.custom_extensions import _BaseURLProxy, INFERENCE_BASE_URL from pydo import GeneratedClient, _version if TYPE_CHECKING: @@ -32,13 +33,28 @@ def get_token(self, *args, **kwargs) -> AccessToken: class Client(GeneratedClient): # type: ignore """The official DigitalOcean Python client - :param token: A valid API token. + :param token: A valid API token / model access key. :type token: str :keyword endpoint: Service URL. Default value is "https://api.digitalocean.com". :paramtype endpoint: str + :keyword inference_endpoint: Serverless inference URL. + Default value is "https://inference.do-ai.run". + :paramtype inference_endpoint: str + :keyword agent_endpoint: Agent inference URL. Pass the per-agent + subdomain (e.g. ``"https://.agents.do-ai.run"``). + Required only when using agent inference endpoints. + :paramtype agent_endpoint: str """ - def __init__(self, token: str, *, timeout: int = 120, **kwargs): + def __init__( + self, + token: str, + *, + timeout: int = 120, + inference_endpoint: str = INFERENCE_BASE_URL, + agent_endpoint: str = "", + **kwargs, + ): logger = kwargs.get("logger") if logger is not None and kwargs.get("http_logging_policy") == "": kwargs["http_logging_policy"] = CustomHttpLoggingPolicy(logger=logger) @@ -48,6 +64,41 @@ def __init__(self, token: str, *, timeout: int = 120, **kwargs): TokenCredentials(token), timeout=timeout, sdk_moniker=sdk_moniker, **kwargs ) + self._setup_inference_routing(inference_endpoint, agent_endpoint) + + def _setup_inference_routing( + self, + inference_endpoint: str, + agent_endpoint: str, + ) -> None: + """Route Inference / AgentInference operation groups to their endpoints. + + * ``*Inference*`` (but not ``*AgentInference*``) → *inference_endpoint* + * ``*AgentInference*`` → *agent_endpoint* + + Both use the same token passed to ``Client.__init__``. + """ + inference_proxy = _BaseURLProxy( + self._client, + inference_endpoint, + ) + + agent_proxy = None + if agent_endpoint: + agent_proxy = _BaseURLProxy( + self._client, + agent_endpoint, + ) + + for attr in self.__dict__.values(): + if not hasattr(attr, "_client"): + continue + class_name = type(attr).__name__ + if class_name.startswith("AgentInference") and agent_proxy: + attr._client = agent_proxy + elif class_name.startswith("Inference"): + attr._client = inference_proxy + __all__ = ["Client"] diff --git a/src/pydo/aio/_patch.py b/src/pydo/aio/_patch.py index 37d16e1..c2bc290 100644 --- a/src/pydo/aio/_patch.py +++ b/src/pydo/aio/_patch.py @@ -13,6 +13,7 @@ from pydo import _version from pydo.custom_policies import CustomHttpLoggingPolicy +from pydo.custom_extensions import _BaseURLProxy, INFERENCE_BASE_URL from pydo.aio import GeneratedClient if TYPE_CHECKING: @@ -34,15 +35,30 @@ async def get_token(self, *args, **kwargs) -> AccessToken: class Client(GeneratedClient): # type: ignore - """The official DigitalOcean Python client + """The official DigitalOcean Python client (async) - :param token: A valid API token. + :param token: A valid API token / model access key. :type token: str :keyword endpoint: Service URL. Default value is "https://api.digitalocean.com". :paramtype endpoint: str + :keyword inference_endpoint: Serverless inference URL. + Default value is "https://inference.do-ai.run". + :paramtype inference_endpoint: str + :keyword agent_endpoint: Agent inference URL. Pass the per-agent + subdomain (e.g. ``"https://.agents.do-ai.run"``). + Required only when using agent inference endpoints. + :paramtype agent_endpoint: str """ - def __init__(self, token: str, *, timeout: int = 120, **kwargs): + def __init__( + self, + token: str, + *, + timeout: int = 120, + inference_endpoint: str = INFERENCE_BASE_URL, + agent_endpoint: str = "", + **kwargs, + ): logger = kwargs.get("logger") if logger is not None and kwargs.get("http_logging_policy") == "": kwargs["http_logging_policy"] = CustomHttpLoggingPolicy(logger=logger) @@ -52,6 +68,41 @@ def __init__(self, token: str, *, timeout: int = 120, **kwargs): TokenCredentials(token), timeout=timeout, sdk_moniker=sdk_moniker, **kwargs ) + self._setup_inference_routing(inference_endpoint, agent_endpoint) + + def _setup_inference_routing( + self, + inference_endpoint: str, + agent_endpoint: str, + ) -> None: + """Route Inference / AgentInference operation groups to their endpoints. + + * ``*Inference*`` (but not ``*AgentInference*``) → *inference_endpoint* + * ``*AgentInference*`` → *agent_endpoint* + + Both use the same token passed to ``Client.__init__``. + """ + inference_proxy = _BaseURLProxy( + self._client, + inference_endpoint, + ) + + agent_proxy = None + if agent_endpoint: + agent_proxy = _BaseURLProxy( + self._client, + agent_endpoint, + ) + + for attr in self.__dict__.values(): + if not hasattr(attr, "_client"): + continue + class_name = type(attr).__name__ + if class_name.startswith("AgentInference") and agent_proxy: + attr._client = agent_proxy + elif class_name.startswith("Inference"): + attr._client = inference_proxy + # Add all objects you want publicly available to users at this package level __all__ = ["Client"] # type: List[str] diff --git a/src/pydo/aio/operations/_patch.py b/src/pydo/aio/operations/_patch.py index 39ea63f..5836b4b 100644 --- a/src/pydo/aio/operations/_patch.py +++ b/src/pydo/aio/operations/_patch.py @@ -5,6 +5,10 @@ """Customize generated code here. Follow our quickstart for examples: https://aka.ms/azsdk/python/dpcodegen/python/customize + +Async mirror of ``pydo/operations/_patch.py``. See that module for the +full design rationale. If no inference / agent-inference operations have +been generated, this module exports nothing. """ from typing import TYPE_CHECKING @@ -12,9 +16,85 @@ # pylint: disable=unused-import,ungrouped-imports from typing import List -__all__ = ( - [] -) # type: List[str] # Add all objects you want publicly available to users at this package level +# --------------------------------------------------------------------------- +# Serverless Inference operations (async) +# --------------------------------------------------------------------------- + +try: + from ._operations import InferenceOperations as _GeneratedInferenceOperations + + import pydo.operations._operations as _ops + + _HAS_INFERENCE = True +except ImportError: + _HAS_INFERENCE = False + +if _HAS_INFERENCE: + from pydo.custom_extensions import ( + AsyncStreamingMixin, + install_streaming_wrappers, + ) + + class InferenceOperations(AsyncStreamingMixin, _GeneratedInferenceOperations): + """Async InferenceOperations with fully automatic streaming support. + + Mirror of the sync version in ``pydo/operations/_patch.py``. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + install_streaming_wrappers( + self, _GeneratedInferenceOperations, _ops, is_async=True + ) + + +# --------------------------------------------------------------------------- +# Agent Inference operations (async) +# --------------------------------------------------------------------------- + +try: + from ._operations import ( + AgentInferenceOperations as _GeneratedAgentInferenceOperations, + ) + + if not _HAS_INFERENCE: + import pydo.operations._operations as _ops + + _HAS_AGENT_INFERENCE = True +except ImportError: + _HAS_AGENT_INFERENCE = False + +if _HAS_AGENT_INFERENCE: + if not _HAS_INFERENCE: + from pydo.custom_extensions import ( + AsyncStreamingMixin, + install_streaming_wrappers, + ) + + class AgentInferenceOperations( + AsyncStreamingMixin, _GeneratedAgentInferenceOperations + ): + """Async AgentInferenceOperations with fully automatic streaming support. + + Mirror of the sync version in ``pydo/operations/_patch.py``. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + install_streaming_wrappers( + self, _GeneratedAgentInferenceOperations, _ops, is_async=True + ) + + +# --------------------------------------------------------------------------- +# Exports +# --------------------------------------------------------------------------- + +__all__ = [] # type: ignore[assignment] +if _HAS_INFERENCE: + __all__.append("InferenceOperations") +if _HAS_AGENT_INFERENCE: + __all__.append("AgentInferenceOperations") def patch_sdk(): diff --git a/src/pydo/custom_extensions.py b/src/pydo/custom_extensions.py new file mode 100644 index 0000000..76d651b --- /dev/null +++ b/src/pydo/custom_extensions.py @@ -0,0 +1,420 @@ +# ------------------------------------ +# Copyright (c) DigitalOcean. +# Licensed under the Apache-2.0 License. +# ------------------------------------ +"""Multi-base-URL routing and SSE streaming support. + +This file is preserved during ``make clean`` (matches the custom_*.py pattern) +and is NOT overwritten by code generation. + +Architecture +------------ +* ``_BaseURLProxy`` – lightweight proxy around a ``PipelineClient`` + that prepends a configured base URL to the generated path. Reuses + the original client's pipeline (auth, retry, logging). Usable for + any alternate host, not limited to inference endpoints. + +* ``StreamingMixin`` / ``AsyncStreamingMixin`` – mix-in classes that provide + the ``_auto_streaming_call`` method. Can be mixed into any operation + group (e.g. ``InferenceOperations``, ``AgentInferenceOperations``, or + future groups) in their ``_patch.py`` files. + +* ``SSEStream`` / ``AsyncSSEStream`` – iterators that parse Server-Sent + Events from a streaming HTTP response and yield parsed JSON chunks. + +* ``install_streaming_wrappers`` – called once in an Operations ``__init__`` + to automatically wrap every generated method that has a matching request + builder. When the user passes ``"stream": True`` in the body, the wrapper + runs the pipeline in streaming mode and returns an SSEStream instead of + falling through to the generated code (which would do ``response.json()`` + and fail on SSE data). + +Extensibility +------------- +Both multi-base-URL **and** streaming are fully automatic for new endpoints: + +1. Add the endpoint to the OpenAPI spec and run ``make generate``. +2. Autorest creates a new method in the generated operations class and a + matching ``build___request`` function. +3. ``install_streaming_wrappers`` discovers the pair at init time and creates + a wrapper that intercepts ``stream: True``. +4. ``_BaseURLProxy`` routes the request to the correct server. + +No manual changes to any ``_patch.py`` file are needed. +""" +import inspect +import json +import re +from io import IOBase +from typing import Any, AsyncIterator, Callable, Iterator, Optional +from urllib.parse import urlparse + +from azure.core.exceptions import ( + ClientAuthenticationError, + HttpResponseError, + ResourceExistsError, + ResourceNotFoundError, + ResourceNotModifiedError, + map_error, +) +from azure.core.rest import HttpRequest +from azure.core.utils import case_insensitive_dict + + +INFERENCE_BASE_URL = "https://inference.do-ai.run" + +_STREAMING_ERROR_MAP = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, +} + + +# --------------------------------------------------------------------------- +# Multi-base-URL proxy +# --------------------------------------------------------------------------- + + +class _BaseURLProxy: + """Proxy that redirects a PipelineClient to an alternate base URL. + + Reuses the wrapped client's pipeline (auth, retry, logging, etc.) + while transparently prepending a different base URL to the + generated path. This is a generic utility and is not limited to + inference endpoints. + + Parameters + ---------- + original_client : PipelineClient + The client whose pipeline (and transport) will be reused. + base_url : str + Target base URL (e.g. ``https://inference.do-ai.run``). + """ + + def __init__(self, original_client, base_url: str): + self._original = original_client + self._base_url = base_url.rstrip("/") + + def format_url(self, url_template: str, **kwargs: Any) -> str: + parsed = urlparse(url_template) + if not parsed.scheme: + return self._base_url + url_template + return url_template + + @property + def _pipeline(self): + return self._original._pipeline + + def __getattr__(self, name: str) -> Any: + return getattr(self._original, name) + + +# --------------------------------------------------------------------------- +# Auto-wrap generated methods with streaming detection +# --------------------------------------------------------------------------- + + +def _class_name_to_builder_prefix(class_name: str) -> str: + """Convert ``'AgentInferenceOperations'`` → ``'agent_inference'``.""" + name = class_name.replace("Operations", "") + return re.sub(r"(? "SSEStream": + body = kwargs.pop("body") + + error_map = dict(_STREAMING_ERROR_MAP) + error_map.update(kwargs.pop("error_map", {}) or {}) + kwargs.pop("cls", None) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + content_type: Optional[str] = kwargs.pop( + "content_type", _headers.pop("Content-Type", None) + ) + content_type = content_type or "application/json" + + _json = None + _content = None + if isinstance(body, (IOBase, bytes)): + _content = body + else: + body = {**body, "stream": True} + _json = body + + _request = builder( + content_type=content_type, + json=_json, + content=_content, + headers=_headers, + params=_params, + **kwargs, + ) + _request.url = self._client.format_url(_request.url) # type: ignore[attr-defined] + + pipeline_response = self._client._pipeline.run( # type: ignore[attr-defined] + _request, stream=True + ) + + response = pipeline_response.http_response + if response.status_code not in [200]: + response.read() + map_error( + status_code=response.status_code, + response=response, + error_map=error_map, + ) + raise HttpResponseError(response=response) + + return SSEStream(response) + + +class AsyncStreamingMixin: + """Provides ``_auto_streaming_call`` for **async** operation groups.""" + + async def _auto_streaming_call( + self, + builder: Callable[..., HttpRequest], + **kwargs: Any, + ) -> "AsyncSSEStream": + body = kwargs.pop("body") + + error_map = dict(_STREAMING_ERROR_MAP) + error_map.update(kwargs.pop("error_map", {}) or {}) + kwargs.pop("cls", None) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + content_type: Optional[str] = kwargs.pop( + "content_type", _headers.pop("Content-Type", None) + ) + content_type = content_type or "application/json" + + _json = None + _content = None + if isinstance(body, (IOBase, bytes)): + _content = body + else: + body = {**body, "stream": True} + _json = body + + _request = builder( + content_type=content_type, + json=_json, + content=_content, + headers=_headers, + params=_params, + **kwargs, + ) + _request.url = self._client.format_url(_request.url) # type: ignore[attr-defined] + + pipeline_response = await self._client._pipeline.run( # type: ignore[attr-defined] + _request, stream=True + ) + + response = pipeline_response.http_response + if response.status_code not in [200]: + await response.read() + map_error( + status_code=response.status_code, + response=response, + error_map=error_map, + ) + raise HttpResponseError(response=response) + + return AsyncSSEStream(response) + + +# --------------------------------------------------------------------------- +# SSE stream iterators +# --------------------------------------------------------------------------- + + +class SSEStream: + """Synchronous iterator over Server-Sent Events. + + Yields parsed JSON objects for each ``data:`` line. Stops on + ``data: [DONE]``. + + Usage:: + + stream = client..(body={ + ..., + "stream": True, + }) + with stream: + for chunk in stream: + print(chunk) + """ + + def __init__(self, response: Any): + self._response = response + + def __iter__(self) -> Iterator[dict]: + return self._iter_events() + + def _iter_events(self) -> Iterator[dict]: + buf = "" + for raw in self._response.iter_bytes(): + text = raw.decode("utf-8") if isinstance(raw, bytes) else raw + buf += text + while "\n" in buf: + line, buf = buf.split("\n", 1) + line = line.strip() + if not line: + continue + if line.startswith("data:"): + data = line[5:].strip() + if data == "[DONE]": + return + try: + yield json.loads(data) + except json.JSONDecodeError: + continue + + def close(self) -> None: + self._response.close() + + def __enter__(self) -> "SSEStream": + return self + + def __exit__(self, *args: Any) -> None: + self.close() + + +class AsyncSSEStream: + """Asynchronous iterator over Server-Sent Events. + + Yields parsed JSON objects for each ``data:`` line. Stops on + ``data: [DONE]``. + + Usage:: + + stream = await client..(body={ + ..., + "stream": True, + }) + async with stream: + async for chunk in stream: + print(chunk) + """ + + def __init__(self, response: Any): + self._response = response + + def __aiter__(self) -> AsyncIterator[dict]: + return self._iter_events() + + async def _iter_events(self) -> AsyncIterator[dict]: + buf = "" + async for raw in self._response.iter_bytes(): + text = raw.decode("utf-8") if isinstance(raw, bytes) else raw + buf += text + while "\n" in buf: + line, buf = buf.split("\n", 1) + line = line.strip() + if not line: + continue + if line.startswith("data:"): + data = line[5:].strip() + if data == "[DONE]": + return + try: + yield json.loads(data) + except json.JSONDecodeError: + continue + + def close(self) -> None: + self._response.close() + + async def __aenter__(self) -> "AsyncSSEStream": + return self + + async def __aexit__(self, *args: Any) -> None: + self.close() diff --git a/src/pydo/operations/_patch.py b/src/pydo/operations/_patch.py index 8a843f7..31f0f00 100644 --- a/src/pydo/operations/_patch.py +++ b/src/pydo/operations/_patch.py @@ -5,17 +5,109 @@ """Customize generated code here. Follow our quickstart for examples: https://aka.ms/azsdk/python/dpcodegen/python/customize + +Streaming for inference / agent-inference operations +----------------------------------------------------- +All generated methods on ``InferenceOperations`` and +``AgentInferenceOperations`` that accept a ``body`` parameter are +**automatically** wrapped at init time. When the caller passes +``"stream": True`` in the request body, the wrapper bypasses the +generated (non-streaming) code, runs the HTTP pipeline with +``stream=True``, and returns an :class:`~pydo.custom_extensions.SSEStream`. + +This means new endpoints added to the OpenAPI spec and regenerated via +``make generate`` get streaming support with **zero manual changes** to +this file. """ from typing import TYPE_CHECKING -from ._operations import DropletsOperations as Droplets - if TYPE_CHECKING: # pylint: disable=unused-import,ungrouped-imports pass +# --------------------------------------------------------------------------- +# Serverless Inference operations +# --------------------------------------------------------------------------- + +try: + from ._operations import InferenceOperations as _GeneratedInferenceOperations + + import pydo.operations._operations as _ops + + _HAS_INFERENCE = True +except ImportError: + _HAS_INFERENCE = False + +if _HAS_INFERENCE: + from pydo.custom_extensions import ( + StreamingMixin, + install_streaming_wrappers, + ) + + class InferenceOperations(StreamingMixin, _GeneratedInferenceOperations): + """InferenceOperations with fully automatic streaming support. + + Every generated method that takes a ``body`` parameter is wrapped + so that ``body["stream"] == True`` triggers SSE streaming + automatically. No per-endpoint overrides are needed. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + install_streaming_wrappers( + self, _GeneratedInferenceOperations, _ops, is_async=False + ) + + +# --------------------------------------------------------------------------- +# Agent Inference operations +# --------------------------------------------------------------------------- + +try: + from ._operations import ( + AgentInferenceOperations as _GeneratedAgentInferenceOperations, + ) + + if not _HAS_INFERENCE: + import pydo.operations._operations as _ops + + _HAS_AGENT_INFERENCE = True +except ImportError: + _HAS_AGENT_INFERENCE = False + +if _HAS_AGENT_INFERENCE: + if not _HAS_INFERENCE: + from pydo.custom_extensions import ( + StreamingMixin, + install_streaming_wrappers, + ) + + class AgentInferenceOperations(StreamingMixin, _GeneratedAgentInferenceOperations): + """AgentInferenceOperations with fully automatic streaming support. + + Same auto-wrapping strategy as ``InferenceOperations``. The + builder prefix is derived from the class name + (``agent_inference``), so builders named + ``build_agent_inference__request`` are discovered + automatically. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + install_streaming_wrappers( + self, _GeneratedAgentInferenceOperations, _ops, is_async=False + ) + + +# --------------------------------------------------------------------------- +# Exports +# --------------------------------------------------------------------------- -__all__ = [] +__all__ = [] # type: ignore[assignment] +if _HAS_INFERENCE: + __all__.append("InferenceOperations") +if _HAS_AGENT_INFERENCE: + __all__.append("AgentInferenceOperations") def patch_sdk():