Fix model device transfer bug#379
Merged
mzouink merged 1 commit intojanelia-cellmap:mainfrom May 12, 2025
Merged
Conversation
…ghtsStore.save_trace()
Member
Contributor
Author
|
Oh yes, some of the failures are related to issues I had getting dacapo working. |
Contributor
|
Great addition. This looks ready to merge. |
Member
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
At the validation interval during model training (using
dacapo.train::train_run),weights_store.store_weights(run, i + 1)caused an error with the model weights not being on the correct device. This was because theLocalWeightsStore.save_tracemethod transfers the model to the cpu, but then did not transfer it back to the original device. This PR uses the strategy for detecting the model device indacapo.predict_local::predict, wraps that into a newModel.get_device()method, and uses this to fix the bug.Please let me know if you have any requests for design/style/testing changes to the code in this PR, this would be my first contribution to DaCapo :)
Here's the exception stack trace, to demo what happens when I run things before this fix: