Skip to content

Enable partial support for MacOS/MPS training#626

Open
dsocolobsky wants to merge 2 commits intomainfrom
dy/enable-macos-training
Open

Enable partial support for MacOS/MPS training#626
dsocolobsky wants to merge 2 commits intomainfrom
dy/enable-macos-training

Conversation

@dsocolobsky
Copy link
Copy Markdown
Contributor

Note this is not intended for production/real training but it can be really helpful in development for locally testing changes without having to resort to uploading the code to a Linux server etc. as long as you don't require heavy compute power or parallelism.

Before this PR one could compile in Mac hosts however once you joined a client it would automatically crash since it wouldn't support some pytorch operations.

This PR introduces certain compatibility changes, nothing should change for Linux/CUDA clients but avoids crashes in MacOS clients by resorting to CPU computations in certain situations. I've been able to do a test run locally with 1 client (my Macbook Air can't handle more) but also join a Mac client to a remote run in a cluster with other Linux clients.

List of changes

  • MPS does not support BFloat16, so on Mac hosts we load and train the model in Float32, but convert compressed gradients to BFloat16 before transmission to maintain compatibility with CUDA clients.
  • When compressing/decompressing gradient indices, on Mac hosts we perform the UInt16/UInt32 casts on CPU, as MPS does not support unsigned integer types beyond Uint8. The result (Uint8) is moved back to MPS, maintaining compatibility with CUDA clients.
  • Replaced certain Tensor::cat calls with a wrapper that for MPS clients performs the concatenation on CPU, working around an MPS crash.
  • For the attention computation, on MPS hosts we use Eager attention as MPS does not support the SDPA backward pass. This should produce the same values, just less efficiently.
  • In distro.rs the DCT basis matrix computation uses FFT/complex operations which MPS doesn't support, so we compute it on CPU.

How to test

Just run your usual test run (in a Mac host obviously):

  1. just setup-solana-localnet-light-test-run
  2. just start-training-localnet-light-client

It'd be also ideal to join it in a run with a Linux client to test compatibility

What still doesn't work

  • Running the decentralized tests locally, as they require building the Docker image which requires Linux.
  • Building docker images like with just nix build_docker_solana_client for the same reason.
  • Parallelism, as Macs have only one GPU.

@ethernet8023
Copy link
Copy Markdown
Contributor

ethernet8023 commented Mar 9, 2026

my understanding was that a lot of this worked ~6 months ago - #189 seemed to have it working. the bfloat16 stuff did work for a while (though emulated in software) - maybe we broke our tch-rs fork or something?
pytorch/pytorch#176296
pytorch/pytorch#176298
seems that uint ops are broken on cpu too in some cases lol

pytorch 2.9.1 should have sdpa for mps too.. see pytorch/pytorch#163598

Replaced certain Tensor::cat calls with a wrapper that for MPS clients performs the concatenation on CPU, working around an MPS crash.
oof. we got a pytorch issue open?

@dsocolobsky
Copy link
Copy Markdown
Contributor Author

Yes I remember that PR had done some good work but at some point something must've broken because I'm not able to join a run on main, also with the other PR we had problems with M3 macs and now it seems to be working.

Let me check the related PRs you've provided though, perhaps we're missing something in the fork or we can remove some of the changes here and make it work with a more proper fix. Thanks!

@dsocolobsky
Copy link
Copy Markdown
Contributor Author

okay the SDPA fix was not needed! Whatever I thought was related to that must've been fixed by one of the other fixes. Will check the rest although apparently we'll need to wait on a new release since there's an ongoing discussion.

@dsocolobsky dsocolobsky marked this pull request as ready for review March 13, 2026 17:59
enable macos/mps compatibility with other clients

revert to using SDPA in Metal as it's actually supported
@dsocolobsky dsocolobsky force-pushed the dy/enable-macos-training branch from 27e0582 to cc54945 Compare March 13, 2026 18:00
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants