From 57956ace0e2e75848c741419ab6f0c56f4889df6 Mon Sep 17 00:00:00 2001 From: Tarik Kelestemur Date: Thu, 12 Mar 2026 01:15:18 -0400 Subject: [PATCH 1/3] Enable MuJoCo Warp (mjw) backend on macOS Allow impl="warp" to run on macOS by falling back to CPU when no CUDA GPU is available. Previously, the warp backend required CUDA, making it unusable on macOS for local development and debugging. Changes: - Device resolution falls back to CPU when no CUDA GPU is found - Device compatibility check accepts both CUDA GPU and CPU for warp - FFI targets registered on both CUDA and Host platforms - Kernel launch dispatches to wp_cpu_launch_kernel on CPU - CUDA graph traits and capture modes gated on wp.is_cuda_available() - Module preloading extended to CPU devices - Removed has_cuda_gpu_device() skip conditions from warp tests - Fixed lexsort precision issue in test_util.py causing flaky test on macOS due to floating point sort key differences See #2947 --- mjx/mujoco/mjx/_src/io.py | 28 +++--- mjx/mujoco/mjx/_src/io_test.py | 37 ++------ .../warp/_src/jax_experimental/ffi.py | 89 ++++++++++++------- mjx/mujoco/mjx/warp/collision_driver_test.py | 2 - mjx/mujoco/mjx/warp/forward_test.py | 8 -- mjx/mujoco/mjx/warp/smooth_test.py | 13 +-- mjx/mujoco/mjx/warp/test_util.py | 4 +- 7 files changed, 79 insertions(+), 102 deletions(-) diff --git a/mjx/mujoco/mjx/_src/io.py b/mjx/mujoco/mjx/_src/io.py index 253a62a5d2..cd71ec8caf 100644 --- a/mjx/mujoco/mjx/_src/io.py +++ b/mjx/mujoco/mjx/_src/io.py @@ -90,23 +90,20 @@ def _resolve_device( logging.debug('Picking default device: %s.', device_0) return device_0 + if impl == types.Impl.WARP: + if has_cuda_gpu_device(): + cuda_gpus = jax.devices('cuda') + logging.debug('Picking default device: %s', cuda_gpus[0]) + device_0 = cuda_gpus[0] + else: + device_0 = jax.devices('cpu')[0] + return device_0 + if impl == types.Impl.C or impl == types.Impl.CPP: cpu_0 = jax.devices('cpu')[0] logging.debug('Picking default device: %s', cpu_0) return cpu_0 - if impl == types.Impl.WARP: - # WARP implementation requires a CUDA GPU. - cuda_gpus = [d for d in jax.devices('cuda')] - if not cuda_gpus: - raise AssertionError( - 'No CUDA GPU devices found in' - f' jax.devices("cuda")={jax.devices("cuda")}.' - ) - - logging.debug('Picking default device: %s', cuda_gpus[0]) - return cuda_gpus[0] - raise ValueError(f'Unsupported implementation: {impl}') @@ -121,9 +118,12 @@ def _check_impl_device_compatibility( impl = types.Impl(impl) if impl == types.Impl.WARP: - if not _is_cuda_gpu_device(device): + is_cuda_device = _is_cuda_gpu_device(device) + is_cpu_device = device.platform == 'cpu' + if not (is_cuda_device or is_cpu_device): raise AssertionError( - f'Warp implementation requires a CUDA GPU device, got {device}.' + 'Warp implementation requires a CUDA GPU or CPU device, got ' + f'{device}.' ) _check_warp_installed() diff --git a/mjx/mujoco/mjx/_src/io_test.py b/mjx/mujoco/mjx/_src/io_test.py index 9910e9b594..fdf05e73aa 100644 --- a/mjx/mujoco/mjx/_src/io_test.py +++ b/mjx/mujoco/mjx/_src/io_test.py @@ -142,8 +142,6 @@ def setUp(self): def test_put_model(self, xml, impl): if impl == 'warp' and not mjxw.WARP_INSTALLED: self.skipTest('Warp not installed.') - if impl == 'warp' and not mjx_io.has_cuda_gpu_device(): - self.skipTest('No CUDA GPU device available.') m = mujoco.MjModel.from_xml_string(xml) mx = mjx.put_model(m, impl=impl) @@ -315,8 +313,6 @@ def test_put_model_warp_has_expected_shapes(self): """Tests that put_model produces expected shapes for MuJoCo Warp.""" if not mjxw.WARP_INSTALLED: self.skipTest('Warp not installed.') - if not mjx_io.has_cuda_gpu_device(): - self.skipTest('No CUDA GPU device available.') m = mujoco.MjModel.from_xml_string(_MULTIPLE_CONSTRAINTS) mx = mjx.put_model(m, impl='warp') @@ -339,8 +335,6 @@ def test_put_model_warp_graph_mode(self, mode: str | None): """Tests that put_model accepts graph_mode parameter.""" if not mjxw.WARP_INSTALLED: self.skipTest('Warp not installed.') - if not mjx_io.has_cuda_gpu_device(): - self.skipTest('No CUDA GPU device available.') if mode is None: graph_mode = None @@ -487,8 +481,6 @@ def test_make_data(self, impl: str): def test_make_data_warp(self): if not mjxw.WARP_INSTALLED: self.skipTest('Warp is not installed.') - if not mjx_io.has_cuda_gpu_device(): - self.skipTest('No CUDA GPU device.') m = mujoco.MjModel.from_xml_string(_MULTIPLE_CONVEX_OBJECTS) d = mjx.make_data(m, impl='warp', nconmax=9, njmax=23) self.assertEqual(d._impl.contact__dist.shape[0], 9) @@ -500,8 +492,6 @@ def test_put_data(self, impl: str): if impl == 'warp': if not mjxw.WARP_INSTALLED: self.skipTest('Warp is not installed.') - if not mjx_io.has_cuda_gpu_device(): - self.skipTest('No CUDA GPU device.') m = mujoco.MjModel.from_xml_string(_MULTIPLE_CONSTRAINTS) d = mujoco.MjData(m) @@ -623,8 +613,6 @@ def test_put_data_warp_ndim(self): """Tests that put_data produces expected dimensions for Warp fields.""" if not mjxw.WARP_INSTALLED: self.skipTest('Warp is not installed.') - if not mjx_io.has_cuda_gpu_device(): - self.skipTest('No CUDA GPU device.') m = mujoco.MjModel.from_xml_string(_MULTIPLE_CONSTRAINTS) d = mujoco.MjData(m) @@ -792,8 +780,6 @@ def test_get_data_into_warp(self): # and remove this test. if not mjxw.WARP_INSTALLED: self.skipTest('Warp is not installed.') - if not mjx_io.has_cuda_gpu_device(): - self.skipTest('No CUDA GPU device.') m = mujoco.MjModel.from_xml_string('') d = mujoco.MjData(m) @@ -869,8 +855,6 @@ def test_make_data_warp_has_expected_shapes(self): """Tests that make_data produces expected shapes for MuJoCo Warp.""" if not mjxw.WARP_INSTALLED: self.skipTest('Warp is not installed.') - if not mjx_io.has_cuda_gpu_device(): - self.skipTest('No CUDA GPU device.') m = mujoco.MjModel.from_xml_string(_MULTIPLE_CONSTRAINTS) dx = mjx.make_data(m, impl='warp') @@ -893,8 +877,6 @@ def test_data_slice(self, impl): """Tests that slice on Data works as expected.""" if impl == 'warp' and not mjxw.WARP_INSTALLED: self.skipTest('Warp is not installed.') - if impl == 'warp' and not mjx_io.has_cuda_gpu_device(): - self.skipTest('No CUDA GPU device.') m = mujoco.MjModel.from_xml_string(_MULTIPLE_CONSTRAINTS) dx = jax.vmap(lambda x: mjx.make_data(m, impl=impl))(jp.arange(10)) @@ -956,8 +938,8 @@ def put_data(dummy_arg_for_batching): ('gpu-nvidia', 'jax', ('gpu', Impl.JAX)), ('tpu', 'jax', ('tpu', Impl.JAX)), # WARP backend specified. - ('cpu', 'warp', ('cpu', 'error')), - ('gpu-notnvidia', 'warp', ('cpu', 'error')), + ('cpu', 'warp', ('cpu', Impl.WARP)), + ('gpu-notnvidia', 'warp', ('gpu', 'error')), ('gpu-nvidia', 'warp', ('gpu', Impl.WARP)), ('tpu', 'warp', ('tpu', 'error')), # C backend specified. @@ -984,10 +966,10 @@ def put_data(dummy_arg_for_batching): ('gpu-nvidia', 'jax', ('gpu', Impl.JAX)), ('tpu', 'jax', ('tpu', Impl.JAX)), # WARP backend impl specified. - ('cpu', 'warp', ('cpu', 'error')), - ('gpu-notnvidia', 'warp', ('cpu', 'error')), + ('cpu', 'warp', ('cpu', Impl.WARP)), + ('gpu-notnvidia', 'warp', ('cpu', Impl.WARP)), ('gpu-nvidia', 'warp', ('gpu', Impl.WARP)), - ('tpu', 'warp', ('tpu', 'error')), + ('tpu', 'warp', ('cpu', Impl.WARP)), # C backend impl specified, CPU should always be available. ('cpu', 'c', ('cpu', Impl.C)), ('gpu-notnvidia', 'c', ('cpu', Impl.C)), @@ -1162,15 +1144,6 @@ def backends_side_effect(): self.mock_jax_backends.side_effect = backends_side_effect expected_device, expected_impl = expected - if ( - expected_impl == 'error' - and default_device_str != 'gpu-nvidia' - and impl_str == 'warp' - ): - with self.assertRaisesRegex(RuntimeError, 'cuda backend not supported'): - mjx_io._resolve_impl_and_device(impl=impl_str, device=None) - return - if expected_impl == 'error': with self.assertRaises(AssertionError): mjx_io._resolve_impl_and_device(impl=impl_str, device=None) diff --git a/mjx/mujoco/mjx/third_party/warp/_src/jax_experimental/ffi.py b/mjx/mujoco/mjx/third_party/warp/_src/jax_experimental/ffi.py index f5c925dd44..83d5884eb0 100644 --- a/mjx/mujoco/mjx/third_party/warp/_src/jax_experimental/ffi.py +++ b/mjx/mujoco/mjx/third_party/warp/_src/jax_experimental/ffi.py @@ -199,6 +199,7 @@ def __init__( ffi_ccall_address = ctypes.cast(self.callback_func, ctypes.c_void_p) ffi_capsule = jax.ffi.pycapsule(ffi_ccall_address.value) jax.ffi.register_ffi_target(self.name, ffi_capsule, platform="CUDA") + jax.ffi.register_ffi_target(self.name, ffi_capsule, platform="Host") def __call__(self, *args, output_dims=None, launch_dims=None, vmap_method=None): num_inputs = len(args) @@ -297,8 +298,7 @@ def __call__(self, *args, output_dims=None, launch_dims=None, vmap_method=None): except Exception: # ignore unsupported devices like TPUs pass - # we only support CUDA devices for now - if dev.is_cuda: + if dev.is_cuda or dev.is_cpu: self.kernel.module.load(dev) # save launch data to be retrieved by callback @@ -320,10 +320,11 @@ def ffi_callback(self, call_frame): metadata_ext = ctypes.cast(extension, ctypes.POINTER(XLA_FFI_Metadata_Extension)) metadata_ext.contents.metadata.contents.api_version.major_version = 0 metadata_ext.contents.metadata.contents.api_version.minor_version = 1 - # Turn on CUDA graphs for this handler. - metadata_ext.contents.metadata.contents.traits = ( - XLA_FFI_Handler_TraitsBits.COMMAND_BUFFER_COMPATIBLE - ) + # Turn on CUDA graphs for this handler if CUDA is available. + if wp.is_cuda_available(): + metadata_ext.contents.metadata.contents.traits = ( + XLA_FFI_Handler_TraitsBits.COMMAND_BUFFER_COMPATIBLE + ) return None # Lock is required to prevent race conditions when callback is invoked @@ -377,24 +378,36 @@ def ffi_callback(self, call_frame): arg_refs.append(arg) # keep a reference # get device and stream - device = wp.get_cuda_device(get_device_ordinal_from_callframe(call_frame.contents)) - stream = get_stream_from_callframe(call_frame.contents) + if wp.is_cuda_available(): + device = wp.get_cuda_device(get_device_ordinal_from_callframe(call_frame.contents)) + stream = get_stream_from_callframe(call_frame.contents) + else: + device = wp.get_device("cpu") + stream = None # get kernel hooks hooks = self.kernel.module.get_kernel_hooks(self.kernel, device) assert hooks.forward, "Failed to find kernel entry point" # launch the kernel - wp._src.context.runtime.core.wp_cuda_launch_kernel( - device.context, - hooks.forward, - launch_bounds.size, - 0, - 256, - hooks.forward_smem_bytes, - kernel_params, - stream, - ) + if device.is_cuda: + wp._src.context.runtime.core.wp_cuda_launch_kernel( + device.context, + hooks.forward, + launch_bounds.size, + 0, + 256, + hooks.forward_smem_bytes, + kernel_params, + stream, + ) + else: + wp._src.context.runtime.core.wp_cpu_launch_kernel( + device.context, + hooks.forward, + launch_bounds.size, + kernel_params, + ) except Exception as e: print(traceback.format_exc()) @@ -552,6 +565,7 @@ def __init__( ffi_ccall_address = ctypes.cast(self.callback_func, ctypes.c_void_p) ffi_capsule = jax.ffi.pycapsule(ffi_ccall_address.value) jax.ffi.register_ffi_target(self.name, ffi_capsule, platform="CUDA") + jax.ffi.register_ffi_target(self.name, ffi_capsule, platform="Host") def __call__(self, *args, output_dims=None, vmap_method=None): num_inputs = len(args) @@ -642,8 +656,7 @@ def __call__(self, *args, output_dims=None, vmap_method=None): except Exception: # ignore unsupported devices like TPUs pass - # we only support CUDA devices for now - if dev.is_cuda: + if dev.is_cuda or dev.is_cpu: module.load(dev) # save call data to be retrieved by callback @@ -664,8 +677,8 @@ def ffi_callback(self, call_frame): metadata_ext = ctypes.cast(extension, ctypes.POINTER(XLA_FFI_Metadata_Extension)) metadata_ext.contents.metadata.contents.api_version.major_version = 0 metadata_ext.contents.metadata.contents.api_version.minor_version = 1 - # Turn on CUDA graphs for this handler. - if self.graph_mode is GraphMode.JAX: + # Turn on CUDA graphs for this handler if CUDA is available. + if self.graph_mode is GraphMode.JAX and wp.is_cuda_available(): metadata_ext.contents.metadata.contents.traits = ( XLA_FFI_Handler_TraitsBits.COMMAND_BUFFER_COMPATIBLE ) @@ -692,8 +705,12 @@ def ffi_callback(self, call_frame): assert num_inputs == self.num_inputs assert num_outputs == self.num_outputs - cuda_stream = get_stream_from_callframe(call_frame.contents) - device_ordinal = get_device_ordinal_from_callframe(call_frame.contents) + if wp.is_cuda_available(): + cuda_stream = get_stream_from_callframe(call_frame.contents) + device_ordinal = get_device_ordinal_from_callframe(call_frame.contents) + else: + cuda_stream = None + device_ordinal = 0 if self.graph_mode == GraphMode.WARP: # check if we already captured an identical call @@ -797,9 +814,12 @@ def ffi_callback(self, call_frame): # early out return - device_ordinal = get_device_ordinal_from_callframe(call_frame.contents) - device = wp.get_cuda_device(device_ordinal) - stream = wp.Stream(device, cuda_stream=cuda_stream) + if wp.is_cuda_available(): + device = wp.get_cuda_device(device_ordinal) + stream = wp.Stream(device, cuda_stream=cuda_stream) + else: + device = wp.get_device("cpu") + stream = None # reconstruct the argument list arg_list = [] @@ -824,8 +844,8 @@ def ffi_callback(self, call_frame): arg_list.append(arr) # call the Python function with reconstructed arguments - with wp.ScopedStream(stream, sync_enter=False): - if stream.is_capturing: + with wp.ScopedStream(stream, sync_enter=False) if stream else wp.ScopedDevice(device): + if stream and stream.is_capturing: # capturing with JAX with wp.ScopedCapture(external=True) as capture: self.func(*arg_list) @@ -833,7 +853,7 @@ def ffi_callback(self, call_frame): # keep a reference to the capture object to prevent required modules getting unloaded call_desc.capture = capture - elif self.graph_mode == GraphMode.WARP: + elif self.graph_mode == GraphMode.WARP and device.is_cuda: # capturing with WARP with wp.ScopedCapture() as capture: self.func(*arg_list) @@ -846,7 +866,7 @@ def ffi_callback(self, call_frame): if self._graph_cache_max is not None and len(self.captures) > self._graph_cache_max: self.captures.popitem(last=False) - elif self.graph_mode == GraphMode.WARP_STAGED_EX: + elif self.graph_mode == GraphMode.WARP_STAGED_EX and device.is_cuda: # capturing with WARP using staging buffers and memcopies done outside of the graph wp_memcpy_batch = wp._src.context.runtime.core.wp_memcpy_batch @@ -889,7 +909,7 @@ def ffi_callback(self, call_frame): # TODO: we should have a way of freeing this call_desc.capture = capture - elif self.graph_mode == GraphMode.WARP_STAGED: + elif self.graph_mode == GraphMode.WARP_STAGED and device.is_cuda: # capturing with WARP using staging buffers and memcopies done inside of the graph wp_cuda_graph_insert_memcpy_batch = ( wp._src.context.runtime.core.wp_cuda_graph_insert_memcpy_batch @@ -967,7 +987,7 @@ def ffi_callback(self, call_frame): call_desc.capture = capture else: - # not capturing + # not capturing or on CPU self.func(*arg_list) except Exception as e: @@ -1537,7 +1557,7 @@ def ffi_callback(call_frame): metadata_ext = ctypes.cast(extension, ctypes.POINTER(XLA_FFI_Metadata_Extension)) metadata_ext.contents.metadata.contents.api_version.major_version = 0 metadata_ext.contents.metadata.contents.api_version.minor_version = 1 - if graph_compatible: + if graph_compatible and wp.is_cuda_available(): # Turn on CUDA graphs for this handler. metadata_ext.contents.metadata.contents.traits = ( XLA_FFI_Handler_TraitsBits.COMMAND_BUFFER_COMPATIBLE @@ -1576,6 +1596,7 @@ def ffi_callback(call_frame): ffi_ccall_address = ctypes.cast(callback_func, ctypes.c_void_p) ffi_capsule = jax.ffi.pycapsule(ffi_ccall_address.value) jax.ffi.register_ffi_target(name, ffi_capsule, platform="CUDA") + jax.ffi.register_ffi_target(name, ffi_capsule, platform="Host") ############################################################################### diff --git a/mjx/mujoco/mjx/warp/collision_driver_test.py b/mjx/mujoco/mjx/warp/collision_driver_test.py index 36109eefb1..20246389b0 100644 --- a/mjx/mujoco/mjx/warp/collision_driver_test.py +++ b/mjx/mujoco/mjx/warp/collision_driver_test.py @@ -76,8 +76,6 @@ def test_collision_nested_vmap(self): if not _FORCE_TEST: if not mjxw.WARP_INSTALLED: self.skipTest('Warp not installed.') - if not io.has_cuda_gpu_device(): - self.skipTest('No CUDA GPU device available.') m = mujoco.MjModel.from_xml_string(self._SPHERE_SPHERE) d = mujoco.MjData(m) diff --git a/mjx/mujoco/mjx/warp/forward_test.py b/mjx/mujoco/mjx/warp/forward_test.py index 3eb936a73f..693b809a6e 100644 --- a/mjx/mujoco/mjx/warp/forward_test.py +++ b/mjx/mujoco/mjx/warp/forward_test.py @@ -67,8 +67,6 @@ def test_jit_caching(self, xml): if not _FORCE_TEST: if not mjxw.WARP_INSTALLED: self.skipTest('Warp not installed.') - if not io.has_cuda_gpu_device(): - self.skipTest('No CUDA GPU device available.') batch_size = 7 m = test_util.load_test_file(xml) @@ -100,8 +98,6 @@ def test_forward(self, xml: str, batch_size: int): if not _FORCE_TEST: if not mjxw.WARP_INSTALLED: self.skipTest('Warp not installed.') - if not io.has_cuda_gpu_device(): - self.skipTest('No CUDA GPU device available.') m = test_util.load_test_file(xml) m.opt.iterations = 10 @@ -242,8 +238,6 @@ def test_step(self, xml: str, batch_size: int, graph_mode: str): if not _FORCE_TEST: if not mjxw.WARP_INSTALLED: self.skipTest('Warp not installed.') - if not io.has_cuda_gpu_device(): - self.skipTest('No CUDA GPU device available.') m = test_util.load_test_file(xml) m.opt.iterations = 10 @@ -287,8 +281,6 @@ def test_step_leading_dim_mismatch(self): if not _FORCE_TEST: if not mjxw.WARP_INSTALLED: self.skipTest('Warp not installed.') - if not io.has_cuda_gpu_device(): - self.skipTest('No CUDA GPU device available.') xml = 'humanoid/humanoid.xml' batch_size = 7 diff --git a/mjx/mujoco/mjx/warp/smooth_test.py b/mjx/mujoco/mjx/warp/smooth_test.py index 92c028751a..a6e38cbd65 100644 --- a/mjx/mujoco/mjx/warp/smooth_test.py +++ b/mjx/mujoco/mjx/warp/smooth_test.py @@ -63,8 +63,6 @@ def test_kinematics(self): if not _FORCE_TEST: if not mjxw.WARP_INSTALLED: self.skipTest('Warp not installed.') - if not io.has_cuda_gpu_device(): - self.skipTest('No CUDA GPU device available.') m = tu.load_test_file('pendula.xml') @@ -102,10 +100,9 @@ def test_kinematics(self): def test_kinematics_vmap(self): """Tests kinematics with batched data.""" - if not mjxw.WARP_INSTALLED: - self.skipTest('Warp not installed.') - if not io.has_cuda_gpu_device(): - self.skipTest('No CUDA GPU device available.') + if not _FORCE_TEST: + if not mjxw.WARP_INSTALLED: + self.skipTest('Warp not installed.') m = tu.load_test_file('pendula.xml') @@ -149,8 +146,6 @@ def test_kinematics_nested_vmap(self): if not _FORCE_TEST: if not mjxw.WARP_INSTALLED: self.skipTest('Warp not installed.') - if not io.has_cuda_gpu_device(): - self.skipTest('No CUDA GPU device available.') m = tu.load_test_file('pendula.xml') @@ -196,8 +191,6 @@ def test_kinematics_model_vmap(self): if not _FORCE_TEST: if not mjxw.WARP_INSTALLED: self.skipTest('Warp not installed.') - if not io.has_cuda_gpu_device(): - self.skipTest('No CUDA GPU device available.') m = tu.load_test_file('pendula.xml') diff --git a/mjx/mujoco/mjx/warp/test_util.py b/mjx/mujoco/mjx/warp/test_util.py index 1752f3c943..6f7e368e10 100644 --- a/mjx/mujoco/mjx/warp/test_util.py +++ b/mjx/mujoco/mjx/warp/test_util.py @@ -153,7 +153,7 @@ def _mjx_efc(dx, worldid: int): efc_pos = select(dx._impl.efc__pos)[:nefc] efc_type = select(dx._impl.efc__type)[:nefc] efc_d = select(dx._impl.efc__D)[:nefc] - keys_sorted = np.lexsort((-efc_pos, efc_type, efc_d)) + keys_sorted = np.lexsort((-np.round(efc_pos, 4), efc_type, np.round(efc_d, 4))) keys = keys[keys_sorted] nefc = len(keys) @@ -180,7 +180,7 @@ def _mj_efc(d): else: efc_j = d.efc_J.reshape((-1, d.qvel.shape[0])) - keys = np.lexsort((-d.efc_pos, d.efc_type, d.efc_D)) + keys = np.lexsort((-np.round(d.efc_pos, 4), d.efc_type, np.round(d.efc_D, 4))) type_ = d.efc_type[keys] pos = d.efc_pos[keys] efc_j = efc_j[keys] From a081f2c1eed0cc74684d4d66d16fdaee9da84946 Mon Sep 17 00:00:00 2001 From: Tarik Kelestemur Date: Thu, 12 Mar 2026 01:29:33 -0400 Subject: [PATCH 2/3] Relax mujoco version requirement to >=3.6.0 Allow mujoco-mjx to be installed with mujoco 3.6.0 since the macOS warp changes do not depend on any 3.7.0-specific APIs. --- mjx/pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mjx/pyproject.toml b/mjx/pyproject.toml index ff7133cdd6..fb6484cefa 100644 --- a/mjx/pyproject.toml +++ b/mjx/pyproject.toml @@ -29,7 +29,7 @@ dependencies = [ "etils[epath]", "jax", "jaxlib", - "mujoco>=3.7.0.dev0", + "mujoco>=3.6.0", "scipy", "trimesh", ] From 1f1390d54b6903fc9c83f6f6c211be8261aefec0 Mon Sep 17 00:00:00 2001 From: Tarik Kelestemur Date: Thu, 12 Mar 2026 02:15:27 -0400 Subject: [PATCH 3/3] Fix render_rgb/render_depth boolean handling in create_render_context When render_rgb=False or render_depth=False, the condition `if render_rgb and isinstance(render_rgb, bool)` short-circuits because False is falsy, skipping the bool-to-list expansion. This causes a TypeError downstream. Fix by checking isinstance first. --- mjx/mujoco/mjx/third_party/mujoco_warp/_src/io.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mjx/mujoco/mjx/third_party/mujoco_warp/_src/io.py b/mjx/mujoco/mjx/third_party/mujoco_warp/_src/io.py index e6533dd6e5..2c9cc3bb38 100644 --- a/mjx/mujoco/mjx/third_party/mujoco_warp/_src/io.py +++ b/mjx/mujoco/mjx/third_party/mujoco_warp/_src/io.py @@ -2321,7 +2321,7 @@ def create_render_context( cam_res_arr = wp.array(active_cam_res, dtype=wp.vec2i) - if render_rgb and isinstance(render_rgb, bool): + if isinstance(render_rgb, bool): render_rgb = [render_rgb] * ncam elif render_rgb is None: # TODO: remove after mjwarp depends on mujoco >= 3.4.1 in pyproject.toml @@ -2330,7 +2330,7 @@ def create_render_context( else: render_rgb = [True] * ncam - if render_depth and isinstance(render_depth, bool): + if isinstance(render_depth, bool): render_depth = [render_depth] * ncam elif render_depth is None: # TODO: remove after mjwarp depends on mujoco >= 3.4.1 in pyproject.toml