Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
203 changes: 195 additions & 8 deletions src/openagents/lms/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,15 +233,19 @@ class BedrockProvider(BaseModelProvider):
def __init__(self, model_name: str, region: Optional[str] = None, **kwargs):
self.model_name = model_name
self.region = region or os.getenv("AWS_DEFAULT_REGION", "us-east-1")
self.bearer_token = os.getenv("AWS_BEARER_TOKEN_BEDROCK")
self.session = None

if not self.bearer_token:
try:
import aioboto3
except ImportError:
raise ImportError(
"aioboto3 package is required for async Bedrock provider (unless using AWS_BEARER_TOKEN_BEDROCK). "
"Install with: pip install aioboto3"
)
self.session = aioboto3.Session()

try:
import aioboto3
except ImportError:
raise ImportError(
"aioboto3 package is required for async Bedrock provider. Install with: pip install aioboto3"
)

self.session = aioboto3.Session()

async def chat_completion(
self,
Expand All @@ -252,6 +256,8 @@ async def chat_completion(
# Format depends on the specific model
if "claude" in self.model_name.lower():
return await self._claude_bedrock_completion(messages, tools)
elif "qwen" in self.model_name.lower():
return await self._qwen_bedrock_completion(messages, tools)
else:
raise NotImplementedError(
f"Model {self.model_name} not yet supported in Bedrock provider"
Expand All @@ -263,6 +269,10 @@ async def _claude_bedrock_completion(
tools: Optional[List[Dict[str, Any]]] = None,
) -> Dict[str, Any]:
"""Handle Claude models on Bedrock."""
# Convert to Converse API format for bearer token, or invoke_model for boto3
if self.bearer_token:
return await self._claude_bedrock_bearer(messages, tools)

# Convert to Claude Bedrock format
claude_messages = []
system_message = None
Expand Down Expand Up @@ -322,6 +332,183 @@ async def _claude_bedrock_completion(

return result

async def _claude_bedrock_bearer(
self,
messages: List[Dict[str, Any]],
tools: Optional[List[Dict[str, Any]]] = None,
) -> Dict[str, Any]:
"""Handle Claude models on Bedrock using bearer token (Converse API)."""
import aiohttp

# Build Converse API request body
converse_messages = []
system_parts = []

for msg in messages:
if msg["role"] == "system":
system_parts.append({"text": msg["content"]})
elif msg["role"] == "tool":
# Converse API expects tool results inside a "user" role message
tool_result = {
"toolResult": {
"toolUseId": msg.get("tool_call_id", ""),
"content": [{"text": msg.get("content", "")}]
}
}
converse_messages.append({
"role": "user",
"content": [tool_result]
})
elif msg["role"] == "assistant" and msg.get("tool_calls"):
# Convert assistant tool_calls to Converse format
content = []
if msg.get("content"):
content.append({"text": msg["content"]})
for tc in msg["tool_calls"]:
content.append({
"toolUse": {
"toolUseId": tc.get("id", ""),
"name": tc.get("name", tc.get("function", {}).get("name", "")),
"input": json.loads(tc.get("arguments", tc.get("function", {}).get("arguments", "{}")))
}
})
converse_messages.append({"role": "assistant", "content": content})
else:
converse_messages.append({
"role": msg["role"],
"content": [{"text": msg.get("content", "") or ""}]
})

body = {"messages": converse_messages}

if system_parts:
body["system"] = system_parts

if tools:
tool_config = {"tools": []}
for tool in tools:
tool_spec = {
"name": tool.get("name", ""),
"description": tool.get("description", ""),
"inputSchema": {
"json": tool.get("input_schema", tool.get("parameters", {"type": "object", "properties": {}}))
}
}
tool_config["tools"].append({"toolSpec": tool_spec})
body["toolConfig"] = tool_config

url = f"https://bedrock-runtime.{self.region}.amazonaws.com/model/{self.model_name}/converse"

async with aiohttp.ClientSession() as session:
async with session.post(
url,
json=body,
headers={
"Content-Type": "application/json",
"Authorization": f"Bearer {self.bearer_token}",
},
) as resp:
if resp.status != 200:
error_text = await resp.text()
raise RuntimeError(f"Bedrock API error ({resp.status}): {error_text}")
response_body = await resp.json()

# Parse Converse API response
result = {"content": "", "tool_calls": []}

output = response_body.get("output", {})
message = output.get("message", {})
for content_block in message.get("content", []):
if "text" in content_block:
result["content"] += content_block["text"]
elif "toolUse" in content_block:
tool_use = content_block["toolUse"]
result["tool_calls"].append({
"id": tool_use["toolUseId"],
"name": tool_use["name"],
"arguments": json.dumps(tool_use["input"]),
})

# Extract token usage
usage = response_body.get("usage", {})
if usage:
input_tokens = usage.get("inputTokens")
output_tokens = usage.get("outputTokens")
result["usage"] = {
"input_tokens": input_tokens,
"output_tokens": output_tokens,
"total_tokens": (input_tokens or 0) + (output_tokens or 0) if input_tokens or output_tokens else None,
}

return result

async def _qwen_bedrock_completion(
self,
messages: List[Dict[str, Any]],
tools: Optional[List[Dict[str, Any]]] = None,
) -> Dict[str, Any]:
"""Handle Qwen models on Bedrock (OpenAI-compatible format)."""
# Use Converse API with bearer token if available
if self.bearer_token:
return await self._claude_bedrock_bearer(messages, tools)

# Qwen uses OpenAI-compatible message format
qwen_messages = []

for msg in messages:
qwen_messages.append({"role": msg["role"], "content": msg["content"]})

body = {
"messages": qwen_messages,
"max_tokens": 4096,
}

if tools:
body["tools"] = [{"type": "function", "function": tool} for tool in tools]
body["tool_choice"] = "auto"

async with self.session.client(
"bedrock-runtime", region_name=self.region
) as client:
response = await client.invoke_model(
modelId=self.model_name, body=json.dumps(body)
)

response_body = json.loads(response["body"].read())

# Standardize response format (OpenAI-compatible response)
result = {"content": "", "tool_calls": []}

# Handle OpenAI-compatible response format
if "choices" in response_body:
choice = response_body["choices"][0]
message = choice.get("message", {})
result["content"] = message.get("content", "")

if "tool_calls" in message and message["tool_calls"]:
for tool_call in message["tool_calls"]:
result["tool_calls"].append(
{
"id": tool_call.get("id", ""),
"name": tool_call["function"]["name"],
"arguments": tool_call["function"]["arguments"],
}
)
# Handle direct content response
elif "content" in response_body:
result["content"] = response_body["content"]

# Extract token usage
usage = response_body.get("usage", {})
if usage:
result["usage"] = {
"input_tokens": usage.get("prompt_tokens"),
"output_tokens": usage.get("completion_tokens"),
"total_tokens": usage.get("total_tokens"),
}

return result

def format_tools(self, tools: List[Any]) -> List[Dict[str, Any]]:
"""Format tools for Bedrock."""
formatted_tools = []
Expand Down
Loading