diff --git a/ambersim/logger/logger.py b/ambersim/logger/logger.py new file mode 100644 index 00000000..ced0f586 --- /dev/null +++ b/ambersim/logger/logger.py @@ -0,0 +1,142 @@ +import jax +import wandb +from torch.utils.tensorboard import SummaryWriter + + +class BaseLogger: + """Base logger interface that defines common methods for logging metrics and parameters.""" + + def __init__(self, log_dir: Union[str, Path]=None): + """Initializes the BaseLogger with a specified log directory. + + Args: + log_dir (str): Directory to store the logs. If None, uses default log directory. + """ + self.log_dir = log_dir + + def log_metric(self, key, value, step=None): + """Logs a metric value. + + Args: + key (str): The name of the metric. + value (float): The value of the metric. + step (int, optional): The step number at which the metric is logged. + """ + raise NotImplementedError + + def log_params(self, params): + """Logs parameters. + + Args: + params (dict): A dictionary containing parameter names and their values. + """ + raise NotImplementedError + + def log_progress(self, step, state_info): + """Logs the state of a process using the log_metric method. + + Args: + state_info (dict): A dictionary containing state information. + step (int, optional): The step number at which the state is logged. + """ + for key, value in state_info.items(): + if isinstance(value, jax.Array): + value = float(value) # we need floats for logging + self.log_metric(key, value, step) + + +class TensorBoardLogger(BaseLogger): + """Logger that implements logging functionality using TensorBoard. + + Inherits from BaseLogger and implements its methods for TensorBoard specific logging. + """ + + def __init__(self, log_dir=None): + """Initializes the TensorBoardLogger with a specified log directory. + + Args: + log_dir (str): Directory to store TensorBoard logs. If None, uses default log directory. + """ + super().__init__(log_dir) + self.writer = SummaryWriter(log_dir) + + def log_metric(self, key, value, step=None): + """Logs a metric to TensorBoard. + + Args: + key (str): The name of the metric. + value (float): The value of the metric. + step (int, optional): The step number at which the metric is logged. + """ + self.writer.add_scalar(key, value, step) + + def log_params(self, params): + """Logs parameters to TensorBoard. + + Args: + params (dict): A dictionary of parameters to log. + """ + self.writer.add_hparams(params) + + +class WandbLogger(BaseLogger): + """Logger that implements logging functionality using Weights & Biases (wandb). + + Inherits from BaseLogger and implements its methods for wandb specific logging. + """ + + def __init__(self, log_dir=None, project_name=None): + """Initializes the WandbLogger with a specified log directory and project name. + + Args: + log_dir (str): Directory to store local wandb logs. If None, uses default wandb directory. + project_name (str): Name of the wandb project. If None, a default project is used. + """ + super().__init__(log_dir) + wandb.init(dir=log_dir, project=project_name) + + def log_metric(self, key, value, step=None): + """Logs a metric to wandb. + + Args: + key (str): The name of the metric. + value (float): The value of the metric. + step (int, optional): The step number at which the metric is logged. + """ + wandb.log({key: value}, step=step) + + def log_params(self, params): + """Logs parameters to wandb. + + Args: + params (dict): A dictionary of parameters to log. + """ + wandb.config.update(params) + + +class LoggerFactory: + """Factory class to create logger instances based on specified logger type. + + Supports creation of different types of loggers like TensorBoardLogger and WandbLogger. + """ + + @staticmethod + def get_logger(logger_type, log_dir=None): + """Creates and returns a logger instance based on the specified logger type. + + Args: + logger_type (str): The type of logger to create ('tensorboard' or 'wandb'). + log_dir (str, optional): Directory to store the logs. Specific to the logger type. + + Returns: + BaseLogger: An instance of the requested logger type. + + Raises: + ValueError: If an unsupported logger type is specified. + """ + if logger_type == "tensorboard": + return TensorBoardLogger(log_dir) + elif logger_type == "wandb": + return WandbLogger(log_dir) + else: + raise ValueError("Unsupported logger type") diff --git a/examples/rl/pendulum/ex_logger.py b/examples/rl/pendulum/ex_logger.py new file mode 100644 index 00000000..5fd2d3ff --- /dev/null +++ b/examples/rl/pendulum/ex_logger.py @@ -0,0 +1,64 @@ +import functools +import os +from datetime import datetime + +import jax +from brax import envs +from brax.training.agents.ppo import networks as ppo_networks +from brax.training.agents.ppo import train as ppo + +from ambersim.logger.logger import LoggerFactory +from ambersim.rl.pendulum.swingup import PendulumSwingupEnv + +""" +A pendulum swingup example that uses a custom logger to log training +progress in real time. +""" + +if __name__ == "__main__": + # Initialize the environment + envs.register_environment("pendulum_swingup", PendulumSwingupEnv) + env = envs.get_environment("pendulum_swingup") + + # Define the training function + network_factory = functools.partial( + ppo_networks.make_ppo_networks, + policy_hidden_layer_sizes=(64,) * 3, + ) + train_fn = functools.partial( + ppo.train, + num_timesteps=100_000, + num_evals=50, + reward_scaling=0.1, + episode_length=200, + normalize_observations=True, + action_repeat=1, + unroll_length=10, + num_minibatches=32, + num_updates_per_batch=8, + discounting=0.97, + learning_rate=3e-4, + entropy_cost=0, + num_envs=1024, + batch_size=512, + network_factory=network_factory, + seed=0, + ) + + # Save the log in the current directory + log_dir = os.path.join(os.path.abspath(os.getcwd()), "logs") + if not os.path.exists(log_dir): + os.makedirs(log_dir) + + print(f"Setting up Tensorboard logging in {log_dir}") + logger = LoggerFactory.get_logger("tensorboard", log_dir) + + # Define a callback to log progress + times = [datetime.now()] + + # Do the training + print("Training...") + make_inference_fn, params, _ = train_fn( + environment=env, + progress_fn=logger.log_progress, + )