-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathrun_model.py
More file actions
97 lines (76 loc) · 2.8 KB
/
run_model.py
File metadata and controls
97 lines (76 loc) · 2.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import gymnasium as gym
from stable_baselines3 import PPO
import os
import webbrowser
from tensorboard import program
from archery_env import ArcheryGymEnv
models_dir = "models"
logs_dir = "logs"
os.makedirs(models_dir, exist_ok=True)
os.makedirs(logs_dir, exist_ok=True)
try:
print("Launching TensorBoard...")
tb = program.TensorBoard()
tb.configure(argv=[None, '--logdir', logs_dir])
url = tb.launch()
print(f"TensorBoard started at {url}")
webbrowser.open(url)
except Exception as e:
print(f"Could not auto-launch TensorBoard: {e}")
def get_saved_models(directory="models"):
if not os.path.exists(directory):
os.makedirs(directory)
return []
files = [f for f in os.listdir(directory) if f.endswith(".zip")]
files.sort(reverse=True)
return files
def main():
models = get_saved_models()
if not models:
print("No models found. Run main_train.py first.")
return
print("\n--- Available Models ---")
for i, model_file in enumerate(models):
print(f"{i + 1}: {model_file}")
while True:
choice = input(f"\nSelect a model number (1-{len(models)}) or 'q' to quit: ")
if choice.lower() == 'q': return
try:
idx = int(choice) - 1
if 0 <= idx < len(models):
selected_model_name = models[idx]
break
else: print("Invalid number.")
except ValueError: print("Please enter a number.")
model_path = os.path.join("models", selected_model_name)
print(f"Loading {selected_model_name}...")
env = ArcheryGymEnv(render_mode="human")
model = PPO.load(model_path)
obs, _ = env.reset()
# --- STATS TRACKING ---
shots_fired = 0
shots_hit = 0
# ----------------------
print("Running simulation... (Press Ctrl+C to stop)")
try:
while True:
action, _states = model.predict(obs)
obs, reward, terminated, truncated, info = env.step(action)
# Reset on finish
if terminated or truncated:
# Update Stats
shots_fired += 1
if reward >= 100: # We used 100.0 for a hit in the env
shots_hit += 1
# Calculate percentage
accuracy = (shots_hit / shots_fired) * 100
# Update the Environment's label so it draws on the next frame
env.accuracy_label = f"{accuracy:.1f}%"
print(f"Shot: {shots_fired} | Hit: {reward >= 100} | Accuracy: {accuracy:.3f}%")
obs, _ = env.reset()
except KeyboardInterrupt:
print("\nStopping simulation...")
finally:
env.close()
if __name__ == "__main__":
main()