-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathvanilla_train.py
More file actions
97 lines (72 loc) · 2.64 KB
/
vanilla_train.py
File metadata and controls
97 lines (72 loc) · 2.64 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
from pathlib import Path
from datetime import datetime
from stable_baselines3 import PPO
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.common.callbacks import EvalCallback
from src.utils.utils import seed_everything
if __name__ == "__main__":
env_id = "Ant"
env_version = "v5"
task_name = "train_" + env_id + "_default"
run_id = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
# global settings
seed = 1234
# log_dir is the main storage for checkpoints, configs and tensorboard records
log_dir = Path().cwd().joinpath("logs")
run_id = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
run_dir = log_dir.joinpath("runs", task_name, run_id)
# create directory where the models snapshots are stored
checkpoint_path = run_dir.joinpath("checkpoints")
checkpoint_path.mkdir(parents=True, exist_ok=False)
# create directory for the tensorboard record
tensorboard_path = run_dir.joinpath("tensorboard")
tensorboard_path.mkdir(parents=True, exist_ok=False)
# seed all random generators used for reproducability
seed_everything(seed=1234)
# instantiate the training enviornments
train_env = make_vec_env(
env_id="Ant-v5",
n_envs=4,
vec_env_cls=DummyVecEnv,
seed=seed,
)
# instantiate validation environments
eval_env = make_vec_env(
env_id="Ant-v5",
n_envs=4,
vec_env_cls=DummyVecEnv,
seed=seed + 10,
)
# instantiate the agent
agent = PPO(
policy="MlpPolicy",
env=train_env,
learning_rate=3e-4,
n_steps=1024,
batch_size=256,
gamma=0.98,
gae_lambda=0.98,
ent_coef=0.0,
clip_range=0.2,
verbose=1,
seed=seed,
device="cpu",
tensorboard_log=tensorboard_path.as_posix()
)
# instantiate callbacks
callbacks = [
EvalCallback(
eval_env=eval_env,
callback_on_new_best=None, # alternatively one can pass e.g. StopTrainingOnRewardThreshold instance for early stopping
best_model_save_path=checkpoint_path.joinpath("best_model").as_posix(),
log_path=checkpoint_path.joinpath("logs").as_posix(),
n_eval_episodes=5,
eval_freq=10000,
render=False,
verbose=1,
)
]
# train the agent
agent.learn(total_timesteps=1000000, callback=callbacks, progress_bar=True)
print("training finished in %d timesteps!"%agent.num_timesteps)