Skip to content

Add RandomForest decoder and embedding classification datamodule#1160

Open
arishofmann wants to merge 1 commit intoterrastackai:mainfrom
arishofmann:main-clean
Open

Add RandomForest decoder and embedding classification datamodule#1160
arishofmann wants to merge 1 commit intoterrastackai:mainfrom
arishofmann:main-clean

Conversation

@arishofmann
Copy link
Copy Markdown

Here is my PR for the RF decoder and embedding classification datamodule I added

@Isabelle-Wittmann
Copy link
Copy Markdown
Collaborator

Hi @arishofmann, thanks a lot for opening a PR, next time feel free to also tag us as reviewers then we get a notification :) I have some high level questions:

  • Why do we need an extra embedding dataset and loader? I would prefer to use the generic datasets for this instead, as far as I can see we only need to extend the generic dataset to also be able to read in the pt files in the expected format.
  • Looking at the current task implementation, I'm not sure the inheritance from the other TerraTorch Tasks/ BaseTask makes sense to me, I think it would be cleaner to keep the existing TerraTorchTask for all gradient-based training, and then have an alternative base task for other types of EmbeddingDecoding mechanisms like RF. For this it would be best to define a general EmbeddingDecodingTask (that does not inherit from TerraTorchTask), from which a specific classification task can inherit if needed. I don't see a need why it needs to be RandomForest specific, would be much better to eg. keep generic for other supervised sklearn methods.
  • With the above changes also the actual decoder wrapper can become a bit simpler

What do you think?

@arishofmann
Copy link
Copy Markdown
Author

Hi @Isabelle-Wittmann, thanks for the review.

On the embedding dataset: yeah, the generic datasets assume image I/O (GeoTIFFs, bands, transforms, spatial dims), so I made a separate one for .pt tensors. But I agree, extending the generic dataset to also handle .pt inputs would be cleaner. I can work on that.
On the task: I inherited from MultiLabelClassificationTask mainly to reuse the metrics and val/test logging. The trade-off was workarounds like automatic_optimization = False. A separate EmbeddingDecodingTask makes sense. Should it inherit from TerraTorchTask (reuses model factory and freeze logic) or directly from LightningModule? Just want to align before I refactor.
On making it generic: good point, I will generalize the decoder wrapper to accept any sklearn estimator instead of hardcoding RF.

@Isabelle-Wittmann
Copy link
Copy Markdown
Collaborator

Hi @arishofmann, great - thanks a lot! For the task module, for me it would make more sense to not inherit from a TerraTorchTask but build a separate new (base) task that fits the sklearn decoders better.

@arishofmann arishofmann force-pushed the main-clean branch 3 times, most recently from d49b2da to ac49c3f Compare April 15, 2026 09:50
@arishofmann
Copy link
Copy Markdown
Author

Hi @Isabelle-Wittmann , I worked through your three points:

  • I made the decoder generic. And renamed RandomForestDecoder to SklearnDecoder, it now takes any sklearn estimator as a string parameter.
  • I created a separate EmbeddingDecodingTask base class with an EmbeddingClassificationTask subclass. Both inherit directly from BaseTask instead of TerraTorchTask, so no gradient-based training assumptions.
  • For the dataset I kept EmbeddingDataset as a separate class. The workflow loads bulk .pt tensors rather than individual image files, so extending the generic datasets felt forced. Happy to revisit if you see a clean way to integrate it.

Tested it on CCC with RF on my Finland embeddings and it runs through and produces validation metrics.

@Isabelle-Wittmann
Copy link
Copy Markdown
Collaborator

Thanks a lot, looks good! I'll test it tomorrow and then merge :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants