Skip to content

Spatial Distributed class. #749

Closed
odiazib wants to merge 38 commits intomainfrom
oscar/sp-distributed-class
Closed

Spatial Distributed class. #749
odiazib wants to merge 38 commits intomainfrom
oscar/sp-distributed-class

Conversation

@odiazib
Copy link
Copy Markdown

@odiazib odiazib commented Jan 21, 2026

The SpatialTorchDistributed class is copied from the TorchDistributed class. I have not added all the methods and modifications needed for spatial parallelism. So far, I have done the following

  • The comm object (distributed object from Makani) is initialized in init.
  • The comm.py file was copied and saved in fme.core.distributed. This file requires nvidia-physicsnemo. Thus, nvidia-physicsnemo was added to the requirements.txt file.
  • The method local_batch_size was modified.

Changes:

  • symbol (e.g. fme.core.my_function) or script and concise description of changes or added feature

  • Can group multiple related symbols on a single bullet

  • Tests added

  • If dependencies changed, "deps only" image rebuilt and "latest_deps_only_image.txt" file updated

Resolves # (delete if none)

@odiazib odiazib requested review from elynnwu and mcgibbon January 21, 2026 17:08
@odiazib odiazib force-pushed the oscar/sp-distributed-class branch 3 times, most recently from c910a2a to ee2d73a Compare January 21, 2026 18:04
Comment thread fme/core/distributed/spatial_torch_distributed.py Outdated
Comment thread fme/core/distributed/comm.py Outdated
Comment thread fme/core/distributed/spatial_torch_distributed.py Outdated
Comment thread fme/core/distributed/spatial_torch_distributed.py Outdated
Comment thread fme/core/distributed/comm.py Outdated


# initialization routine
def init(model_parallel_sizes=[1, 1, 1, 1], model_parallel_names=["h", "w", "fin", "fout"], data_parallel_sizes=[1, -1], data_parallel_names=["ensemble", "batch"], verbose=False):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will we ever configure model_parallel_sizes and data_parallel_sizes to something else?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, here we will configure comm with values other than 1:

comm.init(model_parallel_sizes=params["model_parallel_sizes"], model_parallel_names=params["model_parallel_names"], verbose=False)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In particular, "h" and "w" parallelism (the main focus of this code) are done using model_parallel_sizes.

Comment thread fme/core/distributed/spatial_torch_distributed.py Outdated
Comment thread requirements.txt Outdated
@odiazib odiazib force-pushed the oscar/sp-distributed-class branch 2 times, most recently from 7bbb0a9 to a58046c Compare January 22, 2026 00:18
Copy link
Copy Markdown
Contributor

@mahf708 mahf708 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One overarching comment about naming: I think it is worthwhile considering naming this something other than "Spatial" because while we are only doing the H and W decomposition here, we may do other types of decomposition in the future. I think a good alternative is simply "Makani" and I would prefix "makani_" in front of all files copied from the nvidia/makani repo, e.g., makani_comm.py for https://github.com/NVIDIA/makani/blob/main/makani/utils/comm.py. And I would also specify atop the file the specific link and commit it came from, e.g., https://github.com/NVIDIA/makani/blob/dbcf2c1dc82cdbc544c81193eecd8ac4a6be337c/makani/utils/comm.py.

See how the team did this for sht_fix.py: https://github.com/ai2cm/ace/blob/main/fme/sht_fix.py

@odiazib
Copy link
Copy Markdown
Author

odiazib commented Jan 22, 2026

I think it is worthwhile considering naming this something other than "Spatial" because while we are only doing the H and W decomposition here, we may do other types of decomposition in the future. I think a good alternative is simply "Makani" and I would prefix "makani_" in front of all files copied from the nvidia/makani repo, e.g., makani_co

Yes, we can adapt this naming convention. @elynnwu, @mcgibbon, @oliverwm1, do you have any recommendations on how to name this class and the Makani files? I like @mahf708’s idea of prefixing the filenames with “makani” and using “MakaniTorchDistributed” as the class name.

Comment thread fme/core/distributed/comm.py Outdated
Comment thread fme/core/distributed/comm.py Outdated
Comment thread fme/core/distributed/spatial_torch_distributed.py Outdated
Comment thread fme/core/distributed/spatial_torch_distributed.py Outdated
Comment thread fme/core/distributed/spatial_torch_distributed.py Outdated
Comment on lines +52 to +54
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Potential Issue: these torch backend flags don't appear to be related to spatial parallelism.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will remove these flags.

