-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathtrain.py
More file actions
96 lines (80 loc) · 2.94 KB
/
train.py
File metadata and controls
96 lines (80 loc) · 2.94 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import time
import utils
import argparse
import datetime
import torch, gc
from rl_algorithm.drqn.agent import DRQNAgent
from rl_algorithm.dqn.agent import DQNAgent
gc.collect()
torch.cuda.empty_cache()
# Parse arguments
parser = argparse.ArgumentParser()
parser.add_argument("--env", required=True, help="name of the environment to train on (REQUIRED)")
parser.add_argument("--model", default=None, help="name of the model (default: {ENV}_{ALGO}_{TIME})")
parser.add_argument("--seed", type=int, default=-1, help="specific seed")
parser.add_argument("--frames", type=int, default=2*10**6, help="number of frames of training (default: 2e6)")
parser.add_argument("--max-memory", type=int, default=500000, help="Maximum experiences stored (default: 500000)")
parser.add_argument("--lr", type=float, default=0.0001, help="learning rate (default: 0.0001)")
parser.add_argument("--algorithm", type=str, default="dqn", help="dqn, drqn")
parser.add_argument("--rnd_scale", type=float, default=None)
parser.add_argument("--softmax_ww", type=int, default=50)
parser.add_argument("--log_wandb", type=bool, default=False)
args = parser.parse_args()
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# Set run dir
date = datetime.datetime.now().strftime("%y-%m-%d-%H-%M-%S")
default_model_name = "{}_{}_{}".format(
args.env, args.algorithm, date
)
model_name = args.model or default_model_name
model_dir = utils.get_model_dir(model_name)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
start_time = time.time()
return_per_frame, test_return_per_frame = [], []
seed = args.seed
utils.seed(seed)
env = utils.make_env(args.env, seed)
eval_env = utils.make_env(args.env, seed)
return_per_frame_, test_return_per_frame_ = [], []
num_frames = 0
episode = 0
# Load observations preprocessor
obs_space, preprocess_obss = utils.get_obss_preprocessor(
env.observation_space
)
exploration_options = ["epsilon-random", "epsilon-z", "epsilon-rnd", "epsilon"]
if args.algorithm == "dqn":
agent = DQNAgent(
env=env,
eval_env=eval_env,
exploration_options=exploration_options,
device=device,
args=args,
preprocess_obs=preprocess_obss,
model_dir=model_dir,
)
if args.algorithm == "drqn":
agent = DRQNAgent(
env=env,
eval_env=eval_env,
exploration_options=exploration_options,
device=device,
args=args,
preprocess_obs=preprocess_obss,
model_dir=model_dir,
)
while num_frames < args.frames:
update_start_time = time.time()
logs = agent.collect_experiences(
start_time=start_time,
episode=episode,
num_frames=num_frames,
return_per_frame_=return_per_frame_,
test_return_per_frame_=test_return_per_frame_,
)
update_end_time = time.time()
num_frames = logs["num_frames"]
episode += 1
return_per_frame.append(return_per_frame_)
test_return_per_frame.append(test_return_per_frame_)