-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
114 lines (90 loc) · 2.91 KB
/
Copy pathmain.py
File metadata and controls
114 lines (90 loc) · 2.91 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from transformers import AutoModelForCausalLM, AutoTokenizer
import logging
import os
import torch
from pydantic import BaseModel
# Initialize logging
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)
# Load fine-tuned model
MODEL_PATH = "./fine_tuned_model"
MODEL_NAME = "gpt2"
try:
if os.path.exists(MODEL_PATH):
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, local_files_only=True)
model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, local_files_only=True)
else:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
except Exception as e:
logger.warning(f"Failed to load local model: {str(e)}, loading from remote")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
app = FastAPI()
# Enable CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
class ChatRequest(BaseModel):
prompt: str
@app.get("/health")
async def health_check():
return {"status": "healthy"}
@app.post("/chat")
async def chat(request: ChatRequest):
try:
# Log incoming request
logger.info("Received chat request")
# Get user prompt
prompt = request.prompt
if not prompt:
raise HTTPException(status_code=400, detail="Prompt is required")
# Generate response
inputs = tokenizer.encode(prompt, return_tensors="pt")
outputs = model.generate(
inputs,
max_length=256,
pad_token_id=tokenizer.eos_token_id,
do_sample=True,
top_k=50,
top_p=0.95,
temperature=0.7,
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Log response
logger.info(f"Generated response: {response}")
return {"response": response}
except Exception as e:
logger.error(f"Error processing request: {str(e)}")
raise HTTPException(status_code=500, detail="Internal server error")
@app.get("/models")
async def get_models():
"""Get information about available models."""
return {
"models": [
{
"name": "harpertokenConvAI",
"type": "causal_lm",
"description": "Conversational AI model for text generation",
}
]
}
@app.get("/status")
async def get_status():
"""Get API and model status."""
return {
"status": "running",
"model_loaded": True,
"model_name": "harpertokenConvAI",
}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)