diff --git a/mjx/mujoco/mjx/_src/io.py b/mjx/mujoco/mjx/_src/io.py index 253a62a5d2..cd71ec8caf 100644 --- a/mjx/mujoco/mjx/_src/io.py +++ b/mjx/mujoco/mjx/_src/io.py @@ -90,23 +90,20 @@ def _resolve_device( logging.debug('Picking default device: %s.', device_0) return device_0 + if impl == types.Impl.WARP: + if has_cuda_gpu_device(): + cuda_gpus = jax.devices('cuda') + logging.debug('Picking default device: %s', cuda_gpus[0]) + device_0 = cuda_gpus[0] + else: + device_0 = jax.devices('cpu')[0] + return device_0 + if impl == types.Impl.C or impl == types.Impl.CPP: cpu_0 = jax.devices('cpu')[0] logging.debug('Picking default device: %s', cpu_0) return cpu_0 - if impl == types.Impl.WARP: - # WARP implementation requires a CUDA GPU. - cuda_gpus = [d for d in jax.devices('cuda')] - if not cuda_gpus: - raise AssertionError( - 'No CUDA GPU devices found in' - f' jax.devices("cuda")={jax.devices("cuda")}.' - ) - - logging.debug('Picking default device: %s', cuda_gpus[0]) - return cuda_gpus[0] - raise ValueError(f'Unsupported implementation: {impl}') @@ -121,9 +118,12 @@ def _check_impl_device_compatibility( impl = types.Impl(impl) if impl == types.Impl.WARP: - if not _is_cuda_gpu_device(device): + is_cuda_device = _is_cuda_gpu_device(device) + is_cpu_device = device.platform == 'cpu' + if not (is_cuda_device or is_cpu_device): raise AssertionError( - f'Warp implementation requires a CUDA GPU device, got {device}.' + 'Warp implementation requires a CUDA GPU or CPU device, got ' + f'{device}.' ) _check_warp_installed() diff --git a/mjx/mujoco/mjx/_src/io_test.py b/mjx/mujoco/mjx/_src/io_test.py index 9910e9b594..fdf05e73aa 100644 --- a/mjx/mujoco/mjx/_src/io_test.py +++ b/mjx/mujoco/mjx/_src/io_test.py @@ -142,8 +142,6 @@ def setUp(self): def test_put_model(self, xml, impl): if impl == 'warp' and not mjxw.WARP_INSTALLED: self.skipTest('Warp not installed.') - if impl == 'warp' and not mjx_io.has_cuda_gpu_device(): - self.skipTest('No CUDA GPU device available.') m = mujoco.MjModel.from_xml_string(xml) mx = mjx.put_model(m, impl=impl) @@ -315,8 +313,6 @@ def test_put_model_warp_has_expected_shapes(self): """Tests that put_model produces expected shapes for MuJoCo Warp.""" if not mjxw.WARP_INSTALLED: self.skipTest('Warp not installed.') - if not mjx_io.has_cuda_gpu_device(): - self.skipTest('No CUDA GPU device available.') m = mujoco.MjModel.from_xml_string(_MULTIPLE_CONSTRAINTS) mx = mjx.put_model(m, impl='warp') @@ -339,8 +335,6 @@ def test_put_model_warp_graph_mode(self, mode: str | None): """Tests that put_model accepts graph_mode parameter.""" if not mjxw.WARP_INSTALLED: self.skipTest('Warp not installed.') - if not mjx_io.has_cuda_gpu_device(): - self.skipTest('No CUDA GPU device available.') if mode is None: graph_mode = None @@ -487,8 +481,6 @@ def test_make_data(self, impl: str): def test_make_data_warp(self): if not mjxw.WARP_INSTALLED: self.skipTest('Warp is not installed.') - if not mjx_io.has_cuda_gpu_device(): - self.skipTest('No CUDA GPU device.') m = mujoco.MjModel.from_xml_string(_MULTIPLE_CONVEX_OBJECTS) d = mjx.make_data(m, impl='warp', nconmax=9, njmax=23) self.assertEqual(d._impl.contact__dist.shape[0], 9) @@ -500,8 +492,6 @@ def test_put_data(self, impl: str): if impl == 'warp': if not mjxw.WARP_INSTALLED: self.skipTest('Warp is not installed.') - if not mjx_io.has_cuda_gpu_device(): - self.skipTest('No CUDA GPU device.') m = mujoco.MjModel.from_xml_string(_MULTIPLE_CONSTRAINTS) d = mujoco.MjData(m) @@ -623,8 +613,6 @@ def test_put_data_warp_ndim(self): """Tests that put_data produces expected dimensions for Warp fields.""" if not mjxw.WARP_INSTALLED: self.skipTest('Warp is not installed.') - if not mjx_io.has_cuda_gpu_device(): - self.skipTest('No CUDA GPU device.') m = mujoco.MjModel.from_xml_string(_MULTIPLE_CONSTRAINTS) d = mujoco.MjData(m) @@ -792,8 +780,6 @@ def test_get_data_into_warp(self): # and remove this test. if not mjxw.WARP_INSTALLED: self.skipTest('Warp is not installed.') - if not mjx_io.has_cuda_gpu_device(): - self.skipTest('No CUDA GPU device.') m = mujoco.MjModel.from_xml_string('') d = mujoco.MjData(m) @@ -869,8 +855,6 @@ def test_make_data_warp_has_expected_shapes(self): """Tests that make_data produces expected shapes for MuJoCo Warp.""" if not mjxw.WARP_INSTALLED: self.skipTest('Warp is not installed.') - if not mjx_io.has_cuda_gpu_device(): - self.skipTest('No CUDA GPU device.') m = mujoco.MjModel.from_xml_string(_MULTIPLE_CONSTRAINTS) dx = mjx.make_data(m, impl='warp') @@ -893,8 +877,6 @@ def test_data_slice(self, impl): """Tests that slice on Data works as expected.""" if impl == 'warp' and not mjxw.WARP_INSTALLED: self.skipTest('Warp is not installed.') - if impl == 'warp' and not mjx_io.has_cuda_gpu_device(): - self.skipTest('No CUDA GPU device.') m = mujoco.MjModel.from_xml_string(_MULTIPLE_CONSTRAINTS) dx = jax.vmap(lambda x: mjx.make_data(m, impl=impl))(jp.arange(10)) @@ -956,8 +938,8 @@ def put_data(dummy_arg_for_batching): ('gpu-nvidia', 'jax', ('gpu', Impl.JAX)), ('tpu', 'jax', ('tpu', Impl.JAX)), # WARP backend specified. - ('cpu', 'warp', ('cpu', 'error')), - ('gpu-notnvidia', 'warp', ('cpu', 'error')), + ('cpu', 'warp', ('cpu', Impl.WARP)), + ('gpu-notnvidia', 'warp', ('gpu', 'error')), ('gpu-nvidia', 'warp', ('gpu', Impl.WARP)), ('tpu', 'warp', ('tpu', 'error')), # C backend specified. @@ -984,10 +966,10 @@ def put_data(dummy_arg_for_batching): ('gpu-nvidia', 'jax', ('gpu', Impl.JAX)), ('tpu', 'jax', ('tpu', Impl.JAX)), # WARP backend impl specified. - ('cpu', 'warp', ('cpu', 'error')), - ('gpu-notnvidia', 'warp', ('cpu', 'error')), + ('cpu', 'warp', ('cpu', Impl.WARP)), + ('gpu-notnvidia', 'warp', ('cpu', Impl.WARP)), ('gpu-nvidia', 'warp', ('gpu', Impl.WARP)), - ('tpu', 'warp', ('tpu', 'error')), + ('tpu', 'warp', ('cpu', Impl.WARP)), # C backend impl specified, CPU should always be available. ('cpu', 'c', ('cpu', Impl.C)), ('gpu-notnvidia', 'c', ('cpu', Impl.C)), @@ -1162,15 +1144,6 @@ def backends_side_effect(): self.mock_jax_backends.side_effect = backends_side_effect expected_device, expected_impl = expected - if ( - expected_impl == 'error' - and default_device_str != 'gpu-nvidia' - and impl_str == 'warp' - ): - with self.assertRaisesRegex(RuntimeError, 'cuda backend not supported'): - mjx_io._resolve_impl_and_device(impl=impl_str, device=None) - return - if expected_impl == 'error': with self.assertRaises(AssertionError): mjx_io._resolve_impl_and_device(impl=impl_str, device=None) diff --git a/mjx/mujoco/mjx/third_party/mujoco_warp/_src/io.py b/mjx/mujoco/mjx/third_party/mujoco_warp/_src/io.py index e6533dd6e5..2c9cc3bb38 100644 --- a/mjx/mujoco/mjx/third_party/mujoco_warp/_src/io.py +++ b/mjx/mujoco/mjx/third_party/mujoco_warp/_src/io.py @@ -2321,7 +2321,7 @@ def create_render_context( cam_res_arr = wp.array(active_cam_res, dtype=wp.vec2i) - if render_rgb and isinstance(render_rgb, bool): + if isinstance(render_rgb, bool): render_rgb = [render_rgb] * ncam elif render_rgb is None: # TODO: remove after mjwarp depends on mujoco >= 3.4.1 in pyproject.toml @@ -2330,7 +2330,7 @@ def create_render_context( else: render_rgb = [True] * ncam - if render_depth and isinstance(render_depth, bool): + if isinstance(render_depth, bool): render_depth = [render_depth] * ncam elif render_depth is None: # TODO: remove after mjwarp depends on mujoco >= 3.4.1 in pyproject.toml diff --git a/mjx/mujoco/mjx/third_party/warp/_src/jax_experimental/ffi.py b/mjx/mujoco/mjx/third_party/warp/_src/jax_experimental/ffi.py index f5c925dd44..83d5884eb0 100644 --- a/mjx/mujoco/mjx/third_party/warp/_src/jax_experimental/ffi.py +++ b/mjx/mujoco/mjx/third_party/warp/_src/jax_experimental/ffi.py @@ -199,6 +199,7 @@ def __init__( ffi_ccall_address = ctypes.cast(self.callback_func, ctypes.c_void_p) ffi_capsule = jax.ffi.pycapsule(ffi_ccall_address.value) jax.ffi.register_ffi_target(self.name, ffi_capsule, platform="CUDA") + jax.ffi.register_ffi_target(self.name, ffi_capsule, platform="Host") def __call__(self, *args, output_dims=None, launch_dims=None, vmap_method=None): num_inputs = len(args) @@ -297,8 +298,7 @@ def __call__(self, *args, output_dims=None, launch_dims=None, vmap_method=None): except Exception: # ignore unsupported devices like TPUs pass - # we only support CUDA devices for now - if dev.is_cuda: + if dev.is_cuda or dev.is_cpu: self.kernel.module.load(dev) # save launch data to be retrieved by callback @@ -320,10 +320,11 @@ def ffi_callback(self, call_frame): metadata_ext = ctypes.cast(extension, ctypes.POINTER(XLA_FFI_Metadata_Extension)) metadata_ext.contents.metadata.contents.api_version.major_version = 0 metadata_ext.contents.metadata.contents.api_version.minor_version = 1 - # Turn on CUDA graphs for this handler. - metadata_ext.contents.metadata.contents.traits = ( - XLA_FFI_Handler_TraitsBits.COMMAND_BUFFER_COMPATIBLE - ) + # Turn on CUDA graphs for this handler if CUDA is available. + if wp.is_cuda_available(): + metadata_ext.contents.metadata.contents.traits = ( + XLA_FFI_Handler_TraitsBits.COMMAND_BUFFER_COMPATIBLE + ) return None # Lock is required to prevent race conditions when callback is invoked @@ -377,24 +378,36 @@ def ffi_callback(self, call_frame): arg_refs.append(arg) # keep a reference # get device and stream - device = wp.get_cuda_device(get_device_ordinal_from_callframe(call_frame.contents)) - stream = get_stream_from_callframe(call_frame.contents) + if wp.is_cuda_available(): + device = wp.get_cuda_device(get_device_ordinal_from_callframe(call_frame.contents)) + stream = get_stream_from_callframe(call_frame.contents) + else: + device = wp.get_device("cpu") + stream = None # get kernel hooks hooks = self.kernel.module.get_kernel_hooks(self.kernel, device) assert hooks.forward, "Failed to find kernel entry point" # launch the kernel - wp._src.context.runtime.core.wp_cuda_launch_kernel( - device.context, - hooks.forward, - launch_bounds.size, - 0, - 256, - hooks.forward_smem_bytes, - kernel_params, - stream, - ) + if device.is_cuda: + wp._src.context.runtime.core.wp_cuda_launch_kernel( + device.context, + hooks.forward, + launch_bounds.size, + 0, + 256, + hooks.forward_smem_bytes, + kernel_params, + stream, + ) + else: + wp._src.context.runtime.core.wp_cpu_launch_kernel( + device.context, + hooks.forward, + launch_bounds.size, + kernel_params, + ) except Exception as e: print(traceback.format_exc()) @@ -552,6 +565,7 @@ def __init__( ffi_ccall_address = ctypes.cast(self.callback_func, ctypes.c_void_p) ffi_capsule = jax.ffi.pycapsule(ffi_ccall_address.value) jax.ffi.register_ffi_target(self.name, ffi_capsule, platform="CUDA") + jax.ffi.register_ffi_target(self.name, ffi_capsule, platform="Host") def __call__(self, *args, output_dims=None, vmap_method=None): num_inputs = len(args) @@ -642,8 +656,7 @@ def __call__(self, *args, output_dims=None, vmap_method=None): except Exception: # ignore unsupported devices like TPUs pass - # we only support CUDA devices for now - if dev.is_cuda: + if dev.is_cuda or dev.is_cpu: module.load(dev) # save call data to be retrieved by callback @@ -664,8 +677,8 @@ def ffi_callback(self, call_frame): metadata_ext = ctypes.cast(extension, ctypes.POINTER(XLA_FFI_Metadata_Extension)) metadata_ext.contents.metadata.contents.api_version.major_version = 0 metadata_ext.contents.metadata.contents.api_version.minor_version = 1 - # Turn on CUDA graphs for this handler. - if self.graph_mode is GraphMode.JAX: + # Turn on CUDA graphs for this handler if CUDA is available. + if self.graph_mode is GraphMode.JAX and wp.is_cuda_available(): metadata_ext.contents.metadata.contents.traits = ( XLA_FFI_Handler_TraitsBits.COMMAND_BUFFER_COMPATIBLE ) @@ -692,8 +705,12 @@ def ffi_callback(self, call_frame): assert num_inputs == self.num_inputs assert num_outputs == self.num_outputs - cuda_stream = get_stream_from_callframe(call_frame.contents) - device_ordinal = get_device_ordinal_from_callframe(call_frame.contents) + if wp.is_cuda_available(): + cuda_stream = get_stream_from_callframe(call_frame.contents) + device_ordinal = get_device_ordinal_from_callframe(call_frame.contents) + else: + cuda_stream = None + device_ordinal = 0 if self.graph_mode == GraphMode.WARP: # check if we already captured an identical call @@ -797,9 +814,12 @@ def ffi_callback(self, call_frame): # early out return - device_ordinal = get_device_ordinal_from_callframe(call_frame.contents) - device = wp.get_cuda_device(device_ordinal) - stream = wp.Stream(device, cuda_stream=cuda_stream) + if wp.is_cuda_available(): + device = wp.get_cuda_device(device_ordinal) + stream = wp.Stream(device, cuda_stream=cuda_stream) + else: + device = wp.get_device("cpu") + stream = None # reconstruct the argument list arg_list = [] @@ -824,8 +844,8 @@ def ffi_callback(self, call_frame): arg_list.append(arr) # call the Python function with reconstructed arguments - with wp.ScopedStream(stream, sync_enter=False): - if stream.is_capturing: + with wp.ScopedStream(stream, sync_enter=False) if stream else wp.ScopedDevice(device): + if stream and stream.is_capturing: # capturing with JAX with wp.ScopedCapture(external=True) as capture: self.func(*arg_list) @@ -833,7 +853,7 @@ def ffi_callback(self, call_frame): # keep a reference to the capture object to prevent required modules getting unloaded call_desc.capture = capture - elif self.graph_mode == GraphMode.WARP: + elif self.graph_mode == GraphMode.WARP and device.is_cuda: # capturing with WARP with wp.ScopedCapture() as capture: self.func(*arg_list) @@ -846,7 +866,7 @@ def ffi_callback(self, call_frame): if self._graph_cache_max is not None and len(self.captures) > self._graph_cache_max: self.captures.popitem(last=False) - elif self.graph_mode == GraphMode.WARP_STAGED_EX: + elif self.graph_mode == GraphMode.WARP_STAGED_EX and device.is_cuda: # capturing with WARP using staging buffers and memcopies done outside of the graph wp_memcpy_batch = wp._src.context.runtime.core.wp_memcpy_batch @@ -889,7 +909,7 @@ def ffi_callback(self, call_frame): # TODO: we should have a way of freeing this call_desc.capture = capture - elif self.graph_mode == GraphMode.WARP_STAGED: + elif self.graph_mode == GraphMode.WARP_STAGED and device.is_cuda: # capturing with WARP using staging buffers and memcopies done inside of the graph wp_cuda_graph_insert_memcpy_batch = ( wp._src.context.runtime.core.wp_cuda_graph_insert_memcpy_batch @@ -967,7 +987,7 @@ def ffi_callback(self, call_frame): call_desc.capture = capture else: - # not capturing + # not capturing or on CPU self.func(*arg_list) except Exception as e: @@ -1537,7 +1557,7 @@ def ffi_callback(call_frame): metadata_ext = ctypes.cast(extension, ctypes.POINTER(XLA_FFI_Metadata_Extension)) metadata_ext.contents.metadata.contents.api_version.major_version = 0 metadata_ext.contents.metadata.contents.api_version.minor_version = 1 - if graph_compatible: + if graph_compatible and wp.is_cuda_available(): # Turn on CUDA graphs for this handler. metadata_ext.contents.metadata.contents.traits = ( XLA_FFI_Handler_TraitsBits.COMMAND_BUFFER_COMPATIBLE @@ -1576,6 +1596,7 @@ def ffi_callback(call_frame): ffi_ccall_address = ctypes.cast(callback_func, ctypes.c_void_p) ffi_capsule = jax.ffi.pycapsule(ffi_ccall_address.value) jax.ffi.register_ffi_target(name, ffi_capsule, platform="CUDA") + jax.ffi.register_ffi_target(name, ffi_capsule, platform="Host") ############################################################################### diff --git a/mjx/mujoco/mjx/warp/collision_driver_test.py b/mjx/mujoco/mjx/warp/collision_driver_test.py index 36109eefb1..20246389b0 100644 --- a/mjx/mujoco/mjx/warp/collision_driver_test.py +++ b/mjx/mujoco/mjx/warp/collision_driver_test.py @@ -76,8 +76,6 @@ def test_collision_nested_vmap(self): if not _FORCE_TEST: if not mjxw.WARP_INSTALLED: self.skipTest('Warp not installed.') - if not io.has_cuda_gpu_device(): - self.skipTest('No CUDA GPU device available.') m = mujoco.MjModel.from_xml_string(self._SPHERE_SPHERE) d = mujoco.MjData(m) diff --git a/mjx/mujoco/mjx/warp/forward_test.py b/mjx/mujoco/mjx/warp/forward_test.py index 3eb936a73f..693b809a6e 100644 --- a/mjx/mujoco/mjx/warp/forward_test.py +++ b/mjx/mujoco/mjx/warp/forward_test.py @@ -67,8 +67,6 @@ def test_jit_caching(self, xml): if not _FORCE_TEST: if not mjxw.WARP_INSTALLED: self.skipTest('Warp not installed.') - if not io.has_cuda_gpu_device(): - self.skipTest('No CUDA GPU device available.') batch_size = 7 m = test_util.load_test_file(xml) @@ -100,8 +98,6 @@ def test_forward(self, xml: str, batch_size: int): if not _FORCE_TEST: if not mjxw.WARP_INSTALLED: self.skipTest('Warp not installed.') - if not io.has_cuda_gpu_device(): - self.skipTest('No CUDA GPU device available.') m = test_util.load_test_file(xml) m.opt.iterations = 10 @@ -242,8 +238,6 @@ def test_step(self, xml: str, batch_size: int, graph_mode: str): if not _FORCE_TEST: if not mjxw.WARP_INSTALLED: self.skipTest('Warp not installed.') - if not io.has_cuda_gpu_device(): - self.skipTest('No CUDA GPU device available.') m = test_util.load_test_file(xml) m.opt.iterations = 10 @@ -287,8 +281,6 @@ def test_step_leading_dim_mismatch(self): if not _FORCE_TEST: if not mjxw.WARP_INSTALLED: self.skipTest('Warp not installed.') - if not io.has_cuda_gpu_device(): - self.skipTest('No CUDA GPU device available.') xml = 'humanoid/humanoid.xml' batch_size = 7 diff --git a/mjx/mujoco/mjx/warp/smooth_test.py b/mjx/mujoco/mjx/warp/smooth_test.py index 92c028751a..a6e38cbd65 100644 --- a/mjx/mujoco/mjx/warp/smooth_test.py +++ b/mjx/mujoco/mjx/warp/smooth_test.py @@ -63,8 +63,6 @@ def test_kinematics(self): if not _FORCE_TEST: if not mjxw.WARP_INSTALLED: self.skipTest('Warp not installed.') - if not io.has_cuda_gpu_device(): - self.skipTest('No CUDA GPU device available.') m = tu.load_test_file('pendula.xml') @@ -102,10 +100,9 @@ def test_kinematics(self): def test_kinematics_vmap(self): """Tests kinematics with batched data.""" - if not mjxw.WARP_INSTALLED: - self.skipTest('Warp not installed.') - if not io.has_cuda_gpu_device(): - self.skipTest('No CUDA GPU device available.') + if not _FORCE_TEST: + if not mjxw.WARP_INSTALLED: + self.skipTest('Warp not installed.') m = tu.load_test_file('pendula.xml') @@ -149,8 +146,6 @@ def test_kinematics_nested_vmap(self): if not _FORCE_TEST: if not mjxw.WARP_INSTALLED: self.skipTest('Warp not installed.') - if not io.has_cuda_gpu_device(): - self.skipTest('No CUDA GPU device available.') m = tu.load_test_file('pendula.xml') @@ -196,8 +191,6 @@ def test_kinematics_model_vmap(self): if not _FORCE_TEST: if not mjxw.WARP_INSTALLED: self.skipTest('Warp not installed.') - if not io.has_cuda_gpu_device(): - self.skipTest('No CUDA GPU device available.') m = tu.load_test_file('pendula.xml') diff --git a/mjx/mujoco/mjx/warp/test_util.py b/mjx/mujoco/mjx/warp/test_util.py index 1752f3c943..6f7e368e10 100644 --- a/mjx/mujoco/mjx/warp/test_util.py +++ b/mjx/mujoco/mjx/warp/test_util.py @@ -153,7 +153,7 @@ def _mjx_efc(dx, worldid: int): efc_pos = select(dx._impl.efc__pos)[:nefc] efc_type = select(dx._impl.efc__type)[:nefc] efc_d = select(dx._impl.efc__D)[:nefc] - keys_sorted = np.lexsort((-efc_pos, efc_type, efc_d)) + keys_sorted = np.lexsort((-np.round(efc_pos, 4), efc_type, np.round(efc_d, 4))) keys = keys[keys_sorted] nefc = len(keys) @@ -180,7 +180,7 @@ def _mj_efc(d): else: efc_j = d.efc_J.reshape((-1, d.qvel.shape[0])) - keys = np.lexsort((-d.efc_pos, d.efc_type, d.efc_D)) + keys = np.lexsort((-np.round(d.efc_pos, 4), d.efc_type, np.round(d.efc_D, 4))) type_ = d.efc_type[keys] pos = d.efc_pos[keys] efc_j = efc_j[keys] diff --git a/mjx/pyproject.toml b/mjx/pyproject.toml index ff7133cdd6..fb6484cefa 100644 --- a/mjx/pyproject.toml +++ b/mjx/pyproject.toml @@ -29,7 +29,7 @@ dependencies = [ "etils[epath]", "jax", "jaxlib", - "mujoco>=3.7.0.dev0", + "mujoco>=3.6.0", "scipy", "trimesh", ]