Skip to content

Enable mjw on MacOS#2948

Open
hartikainen wants to merge 5 commits intogoogle-deepmind:mainfrom
hartikainen:mjw-on-macos
Open

Enable mjw on MacOS#2948
hartikainen wants to merge 5 commits intogoogle-deepmind:mainfrom
hartikainen:mjw-on-macos

Conversation

@hartikainen
Copy link
Contributor

@hartikainen hartikainen commented Nov 16, 2025

See #2947.

With the existing changes and warp-lang>=1.11.0, we can now enable warp backend for mujoco on MacOS. More generally, mujoco warp now also works on CPU and not just cuda.

@hartikainen hartikainen force-pushed the mjw-on-macos branch 3 times, most recently from 6e718c6 to 2da826b Compare February 1, 2026 02:24
@hartikainen hartikainen marked this pull request as ready for review February 1, 2026 02:27
Copy link
Contributor Author

@hartikainen hartikainen Feb 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should probably go into the mujoco warp repo. It could also use a bit more thorough review as I'm not super fluent with the FFI semantics.

@hartikainen
Copy link
Contributor Author

hartikainen commented Feb 26, 2026

This should be ready to be reviewed. I see some tests failing locally like this:

Details
$ pytest
============================= test session starts =============================
platform darwin -- Python 3.12.10, pytest-9.0.2, pluggy-1.6.0
rootdir: /mujoco/main/mjx
configfile: pyproject.toml
collected 718 items

mujoco/mjx/_src/collision_driver_test.py .............................. [  4%]
.......                                                                 [  5%]
mujoco/mjx/_src/constraint_test.py ..........                           [  6%]
mujoco/mjx/_src/dataclasses_test.py .                                   [  6%]
mujoco/mjx/_src/forward_test.py .........                               [  7%]
mujoco/mjx/_src/inverse_test.py .......                                 [  8%]
mujoco/mjx/_src/io_test.py ............................................ [ 15%]
........................................................                [ 22%]
mujoco/mjx/_src/math_test.py .......................................... [ 28%]
....................................................................... [ 38%]
........................                                                [ 41%]
mujoco/mjx/_src/mesh_test.py ...                                        [ 42%]
mujoco/mjx/_src/passive_test.py .                                       [ 42%]
mujoco/mjx/_src/ray_test.py ...........                                 [ 44%]
mujoco/mjx/_src/scan_test.py ....                                       [ 44%]
mujoco/mjx/_src/sensor_test.py ..........                               [ 45%]
mujoco/mjx/_src/smooth_test.py .............................            [ 50%]
mujoco/mjx/_src/solver_test.py .............                            [ 51%]
mujoco/mjx/_src/support_test.py ..............                          [ 53%]
mujoco/mjx/integration_test/collision_driver_test.py .................. [ 56%]
....................................................................... [ 66%]
....................................................................... [ 76%]
....................................................................... [ 85%]
.........................                                               [ 89%]
mujoco/mjx/integration_test/forward_test.py ........................... [ 93%]
...                                                                     [ 93%]
mujoco/mjx/integration_test/smooth_test.py ............................ [ 97%]
..                                                                      [ 97%]
mujoco/mjx/warp/collision_driver_test.py .                              [ 97%]
mujoco/mjx/warp/forward_test.py .F..FF.....                             [ 99%]
mujoco/mjx/warp/smooth_test.py ....                                     [100%]

================================== FAILURES ===================================
__________________________ ForwardTest.test_forward1 __________________________

