Fox is an OCaml library for automatic differentiation and numerical computing, inspired heavily by Jax. It provides a flexible framework for automatically differentiating and XLA-jit-compiling tensor-based computations.
Please note that this library is still under active development and lacks many features.
- Automatic Differentiation: Support for forward-mode and reverse-mode automatic differentiation
- Higher-Order Derivatives: Compute nth-order derivatives of functions
- JIT Compilation: Just-in-time compilation to XLA
- Tree-based Value Representation: Efficient handling of complex data structures through tree-based representations
Fox targets OxCaml: it uses OxCaml's effect-handler
syntax and the modal Jane Street libraries, and builds against a local opam
switch on ocaml-variants.5.2.0+ox with a committed lockfile
(fox.opam.locked). It depends on a pinned
xla, which binds a prebuilt XLA
extension blob.
Download the XLA extension (elixir-nx/xla v0.4.4) and point XLA_EXTENSION_DIR
at it. The C++ stubs need its headers at build time; the runtime rpath is
baked in, so no LD_LIBRARY_PATH is needed once built:
wget https://github.com/elixir-nx/xla/releases/download/v0.4.4/xla_extension-x86_64-linux-gnu-cpu.tar.gz
tar -xzf xla_extension-x86_64-linux-gnu-cpu.tar.gz # -> ./xla_extension
export XLA_EXTENSION_DIR=$PWD/xla_extension(On platforms other than linux x86_64, download the matching archive instead.)
Create the switch from the lockfile:
opam switch create . 5.2.0+ox \
--repos ox=git+https://github.com/oxcaml/opam-repository.git,default \
--lockedXLA_EXTENSION_DIR=$PWD/xla_extension dune build @default @runtestAfter changing dependencies, refresh fox.opam.locked:
XLA_EXTENSION_DIR=$PWD/xla_extension opam install . --deps-only --with-test
opam lock .- Support non-singleton tensors in vjp
- Better shape and type story
- Basic tensor operations for simple neural network example
- 2d matmuls
- sum
- random other things
- Print XLA HLO module
- Add JIT caching story
- Pytorch backend
- Testing framework for diffing op backends (XLA, Pytorch, OCaml)
- Quickcheck generators
- Diff XLA backend and OCaml
- Diff OCaml/XLA vjp and Pytorch
- Support various types (ints, bools, etc.)
- Shape inference?
- Mixed-precision support
- Custom operator support