-
Notifications
You must be signed in to change notification settings - Fork 18
Open
Labels
enhancementNew feature or requestNew feature or request
Description
Description of feature
Once we merged #228, #235, #239, #240, we can start with the torch backend implementation. Therefore
- we will keep one
CellFlowclass for both backends - have a
src/backends/jaxand asrc/backends/torchdirectory CellFlow.prepare_modelwill have an argumentbackend: Literal["jax", "torch"]- for now, let's not implement GENOT for torch, let's just go with
OTFlowMatching, to keep things simple- - the only problem which arises is that
prepare_modeltakes backend-specific arguments, namelymatch_fn,optimizer, andvf_act_fn.
For the last point, I see the following as the best solution
allow passing both jax and torch instances, setting per default None , describe the default in the docs, and eventually instantiating them later in the solver classes
Reactions are currently unavailable
Metadata
Metadata
Labels
enhancementNew feature or requestNew feature or request