Thanks for building this! Looking very promising if performant and reliable. Would love to see support multiple arguments in the input function. In JAX this has been accomplished with the argnums argument specifying the index or sequence of indices to differentiate with respect to.
Thanks for building this! Looking very promising if performant and reliable. Would love to see support multiple arguments in the input function. In JAX this has been accomplished with the argnums argument specifying the index or sequence of indices to differentiate with respect to.