-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
49 lines (42 loc) · 1.4 KB
/
train.py
File metadata and controls
49 lines (42 loc) · 1.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
# train.py
import os
import pandas as pd
import joblib
from sklearn.pipeline import Pipeline
from sklearn.linear_model import LogisticRegression
from sklearn.feature_extraction.text import TfidfVectorizer
MODEL_PATH = "model.pkl"
VECT_PATH = "vectorizer.pkl"
CSV_PATH = "training_data.csv"
ALLOWED_SUBS = [
"쇼핑","보험ㆍ대출기타ㆍ금융","식비","이체","교통",
"의료ㆍ건강ㆍ피트니스","주거ㆍ통신","생활","카페ㆍ간식","수입","미분류"
]
def load_data():
if not os.path.exists(CSV_PATH):
raise FileNotFoundError(f"{CSV_PATH} 가 없습니다.")
df = pd.read_csv(CSV_PATH)
df = df.dropna(subset=["memo", "subcategory"])
df = df[df["subcategory"].isin(ALLOWED_SUBS)]
return df
def train():
df = load_data()
# 한국어 짧은 텍스트에 강한 char n-gram + 단어 n-gram
vectorizer = TfidfVectorizer(
analyzer="char_wb", ngram_range=(2,5), min_df=1
)
X = vectorizer.fit_transform(df["memo"])
y = df["subcategory"].values
# 멀티노미얼 최적: 로지스틱 회귀 + 클래스 불균형 보정
clf = LogisticRegression(
max_iter=200,
n_jobs=None,
class_weight="balanced",
multi_class="ovr"
)
clf.fit(X, y)
joblib.dump(clf, MODEL_PATH)
joblib.dump(vectorizer, VECT_PATH)
print("모델 저장:", MODEL_PATH, VECT_PATH)
if __name__ == "__main__":
train()