Currently, the StatefulTrainer cannot use [GradientTransformationExtraArgs](https://optax.readthedocs.io/en/latest/api/transformations.html#optax.GradientTransformationExtraArgs). It would be fairly easy to add support with an extra keyword argument.
Currently, the StatefulTrainer cannot use GradientTransformationExtraArgs. It would be fairly easy to add support with an extra keyword argument.