Skip to content

SBRA-739 redefine the graph creation method#79

Merged
A669015 merged 13 commits intomainfrom
SBRA-739-redefine-the-graph-creation-method
Apr 24, 2025
Merged

SBRA-739 redefine the graph creation method#79
A669015 merged 13 commits intomainfrom
SBRA-739-redefine-the-graph-creation-method

Conversation

@A669015
Copy link
Contributor

@A669015 A669015 commented Feb 11, 2025

This PR refactor the code to avoid the computation of the graph topology for each sample.
The graph topology is now computed on the datamodule side at initialization, and passed to the ligthning module through the argument linking feature of ligthning. Then at training time, the dataloader only send the features to the model.

@A669015 A669015 requested a review from a team as a code owner February 11, 2025 16:23
@github-actions
Copy link

github-actions bot commented Feb 11, 2025

Coverage report for weather-forecast/gravity-wave-drag/cnns model project

File Coverage Missing
All files 100%

Minimum allowed coverage is 0%

Generated by 🐒 cobertura-action against db03d17

@github-actions
Copy link

github-actions bot commented Feb 11, 2025

Coverage report for reactive-flows/cnf-combustion/unets model project

File Coverage Missing
All files 97%
data.py 97% 134 148
unet.py 96% 95 125

Minimum allowed coverage is 0%

Generated by 🐒 cobertura-action against db03d17

@github-actions
Copy link

github-actions bot commented Feb 11, 2025

Coverage report for reactive-flows/cnf-combustion/gnns model project

File Coverage Missing
All files 94%
data.py 93% 82 160 214 222 239 245 317-321
inferer.py 89% 214-240
plotters.py 94% 156-157 160-161 182 191

Minimum allowed coverage is 0%

Generated by 🐒 cobertura-action against db03d17

@github-actions
Copy link

github-actions bot commented Feb 11, 2025

Coverage report for weather-forecast/ecrad-3d-correction/unets model project

File Coverage Missing
All files 96%
data.py 98% 97
dataproc.py 93% 164 168-170 215-216
models.py 98% 52
unet.py 94% 92 122 125

Minimum allowed coverage is 0%

Generated by 🐒 cobertura-action against db03d17

@A669015 A669015 requested a review from elasto February 18, 2025 08:57
@A669015 A669015 marked this pull request as draft February 20, 2025 16:03
@A669015 A669015 force-pushed the SBRA-739-redefine-the-graph-creation-method branch from 164a8d7 to ef0dfa1 Compare February 24, 2025 13:52
@A669015 A669015 force-pushed the SBRA-739-redefine-the-graph-creation-method branch 2 times, most recently from 901f8df to 176c29b Compare March 31, 2025 12:57
@A669015 A669015 force-pushed the SBRA-739-redefine-the-graph-creation-method branch from 176c29b to b745bce Compare March 31, 2025 12:58
@A669015 A669015 marked this pull request as ready for review March 31, 2025 12:59
Copy link

@mikael10j mikael10j left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a quick review

@cbovalo
Copy link
Contributor

cbovalo commented Apr 23, 2025

Just a few comments:

  • in the pyg_data object, I suggest setting edge_index to None to make it clearer that the adjacency matrix will be given later by graph_topology,
  • I think the egde_index should be a buffer of the nn.Module to be put in the correct device automatically (right now it stays on cpu),
  • I try to think to a better way to pass the grid shape (for plotting) instead of using the position. I mean do we need a tensor of shape [num_nodes, num_dimensions] to get only 3 numbers ?

@A669015
Copy link
Contributor Author

A669015 commented Apr 23, 2025

  • I think the egde_index should be a buffer of the nn.Module to be put in the correct device automatically (right now it stays on cpu),

Do you suggest to register the edge_index tensor in the state_dict, so in the model checkpoint ? We can make it non-persistent to keep the checkpoint as light as possible.

  • I try to think to a better way to pass the grid shape (for plotting) instead of using the position. I mean do we need a tensor of shape [num_nodes, num_dimensions] to get only 3 numbers ?

What do you mean by "a tensor of shape [num_nodes, num_dimensions] to get only 3 numbers" ? The grid_shape is a tuple of 3 values.

@cbovalo
Copy link
Contributor

cbovalo commented Apr 24, 2025

  • I think the egde_index should be a buffer of the nn.Module to be put in the correct device automatically (right now it stays on cpu),

Do you suggest to register the edge_index tensor in the state_dict, so in the model checkpoint ? We can make it non-persistent to keep the checkpoint as light as possible.

Not in the state_dict (setting persistent=False in register_buffer)

  • I try to think to a better way to pass the grid shape (for plotting) instead of using the position. I mean do we need a tensor of shape [num_nodes, num_dimensions] to get only 3 numbers ?

What do you mean by "a tensor of shape [num_nodes, num_dimensions] to get only 3 numbers" ? The grid_shape is a tuple of 3 values.

Exactly! What I mean is that you add to graph_topology an attribute pos which is a tensor with shape (num_nodes, 3) - Cartesian coordinates in 3D. You retrieve the min/max in each axis to get the domain size. Maybe this shape could be returned by the Datamodule and be liked to the LightningModule.

@A669015 A669015 merged commit a81d912 into main Apr 24, 2025
18 checks passed
@A669015 A669015 deleted the SBRA-739-redefine-the-graph-creation-method branch April 24, 2025 12:28
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants