-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
73 lines (63 loc) · 2.13 KB
/
main.py
File metadata and controls
73 lines (63 loc) · 2.13 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import os
import asyncio
from contextlib import asynccontextmanager
import uvicorn
from fastapi import FastAPI
from fastapi.middleware.gzip import GZipMiddleware
import torch
from services import EmbeddingService, RerankerService
from controllers import EmbeddingController, RerankController
from fastapi.middleware.cors import CORSMiddleware
from dotenv import load_dotenv
load_dotenv()
port = int(os.getenv("PORT", "8000"))
embeddingService = EmbeddingService()
rerankerService = RerankerService()
embeddingController = EmbeddingController(embeddingService)
crossEncoderRerankerController = RerankController(rerankerService)
@asynccontextmanager
async def lifespan(app: FastAPI):
await asyncio.to_thread(embeddingService.LoadModel)
await rerankerService.LoadModel()
await embeddingService.batcher.start()
await rerankerService.batcher.start()
try:
if (
torch.cuda.is_available()
and getattr(__import__("services").EmbeddingService, "keepaliveTask", None)
is None
):
loop = asyncio.get_running_loop()
__import__("services").EmbeddingService.keepaliveTask = loop.create_task(
embeddingService.GpuKeepAlive()
)
except Exception:
pass
try:
yield
finally:
if embeddingService.batcher.task is not None:
embeddingService.batcher.task.cancel()
if rerankerService.batcher.task is not None:
rerankerService.batcher.task.cancel()
app = FastAPI(lifespan=lifespan)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
app.add_middleware(GZipMiddleware, minimum_size=800)
app.include_router(embeddingController.router, prefix="/api/v1")
app.include_router(crossEncoderRerankerController.router, prefix="/api/v1")
if __name__ == "__main__":
uvicorn.run(
"main:app",
host="0.0.0.0",
port=port,
workers=1,
loop="uvloop",
http="httptools",
timeout_keep_alive=300,
)