Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 113 additions & 3 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import inspect

from fastapi import FastAPI, Request, Depends, status, HTTPException, UploadFile, File
from fastapi.middleware.cors import CORSMiddleware
from fastapi.concurrency import run_in_threadpool


from starlette.responses import StreamingResponse, Response
from pydantic import BaseModel, ConfigDict
from typing import List, Union, Generator, Iterator
from typing import List, Union, Generator, Iterator, AsyncGenerator


from utils.pipelines.auth import bearer_security, get_current_user
Expand Down Expand Up @@ -134,7 +136,7 @@ async def load_module_from_path(module_name, module_path):

try:
# Read the module content
with open(module_path, "r") as file:
with open(module_path, "r", encoding="utf-8") as file:
content = file.read()

# Parse frontmatter
Expand Down Expand Up @@ -786,4 +788,112 @@ def stream_content():
],
}

return await run_in_threadpool(job)
# Asynchronous Job (for non-blocking pipelines)
async def async_job():

if form_data.stream:
res = await pipe(
user_message=user_message,
model_id=pipeline_id,
messages=messages,
body=form_data.model_dump(),
)
logging.info(f"stream:true:{res}")

async def stream_content():
if isinstance(res, str):
message = stream_message_template(form_data.model, res)
logging.info(f"stream_content:str:{message}")
yield f"data: {json.dumps(message)}\n\n"

elif inspect.isasyncgen(res) or isinstance(res, AsyncGenerator):
async for chunk in res:
if isinstance(chunk, BaseModel):
chunk = chunk.model_dump_json()
chunk = f"data: {chunk}"
elif isinstance(chunk, dict):
chunk = json.dumps(chunk)
chunk = f"data: {chunk}"

try:
chunk = chunk.decode("utf-8")
logging.info(f"stream_content:AsyncGenerator:{chunk}")
except:
pass

if isinstance(chunk, str) and chunk.startswith("data:"):
yield f"{chunk}\n\n"
else:
chunk = stream_message_template(form_data.model, chunk)
yield f"data: {json.dumps(chunk)}\n\n"
else:
logging.warning(f"Unhandled async response type: {type(res)}")


finish_message = {
"id": f"{form_data.model}-{str(uuid.uuid4())}",
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": form_data.model,
"choices": [
{
"index": 0,
"delta": {},
"logprobs": None,
"finish_reason": "stop",
}
],
}
yield f"data: {json.dumps(finish_message)}\n\n"
yield "data: [DONE]"

return StreamingResponse(stream_content(), media_type="text/event-stream")

else:
res = await pipe(
user_message=user_message,
model_id=pipeline_id,
messages=messages,
body=form_data.model_dump(),
)
logging.info(f"stream:false:{res}")

if isinstance(res, dict):
return res
elif isinstance(res, BaseModel):
return res.model_dump()
else:
message = ""
if isinstance(res, str):
message = res
elif inspect.isasyncgen(res) or isinstance(res, AsyncGenerator):
async for chunk in res:
message = f"{message}{chunk}"
else:
logging.warning(f"Unhandled async response type: {type(res)}")

logging.info(f"stream:false:{message}")
return {
"id": f"{form_data.model}-{str(uuid.uuid4())}",
"object": "chat.completion",
"created": int(time.time()),
"model": form_data.model,
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": message,
},
"logprobs": None,
"finish_reason": "stop",
}
],
}

if inspect.iscoroutinefunction(pipe):
logging.info(f"Executing ASYNC job for pipeline: {form_data.model}")
return await async_job()
else:
logging.info(f"Executing SYNC job in thread pool for pipeline: {form_data.model}")
return await run_in_threadpool(job)