Conversation
20b91ca to
9f878f1
Compare
6e718c6 to
2da826b
Compare
There was a problem hiding this comment.
This should probably go into the mujoco warp repo. It could also use a bit more thorough review as I'm not super fluent with the FFI semantics.
493b601 to
7f0f35f
Compare
|
This should be ready to be reviewed. I see some tests failing locally like this: Details$ pytest
============================= test session starts =============================
platform darwin -- Python 3.12.10, pytest-9.0.2, pluggy-1.6.0
rootdir: /mujoco/main/mjx
configfile: pyproject.toml
collected 718 items
mujoco/mjx/_src/collision_driver_test.py .............................. [ 4%]
....... [ 5%]
mujoco/mjx/_src/constraint_test.py .......... [ 6%]
mujoco/mjx/_src/dataclasses_test.py . [ 6%]
mujoco/mjx/_src/forward_test.py ......... [ 7%]
mujoco/mjx/_src/inverse_test.py ....... [ 8%]
mujoco/mjx/_src/io_test.py ............................................ [ 15%]
........................................................ [ 22%]
mujoco/mjx/_src/math_test.py .......................................... [ 28%]
....................................................................... [ 38%]
........................ [ 41%]
mujoco/mjx/_src/mesh_test.py ... [ 42%]
mujoco/mjx/_src/passive_test.py . [ 42%]
mujoco/mjx/_src/ray_test.py ........... [ 44%]
mujoco/mjx/_src/scan_test.py .... [ 44%]
mujoco/mjx/_src/sensor_test.py .......... [ 45%]
mujoco/mjx/_src/smooth_test.py ............................. [ 50%]
mujoco/mjx/_src/solver_test.py ............. [ 51%]
mujoco/mjx/_src/support_test.py .............. [ 53%]
mujoco/mjx/integration_test/collision_driver_test.py .................. [ 56%]
....................................................................... [ 66%]
....................................................................... [ 76%]
....................................................................... [ 85%]
......................... [ 89%]
mujoco/mjx/integration_test/forward_test.py ........................... [ 93%]
... [ 93%]
mujoco/mjx/integration_test/smooth_test.py ............................ [ 97%]
.. [ 97%]
mujoco/mjx/warp/collision_driver_test.py . [ 97%]
mujoco/mjx/warp/forward_test.py .F..FF..... [ 99%]
mujoco/mjx/warp/smooth_test.py .... [100%]
================================== FAILURES ===================================
__________________________ ForwardTest.test_forward1 __________________________
self = <mjx.warp.forward_test.ForwardTest testMethod=test_forward1>
xml = 'humanoid/humanoid.xml', batch_size = 7
@parameterized.product(
xml=(
'humanoid/humanoid.xml',
'pendula.xml',
),
batch_size=(1, 7),
)
def test_forward(self, xml: str, batch_size: int):
if not _FORCE_TEST:
if not mjxw.WARP_INSTALLED:
self.skipTest('Warp not installed.')
m = test_util.load_test_file(xml)
m.opt.iterations = 10
m.opt.ls_iterations = 10
m.opt.jacobian = mujoco.mjtJacobian.mjJAC_DENSE
mx = mjx.put_model(m, impl='warp')
d = mujoco.MjData(m)
worldids = jp.arange(batch_size)
dx_batch = jax.vmap(functools.partial(tu.make_data, m))(worldids)
dx_batch = jax.jit(jax.vmap(forward.forward, in_axes=(None, 0)))(
mx, dx_batch
)
for i in range(batch_size):
dx = dx_batch[i]
d.qpos[:] = dx.qpos
d.qvel[:] = dx.qvel
d.ctrl[:] = dx.ctrl
d.mocap_pos[:] = dx.mocap_pos
d.mocap_quat[:] = dx.mocap_quat
mujoco.mj_forward(m, d)
# fwd_position
tu.assert_attr_eq(dx, d, 'xpos')
tu.assert_attr_eq(dx, d, 'xquat')
tu.assert_attr_eq(dx, d, 'xipos')
tu.assert_eq(d.ximat.reshape((-1, 3, 3)), dx.ximat, 'ximat')
tu.assert_attr_eq(dx, d, 'xanchor')
tu.assert_attr_eq(dx, d, 'xaxis')
tu.assert_attr_eq(dx, d, 'geom_xpos')
tu.assert_eq(dx.geom_xmat, d.geom_xmat.reshape((-1, 3, 3)), 'geom_xmat')
if m.nsite:
tu.assert_attr_eq(dx, d, 'site_xpos')
tu.assert_eq(dx.site_xmat, d.site_xmat.reshape((-1, 3, 3)), 'site_xmat')
tu.assert_attr_eq(dx, d, 'cdof')
tu.assert_attr_eq(dx._impl, d, 'cinert')
tu.assert_attr_eq(dx, d, 'subtree_com')
if m.nlight:
tu.assert_attr_eq(dx._impl, d, 'light_xpos')
tu.assert_attr_eq(dx._impl, d, 'light_xdir')
if m.ncam:
tu.assert_attr_eq(dx, d, 'cam_xpos')
tu.assert_eq(dx.cam_xmat, d.cam_xmat.reshape((-1, 3, 3)), 'cam_xmat')
tu.assert_attr_eq(dx, d, 'ten_length')
tu.assert_attr_eq(dx._impl, d, 'ten_J')
tu.assert_attr_eq(dx._impl, d, 'ten_wrapadr')
tu.assert_attr_eq(dx._impl, d, 'ten_wrapnum')
tu.assert_attr_eq(dx._impl, d, 'wrap_xpos')
tu.assert_attr_eq(dx._impl, d, 'wrap_obj')
tu.assert_attr_eq(dx._impl, d, 'crb')
qm = np.zeros((m.nv, m.nv))
mujoco.mj_fullM(m, qm, d.qM)
# mjwarp adds padding to qM
tu.assert_eq(qm, dx._impl.qM[: m.nv, : m.nv], 'qM')
# qLD is fused in a cholesky factorize and solve, and not written to.
tu.assert_contact_eq(d, dx, worldid=i)
tu.assert_attr_eq(dx, d, 'actuator_length')
actuator_moment = np.zeros((m.nu, m.nv))
mujoco.mju_sparse2dense(
actuator_moment,
d.actuator_moment,
d.moment_rownnz,
d.moment_rowadr,
d.moment_colind,
)
tu.assert_eq(dx._impl.actuator_moment, actuator_moment, 'actuator_moment')
# fwd_velocity
tu.assert_attr_eq(dx._impl, d, 'actuator_velocity')
tu.assert_attr_eq(dx, d, 'cvel')
tu.assert_attr_eq(dx, d, 'cdof_dot')
tu.assert_attr_eq(dx._impl, d, 'qfrc_spring')
tu.assert_attr_eq(dx._impl, d, 'qfrc_damper')
tu.assert_attr_eq(dx, d, 'qfrc_gravcomp')
tu.assert_attr_eq(dx, d, 'qfrc_fluid')
tu.assert_attr_eq(dx, d, 'qfrc_passive')
tu.assert_attr_eq(dx, d, 'qfrc_bias')
> tu.assert_efc_eq(d, dx, worldid=i)
mujoco/mjx/warp/forward_test.py:179:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
mujoco/mjx/warp/test_util.py:199: in assert_efc_eq
assert_eq(jp_, j, 'J')
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
a = Array([[-0.9875637 , 0.15721938, 1. , -0.00917677, 1.335037 ,
-0.10656909, -0.10577767, 1.1439778 ... -0.03383656, 0. , 0. , 0. , 0. ,
0. , 0. ]], dtype=float32)
b = array([[-0.9875637 , 0.15721938, 1. , -0.00917675, 1.33503689,
-0.10656907, -0.10577765, 1.14397789....07388643,
-0.0338367 , 0. , 0. , 0. , 0. ,
0. , 0. ]])
name = 'J'
def assert_eq(a, b, name):
tol = _TOLERANCE * 10 # avoid test noise
err_msg = f'mismatch: {name}'
> np.testing.assert_allclose(a, b, err_msg=err_msg, atol=tol, rtol=tol)
E AssertionError:
E Not equal to tolerance rtol=0.0005, atol=0.0005
E mismatch: J
E Mismatched elements: 160 / 432 (37%)
E First 5 mismatches are at indices:
E [4, 0]: -0.9871430993080139 (ACTUAL), -0.999479684061378 (DESIRED)
E [4, 1]: 0.15983904898166656 (ACTUAL), -0.03225462987801541 (DESIRED)
E [4, 3]: 0.16618438065052032 (ACTUAL), -0.22991350005486522 (DESIRED)
E [4, 4]: 1.2810968160629272 (ACTUAL), 1.3359474985903166 (DESIRED)
E [4, 5]: 0.044975683093070984 (ACTUAL), -0.0601160972922109 (DESIRED)
E Max absolute difference among violations: 0.87907806
E Max relative difference among violations: 12.74774404
E ACTUAL: array([[-0.987564, 0.157219, 1. , -0.009177, 1.335037, -0.106569,
E -0.105778, 1.143978, -0.053232, 0.039397, -0.004723, 0.880073,
E -0.497201, 0.173519, -0.011014, 0. , 0. , 0. ,...
E DESIRED: array([[-0.987564, 0.157219, 1. , -0.009177, 1.335037, -0.106569,
E -0.105778, 1.143978, -0.053232, 0.039397, -0.004723, 0.880073,
E -0.497201, 0.173519, -0.011014, 0. , 0. , 0. ,...
mujoco/mjx/warp/test_util.py:36: AssertionError
________________________ ForwardTest.test_jit_caching0 ________________________
self = <mjx.warp.forward_test.ForwardTest testMethod=test_jit_caching0>
xml = 'pendula.xml'
@parameterized.parameters(
'pendula.xml',
'humanoid/humanoid.xml',
)
def test_jit_caching(self, xml):
"""Tests jit caching on the full step function."""
if not _FORCE_TEST:
if not mjxw.WARP_INSTALLED:
self.skipTest('Warp not installed.')
batch_size = 7
m = test_util.load_test_file(xml)
mx = mjx.put_model(m, impl='warp')
keys = jp.arange(batch_size)
dx_batch = jax.vmap(functools.partial(tu.make_data, m))(keys)
step_fn = jax.jit(jax.vmap(forward.step, in_axes=(None, 0)))
dx_batch1 = step_fn(mx, dx_batch)
jax.tree_util.tree_map(lambda x: x.block_until_ready(), dx_batch1)
> self.assertEqual(step_fn._cache_size(), 1)
E AssertionError: 0 != 1
mujoco/mjx/warp/forward_test.py:76: AssertionError
---------------------------- Captured stdout call -----------------------------
Module mul_m_sparse_diag__locals___mul_m_sparse_diag_1de23634 4949128 load on device 'cpu' took 195.10 ms (compiled)
Module mul_m_sparse_ij__locals___mul_m_sparse_ij_b1bb5fb3 8b21b7f load on device 'cpu' took 204.99 ms (compiled)
Module update_gradient_JTDAJ_sparse_tiled__locals__kernel_8f59ead1 6963331 load on device 'cpu' took 269.85 ms (compiled)
________________________ ForwardTest.test_jit_caching1 ________________________
self = <mjx.warp.forward_test.ForwardTest testMethod=test_jit_caching1>
xml = 'humanoid/humanoid.xml'
@parameterized.parameters(
'pendula.xml',
'humanoid/humanoid.xml',
)
def test_jit_caching(self, xml):
"""Tests jit caching on the full step function."""
if not _FORCE_TEST:
if not mjxw.WARP_INSTALLED:
self.skipTest('Warp not installed.')
batch_size = 7
m = test_util.load_test_file(xml)
mx = mjx.put_model(m, impl='warp')
keys = jp.arange(batch_size)
dx_batch = jax.vmap(functools.partial(tu.make_data, m))(keys)
step_fn = jax.jit(jax.vmap(forward.step, in_axes=(None, 0)))
dx_batch1 = step_fn(mx, dx_batch)
jax.tree_util.tree_map(lambda x: x.block_until_ready(), dx_batch1)
> self.assertEqual(step_fn._cache_size(), 1)
E AssertionError: 0 != 1
mujoco/mjx/warp/forward_test.py:76: AssertionError
============================== warnings summary ===============================
mujoco/mjx/_src/collision_driver_test.py::ConvexTest::test_box_box
mujoco/mjx/_src/collision_driver_test.py::ConvexTest::test_box_box_edge
mujoco/mjx/_src/collision_driver_test.py::ConvexTest::test_convex_convex
mujoco/mjx/_src/collision_driver_test.py::ConvexTest::test_convex_convex_edge
mujoco/mjx/_src/collision_driver_test.py::HFieldTest::test_hfield_deep
mujoco/mjx/_src/support_test.py::SupportTest::test_bind
mujoco/mjx/integration_test/smooth_test.py::TransmissionIntegrationTest::test_transmission14
/mujoco/main/.venv/lib/python3.12/site-packages/jax/_src/abstract_arrays.py:135: RuntimeWarning: overflow encountered in cast
return literals.TypedNdArray(np.asarray(x, dtype), weak_type=False)
mujoco/mjx/_src/forward_test.py::ActuatorTest::test_actuator2
mujoco/mjx/_src/io_test.py::DataIOTest::test_qm_mapm2m0
mujoco/mjx/_src/passive_test.py::PassiveTest::test_passive
mujoco/mjx/integration_test/collision_driver_test.py::CollisionDriverIntegrationTest::test_collision_driver0
mujoco/mjx/integration_test/collision_driver_test.py::CollisionDriverIntegrationTest::test_collision_driver115
mujoco/mjx/integration_test/collision_driver_test.py::CollisionDriverIntegrationTest::test_collision_driver125
/mujoco/main/.venv/lib/python3.12/site-packages/jax/_src/interpreters/partial_eval.py:2412: DeprecationWarning: Passing arguments 'a', 'a_min' or 'a_max' to jax.numpy.clip is deprecated. Please use 'arr', 'min' or 'max' respectively instead.
ans_pytree = fun(*args, **kwargs)
mujoco/mjx/_src/io_test.py::DataIOTest::test_make_data_warp
/mujoco/main/mjx/mujoco/mjx/_src/io_test.py:468: DeprecationWarning: nconmax will be deprecated in mujoco-mjx>=3.5. Use naconmax instead.
d = mjx.make_data(m, impl='warp', nconmax=9, njmax=23)
mujoco/mjx/integration_test/collision_driver_test.py: 490 warnings
/mujoco/main/mjx/mujoco/mjx/integration_test/collision_driver_test.py:80: DeprecationWarning: Accessing `contact` directly from `Data` is deprecated. Access it via `data._impl.contact` instead.
idx_mjx = list(zip(dx.contact.geom1, dx.contact.geom2))
mujoco/mjx/integration_test/collision_driver_test.py: 245 warnings
/mujoco/main/mjx/mujoco/mjx/integration_test/collision_driver_test.py:86: DeprecationWarning: Accessing `contact` directly from `Data` is deprecated. Access it via `data._impl.contact` instead.
lambda x: x.take(np.array(idx), axis=0), dx.contact
mujoco/mjx/integration_test/collision_driver_test.py: 11 warnings
/mujoco/main/mjx/mujoco/mjx/integration_test/collision_driver_test.py:75: DeprecationWarning: Accessing `contact` directly from `Data` is deprecated. Access it via `data._impl.contact` instead.
self.assertTrue((dx.contact.dist > 0).all())
mujoco/mjx/integration_test/smooth_test.py: 30 warnings
/mujoco/main/mjx/mujoco/mjx/integration_test/smooth_test.py:86: DeprecationWarning: Accessing `actuator_moment` directly from `Data` is deprecated. Access it via `data._impl.actuator_moment` instead.
dx.actuator_moment,
mujoco/mjx/warp/forward_test.py: 11 warnings
mujoco/mjx/warp/smooth_test.py: 3 warnings
/mujoco/main/mjx/mujoco/mjx/warp/test_util.py:47: DeprecationWarning: nconmax will be deprecated in mujoco-mjx>=3.5. Use naconmax instead.
dx = mjx.make_data(m, impl='warp', nconmax=nconmax, njmax=njmax)
-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
=========================== short test summary info ===========================
FAILED mujoco/mjx/warp/forward_test.py::ForwardTest::test_forward1 - AssertionError:
FAILED mujoco/mjx/warp/forward_test.py::ForwardTest::test_jit_caching0 - AssertionError: 0 != 1
FAILED mujoco/mjx/warp/forward_test.py::ForwardTest::test_jit_caching1 - AssertionError: 0 != 1
=========== 3 failed, 715 passed, 804 warnings in 357.33s (0:05:57) ===========
WARNING: Nan, Inf or huge value in QACC at DOF 0. The simulation is unstable. Time = 0.0000.
WARNING: Nan, Inf or huge value in QACC at DOF 9. The simulation is unstable. Time = 0.0680.Weirdly enough, the two cache tests don't fail if I run the Detailspytest ./mjx/mujoco/mjx/warp/forward_test.py
================================ test session starts ================================
platform darwin -- Python 3.12.10, pytest-9.0.2, pluggy-1.6.0
rootdir: /mujoco/mjx
configfile: pyproject.toml
collected 11 items
mjx/mujoco/mjx/warp/forward_test.py .F......... [100%]
===================================== FAILURES ======================================
_____________________________ ForwardTest.test_forward1 _____________________________
self = <mjx.warp.forward_test.ForwardTest testMethod=test_forward1>
xml = 'humanoid/humanoid.xml', batch_size = 7
@parameterized.product(
xml=(
'humanoid/humanoid.xml',
'pendula.xml',
),
batch_size=(1, 7),
)
def test_forward(self, xml: str, batch_size: int):
if not _FORCE_TEST:
if not mjxw.WARP_INSTALLED:
self.skipTest('Warp not installed.')
m = test_util.load_test_file(xml)
m.opt.iterations = 10
m.opt.ls_iterations = 10
m.opt.jacobian = mujoco.mjtJacobian.mjJAC_DENSE
mx = mjx.put_model(m, impl='warp')
d = mujoco.MjData(m)
worldids = jp.arange(batch_size)
dx_batch = jax.vmap(functools.partial(tu.make_data, m))(worldids)
dx_batch = jax.jit(jax.vmap(forward.forward, in_axes=(None, 0)))(
mx, dx_batch
)
for i in range(batch_size):
dx = dx_batch[i]
d.qpos[:] = dx.qpos
d.qvel[:] = dx.qvel
d.ctrl[:] = dx.ctrl
d.mocap_pos[:] = dx.mocap_pos
d.mocap_quat[:] = dx.mocap_quat
mujoco.mj_forward(m, d)
# fwd_position
tu.assert_attr_eq(dx, d, 'xpos')
tu.assert_attr_eq(dx, d, 'xquat')
tu.assert_attr_eq(dx, d, 'xipos')
tu.assert_eq(d.ximat.reshape((-1, 3, 3)), dx.ximat, 'ximat')
tu.assert_attr_eq(dx, d, 'xanchor')
tu.assert_attr_eq(dx, d, 'xaxis')
tu.assert_attr_eq(dx, d, 'geom_xpos')
tu.assert_eq(dx.geom_xmat, d.geom_xmat.reshape((-1, 3, 3)), 'geom_xmat')
if m.nsite:
tu.assert_attr_eq(dx, d, 'site_xpos')
tu.assert_eq(dx.site_xmat, d.site_xmat.reshape((-1, 3, 3)), 'site_xmat')
tu.assert_attr_eq(dx, d, 'cdof')
tu.assert_attr_eq(dx._impl, d, 'cinert')
tu.assert_attr_eq(dx, d, 'subtree_com')
if m.nlight:
tu.assert_attr_eq(dx._impl, d, 'light_xpos')
tu.assert_attr_eq(dx._impl, d, 'light_xdir')
if m.ncam:
tu.assert_attr_eq(dx, d, 'cam_xpos')
tu.assert_eq(dx.cam_xmat, d.cam_xmat.reshape((-1, 3, 3)), 'cam_xmat')
tu.assert_attr_eq(dx, d, 'ten_length')
tu.assert_attr_eq(dx._impl, d, 'ten_J')
tu.assert_attr_eq(dx._impl, d, 'ten_wrapadr')
tu.assert_attr_eq(dx._impl, d, 'ten_wrapnum')
tu.assert_attr_eq(dx._impl, d, 'wrap_xpos')
tu.assert_attr_eq(dx._impl, d, 'wrap_obj')
tu.assert_attr_eq(dx._impl, d, 'crb')
qm = np.zeros((m.nv, m.nv))
mujoco.mj_fullM(m, qm, d.qM)
# mjwarp adds padding to qM
tu.assert_eq(qm, dx._impl.qM[: m.nv, : m.nv], 'qM')
# qLD is fused in a cholesky factorize and solve, and not written to.
tu.assert_contact_eq(d, dx, worldid=i)
tu.assert_attr_eq(dx, d, 'actuator_length')
actuator_moment = np.zeros((m.nu, m.nv))
mujoco.mju_sparse2dense(
actuator_moment,
d.actuator_moment,
d.moment_rownnz,
d.moment_rowadr,
d.moment_colind,
)
tu.assert_eq(dx._impl.actuator_moment, actuator_moment, 'actuator_moment')
# fwd_velocity
tu.assert_attr_eq(dx._impl, d, 'actuator_velocity')
tu.assert_attr_eq(dx, d, 'cvel')
tu.assert_attr_eq(dx, d, 'cdof_dot')
tu.assert_attr_eq(dx._impl, d, 'qfrc_spring')
tu.assert_attr_eq(dx._impl, d, 'qfrc_damper')
tu.assert_attr_eq(dx, d, 'qfrc_gravcomp')
tu.assert_attr_eq(dx, d, 'qfrc_fluid')
tu.assert_attr_eq(dx, d, 'qfrc_passive')
tu.assert_attr_eq(dx, d, 'qfrc_bias')
# NOTE(user): This fails due to some weird sorting of keys.
> tu.assert_efc_eq(d, dx, worldid=i)
mjx/mujoco/mjx/warp/forward_test.py:179:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
mjx/mujoco/mjx/warp/test_util.py:199: in assert_efc_eq
assert_eq(jp_, j, 'J')
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
a = Array([[-0.9875637 , 0.15721938, 1. , -0.00917677, 1.335037 ,
-0.10656909, -0.10577767, 1.1439778 ... -0.03383656, 0. , 0. , 0. , 0. ,
0. , 0. ]], dtype=float32)
b = array([[-0.9875637 , 0.15721938, 1. , -0.00917675, 1.33503689,
-0.10656907, -0.10577765, 1.14397789....07388643,
-0.0338367 , 0. , 0. , 0. , 0. ,
0. , 0. ]])
name = 'J'
def assert_eq(a, b, name):
tol = _TOLERANCE * 10 # avoid test noise
err_msg = f'mismatch: {name}'
> np.testing.assert_allclose(a, b, err_msg=err_msg, atol=tol, rtol=tol)
E AssertionError:
E Not equal to tolerance rtol=0.0005, atol=0.0005
E mismatch: J
E Mismatched elements: 160 / 432 (37%)
E First 5 mismatches are at indices:
E [4, 0]: -0.9871430993080139 (ACTUAL), -0.999479684061378 (DESIRED)
E [4, 1]: 0.15983904898166656 (ACTUAL), -0.03225462987801541 (DESIRED)
E [4, 3]: 0.16618438065052032 (ACTUAL), -0.22991350005486522 (DESIRED)
E [4, 4]: 1.2810968160629272 (ACTUAL), 1.3359474985903166 (DESIRED)
E [4, 5]: 0.044975683093070984 (ACTUAL), -0.0601160972922109 (DESIRED)
E Max absolute difference among violations: 0.87907806
E Max relative difference among violations: 12.74774404
E ACTUAL: array([[-0.987564, 0.157219, 1. , -0.009177, 1.335037, -0.106569,
E -0.105778, 1.143978, -0.053232, 0.039397, -0.004723, 0.880073,
E -0.497201, 0.173519, -0.011014, 0. , 0. , 0. ,...
E DESIRED: array([[-0.987564, 0.157219, 1. , -0.009177, 1.335037, -0.106569,
E -0.105778, 1.143978, -0.053232, 0.039397, -0.004723, 0.880073,
E -0.497201, 0.173519, -0.011014, 0. , 0. , 0. ,...
mjx/mujoco/mjx/warp/test_util.py:36: AssertionError
================================= warnings summary ==================================
mujoco/mjx/warp/forward_test.py: 13 warnings
/mujoco/mjx/mujoco/mjx/warp/test_util.py:47: DeprecationWarning: nconmax will be deprecated in mujoco-mjx>=3.5. Use naconmax instead.
dx = mjx.make_data(m, impl='warp', nconmax=nconmax, njmax=njmax)
-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
============================== short test summary info ==============================
FAILED mjx/mujoco/mjx/warp/forward_test.py::ForwardTest::test_forward1 - AssertionError:
==================== 1 failed, 10 passed, 13 warnings in 28.00s =====================I tracked the numerical errors down to mujoco/mjx/mujoco/mjx/warp/test_util.py Line 156 in 07d7bc9 Changing this line to the following fixes the issue: keys_sorted = np.lexsort((-np.round(efc_pos, 4), efc_type, np.round(efc_d, 4)))I'm not sure if the |
7f0f35f to
965bef3
Compare
9fd58de to
732877b
Compare
| efc_type = select(dx._impl.efc__type)[:nefc] | ||
| efc_d = select(dx._impl.efc__D)[:nefc] | ||
| keys_sorted = np.lexsort((-efc_pos, efc_type, efc_d)) | ||
| keys_sorted = np.lexsort((np.round(-efc_pos, 12), efc_type, np.round(efc_d, 12))) |
There was a problem hiding this comment.
Note these rounding changes. Let me know if you want these to be in a separate PR.
31b5501 to
c38d802
Compare
This is not ready yet but just demonstrates that it should be possible to enable warp backend on MacOS. ```console $ uv pip install -U --extra-index-url="https://py.mujoco.org" "mujoco>=3.7.0.dev0,<3.8.0" && \ uv pip install -U warp-lang && \ uv pip install -U -e ./mjx /Users/google-deepmind/mujoco/main/mjx/mujoco/mjx/_src/io_test.py:472: DeprecationWarning: nconmax will be deprecated in mujoco-mjx>=3.5. Use naconmax instead. d = mjx.make_data(m, impl='warp', nconmax=9, njmax=23) Warp 1.12.0 initialized: CUDA not enabled in this build Devices: "cpu" : "arm" Kernel cache: /private/var/folders/ml/rrlg98ln26l7xvxgq_yqfn4c0000gn/T/tmpoklv_7xg Warp DeprecationWarning: The symbol `warp.types.warp_type_to_np_dtype` will soon be removed from the public API. It can still be accessed from `warp._src.types.warp_type_to_np_dtype` but might be changed or removed without notice. ./opt/homebrew/Cellar/python@3.12/3.12.10/Frameworks/Python.framework/Versions/3.12/lib/python3.12/tempfile.py:940: ResourceWarning: Implicitly cleaning up <TemporaryDirectory '/var/folders/ml/rrlg98ln26l7xvxgq_yqfn4c0000gn/T/tmpoklv_7xg'> _warnings.warn(warn_message, ResourceWarning) ---------------------------------------------------------------------- Ran 1 test in 1.105s OK ```
When using warp on CPU, there are small numerical differences that cause the `np.lexsort` to return inconsistently-ordered `efc-pos` and `efc_D`. This changes the sorting to round the values, fixing the flakiness.
485cc2f to
877f79a
Compare
See #2947.
With the existing changes and
warp-lang>=1.11.0, we can now enable warp backend for mujoco on MacOS. More generally, mujoco warp now also works on CPU and not just cuda.