From f0a92556b1916f34d95ccce151ea14b17d904d6c Mon Sep 17 00:00:00 2001 From: SSharma-10 Date: Wed, 18 Mar 2026 14:58:02 +0530 Subject: [PATCH 1/3] streaming and multi Host support --- client_gen_config.md | 13 + src/pydo/_patch.py | 58 +++- src/pydo/aio/_patch.py | 60 +++- src/pydo/aio/operations/_patch.py | 86 +++++- src/pydo/custom_inference.py | 469 ++++++++++++++++++++++++++++++ src/pydo/operations/_patch.py | 98 ++++++- 6 files changed, 773 insertions(+), 11 deletions(-) create mode 100644 src/pydo/custom_inference.py diff --git a/client_gen_config.md b/client_gen_config.md index 6b931fc6..b7ad1b72 100644 --- a/client_gen_config.md +++ b/client_gen_config.md @@ -73,4 +73,17 @@ directive: where: '$.paths."/v2/apps".post' transform: > $["parameters"] = []; + + # Strip path-level and operation-level "servers" overrides from all paths. + # Without this, Autorest generates a duplicate `endpoint` parameter in + # GeneratedClient.__init__ (one per unique server URL), which is a Python + # syntax error. Runtime multi-base-URL routing is handled by + # _InferenceClientProxy in custom_inference.py instead. + - from: openapi-document + where: '$.paths.*' + transform: > + delete $["servers"]; + for (const m of ["get","post","put","patch","delete","head","options"]) { + if ($[m] && $[m].servers) { delete $[m].servers; } + } ``` diff --git a/src/pydo/_patch.py b/src/pydo/_patch.py index 350d70fd..63bb0bc9 100644 --- a/src/pydo/_patch.py +++ b/src/pydo/_patch.py @@ -11,6 +11,7 @@ from azure.core.credentials import AccessToken from pydo.custom_policies import CustomHttpLoggingPolicy +from pydo.custom_inference import _InferenceClientProxy, INFERENCE_BASE_URL from pydo import GeneratedClient, _version if TYPE_CHECKING: @@ -32,13 +33,28 @@ def get_token(self, *args, **kwargs) -> AccessToken: class Client(GeneratedClient): # type: ignore """The official DigitalOcean Python client - :param token: A valid API token. + :param token: A valid API token / model access key. :type token: str :keyword endpoint: Service URL. Default value is "https://api.digitalocean.com". :paramtype endpoint: str + :keyword inference_endpoint: Serverless inference URL. + Default value is "https://inference.do-ai.run". + :paramtype inference_endpoint: str + :keyword agent_endpoint: Agent inference URL. Pass the per-agent + subdomain (e.g. ``"https://.agents.do-ai.run"``). + Required only when using agent inference endpoints. + :paramtype agent_endpoint: str """ - def __init__(self, token: str, *, timeout: int = 120, **kwargs): + def __init__( + self, + token: str, + *, + timeout: int = 120, + inference_endpoint: str = INFERENCE_BASE_URL, + agent_endpoint: str = "", + **kwargs, + ): logger = kwargs.get("logger") if logger is not None and kwargs.get("http_logging_policy") == "": kwargs["http_logging_policy"] = CustomHttpLoggingPolicy(logger=logger) @@ -48,6 +64,44 @@ def __init__(self, token: str, *, timeout: int = 120, **kwargs): TokenCredentials(token), timeout=timeout, sdk_moniker=sdk_moniker, **kwargs ) + self._setup_inference_routing(inference_endpoint, agent_endpoint) + + def _setup_inference_routing( + self, + inference_endpoint: str, + agent_endpoint: str, + ) -> None: + """Route Inference / AgentInference operation groups to their endpoints. + + * ``*Inference*`` (but not ``*AgentInference*``) → *inference_endpoint* + * ``*AgentInference*`` → *agent_endpoint* + + Both use the same token passed to ``Client.__init__``. + """ + inference_proxy = _InferenceClientProxy( + self._client, + inference_endpoint, + strip_path_segments=["/inference/"], + version_prefix="/v1/", + ) + + agent_proxy = None + if agent_endpoint: + agent_proxy = _InferenceClientProxy( + self._client, + agent_endpoint, + path_replacements={"/v1/": "/api/v1/"}, + ) + + for attr in self.__dict__.values(): + if not hasattr(attr, "_client"): + continue + class_name = type(attr).__name__ + if class_name.startswith("AgentInference") and agent_proxy: + attr._client = agent_proxy + elif class_name.startswith("Inference"): + attr._client = inference_proxy + __all__ = ["Client"] diff --git a/src/pydo/aio/_patch.py b/src/pydo/aio/_patch.py index 37d16e19..634769e5 100644 --- a/src/pydo/aio/_patch.py +++ b/src/pydo/aio/_patch.py @@ -13,6 +13,7 @@ from pydo import _version from pydo.custom_policies import CustomHttpLoggingPolicy +from pydo.custom_inference import _InferenceClientProxy, INFERENCE_BASE_URL from pydo.aio import GeneratedClient if TYPE_CHECKING: @@ -34,15 +35,30 @@ async def get_token(self, *args, **kwargs) -> AccessToken: class Client(GeneratedClient): # type: ignore - """The official DigitalOcean Python client + """The official DigitalOcean Python client (async) - :param token: A valid API token. + :param token: A valid API token / model access key. :type token: str :keyword endpoint: Service URL. Default value is "https://api.digitalocean.com". :paramtype endpoint: str + :keyword inference_endpoint: Serverless inference URL. + Default value is "https://inference.do-ai.run". + :paramtype inference_endpoint: str + :keyword agent_endpoint: Agent inference URL. Pass the per-agent + subdomain (e.g. ``"https://.agents.do-ai.run"``). + Required only when using agent inference endpoints. + :paramtype agent_endpoint: str """ - def __init__(self, token: str, *, timeout: int = 120, **kwargs): + def __init__( + self, + token: str, + *, + timeout: int = 120, + inference_endpoint: str = INFERENCE_BASE_URL, + agent_endpoint: str = "", + **kwargs, + ): logger = kwargs.get("logger") if logger is not None and kwargs.get("http_logging_policy") == "": kwargs["http_logging_policy"] = CustomHttpLoggingPolicy(logger=logger) @@ -52,6 +68,44 @@ def __init__(self, token: str, *, timeout: int = 120, **kwargs): TokenCredentials(token), timeout=timeout, sdk_moniker=sdk_moniker, **kwargs ) + self._setup_inference_routing(inference_endpoint, agent_endpoint) + + def _setup_inference_routing( + self, + inference_endpoint: str, + agent_endpoint: str, + ) -> None: + """Route Inference / AgentInference operation groups to their endpoints. + + * ``*Inference*`` (but not ``*AgentInference*``) → *inference_endpoint* + * ``*AgentInference*`` → *agent_endpoint* + + Both use the same token passed to ``Client.__init__``. + """ + inference_proxy = _InferenceClientProxy( + self._client, + inference_endpoint, + strip_path_segments=["/inference/"], + version_prefix="/v1/", + ) + + agent_proxy = None + if agent_endpoint: + agent_proxy = _InferenceClientProxy( + self._client, + agent_endpoint, + path_replacements={"/v1/": "/api/v1/"}, + ) + + for attr in self.__dict__.values(): + if not hasattr(attr, "_client"): + continue + class_name = type(attr).__name__ + if class_name.startswith("AgentInference") and agent_proxy: + attr._client = agent_proxy + elif class_name.startswith("Inference"): + attr._client = inference_proxy + # Add all objects you want publicly available to users at this package level __all__ = ["Client"] # type: List[str] diff --git a/src/pydo/aio/operations/_patch.py b/src/pydo/aio/operations/_patch.py index 39ea63f6..7c363673 100644 --- a/src/pydo/aio/operations/_patch.py +++ b/src/pydo/aio/operations/_patch.py @@ -5,6 +5,10 @@ """Customize generated code here. Follow our quickstart for examples: https://aka.ms/azsdk/python/dpcodegen/python/customize + +Async mirror of ``pydo/operations/_patch.py``. See that module for the +full design rationale. If no inference / agent-inference operations have +been generated, this module exports nothing. """ from typing import TYPE_CHECKING @@ -12,9 +16,85 @@ # pylint: disable=unused-import,ungrouped-imports from typing import List -__all__ = ( - [] -) # type: List[str] # Add all objects you want publicly available to users at this package level +# --------------------------------------------------------------------------- +# Serverless Inference operations (async) +# --------------------------------------------------------------------------- + +try: + from ._operations import InferenceOperations as _GeneratedInferenceOperations + + import pydo.operations._operations as _ops + + _HAS_INFERENCE = True +except ImportError: + _HAS_INFERENCE = False + +if _HAS_INFERENCE: + from pydo.custom_inference import ( + AsyncStreamingMixin, + install_streaming_wrappers, + ) + + class InferenceOperations(AsyncStreamingMixin, _GeneratedInferenceOperations): + """Async InferenceOperations with fully automatic streaming support. + + Mirror of the sync version in ``pydo/operations/_patch.py``. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + install_streaming_wrappers( + self, _GeneratedInferenceOperations, _ops, is_async=True + ) + + +# --------------------------------------------------------------------------- +# Agent Inference operations (async) +# --------------------------------------------------------------------------- + +try: + from ._operations import ( + AgentInferenceOperations as _GeneratedAgentInferenceOperations, + ) + + if not _HAS_INFERENCE: + import pydo.operations._operations as _ops + + _HAS_AGENT_INFERENCE = True +except ImportError: + _HAS_AGENT_INFERENCE = False + +if _HAS_AGENT_INFERENCE: + if not _HAS_INFERENCE: + from pydo.custom_inference import ( + AsyncStreamingMixin, + install_streaming_wrappers, + ) + + class AgentInferenceOperations( + AsyncStreamingMixin, _GeneratedAgentInferenceOperations + ): + """Async AgentInferenceOperations with fully automatic streaming support. + + Mirror of the sync version in ``pydo/operations/_patch.py``. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + install_streaming_wrappers( + self, _GeneratedAgentInferenceOperations, _ops, is_async=True + ) + + +# --------------------------------------------------------------------------- +# Exports +# --------------------------------------------------------------------------- + +__all__ = [] # type: ignore[assignment] +if _HAS_INFERENCE: + __all__.append("InferenceOperations") +if _HAS_AGENT_INFERENCE: + __all__.append("AgentInferenceOperations") def patch_sdk(): diff --git a/src/pydo/custom_inference.py b/src/pydo/custom_inference.py new file mode 100644 index 00000000..79d7dc42 --- /dev/null +++ b/src/pydo/custom_inference.py @@ -0,0 +1,469 @@ +# ------------------------------------ +# Copyright (c) DigitalOcean. +# Licensed under the Apache-2.0 License. +# ------------------------------------ +"""Multi-base-URL routing and SSE streaming support for inference operations. + +This file is preserved during ``make clean`` (matches the custom_*.py pattern) +and is NOT overwritten by code generation. + +Architecture +------------ +* ``_InferenceClientProxy`` – lightweight proxy around a ``PipelineClient`` + that rewrites request URLs to a configured base URL. Supports stripping + tag-based path segments (e.g. ``/inference/``) and rewriting the API + version prefix (``/v2/`` → ``/v1/`` by default). Reuses the original + client's pipeline (auth, retry, logging). + +* ``StreamingMixin`` / ``AsyncStreamingMixin`` – mix-in classes that provide + the ``_auto_streaming_call`` method used by both ``InferenceOperations`` + and ``AgentInferenceOperations`` (and any future groups) in their + ``_patch.py`` files. + +* ``SSEStream`` / ``AsyncSSEStream`` – iterators that parse Server-Sent + Events from a streaming HTTP response and yield parsed JSON chunks. + +* ``install_streaming_wrappers`` – called once in an Operations ``__init__`` + to automatically wrap every generated method that has a matching request + builder. When the user passes ``"stream": True`` in the body, the wrapper + runs the pipeline in streaming mode and returns an SSEStream instead of + falling through to the generated code (which would do ``response.json()`` + and fail on SSE data). + +Supported routing patterns +-------------------------- +============================== ================================== ============== +Pattern Base URL Auth +============================== ================================== ============== +Serverless Inference ``inference.do-ai.run`` Model access key +Agent Inference ``{id}.agents.do-ai.run`` Access key +============================== ================================== ============== + +Extensibility +------------- +Both multi-base-URL **and** streaming are fully automatic for new endpoints: + +1. Add the endpoint to the OpenAPI spec and run ``make generate``. +2. Autorest creates a new method in the generated operations class and a + matching ``build___request`` function. +3. ``install_streaming_wrappers`` discovers the pair at init time and creates + a wrapper that intercepts ``stream: True``. +4. ``_InferenceClientProxy`` routes the request to the correct server. + +No manual changes to any ``_patch.py`` file are needed. +""" +import inspect +import json +import re +from io import IOBase +from typing import Any, AsyncIterator, Callable, Iterator, Optional, Union +from urllib.parse import urlparse + +from azure.core.exceptions import ( + ClientAuthenticationError, + HttpResponseError, + ResourceExistsError, + ResourceNotFoundError, + ResourceNotModifiedError, + map_error, +) +from azure.core.rest import HttpRequest +from azure.core.utils import case_insensitive_dict + + +INFERENCE_BASE_URL = "https://inference.do-ai.run" + +_STREAMING_ERROR_MAP = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, +} + + +# --------------------------------------------------------------------------- +# Multi-base-URL proxy +# --------------------------------------------------------------------------- + + +class _InferenceClientProxy: + """Proxy that redirects a PipelineClient to an alternate base URL. + + Reuses the wrapped client's pipeline (auth, retry, logging, etc.) + while transparently rewriting URL paths. + + Parameters + ---------- + original_client : PipelineClient + The client whose pipeline (and transport) will be reused. + base_url : str + Target base URL (e.g. ``https://inference.do-ai.run``). + strip_path_segments : list[str] | None + Tag-based path segments to remove from the generated URL. + For example ``["/inference/"]`` strips the Autorest-generated + ``/inference/`` segment so ``/v1/inference/chat/completions`` + becomes ``/v1/chat/completions``. + path_replacements : dict[str, str] | None + Arbitrary find→replace pairs applied to the URL path before + the base URL is prepended. Applied after ``strip_path_segments``. + Example: ``{"/v1/agent/": "/api/v1/"}`` rewrites the generated + ``/v1/agent/chat/completions`` to ``/api/v1/chat/completions``. + version_prefix : str + Replacement for the ``/v2/`` prefix that Autorest generates. + Defaults to ``"/v1/"`` (serverless inference). Set to ``"/"`` + to strip the version entirely (agent inference). + """ + + def __init__( + self, + original_client, + base_url: str, + *, + strip_path_segments: Optional[list] = None, + path_replacements: Optional[dict] = None, + version_prefix: str = "/v1/", + ): + self._original = original_client + self._base_url = base_url.rstrip("/") + self._strip_path_segments = strip_path_segments or [] + self._path_replacements = path_replacements or {} + self._version_prefix = version_prefix + + def format_url(self, url_template: str, **kwargs: Any) -> str: + if url_template.startswith("/v2/"): + url_template = self._version_prefix + url_template[4:] + elif url_template == "/v2": + url_template = self._version_prefix.rstrip("/") or "/" + + for seg in self._strip_path_segments: + if seg in url_template: + url_template = url_template.replace(seg, "/", 1) + + for find, replace in self._path_replacements.items(): + if find in url_template: + url_template = url_template.replace(find, replace, 1) + + while "//" in url_template and not url_template.startswith("http"): + url_template = url_template.replace("//", "/") + + parsed = urlparse(url_template) + if not parsed.scheme: + return self._base_url + url_template + return url_template + + @property + def _pipeline(self): + return self._original._pipeline + + def __getattr__(self, name: str) -> Any: + return getattr(self._original, name) + + +# --------------------------------------------------------------------------- +# Auto-wrap generated methods with streaming detection +# --------------------------------------------------------------------------- + + +def _class_name_to_builder_prefix(class_name: str) -> str: + """Convert ``'AgentInferenceOperations'`` → ``'agent_inference'``.""" + name = class_name.replace("Operations", "") + return re.sub(r"(? "SSEStream": + body = kwargs.pop("body") + + error_map = dict(_STREAMING_ERROR_MAP) + error_map.update(kwargs.pop("error_map", {}) or {}) + kwargs.pop("cls", None) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + content_type: Optional[str] = kwargs.pop( + "content_type", _headers.pop("Content-Type", None) + ) + content_type = content_type or "application/json" + + _json = None + _content = None + if isinstance(body, (IOBase, bytes)): + _content = body + else: + body = {**body, "stream": True} + _json = body + + _request = builder( + content_type=content_type, + json=_json, + content=_content, + headers=_headers, + params=_params, + **kwargs, + ) + _request.url = self._client.format_url(_request.url) # type: ignore[attr-defined] + + pipeline_response = self._client._pipeline.run( # type: ignore[attr-defined] + _request, stream=True + ) + + response = pipeline_response.http_response + if response.status_code not in [200]: + response.read() + map_error( + status_code=response.status_code, + response=response, + error_map=error_map, + ) + raise HttpResponseError(response=response) + + return SSEStream(response) + + +class AsyncStreamingMixin: + """Provides ``_auto_streaming_call`` for **async** operation groups.""" + + async def _auto_streaming_call( + self, + builder: Callable[..., HttpRequest], + **kwargs: Any, + ) -> "AsyncSSEStream": + body = kwargs.pop("body") + + error_map = dict(_STREAMING_ERROR_MAP) + error_map.update(kwargs.pop("error_map", {}) or {}) + kwargs.pop("cls", None) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + content_type: Optional[str] = kwargs.pop( + "content_type", _headers.pop("Content-Type", None) + ) + content_type = content_type or "application/json" + + _json = None + _content = None + if isinstance(body, (IOBase, bytes)): + _content = body + else: + body = {**body, "stream": True} + _json = body + + _request = builder( + content_type=content_type, + json=_json, + content=_content, + headers=_headers, + params=_params, + **kwargs, + ) + _request.url = self._client.format_url(_request.url) # type: ignore[attr-defined] + + pipeline_response = await self._client._pipeline.run( # type: ignore[attr-defined] + _request, stream=True + ) + + response = pipeline_response.http_response + if response.status_code not in [200]: + await response.read() + map_error( + status_code=response.status_code, + response=response, + error_map=error_map, + ) + raise HttpResponseError(response=response) + + return AsyncSSEStream(response) + + +# --------------------------------------------------------------------------- +# SSE stream iterators +# --------------------------------------------------------------------------- + + +class SSEStream: + """Synchronous iterator over Server-Sent Events. + + Yields parsed JSON objects for each ``data:`` line. Stops on + ``data: [DONE]``. + + Usage:: + + stream = client.inference.({ + ..., + "stream": True, + }) + with stream: + for chunk in stream: + print(chunk) + """ + + def __init__(self, response: Any): + self._response = response + + def __iter__(self) -> Iterator[dict]: + return self._iter_events() + + def _iter_events(self) -> Iterator[dict]: + buf = "" + for raw in self._response.iter_bytes(): + text = raw.decode("utf-8") if isinstance(raw, bytes) else raw + buf += text + while "\n" in buf: + line, buf = buf.split("\n", 1) + line = line.strip() + if not line: + continue + if line.startswith("data:"): + data = line[5:].strip() + if data == "[DONE]": + return + try: + yield json.loads(data) + except json.JSONDecodeError: + continue + + def close(self) -> None: + self._response.close() + + def __enter__(self) -> "SSEStream": + return self + + def __exit__(self, *args: Any) -> None: + self.close() + + +class AsyncSSEStream: + """Asynchronous iterator over Server-Sent Events. + + Yields parsed JSON objects for each ``data:`` line. Stops on + ``data: [DONE]``. + + Usage:: + + stream = await client.inference.({ + ..., + "stream": True, + }) + async with stream: + async for chunk in stream: + print(chunk) + """ + + def __init__(self, response: Any): + self._response = response + + def __aiter__(self) -> AsyncIterator[dict]: + return self._iter_events() + + async def _iter_events(self) -> AsyncIterator[dict]: + buf = "" + async for raw in self._response.iter_bytes(): + text = raw.decode("utf-8") if isinstance(raw, bytes) else raw + buf += text + while "\n" in buf: + line, buf = buf.split("\n", 1) + line = line.strip() + if not line: + continue + if line.startswith("data:"): + data = line[5:].strip() + if data == "[DONE]": + return + try: + yield json.loads(data) + except json.JSONDecodeError: + continue + + def close(self) -> None: + self._response.close() + + async def __aenter__(self) -> "AsyncSSEStream": + return self + + async def __aexit__(self, *args: Any) -> None: + self.close() diff --git a/src/pydo/operations/_patch.py b/src/pydo/operations/_patch.py index 8a843f7c..86990d91 100644 --- a/src/pydo/operations/_patch.py +++ b/src/pydo/operations/_patch.py @@ -5,17 +5,109 @@ """Customize generated code here. Follow our quickstart for examples: https://aka.ms/azsdk/python/dpcodegen/python/customize + +Streaming for inference / agent-inference operations +----------------------------------------------------- +All generated methods on ``InferenceOperations`` and +``AgentInferenceOperations`` that accept a ``body`` parameter are +**automatically** wrapped at init time. When the caller passes +``"stream": True`` in the request body, the wrapper bypasses the +generated (non-streaming) code, runs the HTTP pipeline with +``stream=True``, and returns an :class:`~pydo.custom_inference.SSEStream`. + +This means new endpoints added to the OpenAPI spec and regenerated via +``make generate`` get streaming support with **zero manual changes** to +this file. """ from typing import TYPE_CHECKING -from ._operations import DropletsOperations as Droplets - if TYPE_CHECKING: # pylint: disable=unused-import,ungrouped-imports pass +# --------------------------------------------------------------------------- +# Serverless Inference operations +# --------------------------------------------------------------------------- + +try: + from ._operations import InferenceOperations as _GeneratedInferenceOperations + + import pydo.operations._operations as _ops + + _HAS_INFERENCE = True +except ImportError: + _HAS_INFERENCE = False + +if _HAS_INFERENCE: + from pydo.custom_inference import ( + StreamingMixin, + install_streaming_wrappers, + ) + + class InferenceOperations(StreamingMixin, _GeneratedInferenceOperations): + """InferenceOperations with fully automatic streaming support. + + Every generated method that takes a ``body`` parameter is wrapped + so that ``body["stream"] == True`` triggers SSE streaming + automatically. No per-endpoint overrides are needed. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + install_streaming_wrappers( + self, _GeneratedInferenceOperations, _ops, is_async=False + ) + + +# --------------------------------------------------------------------------- +# Agent Inference operations +# --------------------------------------------------------------------------- + +try: + from ._operations import ( + AgentInferenceOperations as _GeneratedAgentInferenceOperations, + ) + + if not _HAS_INFERENCE: + import pydo.operations._operations as _ops + + _HAS_AGENT_INFERENCE = True +except ImportError: + _HAS_AGENT_INFERENCE = False + +if _HAS_AGENT_INFERENCE: + if not _HAS_INFERENCE: + from pydo.custom_inference import ( + StreamingMixin, + install_streaming_wrappers, + ) + + class AgentInferenceOperations(StreamingMixin, _GeneratedAgentInferenceOperations): + """AgentInferenceOperations with fully automatic streaming support. + + Same auto-wrapping strategy as ``InferenceOperations``. The + builder prefix is derived from the class name + (``agent_inference``), so builders named + ``build_agent_inference__request`` are discovered + automatically. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + install_streaming_wrappers( + self, _GeneratedAgentInferenceOperations, _ops, is_async=False + ) + + +# --------------------------------------------------------------------------- +# Exports +# --------------------------------------------------------------------------- -__all__ = [] +__all__ = [] # type: ignore[assignment] +if _HAS_INFERENCE: + __all__.append("InferenceOperations") +if _HAS_AGENT_INFERENCE: + __all__.append("AgentInferenceOperations") def patch_sdk(): From 54e3fbf2c2c72e393660b02075e9d11fa8e73e6c Mon Sep 17 00:00:00 2001 From: SSharma-10 Date: Wed, 18 Mar 2026 16:40:09 +0530 Subject: [PATCH 2/3] url fix --- src/pydo/_patch.py | 3 --- src/pydo/aio/_patch.py | 3 --- 2 files changed, 6 deletions(-) diff --git a/src/pydo/_patch.py b/src/pydo/_patch.py index 63bb0bc9..6876a833 100644 --- a/src/pydo/_patch.py +++ b/src/pydo/_patch.py @@ -81,8 +81,6 @@ def _setup_inference_routing( inference_proxy = _InferenceClientProxy( self._client, inference_endpoint, - strip_path_segments=["/inference/"], - version_prefix="/v1/", ) agent_proxy = None @@ -90,7 +88,6 @@ def _setup_inference_routing( agent_proxy = _InferenceClientProxy( self._client, agent_endpoint, - path_replacements={"/v1/": "/api/v1/"}, ) for attr in self.__dict__.values(): diff --git a/src/pydo/aio/_patch.py b/src/pydo/aio/_patch.py index 634769e5..fdfb4733 100644 --- a/src/pydo/aio/_patch.py +++ b/src/pydo/aio/_patch.py @@ -85,8 +85,6 @@ def _setup_inference_routing( inference_proxy = _InferenceClientProxy( self._client, inference_endpoint, - strip_path_segments=["/inference/"], - version_prefix="/v1/", ) agent_proxy = None @@ -94,7 +92,6 @@ def _setup_inference_routing( agent_proxy = _InferenceClientProxy( self._client, agent_endpoint, - path_replacements={"/v1/": "/api/v1/"}, ) for attr in self.__dict__.values(): From 974f55c7e7227ac152168163799ae47214ae2f4c Mon Sep 17 00:00:00 2001 From: SSharma-10 Date: Wed, 18 Mar 2026 21:09:18 +0530 Subject: [PATCH 3/3] file renaming --- client_gen_config.md | 2 +- src/pydo/_patch.py | 6 +- src/pydo/aio/_patch.py | 6 +- src/pydo/aio/operations/_patch.py | 4 +- ...stom_inference.py => custom_extensions.py} | 85 ++++--------------- src/pydo/operations/_patch.py | 6 +- 6 files changed, 30 insertions(+), 79 deletions(-) rename src/pydo/{custom_inference.py => custom_extensions.py} (79%) diff --git a/client_gen_config.md b/client_gen_config.md index b7ad1b72..5f689cc7 100644 --- a/client_gen_config.md +++ b/client_gen_config.md @@ -78,7 +78,7 @@ directive: # Without this, Autorest generates a duplicate `endpoint` parameter in # GeneratedClient.__init__ (one per unique server URL), which is a Python # syntax error. Runtime multi-base-URL routing is handled by - # _InferenceClientProxy in custom_inference.py instead. + # _BaseURLProxy in custom_extensions.py instead. - from: openapi-document where: '$.paths.*' transform: > diff --git a/src/pydo/_patch.py b/src/pydo/_patch.py index 6876a833..b76607e3 100644 --- a/src/pydo/_patch.py +++ b/src/pydo/_patch.py @@ -11,7 +11,7 @@ from azure.core.credentials import AccessToken from pydo.custom_policies import CustomHttpLoggingPolicy -from pydo.custom_inference import _InferenceClientProxy, INFERENCE_BASE_URL +from pydo.custom_extensions import _BaseURLProxy, INFERENCE_BASE_URL from pydo import GeneratedClient, _version if TYPE_CHECKING: @@ -78,14 +78,14 @@ def _setup_inference_routing( Both use the same token passed to ``Client.__init__``. """ - inference_proxy = _InferenceClientProxy( + inference_proxy = _BaseURLProxy( self._client, inference_endpoint, ) agent_proxy = None if agent_endpoint: - agent_proxy = _InferenceClientProxy( + agent_proxy = _BaseURLProxy( self._client, agent_endpoint, ) diff --git a/src/pydo/aio/_patch.py b/src/pydo/aio/_patch.py index fdfb4733..c2bc2906 100644 --- a/src/pydo/aio/_patch.py +++ b/src/pydo/aio/_patch.py @@ -13,7 +13,7 @@ from pydo import _version from pydo.custom_policies import CustomHttpLoggingPolicy -from pydo.custom_inference import _InferenceClientProxy, INFERENCE_BASE_URL +from pydo.custom_extensions import _BaseURLProxy, INFERENCE_BASE_URL from pydo.aio import GeneratedClient if TYPE_CHECKING: @@ -82,14 +82,14 @@ def _setup_inference_routing( Both use the same token passed to ``Client.__init__``. """ - inference_proxy = _InferenceClientProxy( + inference_proxy = _BaseURLProxy( self._client, inference_endpoint, ) agent_proxy = None if agent_endpoint: - agent_proxy = _InferenceClientProxy( + agent_proxy = _BaseURLProxy( self._client, agent_endpoint, ) diff --git a/src/pydo/aio/operations/_patch.py b/src/pydo/aio/operations/_patch.py index 7c363673..5836b4b5 100644 --- a/src/pydo/aio/operations/_patch.py +++ b/src/pydo/aio/operations/_patch.py @@ -30,7 +30,7 @@ _HAS_INFERENCE = False if _HAS_INFERENCE: - from pydo.custom_inference import ( + from pydo.custom_extensions import ( AsyncStreamingMixin, install_streaming_wrappers, ) @@ -66,7 +66,7 @@ def __init__(self, *args, **kwargs): if _HAS_AGENT_INFERENCE: if not _HAS_INFERENCE: - from pydo.custom_inference import ( + from pydo.custom_extensions import ( AsyncStreamingMixin, install_streaming_wrappers, ) diff --git a/src/pydo/custom_inference.py b/src/pydo/custom_extensions.py similarity index 79% rename from src/pydo/custom_inference.py rename to src/pydo/custom_extensions.py index 79d7dc42..76d651b8 100644 --- a/src/pydo/custom_inference.py +++ b/src/pydo/custom_extensions.py @@ -2,23 +2,22 @@ # Copyright (c) DigitalOcean. # Licensed under the Apache-2.0 License. # ------------------------------------ -"""Multi-base-URL routing and SSE streaming support for inference operations. +"""Multi-base-URL routing and SSE streaming support. This file is preserved during ``make clean`` (matches the custom_*.py pattern) and is NOT overwritten by code generation. Architecture ------------ -* ``_InferenceClientProxy`` – lightweight proxy around a ``PipelineClient`` - that rewrites request URLs to a configured base URL. Supports stripping - tag-based path segments (e.g. ``/inference/``) and rewriting the API - version prefix (``/v2/`` → ``/v1/`` by default). Reuses the original - client's pipeline (auth, retry, logging). +* ``_BaseURLProxy`` – lightweight proxy around a ``PipelineClient`` + that prepends a configured base URL to the generated path. Reuses + the original client's pipeline (auth, retry, logging). Usable for + any alternate host, not limited to inference endpoints. * ``StreamingMixin`` / ``AsyncStreamingMixin`` – mix-in classes that provide - the ``_auto_streaming_call`` method used by both ``InferenceOperations`` - and ``AgentInferenceOperations`` (and any future groups) in their - ``_patch.py`` files. + the ``_auto_streaming_call`` method. Can be mixed into any operation + group (e.g. ``InferenceOperations``, ``AgentInferenceOperations``, or + future groups) in their ``_patch.py`` files. * ``SSEStream`` / ``AsyncSSEStream`` – iterators that parse Server-Sent Events from a streaming HTTP response and yield parsed JSON chunks. @@ -30,15 +29,6 @@ falling through to the generated code (which would do ``response.json()`` and fail on SSE data). -Supported routing patterns --------------------------- -============================== ================================== ============== -Pattern Base URL Auth -============================== ================================== ============== -Serverless Inference ``inference.do-ai.run`` Model access key -Agent Inference ``{id}.agents.do-ai.run`` Access key -============================== ================================== ============== - Extensibility ------------- Both multi-base-URL **and** streaming are fully automatic for new endpoints: @@ -48,7 +38,7 @@ matching ``build___request`` function. 3. ``install_streaming_wrappers`` discovers the pair at init time and creates a wrapper that intercepts ``stream: True``. -4. ``_InferenceClientProxy`` routes the request to the correct server. +4. ``_BaseURLProxy`` routes the request to the correct server. No manual changes to any ``_patch.py`` file are needed. """ @@ -56,7 +46,7 @@ import json import re from io import IOBase -from typing import Any, AsyncIterator, Callable, Iterator, Optional, Union +from typing import Any, AsyncIterator, Callable, Iterator, Optional from urllib.parse import urlparse from azure.core.exceptions import ( @@ -86,11 +76,13 @@ # --------------------------------------------------------------------------- -class _InferenceClientProxy: +class _BaseURLProxy: """Proxy that redirects a PipelineClient to an alternate base URL. Reuses the wrapped client's pipeline (auth, retry, logging, etc.) - while transparently rewriting URL paths. + while transparently prepending a different base URL to the + generated path. This is a generic utility and is not limited to + inference endpoints. Parameters ---------- @@ -98,54 +90,13 @@ class _InferenceClientProxy: The client whose pipeline (and transport) will be reused. base_url : str Target base URL (e.g. ``https://inference.do-ai.run``). - strip_path_segments : list[str] | None - Tag-based path segments to remove from the generated URL. - For example ``["/inference/"]`` strips the Autorest-generated - ``/inference/`` segment so ``/v1/inference/chat/completions`` - becomes ``/v1/chat/completions``. - path_replacements : dict[str, str] | None - Arbitrary find→replace pairs applied to the URL path before - the base URL is prepended. Applied after ``strip_path_segments``. - Example: ``{"/v1/agent/": "/api/v1/"}`` rewrites the generated - ``/v1/agent/chat/completions`` to ``/api/v1/chat/completions``. - version_prefix : str - Replacement for the ``/v2/`` prefix that Autorest generates. - Defaults to ``"/v1/"`` (serverless inference). Set to ``"/"`` - to strip the version entirely (agent inference). """ - def __init__( - self, - original_client, - base_url: str, - *, - strip_path_segments: Optional[list] = None, - path_replacements: Optional[dict] = None, - version_prefix: str = "/v1/", - ): + def __init__(self, original_client, base_url: str): self._original = original_client self._base_url = base_url.rstrip("/") - self._strip_path_segments = strip_path_segments or [] - self._path_replacements = path_replacements or {} - self._version_prefix = version_prefix def format_url(self, url_template: str, **kwargs: Any) -> str: - if url_template.startswith("/v2/"): - url_template = self._version_prefix + url_template[4:] - elif url_template == "/v2": - url_template = self._version_prefix.rstrip("/") or "/" - - for seg in self._strip_path_segments: - if seg in url_template: - url_template = url_template.replace(seg, "/", 1) - - for find, replace in self._path_replacements.items(): - if find in url_template: - url_template = url_template.replace(find, replace, 1) - - while "//" in url_template and not url_template.startswith("http"): - url_template = url_template.replace("//", "/") - parsed = urlparse(url_template) if not parsed.scheme: return self._base_url + url_template @@ -244,7 +195,7 @@ def wrapper(*args, **kwargs): # noqa: E303 # --------------------------------------------------------------------------- -# Streaming call mix-ins (reused by Inference & AgentInference _patch.py) +# Streaming call mix-ins (reused by any operation group's _patch.py) # --------------------------------------------------------------------------- @@ -373,7 +324,7 @@ class SSEStream: Usage:: - stream = client.inference.({ + stream = client..(body={ ..., "stream": True, }) @@ -425,7 +376,7 @@ class AsyncSSEStream: Usage:: - stream = await client.inference.({ + stream = await client..(body={ ..., "stream": True, }) diff --git a/src/pydo/operations/_patch.py b/src/pydo/operations/_patch.py index 86990d91..31f0f00a 100644 --- a/src/pydo/operations/_patch.py +++ b/src/pydo/operations/_patch.py @@ -13,7 +13,7 @@ **automatically** wrapped at init time. When the caller passes ``"stream": True`` in the request body, the wrapper bypasses the generated (non-streaming) code, runs the HTTP pipeline with -``stream=True``, and returns an :class:`~pydo.custom_inference.SSEStream`. +``stream=True``, and returns an :class:`~pydo.custom_extensions.SSEStream`. This means new endpoints added to the OpenAPI spec and regenerated via ``make generate`` get streaming support with **zero manual changes** to @@ -39,7 +39,7 @@ _HAS_INFERENCE = False if _HAS_INFERENCE: - from pydo.custom_inference import ( + from pydo.custom_extensions import ( StreamingMixin, install_streaming_wrappers, ) @@ -77,7 +77,7 @@ def __init__(self, *args, **kwargs): if _HAS_AGENT_INFERENCE: if not _HAS_INFERENCE: - from pydo.custom_inference import ( + from pydo.custom_extensions import ( StreamingMixin, install_streaming_wrappers, )