Skip to content

Small fixes checkpoint loading and ott-jax version bump#31

Merged
jannisborn merged 9 commits intomainfrom
ot_trainer-monge
Aug 31, 2025
Merged

Small fixes checkpoint loading and ott-jax version bump#31
jannisborn merged 9 commits intomainfrom
ot_trainer-monge

Conversation

@DriessenA
Copy link
Collaborator

Tested only the conditional monge after the version bump. Now also some minor changes to the monga gap to ensure checkpoint loading and compatability with the ott-jax version bump.

@DriessenA DriessenA requested a review from jannisborn July 29, 2025 17:21
Copy link
Collaborator

@jannisborn jannisborn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice work maybe I'm just lacking the context, but I'm not sure we need this PR because MLP still exists in current ott-jax. PotentialMLP behaves differently, it also models the potential.

from ott.neural.methods.neuraldual import W2NeuralDual
from ott.neural.networks.icnn import ICNN
from ott.neural.networks.potentials import MLP
from ott.neural.networks.potentials import PotentialMLP
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a breaking change (unless PotentialMLP is an alias to MLP in all older ott versions) so we need to bump the lower bound to a version that already has PotentialMLP

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aah I see what you mean, PotentialMLP is the alias to the MLP in the older versions. I thought the lower bound is already jax-ott==0.5.0 so that should be alright.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dont think it's an alias as the source code is using different classes and PotentialMLP seems to calculate some extra things. Do we need those?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The main reason why we need the PotentialMLP and not the MLP, is because we need it to have the function create_train_state. In the newer version, the MLP doesn't inherit from any model base class that has this function described, but potentialMLP does. Otherwise the MLP (whichever we choose), is also used in the NeuralDualTrainer, where I guess these extra functionalities might be used.

So old version MLP (see here):

class MLP(ModelBase):
  """A non-convex MLP.

  Args:
    dim_hidden: sequence specifying size of hidden dimensions. The output
      dimension of the last layer is automatically set to 1 if
      :attr:`is_potential` is ``True``, or the dimension of the input otherwise
    is_potential: Model the potential if ``True``, otherwise
      model the gradient of the potential
    act_fn: Activation function
  """

  dim_hidden: Sequence[int]
  is_potential: bool = True
  act_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.leaky_relu

  @nn.compact
  def __call__(self, x: jnp.ndarray) -> jnp.ndarray:  # noqa: D102
    squeeze = x.ndim == 1
    if squeeze:
      x = jnp.expand_dims(x, 0)
    assert x.ndim == 2, x.ndim
    n_input = x.shape[-1]

    z = x
    for n_hidden in self.dim_hidden:
      Wx = nn.Dense(n_hidden, use_bias=True)
      z = self.act_fn(Wx(z))

    if self.is_potential:
      Wx = nn.Dense(1, use_bias=True)
      z = Wx(z).squeeze(-1)

      quad_term = 0.5 * jax.vmap(jnp.dot)(x, x)
      z += quad_term
    else:
      Wx = nn.Dense(n_input, use_bias=True)
      z = x + Wx(z)

    return z.squeeze(0) if squeeze else z

Current version PotentialMLP (see here):

class PotentialMLP(BasePotential):
  """Potential MLP.

  Args:
    dim_hidden: Sequence specifying the size of hidden dimensions. The output
      dimension of the last layer is automatically set to 1 if
      :attr:`is_potential` is ``True``, or the dimension of the input otherwise.
    is_potential: Model the potential if ``True``, otherwise
      model the gradient of the potential.
    act_fn: Activation function.
  """

  dim_hidden: Sequence[int]
  is_potential: bool = True
  act_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.leaky_relu

  @nn.compact
  def __call__(self, x: jnp.ndarray) -> jnp.ndarray:  # noqa: D102
    squeeze = x.ndim == 1
    if squeeze:
      x = jnp.expand_dims(x, 0)
    assert x.ndim == 2, x.ndim
    n_input = x.shape[-1]

    z = x
    for n_hidden in self.dim_hidden:
      Wx = nn.Dense(n_hidden, use_bias=True)
      z = self.act_fn(Wx(z))

    if self.is_potential:
      Wx = nn.Dense(1, use_bias=True)
      z = Wx(z).squeeze(-1)

      quad_term = 0.5 * jax.vmap(jnp.dot)(x, x)
      z += quad_term
    else:
      Wx = nn.Dense(n_input, use_bias=True)
      z = x + Wx(z)

    return z.squeeze(0) if squeeze else z

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alright thanks for digging and giving the context! This means that #23 silently introduced a bug and our current pypi release is unstable. So let's bump the version of cmonge in this PR. Also, would be great to add a small test that tries to use "create_train_state" function. Basically a test that would have failed when trying to merge #23

@jannisborn jannisborn merged commit d903d23 into main Aug 31, 2025
1 check passed
@jannisborn jannisborn deleted the ot_trainer-monge branch August 31, 2025 19:48
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants