From 721a12f2cbde36ff713ae985dd33d9d6f6adbe2f Mon Sep 17 00:00:00 2001 From: Kristian Hartikainen Date: Sun, 16 Nov 2025 10:06:30 -0500 Subject: [PATCH 1/5] Enable `mjw` on `MacOS` This is not ready yet but just demonstrates that it should be possible to enable warp backend on MacOS. ```console $ uv pip install -U --extra-index-url="https://py.mujoco.org" "mujoco>=3.7.0.dev0,<3.8.0" && \ uv pip install -U warp-lang && \ uv pip install -U -e ./mjx /Users/google-deepmind/mujoco/main/mjx/mujoco/mjx/_src/io_test.py:472: DeprecationWarning: nconmax will be deprecated in mujoco-mjx>=3.5. Use naconmax instead. d = mjx.make_data(m, impl='warp', nconmax=9, njmax=23) Warp 1.12.0 initialized: CUDA not enabled in this build Devices: "cpu" : "arm" Kernel cache: /private/var/folders/ml/rrlg98ln26l7xvxgq_yqfn4c0000gn/T/tmpoklv_7xg Warp DeprecationWarning: The symbol `warp.types.warp_type_to_np_dtype` will soon be removed from the public API. It can still be accessed from `warp._src.types.warp_type_to_np_dtype` but might be changed or removed without notice. ./opt/homebrew/Cellar/python@3.12/3.12.10/Frameworks/Python.framework/Versions/3.12/lib/python3.12/tempfile.py:940: ResourceWarning: Implicitly cleaning up _warnings.warn(warn_message, ResourceWarning) ---------------------------------------------------------------------- Ran 1 test in 1.105s OK ``` --- mjx/mujoco/mjx/_src/io.py | 28 +++++++++---------- mjx/mujoco/mjx/_src/io_test.py | 29 ++++---------------- mjx/mujoco/mjx/warp/collision_driver_test.py | 2 -- mjx/mujoco/mjx/warp/forward_test.py | 8 ------ mjx/mujoco/mjx/warp/smooth_test.py | 13 ++------- 5 files changed, 22 insertions(+), 58 deletions(-) diff --git a/mjx/mujoco/mjx/_src/io.py b/mjx/mujoco/mjx/_src/io.py index d9ef50b359..2756bad411 100644 --- a/mjx/mujoco/mjx/_src/io.py +++ b/mjx/mujoco/mjx/_src/io.py @@ -90,23 +90,20 @@ def _resolve_device( logging.debug('Picking default device: %s.', device_0) return device_0 + if impl == types.Impl.WARP: + if has_cuda_gpu_device(): + cuda_gpus = jax.devices('cuda') + logging.debug('Picking default device: %s', cuda_gpus[0]) + device_0 = cuda_gpus[0] + else: + device_0 = jax.devices('cpu')[0] + return device_0 + if impl == types.Impl.CPP: cpu_0 = jax.devices('cpu')[0] logging.debug('Picking default device: %s', cpu_0) return cpu_0 - if impl == types.Impl.WARP: - # WARP implementation requires a CUDA GPU. - cuda_gpus = [d for d in jax.devices('cuda')] - if not cuda_gpus: - raise AssertionError( - 'No CUDA GPU devices found in' - f' jax.devices("cuda")={jax.devices("cuda")}.' - ) - - logging.debug('Picking default device: %s', cuda_gpus[0]) - return cuda_gpus[0] - raise ValueError(f'Unsupported implementation: {impl}') @@ -121,9 +118,12 @@ def _check_impl_device_compatibility( impl = types.Impl(impl) if impl == types.Impl.WARP: - if not _is_cuda_gpu_device(device): + is_cuda_device = _is_cuda_gpu_device(device) + is_cpu_device = device.platform == 'cpu' + if not (is_cuda_device or is_cpu_device): raise AssertionError( - f'Warp implementation requires a CUDA GPU device, got {device}.' + 'Warp implementation requires a CUDA GPU or CPU device, got ' + f'{device}.' ) _check_warp_installed() diff --git a/mjx/mujoco/mjx/_src/io_test.py b/mjx/mujoco/mjx/_src/io_test.py index dfcb541d70..1adcc98960 100644 --- a/mjx/mujoco/mjx/_src/io_test.py +++ b/mjx/mujoco/mjx/_src/io_test.py @@ -142,8 +142,6 @@ def setUp(self): def test_put_model(self, xml, impl): if impl == 'warp' and not mjxw.WARP_INSTALLED: self.skipTest('Warp not installed.') - if impl == 'warp' and not mjx_io.has_cuda_gpu_device(): - self.skipTest('No CUDA GPU device available.') m = mujoco.MjModel.from_xml_string(xml) mx = mjx.put_model(m, impl=impl) @@ -311,8 +309,6 @@ def test_put_model_warp_has_expected_shapes(self): """Tests that put_model produces expected shapes for MuJoCo Warp.""" if not mjxw.WARP_INSTALLED: self.skipTest('Warp not installed.') - if not mjx_io.has_cuda_gpu_device(): - self.skipTest('No CUDA GPU device available.') m = mujoco.MjModel.from_xml_string(_MULTIPLE_CONSTRAINTS) mx = mjx.put_model(m, impl='warp') @@ -472,8 +468,6 @@ def test_make_data(self, impl: str): def test_make_data_warp(self): if not mjxw.WARP_INSTALLED: self.skipTest('Warp is not installed.') - if not mjx_io.has_cuda_gpu_device(): - self.skipTest('No CUDA GPU device.') m = mujoco.MjModel.from_xml_string(_MULTIPLE_CONVEX_OBJECTS) d = mjx.make_data(m, impl='warp', nconmax=9, njmax=23) self.assertEqual(d._impl.contact__dist.shape[0], 9) @@ -847,8 +841,6 @@ def test_make_data_warp_has_expected_shapes(self): """Tests that make_data produces expected shapes for MuJoCo Warp.""" if not mjxw.WARP_INSTALLED: self.skipTest('Warp is not installed.') - if not mjx_io.has_cuda_gpu_device(): - self.skipTest('No CUDA GPU device.') m = mujoco.MjModel.from_xml_string(_MULTIPLE_CONSTRAINTS) dx = mjx.make_data(m, impl='warp') @@ -871,8 +863,6 @@ def test_data_slice(self, impl): """Tests that slice on Data works as expected.""" if impl == 'warp' and not mjxw.WARP_INSTALLED: self.skipTest('Warp is not installed.') - if impl == 'warp' and not mjx_io.has_cuda_gpu_device(): - self.skipTest('No CUDA GPU device.') m = mujoco.MjModel.from_xml_string(_MULTIPLE_CONSTRAINTS) dx = jax.vmap(lambda x: mjx.make_data(m, impl=impl))(jp.arange(10)) @@ -934,8 +924,8 @@ def put_data(dummy_arg_for_batching): ('gpu-nvidia', 'jax', ('gpu', Impl.JAX)), ('tpu', 'jax', ('tpu', Impl.JAX)), # WARP backend specified. - ('cpu', 'warp', ('cpu', 'error')), - ('gpu-notnvidia', 'warp', ('cpu', 'error')), + ('cpu', 'warp', ('cpu', Impl.WARP)), + ('gpu-notnvidia', 'warp', ('gpu', 'error')), ('gpu-nvidia', 'warp', ('gpu', Impl.WARP)), ('tpu', 'warp', ('tpu', 'error')), # CPP backend specified. @@ -962,10 +952,10 @@ def put_data(dummy_arg_for_batching): ('gpu-nvidia', 'jax', ('gpu', Impl.JAX)), ('tpu', 'jax', ('tpu', Impl.JAX)), # WARP backend impl specified. - ('cpu', 'warp', ('cpu', 'error')), - ('gpu-notnvidia', 'warp', ('cpu', 'error')), + ('cpu', 'warp', ('cpu', Impl.WARP)), + ('gpu-notnvidia', 'warp', ('cpu', Impl.WARP)), ('gpu-nvidia', 'warp', ('gpu', Impl.WARP)), - ('tpu', 'warp', ('tpu', 'error')), + ('tpu', 'warp', ('cpu', Impl.WARP)), # CPP backend impl specified, CPU should always be available. ('cpu', 'cpp', ('cpu', Impl.CPP)), ('gpu-notnvidia', 'cpp', ('cpu', Impl.CPP)), @@ -1140,15 +1130,6 @@ def backends_side_effect(): self.mock_jax_backends.side_effect = backends_side_effect expected_device, expected_impl = expected - if ( - expected_impl == 'error' - and default_device_str != 'gpu-nvidia' - and impl_str == 'warp' - ): - with self.assertRaisesRegex(RuntimeError, 'cuda backend not supported'): - mjx_io._resolve_impl_and_device(impl=impl_str, device=None) - return - if expected_impl == 'error': with self.assertRaises(AssertionError): mjx_io._resolve_impl_and_device(impl=impl_str, device=None) diff --git a/mjx/mujoco/mjx/warp/collision_driver_test.py b/mjx/mujoco/mjx/warp/collision_driver_test.py index 36109eefb1..20246389b0 100644 --- a/mjx/mujoco/mjx/warp/collision_driver_test.py +++ b/mjx/mujoco/mjx/warp/collision_driver_test.py @@ -76,8 +76,6 @@ def test_collision_nested_vmap(self): if not _FORCE_TEST: if not mjxw.WARP_INSTALLED: self.skipTest('Warp not installed.') - if not io.has_cuda_gpu_device(): - self.skipTest('No CUDA GPU device available.') m = mujoco.MjModel.from_xml_string(self._SPHERE_SPHERE) d = mujoco.MjData(m) diff --git a/mjx/mujoco/mjx/warp/forward_test.py b/mjx/mujoco/mjx/warp/forward_test.py index 15e874e3eb..316fcf5f32 100644 --- a/mjx/mujoco/mjx/warp/forward_test.py +++ b/mjx/mujoco/mjx/warp/forward_test.py @@ -67,8 +67,6 @@ def test_jit_caching(self, xml): if not _FORCE_TEST: if not mjxw.WARP_INSTALLED: self.skipTest('Warp not installed.') - if not io.has_cuda_gpu_device(): - self.skipTest('No CUDA GPU device available.') batch_size = 7 m = test_util.load_test_file(xml) @@ -100,8 +98,6 @@ def test_forward(self, xml: str, batch_size: int): if not _FORCE_TEST: if not mjxw.WARP_INSTALLED: self.skipTest('Warp not installed.') - if not io.has_cuda_gpu_device(): - self.skipTest('No CUDA GPU device available.') m = test_util.load_test_file(xml) m.opt.iterations = 10 @@ -250,8 +246,6 @@ def test_step(self, xml: str, batch_size: int, graph_mode: str): if not _FORCE_TEST: if not mjxw.WARP_INSTALLED: self.skipTest('Warp not installed.') - if not io.has_cuda_gpu_device(): - self.skipTest('No CUDA GPU device available.') m = test_util.load_test_file(xml) m.opt.iterations = 10 @@ -295,8 +289,6 @@ def test_step_leading_dim_mismatch(self): if not _FORCE_TEST: if not mjxw.WARP_INSTALLED: self.skipTest('Warp not installed.') - if not io.has_cuda_gpu_device(): - self.skipTest('No CUDA GPU device available.') xml = 'humanoid/humanoid.xml' batch_size = 7 diff --git a/mjx/mujoco/mjx/warp/smooth_test.py b/mjx/mujoco/mjx/warp/smooth_test.py index 92c028751a..a6e38cbd65 100644 --- a/mjx/mujoco/mjx/warp/smooth_test.py +++ b/mjx/mujoco/mjx/warp/smooth_test.py @@ -63,8 +63,6 @@ def test_kinematics(self): if not _FORCE_TEST: if not mjxw.WARP_INSTALLED: self.skipTest('Warp not installed.') - if not io.has_cuda_gpu_device(): - self.skipTest('No CUDA GPU device available.') m = tu.load_test_file('pendula.xml') @@ -102,10 +100,9 @@ def test_kinematics(self): def test_kinematics_vmap(self): """Tests kinematics with batched data.""" - if not mjxw.WARP_INSTALLED: - self.skipTest('Warp not installed.') - if not io.has_cuda_gpu_device(): - self.skipTest('No CUDA GPU device available.') + if not _FORCE_TEST: + if not mjxw.WARP_INSTALLED: + self.skipTest('Warp not installed.') m = tu.load_test_file('pendula.xml') @@ -149,8 +146,6 @@ def test_kinematics_nested_vmap(self): if not _FORCE_TEST: if not mjxw.WARP_INSTALLED: self.skipTest('Warp not installed.') - if not io.has_cuda_gpu_device(): - self.skipTest('No CUDA GPU device available.') m = tu.load_test_file('pendula.xml') @@ -196,8 +191,6 @@ def test_kinematics_model_vmap(self): if not _FORCE_TEST: if not mjxw.WARP_INSTALLED: self.skipTest('Warp not installed.') - if not io.has_cuda_gpu_device(): - self.skipTest('No CUDA GPU device available.') m = tu.load_test_file('pendula.xml') From 462333e456c297d7529f289d5b1324e9941a435e Mon Sep 17 00:00:00 2001 From: Kristian Hartikainen Date: Sat, 31 Jan 2026 19:15:10 -0500 Subject: [PATCH 2/5] Enable jax-warp ffi for cpu --- .../warp/_src/jax_experimental/ffi.py | 85 ++++++++++++------- 1 file changed, 55 insertions(+), 30 deletions(-) diff --git a/mjx/mujoco/mjx/third_party/warp/_src/jax_experimental/ffi.py b/mjx/mujoco/mjx/third_party/warp/_src/jax_experimental/ffi.py index f5c925dd44..0982c4064c 100644 --- a/mjx/mujoco/mjx/third_party/warp/_src/jax_experimental/ffi.py +++ b/mjx/mujoco/mjx/third_party/warp/_src/jax_experimental/ffi.py @@ -199,6 +199,7 @@ def __init__( ffi_ccall_address = ctypes.cast(self.callback_func, ctypes.c_void_p) ffi_capsule = jax.ffi.pycapsule(ffi_ccall_address.value) jax.ffi.register_ffi_target(self.name, ffi_capsule, platform="CUDA") + jax.ffi.register_ffi_target(self.name, ffi_capsule, platform="Host") def __call__(self, *args, output_dims=None, launch_dims=None, vmap_method=None): num_inputs = len(args) @@ -298,6 +299,7 @@ def __call__(self, *args, output_dims=None, launch_dims=None, vmap_method=None): # ignore unsupported devices like TPUs pass # we only support CUDA devices for now + # TODO(hartikainen): Should this be `dev.is_cuda or dev.is_cpu`? if dev.is_cuda: self.kernel.module.load(dev) @@ -320,10 +322,11 @@ def ffi_callback(self, call_frame): metadata_ext = ctypes.cast(extension, ctypes.POINTER(XLA_FFI_Metadata_Extension)) metadata_ext.contents.metadata.contents.api_version.major_version = 0 metadata_ext.contents.metadata.contents.api_version.minor_version = 1 - # Turn on CUDA graphs for this handler. - metadata_ext.contents.metadata.contents.traits = ( - XLA_FFI_Handler_TraitsBits.COMMAND_BUFFER_COMPATIBLE - ) + # Turn on CUDA graphs for this handler if CUDA is available. + if wp.is_cuda_available(): + metadata_ext.contents.metadata.contents.traits = ( + XLA_FFI_Handler_TraitsBits.COMMAND_BUFFER_COMPATIBLE + ) return None # Lock is required to prevent race conditions when callback is invoked @@ -377,24 +380,36 @@ def ffi_callback(self, call_frame): arg_refs.append(arg) # keep a reference # get device and stream - device = wp.get_cuda_device(get_device_ordinal_from_callframe(call_frame.contents)) - stream = get_stream_from_callframe(call_frame.contents) + if wp.is_cuda_available(): + device = wp.get_cuda_device(get_device_ordinal_from_callframe(call_frame.contents)) + stream = get_stream_from_callframe(call_frame.contents) + else: + device = wp.get_device("cpu") + stream = None # get kernel hooks hooks = self.kernel.module.get_kernel_hooks(self.kernel, device) assert hooks.forward, "Failed to find kernel entry point" # launch the kernel - wp._src.context.runtime.core.wp_cuda_launch_kernel( - device.context, - hooks.forward, - launch_bounds.size, - 0, - 256, - hooks.forward_smem_bytes, - kernel_params, - stream, - ) + if device.is_cuda: + wp._src.context.runtime.core.wp_cuda_launch_kernel( + device.context, + hooks.forward, + launch_bounds.size, + 0, + 256, + hooks.forward_smem_bytes, + kernel_params, + stream, + ) + else: + wp._src.context.runtime.core.wp_cpu_launch_kernel( + device.context, + hooks.forward, + launch_bounds.size, + kernel_params, + ) except Exception as e: print(traceback.format_exc()) @@ -552,6 +567,7 @@ def __init__( ffi_ccall_address = ctypes.cast(self.callback_func, ctypes.c_void_p) ffi_capsule = jax.ffi.pycapsule(ffi_ccall_address.value) jax.ffi.register_ffi_target(self.name, ffi_capsule, platform="CUDA") + jax.ffi.register_ffi_target(self.name, ffi_capsule, platform="Host") def __call__(self, *args, output_dims=None, vmap_method=None): num_inputs = len(args) @@ -643,6 +659,7 @@ def __call__(self, *args, output_dims=None, vmap_method=None): # ignore unsupported devices like TPUs pass # we only support CUDA devices for now + # TODO(hartikainen): Should this be `dev.is_cuda or dev.is_cpu`? if dev.is_cuda: module.load(dev) @@ -664,8 +681,8 @@ def ffi_callback(self, call_frame): metadata_ext = ctypes.cast(extension, ctypes.POINTER(XLA_FFI_Metadata_Extension)) metadata_ext.contents.metadata.contents.api_version.major_version = 0 metadata_ext.contents.metadata.contents.api_version.minor_version = 1 - # Turn on CUDA graphs for this handler. - if self.graph_mode is GraphMode.JAX: + # Turn on CUDA graphs for this handler if CUDA is available. + if self.graph_mode is GraphMode.JAX and wp.is_cuda_available(): metadata_ext.contents.metadata.contents.traits = ( XLA_FFI_Handler_TraitsBits.COMMAND_BUFFER_COMPATIBLE ) @@ -692,8 +709,12 @@ def ffi_callback(self, call_frame): assert num_inputs == self.num_inputs assert num_outputs == self.num_outputs - cuda_stream = get_stream_from_callframe(call_frame.contents) - device_ordinal = get_device_ordinal_from_callframe(call_frame.contents) + if wp.is_cuda_available(): + cuda_stream = get_stream_from_callframe(call_frame.contents) + device_ordinal = get_device_ordinal_from_callframe(call_frame.contents) + else: + cuda_stream = None + device_ordinal = 0 if self.graph_mode == GraphMode.WARP: # check if we already captured an identical call @@ -797,9 +818,12 @@ def ffi_callback(self, call_frame): # early out return - device_ordinal = get_device_ordinal_from_callframe(call_frame.contents) - device = wp.get_cuda_device(device_ordinal) - stream = wp.Stream(device, cuda_stream=cuda_stream) + if wp.is_cuda_available(): + device = wp.get_cuda_device(device_ordinal) + stream = wp.Stream(device, cuda_stream=cuda_stream) + else: + device = wp.get_device("cpu") + stream = None # reconstruct the argument list arg_list = [] @@ -824,8 +848,8 @@ def ffi_callback(self, call_frame): arg_list.append(arr) # call the Python function with reconstructed arguments - with wp.ScopedStream(stream, sync_enter=False): - if stream.is_capturing: + with wp.ScopedStream(stream, sync_enter=False) if stream else wp.ScopedDevice(device): + if stream and stream.is_capturing: # capturing with JAX with wp.ScopedCapture(external=True) as capture: self.func(*arg_list) @@ -833,7 +857,7 @@ def ffi_callback(self, call_frame): # keep a reference to the capture object to prevent required modules getting unloaded call_desc.capture = capture - elif self.graph_mode == GraphMode.WARP: + elif self.graph_mode == GraphMode.WARP and device.is_cuda: # capturing with WARP with wp.ScopedCapture() as capture: self.func(*arg_list) @@ -846,7 +870,7 @@ def ffi_callback(self, call_frame): if self._graph_cache_max is not None and len(self.captures) > self._graph_cache_max: self.captures.popitem(last=False) - elif self.graph_mode == GraphMode.WARP_STAGED_EX: + elif self.graph_mode == GraphMode.WARP_STAGED_EX and device.is_cuda: # capturing with WARP using staging buffers and memcopies done outside of the graph wp_memcpy_batch = wp._src.context.runtime.core.wp_memcpy_batch @@ -889,7 +913,7 @@ def ffi_callback(self, call_frame): # TODO: we should have a way of freeing this call_desc.capture = capture - elif self.graph_mode == GraphMode.WARP_STAGED: + elif self.graph_mode == GraphMode.WARP_STAGED and device.is_cuda: # capturing with WARP using staging buffers and memcopies done inside of the graph wp_cuda_graph_insert_memcpy_batch = ( wp._src.context.runtime.core.wp_cuda_graph_insert_memcpy_batch @@ -967,7 +991,7 @@ def ffi_callback(self, call_frame): call_desc.capture = capture else: - # not capturing + # not capturing or on CPU self.func(*arg_list) except Exception as e: @@ -1537,7 +1561,7 @@ def ffi_callback(call_frame): metadata_ext = ctypes.cast(extension, ctypes.POINTER(XLA_FFI_Metadata_Extension)) metadata_ext.contents.metadata.contents.api_version.major_version = 0 metadata_ext.contents.metadata.contents.api_version.minor_version = 1 - if graph_compatible: + if graph_compatible and wp.is_cuda_available(): # Turn on CUDA graphs for this handler. metadata_ext.contents.metadata.contents.traits = ( XLA_FFI_Handler_TraitsBits.COMMAND_BUFFER_COMPATIBLE @@ -1576,6 +1600,7 @@ def ffi_callback(call_frame): ffi_ccall_address = ctypes.cast(callback_func, ctypes.c_void_p) ffi_capsule = jax.ffi.pycapsule(ffi_ccall_address.value) jax.ffi.register_ffi_target(name, ffi_capsule, platform="CUDA") + jax.ffi.register_ffi_target(name, ffi_capsule, platform="Host") ############################################################################### From 792da0b79946f6d3d91b954ffe48ed259bf2d655 Mon Sep 17 00:00:00 2001 From: Kristian Hartikainen Date: Thu, 19 Mar 2026 08:51:17 -0400 Subject: [PATCH 3/5] Fix jax-warp ffi --- .../warp/_src/jax_experimental/ffi.py | 74 +++++++++++-------- 1 file changed, 43 insertions(+), 31 deletions(-) diff --git a/mjx/mujoco/mjx/third_party/warp/_src/jax_experimental/ffi.py b/mjx/mujoco/mjx/third_party/warp/_src/jax_experimental/ffi.py index 0982c4064c..c71f9940fe 100644 --- a/mjx/mujoco/mjx/third_party/warp/_src/jax_experimental/ffi.py +++ b/mjx/mujoco/mjx/third_party/warp/_src/jax_experimental/ffi.py @@ -195,11 +195,15 @@ def __init__( # register the callback FFI_CCALLFUNC = ctypes.CFUNCTYPE(ctypes.c_void_p, ctypes.POINTER(XLA_FFI_CallFrame)) - self.callback_func = FFI_CCALLFUNC(lambda call_frame: self.ffi_callback(call_frame)) - ffi_ccall_address = ctypes.cast(self.callback_func, ctypes.c_void_p) - ffi_capsule = jax.ffi.pycapsule(ffi_ccall_address.value) - jax.ffi.register_ffi_target(self.name, ffi_capsule, platform="CUDA") - jax.ffi.register_ffi_target(self.name, ffi_capsule, platform="Host") + self.callback_func_cuda = FFI_CCALLFUNC(lambda call_frame: self.ffi_callback(call_frame, platform="CUDA")) + ffi_ccall_address_cuda = ctypes.cast(self.callback_func_cuda, ctypes.c_void_p) + ffi_capsule_cuda = jax.ffi.pycapsule(ffi_ccall_address_cuda.value) + jax.ffi.register_ffi_target(self.name, ffi_capsule_cuda, platform="CUDA") + + self.callback_func_host = FFI_CCALLFUNC(lambda call_frame: self.ffi_callback(call_frame, platform="Host")) + ffi_ccall_address_host = ctypes.cast(self.callback_func_host, ctypes.c_void_p) + ffi_capsule_host = jax.ffi.pycapsule(ffi_ccall_address_host.value) + jax.ffi.register_ffi_target(self.name, ffi_capsule_host, platform="Host") def __call__(self, *args, output_dims=None, launch_dims=None, vmap_method=None): num_inputs = len(args) @@ -299,8 +303,7 @@ def __call__(self, *args, output_dims=None, launch_dims=None, vmap_method=None): # ignore unsupported devices like TPUs pass # we only support CUDA devices for now - # TODO(hartikainen): Should this be `dev.is_cuda or dev.is_cpu`? - if dev.is_cuda: + if dev.is_cuda or dev.is_cpu: self.kernel.module.load(dev) # save launch data to be retrieved by callback @@ -310,7 +313,7 @@ def __call__(self, *args, output_dims=None, launch_dims=None, vmap_method=None): return call(*args, launch_id=launch_id) - def ffi_callback(self, call_frame): + def ffi_callback(self, call_frame, platform="CUDA"): try: # On the first call, XLA runtime will query the API version and traits # metadata using the |extension| field. Let us respond to that query @@ -322,8 +325,8 @@ def ffi_callback(self, call_frame): metadata_ext = ctypes.cast(extension, ctypes.POINTER(XLA_FFI_Metadata_Extension)) metadata_ext.contents.metadata.contents.api_version.major_version = 0 metadata_ext.contents.metadata.contents.api_version.minor_version = 1 - # Turn on CUDA graphs for this handler if CUDA is available. - if wp.is_cuda_available(): + # Turn on CUDA graphs for this handler if on CUDA platform. + if platform == "CUDA": metadata_ext.contents.metadata.contents.traits = ( XLA_FFI_Handler_TraitsBits.COMMAND_BUFFER_COMPATIBLE ) @@ -380,7 +383,7 @@ def ffi_callback(self, call_frame): arg_refs.append(arg) # keep a reference # get device and stream - if wp.is_cuda_available(): + if platform == "CUDA": device = wp.get_cuda_device(get_device_ordinal_from_callframe(call_frame.contents)) stream = get_stream_from_callframe(call_frame.contents) else: @@ -563,11 +566,15 @@ def __init__( # register the callback FFI_CCALLFUNC = ctypes.CFUNCTYPE(ctypes.c_void_p, ctypes.POINTER(XLA_FFI_CallFrame)) - self.callback_func = FFI_CCALLFUNC(lambda call_frame: self.ffi_callback(call_frame)) - ffi_ccall_address = ctypes.cast(self.callback_func, ctypes.c_void_p) - ffi_capsule = jax.ffi.pycapsule(ffi_ccall_address.value) - jax.ffi.register_ffi_target(self.name, ffi_capsule, platform="CUDA") - jax.ffi.register_ffi_target(self.name, ffi_capsule, platform="Host") + self.callback_func_cuda = FFI_CCALLFUNC(lambda call_frame: self.ffi_callback(call_frame, platform="CUDA")) + ffi_ccall_address_cuda = ctypes.cast(self.callback_func_cuda, ctypes.c_void_p) + ffi_capsule_cuda = jax.ffi.pycapsule(ffi_ccall_address_cuda.value) + jax.ffi.register_ffi_target(self.name, ffi_capsule_cuda, platform="CUDA") + + self.callback_func_host = FFI_CCALLFUNC(lambda call_frame: self.ffi_callback(call_frame, platform="Host")) + ffi_ccall_address_host = ctypes.cast(self.callback_func_host, ctypes.c_void_p) + ffi_capsule_host = jax.ffi.pycapsule(ffi_ccall_address_host.value) + jax.ffi.register_ffi_target(self.name, ffi_capsule_host, platform="Host") def __call__(self, *args, output_dims=None, vmap_method=None): num_inputs = len(args) @@ -659,8 +666,7 @@ def __call__(self, *args, output_dims=None, vmap_method=None): # ignore unsupported devices like TPUs pass # we only support CUDA devices for now - # TODO(hartikainen): Should this be `dev.is_cuda or dev.is_cpu`? - if dev.is_cuda: + if dev.is_cuda or dev.is_cpu: module.load(dev) # save call data to be retrieved by callback @@ -669,7 +675,7 @@ def __call__(self, *args, output_dims=None, vmap_method=None): self.call_id += 1 return call(*args, call_id=call_id) - def ffi_callback(self, call_frame): + def ffi_callback(self, call_frame, platform="CUDA"): try: # On the first call, XLA runtime will query the API version and traits # metadata using the |extension| field. Let us respond to that query @@ -681,8 +687,8 @@ def ffi_callback(self, call_frame): metadata_ext = ctypes.cast(extension, ctypes.POINTER(XLA_FFI_Metadata_Extension)) metadata_ext.contents.metadata.contents.api_version.major_version = 0 metadata_ext.contents.metadata.contents.api_version.minor_version = 1 - # Turn on CUDA graphs for this handler if CUDA is available. - if self.graph_mode is GraphMode.JAX and wp.is_cuda_available(): + # Turn on CUDA graphs for this handler if on CUDA platform. + if self.graph_mode is GraphMode.JAX and platform == "CUDA": metadata_ext.contents.metadata.contents.traits = ( XLA_FFI_Handler_TraitsBits.COMMAND_BUFFER_COMPATIBLE ) @@ -709,7 +715,7 @@ def ffi_callback(self, call_frame): assert num_inputs == self.num_inputs assert num_outputs == self.num_outputs - if wp.is_cuda_available(): + if platform == "CUDA": cuda_stream = get_stream_from_callframe(call_frame.contents) device_ordinal = get_device_ordinal_from_callframe(call_frame.contents) else: @@ -818,7 +824,7 @@ def ffi_callback(self, call_frame): # early out return - if wp.is_cuda_available(): + if platform == "CUDA": device = wp.get_cuda_device(device_ordinal) stream = wp.Stream(device, cuda_stream=cuda_stream) else: @@ -1549,7 +1555,7 @@ def register_ffi_callback(name: str, func: Callable, graph_compatible: bool = Tr # TODO check that the name is not already registered - def ffi_callback(call_frame): + def ffi_callback(call_frame, platform="CUDA"): try: extension = call_frame.contents.extension_start # On the first call, XLA runtime will query the API version and traits @@ -1561,7 +1567,7 @@ def ffi_callback(call_frame): metadata_ext = ctypes.cast(extension, ctypes.POINTER(XLA_FFI_Metadata_Extension)) metadata_ext.contents.metadata.contents.api_version.major_version = 0 metadata_ext.contents.metadata.contents.api_version.minor_version = 1 - if graph_compatible and wp.is_cuda_available(): + if graph_compatible and platform == "CUDA": # Turn on CUDA graphs for this handler. metadata_ext.contents.metadata.contents.traits = ( XLA_FFI_Handler_TraitsBits.COMMAND_BUFFER_COMPATIBLE @@ -1594,13 +1600,19 @@ def ffi_callback(call_frame): return None FFI_CCALLFUNC = ctypes.CFUNCTYPE(ctypes.c_void_p, ctypes.POINTER(XLA_FFI_CallFrame)) - callback_func = FFI_CCALLFUNC(ffi_callback) + callback_func_cuda = FFI_CCALLFUNC(lambda call_frame: ffi_callback(call_frame, platform="CUDA")) + callback_func_host = FFI_CCALLFUNC(lambda call_frame: ffi_callback(call_frame, platform="Host")) with _FFI_REGISTRY_LOCK: - _FFI_CALLBACK_REGISTRY[name] = callback_func - ffi_ccall_address = ctypes.cast(callback_func, ctypes.c_void_p) - ffi_capsule = jax.ffi.pycapsule(ffi_ccall_address.value) - jax.ffi.register_ffi_target(name, ffi_capsule, platform="CUDA") - jax.ffi.register_ffi_target(name, ffi_capsule, platform="Host") + _FFI_CALLBACK_REGISTRY[f"{name}_cuda"] = callback_func_cuda + _FFI_CALLBACK_REGISTRY[f"{name}_host"] = callback_func_host + + ffi_ccall_address_cuda = ctypes.cast(callback_func_cuda, ctypes.c_void_p) + ffi_capsule_cuda = jax.ffi.pycapsule(ffi_ccall_address_cuda.value) + jax.ffi.register_ffi_target(name, ffi_capsule_cuda, platform="CUDA") + + ffi_ccall_address_host = ctypes.cast(callback_func_host, ctypes.c_void_p) + ffi_capsule_host = jax.ffi.pycapsule(ffi_ccall_address_host.value) + jax.ffi.register_ffi_target(name, ffi_capsule_host, platform="Host") ############################################################################### From 2ad32d7cfdb140b0f1071e0885d7f98ab0bf68a2 Mon Sep 17 00:00:00 2001 From: Kristian Hartikainen Date: Thu, 19 Mar 2026 10:20:09 -0400 Subject: [PATCH 4/5] Fix unstable test constraint sorting on CPU warp When using warp on CPU, there are small numerical differences that cause the `np.lexsort` to return inconsistently-ordered `efc-pos` and `efc_D`. This changes the sorting to round the values, fixing the flakiness. --- mjx/mujoco/mjx/warp/test_util.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mjx/mujoco/mjx/warp/test_util.py b/mjx/mujoco/mjx/warp/test_util.py index 1752f3c943..3bac36f5ef 100644 --- a/mjx/mujoco/mjx/warp/test_util.py +++ b/mjx/mujoco/mjx/warp/test_util.py @@ -153,7 +153,7 @@ def _mjx_efc(dx, worldid: int): efc_pos = select(dx._impl.efc__pos)[:nefc] efc_type = select(dx._impl.efc__type)[:nefc] efc_d = select(dx._impl.efc__D)[:nefc] - keys_sorted = np.lexsort((-efc_pos, efc_type, efc_d)) + keys_sorted = np.lexsort((np.round(-efc_pos, 12), efc_type, np.round(efc_d, 12))) keys = keys[keys_sorted] nefc = len(keys) @@ -180,7 +180,7 @@ def _mj_efc(d): else: efc_j = d.efc_J.reshape((-1, d.qvel.shape[0])) - keys = np.lexsort((-d.efc_pos, d.efc_type, d.efc_D)) + keys = np.lexsort((np.round(-d.efc_pos, 12), d.efc_type, np.round(d.efc_D, 12))) type_ = d.efc_type[keys] pos = d.efc_pos[keys] efc_j = efc_j[keys] From f88be6e77472f5da07b6806c31601c881fe033b5 Mon Sep 17 00:00:00 2001 From: Kristian Hartikainen Date: Thu, 19 Mar 2026 14:02:13 -0400 Subject: [PATCH 5/5] Enable io tests on non-cuda device --- mjx/mujoco/mjx/_src/io_test.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/mjx/mujoco/mjx/_src/io_test.py b/mjx/mujoco/mjx/_src/io_test.py index 1adcc98960..05f35d9e55 100644 --- a/mjx/mujoco/mjx/_src/io_test.py +++ b/mjx/mujoco/mjx/_src/io_test.py @@ -331,8 +331,6 @@ def test_put_model_warp_graph_mode(self, mode: str | None): """Tests that put_model accepts graph_mode parameter.""" if not mjxw.WARP_INSTALLED: self.skipTest('Warp not installed.') - if not mjx_io.has_cuda_gpu_device(): - self.skipTest('No CUDA GPU device available.') if mode is None: graph_mode = None @@ -479,8 +477,6 @@ def test_put_data(self, impl: str): if impl == 'warp': if not mjxw.WARP_INSTALLED: self.skipTest('Warp is not installed.') - if not mjx_io.has_cuda_gpu_device(): - self.skipTest('No CUDA GPU device.') m = mujoco.MjModel.from_xml_string(_MULTIPLE_CONSTRAINTS) d = mujoco.MjData(m) @@ -598,8 +594,6 @@ def test_put_data_warp_ndim(self): """Tests that put_data produces expected dimensions for Warp fields.""" if not mjxw.WARP_INSTALLED: self.skipTest('Warp is not installed.') - if not mjx_io.has_cuda_gpu_device(): - self.skipTest('No CUDA GPU device.') m = mujoco.MjModel.from_xml_string(_MULTIPLE_CONSTRAINTS) d = mujoco.MjData(m) @@ -764,8 +758,6 @@ def test_get_data_into_warp(self): # and remove this test. if not mjxw.WARP_INSTALLED: self.skipTest('Warp is not installed.') - if not mjx_io.has_cuda_gpu_device(): - self.skipTest('No CUDA GPU device.') m = mujoco.MjModel.from_xml_string('') d = mujoco.MjData(m)