Comment thread fme/core/distributed/spatial_torch_distributed.py Outdated
Comment thread fme/core/distributed/spatial_torch_distributed.py Outdated
Comment on lines +61 to +68
def is_available(cls) -> bool:
"""Check if torch distributed is available."""
h_parallel_size = int(os.environ.get("H_PARALLEL_SIZE", 1))
w_parallel_size = int(os.environ.get("W_PARALLEL_SIZE", 1))
spatial_parallelism=False
if (h_parallel_size>1) or (w_parallel_size >1):
spatial_parallelism=True
return torch.distributed.is_available() and spatial_parallelism
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion: Split off self._spatial_parallelism_enabled() logic into a separate helper function to make this function a bit clearer. Specifically, torch.distributed.is_available() is a very high level operation, while the spatial parallelism check here is very low-level. Splitting it to a helper keeps everything in is_available at the same level of abstraction.


def reduce_mean(self, tensor: torch.Tensor) -> torch.Tensor | None:
torch.distributed.all_reduce(tensor)
return tensor / self.total_ranks
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this correct? I would expect you need to handle spatial parallelism in these functions. Specifically, I expect you need to pass a group argument to all_reduce which tells it which spot on the globe it's in, so each process only reduces the mean in its patch. As written, I expect every rank will get the mean across all "tiles" and the spatial areas will be mixed into one.

This is done for example in physicsnemo where it uses group=DistributedManager().group(process_group) (not sure what process_group is) to set a group arg for the reduction function override for param.grad (link)

Unfortunately I can't really find any other uses of distributed reduction in the physicsnemo code, not sure if they have inline spatial-map metrics the way we have.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did not modify this method in the original branch. I added a couple of unit tests and the implementation works at least for the training part. I believe the only reductions involved are those that compute the loss and the metrics. This implementation should be fine in cases where spatial dimensions are not involved. If this approach is acceptable, I would like to address the remaining issue in a follow-up PR.

Comment thread requirements.txt Outdated
@mcgibbon
Copy link
Copy Markdown
Contributor

mcgibbon commented Jan 22, 2026

I would discourage using "makani" as the name since the code is primarily from physicsnemo (not makani). Maybe we can use "Model", which is the terminology physicsnemo uses, instead of "Spatial".

Comment thread fme/core/distributed/spatial_torch_distributed.py
@odiazib
Copy link
Copy Markdown
Author

odiazib commented Jan 27, 2026

I would discourage using "makani" as the name since the code is primarily from physicsnemo (not makani). Maybe we can use "Model", which is the terminology physicsnemo uses, instead of "Spatial".

The file comm.py was a direct copy from Makani. It uses routines from PhysicsNemo. @mahf708, what do you think about using 'Model' instead of 'Makani' for the class name (ModelTorchDistributed)? I guess we will not use 'Makani' in the name of the files. Is this correct, @mcgibbon?

@mahf708
Copy link
Copy Markdown
Contributor

mahf708 commented Jan 27, 2026

Yes, any name is good for me as long as it's clear. I defer to Jeremy and Elynn, so anything they decide is good for me

@odiazib odiazib force-pushed the oscar/sp-distributed-class branch 2 times, most recently from d24b1c2 to 2309d5e Compare February 3, 2026 22:55
return lat, lon, nlat, nlon, batch_size, input_tensor


@pytest.mark.skipif(torch.cuda.device_count() < 4, reason="requires multi-GPU machine")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you relax this criteria to 2 GPUs? We can change the test to use 1x2 instead. I believe the GPU tests will run once you do that.

@odiazib
Copy link
Copy Markdown
Author

odiazib commented Feb 4, 2026

Hi @elynnwu and @mcgibbon

