Enable partial support for MacOS/MPS training#626
Conversation
|
my understanding was that a lot of this worked ~6 months ago - #189 seemed to have it working. the bfloat16 stuff did work for a while (though emulated in software) - maybe we broke our tch-rs fork or something? pytorch 2.9.1 should have sdpa for mps too.. see pytorch/pytorch#163598
|
|
Yes I remember that PR had done some good work but at some point something must've broken because I'm not able to join a run on main, also with the other PR we had problems with M3 macs and now it seems to be working. Let me check the related PRs you've provided though, perhaps we're missing something in the fork or we can remove some of the changes here and make it work with a more proper fix. Thanks! |
|
okay the SDPA fix was not needed! Whatever I thought was related to that must've been fixed by one of the other fixes. Will check the rest although apparently we'll need to wait on a new release since there's an ongoing discussion. |
enable macos/mps compatibility with other clients revert to using SDPA in Metal as it's actually supported
27e0582 to
cc54945
Compare
Note this is not intended for production/real training but it can be really helpful in development for locally testing changes without having to resort to uploading the code to a Linux server etc. as long as you don't require heavy compute power or parallelism.
Before this PR one could compile in Mac hosts however once you joined a client it would automatically crash since it wouldn't support some pytorch operations.
This PR introduces certain compatibility changes, nothing should change for Linux/CUDA clients but avoids crashes in MacOS clients by resorting to CPU computations in certain situations. I've been able to do a test run locally with 1 client (my Macbook Air can't handle more) but also join a Mac client to a remote run in a cluster with other Linux clients.
List of changes
BFloat16, so on Mac hosts we load and train the model inFloat32, but convert compressed gradients toBFloat16before transmission to maintain compatibility with CUDA clients.Tensor::catcalls with a wrapper that for MPS clients performs the concatenation on CPU, working around an MPS crash.Eagerattention as MPS does not support theSDPAbackward pass. This should produce the same values, just less efficiently.distro.rsthe DCT basis matrix computation uses FFT/complex operations which MPS doesn't support, so we compute it on CPU.How to test
Just run your usual test run (in a Mac host obviously):
just setup-solana-localnet-light-test-runjust start-training-localnet-light-clientIt'd be also ideal to join it in a run with a Linux client to test compatibility
What still doesn't work
just nix build_docker_solana_clientfor the same reason.