Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
153 changes: 123 additions & 30 deletions src/app/endpoints/chat.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# src/app/endpoints/chat.py
import json
import time
from typing import Optional
from fastapi import APIRouter, HTTPException
from app.logger import logger
from schemas.request import GeminiRequest, OpenAIChatRequest
Expand All @@ -19,36 +21,111 @@ async def translate_chat(request: GeminiRequest):
if not session_manager:
raise HTTPException(status_code=503, detail="Session manager is not initialized.")
try:
# This call now correctly uses the fixed session manager
response = await session_manager.get_response(request.model, request.message, request.files)
return {"response": response.text}
except Exception as e:
logger.error(f"Error in /translate endpoint: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Error during translation: {str(e)}")

def convert_to_openai_format(response_text: str, model: str, stream: bool = False):
return {
"id": f"chatcmpl-{int(time.time())}",
"object": "chat.completion.chunk" if stream else "chat.completion",
"created": int(time.time()),
"model": model,
"choices": [
{

def _build_tools_prompt(tools: list) -> str:
"""Convert OpenAI tool definitions to a system prompt for Gemini."""
declarations = []
for t in tools:
if t.get("type") == "function" and "function" in t:
declarations.append(t["function"])
if not declarations:
return ""
lines = [
"You have access to the following tools. When you want to call a tool, respond with "
"ONLY a JSON object in this exact format, with no other text before or after:\n"
'{"tool_call": {"name": "<tool_name>", "arguments": {<arguments>}}}\n',
"Available tools:",
]
for fn in declarations:
lines.append(f"- {fn['name']}: {fn.get('description', '')}")
if fn.get("parameters"):
lines.append(f" Parameters: {json.dumps(fn['parameters'])}")
return "\n".join(lines)


def _parse_tool_call(text: str) -> Optional[dict]:
"""Extract a tool_call JSON object from model response text."""
decoder = json.JSONDecoder()
for i, ch in enumerate(text):
if ch == '{':
try:
obj, _ = decoder.raw_decode(text, i)
if isinstance(obj, dict) and "tool_call" in obj:
return obj["tool_call"]
except (json.JSONDecodeError, ValueError):
pass
return None


def convert_to_openai_format(response_text: str, model: str, stream: bool = False, tool_call: Optional[dict] = None):
ts = int(time.time())
if tool_call:
args = tool_call.get("arguments", {})
return {
"id": f"chatcmpl-{ts}",
"object": "chat.completion.chunk" if stream else "chat.completion",
"created": ts,
"model": model,
"choices": [{
"index": 0,
"message": {
"role": "assistant",
"content": response_text,
"content": None,
"tool_calls": [{
"id": f"call_{ts}",
"type": "function",
"function": {
"name": tool_call.get("name", ""),
"arguments": json.dumps(args) if isinstance(args, dict) else args,
},
}],
},
"finish_reason": "stop",
"finish_reason": "tool_calls",
}],
"usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
}
return {
"id": f"chatcmpl-{ts}",
"object": "chat.completion.chunk" if stream else "chat.completion",
"created": ts,
"model": model,
"choices": [{
"index": 0,
"message": {
"role": "assistant",
"content": response_text,
},
"finish_reason": "stop",
}],
"usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
}


@router.get("/v1/models")
async def list_models():
from gemini_webapi.constants import Model
ts = int(time.time())
return {
"object": "list",
"data": [
{
"id": model.model_name,
"object": "model",
"created": ts,
"owned_by": "google",
}
for model in Model
if model != Model.UNSPECIFIED
],
"usage": {
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": 0,
},
}


@router.post("/v1/chat/completions")
async def chat_completions(request: OpenAIChatRequest):
try:
Expand All @@ -61,34 +138,50 @@ async def chat_completions(request: OpenAIChatRequest):
if not request.messages:
raise HTTPException(status_code=400, detail="No messages provided.")

# Build conversation prompt with system prompt and full history
conversation_parts = []

# Inject tool definitions as a system prompt section
if request.tools:
tools_prompt = _build_tools_prompt(request.tools)
if tools_prompt:
conversation_parts.append(tools_prompt)

for msg in request.messages:
role = msg.get("role", "user")
content = msg.get("content", "")
if not content:
continue
content = msg.get("content") or ""

if role == "system":
conversation_parts.append(f"System: {content}")
elif role == "user":
conversation_parts.append(f"User: {content}")
elif role == "assistant":
conversation_parts.append(f"Assistant: {content}")
tool_calls = msg.get("tool_calls")
if tool_calls:
for tc in tool_calls:
fn = tc.get("function", {})
conversation_parts.append(
f"Assistant called tool {fn.get('name')}: {fn.get('arguments', '')}"
)
elif content:
conversation_parts.append(f"Assistant: {content}")
elif role == "tool":
tool_call_id = msg.get("tool_call_id", "")
conversation_parts.append(f"Tool result [{tool_call_id}]: {content}")

if not conversation_parts:
raise HTTPException(status_code=400, detail="No valid messages found.")

# Join all parts with newlines
final_prompt = "\n\n".join(conversation_parts)

if request.model:
try:
response = await gemini_client.generate_content(message=final_prompt, model=request.model.value, files=None)
return convert_to_openai_format(response.text, request.model.value, is_stream)
except Exception as e:
logger.error(f"Error in /v1/chat/completions endpoint: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Error processing chat completion: {str(e)}")
else:
if not request.model:
raise HTTPException(status_code=400, detail="Model not specified in the request.")

try:
response = await gemini_client.generate_content(message=final_prompt, model=request.model, files=None)
logger.debug(f"Gemini raw response: {response.text!r}")
tool_call = _parse_tool_call(response.text) if request.tools else None
logger.debug(f"Parsed tool_call: {tool_call}")
return convert_to_openai_format(response.text, request.model, is_stream, tool_call)
except Exception as e:
logger.error(f"Error in /v1/chat/completions endpoint: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Error processing chat completion: {str(e)}")
3 changes: 1 addition & 2 deletions src/app/endpoints/gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,8 @@ async def gemini_generate(request: GeminiRequest):
raise HTTPException(status_code=503, detail=str(e))

try:
# Use the value attribute for the model (since GeminiRequest.model is an Enum)
files: Optional[List[Union[str, Path]]] = [Path(f) for f in request.files] if request.files else None
response = await gemini_client.generate_content(request.message, request.model.value, files=files)
response = await gemini_client.generate_content(request.message, request.model, files=files)
return {"response": response.text}
except Exception as e:
logger.error(f"Error in /gemini endpoint: {e}", exc_info=True)
Expand Down
Loading