diff --git a/crazyflow/control/core.py b/crazyflow/control/core.py
index fde4393..bf53bc2 100644
--- a/crazyflow/control/core.py
+++ b/crazyflow/control/core.py
@@ -34,18 +34,16 @@ def parametrize(
device: The device to use. If None, the device is inferred from the xp module.
Example:
- ```python
- import numpy as np
- from crazyflow.control import parametrize
- from crazyflow.control.mellinger import state2attitude
-
- ctrl = parametrize(state2attitude, "cf2x_L250")
- pos = np.zeros(3)
- quat = np.array([0.0, 0.0, 0.0, 1.0])
- vel = np.zeros(3)
- cmd = np.zeros(13)
- rpyt, int_pos_err = ctrl(pos, quat, vel, cmd)
- ```
+ ```python
+ import numpy as np
+ from crazyflow.control import parametrize
+ from crazyflow.control.mellinger import state2attitude
+
+ ctrl = parametrize(state2attitude, "cf2x_L250")
+ pos, quat = np.zeros(3), np.array([0.0, 0.0, 0.0, 1.0])
+ vel, cmd = np.zeros(3), np.zeros(13)
+ rpyt, int_pos_err = ctrl(pos, quat, vel, cmd)
+ ```
Returns:
The parametrized controller function with all keyword argument only parameters filled in.
diff --git a/crazyflow/dynamics/__init__.py b/crazyflow/dynamics/__init__.py
index b8426d3..bd90eed 100644
--- a/crazyflow/dynamics/__init__.py
+++ b/crazyflow/dynamics/__init__.py
@@ -48,12 +48,12 @@ def dynamics_features(dynamics: Callable) -> dict[str, bool]:
``ValueError``.
Example:
- ```python
- from crazyflow.dynamics import dynamics_features
- from crazyflow.dynamics.first_principles import dynamics
+ ```python
+ from crazyflow.dynamics import dynamics_features
+ from crazyflow.dynamics.first_principles import dynamics
- dynamics_features(dynamics) # {'rotor_dynamics': True}
- ```
+ dynamics_features(dynamics) # {'rotor_dynamics': True}
+ ```
"""
if hasattr(dynamics, "func"): # Is a partial function
return dynamics_features(dynamics.func)
diff --git a/crazyflow/dynamics/core.py b/crazyflow/dynamics/core.py
index 31c22ae..20127c2 100644
--- a/crazyflow/dynamics/core.py
+++ b/crazyflow/dynamics/core.py
@@ -55,15 +55,19 @@ def parametrize(
device: The device to use. If none, the device is inferred from the xp module.
Example:
- ```{ .python notest }
- from crazyflow.dynamics.core import parametrize
- from crazyflow.dynamics.first_principles import dynamics
-
- dynamics_fn = parametrize(dynamics, drone="cf2x_L250")
- pos_dot, quat_dot, vel_dot, ang_vel_dot, rotor_vel_dot = dynamics_fn(
- pos=pos, quat=quat, vel=vel, ang_vel=ang_vel, cmd=cmd, rotor_vel=rotor_vel
- )
- ```
+ ```python
+ import numpy as np
+ from crazyflow.dynamics.core import parametrize
+ from crazyflow.dynamics.first_principles import dynamics
+
+ dynamics_fn = parametrize(dynamics, drone="cf2x_L250")
+ pos, quat = np.zeros(3), np.array([0.0, 0.0, 0.0, 1.0])
+ vel, ang_vel = np.zeros(3), np.zeros(3)
+ rotor_vel, cmd = np.zeros(4), np.zeros(4)
+ pos_dot, quat_dot, vel_dot, ang_vel_dot, rotor_vel_dot = dynamics_fn(
+ pos=pos, quat=quat, vel=vel, ang_vel=ang_vel, cmd=cmd, rotor_vel=rotor_vel
+ )
+ ```
Returns:
The parametrized dynamics function with all keyword argument only parameters filled in.
diff --git a/docs/examples/index.md b/docs/examples/index.md
index f4baaf7..ec3e825 100644
--- a/docs/examples/index.md
+++ b/docs/examples/index.md
@@ -8,6 +8,7 @@ These runnable examples cover control, JAX transformations, pipeline extensions,
A single drone commanded to hold a fixed height using state control. This is the minimal end-to-end loop: create a `Sim`, reset it, apply a state command, and step forward.
+
```{ .python notest }
--8<-- "examples/control/hover.py"
```
@@ -22,6 +23,7 @@ python examples/control/hover.py
Commanding roll, pitch, yaw, and collective thrust directly. This level bypasses the Mellinger position loop and is typical for RL agents that output attitude targets.
+
```{ .python notest }
--8<-- "examples/control/attitude.py"
```
@@ -42,6 +44,7 @@ python examples/control/sampling.py
Because the simulator is built entirely from JAX operations, `jax.grad` can differentiate through it. Starting the drone above the target height keeps it away from the floor, so the floor-clipping stage never fires and gradients flow freely through the entire trajectory.
+
```{ .python notest }
--8<-- "examples/jax/gradient.py"
```
@@ -52,6 +55,7 @@ Because the simulator is built entirely from JAX operations, `jax.grad` can diff
Randomizing mass and inertia through the reset pipeline. An optional mask limits randomization to selected worlds.
+
```{ .python notest }
--8<-- "examples/plugins/randomize.py"
```
@@ -66,6 +70,7 @@ python examples/plugins/randomize.py
Inserting a random external force and torque into the step pipeline. The disturbance fires on every dynamics tick, so the drone fights wind-like perturbations.
+
```{ .python notest }
--8<-- "examples/plugins/disturbance.py"
```
@@ -80,6 +85,7 @@ Offscreen rendering returns RGB and depth images on every frame. The FPV camera
+
```{ .python notest }
--8<-- "examples/rendering/cameras.py"
```
@@ -98,6 +104,7 @@ python examples/rendering/cameras.py
+
```{ .python notest }
--8<-- "examples/rendering/led_deck.py"
```
@@ -121,6 +128,7 @@ The default collision geometry is a sphere around the drone frame. `use_box_coll
+
```{ .python notest }
--8<-- "examples/contacts/contacts.py"
```
@@ -131,6 +139,7 @@ The default collision geometry is a sphere around the drone frame. `use_box_coll
`render_depth` fires rays from a camera and returns per-pixel distances. This is faster than full RGB rendering and useful for obstacle sensing or depth-based controllers.
+
```{ .python notest }
--8<-- "examples/rendering/raycasting.py"
```
@@ -145,6 +154,7 @@ python examples/rendering/raycasting.py
Evaluating a random policy in the figure-8 environment. The env wraps `Sim` behind the standard Gymnasium `VectorEnv` interface.
+
```{ .python notest }
--8<-- "examples/environments/figure8.py"
```
diff --git a/docs/user-guide/control/parametrize.md b/docs/user-guide/control/parametrize.md
index d9a1220..4742a7a 100644
--- a/docs/user-guide/control/parametrize.md
+++ b/docs/user-guide/control/parametrize.md
@@ -91,6 +91,7 @@ float(params["mass"]) # 0.029
By default parameters are stored as NumPy arrays. Pass `xp` to convert them upfront, which avoids per-call conversion overhead in frameworks like PyTorch or JAX:
+
```{ .python notest }
import torch
from crazyflow.control import parametrize
diff --git a/docs/user-guide/dynamics/parametrize.md b/docs/user-guide/dynamics/parametrize.md
index 519769c..59e549f 100644
--- a/docs/user-guide/dynamics/parametrize.md
+++ b/docs/user-guide/dynamics/parametrize.md
@@ -37,6 +37,7 @@ If your drone is not listed, you can identify the parameters from flight data us
By default, `parametrize` stores parameters as NumPy arrays. For frameworks that would otherwise need to convert those arrays on every call — such as PyTorch, where NumPy arrays must become tensors — passing `xp` converts the parameters upfront. The backend of the outputs is always inferred from whatever arrays you pass in at call time.
+
```{ .python notest }
import torch
import jax.numpy as jnp
@@ -49,6 +50,7 @@ dynamics_jax = parametrize(dynamics, drone="cf2x_L250", xp=jnp)
You can also specify a compute device — for example, to move JAX parameters to GPU at construction time:
+
```{ .python notest }
import jax
import jax.numpy as jnp
diff --git a/docs/user-guide/dynamics/system-identification.md b/docs/user-guide/dynamics/system-identification.md
index 7818ecd..4a46bff 100644
--- a/docs/user-guide/dynamics/system-identification.md
+++ b/docs/user-guide/dynamics/system-identification.md
@@ -20,6 +20,7 @@ After `preprocessing` + [`derivatives_svf`][crazyflow.dynamics.utils.data_utils.
## Full pipeline
+
```{ .python notest }
from crazyflow.dynamics.utils.data_utils import preprocessing, derivatives_svf
from crazyflow.dynamics.utils.identification import sys_id_translation, sys_id_rotation
@@ -61,6 +62,7 @@ See [`sys_id_translation`][crazyflow.dynamics.utils.identification.sys_id_transl
To check that the identified parameters generalise to unseen flight regimes, collect a second dataset of different trajectories and pass it as `data_validation`. RMSE and R² are then reported on both the training data and the validation data.
+
```{ .python notest }
# Preprocess the validation dataset independently — it must come from
# different trajectories, not a split of the same recording.
@@ -97,6 +99,7 @@ cmd_rpy_coef = [196.18, 196.18, 390.27] # from rot_params["cmd_rpy_coef"
Once the entry is in the TOML file, load the dynamics as usual:
+
```{ .python notest }
from crazyflow.dynamics import parametrize
from crazyflow.dynamics.so_rpy_rotor_drag import dynamics
diff --git a/docs/user-guide/functional-api.md b/docs/user-guide/functional-api.md
index dc43dd1..9fc6fb0 100644
--- a/docs/user-guide/functional-api.md
+++ b/docs/user-guide/functional-api.md
@@ -4,30 +4,6 @@ The object-oriented API is convenient for scripting, but it relies on Python-lev
The functional API addresses this by expressing the same operations as pure functions that take `SimData` and return updated `SimData`. There is no hidden state, so JAX can trace, compile, and differentiate through arbitrary compositions of these functions.
-## What does not work inside JAX transformations
-
-The object-oriented `Sim` methods mutate `sim.data` in place through Python calls. JAX cannot trace through Python-level state mutations, so these methods cannot be used inside `jax.jit`, `jax.grad`, or `jax.lax.scan`:
-
-```{ .python notest }
-import jax
-import jax.numpy as jnp
-from crazyflow.sim import Sim
-from crazyflow.control import Control
-
-sim = Sim(control=Control.attitude)
-sim.reset()
-
-@jax.jit
-def broken(cmd):
- sim.attitude_control(cmd) # mutates sim.data — JAX traces the ops but leaks the tracer
- sim.step(1)
- return sim.data.states.pos # sim.data now holds a leaked tracer; accessing it outside JIT raises UnexpectedTracerError
-```
-
-## What does work
-
-The purely functional counterpart passes `SimData` explicitly and returns updated `SimData`. Every operation is a plain JAX function with no Python-level mutation, so the full simulation pipeline is traceable by any JAX transformation:
-
```python
import jax
import jax.numpy as jnp
diff --git a/docs/user-guide/mujoco.md b/docs/user-guide/mujoco.md
index 34b0be1..5868437 100644
--- a/docs/user-guide/mujoco.md
+++ b/docs/user-guide/mujoco.md
@@ -75,6 +75,7 @@ sim.reset()
Loading from a file works identically:
+
```{ .python notest }
import mujoco
gate_spec = mujoco.MjSpec.from_file("assets/gate.xml")
@@ -100,6 +101,7 @@ After `sim.step()` or `sim.reset()`, `mjx_synced` is set to `False`. The `sim.re
These run only once per render or contact call, regardless of how many dynamics steps were taken since the last sync.
+
```{ .python notest }
for i in range(10):
sim.step(5) # JAX dynamics only, mjx_synced = False
@@ -113,6 +115,7 @@ for i in range(10):
This means the order of calls matters. Grouping all rendering and contact queries together after a step lets them share a single sync:
+
```{ .python notest }
sim.step(5)
contacts = sim.contacts() # sync runs here
@@ -121,6 +124,7 @@ sim.render(mode="rgb_array") # flag already set, no second sync
Interleaving a step between them forces two syncs:
+
```{ .python notest }
contacts = sim.contacts() # sync runs here
sim.step(5) # flag cleared
@@ -135,12 +139,16 @@ The solution is to **close over** `mjx_data` rather than pass it as an argument.
The drone racing environment in [lsy_drone_racing](https://github.com/learnsyslab/lsy_drone_racing) uses this pattern to build a contact check function:
-```{ .python notest }
+```python
from jax import Array
+from crazyflow.sim import Sim
from crazyflow.sim.sim import sync_sim2mjx
from crazyflow.sim.data import SimData
+sim = Sim(n_worlds=1, n_drones=1)
+sim.reset()
+
_mjx_data = sim.mjx_data # captured in closure
def check_contacts(sim_data: SimData, obstacle_mocap_pos: Array) -> Array:
@@ -148,6 +156,8 @@ def check_contacts(sim_data: SimData, obstacle_mocap_pos: Array) -> Array:
mjx_data = _mjx_data.replace(mocap_pos=obstacle_mocap_pos)
_, mjx_data = sync_sim2mjx(sim_data, mjx_data, sim.mjx_model)
return mjx_data._impl.contact.dist < 0
+
+in_contact = check_contacts(sim.data, sim.mjx_data.mocap_pos)
```
`_mjx_data` is fused into the closure and compiled as a constant. Only `sim_data` and the obstacle positions cross the JIT boundary at runtime — a much smaller pytree than passing the full `mjx_data`.
diff --git a/docs/user-guide/oo-api.md b/docs/user-guide/oo-api.md
index 207dbe3..d4e2af7 100644
--- a/docs/user-guide/oo-api.md
+++ b/docs/user-guide/oo-api.md
@@ -186,6 +186,7 @@ pos_w0_d1 = sim.data.states.pos[0, 1] # (3,)
`sim.render()` opens an interactive MuJoCo viewer or returns an image array for offscreen rendering.
+
```{ .python notest }
sim.render() # interactive window, world 0
sim.render(mode="rgb_array") # returns (H, W, 3) uint8
diff --git a/docs/user-guide/visualization.md b/docs/user-guide/visualization.md
index 9daec03..17fdb6c 100644
--- a/docs/user-guide/visualization.md
+++ b/docs/user-guide/visualization.md
@@ -21,6 +21,7 @@ Crazyflow supports onscreen interactive rendering and offscreen RGB/depth captur
| `"depth_array"` | `(H, W) float32` | Offscreen depth frame in metres |
| `"rgbd_tuple"` | `(rgb, depth)` | Both channels as a tuple |
+
```{ .python notest }
sim.render() # interactive window
rgb = sim.render(mode="rgb_array") # numpy array (H, W, 3)
@@ -33,6 +34,7 @@ sim.close() # close the viewer
Pass a camera name or integer ID to select which camera to render from. The default (`camera=-1`) uses the free camera. Each drone ships with a first-person view camera named `fpv_cam:`:
+
```{ .python notest }
sim.render(camera="fpv_cam:0") # first-person view from drone 0
sim.render(camera=0) # camera by integer ID
@@ -42,6 +44,7 @@ sim.render(camera=0) # camera by integer ID
For obstacle sensing or perception-based controllers, `render_depth` fires a ray from each camera pixel and returns per-pixel distances — faster than full RGB rendering because it skips lighting and colour computation:
+
```{ .python notest }
import jax.numpy as jnp
from crazyflow.sim.sensors import build_render_depth_fn, render_depth
@@ -64,6 +67,7 @@ dist = render_fn(sim)
`change_material` updates the RGBA colour and emission intensity of any named material on any subset of drones without rebuilding the model:
+
```{ .python notest }
import numpy as np
from crazyflow.sim.visualize import change_material
@@ -77,6 +81,7 @@ sim.render()
`sim.render()` always renders a single world at a time. Pass `world=` to choose which one:
+
```{ .python notest }
sim.render(world=0) # default
sim.render(world=3) # render world 3