self = <mjx.warp.forward_test.ForwardTest testMethod=test_forward1>
xml = 'humanoid/humanoid.xml', batch_size = 7

    @parameterized.product(
        xml=(
            'humanoid/humanoid.xml',
            'pendula.xml',
        ),
        batch_size=(1, 7),
    )
    def test_forward(self, xml: str, batch_size: int):
      if not _FORCE_TEST:
        if not mjxw.WARP_INSTALLED:
          self.skipTest('Warp not installed.')

      m = test_util.load_test_file(xml)
      m.opt.iterations = 10
      m.opt.ls_iterations = 10
      m.opt.jacobian = mujoco.mjtJacobian.mjJAC_DENSE
      mx = mjx.put_model(m, impl='warp')

      d = mujoco.MjData(m)
      worldids = jp.arange(batch_size)
      dx_batch = jax.vmap(functools.partial(tu.make_data, m))(worldids)

      dx_batch = jax.jit(jax.vmap(forward.forward, in_axes=(None, 0)))(
          mx, dx_batch
      )

      for i in range(batch_size):
        dx = dx_batch[i]

        d.qpos[:] = dx.qpos
        d.qvel[:] = dx.qvel
        d.ctrl[:] = dx.ctrl
        d.mocap_pos[:] = dx.mocap_pos
        d.mocap_quat[:] = dx.mocap_quat
        mujoco.mj_forward(m, d)

        # fwd_position
        tu.assert_attr_eq(dx, d, 'xpos')
        tu.assert_attr_eq(dx, d, 'xquat')
        tu.assert_attr_eq(dx, d, 'xipos')
        tu.assert_eq(d.ximat.reshape((-1, 3, 3)), dx.ximat, 'ximat')
        tu.assert_attr_eq(dx, d, 'xanchor')
        tu.assert_attr_eq(dx, d, 'xaxis')
        tu.assert_attr_eq(dx, d, 'geom_xpos')
        tu.assert_eq(dx.geom_xmat, d.geom_xmat.reshape((-1, 3, 3)), 'geom_xmat')
        if m.nsite:
          tu.assert_attr_eq(dx, d, 'site_xpos')
          tu.assert_eq(dx.site_xmat, d.site_xmat.reshape((-1, 3, 3)), 'site_xmat')
        tu.assert_attr_eq(dx, d, 'cdof')
        tu.assert_attr_eq(dx._impl, d, 'cinert')
        tu.assert_attr_eq(dx, d, 'subtree_com')
        if m.nlight:
          tu.assert_attr_eq(dx._impl, d, 'light_xpos')
          tu.assert_attr_eq(dx._impl, d, 'light_xdir')
        if m.ncam:
          tu.assert_attr_eq(dx, d, 'cam_xpos')
          tu.assert_eq(dx.cam_xmat, d.cam_xmat.reshape((-1, 3, 3)), 'cam_xmat')
        tu.assert_attr_eq(dx, d, 'ten_length')
        tu.assert_attr_eq(dx._impl, d, 'ten_J')
        tu.assert_attr_eq(dx._impl, d, 'ten_wrapadr')
        tu.assert_attr_eq(dx._impl, d, 'ten_wrapnum')
        tu.assert_attr_eq(dx._impl, d, 'wrap_xpos')
        tu.assert_attr_eq(dx._impl, d, 'wrap_obj')
        tu.assert_attr_eq(dx._impl, d, 'crb')

        qm = np.zeros((m.nv, m.nv))
        mujoco.mj_fullM(m, qm, d.qM)
        # mjwarp adds padding to qM
        tu.assert_eq(qm, dx._impl.qM[: m.nv, : m.nv], 'qM')
        # qLD is fused in a cholesky factorize and solve, and not written to.

        tu.assert_contact_eq(d, dx, worldid=i)

        tu.assert_attr_eq(dx, d, 'actuator_length')
        actuator_moment = np.zeros((m.nu, m.nv))
        mujoco.mju_sparse2dense(
            actuator_moment,
            d.actuator_moment,
            d.moment_rownnz,
            d.moment_rowadr,
            d.moment_colind,
        )
        tu.assert_eq(dx._impl.actuator_moment, actuator_moment, 'actuator_moment')

        # fwd_velocity
        tu.assert_attr_eq(dx._impl, d, 'actuator_velocity')
        tu.assert_attr_eq(dx, d, 'cvel')
        tu.assert_attr_eq(dx, d, 'cdof_dot')
        tu.assert_attr_eq(dx._impl, d, 'qfrc_spring')
        tu.assert_attr_eq(dx._impl, d, 'qfrc_damper')
        tu.assert_attr_eq(dx, d, 'qfrc_gravcomp')
        tu.assert_attr_eq(dx, d, 'qfrc_fluid')
        tu.assert_attr_eq(dx, d, 'qfrc_passive')
        tu.assert_attr_eq(dx, d, 'qfrc_bias')
>       tu.assert_efc_eq(d, dx, worldid=i)

mujoco/mjx/warp/forward_test.py:179:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
mujoco/mjx/warp/test_util.py:199: in assert_efc_eq
    assert_eq(jp_, j, 'J')
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

a = Array([[-0.9875637 ,  0.15721938,  1.        , -0.00917677,  1.335037  ,
        -0.10656909, -0.10577767,  1.1439778 ...    -0.03383656,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ]], dtype=float32)
b = array([[-0.9875637 ,  0.15721938,  1.        , -0.00917675,  1.33503689,
        -0.10656907, -0.10577765,  1.14397789....07388643,
        -0.0338367 ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ]])
name = 'J'

    def assert_eq(a, b, name):
      tol = _TOLERANCE * 10  # avoid test noise
      err_msg = f'mismatch: {name}'
>     np.testing.assert_allclose(a, b, err_msg=err_msg, atol=tol, rtol=tol)
E     AssertionError:
E     Not equal to tolerance rtol=0.0005, atol=0.0005
E     mismatch: J
E     Mismatched elements: 160 / 432 (37%)
E     First 5 mismatches are at indices:
E      [4, 0]: -0.9871430993080139 (ACTUAL), -0.999479684061378 (DESIRED)
E      [4, 1]: 0.15983904898166656 (ACTUAL), -0.03225462987801541 (DESIRED)
E      [4, 3]: 0.16618438065052032 (ACTUAL), -0.22991350005486522 (DESIRED)
E      [4, 4]: 1.2810968160629272 (ACTUAL), 1.3359474985903166 (DESIRED)
E      [4, 5]: 0.044975683093070984 (ACTUAL), -0.0601160972922109 (DESIRED)
E     Max absolute difference among violations: 0.87907806
E     Max relative difference among violations: 12.74774404
E      ACTUAL: array([[-0.987564,  0.157219,  1.      , -0.009177,  1.335037, -0.106569,
E             -0.105778,  1.143978, -0.053232,  0.039397, -0.004723,  0.880073,
E             -0.497201,  0.173519, -0.011014,  0.      ,  0.      ,  0.      ,...
E      DESIRED: array([[-0.987564,  0.157219,  1.      , -0.009177,  1.335037, -0.106569,
E             -0.105778,  1.143978, -0.053232,  0.039397, -0.004723,  0.880073,
E             -0.497201,  0.173519, -0.011014,  0.      ,  0.      ,  0.      ,...

mujoco/mjx/warp/test_util.py:36: AssertionError
________________________ ForwardTest.test_jit_caching0 ________________________

self = <mjx.warp.forward_test.ForwardTest testMethod=test_jit_caching0>
xml = 'pendula.xml'

    @parameterized.parameters(
        'pendula.xml',
        'humanoid/humanoid.xml',
    )
    def test_jit_caching(self, xml):
      """Tests jit caching on the full step function."""
      if not _FORCE_TEST:
        if not mjxw.WARP_INSTALLED:
          self.skipTest('Warp not installed.')

      batch_size = 7
      m = test_util.load_test_file(xml)
      mx = mjx.put_model(m, impl='warp')

      keys = jp.arange(batch_size)
      dx_batch = jax.vmap(functools.partial(tu.make_data, m))(keys)

      step_fn = jax.jit(jax.vmap(forward.step, in_axes=(None, 0)))
      dx_batch1 = step_fn(mx, dx_batch)
      jax.tree_util.tree_map(lambda x: x.block_until_ready(), dx_batch1)
>     self.assertEqual(step_fn._cache_size(), 1)
E     AssertionError: 0 != 1

mujoco/mjx/warp/forward_test.py:76: AssertionError
---------------------------- Captured stdout call -----------------------------
Module mul_m_sparse_diag__locals___mul_m_sparse_diag_1de23634 4949128 load on device 'cpu' took 195.10 ms  (compiled)
Module mul_m_sparse_ij__locals___mul_m_sparse_ij_b1bb5fb3 8b21b7f load on device 'cpu' took 204.99 ms  (compiled)
Module update_gradient_JTDAJ_sparse_tiled__locals__kernel_8f59ead1 6963331 load on device 'cpu' took 269.85 ms  (compiled)
________________________ ForwardTest.test_jit_caching1 ________________________

self = <mjx.warp.forward_test.ForwardTest testMethod=test_jit_caching1>
xml = 'humanoid/humanoid.xml'

    @parameterized.parameters(
        'pendula.xml',
        'humanoid/humanoid.xml',
    )
    def test_jit_caching(self, xml):
      """Tests jit caching on the full step function."""
      if not _FORCE_TEST:
        if not mjxw.WARP_INSTALLED:
          self.skipTest('Warp not installed.')

      batch_size = 7
      m = test_util.load_test_file(xml)
      mx = mjx.put_model(m, impl='warp')

      keys = jp.arange(batch_size)
      dx_batch = jax.vmap(functools.partial(tu.make_data, m))(keys)

      step_fn = jax.jit(jax.vmap(forward.step, in_axes=(None, 0)))
      dx_batch1 = step_fn(mx, dx_batch)
      jax.tree_util.tree_map(lambda x: x.block_until_ready(), dx_batch1)
>     self.assertEqual(step_fn._cache_size(), 1)
E     AssertionError: 0 != 1

mujoco/mjx/warp/forward_test.py:76: AssertionError
============================== warnings summary ===============================
mujoco/mjx/_src/collision_driver_test.py::ConvexTest::test_box_box
mujoco/mjx/_src/collision_driver_test.py::ConvexTest::test_box_box_edge
mujoco/mjx/_src/collision_driver_test.py::ConvexTest::test_convex_convex
mujoco/mjx/_src/collision_driver_test.py::ConvexTest::test_convex_convex_edge
mujoco/mjx/_src/collision_driver_test.py::HFieldTest::test_hfield_deep
mujoco/mjx/_src/support_test.py::SupportTest::test_bind
mujoco/mjx/integration_test/smooth_test.py::TransmissionIntegrationTest::test_transmission14
  /mujoco/main/.venv/lib/python3.12/site-packages/jax/_src/abstract_arrays.py:135: RuntimeWarning: overflow encountered in cast
    return literals.TypedNdArray(np.asarray(x, dtype), weak_type=False)

mujoco/mjx/_src/forward_test.py::ActuatorTest::test_actuator2
mujoco/mjx/_src/io_test.py::DataIOTest::test_qm_mapm2m0
mujoco/mjx/_src/passive_test.py::PassiveTest::test_passive
mujoco/mjx/integration_test/collision_driver_test.py::CollisionDriverIntegrationTest::test_collision_driver0
mujoco/mjx/integration_test/collision_driver_test.py::CollisionDriverIntegrationTest::test_collision_driver115
mujoco/mjx/integration_test/collision_driver_test.py::CollisionDriverIntegrationTest::test_collision_driver125
  /mujoco/main/.venv/lib/python3.12/site-packages/jax/_src/interpreters/partial_eval.py:2412: DeprecationWarning: Passing arguments 'a', 'a_min' or 'a_max' to jax.numpy.clip is deprecated. Please use 'arr', 'min' or 'max' respectively instead.
    ans_pytree = fun(*args, **kwargs)

mujoco/mjx/_src/io_test.py::DataIOTest::test_make_data_warp
  /mujoco/main/mjx/mujoco/mjx/_src/io_test.py:468: DeprecationWarning: nconmax will be deprecated in mujoco-mjx>=3.5. Use naconmax instead.
    d = mjx.make_data(m, impl='warp', nconmax=9, njmax=23)

mujoco/mjx/integration_test/collision_driver_test.py: 490 warnings
  /mujoco/main/mjx/mujoco/mjx/integration_test/collision_driver_test.py:80: DeprecationWarning: Accessing `contact` directly from `Data` is deprecated. Access it via `data._impl.contact` instead.
    idx_mjx = list(zip(dx.contact.geom1, dx.contact.geom2))

mujoco/mjx/integration_test/collision_driver_test.py: 245 warnings
  /mujoco/main/mjx/mujoco/mjx/integration_test/collision_driver_test.py:86: DeprecationWarning: Accessing `contact` directly from `Data` is deprecated. Access it via `data._impl.contact` instead.
    lambda x: x.take(np.array(idx), axis=0), dx.contact

mujoco/mjx/integration_test/collision_driver_test.py: 11 warnings
  /mujoco/main/mjx/mujoco/mjx/integration_test/collision_driver_test.py:75: DeprecationWarning: Accessing `contact` directly from `Data` is deprecated. Access it via `data._impl.contact` instead.
    self.assertTrue((dx.contact.dist > 0).all())

mujoco/mjx/integration_test/smooth_test.py: 30 warnings
  /mujoco/main/mjx/mujoco/mjx/integration_test/smooth_test.py:86: DeprecationWarning: Accessing `actuator_moment` directly from `Data` is deprecated. Access it via `data._impl.actuator_moment` instead.
    dx.actuator_moment,

mujoco/mjx/warp/forward_test.py: 11 warnings
mujoco/mjx/warp/smooth_test.py: 3 warnings
  /mujoco/main/mjx/mujoco/mjx/warp/test_util.py:47: DeprecationWarning: nconmax will be deprecated in mujoco-mjx>=3.5. Use naconmax instead.
    dx = mjx.make_data(m, impl='warp', nconmax=nconmax, njmax=njmax)

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
=========================== short test summary info ===========================
FAILED mujoco/mjx/warp/forward_test.py::ForwardTest::test_forward1 - AssertionError:
FAILED mujoco/mjx/warp/forward_test.py::ForwardTest::test_jit_caching0 - AssertionError: 0 != 1
FAILED mujoco/mjx/warp/forward_test.py::ForwardTest::test_jit_caching1 - AssertionError: 0 != 1
=========== 3 failed, 715 passed, 804 warnings in 357.33s (0:05:57) ===========
WARNING: Nan, Inf or huge value in QACC at DOF 0. The simulation is unstable. Time = 0.0000.

WARNING: Nan, Inf or huge value in QACC at DOF 9. The simulation is unstable. Time = 0.0680.

Weirdly enough, the two cache tests don't fail if I run the forward_test.py directly:

Details
pytest ./mjx/mujoco/mjx/warp/forward_test.py
================================ test session starts ================================
platform darwin -- Python 3.12.10, pytest-9.0.2, pluggy-1.6.0
rootdir: /mujoco/mjx
configfile: pyproject.toml
collected 11 items

mjx/mujoco/mjx/warp/forward_test.py .F.........                               [100%]

===================================== FAILURES ======================================
_____________________________ ForwardTest.test_forward1 _____________________________

self = <mjx.warp.forward_test.ForwardTest testMethod=test_forward1>
xml = 'humanoid/humanoid.xml', batch_size = 7

    @parameterized.product(
        xml=(
            'humanoid/humanoid.xml',
            'pendula.xml',
        ),
        batch_size=(1, 7),
    )
    def test_forward(self, xml: str, batch_size: int):
      if not _FORCE_TEST:
        if not mjxw.WARP_INSTALLED:
          self.skipTest('Warp not installed.')

      m = test_util.load_test_file(xml)
      m.opt.iterations = 10
      m.opt.ls_iterations = 10
      m.opt.jacobian = mujoco.mjtJacobian.mjJAC_DENSE
      mx = mjx.put_model(m, impl='warp')

      d = mujoco.MjData(m)
      worldids = jp.arange(batch_size)
      dx_batch = jax.vmap(functools.partial(tu.make_data, m))(worldids)

      dx_batch = jax.jit(jax.vmap(forward.forward, in_axes=(None, 0)))(
          mx, dx_batch
      )

      for i in range(batch_size):
        dx = dx_batch[i]

        d.qpos[:] = dx.qpos
        d.qvel[:] = dx.qvel
        d.ctrl[:] = dx.ctrl
        d.mocap_pos[:] = dx.mocap_pos
        d.mocap_quat[:] = dx.mocap_quat
        mujoco.mj_forward(m, d)

        # fwd_position
        tu.assert_attr_eq(dx, d, 'xpos')
        tu.assert_attr_eq(dx, d, 'xquat')
        tu.assert_attr_eq(dx, d, 'xipos')
        tu.assert_eq(d.ximat.reshape((-1, 3, 3)), dx.ximat, 'ximat')
        tu.assert_attr_eq(dx, d, 'xanchor')
        tu.assert_attr_eq(dx, d, 'xaxis')
        tu.assert_attr_eq(dx, d, 'geom_xpos')
        tu.assert_eq(dx.geom_xmat, d.geom_xmat.reshape((-1, 3, 3)), 'geom_xmat')
        if m.nsite:
          tu.assert_attr_eq(dx, d, 'site_xpos')
          tu.assert_eq(dx.site_xmat, d.site_xmat.reshape((-1, 3, 3)), 'site_xmat')
        tu.assert_attr_eq(dx, d, 'cdof')
        tu.assert_attr_eq(dx._impl, d, 'cinert')
        tu.assert_attr_eq(dx, d, 'subtree_com')
        if m.nlight:
          tu.assert_attr_eq(dx._impl, d, 'light_xpos')
          tu.assert_attr_eq(dx._impl, d, 'light_xdir')
        if m.ncam:
          tu.assert_attr_eq(dx, d, 'cam_xpos')
          tu.assert_eq(dx.cam_xmat, d.cam_xmat.reshape((-1, 3, 3)), 'cam_xmat')
        tu.assert_attr_eq(dx, d, 'ten_length')
        tu.assert_attr_eq(dx._impl, d, 'ten_J')
        tu.assert_attr_eq(dx._impl, d, 'ten_wrapadr')
        tu.assert_attr_eq(dx._impl, d, 'ten_wrapnum')
        tu.assert_attr_eq(dx._impl, d, 'wrap_xpos')
        tu.assert_attr_eq(dx._impl, d, 'wrap_obj')
        tu.assert_attr_eq(dx._impl, d, 'crb')

        qm = np.zeros((m.nv, m.nv))
        mujoco.mj_fullM(m, qm, d.qM)
        # mjwarp adds padding to qM
        tu.assert_eq(qm, dx._impl.qM[: m.nv, : m.nv], 'qM')
        # qLD is fused in a cholesky factorize and solve, and not written to.

        tu.assert_contact_eq(d, dx, worldid=i)

        tu.assert_attr_eq(dx, d, 'actuator_length')
        actuator_moment = np.zeros((m.nu, m.nv))
        mujoco.mju_sparse2dense(
            actuator_moment,
            d.actuator_moment,
            d.moment_rownnz,
            d.moment_rowadr,
            d.moment_colind,
        )
        tu.assert_eq(dx._impl.actuator_moment, actuator_moment, 'actuator_moment')

        # fwd_velocity
        tu.assert_attr_eq(dx._impl, d, 'actuator_velocity')
        tu.assert_attr_eq(dx, d, 'cvel')
        tu.assert_attr_eq(dx, d, 'cdof_dot')
        tu.assert_attr_eq(dx._impl, d, 'qfrc_spring')
        tu.assert_attr_eq(dx._impl, d, 'qfrc_damper')
        tu.assert_attr_eq(dx, d, 'qfrc_gravcomp')
        tu.assert_attr_eq(dx, d, 'qfrc_fluid')
        tu.assert_attr_eq(dx, d, 'qfrc_passive')
        tu.assert_attr_eq(dx, d, 'qfrc_bias')
        # NOTE(user): This fails due to some weird sorting of keys.
>       tu.assert_efc_eq(d, dx, worldid=i)

mjx/mujoco/mjx/warp/forward_test.py:179:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
mjx/mujoco/mjx/warp/test_util.py:199: in assert_efc_eq
    assert_eq(jp_, j, 'J')
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

a = Array([[-0.9875637 ,  0.15721938,  1.        , -0.00917677,  1.335037  ,
        -0.10656909, -0.10577767,  1.1439778 ...    -0.03383656,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ]], dtype=float32)
b = array([[-0.9875637 ,  0.15721938,  1.        , -0.00917675,  1.33503689,
        -0.10656907, -0.10577765,  1.14397789....07388643,
        -0.0338367 ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ]])
name = 'J'

    def assert_eq(a, b, name):
      tol = _TOLERANCE * 10  # avoid test noise
      err_msg = f'mismatch: {name}'
>     np.testing.assert_allclose(a, b, err_msg=err_msg, atol=tol, rtol=tol)
E     AssertionError:
E     Not equal to tolerance rtol=0.0005, atol=0.0005
E     mismatch: J
E     Mismatched elements: 160 / 432 (37%)
E     First 5 mismatches are at indices:
E      [4, 0]: -0.9871430993080139 (ACTUAL), -0.999479684061378 (DESIRED)
E      [4, 1]: 0.15983904898166656 (ACTUAL), -0.03225462987801541 (DESIRED)
E      [4, 3]: 0.16618438065052032 (ACTUAL), -0.22991350005486522 (DESIRED)
E      [4, 4]: 1.2810968160629272 (ACTUAL), 1.3359474985903166 (DESIRED)
E      [4, 5]: 0.044975683093070984 (ACTUAL), -0.0601160972922109 (DESIRED)
E     Max absolute difference among violations: 0.87907806
E     Max relative difference among violations: 12.74774404
E      ACTUAL: array([[-0.987564,  0.157219,  1.      , -0.009177,  1.335037, -0.106569,
E             -0.105778,  1.143978, -0.053232,  0.039397, -0.004723,  0.880073,
E             -0.497201,  0.173519, -0.011014,  0.      ,  0.      ,  0.      ,...
E      DESIRED: array([[-0.987564,  0.157219,  1.      , -0.009177,  1.335037, -0.106569,
E             -0.105778,  1.143978, -0.053232,  0.039397, -0.004723,  0.880073,
E             -0.497201,  0.173519, -0.011014,  0.      ,  0.      ,  0.      ,...

mjx/mujoco/mjx/warp/test_util.py:36: AssertionError
================================= warnings summary ==================================
mujoco/mjx/warp/forward_test.py: 13 warnings
  /mujoco/mjx/mujoco/mjx/warp/test_util.py:47: DeprecationWarning: nconmax will be deprecated in mujoco-mjx>=3.5. Use naconmax instead.
    dx = mjx.make_data(m, impl='warp', nconmax=nconmax, njmax=njmax)

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
============================== short test summary info ==============================
FAILED mjx/mujoco/mjx/warp/forward_test.py::ForwardTest::test_forward1 - AssertionError:
==================== 1 failed, 10 passed, 13 warnings in 28.00s =====================

I tracked the numerical errors down to lexsort in test_util.py

keys_sorted = np.lexsort((-efc_pos, efc_type, efc_d))

Changing this line to the following fixes the issue:

  keys_sorted = np.lexsort((-np.round(efc_pos, 4), efc_type, np.round(efc_d, 4)))

I'm not sure if the lexsort rounding is worth including here given that the tests pass on CI. Happy to open a PR for it though.

efc_type = select(dx._impl.efc__type)[:nefc]
efc_d = select(dx._impl.efc__D)[:nefc]
keys_sorted = np.lexsort((-efc_pos, efc_type, efc_d))
keys_sorted = np.lexsort((np.round(-efc_pos, 12), efc_type, np.round(efc_d, 12)))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note these rounding changes. Let me know if you want these to be in a separate PR.

@hartikainen
Copy link
Contributor Author

hartikainen commented Mar 19, 2026

@thowell @btaba would either of you have time to take a look at this one? It would be really useful for our team to be able to visualize the warp models on macos. I'm pretty confident that the macos build error is unrelated to my changes.

@hartikainen hartikainen force-pushed the mjw-on-macos branch 2 times, most recently from 31b5501 to c38d802 Compare March 19, 2026 17:56
This is not ready yet but just demonstrates that it should be possible
to enable warp backend on MacOS.

```console
$
uv pip install -U --extra-index-url="https://py.mujoco.org" "mujoco>=3.7.0.dev0,<3.8.0" && \
uv pip install -U warp-lang && \
uv pip install -U -e ./mjx

/Users/google-deepmind/mujoco/main/mjx/mujoco/mjx/_src/io_test.py:472: DeprecationWarning: nconmax will be deprecated in mujoco-mjx>=3.5. Use naconmax instead.
  d = mjx.make_data(m, impl='warp', nconmax=9, njmax=23)
Warp 1.12.0 initialized:
   CUDA not enabled in this build
   Devices:
     "cpu"      : "arm"
   Kernel cache:
     /private/var/folders/ml/rrlg98ln26l7xvxgq_yqfn4c0000gn/T/tmpoklv_7xg
Warp DeprecationWarning: The symbol `warp.types.warp_type_to_np_dtype` will soon be removed from the public API. It can still be accessed from `warp._src.types.warp_type_to_np_dtype` but might be changed or removed without notice.
./opt/homebrew/Cellar/python@3.12/3.12.10/Frameworks/Python.framework/Versions/3.12/lib/python3.12/tempfile.py:940: ResourceWarning: Implicitly cleaning up <TemporaryDirectory '/var/folders/ml/rrlg98ln26l7xvxgq_yqfn4c0000gn/T/tmpoklv_7xg'>
  _warnings.warn(warn_message, ResourceWarning)

----------------------------------------------------------------------
Ran 1 test in 1.105s

OK
```
When using warp on CPU, there are small numerical differences that cause
the `np.lexsort` to return inconsistently-ordered `efc-pos` and `efc_D`.
This changes the sorting to round the values, fixing the flakiness.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant