-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathml_server.py
More file actions
179 lines (154 loc) · 5.92 KB
/
ml_server.py
File metadata and controls
179 lines (154 loc) · 5.92 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
# ml_server.py
import os
import csv
import joblib
import pandas as pd
from typing import Optional
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
MODEL_PATH = "model.pkl"
VECT_PATH = "vectorizer.pkl"
CSV_PATH = "training_data.csv"
ALLOWED_SUBS = [
"쇼핑","보험ㆍ대출기타ㆍ금융","식비","이체","교통",
"의료ㆍ건강ㆍ피트니스","주거ㆍ통신","생활","카페ㆍ간식","수입","미분류"
]
# ---------- 유틸 ----------
def ensure_csv():
if os.path.exists(CSV_PATH):
return
# 샘플 몇 개만 넣어 초기화 (없어도 동작)
rows = [
["스타벅스 커피 | 스타벅스 | 강남점", "카페ㆍ간식"],
["이디야 음료 | 이디야 | 역삼점", "카페ㆍ간식"],
["무신사 쇼핑 | 무신사 | 온라인", "쇼핑"],
["쿠팡 쇼핑 | 쿠팡 | 온라인", "쇼핑"],
["GS25 편의점 | GS25 | 선릉점", "식비"],
["세브란스병원 | 진료 | 외래", "의료ㆍ건강ㆍ피트니스"],
["버스 결제 | 서울버스 | 교통카드", "교통"],
["관리비 자동이체 | 관리사무소 | 아파트", "주거ㆍ통신"],
]
with open(CSV_PATH, "w", newline="", encoding="utf-8") as f:
w = csv.writer(f)
w.writerow(["memo","subcategory"])
w.writerows(rows)
print("✅ training_data.csv 초기화 완료")
def extract_center_memo(memo: str) -> str:
if memo is None:
return ""
parts = [p.strip() for p in memo.split("|")]
if len(parts) >= 2:
return parts[1]
return memo.strip()
def hard_rules(memo: str, amount: Optional[float]):
"""
하드룰:
- 메모에 '진주종합금융센터' 포함 → 미분류
- 금액이 양수 → 수입
"""
center = extract_center_memo(memo)
if "진주종합금융센터" in memo or "진주종합금융센터" in center:
return "미분류", 0.99, "하드룰: 진주종합금융센터 → 미분류"
if amount is not None and amount > 0:
return "수입", 0.99, "하드룰: 금액 양수 → 수입"
return None
# ---------- 모델 로딩 ----------
def load_model():
if not (os.path.exists(MODEL_PATH) and os.path.exists(VECT_PATH)):
raise FileNotFoundError(
"모델 파일이 없습니다. 먼저 `python train.py` 로 학습하세요."
)
clf = joblib.load(MODEL_PATH)
vect = joblib.load(VECT_PATH)
return clf, vect
ensure_csv()
clf, vect = None, None
try:
clf, vect = load_model()
print("✅ 모델 로드 완료")
except Exception as e:
print("⚠️ 모델 로드 실패:", e)
# ---------- FastAPI ----------
app = FastAPI(title="KB532 ML Server", version="1.0.0")
class PredictReq(BaseModel):
memo: str
amount: Optional[float] = None # 금액(원) - 양수면 수입
# 필요시 userId, txId 등 추가 가능
class PredictRes(BaseModel):
subcategory: str
confidence: float
rationale: str
center_memo: str
class FeedbackReq(BaseModel):
memo: str
chosen_subcategory: str # 유저가 선택한 소분류
@app.post("/predict", response_model=PredictRes)
def predict(req: PredictReq):
global clf, vect
if clf is None or vect is None:
raise HTTPException(503, "모델이 없습니다. /retrain 으로 학습 후 사용하세요.")
# 하드 룰 우선 적용
ruled = hard_rules(req.memo, req.amount)
center = extract_center_memo(req.memo)
if ruled is not None:
sub, conf, why = ruled
# ‘기타지출’은 금지 → 미분류 강제
if sub == "기타지출":
sub = "미분류"
why += " (+기타지출 금지 규칙)"
return PredictRes(subcategory=sub, confidence=conf, rationale=why, center_memo=center)
# ML 예측
X = vect.transform([center])
probs = getattr(clf, "predict_proba")(X)[0]
labels = clf.classes_
best_idx = probs.argmax()
sub = labels[best_idx]
conf = float(probs[best_idx])
# ‘기타지출’은 미분류로 치환
if sub == "기타지출":
sub = "미분류"
return PredictRes(
subcategory=sub,
confidence=conf,
rationale="ML 예측 (center memo 기반 TF-IDF + LogisticRegression)",
center_memo=center
)
@app.post("/feedback")
def feedback(req: FeedbackReq):
"""
유저가 ‘미분류’를 직접 수정했을 때 호출.
CSV에 누적 → 추후 /retrain 으로 반영.
"""
sub = req.chosen_subcategory.strip()
if sub not in ALLOWED_SUBS:
raise HTTPException(400, f"허용되지 않은 소분류입니다: {sub}")
with open(CSV_PATH, "a", newline="", encoding="utf-8") as f:
w = csv.writer(f)
w.writerow([req.memo, sub])
return {"ok": True, "message": "피드백 저장 완료. 충분히 쌓이면 /retrain 실행하세요."}
@app.post("/retrain")
def retrain():
"""
누적된 training_data.csv 기반으로 즉시 재학습.
"""
global clf, vect
try:
df = pd.read_csv(CSV_PATH).dropna(subset=["memo","subcategory"])
df = df[df["subcategory"].isin(ALLOWED_SUBS)]
if len(df) < 5:
raise HTTPException(400, "학습 데이터가 너무 적습니다. 최소 5건 이상 필요.")
# 벡터라이저/모델 재학습 (train.py와 동일 설정)
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.linear_model import LogisticRegression
vect = TfidfVectorizer(analyzer="char_wb", ngram_range=(2,5), min_df=1)
X = vect.fit_transform(df["memo"])
y = df["subcategory"].values
clf = LogisticRegression(max_iter=200, class_weight="balanced", multi_class="ovr")
clf.fit(X, y)
joblib.dump(clf, MODEL_PATH)
joblib.dump(vect, VECT_PATH)
return {"ok": True, "message": "재학습 완료"}
except HTTPException:
raise
except Exception as e:
raise HTTPException(500, f"재학습 실패: {e}")