-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathPPO_f.py
More file actions
105 lines (89 loc) · 5.19 KB
/
PPO_f.py
File metadata and controls
105 lines (89 loc) · 5.19 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import time
import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
from tqdm import trange
import wandb
from Agent import Agent
from Utils import make_env
def policy_rollout(agent, envs, num_steps, num_paralel_envs, writer, global_step, env_name):
device = 'cuda' if torch.cuda.is_available() else 'cpu'
obs = torch.zeros((num_steps, num_paralel_envs, np.prod(envs.single_observation_space.shape))).to(device)
actions = torch.zeros((num_steps, num_paralel_envs) + envs.single_action_space.shape).to(device)
logprobs = torch.zeros((num_steps, num_paralel_envs)).to(device)
rewards = torch.zeros((num_steps, num_paralel_envs)).to(device)
dones = torch.zeros((num_steps, num_paralel_envs)).to(device)
values = torch.zeros((num_steps, num_paralel_envs)).to(device)
next_obs, _ = envs.reset()
next_obs = torch.Tensor(next_obs).to(device)
next_done = torch.zeros(num_paralel_envs).to(device)
for step in range(num_steps):
global_step += num_paralel_envs
obs[step] = next_obs
dones[step] = next_done
with torch.no_grad():
action, logprob, _, value,start, end = agent.get_action_and_value(next_obs)
values[step] = value.flatten()
actions[step] = action
logprobs[step] = logprob
next_obs, reward, terminations, truncations, infos = envs.step(action.cpu().numpy())
done = np.logical_or(terminations, truncations)
rewards[step] = torch.tensor(reward).flatten().to(device)
next_obs, next_done = torch.Tensor(next_obs).to(device), torch.Tensor(done).to(device)
if "final_info" in infos:
for info in infos["final_info"]:
# Skip the envs that are not done
if info:
# print(f"global_step={global_step}, episodic_return={info['episode']['r']}")
writer.add_scalar(f"charts/episodic_return_{env_name}", info["episode"]["r"], global_step)
writer.add_scalar(f"charts/episodic_length_{env_name}", info["episode"]["l"], global_step)
return obs, next_obs, actions, logprobs, rewards, dones, values, next_obs, next_done, global_step
def calculate_returns(agent, envs, gamma, gae_lambda, num_steps, num_paralel_envs, writer, global_step, env_name):
obs, next_obs, actions, logprobs, rewards, dones, values, last_obs, last_done, global_step = policy_rollout(agent,
envs,
num_steps,
num_paralel_envs,
writer,
global_step, env_name)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
with torch.no_grad():
next_value = agent.get_value(next_obs).reshape(1, -1)
advantages = torch.zeros_like(rewards).to(device)
lastgaelam = 0
for t in reversed(range(num_steps)):
if t == num_steps - 1:
next_non_terminal = 1.0 - last_done
nextvalues = next_value
else:
next_non_terminal = 1.0 - dones[t + 1]
nextvalues = values[t + 1]
delta = rewards[t] + gamma * nextvalues * next_non_terminal - values[t]
advantages[t] = lastgaelam = delta + gamma * gae_lambda * next_non_terminal * lastgaelam
returns = advantages + values
b_obs = obs.reshape((-1, np.prod(envs.single_observation_space.shape)))
b_logprobs = logprobs.reshape(-1)
b_actions = actions.reshape((-1,) + envs.single_action_space.shape)
b_advantages = advantages.reshape(-1)
b_returns = returns.reshape(-1)
b_values = values.reshape(-1)
return b_obs, b_logprobs, b_actions, b_advantages, b_returns, b_values, global_step
def calculate_loss(agent, b_obs, b_logprobs, b_actions, b_advantages, b_returns, b_values, mb_inds, clip_coef, ent_coef, vf_coef):
_, newlogprob, entropy, newvalue, start, end = agent.get_action_and_value(b_obs[mb_inds], b_actions[mb_inds])
logratio = newlogprob - b_logprobs[mb_inds]
ratio = torch.exp(logratio)
mb_advantages = b_advantages[mb_inds]
pg_loss1 = -mb_advantages * ratio
pg_loss2 = -mb_advantages * torch.clamp(ratio, 1 - clip_coef, 1 + clip_coef)
pg_loss = torch.max(pg_loss1, pg_loss2).mean()
newvalue = newvalue.view(-1)
v_loss_unclipped = (newvalue - b_returns[mb_inds]) ** 2
v_clipped = b_values[mb_inds] + torch.clamp(newvalue - b_values[mb_inds], -clip_coef, clip_coef)
v_loss_clipped = (v_clipped - b_returns[mb_inds]) ** 2
v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped)
v_loss = 0.5 * v_loss_max.mean()
entropy_loss = entropy.mean()
loss = pg_loss - ent_coef * entropy_loss + v_loss * vf_coef
return loss, v_loss, pg_loss, entropy_loss, start, end