Thanks for all your help reviewing this PR. I believe I’ve addressed most of your comments, except for the remark about reduce_mean in the ModelDistributedClass. For that item, I added a couple of unit tests (based on the original branch), and in those tests we don’t need to modify the reduce_mean method since we are computing global means.
I understand Jeremy’s comment (#749 (comment)), but I’d prefer to tackle that in a separate PR.

A couple of other points:

Finally, I added the new unit tests under fme/core/distributed/test_model_torch_distributed. Should I move them to a different directory?

def local_batch_size(self, batch_size: int) -> int:
return batch_size // comm.get_size("data")

def reduce_mean(self, tensor: torch.Tensor) -> torch.Tensor | None:
Copy link
Copy Markdown
Contributor

@mcgibbon mcgibbon Feb 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code is currently incorrect. Take for example a basic case where n_lat=1, n_lon=2, and W_PARALLEL_SIZE=2, running on 4 ranks. Let's say you have a batch_size=2 array with shape [n_batch, n_lat, n_lon], where the first batch member has data [[1, -1]], and the second batch member has data [[2, -2]].

If we were not using spatial parallelism, we would have two ranks with the above data, the first rank having [[1, -1]] and the second rank having [[2, -2]], and the mean across ranks would be [[1.5, -1.5]]. This is the result we want to get regardless of whether or not we're using spatial parallelism.

If you're using spatial parallelism with 4 ranks, the ranks each have data [[1]], [[-1]], [[2]], and [[-2]]. The current code will reduce this to [[0]] on all ranks, but what we actually want is that the ranks in the "west" group reduce to [[1.5]], and that the ranks in the "east" group reduce to [[-1.5]], so that if we stitched it back to the global data we'd get [[1.5, -1.5]]. The current result of [[0, 0]] is a bug.

To fix it, we have to pass the reduction operation a keyword argument telling it what "group" this process is in. There is a helper method in physicsnemo for this, that I referenced earlier.

Copy link
Copy Markdown
Contributor

@mcgibbon mcgibbon Feb 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's very important to fix this before merging those methods, so that we aren't confused by incorrect answers on main. If we just aren't using the methods yet and want to implement them later, we should have the methods raise NotImplementedError() instead of giving the answers they currently do.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mcgibbon, I reviewed the Makani implementation and found a couple of places where the “group” parameter was passed to the reduce operation. I ended up using “data” as the group in the torch.distributed.all_reduce routine. I also added a couple of unit tests for this method:

Test 1 uses your simple example.

Test 2 uses random values for the input tensor.

Comment thread fme/core/distributed/test_model_torch_distributed/test_coordinates_sp.py Outdated
Comment thread fme/core/distributed/test_model_torch_distributed/test_coordinates_sp.py Outdated
Comment thread fme/core/distributed/test_model_torch_distributed/test_coordinates_sp.py Outdated
Comment thread fme/core/distributed/test_model_torch_distributed/test_coordinates_sp.py Outdated
Comment thread fme/core/distributed/test_model_torch_distributed/test_coordinates_sp.py Outdated
Comment thread fme/core/distributed/test_model_torch_distributed/test_coordinates_sp.py Outdated
def get_local_rank(self) -> int:
return self._device_id

def get_local_slices(self, crop_shape):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion: the code here is pretty complex and this function will be widely used - add a unit test for get_local_slices. You may need to move the logic to some kind of helper function and test just the helper, since this object is difficult to construct in a test.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question: What is the meaning of "crop" shape? I thought this takes the full domain shape as its input argument.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The unit test for get_local_slices is located at here.

In this test, a tensor is created and then scattered across processes; its local slices are gathered into a new tensor. The test passes if the original tensor and the gathered tensor are equal.

I renamed crop_shape to tensor_shape.

torch.distributed.all_reduce(tensor, op=torch.distributed.ReduceOp.MAX)
return tensor

def gather(self, tensor: torch.Tensor) -> list[torch.Tensor] | None:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion: Raise NotImplementedError() instead or add a unit test for this and the other spatially-involved distributed methods (like gather_irregular).

return self.world_size

def get_local_rank(self) -> int:
return self._device_id
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion (optional): rename self._device_id to self._local_rank, since it is a little less confusing to see/read "my device id is my local rank" than reading "my local rank is my device id". Up to you, if you think the current way is clearer.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did not see _local_rank defined in the base class. Should I add it?

Comment thread fme/core/distributed/test_model_torch_distributed/test_loss_sp.py Outdated
Comment thread fme/core/distributed/test_model_torch_distributed/test_loss_sp.py Outdated
Comment thread fme/core/distributed/test_model_torch_distributed/test_loss_sp.py Outdated
Comment thread fme/core/distributed/test_model_torch_distributed/test_loss_sp.py Outdated
odiazib and others added 5 commits February 9, 2026 19:32
…ch Distributed class, but in the __init__ method, the comm object from Makani is initialized. Most of the changes required for this class have not been added yet. In addition, the comm file is also included. Finally, nvidia-physicsnemo is added to the requirements.txt file.
Co-authored-by: Jeremy McGibbon <jeremym@allenai.org>
Co-authored-by: Jeremy McGibbon <jeremym@allenai.org>
… my large PR.

Only run this test if the number of GPUs is greater than four.

First, invoke the single version in loss test.
…arios where spatial parallelism is not involved.
…the RuntimeError “Boolean value of Tensor with more than one value is ambiguous.” Therefore, we need to compare tensors using torch.testing.assert_close.
…est. Since both tests now run simultaneously, saving data to the local directory is no longer necessary.
@odiazib odiazib force-pushed the oscar/sp-distributed-class branch from 456a247 to 8c1f7d6 Compare February 10, 2026 03:33
Copy link
Copy Markdown
Contributor

@mahf708 mahf708 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great, @odiazib --- thanks a lot! I added some minor comments for your consideration, but can punt to later PRs if desired

if ModelTorchDistributed.is_available() and not force_non_distributed:
self._distributed: DistributedBackend = ModelTorchDistributed()
elif TorchDistributed.is_available() and not force_non_distributed:
self._distributed = TorchDistributed()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you wanna perserve the type like before, i.e.

Suggested change
self._distributed = TorchDistributed()
self._distributed: DistributedBackend = TorchDistributed()

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried to add the type DistributedBackend to this line. However, when running pre-commit I get:

fme/core/distributed/distributed.py:34: error: Attribute "_distributed" already defined on line 32  [no-redef]
Found 1 error in 1 file (checked 1 source file)

I ran a single test, and it ran fine.

For this reason, I removed the type.

@mcgibbon, any suggestions?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah interesting linting error. @odiazib feel free to disregard this comment and all my other comments; they were very minor

Comment thread fme/core/distributed/distributed.py Outdated
Comment thread fme/core/distributed/non_distributed.py Outdated
Comment thread fme/core/testing/distributed.py Outdated
torch.distributed.all_reduce(tensor, op=torch.distributed.ReduceOp.MIN)
return tensor

def reduce_max(self, tensor: torch.Tensor) -> torch.Tensor | None:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not important, but for all reductions, you could potentially do a generic all-purpose reduce op that can reroute to the correct one or potentially do an einops-style one, e.g., the signature

def reduce(tensor, op):
  # if op ...
  #   ...
  return tensor

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mcgibbon, any suggestions on this? I just copied and pasted from the torch.distributed class.

Copy link
Copy Markdown
Contributor

@mahf708 mahf708 Feb 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, if that follows a pattern already established, then please disregard my comment (sorry I should've checked)

Comment on lines +26 to +27
h_parallel_size = int(os.environ.get("H_PARALLEL_SIZE", 1))
w_parallel_size = int(os.environ.get("W_PARALLEL_SIZE", 1))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am tiny bit uncomfortable with changing runtime behavior with env variables, but maybe that's how it's done? I defer to Jeremy and Elynn --- I've seen stuff like this in the og distributed class, so maybe that's the preferred way here

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using environment variables was an easy way to get things started with spatial parallelism. At this point, however, I believe we can consider an alternative.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's ok to keep them (if that's the easiest thing)

Comment thread fme/core/distributed/model_torch_distributed/test_model_torch_distributed.py Outdated
Comment on lines +49 to +50
# Set H_PARALLEL_SIZE back to 1.
os.environ["H_PARALLEL_SIZE"] = "1"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why I feel a little uneasy about changing runtime behavior with env variable is needing this type of clean-up 👀

def get_local_slices(self, tensor_shape):
return tuple(slice(None, None) for _ in tensor_shape)

def reduce_mean(self, tensor: torch.Tensor) -> torch.Tensor:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Possibly a bug from previous changes, missing group=None here - looks like all reduce_mean calls have group as input

from unittest.mock import patch

import pytest
from model_torch_distributed import ModelTorchDistributed
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed to from fme.core.distributed.model_torch_distributed.model_torch_distributed import ModelTorchDistributed so that this import always works and to be consistent with your other test imports.

@@ -0,0 +1,86 @@
import logging
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you remove test_reduce_mean.py? Why are there two versions of this?

@elynnwu elynnwu closed this in #847 Feb 19, 2026
spencerkclark pushed a commit that referenced this pull request Feb 19, 2026
add ModelTorchDistributed with tests

Changes:
- fme.core.distributed has a new ModelDistributedBackend that allows for
parallelism over spatial dimensions as well as batch/data.
- torch is pinned with a minimum of 2.4.0 to use new facilities for
distributed, etc.

- [x] Tests added
- [ ] If dependencies changed, "deps only" image rebuilt and
"latest_deps_only_image.txt" file updated

Closes #749
Closes #842
William-gregory pushed a commit to William-gregory/ace that referenced this pull request Mar 3, 2026
add ModelTorchDistributed with tests

Changes:
- fme.core.distributed has a new ModelDistributedBackend that allows for
parallelism over spatial dimensions as well as batch/data.
- torch is pinned with a minimum of 2.4.0 to use new facilities for
distributed, etc.

- [x] Tests added
- [ ] If dependencies changed, "deps only" image rebuilt and
"latest_deps_only_image.txt" file updated

Closes ai2cm#749
Closes ai2cm#842
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants