-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathapi_service.py
More file actions
119 lines (100 loc) · 3.75 KB
/
api_service.py
File metadata and controls
119 lines (100 loc) · 3.75 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field, field_validator
from typing import List, Dict, Union, Optional, Any
import logging
import uvicorn
import numpy as np
import re
from src.model_loader import ModelLoader
from src.data_validator import DataValidator
from src.config import API_PORT, API_HOST, REQUIRED_NUMERIC_FEATURES, CATEGORICAL_FEATURES
from src.utils import setup_logging
# Initialize logging
setup_logging()
logger = logging.getLogger(__name__)
# Initialize FastAPI app
app = FastAPI(title="GBM Model API")
# Add health check endpoint
@app.get("/health")
async def health_check():
"""Health check endpoint"""
return {"status": "healthy"}
# Initialize model and data validator
model = ModelLoader()
data_validator = DataValidator()
def clean_numeric_value(value: Any) -> float:
"""Clean numeric values from currency and percentage strings"""
if value is None:
return 0.0
if isinstance(value, (int, float)):
return float(value)
if isinstance(value, str):
# Remove currency symbols, commas, and handle negative values
value = value.replace('$', '').replace(',', '').replace('-$', '-').replace('%', '')
try:
return float(value)
except ValueError:
return 0.0
return 0.0
class PredictionInput(BaseModel):
"""Model for individual prediction inputs with proper type annotations"""
# Required numeric features
x_39: float = Field(default=0.0)
x_49: float = Field(default=0.0)
x_61: float = Field(default=0.0)
x_75: float = Field(default=0.0)
# Required categorical features
x_9: str = Field(default="") # Day of week
x_65: str = Field(default="") # Car make
# Allow additional fields
class Config:
extra = "allow"
@field_validator('*', mode='before')
@classmethod
def clean_fields(cls, v: Any, info: Any) -> Any:
if info.field_name in ['x_9', 'x_65']:
return str(v) if v is not None else ""
return clean_numeric_value(v)
@app.post("/predict")
async def predict(data: List[PredictionInput]):
"""
Endpoint for model predictions.
Accepts single or multiple instances for prediction.
"""
try:
# Convert pydantic model to dict
features = [d.dict() for d in data]
# Validate each instance
for instance in features:
data_validator.validate_features(instance)
# Get probability predictions
probabilities = model.predict_proba(features)
# Format response according to requirements
results = []
for i, proba in enumerate(probabilities):
# Only include the features used by the model
used_features = {
'x_39': features[i]['x_39'],
'x_49': features[i]['x_49'],
'x_61': features[i]['x_61'],
'x_75': features[i]['x_75'],
'x_9': features[i]['x_9'],
'x_65': features[i]['x_65']
}
results.append({
"business_outcome": 1 if proba[1] >= 0.75 else 0, # Use second column for positive class probability
"prediction": float(proba[1]), # Return probability of positive class
"feature_inputs": used_features
})
return results
except ValueError as e:
logger.error(f"Validation error: {str(e)}")
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
logger.error(f"Prediction error: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
def start_server():
"""Start the API server"""
uvicorn.run(app, host=API_HOST, port=API_PORT)
if __name__ == "__main__":
start_server()