Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
306 changes: 306 additions & 0 deletions examples/cartpole/cartpole-ppo-plot.ipynb

Large diffs are not rendered by default.

2,615 changes: 2,615 additions & 0 deletions examples/cartpole/cartpole-ppo-psrs-homer16.ipynb

Large diffs are not rendered by default.

33,029 changes: 33,029 additions & 0 deletions examples/cartpole/cartpole-ppo-true-revealed.ipynb

Large diffs are not rendered by default.

2,876 changes: 2,876 additions & 0 deletions examples/cartpole/cartpole-ppo.ipynb

Large diffs are not rendered by default.

37 changes: 37 additions & 0 deletions examples/cartpole/run.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Run notebook to generate data

# Train HOMER
# cd into this directory first
# python train_homer_encoder.py \
# --num_epochs=20 \
# --seed=0 \
# --batch_size=64 \
# --latent_size=8 \
# --hidden_size=64 \
# --lr=1e-3 \
# --weight_decay=0.0 \
# --temperature_decay=False \
# --output_dir='outputs/models' \
# --num_samples=1000

python train_homer_encoder.py \
--num_epochs=1000 \
--seed=0 \
--batch_size=64 \
--latent_size=16 \
--hidden_size=64 \
--lr=1e-3 \
--weight_decay=0.0 \
--temperature_decay=False \
--output_dir='outputs/models'

python train_homer_encoder.py \
--num_epochs=1000 \
--seed=0 \
--batch_size=64 \
--latent_size=32 \
--hidden_size=64 \
--lr=1e-3 \
--weight_decay=0.0 \
--temperature_decay=False \
--output_dir='outputs/models'
89 changes: 89 additions & 0 deletions examples/cartpole/train_homer_encoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import argparse
import os
from datetime import datetime
import torch
from torch.utils.data import random_split

import random
import numpy as np
import pandas as pd

from offsim4rl.utils.dataset_utils import load_h5_dataset
from offsim4rl.data import SAS_Dataset
from offsim4rl.encoders.homer import HOMEREncoder
from offsim4rl.utils.vis_utils import plot_latent_state_color_map

import matplotlib.pyplot as plt
class HOMEREncoder_CartPole(HOMEREncoder):
def _visualize(self, fname='latent_state.png'):
obs, _, _ = val_dataset[:]
emb = self.encode(obs)
df_output = pd.DataFrame([(i, *x) for i, x in zip(emb, obs)], columns=['i', 'x', "x'", 'y', "y'"])
fig, ax = plt.subplots(figsize=(4, 4))
plt.scatter(df_output['x'], df_output['y'], c=df_output['i'], cmap='nipy_spectral', marker='.', lw=0, s=3)
plt.xlim(-2.5, 2.5)
plt.ylim(-0.25, 0.25)
plt.axhline(-0.2095, c='gray')
plt.axhline(0.2095, c='gray')
plt.axvline(-2.4, c='gray')
plt.axvline(2.4, c='gray')
plt.xlabel('Cart Position')
plt.ylabel('Pole Angle')
plt.savefig(os.path.join(args.output_dir, model_dir, 'vis', fname))

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--num_samples', type=int, default=None) # for debugging
parser.add_argument('--num_epochs', type=int, default=1000)
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--batch_size', type=int, default=64)
parser.add_argument('--latent_size', type=int, default=25)
parser.add_argument('--hidden_size', type=int, default=64)
parser.add_argument('--lr', type=float, default=1e-3)
parser.add_argument('--weight_decay', type=float, default=0.0)
parser.add_argument('--temperature_decay', type=bool, default=False)
parser.add_argument('--input_dir', type=str, default='outputs')
parser.add_argument('--output_dir', type=str, default='outputs')
args = parser.parse_args()

torch.manual_seed(args.seed)
random.seed(args.seed)
np.random.seed(args.seed)

model_dir = f"./trial={datetime.now().isoformat(timespec='minutes').replace('-','').replace(':','')}," + \
f"encoder_model=both,seed={args.seed}," + \
f"dZ={args.latent_size},dH={args.hidden_size},lr={args.lr},weight_decay={args.weight_decay}/"
os.makedirs(os.path.join(args.output_dir, model_dir, 'vis'), exist_ok=True)

buffer = load_h5_dataset(os.path.join(args.input_dir, 'CartPole-v1_ppo.h5'))
full_dataset = SAS_Dataset(buffer['observations'], buffer['actions'], buffer['next_observations'])

# TRAINING
if args.num_samples is not None: # Limit sample size for debugging
full_dataset = torch.utils.data.Subset(full_dataset, range(args.num_samples))

train_dataset, val_dataset = random_split(
full_dataset,
[len(full_dataset) // 2, len(full_dataset) // 2],
generator=torch.Generator().manual_seed(42)
)

homer_encoder = HOMEREncoder_CartPole(
obs_dim=4, action_dim=2,
latent_size=args.latent_size,
hidden_size=args.hidden_size,
log_dir=os.path.join(args.output_dir, model_dir),
)

homer_encoder.train(
train_dataset,
val_dataset,
lr=args.lr,
weight_decay=args.weight_decay,
batch_size=args.batch_size,
num_epochs=args.num_epochs,
temperature_decay=args.temperature_decay,
)

# INFERENCE
homer_encoder._visualize(fname='latent_state.png')
2 changes: 1 addition & 1 deletion examples/continuous_grid/train_homer_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,4 +68,4 @@
)

# INFERENCE
homer_encoder._visualize(fname='latent_state.png')
homer_encoder._visualize(fname='latent_state.png')
Loading