-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain_state.py
More file actions
23 lines (17 loc) · 893 Bytes
/
train_state.py
File metadata and controls
23 lines (17 loc) · 893 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
from typing import Callable, Any
from flax import core, struct
from flax.training import train_state
import jmp
class TrainState(train_state.TrainState):
"""Extension of `flax.training.train_state.TrainState` to include additional fields"""
# original fields:
# step: int | jax.Array
# apply_fn: Callable = struct.field(pytree_node=False)
# params: core.FrozenDict[str, Any] = struct.field(pytree_node=True)
# tx: optax.GradientTransformation = struct.field(pytree_node=False)
# opt_state: optax.OptState = struct.field(pytree_node=True)
field_apply_fn: Callable = struct.field(pytree_node=False)
batch_stats: core.FrozenDict[str, Any] = struct.field(pytree_node=True)
mp_policy: jmp.Policy = struct.field(pytree_node=False)
loss_scale: jmp.DynamicLossScale = struct.field(pytree_node=True)
wandb_id: str = struct.field(pytree_node=False)