-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
77 lines (69 loc) · 2.86 KB
/
main.py
File metadata and controls
77 lines (69 loc) · 2.86 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
from fastapi import FastAPI, Request, HTTPException
from fastapi.responses import HTMLResponse, JSONResponse, FileResponse
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel
import subprocess
import shlex
import json
import os
import shutil
app = FastAPI()
# Serve static files (index.html) from current directory
app.mount("/static", StaticFiles(directory="."), name="static")
class GailInput(BaseModel):
age: float
race: int = 1
age_at_menarche: float = None
age_at_first_live_birth: float = None
n_biopsy: int = 0
n_relatives: int = 0
hyperplasia: int = 0
age_at_biopsy: float = None
projection_time: float = 5
@app.get("/", response_class=HTMLResponse)
async def index():
return FileResponse('index.html')
@app.post("/predict")
async def predict(payload: GailInput):
# Build CLI args for R script
args = [
f"--age", str(payload.age),
f"--race", str(payload.race),
]
if payload.age_at_menarche is not None:
args += ["--age_at_menarche", str(payload.age_at_menarche)]
if payload.age_at_first_live_birth is not None:
args += ["--age_at_first_live_birth", str(payload.age_at_first_live_birth)]
args += ["--n_biopsy", str(payload.n_biopsy)]
args += ["--n_relatives", str(payload.n_relatives)]
args += ["--hyperplasia", str(payload.hyperplasia)]
if payload.age_at_biopsy is not None:
args += ["--age_at_biopsy", str(payload.age_at_biopsy)]
args += ["--projection_time", str(payload.projection_time)]
image = os.getenv('GAIL_MODEL_IMAGE', 'gail_model_container')
stdout = ''
# Try docker first
if shutil.which('docker'):
cmd = ['docker', 'run', '--rm', image] + args
try:
proc = subprocess.run(cmd, capture_output=True, text=True, check=True)
stdout = proc.stdout.strip()
except subprocess.CalledProcessError as e:
raise HTTPException(status_code=500, detail={'stdout': e.stdout, 'stderr': e.stderr})
else:
# Fallback to local Rscript if docker is not available
rscript = shutil.which('Rscript')
if not rscript:
raise HTTPException(status_code=500, detail='Neither docker nor Rscript found on PATH. Install Docker or R to run the computation.')
rpath = os.path.join(os.getcwd(), 'gail_model.R')
cmd = [rscript, rpath] + args
try:
proc = subprocess.run(cmd, capture_output=True, text=True, check=True)
stdout = proc.stdout.strip()
except subprocess.CalledProcessError as e:
raise HTTPException(status_code=500, detail={'stdout': e.stdout, 'stderr': e.stderr})
try:
data = json.loads(stdout)
return JSONResponse(content=data)
except json.JSONDecodeError:
raise HTTPException(status_code=500, detail={'error': 'Failed to parse JSON from R script', 'raw_output': stdout})