Add MLX backend support for Nutpie compilation#254
Add MLX backend support for Nutpie compilation#254cetagostini wants to merge 8 commits intopymc-devs:mainfrom
Conversation
Introduces MLX as a backend option in compile_pymc_model, allowing gradient computation via MLX or Pytensor. Updates dependency groups to include MLX, extends internal functions to handle MLX mode, and adds corresponding tests for deterministic sampling with MLX.
|
Thanks, that looks great! |
Bump MLX version requirement to >=0.29.0 in pyproject.toml for all relevant extras. In compile_pymc.py, JIT compile the logp function using mx.compile for improved performance, aligning with JAX backend behavior.
Good point, that simple addition brings between 5% to 20% more performance! @aseyboldt |
|
@aseyboldt solve the test issue to work only on macs with intel chips. |
|
@aseyboldt can you give me a hand? The test failing its strange. My local pass everythig. |
aseyboldt
left a comment
There was a problem hiding this comment.
That failure is annoying. For some reason the results seem to differ between different machines? I think we really should figure out what's going on here. Maybe it helps if we print the first couple of values in warmup_posterior to see if the initial values are already different, or if small differences accumulate?
| updated.update(**updates) | ||
|
|
||
| # Convert to MLX arrays if using MLX backend (indicated by force_single_core) | ||
| if self._force_single_core: |
There was a problem hiding this comment.
We should not use that argument to detect mlx.
How about we add an attribute _convert_data_item or so to the dataclass, that contains a function that transforms data arrays? We could then also use that for jax.
Introduces MLX as a backend option in compile_pymc_model, allowing gradient computation via MLX or Pytensor. Updates dependency groups to include MLX, extends internal functions to handle MLX mode, and adds corresponding tests for deterministic sampling with MLX.