-
Notifications
You must be signed in to change notification settings - Fork 173
Expand file tree
/
Copy pathapi.py
More file actions
164 lines (144 loc) · 5.62 KB
/
api.py
File metadata and controls
164 lines (144 loc) · 5.62 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
# Copyright 2025 © BeeAI a Series of LF Projects, LLC
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
import json
import logging
import re
import urllib
import urllib.parse
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
from datetime import timedelta
from textwrap import indent
from typing import Any
import httpx
import openai
import pydantic
from a2a.client import A2AClientHTTPError, Client, ClientConfig, ClientFactory
from a2a.types import AgentCard
from agentstack_sdk.platform.context import ContextToken
from httpx import HTTPStatusError
from httpx._types import RequestFiles
from agentstack_cli import configuration
from agentstack_cli.configuration import Configuration
logger = logging.getLogger(__name__)
config = Configuration()
API_BASE_URL = "api/v1/"
async def api_request(
method: str,
path: str,
json: dict | None = None,
files: RequestFiles | None = None,
params: dict[str, Any] | None = None,
use_auth: bool = True,
) -> dict | None:
"""Make an API request to the server."""
async with configuration.use_platform_client() as client:
response = await client.request(
method,
urllib.parse.urljoin(API_BASE_URL, path),
json=json,
files=files,
params=params,
timeout=60,
headers=(
{"Authorization": f"Bearer {token.access_token}"}
if use_auth and (token := await config.auth_manager.load_auth_token())
else {}
),
)
if response.is_error:
error = ""
try:
error = response.json()
error = error.get("detail", str(error))
except Exception:
response.raise_for_status()
if response.status_code == 401:
message = f'{error}\nexport AGENTSTACK__ADMIN_PASSWORD="<PASSWORD>" to set the admin password.'
raise HTTPStatusError(message=message, request=response.request, response=response)
raise HTTPStatusError(message=error, request=response.request, response=response)
if response.content:
return response.json()
async def api_stream(
method: str,
path: str,
json: dict | None = None,
params: dict[str, Any] | None = None,
use_auth: bool = True,
) -> AsyncIterator[dict[str, Any]]:
"""Make a streaming API request to the server."""
import json as jsonlib
async with (
configuration.use_platform_client() as client,
client.stream(
method,
urllib.parse.urljoin(API_BASE_URL, path),
json=json,
params=params,
timeout=timedelta(hours=1).total_seconds(),
headers=(
{"Authorization": f"Bearer {token.access_token}"}
if use_auth and (token := await config.auth_manager.load_auth_token())
else {}
),
) as response,
):
response: httpx.Response
if response.is_error:
error = ""
try:
[error] = [jsonlib.loads(message) async for message in response.aiter_text()]
error = error.get("detail", str(error))
except Exception:
response.raise_for_status()
raise HTTPStatusError(message=error, request=response.request, response=response)
async for line in response.aiter_lines():
if line:
yield jsonlib.loads(re.sub("^data:", "", line).strip())
async def fetch_server_version() -> str | None:
"""Fetch server version from OpenAPI schema."""
class OpenAPIInfo(pydantic.BaseModel):
version: str
class OpenAPISchema(pydantic.BaseModel):
info: OpenAPIInfo
try:
response = await api_request("GET", "openapi.json", use_auth=False)
if not response:
return None
schema = OpenAPISchema.model_validate(response)
return schema.info.version
except Exception as e:
logger.warning("Failed to fetch server version: %s", e)
return None
@asynccontextmanager
async def a2a_client(agent_card: AgentCard, context_token: ContextToken) -> AsyncIterator[Client]:
try:
async with httpx.AsyncClient(
headers={"Authorization": f"Bearer {context_token.token.get_secret_value()}"},
follow_redirects=True,
timeout=timedelta(hours=1).total_seconds(),
) as httpx_client:
yield ClientFactory(ClientConfig(httpx_client=httpx_client, use_client_preference=True)).create(
card=agent_card
)
except A2AClientHTTPError as ex:
card_data = json.dumps(
agent_card.model_dump(include={"url", "additional_interfaces", "preferred_transport"}), indent=2
)
raise RuntimeError(
f"The agent is not reachable, please check that the agent card is configured properly.\n"
f"Agent connection info:\n{indent(card_data, prefix=' ')}\n"
"Full Error:\n"
f"{indent(str(ex), prefix=' ')}"
) from ex
@asynccontextmanager
async def openai_client() -> AsyncIterator[openai.AsyncOpenAI]:
async with Configuration().use_platform_client() as platform_client:
headers = platform_client.headers.copy()
headers.pop("Authorization", None)
yield openai.AsyncOpenAI(
api_key=platform_client.headers.get("Authorization", "").removeprefix("Bearer ") or "dummy",
base_url=urllib.parse.urljoin(str(platform_client.base_url), urllib.parse.urljoin(API_BASE_URL, "openai")),
default_headers=headers,
)