On this line it looks like dmgas_dt is being computed via finite differences. But @alexalar: since mgas is computed differentiably, shouldn't it be possible to define dmgas_dt using jax.grad instead?
This is currently the only use of the _jax_get_dt_array function in the repo.
On this line it looks like
dmgas_dtis being computed via finite differences. But @alexalar: sincemgasis computed differentiably, shouldn't it be possible to definedmgas_dtusingjax.gradinstead?This is currently the only use of the
_jax_get_dt_arrayfunction in the repo.