From a50f4dc56949e5eeb08e9772846d4eb168f61529 Mon Sep 17 00:00:00 2001 From: Amy Li Date: Fri, 15 Dec 2023 09:43:04 -0800 Subject: [PATCH 1/3] add logger class and example usage of tensorboard logger --- ambersim/logger/logger.py | 133 ++++++++++++++++++++++++++++++ examples/rl/pendulum/ex_logger.py | 78 ++++++++++++++++++ 2 files changed, 211 insertions(+) create mode 100644 ambersim/logger/logger.py create mode 100644 examples/rl/pendulum/ex_logger.py diff --git a/ambersim/logger/logger.py b/ambersim/logger/logger.py new file mode 100644 index 00000000..4d905689 --- /dev/null +++ b/ambersim/logger/logger.py @@ -0,0 +1,133 @@ +import wandb +from torch.utils.tensorboard import SummaryWriter + + +class BaseLogger: + """Base logger interface that defines common methods for logging metrics and parameters. + + Attributes: + log_dir (str): Directory where logs are stored. If None, default locations are used. + """ + + def __init__(self, log_dir=None): + """Initializes the BaseLogger with a specified log directory. + + Args: + log_dir (str): Directory to store the logs. If None, uses default log directory. + """ + self.log_dir = log_dir + + def log_metric(self, key, value, step=None): + """Logs a metric value. + + Args: + key (str): The name of the metric. + value (float): The value of the metric. + step (int, optional): The step number at which the metric is logged. + """ + raise NotImplementedError + + def log_params(self, params): + """Logs parameters. + + Args: + params (dict): A dictionary containing parameter names and their values. + """ + raise NotImplementedError + + +class TensorBoardLogger(BaseLogger): + """Logger that implements logging functionality using TensorBoard. + + Inherits from BaseLogger and implements its methods for TensorBoard specific logging. + """ + + def __init__(self, log_dir=None): + """Initializes the TensorBoardLogger with a specified log directory. + + Args: + log_dir (str): Directory to store TensorBoard logs. If None, uses default log directory. + """ + super().__init__(log_dir) + self.writer = SummaryWriter(log_dir) + + def log_metric(self, key, value, step=None): + """Logs a metric to TensorBoard. + + Args: + key (str): The name of the metric. + value (float): The value of the metric. + step (int, optional): The step number at which the metric is logged. + """ + self.writer.add_scalar(key, value, step) + + def log_params(self, params): + """Logs parameters to TensorBoard. + + Args: + params (dict): A dictionary of parameters to log. + """ + self.writer.add_hparams(params) + + +class WandbLogger(BaseLogger): + """Logger that implements logging functionality using Weights & Biases (wandb). + + Inherits from BaseLogger and implements its methods for wandb specific logging. + """ + + def __init__(self, log_dir=None, project_name=None): + """Initializes the WandbLogger with a specified log directory and project name. + + Args: + log_dir (str): Directory to store local wandb logs. If None, uses default wandb directory. + project_name (str): Name of the wandb project. If None, a default project is used. + """ + super().__init__(log_dir) + wandb.init(dir=log_dir, project=project_name) + + def log_metric(self, key, value, step=None): + """Logs a metric to wandb. + + Args: + key (str): The name of the metric. + value (float): The value of the metric. + step (int, optional): The step number at which the metric is logged. + """ + wandb.log({key: value}, step=step) + + def log_params(self, params): + """Logs parameters to wandb. + + Args: + params (dict): A dictionary of parameters to log. + """ + wandb.config.update(params) + + +class LoggerFactory: + """Factory class to create logger instances based on specified logger type. + + Supports creation of different types of loggers like TensorBoardLogger and WandbLogger. + """ + + @staticmethod + def get_logger(logger_type, log_dir=None): + """Creates and returns a logger instance based on the specified logger type. + + Args: + logger_type (str): The type of logger to create ('tensorboard' or 'wandb'). + log_dir (str, optional): Directory to store the logs. Specific to the logger type. + + Returns: + BaseLogger: An instance of the requested logger type. + + Raises: + ValueError: If an unsupported logger type is specified. + """ + if logger_type == "tensorboard": + return TensorBoardLogger(log_dir) + elif logger_type == "wandb": + return WandbLogger(log_dir) + else: + raise ValueError("Unsupported logger type") diff --git a/examples/rl/pendulum/ex_logger.py b/examples/rl/pendulum/ex_logger.py new file mode 100644 index 00000000..a0de6650 --- /dev/null +++ b/examples/rl/pendulum/ex_logger.py @@ -0,0 +1,78 @@ +import functools +import os +from datetime import datetime + +import jax +from brax import envs +from brax.training.agents.ppo import networks as ppo_networks +from brax.training.agents.ppo import train as ppo + +from ambersim.logger.logger import LoggerFactory +from ambersim.rl.pendulum.swingup import PendulumSwingupEnv + +""" +A pendulum swingup example that uses a custom logger to log training +progress in real time. +""" + +if __name__ == "__main__": + # Initialize the environment + envs.register_environment("pendulum_swingup", PendulumSwingupEnv) + env = envs.get_environment("pendulum_swingup") + + # Define the training function + network_factory = functools.partial( + ppo_networks.make_ppo_networks, + policy_hidden_layer_sizes=(64,) * 3, + ) + train_fn = functools.partial( + ppo.train, + num_timesteps=100_000, + num_evals=50, + reward_scaling=0.1, + episode_length=200, + normalize_observations=True, + action_repeat=1, + unroll_length=10, + num_minibatches=32, + num_updates_per_batch=8, + discounting=0.97, + learning_rate=3e-4, + entropy_cost=0, + num_envs=1024, + batch_size=512, + network_factory=network_factory, + seed=0, + ) + + # Save the log in the current directory + log_dir = os.path.join(os.path.abspath(os.getcwd()), "logs") + if not os.path.exists(log_dir): + os.makedirs(log_dir) + + print(f"Setting up Tensorboard logging in {log_dir}") + logger = LoggerFactory.get_logger("tensorboard", log_dir) + + # Define a callback to log progress + times = [datetime.now()] + + def progress(num_steps, metrics): + """Logs progress during RL.""" + print(f" Steps: {num_steps}, Reward: {metrics['eval/episode_reward']}") + times.append(datetime.now()) + + # Write all metrics to tensorboard + for key, val in metrics.items(): + if isinstance(val, jax.Array): + val = float(val) # we need floats for logging + logger.log_metric(key, val, num_steps) + + # Do the training + print("Training...") + make_inference_fn, params, _ = train_fn( + environment=env, + progress_fn=progress, + ) + + print(f"Time to jit: {times[1] - times[0]}") + print(f"Time to train: {times[-1] - times[1]}") From 3ed391987dfceca38dda676770b88a1b8af87e02 Mon Sep 17 00:00:00 2001 From: Amy Li Date: Fri, 15 Dec 2023 13:57:03 -0800 Subject: [PATCH 2/3] add log_progress function in base class --- ambersim/logger/logger.py | 13 +++++++++++++ examples/rl/pendulum/ex_logger.py | 16 +--------------- 2 files changed, 14 insertions(+), 15 deletions(-) diff --git a/ambersim/logger/logger.py b/ambersim/logger/logger.py index 4d905689..4f94e41f 100644 --- a/ambersim/logger/logger.py +++ b/ambersim/logger/logger.py @@ -1,3 +1,4 @@ +import jax import wandb from torch.utils.tensorboard import SummaryWriter @@ -35,6 +36,18 @@ def log_params(self, params): """ raise NotImplementedError + def log_progress(self, step, state_info): + """Logs the state of a process using the log_metric method. + + Args: + state_info (dict): A dictionary containing state information. + step (int, optional): The step number at which the state is logged. + """ + for key, value in state_info.items(): + if isinstance(value, jax.Array): + value = float(value) # we need floats for logging + self.log_metric(key, value, step) + class TensorBoardLogger(BaseLogger): """Logger that implements logging functionality using TensorBoard. diff --git a/examples/rl/pendulum/ex_logger.py b/examples/rl/pendulum/ex_logger.py index a0de6650..5fd2d3ff 100644 --- a/examples/rl/pendulum/ex_logger.py +++ b/examples/rl/pendulum/ex_logger.py @@ -56,23 +56,9 @@ # Define a callback to log progress times = [datetime.now()] - def progress(num_steps, metrics): - """Logs progress during RL.""" - print(f" Steps: {num_steps}, Reward: {metrics['eval/episode_reward']}") - times.append(datetime.now()) - - # Write all metrics to tensorboard - for key, val in metrics.items(): - if isinstance(val, jax.Array): - val = float(val) # we need floats for logging - logger.log_metric(key, val, num_steps) - # Do the training print("Training...") make_inference_fn, params, _ = train_fn( environment=env, - progress_fn=progress, + progress_fn=logger.log_progress, ) - - print(f"Time to jit: {times[1] - times[0]}") - print(f"Time to train: {times[-1] - times[1]}") From 1088fd0a7381d49fabcd5e8555591f709d0a14dc Mon Sep 17 00:00:00 2001 From: kli58 <30124182+kli58@users.noreply.github.com> Date: Mon, 18 Dec 2023 09:23:41 -0800 Subject: [PATCH 3/3] Apply suggestions from code review Co-authored-by: alberthli --- ambersim/logger/logger.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/ambersim/logger/logger.py b/ambersim/logger/logger.py index 4f94e41f..ced0f586 100644 --- a/ambersim/logger/logger.py +++ b/ambersim/logger/logger.py @@ -4,13 +4,9 @@ class BaseLogger: - """Base logger interface that defines common methods for logging metrics and parameters. + """Base logger interface that defines common methods for logging metrics and parameters.""" - Attributes: - log_dir (str): Directory where logs are stored. If None, default locations are used. - """ - - def __init__(self, log_dir=None): + def __init__(self, log_dir: Union[str, Path]=None): """Initializes the BaseLogger with a specified log directory. Args: