diff --git a/crazyflow/control/core.py b/crazyflow/control/core.py index fde4393..bf53bc2 100644 --- a/crazyflow/control/core.py +++ b/crazyflow/control/core.py @@ -34,18 +34,16 @@ def parametrize( device: The device to use. If None, the device is inferred from the xp module. Example: - ```python - import numpy as np - from crazyflow.control import parametrize - from crazyflow.control.mellinger import state2attitude - - ctrl = parametrize(state2attitude, "cf2x_L250") - pos = np.zeros(3) - quat = np.array([0.0, 0.0, 0.0, 1.0]) - vel = np.zeros(3) - cmd = np.zeros(13) - rpyt, int_pos_err = ctrl(pos, quat, vel, cmd) - ``` + ```python + import numpy as np + from crazyflow.control import parametrize + from crazyflow.control.mellinger import state2attitude + + ctrl = parametrize(state2attitude, "cf2x_L250") + pos, quat = np.zeros(3), np.array([0.0, 0.0, 0.0, 1.0]) + vel, cmd = np.zeros(3), np.zeros(13) + rpyt, int_pos_err = ctrl(pos, quat, vel, cmd) + ``` Returns: The parametrized controller function with all keyword argument only parameters filled in. diff --git a/crazyflow/dynamics/__init__.py b/crazyflow/dynamics/__init__.py index b8426d3..bd90eed 100644 --- a/crazyflow/dynamics/__init__.py +++ b/crazyflow/dynamics/__init__.py @@ -48,12 +48,12 @@ def dynamics_features(dynamics: Callable) -> dict[str, bool]: ``ValueError``. Example: - ```python - from crazyflow.dynamics import dynamics_features - from crazyflow.dynamics.first_principles import dynamics + ```python + from crazyflow.dynamics import dynamics_features + from crazyflow.dynamics.first_principles import dynamics - dynamics_features(dynamics) # {'rotor_dynamics': True} - ``` + dynamics_features(dynamics) # {'rotor_dynamics': True} + ``` """ if hasattr(dynamics, "func"): # Is a partial function return dynamics_features(dynamics.func) diff --git a/crazyflow/dynamics/core.py b/crazyflow/dynamics/core.py index 31c22ae..20127c2 100644 --- a/crazyflow/dynamics/core.py +++ b/crazyflow/dynamics/core.py @@ -55,15 +55,19 @@ def parametrize( device: The device to use. If none, the device is inferred from the xp module. Example: - ```{ .python notest } - from crazyflow.dynamics.core import parametrize - from crazyflow.dynamics.first_principles import dynamics - - dynamics_fn = parametrize(dynamics, drone="cf2x_L250") - pos_dot, quat_dot, vel_dot, ang_vel_dot, rotor_vel_dot = dynamics_fn( - pos=pos, quat=quat, vel=vel, ang_vel=ang_vel, cmd=cmd, rotor_vel=rotor_vel - ) - ``` + ```python + import numpy as np + from crazyflow.dynamics.core import parametrize + from crazyflow.dynamics.first_principles import dynamics + + dynamics_fn = parametrize(dynamics, drone="cf2x_L250") + pos, quat = np.zeros(3), np.array([0.0, 0.0, 0.0, 1.0]) + vel, ang_vel = np.zeros(3), np.zeros(3) + rotor_vel, cmd = np.zeros(4), np.zeros(4) + pos_dot, quat_dot, vel_dot, ang_vel_dot, rotor_vel_dot = dynamics_fn( + pos=pos, quat=quat, vel=vel, ang_vel=ang_vel, cmd=cmd, rotor_vel=rotor_vel + ) + ``` Returns: The parametrized dynamics function with all keyword argument only parameters filled in. diff --git a/docs/examples/index.md b/docs/examples/index.md index f4baaf7..ec3e825 100644 --- a/docs/examples/index.md +++ b/docs/examples/index.md @@ -8,6 +8,7 @@ These runnable examples cover control, JAX transformations, pipeline extensions, A single drone commanded to hold a fixed height using state control. This is the minimal end-to-end loop: create a `Sim`, reset it, apply a state command, and step forward. + ```{ .python notest } --8<-- "examples/control/hover.py" ``` @@ -22,6 +23,7 @@ python examples/control/hover.py Commanding roll, pitch, yaw, and collective thrust directly. This level bypasses the Mellinger position loop and is typical for RL agents that output attitude targets. + ```{ .python notest } --8<-- "examples/control/attitude.py" ``` @@ -42,6 +44,7 @@ python examples/control/sampling.py Because the simulator is built entirely from JAX operations, `jax.grad` can differentiate through it. Starting the drone above the target height keeps it away from the floor, so the floor-clipping stage never fires and gradients flow freely through the entire trajectory. + ```{ .python notest } --8<-- "examples/jax/gradient.py" ``` @@ -52,6 +55,7 @@ Because the simulator is built entirely from JAX operations, `jax.grad` can diff Randomizing mass and inertia through the reset pipeline. An optional mask limits randomization to selected worlds. + ```{ .python notest } --8<-- "examples/plugins/randomize.py" ``` @@ -66,6 +70,7 @@ python examples/plugins/randomize.py Inserting a random external force and torque into the step pipeline. The disturbance fires on every dynamics tick, so the drone fights wind-like perturbations. + ```{ .python notest } --8<-- "examples/plugins/disturbance.py" ``` @@ -80,6 +85,7 @@ Offscreen rendering returns RGB and depth images on every frame. The FPV camera RGB and depth camera outputs from a Crazyflow drone simulation + ```{ .python notest } --8<-- "examples/rendering/cameras.py" ``` @@ -98,6 +104,7 @@ python examples/rendering/cameras.py Crazyflow drones with runtime-controlled LED deck materials + ```{ .python notest } --8<-- "examples/rendering/led_deck.py" ``` @@ -121,6 +128,7 @@ The default collision geometry is a sphere around the drone frame. `use_box_coll + ```{ .python notest } --8<-- "examples/contacts/contacts.py" ``` @@ -131,6 +139,7 @@ The default collision geometry is a sphere around the drone frame. `use_box_coll `render_depth` fires rays from a camera and returns per-pixel distances. This is faster than full RGB rendering and useful for obstacle sensing or depth-based controllers. + ```{ .python notest } --8<-- "examples/rendering/raycasting.py" ``` @@ -145,6 +154,7 @@ python examples/rendering/raycasting.py Evaluating a random policy in the figure-8 environment. The env wraps `Sim` behind the standard Gymnasium `VectorEnv` interface. + ```{ .python notest } --8<-- "examples/environments/figure8.py" ``` diff --git a/docs/user-guide/control/parametrize.md b/docs/user-guide/control/parametrize.md index d9a1220..4742a7a 100644 --- a/docs/user-guide/control/parametrize.md +++ b/docs/user-guide/control/parametrize.md @@ -91,6 +91,7 @@ float(params["mass"]) # 0.029 By default parameters are stored as NumPy arrays. Pass `xp` to convert them upfront, which avoids per-call conversion overhead in frameworks like PyTorch or JAX: + ```{ .python notest } import torch from crazyflow.control import parametrize diff --git a/docs/user-guide/dynamics/parametrize.md b/docs/user-guide/dynamics/parametrize.md index 519769c..59e549f 100644 --- a/docs/user-guide/dynamics/parametrize.md +++ b/docs/user-guide/dynamics/parametrize.md @@ -37,6 +37,7 @@ If your drone is not listed, you can identify the parameters from flight data us By default, `parametrize` stores parameters as NumPy arrays. For frameworks that would otherwise need to convert those arrays on every call — such as PyTorch, where NumPy arrays must become tensors — passing `xp` converts the parameters upfront. The backend of the outputs is always inferred from whatever arrays you pass in at call time. + ```{ .python notest } import torch import jax.numpy as jnp @@ -49,6 +50,7 @@ dynamics_jax = parametrize(dynamics, drone="cf2x_L250", xp=jnp) You can also specify a compute device — for example, to move JAX parameters to GPU at construction time: + ```{ .python notest } import jax import jax.numpy as jnp diff --git a/docs/user-guide/dynamics/system-identification.md b/docs/user-guide/dynamics/system-identification.md index 7818ecd..4a46bff 100644 --- a/docs/user-guide/dynamics/system-identification.md +++ b/docs/user-guide/dynamics/system-identification.md @@ -20,6 +20,7 @@ After `preprocessing` + [`derivatives_svf`][crazyflow.dynamics.utils.data_utils. ## Full pipeline + ```{ .python notest } from crazyflow.dynamics.utils.data_utils import preprocessing, derivatives_svf from crazyflow.dynamics.utils.identification import sys_id_translation, sys_id_rotation @@ -61,6 +62,7 @@ See [`sys_id_translation`][crazyflow.dynamics.utils.identification.sys_id_transl To check that the identified parameters generalise to unseen flight regimes, collect a second dataset of different trajectories and pass it as `data_validation`. RMSE and R² are then reported on both the training data and the validation data. + ```{ .python notest } # Preprocess the validation dataset independently — it must come from # different trajectories, not a split of the same recording. @@ -97,6 +99,7 @@ cmd_rpy_coef = [196.18, 196.18, 390.27] # from rot_params["cmd_rpy_coef" Once the entry is in the TOML file, load the dynamics as usual: + ```{ .python notest } from crazyflow.dynamics import parametrize from crazyflow.dynamics.so_rpy_rotor_drag import dynamics diff --git a/docs/user-guide/functional-api.md b/docs/user-guide/functional-api.md index dc43dd1..9fc6fb0 100644 --- a/docs/user-guide/functional-api.md +++ b/docs/user-guide/functional-api.md @@ -4,30 +4,6 @@ The object-oriented API is convenient for scripting, but it relies on Python-lev The functional API addresses this by expressing the same operations as pure functions that take `SimData` and return updated `SimData`. There is no hidden state, so JAX can trace, compile, and differentiate through arbitrary compositions of these functions. -## What does not work inside JAX transformations - -The object-oriented `Sim` methods mutate `sim.data` in place through Python calls. JAX cannot trace through Python-level state mutations, so these methods cannot be used inside `jax.jit`, `jax.grad`, or `jax.lax.scan`: - -```{ .python notest } -import jax -import jax.numpy as jnp -from crazyflow.sim import Sim -from crazyflow.control import Control - -sim = Sim(control=Control.attitude) -sim.reset() - -@jax.jit -def broken(cmd): - sim.attitude_control(cmd) # mutates sim.data — JAX traces the ops but leaks the tracer - sim.step(1) - return sim.data.states.pos # sim.data now holds a leaked tracer; accessing it outside JIT raises UnexpectedTracerError -``` - -## What does work - -The purely functional counterpart passes `SimData` explicitly and returns updated `SimData`. Every operation is a plain JAX function with no Python-level mutation, so the full simulation pipeline is traceable by any JAX transformation: - ```python import jax import jax.numpy as jnp diff --git a/docs/user-guide/mujoco.md b/docs/user-guide/mujoco.md index 34b0be1..5868437 100644 --- a/docs/user-guide/mujoco.md +++ b/docs/user-guide/mujoco.md @@ -75,6 +75,7 @@ sim.reset() Loading from a file works identically: + ```{ .python notest } import mujoco gate_spec = mujoco.MjSpec.from_file("assets/gate.xml") @@ -100,6 +101,7 @@ After `sim.step()` or `sim.reset()`, `mjx_synced` is set to `False`. The `sim.re These run only once per render or contact call, regardless of how many dynamics steps were taken since the last sync. + ```{ .python notest } for i in range(10): sim.step(5) # JAX dynamics only, mjx_synced = False @@ -113,6 +115,7 @@ for i in range(10): This means the order of calls matters. Grouping all rendering and contact queries together after a step lets them share a single sync: + ```{ .python notest } sim.step(5) contacts = sim.contacts() # sync runs here @@ -121,6 +124,7 @@ sim.render(mode="rgb_array") # flag already set, no second sync Interleaving a step between them forces two syncs: + ```{ .python notest } contacts = sim.contacts() # sync runs here sim.step(5) # flag cleared @@ -135,12 +139,16 @@ The solution is to **close over** `mjx_data` rather than pass it as an argument. The drone racing environment in [lsy_drone_racing](https://github.com/learnsyslab/lsy_drone_racing) uses this pattern to build a contact check function: -```{ .python notest } +```python from jax import Array +from crazyflow.sim import Sim from crazyflow.sim.sim import sync_sim2mjx from crazyflow.sim.data import SimData +sim = Sim(n_worlds=1, n_drones=1) +sim.reset() + _mjx_data = sim.mjx_data # captured in closure def check_contacts(sim_data: SimData, obstacle_mocap_pos: Array) -> Array: @@ -148,6 +156,8 @@ def check_contacts(sim_data: SimData, obstacle_mocap_pos: Array) -> Array: mjx_data = _mjx_data.replace(mocap_pos=obstacle_mocap_pos) _, mjx_data = sync_sim2mjx(sim_data, mjx_data, sim.mjx_model) return mjx_data._impl.contact.dist < 0 + +in_contact = check_contacts(sim.data, sim.mjx_data.mocap_pos) ``` `_mjx_data` is fused into the closure and compiled as a constant. Only `sim_data` and the obstacle positions cross the JIT boundary at runtime — a much smaller pytree than passing the full `mjx_data`. diff --git a/docs/user-guide/oo-api.md b/docs/user-guide/oo-api.md index 207dbe3..d4e2af7 100644 --- a/docs/user-guide/oo-api.md +++ b/docs/user-guide/oo-api.md @@ -186,6 +186,7 @@ pos_w0_d1 = sim.data.states.pos[0, 1] # (3,) `sim.render()` opens an interactive MuJoCo viewer or returns an image array for offscreen rendering. + ```{ .python notest } sim.render() # interactive window, world 0 sim.render(mode="rgb_array") # returns (H, W, 3) uint8 diff --git a/docs/user-guide/visualization.md b/docs/user-guide/visualization.md index 9daec03..17fdb6c 100644 --- a/docs/user-guide/visualization.md +++ b/docs/user-guide/visualization.md @@ -21,6 +21,7 @@ Crazyflow supports onscreen interactive rendering and offscreen RGB/depth captur | `"depth_array"` | `(H, W) float32` | Offscreen depth frame in metres | | `"rgbd_tuple"` | `(rgb, depth)` | Both channels as a tuple | + ```{ .python notest } sim.render() # interactive window rgb = sim.render(mode="rgb_array") # numpy array (H, W, 3) @@ -33,6 +34,7 @@ sim.close() # close the viewer Pass a camera name or integer ID to select which camera to render from. The default (`camera=-1`) uses the free camera. Each drone ships with a first-person view camera named `fpv_cam:`: + ```{ .python notest } sim.render(camera="fpv_cam:0") # first-person view from drone 0 sim.render(camera=0) # camera by integer ID @@ -42,6 +44,7 @@ sim.render(camera=0) # camera by integer ID For obstacle sensing or perception-based controllers, `render_depth` fires a ray from each camera pixel and returns per-pixel distances — faster than full RGB rendering because it skips lighting and colour computation: + ```{ .python notest } import jax.numpy as jnp from crazyflow.sim.sensors import build_render_depth_fn, render_depth @@ -64,6 +67,7 @@ dist = render_fn(sim) `change_material` updates the RGBA colour and emission intensity of any named material on any subset of drones without rebuilding the model: + ```{ .python notest } import numpy as np from crazyflow.sim.visualize import change_material @@ -77,6 +81,7 @@ sim.render() `sim.render()` always renders a single world at a time. Pass `world=` to choose which one: + ```{ .python notest } sim.render(world=0) # default sim.render(world=3) # render world 3