-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrainer_cli.py
More file actions
151 lines (117 loc) · 4.75 KB
/
trainer_cli.py
File metadata and controls
151 lines (117 loc) · 4.75 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
#!/usr/bin/env python3
"""
GTO Poker Trainer - Command Line Interface
Usage:
python trainer_cli.py train [--quick] [--iterations N] [--device cpu|cuda]
python trainer_cli.py evaluate <checkpoint>
python trainer_cli.py analyze <checkpoint>
"""
import argparse
import sys
import os
# Add src to path
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from src.trainer import GTOTrainer, TrainingConfig, train_gto
from src.evaluation import analyze_preflop_strategy
def cmd_train(args):
"""Run training."""
print("=" * 60)
print("GTO POKER TRAINER")
print("=" * 60)
print()
config = TrainingConfig(
quick_mode=args.quick,
device=args.device,
checkpoint_dir=args.checkpoint_dir,
print_every=args.print_every,
save_every=args.save_every,
eval_every=args.eval_every
)
trainer = GTOTrainer(config)
if args.resume:
trainer.load_checkpoint(args.resume)
trainer.train(num_iterations=args.iterations)
# Print final analysis
print("\n" + "=" * 60)
print("FINAL STRATEGY ANALYSIS")
print("=" * 60)
print(trainer.analyze_strategy())
def cmd_evaluate(args):
"""Evaluate a trained model."""
print("Loading checkpoint...")
config = TrainingConfig(device=args.device)
trainer = GTOTrainer(config)
trainer.load_checkpoint(args.checkpoint)
print("\nRunning evaluation...")
print(trainer.analyze_strategy())
def cmd_analyze(args):
"""Analyze strategy from checkpoint."""
print("Loading checkpoint...")
config = TrainingConfig(device=args.device)
trainer = GTOTrainer(config)
trainer.load_checkpoint(args.checkpoint)
print("\n" + trainer.analyze_strategy())
def cmd_quick_test(args):
"""Run a quick test to verify everything works."""
print("Running quick training test...")
print("This will train for just a few iterations to verify setup.\n")
config = TrainingConfig(
quick_mode=True,
device=args.device,
print_every=5,
save_every=20,
eval_every=10
)
trainer = GTOTrainer(config)
trainer.train(num_iterations=args.iterations or 20)
print("\nQuick test complete!")
print("If you see strategy analysis above, everything is working.")
def main():
parser = argparse.ArgumentParser(
description="GTO Poker Trainer - Train neural networks to play GTO poker"
)
subparsers = parser.add_subparsers(dest='command', help='Commands')
# Train command
train_parser = subparsers.add_parser('train', help='Train a GTO model')
train_parser.add_argument('--quick', action='store_true',
help='Use quick curriculum for testing')
train_parser.add_argument('--iterations', type=int, default=None,
help='Override number of iterations')
train_parser.add_argument('--device', type=str, default='cpu',
choices=['cpu', 'cuda'],
help='Device to train on')
train_parser.add_argument('--checkpoint-dir', type=str, default='checkpoints',
help='Directory to save checkpoints')
train_parser.add_argument('--resume', type=str, default=None,
help='Resume from checkpoint')
train_parser.add_argument('--print-every', type=int, default=10,
help='Print progress every N iterations')
train_parser.add_argument('--save-every', type=int, default=100,
help='Save checkpoint every N iterations')
train_parser.add_argument('--eval-every', type=int, default=50,
help='Run evaluation every N iterations')
# Evaluate command
eval_parser = subparsers.add_parser('evaluate', help='Evaluate a trained model')
eval_parser.add_argument('checkpoint', type=str, help='Path to checkpoint')
eval_parser.add_argument('--device', type=str, default='cpu')
# Analyze command
analyze_parser = subparsers.add_parser('analyze', help='Analyze strategy')
analyze_parser.add_argument('checkpoint', type=str, help='Path to checkpoint')
analyze_parser.add_argument('--device', type=str, default='cpu')
# Quick test command
test_parser = subparsers.add_parser('test', help='Run a quick test')
test_parser.add_argument('--device', type=str, default='cpu')
test_parser.add_argument('--iterations', type=int, default=20)
args = parser.parse_args()
if args.command == 'train':
cmd_train(args)
elif args.command == 'evaluate':
cmd_evaluate(args)
elif args.command == 'analyze':
cmd_analyze(args)
elif args.command == 'test':
cmd_quick_test(args)
else:
parser.print_help()
if __name__ == '__main__':
main()