-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
63 lines (49 loc) · 1.81 KB
/
main.py
File metadata and controls
63 lines (49 loc) · 1.81 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
from fastapi import FastAPI, File, HTTPException, UploadFile
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from PIL import Image
import io
import cv2
import numpy as np
import yolov5
from offside_detector import OffsideDetector, draw_definitive_results
app = FastAPI(title="Offside Detection API")
# CORS configuration
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["POST"],
allow_headers=["*"],
)
# Load model once globally
model = yolov5.load('keremberke/yolov5m-football')
model.conf = 0.25
model.iou = 0.45
model.multi_label = True
detector = OffsideDetector()
def read_image_as_cv2(file_bytes: bytes) -> np.ndarray:
image = Image.open(io.BytesIO(file_bytes)).convert("RGB")
return cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
@app.get("/")
async def root():
return {"message": "⚽ Offside Detection API is running!"}
@app.post("/detect-offside")
async def process_image(file: UploadFile = File(...)):
try:
contents = await file.read()
image = read_image_as_cv2(contents)
results = model(image)
detections = results.pandas().xyxy[0]
detections['class_name'] = detections['class'].map(lambda x: model.names[int(x)])
offside_players = detector.detect_offside(detections, image)
annotated = draw_definitive_results(image, detections, offside_players, detector)
_, buffer = cv2.imencode(".jpg", annotated)
image_stream = io.BytesIO(buffer.tobytes())
image_stream.seek(0)
return StreamingResponse(image_stream, media_type="image/jpeg")
except Exception as e:
raise HTTPException(500, f"Processing error: {str(e)}")
# Local dev runner
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)