-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain_classifier.py
More file actions
38 lines (31 loc) · 1.45 KB
/
train_classifier.py
File metadata and controls
38 lines (31 loc) · 1.45 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import argparse
import pandas as pd
from utils.pattern_classifier import PatternClassifier
def main():
parser = argparse.ArgumentParser(description="Train Cup & Handle ML Classifier")
parser.add_argument("--report", type=str, required=True, help="Path to report.csv")
parser.add_argument("--out", type=str, default="models/cup_handle_model.pkl", help="Output model file")
args = parser.parse_args()
# -------------------------------
# Step 1: Load dataset
# -------------------------------
df = pd.read_csv(args.report)
if "valid" not in df.columns:
raise ValueError("❌ 'valid' column not found in report.csv. Please run main.py first.")
# -------------------------------
# Step 2: Check class balance
# -------------------------------
class_counts = df["valid"].value_counts(dropna=False).to_dict()
print("📊 Class distribution:", class_counts)
if class_counts.get(True, 0) < 5 or class_counts.get(False, 0) < 5:
print("⚠️ Warning: Extremely imbalanced dataset (too few Valid/Invalid samples).")
print("👉 Consider generating more patterns with main.py before training.")
# Exit early to avoid training on bad data
return
# -------------------------------
# Step 3: Train + Save Model
# -------------------------------
clf = PatternClassifier(model_path=args.out)
clf.train(df) # trains and saves automatically
if __name__ == "__main__":
main()