Skip to content

Commit 0c65a7e

Browse files
committed
Upgrade console logger and allow multiple backends
1 parent fb22883 commit 0c65a7e

9 files changed

Lines changed: 179 additions & 68 deletions

File tree

docs/callbacks/index.md

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ via the `callback` argument to `learn`.
3131
from jax import random as jr
3232

3333
from lerax.algorithm import PPO
34-
from lerax.callback import LoggingCallback, ProgressBarCallback, TensorBoardBackend
34+
from lerax.callback import ConsoleBackend, LoggingCallback, TensorBoardBackend
3535
from lerax.env.classic_control import CartPole
3636
from lerax.policy import MLPActorCriticPolicy
3737

@@ -41,29 +41,29 @@ env = CartPole()
4141
policy = MLPActorCriticPolicy(env=env, key=policy_key)
4242
algo = PPO()
4343

44-
logger = LoggingCallback(TensorBoardBackend(), env=env, policy=policy)
45-
callbacks = [
46-
ProgressBarCallback(total_timesteps=2**16, env=env, policy=policy),
47-
logger,
48-
]
44+
logger = LoggingCallback(
45+
[TensorBoardBackend(), ConsoleBackend(total_timesteps=2**16)],
46+
env=env,
47+
policy=policy,
48+
)
4949

5050
policy = algo.learn(
5151
env,
5252
policy,
5353
total_timesteps=2**16,
5454
key=learn_key,
55-
callback=callbacks,
55+
callback=logger,
5656
)
5757
logger.close()
5858
```
5959

6060
## Built-in callbacks
6161

62-
- [`ProgressBarCallback`](progress_bar.md):
63-
Rich-based progress bar showing iterations, elapsed/remaining time, and iterations per second.
64-
6562
- [`LoggingCallback`](logging.md):
66-
Logs training metrics (learning rate, training log entries, episode return/length EMAs) to TensorBoard, Aim, or Weights & Biases via a pluggable backend.
63+
Logs training metrics (learning rate, training log entries, episode return/length EMAs) to one or more pluggable backends. Use `ConsoleBackend` for a live terminal display with progress bar and metrics table, `TensorBoardBackend` for TensorBoard, or `WandbBackend` for Weights & Biases.
64+
65+
- [`ProgressBarCallback`](progress_bar.md):
66+
Standalone Rich progress bar callback. For most use cases prefer `ConsoleBackend` inside `LoggingCallback` instead, which provides both a progress bar and a live metrics table.
6767

6868
- `CallbackList`:
6969
Aggregates multiple callbacks and forwards all hooks to each one. Used automatically when you pass a list of callbacks.

docs/callbacks/logging.md

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,11 +88,18 @@ backend = WandbBackend(
8888

8989
### ConsoleBackend
9090

91-
Prints metrics to the terminal using Rich. Useful for quick debugging without a logging server.
91+
Displays a live metrics table and progress bar in the terminal using [Rich](https://rich.readthedocs.io/).
92+
On each iteration the metrics table is updated in-place (not appended), keeping the display compact, with a progress bar rendered below it.
93+
94+
When `total_timesteps` is provided, the progress bar shows completion, elapsed/remaining time, and throughput. Without it, metrics are printed as simple key=value lines.
9295

9396
```py
9497
from lerax.callback import ConsoleBackend
9598

99+
# With progress bar and live metrics table
100+
backend = ConsoleBackend(total_timesteps=2**16)
101+
102+
# Without progress bar (plain text output)
96103
backend = ConsoleBackend()
97104
```
98105

docs/callbacks/progress_bar.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,11 @@ description: Display a live Rich progress bar during training.
55

66
# ProgressBarCallback
77

8+
!!! tip "Prefer ConsoleBackend"
9+
For most use cases, prefer [`ConsoleBackend`](logging.md#consolebackend) inside
10+
`LoggingCallback` instead. It provides the same progress bar plus a live
11+
metrics table, and avoids needing a separate callback.
12+
813
`ProgressBarCallback` displays a terminal progress bar using [Rich](https://rich.readthedocs.io/).
914
It shows:
1015

examples/gym_environment.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from jax import random as jr
33

44
from lerax.algorithm import PPO
5-
from lerax.callback import LoggingCallback, ProgressBarCallback, TensorBoardBackend
5+
from lerax.callback import ConsoleBackend, LoggingCallback, TensorBoardBackend
66
from lerax.compatibility.gym import GymToLeraxEnv
77
from lerax.policy import MLPActorCriticPolicy
88

@@ -12,10 +12,11 @@
1212
env = GymToLeraxEnv(gym_env)
1313
policy = MLPActorCriticPolicy(env=env, key=policy_key)
1414
algo = PPO(num_envs=1) # Vectorization is not supported for Gym environments
15-
logger = LoggingCallback(TensorBoardBackend(), env=env, policy=policy)
16-
callbacks = [ProgressBarCallback(2**16), logger]
17-
18-
policy = algo.learn(
19-
env, policy, total_timesteps=2**16, key=learn_key, callback=callbacks
15+
logger = LoggingCallback(
16+
[TensorBoardBackend(), ConsoleBackend(total_timesteps=2**16)],
17+
env=env,
18+
policy=policy,
2019
)
20+
21+
policy = algo.learn(env, policy, total_timesteps=2**16, key=learn_key, callback=logger)
2122
logger.close()

examples/gymnax_environment.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from jax import random as jr
33

44
from lerax.algorithm import PPO
5-
from lerax.callback import LoggingCallback, ProgressBarCallback, TensorBoardBackend
5+
from lerax.callback import ConsoleBackend, LoggingCallback, TensorBoardBackend
66
from lerax.compatibility.gymnax import GymnaxToLeraxEnv
77
from lerax.policy import MLPActorCriticPolicy
88

@@ -13,10 +13,11 @@
1313

1414
policy = MLPActorCriticPolicy(env=env, key=policy_key)
1515
algo = PPO()
16-
logger = LoggingCallback(TensorBoardBackend(), env=env, policy=policy)
17-
callbacks = [ProgressBarCallback(2**16), logger]
18-
19-
policy = algo.learn(
20-
env, policy, total_timesteps=2**16, key=learn_key, callback=callbacks
16+
logger = LoggingCallback(
17+
[TensorBoardBackend(), ConsoleBackend(total_timesteps=2**16)],
18+
env=env,
19+
policy=policy,
2120
)
21+
22+
policy = algo.learn(env, policy, total_timesteps=2**16, key=learn_key, callback=logger)
2223
logger.close()

examples/ppo.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from jax import random as jr
22

33
from lerax.algorithm import PPO
4-
from lerax.callback import LoggingCallback, ProgressBarCallback, WandbBackend
4+
from lerax.callback import ConsoleBackend, LoggingCallback, TensorBoardBackend
55
from lerax.env.classic_control import CartPole
66
from lerax.policy import MLPActorCriticPolicy
77

@@ -10,12 +10,15 @@
1010
env = CartPole()
1111
policy = MLPActorCriticPolicy(env=env, key=policy_key)
1212
algo = PPO()
13+
total_timesteps = 2**16
1314
logger = LoggingCallback(
14-
WandbBackend(project="lerax"), env=env, policy=policy, video_interval=1
15+
[TensorBoardBackend(), ConsoleBackend(total_timesteps=total_timesteps)],
16+
env=env,
17+
policy=policy,
18+
video_interval=1,
1519
)
16-
callbacks = [ProgressBarCallback(2**16), logger]
1720

1821
policy = algo.learn(
19-
env, policy, total_timesteps=2**16, key=learn_key, callback=callbacks
22+
env, policy, total_timesteps=total_timesteps, key=learn_key, callback=logger
2023
)
2124
logger.close()

src/lerax/callback/logging/callback.py

Lines changed: 24 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import concurrent.futures
44
import dataclasses
55
import os
6-
from collections.abc import Callable
6+
from collections.abc import Callable, Sequence
77
from datetime import datetime
88
from functools import partial
99
from typing import Any
@@ -205,7 +205,7 @@ def _make_video_recorder(
205205
video_width: int,
206206
video_height: int,
207207
video_fps: float,
208-
backend: AbstractLoggingBackend,
208+
backends: list[AbstractLoggingBackend],
209209
executor: concurrent.futures.ThreadPoolExecutor,
210210
) -> Callable[..., None]:
211211
"""
@@ -224,7 +224,7 @@ def _make_video_recorder(
224224
video_width: Render width in pixels.
225225
video_height: Render height in pixels.
226226
video_fps: Playback frames per second.
227-
backend: Logging backend to forward video frames to.
227+
backends: Logging backends to forward video frames to.
228228
executor: Thread pool to run the recording work in.
229229
"""
230230

@@ -349,7 +349,8 @@ def render_frame(env_state) -> np.ndarray:
349349
renderer.close()
350350

351351
frames_arr = np.stack(frames).astype(np.uint8)
352-
backend.log_video("eval/video", frames_arr, step, fps=video_fps)
352+
for b in backends:
353+
b.log_video("eval/video", frames_arr, step, fps=video_fps)
353354
except Exception as exc:
354355
import warnings
355356

@@ -407,7 +408,7 @@ class LoggingCallback(AbstractCallback[EmptyCallbackState, LoggingCallbackStepSt
407408
alpha: EMA smoothing factor for episode statistics.
408409
409410
Args:
410-
backend: Logging backend to send metrics to.
411+
backend: Logging backend (or list of backends) to send metrics to.
411412
name: Explicit run name. When ``None``, a name is generated from the
412413
environment name, policy name, and a timestamp. If neither ``env``
413414
nor ``policy`` are provided, falls back to a plain timestamp.
@@ -423,7 +424,7 @@ class LoggingCallback(AbstractCallback[EmptyCallbackState, LoggingCallbackStepSt
423424
video_fps: Playback frames per second.
424425
"""
425426

426-
_backend: AbstractLoggingBackend = eqx.field(static=True)
427+
_backends: list[AbstractLoggingBackend] = eqx.field(static=True)
427428
_name: str | None = eqx.field(static=True)
428429
_hparams: dict[str, Any] | None = eqx.field(static=True)
429430
alpha: float
@@ -434,7 +435,7 @@ class LoggingCallback(AbstractCallback[EmptyCallbackState, LoggingCallbackStepSt
434435

435436
def __init__(
436437
self,
437-
backend: AbstractLoggingBackend,
438+
backend: AbstractLoggingBackend | Sequence[AbstractLoggingBackend],
438439
name: str | None = None,
439440
env: AbstractEnvLike | None = None,
440441
policy: AbstractPolicy | None = None,
@@ -446,7 +447,11 @@ def __init__(
446447
video_height: int = 480,
447448
video_fps: float = 50.0,
448449
) -> None:
449-
self._backend = backend
450+
if isinstance(backend, AbstractLoggingBackend):
451+
self._backends = [backend]
452+
else:
453+
self._backends = list(backend)
454+
450455
self._hparams = hparams
451456
self.alpha = alpha
452457

@@ -461,7 +466,8 @@ def __init__(
461466
name = "_".join(parts)
462467
self._name = name
463468

464-
backend.open(name)
469+
for b in self._backends:
470+
b.open(name)
465471

466472
if video_interval > 0:
467473
executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
@@ -472,7 +478,7 @@ def __init__(
472478
video_width,
473479
video_height,
474480
video_fps,
475-
backend,
481+
self._backends,
476482
executor,
477483
)
478484
else:
@@ -513,9 +519,8 @@ def on_iteration(
513519
scalars["episode/return"] = step_state.average_return.mean()
514520
scalars["episode/length"] = step_state.average_length.mean()
515521

516-
callback_with_numpy_wrapper(self._backend.log_scalars, ordered=True)(
517-
scalars, last_step
518-
)
522+
for b in self._backends:
523+
callback_with_numpy_wrapper(b.log_scalars, ordered=True)(scalars, last_step)
519524

520525
if self._record_video_fn is not None:
521526
video_key, key = jr.split(key)
@@ -537,7 +542,8 @@ def on_training_start(
537542
)
538543
hparams.update(self._hparams or {})
539544

540-
callback_wrapper(lambda: self._backend.log_hparams(hparams), ordered=True)()
545+
for b in self._backends:
546+
callback_wrapper(lambda b=b: b.log_hparams(hparams), ordered=True)()
541547
return ctx.state
542548

543549
def on_training_end(
@@ -548,13 +554,14 @@ def on_training_end(
548554
def close(self) -> None:
549555
"""Flush pending data and release backend resources.
550556
551-
Call this after all ``learn()`` calls are complete. The backend
552-
remains open between ``learn()`` calls so that metrics from
557+
Call this after all ``learn()`` calls are complete. The backends
558+
remain open between ``learn()`` calls so that metrics from
553559
multiple stages are logged to the same run.
554560
"""
555561
if self._video_executor is not None:
556562
self._video_executor.shutdown(wait=True)
557-
self._backend.close()
563+
for b in self._backends:
564+
b.close()
558565

559566
def continue_training(
560567
self, ctx: IterationContext, *, key: Key[Array, ""]

0 commit comments

Comments
 (0)