Small fixes checkpoint loading and ott-jax version bump#31
Conversation
jannisborn
left a comment
There was a problem hiding this comment.
Nice work maybe I'm just lacking the context, but I'm not sure we need this PR because MLP still exists in current ott-jax. PotentialMLP behaves differently, it also models the potential.
| from ott.neural.methods.neuraldual import W2NeuralDual | ||
| from ott.neural.networks.icnn import ICNN | ||
| from ott.neural.networks.potentials import MLP | ||
| from ott.neural.networks.potentials import PotentialMLP |
There was a problem hiding this comment.
This is a breaking change (unless PotentialMLP is an alias to MLP in all older ott versions) so we need to bump the lower bound to a version that already has PotentialMLP
There was a problem hiding this comment.
Aah I see what you mean, PotentialMLP is the alias to the MLP in the older versions. I thought the lower bound is already jax-ott==0.5.0 so that should be alright.
There was a problem hiding this comment.
I dont think it's an alias as the source code is using different classes and PotentialMLP seems to calculate some extra things. Do we need those?
There was a problem hiding this comment.
The main reason why we need the PotentialMLP and not the MLP, is because we need it to have the function create_train_state. In the newer version, the MLP doesn't inherit from any model base class that has this function described, but potentialMLP does. Otherwise the MLP (whichever we choose), is also used in the NeuralDualTrainer, where I guess these extra functionalities might be used.
So old version MLP (see here):
class MLP(ModelBase):
"""A non-convex MLP.
Args:
dim_hidden: sequence specifying size of hidden dimensions. The output
dimension of the last layer is automatically set to 1 if
:attr:`is_potential` is ``True``, or the dimension of the input otherwise
is_potential: Model the potential if ``True``, otherwise
model the gradient of the potential
act_fn: Activation function
"""
dim_hidden: Sequence[int]
is_potential: bool = True
act_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.leaky_relu
@nn.compact
def __call__(self, x: jnp.ndarray) -> jnp.ndarray: # noqa: D102
squeeze = x.ndim == 1
if squeeze:
x = jnp.expand_dims(x, 0)
assert x.ndim == 2, x.ndim
n_input = x.shape[-1]
z = x
for n_hidden in self.dim_hidden:
Wx = nn.Dense(n_hidden, use_bias=True)
z = self.act_fn(Wx(z))
if self.is_potential:
Wx = nn.Dense(1, use_bias=True)
z = Wx(z).squeeze(-1)
quad_term = 0.5 * jax.vmap(jnp.dot)(x, x)
z += quad_term
else:
Wx = nn.Dense(n_input, use_bias=True)
z = x + Wx(z)
return z.squeeze(0) if squeeze else z
Current version PotentialMLP (see here):
class PotentialMLP(BasePotential):
"""Potential MLP.
Args:
dim_hidden: Sequence specifying the size of hidden dimensions. The output
dimension of the last layer is automatically set to 1 if
:attr:`is_potential` is ``True``, or the dimension of the input otherwise.
is_potential: Model the potential if ``True``, otherwise
model the gradient of the potential.
act_fn: Activation function.
"""
dim_hidden: Sequence[int]
is_potential: bool = True
act_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.leaky_relu
@nn.compact
def __call__(self, x: jnp.ndarray) -> jnp.ndarray: # noqa: D102
squeeze = x.ndim == 1
if squeeze:
x = jnp.expand_dims(x, 0)
assert x.ndim == 2, x.ndim
n_input = x.shape[-1]
z = x
for n_hidden in self.dim_hidden:
Wx = nn.Dense(n_hidden, use_bias=True)
z = self.act_fn(Wx(z))
if self.is_potential:
Wx = nn.Dense(1, use_bias=True)
z = Wx(z).squeeze(-1)
quad_term = 0.5 * jax.vmap(jnp.dot)(x, x)
z += quad_term
else:
Wx = nn.Dense(n_input, use_bias=True)
z = x + Wx(z)
return z.squeeze(0) if squeeze else z
There was a problem hiding this comment.
Alright thanks for digging and giving the context! This means that #23 silently introduced a bug and our current pypi release is unstable. So let's bump the version of cmonge in this PR. Also, would be great to add a small test that tries to use "create_train_state" function. Basically a test that would have failed when trying to merge #23
Co-authored-by: Jannis Born <jannis.born@gmx.de>
…e into ot_trainer-monge
Tested only the conditional monge after the version bump. Now also some minor changes to the monga gap to ensure checkpoint loading and compatability with the ott-jax version bump.