Add use_latlon support to OlmoEarth fine-tuning wrapper#677
Draft
pjreddie wants to merge 3 commits into
Draft
Conversation
Adds a `use_latlon` option to the OlmoEarth wrapper. When enabled, each sample's crop-center lat/lon is computed from its SampleMetadata (crop_bounds + projection CRS, transformed to WGS84) and passed to the encoder via the MaskedOlmoEarthSample.latlon field, so models pretrained with a geographic (lat/lon) encoding can use it during fine-tuning. Because the latlon encoding's dropout is train-only and some checkpoints were pretrained with a non-zero latlon_dropout_rate (e.g. rope_simple_v1 used 0.5), the wrapper forces latlon_dropout_rate=0 when use_latlon is set so the encoding is active on every fine-tuning step. Defaults to False (no behavior change for existing configs). Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
favyen2
reviewed
Jun 26, 2026
3 tasks
favyen2
reviewed
Jun 26, 2026
| if use_latlon and hasattr(self.model, "composite_encodings"): | ||
| encodings = self.model.composite_encodings | ||
| if hasattr(encodings, "latlon_dropout_rate"): | ||
| encodings.latlon_dropout_rate = 0.0 |
Collaborator
There was a problem hiding this comment.
We changed band_dropout_rate to be the reverse, where it is initially set to 0.0 and only set to the configured value during training. Can we do the same here, so that users of olmoearth_pretrain don't need to worry about disabling latlon_dropout_rate?
Per review on #677/#574: move the rasterio imports out of the method to the top of the file, and use rslearn's STGeometry.to_projection to convert the crop-center from the window projection to WGS84 instead of calling rasterio.warp.transform directly. Verified to produce identical lat/lon. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
olmoearth_pretrain's SeparateEncodings now keeps latlon dropout inactive by default (only enabled by the pretraining loop via enable_latlon_dropout), mirroring band_dropout_rate. So the wrapper no longer needs to reach into the model and force latlon_dropout_rate=0 -- loaded models already use the full latlon encoding with no dropout. Removes that hook; use_latlon now only computes and injects the lat/lon. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Adds a
use_latlonoption to theOlmoEarthfine-tuning wrapper. When enabled, each sample's crop-center lat/lon is computed from itsSampleMetadata(crop_bounds+projectionCRS, transformed to WGS84) and passed to the encoder viaMaskedOlmoEarthSample.latlon, so models pretrained with a geographic (lat/lon) encoding (e.g. theseparate/simple-encoding OlmoEarth variants) can use it during fine-tuning.The latlon encoding's dropout is train-only, and some checkpoints were pretrained with a non-zero
latlon_dropout_rate(e.g.rope_simple_v1used 0.5). To keep the encoding active on every fine-tuning step, the wrapper forceslatlon_dropout_rate = 0whenuse_latlonis set.use_latlondefaults toFalse, so existing configs are unaffected.Changes
OlmoEarth.__init__: newuse_latlon: bool = Falsearg; when set (and the loaded encoder exposescomposite_encodings.latlon_dropout_rate), force the rate to 0.OlmoEarth._compute_latlon_from_metadata(...): crop-center → WGS84(lat, lon)per sample.OlmoEarth.forward: whenuse_latlon, inject the computed lat/lon into the sample before the encoder.Testing
rope_simple_v1) withuse_latlon=True; confirmedlatlon_dropout_rateis forced to 0.fast_dev_runof a segmentation fine-tune (encoder +UNetDecoder+SegmentationHead) withuse_latlon: true,use_legacy_timestamps: false, andUNetDecoder(use_batch_norm: true)completes a train+val step (loss ≈ ln(num_classes) at init, as expected).🤖 Generated with Claude Code