Skip to content

Fix model device transfer bug#379

Merged
mzouink merged 1 commit intojanelia-cellmap:mainfrom
eschombu:device-fix
May 12, 2025
Merged

Fix model device transfer bug#379
mzouink merged 1 commit intojanelia-cellmap:mainfrom
eschombu:device-fix

Conversation

@eschombu
Copy link
Copy Markdown
Contributor

@eschombu eschombu commented May 9, 2025

At the validation interval during model training (using dacapo.train::train_run), weights_store.store_weights(run, i + 1) caused an error with the model weights not being on the correct device. This was because the LocalWeightsStore.save_trace method transfers the model to the cpu, but then did not transfer it back to the original device. This PR uses the strategy for detecting the model device in dacapo.predict_local::predict, wraps that into a new Model.get_device() method, and uses this to fix the bug.

Please let me know if you have any requests for design/style/testing changes to the code in this PR, this would be my first contribution to DaCapo :)

Here's the exception stack trace, to demo what happens when I run things before this fix:

AssertionError                            Traceback (most recent call last)
Cell In[16], line 13
     10 run = Run(config_store.retrieve_run_config(run_config.name))
     12 if __name__ == "__main__":
---> 13     train_run(run)

File [~/code/dacapo/dacapo/train.py:137](http://ccnlin042.flatironinstitute.org:8889/lab/tree/examples/starter_tutorial/~/code/dacapo/dacapo/train.py#line=136), in train_run(run, validate, save_snapshots)
    133 weights_store.store_weights(run, i + 1)
    135 if validate:
    136     # VALIDATE
--> 137     validate_run(
    138         run,
    139         i + 1,
    140     )
    141     stats_store.store_validation_iteration_scores(
    142         run.name, run.validation_scores
    143     )

File [~/code/dacapo/dacapo/validate.py:177](http://ccnlin042.flatironinstitute.org:8889/lab/tree/examples/starter_tutorial/~/code/dacapo/dacapo/validate.py#line=176), in validate_run(run, iteration, datasets_config)
    172     logger.info("validation inputs already copied!")
    174 prediction_array_identifier = array_store.validation_prediction_array(
    175     run.name, iteration, validation_dataset
    176 )
--> 177 predict(
    178     run.model,
    179     input_raw_array_identifier,
    180     prediction_array_identifier,
    181     output_roi=validation_dataset.gt.roi,
    182 )
    184 post_processor.set_prediction(prediction_array_identifier)
    186 dataset_iteration_scores = []

File [~/code/dacapo/dacapo/predict_local.py:78](http://ccnlin042.flatironinstitute.org:8889/lab/tree/examples/starter_tutorial/~/code/dacapo/dacapo/predict_local.py#line=77), in predict(model, raw_array_identifier, prediction_array_identifier, output_roi)
     74 device = compute_context.device
     76 model_device = str(next(model.parameters()).device).split(":")[0]
---> 78 assert model_device == str(
     79     device
     80 ), f"Model is not on the right device, Model: {model_device}, Compute device: {device}"
     82 def predict_fn(block):
     83     raw_input = raw_array.to_ndarray(block.read_roi)

AssertionError: Model is not on the right device, Model: cpu, Compute device: cuda

@mzouink
Copy link
Copy Markdown
Member

mzouink commented May 9, 2025

Thanks @eschombu for the valid contribution

@pattonw can you please double check which tests are failing ?

@eschombu
Copy link
Copy Markdown
Contributor Author

eschombu commented May 9, 2025

Oh yes, some of the failures are related to issues I had getting dacapo working. bioimageio.core==0.8.0 has breaking changes, I had to install bioimageio.core==0.7.0 to fix the ImportError.

@pattonw
Copy link
Copy Markdown
Contributor

pattonw commented May 12, 2025

Great addition. This looks ready to merge.
I just had to pin the bioimageio version to pass the tests, black is just failing because it can't push to your branch, and mypy is currently not in a state to pass.

@mzouink mzouink merged commit 93500ae into janelia-cellmap:main May 12, 2025
0 of 4 checks passed
@mzouink
Copy link
Copy Markdown
Member

mzouink commented May 12, 2025

Thanks @eschombu and @pattonw ! merged

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants