-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathlearning.py
More file actions
335 lines (271 loc) · 11.6 KB
/
learning.py
File metadata and controls
335 lines (271 loc) · 11.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
import functools
import os
import time
import einops
import flax
import jax
import jax.numpy as jnp
import jax.random as random
import optax
import orbax.checkpoint as ocp
from flax.training import train_state
import airports
import configure
import game
import global_stopwatch
import modeling
import utils
MAX_TURNS = 2 + 7 * (8 + 2 + 2)
#AIRPORT = airports.KEF
AIRPORT = airports.random_state
def make_score_fn():
def get_raw_score(state: game.State):
is_landing = (state.altitude == 0) & (state.track_left == 1)
axis_done = (state.pilot_axis > 0) & (state.copilot_axis > 0)
engine_done = (state.pilot_engine > 0) & (state.copilot_engine > 0)
fuel_done = jnp.where(state.fuel_rule, state.is_filled[game.FUEL_INDEX],
engine_done | ~state.leak_rule)
return (-4 * state.altitude +
-4 * state.track_left +
-2 * state.num_planes() +
-0.5 * state.wind_speed() +
jnp.sum(state.is_on) +
0.1 * jnp.where(fuel_done, state.fuel, state.fuel - 6) +
1.0 * (axis_done & (state.track_left > 0) &
((state.min_tilt[0] != -2) | (state.max_tilt[0] != 2)) &
(state.tilt >= state.min_tilt[0]) & (state.tilt <= state.max_tilt[0])) +
0.1 * is_landing * (state.pilot_engine > 0) * (6 - state.pilot_engine) +
0.1 * is_landing * (state.copilot_engine > 0) * (6 - state.copilot_engine) +
(is_landing & engine_done &
(state.pilot_engine + state.copilot_engine <= state.brake_speed)) +
(is_landing & axis_done & (state.tilt == 0)))
r0 = get_raw_score(AIRPORT())
rmax = get_raw_score(AIRPORT().replace(
coffees=jnp.array(3, dtype=jnp.int32),
rerolls=jnp.array(3, dtype=jnp.int32),
altitude=jnp.array(0, dtype=jnp.int32),
track_left=jnp.array(1, dtype=jnp.int32),
is_on=jnp.array([slot.is_switch for slot in game.SLOTS]),
approach_track=jnp.array([0, -1, -1, -1, -1, -1, -1]),
brake_speed=jnp.array(2, dtype=jnp.int32),
pilot_engine=jnp.array(1, dtype=jnp.int32),
copilot_engine=jnp.array(1, dtype=jnp.int32),
pilot_axis=jnp.array(1, dtype=jnp.int32),
copilot_axis=jnp.array(1, dtype=jnp.int32),
wind=jnp.array(9, dtype=jnp.int32),
))
@jax.jit
def score_state(state: game.State):
return jnp.where(
state.result == game.WIN, 1, 0.25 * (get_raw_score(state) - r0) / (rmax - r0))
return score_state
@flax.struct.dataclass
class Examples:
states: game.State
actions: game.Action
returns: jnp.ndarray
advantages: jnp.ndarray
log_probs: jnp.ndarray
def make_train_state(config: configure.Config, key: random.PRNGKey):
state = AIRPORT()
model = modeling.Model(config)
params = model.init(key, jax.tree.map(lambda x: x[None], state))
tx = optax.adam(optax.warmup_constant_schedule(
init_value=0.0,
peak_value=config.learning_rate,
warmup_steps=100
))
return train_state.TrainState.create(
apply_fn=model.apply,
params=params,
tx=tx
)
def make_checkpoint_manager(path):
checkpointer = ocp.PyTreeCheckpointer()
options = ocp.CheckpointManagerOptions(max_to_keep=2, create=True)
return ocp.CheckpointManager(os.path.abspath(path), checkpointer, options)
def load_checkpoint(train_state, path):
checkpoint_manager = make_checkpoint_manager(path)
device = jax.devices()[0]
sharding = jax.sharding.SingleDeviceSharding(device)
restore_args = jax.tree.map(
lambda _: ocp.type_handlers.ArrayRestoreArgs(sharding=sharding), train_state)
return checkpoint_manager.restore(checkpoint_manager.latest_step(), train_state,
restore_kwargs={"restore_args": restore_args})
@functools.partial(jax.jit, static_argnums=(0,))
def train_step(config: configure.Config, train_state, examples: Examples):
@functools.partial(jax.grad, has_aux=True)
def grad_fn(params):
outputs = train_state.apply_fn(params, examples.states)
log_probs = outputs.log_prob(examples.actions)
if config.ppo:
ratio = jnp.exp(log_probs - examples.log_probs)
clipped_ratio = jnp.clip(ratio, 1.0 - config.ppo_clip_ratio, 1.0 + config.ppo_clip_ratio)
policy_loss = -jnp.mean(jnp.minimum(
ratio * examples.advantages, clipped_ratio * examples.advantages))
else:
policy_loss = -jnp.mean(log_probs * examples.advantages)
example_mask = (examples.states.result == game.SAFE)
value_loss = jnp.mean(example_mask * jnp.square(examples.returns - outputs.value))
entropy = -jnp.mean(example_mask * jnp.sum(
jax.nn.softmax(outputs.logits) * jax.nn.log_softmax(outputs.logits), axis=-1))
loss = policy_loss + config.value_weight * value_loss - config.entropy_weight * entropy
return loss, {'loss': loss, 'policy_loss': policy_loss, 'value_loss': value_loss,
'entropy': entropy}
grads, metrics = grad_fn(train_state.params)
new_train_state = train_state.apply_gradients(grads=grads)
return new_train_state, metrics
@functools.partial(jax.jit, static_argnums=3)
@functools.partial(jax.vmap, in_axes=(0, 0, None, None))
def get_returns_and_advantages(rewards, values, discount_factor, gae_lambda):
def discounted_return(ret, reward):
ret = reward + discount_factor * ret
return ret, ret
_, returns = jax.lax.scan(discounted_return, 0, rewards[::-1])
returns = returns[::-1]
if gae_lambda < 0: # REINFORCE
return returns, returns
if gae_lambda == 1: # REINFORCE with baseline
return returns, returns - values
def gae_advantage(advantage, delta):
advantage = delta + discount_factor * gae_lambda * advantage
return advantage, advantage
deltas = rewards + discount_factor * jnp.append(values[1:], 0) - values
_, advantages = jax.lax.scan(gae_advantage, 0.0, deltas[::-1])
advantages = advantages[::-1]
return returns, advantages
make_state = jax.jit(jax.vmap(AIRPORT))
@jax.jit
def trajectory_step(train_state, states, scores, key):
outputs = train_state.apply_fn(train_state.params, states)
if key is None:
actions = outputs.get_top_action()
else:
actions = outputs.sample_action(key)
log_probs = outputs.log_prob(actions)
new_states = jax.vmap(game.do_action)(states, actions)
new_scores = jax.vmap(make_score_fn())(new_states)
rewards = new_scores - scores
# if the action caused a loss, or the game was already over, zero-out the reward
rewards *= (new_states.result > game.SAFE) | (states.result == game.SAFE)
values = outputs.value * (states.result == game.SAFE)
return new_states, new_scores, (states, actions, rewards, values, log_probs)
def collect_trajectories(
config: configure.Config, key, train_state, collect_metrics=True):
states = make_state(key=random.split(key, config.batch_size))
scores = jnp.zeros(config.batch_size)
history = []
global_stopwatch.start("collect_trajectory")
for _ in range(MAX_TURNS): # replace with a scan?
key, step_key = random.split(key)
states, scores, result = trajectory_step(train_state, states, scores, step_key)
history.append(result)
global_stopwatch.stop()
global_stopwatch.start("postprocess_trajectory")
states, actions, rewards, values, log_probs = map(lambda x: stack_lists(x, 1), zip(*history))
returns, advantages = get_returns_and_advantages(
rewards, values,
jnp.array(config.discount_factor, jnp.float32),
config.gae_lambda)
examples = Examples(
states=states, actions=actions, returns=returns, advantages=advantages, log_probs=log_probs)
_, shuffle_key = random.split(key)
examples = make_minibatches(examples, shuffle_key, config.train_batch_size)
global_stopwatch.stop()
if not collect_metrics:
return examples, None
global_stopwatch.start("get_metrics")
for j in range(rewards.shape[1]):
result = states.result[0, j]
if result != game.SAFE:
j -= 1
break
final_state = jax.tree.map(lambda x: x[0, j], states)
total_reward = float(rewards[0].sum())
metrics = {
"reward": 100 * total_reward,
"altitude": float(final_state.altitude),
"turns": float(8 * (7 - final_state.altitude) -
final_state.pilot_dice.num() - final_state.copilot_dice.num() + 1),
"planes_left": float(jnp.maximum(0, final_state.approach_track).sum()),
"landing_gear_on": float(final_state.min_speed - 5),
"brakes_on": float((final_state.brake_speed - 1) if final_state.ice_rule else
(final_state.brake_speed / 2)),
"flaps_on": float(final_state.max_speed - 8),
"rerolls": float(final_state.rerolls),
"result": int(result),
"track_left": int(final_state.track_left),
"win_percent": 100.0 if total_reward > 0.99 else 0.0
}
global_stopwatch.stop()
return examples, metrics
@functools.partial(jax.jit, static_argnums=(1,))
def stack_lists(elems, axis=0):
return jax.tree.map(lambda *args: jnp.stack(args, axis), *elems)
@functools.partial(jax.jit, static_argnums=(2,))
def make_minibatches(examples, key, batch_size):
n_examples = examples.returns.shape[0] * examples.returns.shape[1]
n_batches = n_examples // batch_size
reshaped = jax.tree.map(lambda x: einops.rearrange(x, 'n m ... -> (n m) ...'), examples)
idx = random.permutation(key, n_examples)
shuffled = jax.tree.map(lambda x: x[idx], reshaped)
return [jax.tree.map(lambda leaves: leaves[i * batch_size:(i + 1) * batch_size], shuffled)
for i in range(n_batches)]
class History:
def __init__(self):
self.keys = []
self.history = []
def append(self, metrics):
self.keys = list(metrics.keys())
self.history.append(list(metrics.values()))
def get_avgs(self, last_n=0):
avgs = jax.tree.map(lambda *args: sum(args) / len(args), *self.history[-last_n:])
return {k: avg for k, avg in zip(self.keys, avgs)}
def print_avgs(self, last_n=0):
print(", ".join([f"{k}: {v:.2f}" for k, v in self.get_avgs(last_n).items()]))
def write(self, path):
utils.write_pickle((self.keys, self.history), path)
def main(**kwargs):
config_kwargs = dict(write=True, overwrite=True)
config_kwargs.update(**kwargs)
config = configure.Config(**config_kwargs)
key = random.PRNGKey(1)
key, init_key = random.split(key)
train_state = make_train_state(config, init_key)
checkpoint_manager = make_checkpoint_manager(config.checkpoint_dir)
if config.init_checkpoint:
train_state = load_checkpoint(train_state, config.init_checkpoint)
start_time = time.time()
history = History()
for step in range(config.num_steps):
key, trajectory_key = random.split(key)
examples, metrics = collect_trajectories(config, trajectory_key, train_state,
step % config.metrics_every == 0)
global_stopwatch.start("train_step")
for e in examples:
train_state, train_metrics = train_step(config, train_state, e)
global_stopwatch.stop()
if metrics is not None:
metrics.update(train_metrics)
metrics["games_played"] = step * config.batch_size
metrics["time"] = time.time() - start_time
history.append(metrics)
if step % config.metrics_every == 0:
print(f"{config.experiment_name}: step {step}, ({metrics['games_played']} games), " +
f"{metrics['time']:.2f}s elapsed")
history.print_avgs(100)
if step < 2:
global_stopwatch.clear()
elif step % 10 == 0:
global_stopwatch.print_times()
if step % 100 == 0:
global_stopwatch.start("write_history_and_checkpoint")
history.write(config.history_path)
checkpoint_manager.save(step, train_state)
global_stopwatch.stop()
if __name__ == "__main__":
main(experiment_name="aiviator",
num_steps=1000000,
num_layers=8,
batch_size=256,
train_batch_size=MAX_TURNS * 256 // 32)