From 463c0ac21503c6472871dfafd2c0b60e9a47f689 Mon Sep 17 00:00:00 2001 From: GabrieleTi Date: Wed, 4 Mar 2026 12:55:01 +0100 Subject: [PATCH 01/60] rename caption folders --- .../v1.json | 0 .../v3.json | 0 .../v4.json | 0 .../v5.json | 0 notebooks/04-TvdP_generate-caption-templates.ipynb | 4 ++-- src/data/base_caption_builder.py | 6 ++++-- tests/conftest.py | 7 +++---- tests/test_captions.py | 13 +++++-------- 8 files changed, 14 insertions(+), 16 deletions(-) rename data/s2bms/{caption_templates => location_caption_templates}/v1.json (100%) rename data/s2bms/{caption_templates => location_caption_templates}/v3.json (100%) rename data/s2bms/{caption_templates => location_caption_templates}/v4.json (100%) rename data/s2bms/{caption_templates => location_caption_templates}/v5.json (100%) diff --git a/data/s2bms/caption_templates/v1.json b/data/s2bms/location_caption_templates/v1.json similarity index 100% rename from data/s2bms/caption_templates/v1.json rename to data/s2bms/location_caption_templates/v1.json diff --git a/data/s2bms/caption_templates/v3.json b/data/s2bms/location_caption_templates/v3.json similarity index 100% rename from data/s2bms/caption_templates/v3.json rename to data/s2bms/location_caption_templates/v3.json diff --git a/data/s2bms/caption_templates/v4.json b/data/s2bms/location_caption_templates/v4.json similarity index 100% rename from data/s2bms/caption_templates/v4.json rename to data/s2bms/location_caption_templates/v4.json diff --git a/data/s2bms/caption_templates/v5.json b/data/s2bms/location_caption_templates/v5.json similarity index 100% rename from data/s2bms/caption_templates/v5.json rename to data/s2bms/location_caption_templates/v5.json diff --git a/notebooks/04-TvdP_generate-caption-templates.ipynb b/notebooks/04-TvdP_generate-caption-templates.ipynb index b5c90f1..39a6ef9 100644 --- a/notebooks/04-TvdP_generate-caption-templates.ipynb +++ b/notebooks/04-TvdP_generate-caption-templates.ipynb @@ -44,11 +44,11 @@ "metadata": {}, "outputs": [], "source": [ - "# tmp = cg.generate_captions(n=20, seed=0, save_path=os.path.join(os.environ['DATA_DIR'], 's2bms/caption_templates'))\n", + "# tmp = cg.generate_captions(n=20, seed=0, save_path=os.path.join(os.environ['DATA_DIR'], 's2bms/location_caption_templates'))\n", "tmp = cg.generate_captions(\n", " n=50,\n", " seed=0,\n", - " save_path=os.path.join(os.environ[\"PROJECT_ROOT\"], \"data/s2bms/caption_templates\"),\n", + " save_path=os.path.join(os.environ[\"PROJECT_ROOT\"], \"data/s2bms/location_caption_templates\"),\n", ")" ] } diff --git a/src/data/base_caption_builder.py b/src/data/base_caption_builder.py index 3ca4b77..82de4f0 100644 --- a/src/data/base_caption_builder.py +++ b/src/data/base_caption_builder.py @@ -21,7 +21,7 @@ def __init__(self, templates_fname: str, data_dir: str, seed: int) -> None: """ self.data_dir = data_dir - templates_path = os.path.join(self.data_dir, "caption_templates", templates_fname) + templates_path = os.path.join(self.data_dir, "location_caption_templates", templates_fname) self.templates = json.load(open(templates_path)) self.tokens_in_template = [self._extract_tokens(t) for t in self.templates] @@ -42,7 +42,9 @@ def sync_with_dataset(self, dataset: BaseDataset) -> None: @staticmethod def _extract_tokens(template: str) -> List[str]: """Extract tokens in template and return a list of tokens.""" - return re.findall(r"<([^<>]+)>", template) + tokens = re.findall(r"<([^<>]+)>", template) + # TODO: check if those columns exist in the dataset maps + return tokens @staticmethod def _fill(template: str, fillers: Dict[str, str]) -> str: diff --git a/tests/conftest.py b/tests/conftest.py index e4e9b2d..a3ff6c9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -7,12 +7,11 @@ import pandas as pd import pytest import rootutils -import torch from hydra import compose, initialize from hydra.core.global_hydra import GlobalHydra from omegaconf import DictConfig, open_dict -from src.data.base_caption_builder import BaseCaptionBuilder, DummyCaptionBuilder +from src.data.base_caption_builder import DummyCaptionBuilder from src.data.base_datamodule import BaseDataModule from src.data.butterfly_dataset import ButterflyDataset @@ -165,8 +164,8 @@ def create_butterfly_dataset(request, sample_csv, tmp_path): mock=use_mock, ) - templates_path = tmp_path / "caption_templates" / "v1.json" - os.makedirs(str(tmp_path / "caption_templates"), exist_ok=True) + templates_path = tmp_path / "location_caption_templates" / "v1.json" + os.makedirs(str(tmp_path / "location_caption_templates"), exist_ok=True) print(f"Mock captions written to {templates_path}") templates_path.write_text(json.dumps([" text"])) diff --git a/tests/test_captions.py b/tests/test_captions.py index 66b132f..61a2476 100644 --- a/tests/test_captions.py +++ b/tests/test_captions.py @@ -1,10 +1,7 @@ import json import os -import pandas as pd -import pytest - -from src.data.base_caption_builder import BaseCaptionBuilder, DummyCaptionBuilder +from src.data.base_caption_builder import DummyCaptionBuilder from src.data.base_datamodule import BaseDataModule from src.data.butterfly_caption_builder import ButterflyCaptionBuilder from src.data.butterfly_dataset import ButterflyDataset @@ -12,8 +9,8 @@ def test_datamodule_uses_collate_when_aux_data(request, sample_csv, tmp_path): use_mock = request.config.getoption("--use-mock") - templates_path = tmp_path / "caption_templates" / "v1.json" - os.makedirs(str(tmp_path / "caption_templates"), exist_ok=True) + templates_path = tmp_path / "location_caption_templates" / "v1.json" + os.makedirs(str(tmp_path / "location_caption_templates"), exist_ok=True) print(f"Mock captions written to {templates_path}") templates_path.write_text(json.dumps([" text"])) @@ -57,8 +54,8 @@ def test_captionbuilder_generic_properties(tmp_path): # templates_path = os.path.join(repo_root, "data", "s2bms") # else: templates_path = tmp_path - templates_fpath = templates_path / "caption_templates" / templates_fname - os.makedirs(str(templates_path / "caption_templates"), exist_ok=True) + templates_fpath = templates_path / "location_caption_templates" / templates_fname + os.makedirs(str(templates_path / "location_caption_templates"), exist_ok=True) templates_fpath.write_text(json.dumps([" text"])) print(f"written to {templates_path}") From ba10999811e1f0f1d525ba64938157f930f13029 Mon Sep 17 00:00:00 2001 From: GabrieleTi Date: Wed, 4 Mar 2026 12:57:37 +0100 Subject: [PATCH 02/60] Introduce concept captions and validation logic with top-k metrics --- configs/data/butterfly_coords_text.yaml | 1 + .../data/butterfly_full_param_example.yaml | 3 +- src/data/base_caption_builder.py | 12 ++- src/data/base_datamodule.py | 1 + src/data/butterfly_caption_builder.py | 8 +- src/data/collate_fns.py | 3 - src/models/text_alignment_model.py | 93 ++++++++++++++++--- 7 files changed, 101 insertions(+), 20 deletions(-) diff --git a/configs/data/butterfly_coords_text.yaml b/configs/data/butterfly_coords_text.yaml index 92b6bd7..9c98968 100644 --- a/configs/data/butterfly_coords_text.yaml +++ b/configs/data/butterfly_coords_text.yaml @@ -14,6 +14,7 @@ dataset: caption_builder: _target_: src.data.butterfly_caption_builder.ButterflyCaptionBuilder templates_fname: v3.json + concepts_fname: v1.json data_dir: ${paths.data_dir}/s2bms seed: ${seed} diff --git a/configs/data/butterfly_full_param_example.yaml b/configs/data/butterfly_full_param_example.yaml index 541f4c2..78923f5 100644 --- a/configs/data/butterfly_full_param_example.yaml +++ b/configs/data/butterfly_full_param_example.yaml @@ -22,7 +22,8 @@ dataset: caption_builder: _target_: src.data.butterfly_caption_builder.ButterflyCaptionBuilder - templates_fname: caption_templates.json + templates_fname: v3.json + concepts_fname: v1.json data_dir: ${paths.data_dir}/s2bms seed: ${seed} diff --git a/src/data/base_caption_builder.py b/src/data/base_caption_builder.py index 82de4f0..c265a57 100644 --- a/src/data/base_caption_builder.py +++ b/src/data/base_caption_builder.py @@ -11,7 +11,9 @@ class BaseCaptionBuilder(ABC): - def __init__(self, templates_fname: str, data_dir: str, seed: int) -> None: + def __init__( + self, templates_fname: str, concepts_fname: str, data_dir: str, seed: int + ) -> None: """Interface of caption builder class for converting numerical auxiliary data values into textual descriptions from provided caption templates. @@ -25,6 +27,9 @@ def __init__(self, templates_fname: str, data_dir: str, seed: int) -> None: self.templates = json.load(open(templates_path)) self.tokens_in_template = [self._extract_tokens(t) for t in self.templates] + concepts_path = os.path.join(self.data_dir, "concept_captions", concepts_fname) + self.concepts = json.load(open(concepts_path)) + self.column_to_metadata_map: Dict[str] | None = None self.seed = seed random.seed(self.seed) @@ -98,8 +103,9 @@ def all(self, aux_values) -> List[str]: return formatted_rows - def build_concepts(self, aux_values) -> List[str]: - pass + def sync_concepts(self) -> List[str]: + for concept in self.concepts: + concept["id"] = self.column_to_metadata_map["aux"][concept["col"]]["id"] class DummyCaptionBuilder(BaseCaptionBuilder): diff --git a/src/data/base_datamodule.py b/src/data/base_datamodule.py index 3b639ec..201c671 100644 --- a/src/data/base_datamodule.py +++ b/src/data/base_datamodule.py @@ -63,6 +63,7 @@ def __init__( assert caption_builder is not None, "Caption_builder cannot be None" self.caption_builder = caption_builder self.caption_builder.sync_with_dataset(self.dataset) + self.concept_configs = caption_builder.concepts self.split_data() diff --git a/src/data/butterfly_caption_builder.py b/src/data/butterfly_caption_builder.py index 66b60a3..d6dee33 100644 --- a/src/data/butterfly_caption_builder.py +++ b/src/data/butterfly_caption_builder.py @@ -16,8 +16,10 @@ class ButterflyCaptionBuilder(BaseCaptionBuilder): - def __init__(self, templates_fname: str, data_dir: str, seed: int): - super().__init__(templates_fname, data_dir, seed) + def __init__( + self, templates_fname: str, concepts_fname: str, data_dir: str, seed: int + ) -> None: + super().__init__(templates_fname, concepts_fname, data_dir, seed) @override def sync_with_dataset(self, dataset: BaseDataset) -> None: @@ -42,6 +44,8 @@ def sync_with_dataset(self, dataset: BaseDataset) -> None: "units": units, } + self.sync_concepts() + def get_corine_column_keys(self): """Returns metadata for corine columns.""" if not os.path.isfile(os.path.join(self.data_dir, "corine_classes.csv")): diff --git a/src/data/collate_fns.py b/src/data/collate_fns.py index 4e37844..2c001b4 100644 --- a/src/data/collate_fns.py +++ b/src/data/collate_fns.py @@ -39,10 +39,7 @@ def collate_fn( # convert aux into captions if mode == "train": batch_collected["text"] = caption_builder.random(batch_collected["aux"]) - elif mode == "val": - batch_collected["text"] = caption_builder.all(batch_collected["aux"]) else: batch_collected["text"] = caption_builder.all(batch_collected["aux"]) - # batch_collected['concepts'] = caption_builder.build_concepts(batch_collected["aux"]) return batch_collected diff --git a/src/models/text_alignment_model.py b/src/models/text_alignment_model.py index 21b476a..bd04a16 100644 --- a/src/models/text_alignment_model.py +++ b/src/models/text_alignment_model.py @@ -1,3 +1,4 @@ +from io import text_encoding from typing import Dict, Tuple, override import torch @@ -7,6 +8,7 @@ from src.models.components.eo_encoders.base_eo_encoder import BaseEOEncoder from src.models.components.eo_encoders.multimodal_encoder import MultiModalEncoder from src.models.components.loss_fns.base_loss_fn import BaseLossFn +from src.models.components.metrics.contrastive_validation import ContrastiveValidation from src.models.components.metrics.metrics_wrapper import MetricsWrapper from src.models.components.pred_heads.linear_pred_head import ( BasePredictionHead, @@ -29,6 +31,7 @@ def __init__( num_classes: int | None = None, tabular_dim: int | None = None, prediction_head: BasePredictionHead | None = None, + ks: list[int] | None = [5, 10, 15], ) -> None: """Implementation of contrastive text-eo modality alignment model. @@ -47,6 +50,9 @@ def __init__( trainable_modules, optimizer, scheduler, loss_fn, metrics, num_classes, tabular_dim ) + self.ks = ks + self.log_kwargs = dict(on_step=False, on_epoch=True, prog_bar=True, sync_dist=True) + # Encoders configuration self.eo_encoder = eo_encoder # TODO: move to multi-modal eo encoder @@ -81,6 +87,21 @@ def __init__( # Freezing requested parts self.freezer() + def setup(self, stage: str) -> None: + self.concept_configs = self.trainer.datamodule.concept_configs + self.concepts = [c["concept_caption"] for c in self.concept_configs] + + self.contrastive_val = ContrastiveValidation(self.ks, self.concept_configs) + self.outputs_epoch_memory = [] + + for trainable_module in self.trainable_modules: + if "text" in trainable_module: + self.concept_embeds = None + return + + # Encode concepts if text branch is frozen + self.concept_embeds = self.text_encoder({"text": self.concepts}, mode="train") + @override def forward( self, @@ -95,7 +116,7 @@ def forward( return eo_feats, text_feats @override - def _step(self, batch: Dict[str, torch.Tensor], mode: str = "train") -> torch.Tensor: + def _step(self, batch: Dict[str, torch.Tensor], mode: str = "train"): """Model step logic.""" # Embed @@ -122,18 +143,68 @@ def _step(self, batch: Dict[str, torch.Tensor], mode: str = "train") -> torch.Te ) # Logging - log_kwargs = dict( - on_step=False, - on_epoch=True, - prog_bar=True, - sync_dist=True, - batch_size=local_batch_size, - ) - self.log(f"{mode}_loss", loss, **log_kwargs) + + self.log(f"{mode}_loss", loss, batch_size=local_batch_size, **self.log_kwargs) if self.loss_fn.__getattr__("log_temp") and mode == "train": - self.log("temp", self.loss_fn.__getattr__("log_temp").exp(), **log_kwargs) + self.log( + "temp", + self.loss_fn.__getattr__("log_temp").exp(), + batch_size=local_batch_size, + **self.log_kwargs, + ) + + self.log_dict(metrics, batch_size=local_batch_size, **self.log_kwargs) - self.log_dict(metrics, **log_kwargs) + if mode == "val": + self.outputs_epoch_memory.append( + { + "eo_feats": eo_feats.detach(), + "aux_vals": batch.get("aux", {}).get("aux").detach(), + } + ) return loss + + @override + def on_validation_epoch_end(self): + + # Combine batches + eo_feats = torch.cat([x["eo_feats"] for x in self.outputs_epoch_memory], dim=0) + + aux_vals = torch.cat([x["aux_vals"] for x in self.outputs_epoch_memory], dim=0) + + # Rank on similarity + similarity = self.concept_similarities(eo_feats) + + avr_scores, concept_scores = self.contrastive_val(similarity, aux_values=aux_vals) + + self.log_dict(avr_scores) + for i, result in enumerate(concept_scores): + print(f'\nConcept "{self.concepts[i]}" average top-k accuracies:') + for k, v in result.items(): + print(f"Top-{k}: {v:.1f}%") + + # Reset memory + self.outputs_epoch_memory.clear() + + def concept_similarities(self, eo_embeds, concept=None) -> torch.Tensor: + # Get concept embeddings + if concept is not None: + # If only one concept is provided + if isinstance(concept, str): + concept = [concept] + + concept_embeds = self.text_encoder({"text": concept}, mode="train") + + elif self.concept_embeds is not None: + concept_embeds = self.concept_embeds + else: + concept_embeds = self.text_encoder({"text": self.concepts}, mode="train") + + # Similarity + eo_embeds = F.normalize(eo_embeds, dim=1) + concept_embeds = F.normalize(concept_embeds, dim=1) + similarity_matrix = concept_embeds @ eo_embeds.T + + return similarity_matrix From ed48c48e08573048b3896510a55fa57bd93c5b5e Mon Sep 17 00:00:00 2001 From: GabrieleTi Date: Wed, 4 Mar 2026 13:01:22 +0100 Subject: [PATCH 03/60] Add constructive concept eval for test split --- src/models/text_alignment_model.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/src/models/text_alignment_model.py b/src/models/text_alignment_model.py index bd04a16..0ea3ac1 100644 --- a/src/models/text_alignment_model.py +++ b/src/models/text_alignment_model.py @@ -156,7 +156,7 @@ def _step(self, batch: Dict[str, torch.Tensor], mode: str = "train"): self.log_dict(metrics, batch_size=local_batch_size, **self.log_kwargs) - if mode == "val": + if mode in ["val", "test"]: self.outputs_epoch_memory.append( { "eo_feats": eo_feats.detach(), @@ -167,7 +167,7 @@ def _step(self, batch: Dict[str, torch.Tensor], mode: str = "train"): return loss @override - def on_validation_epoch_end(self): + def on_x_epoch_end(self): # Combine batches eo_feats = torch.cat([x["eo_feats"] for x in self.outputs_epoch_memory], dim=0) @@ -188,6 +188,14 @@ def on_validation_epoch_end(self): # Reset memory self.outputs_epoch_memory.clear() + @override + def on_validation_epoch_end(self): + return self.on_x_epoch_end() + + @override + def on_test_epoch_end(self): + return self.on_x_epoch_end() + def concept_similarities(self, eo_embeds, concept=None) -> torch.Tensor: # Get concept embeddings if concept is not None: From f486bc493841aabb26f154daa7458396401457b9 Mon Sep 17 00:00:00 2001 From: GabrieleTi Date: Wed, 4 Mar 2026 13:27:47 +0100 Subject: [PATCH 04/60] add concept caption v1 --- data/s2bms/concept_captions/v1.json | 14 ++++++++++++++ 1 file changed, 14 insertions(+) create mode 100644 data/s2bms/concept_captions/v1.json diff --git a/data/s2bms/concept_captions/v1.json b/data/s2bms/concept_captions/v1.json new file mode 100644 index 0000000..1a09fbf --- /dev/null +++ b/data/s2bms/concept_captions/v1.json @@ -0,0 +1,14 @@ +[ + { + "concept_caption": "Forested area", + "is_max": true, + "theta_k": 0.5, + "col": "aux_corine_frac_311" + }, + { + "concept_caption": "Sparsely populated area", + "is_max": false, + "theta_k": 0.2, + "col": "aux_corine_frac_111" + } +] From 4f9467a185a5bea564c9cbba0477141642a57005 Mon Sep 17 00:00:00 2001 From: GabrieleTi Date: Wed, 4 Mar 2026 14:22:43 +0100 Subject: [PATCH 05/60] add contrastive concept top-k metrics --- .../metrics/contrastive_validation.py | 93 +++++++++++++++++++ 1 file changed, 93 insertions(+) create mode 100644 src/models/components/metrics/contrastive_validation.py diff --git a/src/models/components/metrics/contrastive_validation.py b/src/models/components/metrics/contrastive_validation.py new file mode 100644 index 0000000..3dfb2fa --- /dev/null +++ b/src/models/components/metrics/contrastive_validation.py @@ -0,0 +1,93 @@ +from typing import Any, Dict, List, override + +import torch + +from src.models.components.metrics.base_metrics import BaseMetrics + + +class ContrastiveValidation(BaseMetrics): + def __init__(self, ks: List[Any], concept_configs: List[Any]) -> None: + """Evaluates how many eo embeddings are retrieved in top-k metrics based the GT labels. + + :param ks: k values for top-k metrics + :param concept_configs: concept configurations containing details about min/max mode, which + aux_col to use as GT. + """ + super().__init__() + + self.concept_configs = concept_configs + + self.ks = ks + if any("theta_k" in c for c in self.concept_configs): + self.ks.append("theta_k") + + @override + def forward( + self, + similarity_matrix: torch.Tensor, + aux_values: torch.Tensor, + **kwargs, + ) -> torch.Tensor | Dict[str, torch.Tensor]: + """Calculates top-k metrics based the GT (aux-derived) labels.""" + + aux_vals = aux_values.T + + avr_scores = {k: [] for k in self.ks} + concept_scores = [] + for i, configs in enumerate(self.concept_configs): + idx = configs["id"] + is_max = configs["is_max"] + theta_k = configs.get("theta_k") + if theta_k and aux_vals is not None: + aux_val = aux_vals[idx] + theta_k = ( + sum(aux_val >= theta_k).item() if is_max else sum(aux_val <= theta_k).item() + ) + + score = self.topk_rank_agreement( + aux_val, similarity_matrix[i], self.ks, is_max, theta_k + ) + concept_scores.append(score) + for k, v in score.items(): + avr_scores[k].append(v) + + return_scores = {} + for k, v in avr_scores.items(): + return_scores[f"avr_top-{k}"] = sum(v) / len(v) + + return return_scores, concept_scores + + @staticmethod + def topk_rank_agreement(gt_vals, pred_vals, ks, is_max=True, theta_k=None): + """Get how much of top-k concept retrievals are predicted correctly.""" + num_candidates = len(gt_vals) + + gt_order = torch.argsort(gt_vals, descending=True) + pred_order = torch.argsort(pred_vals, descending=True) + + gt_rank_pos = torch.empty_like(gt_order) + gt_rank_pos[gt_order] = torch.arange(num_candidates, device=gt_order.device) + + pred_rank_pos = torch.empty_like(pred_order) + pred_rank_pos[pred_order] = torch.arange(num_candidates, device=pred_order.device) + + results = {} + + for k in ks: + k_key = k + if k == "theta_k": + if theta_k != 0: + k = theta_k + else: + continue + + if is_max: + gt_mask = gt_rank_pos < k + pred_mask = pred_rank_pos < k + else: + k_inverted = num_candidates - k + gt_mask = gt_rank_pos >= k_inverted + pred_mask = pred_rank_pos >= k_inverted + results[k_key] = (gt_mask & pred_mask).sum().item() / k * 100 + + return results From cd9b840322b52dae05639dad787d0d0140e78f90 Mon Sep 17 00:00:00 2001 From: GabrieleTi Date: Wed, 4 Mar 2026 14:26:00 +0100 Subject: [PATCH 06/60] test fixes --- src/data/base_caption_builder.py | 6 ++-- src/data/butterfly_dataset.py | 4 +-- src/models/text_alignment_model.py | 1 + tests/conftest.py | 15 ++++++++++ tests/test_captions.py | 46 +++++++++++++++++++++++++----- 5 files changed, 61 insertions(+), 11 deletions(-) diff --git a/src/data/base_caption_builder.py b/src/data/base_caption_builder.py index c265a57..54a020f 100644 --- a/src/data/base_caption_builder.py +++ b/src/data/base_caption_builder.py @@ -111,8 +111,10 @@ def sync_concepts(self) -> List[str]: class DummyCaptionBuilder(BaseCaptionBuilder): """Dummy caption builder for testing purposes.""" - def __init__(self, templates_fname: str, data_dir: str, seed: int) -> None: - super().__init__(templates_fname, data_dir, seed) + def __init__( + self, templates_fname: str, concepts_fname: str, data_dir: str, seed: int + ) -> None: + super().__init__(templates_fname, concepts_fname, data_dir, seed) def sync_with_dataset(self, dataset) -> None: pass diff --git a/src/data/butterfly_dataset.py b/src/data/butterfly_dataset.py index 422f8c2..5ecacf7 100644 --- a/src/data/butterfly_dataset.py +++ b/src/data/butterfly_dataset.py @@ -17,7 +17,7 @@ def __init__( data_dir: str, modalities: dict, use_target_data: bool = True, - use_aux_data: Dict[str, List[str] | str] | None = None, + use_aux_data: Any = None, seed: int = 12345, cache_dir: str = None, mock: bool = False, @@ -28,7 +28,7 @@ def __init__( :param modalities: a dict of modalities needed as EO data (for EO encoder) (e.g., {"coords": None, "s2": {"channels": "rgb", "preprocessing": "zscored"}}) :param use_target_data: if target values should be returned - :param use_aux_data: if auxiliary values should be returned + :param use_aux_data: which (if any) auxiliary values should be returned :param seed: random seed :param cache_dir: path to cache dir :param mock: whether to mock csv file diff --git a/src/models/text_alignment_model.py b/src/models/text_alignment_model.py index 0ea3ac1..31c1bdc 100644 --- a/src/models/text_alignment_model.py +++ b/src/models/text_alignment_model.py @@ -178,6 +178,7 @@ def on_x_epoch_end(self): similarity = self.concept_similarities(eo_feats) avr_scores, concept_scores = self.contrastive_val(similarity, aux_values=aux_vals) + # TODO pearson self.log_dict(avr_scores) for i, result in enumerate(concept_scores): diff --git a/tests/conftest.py b/tests/conftest.py index a3ff6c9..e4dcbf4 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -169,8 +169,23 @@ def create_butterfly_dataset(request, sample_csv, tmp_path): print(f"Mock captions written to {templates_path}") templates_path.write_text(json.dumps([" text"])) + concepts_path = tmp_path / "concept_captions" / "v1.json" + os.makedirs(str(tmp_path / "concept_captions"), exist_ok=True) + print(f"Concept captions written to {concepts_path}") + concepts_path.write_text( + json.dumps( + """[{ + "concept_caption": "Forested area", + "is_max": true, + "theta_k": 0.5, + "col": "aux_corine_frac_311" + }]""" + ) + ) + caption_builder = DummyCaptionBuilder( templates_fname="v1.json", + concepts_fname="v1.json", data_dir=str(tmp_path), seed=0, ) diff --git a/tests/test_captions.py b/tests/test_captions.py index 61a2476..85951dd 100644 --- a/tests/test_captions.py +++ b/tests/test_captions.py @@ -14,7 +14,21 @@ def test_datamodule_uses_collate_when_aux_data(request, sample_csv, tmp_path): print(f"Mock captions written to {templates_path}") templates_path.write_text(json.dumps([" text"])) - caption_builder = DummyCaptionBuilder("v1.json", data_dir=str(tmp_path), seed=0) + concepts_path = tmp_path / "concept_captions" / "v1.json" + os.makedirs(str(tmp_path / "concept_captions"), exist_ok=True) + print(f"Concept captions written to {concepts_path}") + concepts_path.write_text( + json.dumps( + """[{ + "concept_caption": "Forested area", + "is_max": true, + "theta_k": 0.5, + "col": "aux_corine_frac_311" + }]""" + ) + ) + + caption_builder = DummyCaptionBuilder("v1.json", "v1.json", data_dir=str(tmp_path), seed=0) dataset = ButterflyDataset( data_dir=sample_csv, @@ -46,6 +60,7 @@ def test_captionbuilder_generic_properties(tmp_path): dict_caption_builders = {"butterfly": ButterflyCaptionBuilder, "dummy": DummyCaptionBuilder} templates_fname = "v1.json" + concepts_fname = "v1.json" for name_cb, cb_class in dict_caption_builders.items(): # There is no data on git anymore @@ -53,21 +68,38 @@ def test_captionbuilder_generic_properties(tmp_path): # repo_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) # templates_path = os.path.join(repo_root, "data", "s2bms") # else: - templates_path = tmp_path - templates_fpath = templates_path / "location_caption_templates" / templates_fname - os.makedirs(str(templates_path / "location_caption_templates"), exist_ok=True) - templates_fpath.write_text(json.dumps([" text"])) - print(f"written to {templates_path}") + templates_path = tmp_path / "location_caption_templates" / templates_fname + os.makedirs(str(tmp_path / "location_caption_templates"), exist_ok=True) + print(f"Mock captions written to {templates_path}") + templates_path.write_text(json.dumps([" text"])) + + concepts_path = tmp_path / "concept_captions" / concepts_fname + os.makedirs(str(tmp_path / "concept_captions"), exist_ok=True) + print(f"Concept captions written to {concepts_path}") + concepts_path.write_text( + json.dumps( + """[{ + "concept_caption": "Forested area", + "is_max": true, + "theta_k": 0.5, + "col": "aux_corine_frac_311" + }]""" + ) + ) caption_builder = cb_class( templates_fname=templates_fname, - data_dir=templates_path, + concepts_fname=concepts_fname, + data_dir=tmp_path, seed=0, ) assert hasattr( caption_builder, "templates" ), f"'templates' attribute missing in {cb_class.__name__}." + assert hasattr( + caption_builder, "concepts" + ), f"'concepts' attribute missing in {cb_class.__name__}." assert hasattr( caption_builder, "data_dir" ), f"'data_dir' attribute missing in {cb_class.__name__}." From 0aaa7e05cfb8465a782ee7f0a0d05a5da66a95c8 Mon Sep 17 00:00:00 2001 From: GabrieleTi Date: Thu, 5 Mar 2026 08:55:42 +0100 Subject: [PATCH 07/60] Change column ids --- data/s2bms/concept_captions/v1.json | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/data/s2bms/concept_captions/v1.json b/data/s2bms/concept_captions/v1.json index 1a09fbf..323089f 100644 --- a/data/s2bms/concept_captions/v1.json +++ b/data/s2bms/concept_captions/v1.json @@ -3,12 +3,12 @@ "concept_caption": "Forested area", "is_max": true, "theta_k": 0.5, - "col": "aux_corine_frac_311" + "col": "aux_corine_frac_3" }, { "concept_caption": "Sparsely populated area", "is_max": false, "theta_k": 0.2, - "col": "aux_corine_frac_111" + "col": "aux_corine_frac_11" } ] From b02edd640beeaea4f9f223797a15ea5e51030943 Mon Sep 17 00:00:00 2001 From: GabrieleTi Date: Thu, 5 Mar 2026 09:37:13 +0100 Subject: [PATCH 08/60] Move auc out of if statement and rename k_threshold --- .../components/metrics/contrastive_validation.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/src/models/components/metrics/contrastive_validation.py b/src/models/components/metrics/contrastive_validation.py index 3dfb2fa..9886aaf 100644 --- a/src/models/components/metrics/contrastive_validation.py +++ b/src/models/components/metrics/contrastive_validation.py @@ -37,15 +37,18 @@ def forward( for i, configs in enumerate(self.concept_configs): idx = configs["id"] is_max = configs["is_max"] - theta_k = configs.get("theta_k") - if theta_k and aux_vals is not None: - aux_val = aux_vals[idx] - theta_k = ( - sum(aux_val >= theta_k).item() if is_max else sum(aux_val <= theta_k).item() + k_threshold = configs.get("theta_k") + aux_val = aux_vals[idx] + + if k_threshold: + k_threshold = ( + sum(aux_val >= k_threshold).item() + if is_max + else sum(aux_val <= k_threshold).item() ) score = self.topk_rank_agreement( - aux_val, similarity_matrix[i], self.ks, is_max, theta_k + aux_val, similarity_matrix[i], self.ks, is_max, k_threshold ) concept_scores.append(score) for k, v in score.items(): From 4bf4f0f360f43e7a8649d03f0592ff1849e5ae7e Mon Sep 17 00:00:00 2001 From: GabrieleTi Date: Thu, 5 Mar 2026 09:38:30 +0100 Subject: [PATCH 09/60] Rename k_threshold --- .../components/metrics/contrastive_validation.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/models/components/metrics/contrastive_validation.py b/src/models/components/metrics/contrastive_validation.py index 9886aaf..dc5be7b 100644 --- a/src/models/components/metrics/contrastive_validation.py +++ b/src/models/components/metrics/contrastive_validation.py @@ -19,7 +19,7 @@ def __init__(self, ks: List[Any], concept_configs: List[Any]) -> None: self.ks = ks if any("theta_k" in c for c in self.concept_configs): - self.ks.append("theta_k") + self.ks.append("dynamic_k") @override def forward( @@ -61,7 +61,7 @@ def forward( return return_scores, concept_scores @staticmethod - def topk_rank_agreement(gt_vals, pred_vals, ks, is_max=True, theta_k=None): + def topk_rank_agreement(gt_vals, pred_vals, ks, is_max=True, dynamic_k=None): """Get how much of top-k concept retrievals are predicted correctly.""" num_candidates = len(gt_vals) @@ -78,9 +78,9 @@ def topk_rank_agreement(gt_vals, pred_vals, ks, is_max=True, theta_k=None): for k in ks: k_key = k - if k == "theta_k": - if theta_k != 0: - k = theta_k + if k == "dynamic_k": + if dynamic_k != 0: + k = dynamic_k else: continue From ddf30fc456a90975f259933ee381f3d7f83194da Mon Sep 17 00:00:00 2001 From: GabrieleTi Date: Thu, 5 Mar 2026 09:42:24 +0100 Subject: [PATCH 10/60] Add no_grad for text embedding --- src/models/text_alignment_model.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/src/models/text_alignment_model.py b/src/models/text_alignment_model.py index 31c1bdc..3594ec3 100644 --- a/src/models/text_alignment_model.py +++ b/src/models/text_alignment_model.py @@ -100,7 +100,8 @@ def setup(self, stage: str) -> None: return # Encode concepts if text branch is frozen - self.concept_embeds = self.text_encoder({"text": self.concepts}, mode="train") + with torch.no_grad(): + self.concept_embeds = self.text_encoder({"text": self.concepts}, mode="train") @override def forward( @@ -143,7 +144,6 @@ def _step(self, batch: Dict[str, torch.Tensor], mode: str = "train"): ) # Logging - self.log(f"{mode}_loss", loss, batch_size=local_batch_size, **self.log_kwargs) if self.loss_fn.__getattr__("log_temp") and mode == "train": @@ -167,7 +167,7 @@ def _step(self, batch: Dict[str, torch.Tensor], mode: str = "train"): return loss @override - def on_x_epoch_end(self): + def _on_epoch_end(self, mode: str): # Combine batches eo_feats = torch.cat([x["eo_feats"] for x in self.outputs_epoch_memory], dim=0) @@ -182,7 +182,7 @@ def on_x_epoch_end(self): self.log_dict(avr_scores) for i, result in enumerate(concept_scores): - print(f'\nConcept "{self.concepts[i]}" average top-k accuracies:') + print(f'\nConcept "{self.concepts[i]}" average top-k accuracies in {mode} split:') for k, v in result.items(): print(f"Top-{k}: {v:.1f}%") @@ -191,11 +191,11 @@ def on_x_epoch_end(self): @override def on_validation_epoch_end(self): - return self.on_x_epoch_end() + return self._on_epoch_end("val") @override def on_test_epoch_end(self): - return self.on_x_epoch_end() + return self._on_epoch_end("test") def concept_similarities(self, eo_embeds, concept=None) -> torch.Tensor: # Get concept embeddings @@ -203,13 +203,14 @@ def concept_similarities(self, eo_embeds, concept=None) -> torch.Tensor: # If only one concept is provided if isinstance(concept, str): concept = [concept] - - concept_embeds = self.text_encoder({"text": concept}, mode="train") + with torch.no_grad(): + concept_embeds = self.text_encoder({"text": concept}, mode="train") elif self.concept_embeds is not None: concept_embeds = self.concept_embeds else: - concept_embeds = self.text_encoder({"text": self.concepts}, mode="train") + with torch.no_grad(): + concept_embeds = self.text_encoder({"text": self.concepts}, mode="train") # Similarity eo_embeds = F.normalize(eo_embeds, dim=1) From c19a3049c3498946545037ed7c979aa3fdba3af0 Mon Sep 17 00:00:00 2001 From: GabrieleTi Date: Thu, 5 Mar 2026 10:23:26 +0100 Subject: [PATCH 11/60] Move data bound model configuration into setup methods --- src/models/base_model.py | 7 ++--- src/models/predictive_model.py | 32 ++++++++++++++-------- src/models/text_alignment_model.py | 44 +++++++++++++++++++----------- src/train.py | 4 +-- 4 files changed, 51 insertions(+), 36 deletions(-) diff --git a/src/models/base_model.py b/src/models/base_model.py index 407d08a..0a8f259 100644 --- a/src/models/base_model.py +++ b/src/models/base_model.py @@ -16,8 +16,6 @@ def __init__( scheduler: torch.optim.lr_scheduler, loss_fn: BaseLossFn, metrics: MetricsWrapper, - num_classes: int | None = None, - tabular_dim: int | None = None, ) -> None: """Interface for any model. @@ -26,7 +24,6 @@ def __init__( :param scheduler: scheduler for the model weight update :param loss_fn: loss function :param metrics: metrics to track for model performance estimation - :param num_classes: number of classes to predict """ super().__init__() self.save_hyperparameters( @@ -34,8 +31,8 @@ def __init__( ) self.trainable_modules = trainable_modules - self.num_classes = num_classes - self.tabular_dim = tabular_dim + self.num_classes: int | None = None + self.tabular_dim: int | None = None self.loss_fn = loss_fn self.metrics = metrics diff --git a/src/models/predictive_model.py b/src/models/predictive_model.py index f46a4dc..49044d1 100644 --- a/src/models/predictive_model.py +++ b/src/models/predictive_model.py @@ -23,8 +23,6 @@ def __init__( scheduler: torch.optim.lr_scheduler, loss_fn: BaseLossFn, metrics: MetricsWrapper, - num_classes: int | None = None, - tabular_dim: int | None = None, ) -> None: """Implementation of the predictive model with replaceable EO encoder, and prediction head. @@ -39,31 +37,41 @@ def __init__( :param tabular_dim: number of tabular features """ - super().__init__( - trainable_modules, optimizer, scheduler, loss_fn, metrics, num_classes, tabular_dim - ) + super().__init__(trainable_modules, optimizer, scheduler, loss_fn, metrics) # EO encoder configuration self.eo_encoder = eo_encoder + # Prediction head + self.prediction_head = prediction_head + + def setup(self, stage: str) -> None: + """Setup the predictive model.""" + self.num_classes = self.trainer.datamodule.num_classes + self.tabular_dim = self.trainer.datamodule.tabular_dim + + self.setup_encoders_adapters() + + # Freezing requested parts + self.freezer() + + def setup_encoders_adapters(self): + """Set up encoders and missing adapters/projectors.""" # TODO: move to multi-modal eo encoder if ( isinstance(self.eo_encoder, MultiModalEncoder) and self.eo_encoder.use_tabular and not self.eo_encoder._tabular_ready ): - self.eo_encoder.build_tabular_branch(tabular_dim) + self.eo_encoder.build_tabular_branch(self.tabular_dim) - # Prediction head - self.prediction_head = prediction_head - self.prediction_head.set_dim(input_dim=self.eo_encoder.output_dim, output_dim=num_classes) + self.prediction_head.set_dim( + input_dim=self.eo_encoder.output_dim, output_dim=self.num_classes + ) self.prediction_head.configure_nn() if "prediction_head" not in self.trainable_modules: self.trainable_modules.append("prediction_head") - # Freezing requested parts - self.freezer() - @override def forward(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: feats = self.eo_encoder(batch) diff --git a/src/models/text_alignment_model.py b/src/models/text_alignment_model.py index 3594ec3..2447cfd 100644 --- a/src/models/text_alignment_model.py +++ b/src/models/text_alignment_model.py @@ -28,8 +28,6 @@ def __init__( loss_fn: BaseLossFn, trainable_modules: list[str], metrics: MetricsWrapper, - num_classes: int | None = None, - tabular_dim: int | None = None, prediction_head: BasePredictionHead | None = None, ks: list[int] | None = [5, 10, 15], ) -> None: @@ -46,36 +44,53 @@ def __init__( :param tabular_dim: number of tabular features :param prediction_head: prediction head """ - super().__init__( - trainable_modules, optimizer, scheduler, loss_fn, metrics, num_classes, tabular_dim - ) + super().__init__(trainable_modules, optimizer, scheduler, loss_fn, metrics) self.ks = ks self.log_kwargs = dict(on_step=False, on_epoch=True, prog_bar=True, sync_dist=True) # Encoders configuration self.eo_encoder = eo_encoder + self.text_encoder = text_encoder + + # Prediction head + self.prediction_head = prediction_head + + def setup(self, stage: str) -> None: + """Configures model based on the parameters provided by dataset (through datamodule) This + method is called after trainer is initialized and datamodule is available.""" + self.num_classes = self.trainer.datamodule.num_classes + self.tabular_dim = self.trainer.datamodule.tabular_dim + + # Set up encoders and missing adapters/projectors + self.setup_encoders_adapters() + + # Freeze requested parts + self.freezer() + + # Configure contrastive retrieval evaluation + self.setup_retrieval_evaluation() + + def setup_encoders_adapters(self): + """Set up encoders and missing adapters/projectors.""" # TODO: move to multi-modal eo encoder if ( isinstance(self.eo_encoder, MultiModalEncoder) and self.eo_encoder.use_tabular and not self.eo_encoder._tabular_ready ): - self.eo_encoder.build_tabular_branch(tabular_dim) - - self.text_encoder = text_encoder - # TODO: if eo==geoclip_img pass on shared mlp + self.eo_encoder.build_tabular_branch(self.tabular_dim) # Extra projector for text encoder if eo and text dim not match if self.eo_encoder.output_dim != self.text_encoder.output_dim: self.text_encoder.add_projector(projected_dim=self.eo_encoder.output_dim) self.trainable_modules.append("text_encoder.extra_projector") - # Prediction head - self.prediction_head = prediction_head + # TODO: if eo==geoclip_img pass on shared mlp + if self.prediction_head is not None: self.prediction_head.set_dim( - input_dim=self.eo_encoder.output_dim, output_dim=num_classes + input_dim=self.eo_encoder.output_dim, output_dim=self.num_classes ) self.prediction_head.configure_nn() @@ -84,10 +99,7 @@ def __init__( self.eo_encoder = self.eo_encoder.to(self.text_encoder.dtype) print(f"Eo encoder dtype changed to {self.eo_encoder.dtype}") - # Freezing requested parts - self.freezer() - - def setup(self, stage: str) -> None: + def setup_retrieval_evaluation(self): self.concept_configs = self.trainer.datamodule.concept_configs self.concepts = [c["concept_caption"] for c in self.concept_configs] diff --git a/src/train.py b/src/train.py index 34f4a99..8348fbb 100644 --- a/src/train.py +++ b/src/train.py @@ -51,9 +51,7 @@ def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: datamodule: BaseDataModule = hydra.utils.instantiate(cfg.data) log.info(f"Instantiating model <{cfg.model._target_}>") - model: LightningModule = hydra.utils.instantiate( - cfg.model, num_classes=datamodule.num_classes, tabular_dim=datamodule.tabular_dim - ) + model: LightningModule = hydra.utils.instantiate(cfg.model) log.info("Instantiating callbacks...") callbacks: List[Callback] = instantiate_callbacks(cfg.get("callbacks")) From a529674045ba64ccb5c5546880e27e91f611653e Mon Sep 17 00:00:00 2001 From: GabrieleTi Date: Thu, 5 Mar 2026 11:21:08 +0100 Subject: [PATCH 12/60] return dic of scores and average them later --- .../metrics/contrastive_validation.py | 17 ++++++++--------- src/models/text_alignment_model.py | 18 +++++++++++++----- 2 files changed, 21 insertions(+), 14 deletions(-) diff --git a/src/models/components/metrics/contrastive_validation.py b/src/models/components/metrics/contrastive_validation.py index dc5be7b..bf3ab5a 100644 --- a/src/models/components/metrics/contrastive_validation.py +++ b/src/models/components/metrics/contrastive_validation.py @@ -5,7 +5,7 @@ from src.models.components.metrics.base_metrics import BaseMetrics -class ContrastiveValidation(BaseMetrics): +class RetrievalContrastiveValidation(BaseMetrics): def __init__(self, ks: List[Any], concept_configs: List[Any]) -> None: """Evaluates how many eo embeddings are retrieved in top-k metrics based the GT labels. @@ -32,9 +32,10 @@ def forward( aux_vals = aux_values.T - avr_scores = {k: [] for k in self.ks} - concept_scores = [] + concept_scores = {} for i, configs in enumerate(self.concept_configs): + avr_scores = {k: [] for k in self.ks} + idx = configs["id"] is_max = configs["is_max"] k_threshold = configs.get("theta_k") @@ -50,15 +51,13 @@ def forward( score = self.topk_rank_agreement( aux_val, similarity_matrix[i], self.ks, is_max, k_threshold ) - concept_scores.append(score) + for k, v in score.items(): - avr_scores[k].append(v) + avr_scores[k] = v - return_scores = {} - for k, v in avr_scores.items(): - return_scores[f"avr_top-{k}"] = sum(v) / len(v) + concept_scores[i] = avr_scores - return return_scores, concept_scores + return concept_scores @staticmethod def topk_rank_agreement(gt_vals, pred_vals, ks, is_max=True, dynamic_k=None): diff --git a/src/models/text_alignment_model.py b/src/models/text_alignment_model.py index 3594ec3..a476f37 100644 --- a/src/models/text_alignment_model.py +++ b/src/models/text_alignment_model.py @@ -8,7 +8,9 @@ from src.models.components.eo_encoders.base_eo_encoder import BaseEOEncoder from src.models.components.eo_encoders.multimodal_encoder import MultiModalEncoder from src.models.components.loss_fns.base_loss_fn import BaseLossFn -from src.models.components.metrics.contrastive_validation import ContrastiveValidation +from src.models.components.metrics.contrastive_validation import ( + RetrievalContrastiveValidation, +) from src.models.components.metrics.metrics_wrapper import MetricsWrapper from src.models.components.pred_heads.linear_pred_head import ( BasePredictionHead, @@ -91,7 +93,7 @@ def setup(self, stage: str) -> None: self.concept_configs = self.trainer.datamodule.concept_configs self.concepts = [c["concept_caption"] for c in self.concept_configs] - self.contrastive_val = ContrastiveValidation(self.ks, self.concept_configs) + self.contrastive_val = RetrievalContrastiveValidation(self.ks, self.concept_configs) self.outputs_epoch_memory = [] for trainable_module in self.trainable_modules: @@ -177,14 +179,20 @@ def _on_epoch_end(self, mode: str): # Rank on similarity similarity = self.concept_similarities(eo_feats) - avr_scores, concept_scores = self.contrastive_val(similarity, aux_values=aux_vals) + concept_scores = self.contrastive_val(similarity, aux_values=aux_vals) # TODO pearson - self.log_dict(avr_scores) - for i, result in enumerate(concept_scores): + avr_scores = {f"{mode}_avr_top-{k}": [] for k in self.ks} + for i, result in concept_scores.items(): print(f'\nConcept "{self.concepts[i]}" average top-k accuracies in {mode} split:') for k, v in result.items(): print(f"Top-{k}: {v:.1f}%") + avr_scores[f"{mode}_avr_top-{k}"].append(v) + + for k, v in avr_scores.items(): + avr_scores[k] = sum(v) / len(v) + + self.log_dict(avr_scores) # Reset memory self.outputs_epoch_memory.clear() From 339ef885f143caf8393391811d4f8caed9875b62 Mon Sep 17 00:00:00 2001 From: GabrieleTi Date: Thu, 5 Mar 2026 11:52:45 +0100 Subject: [PATCH 13/60] Fix tests --- tests/test_configs.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/test_configs.py b/tests/test_configs.py index 9bb9c21..e34abf6 100644 --- a/tests/test_configs.py +++ b/tests/test_configs.py @@ -15,8 +15,7 @@ def test_train_config(cfg_train: DictConfig) -> None: HydraConfig().set_config(cfg_train) - datamodule = hydra.utils.instantiate(cfg_train.data) - hydra.utils.instantiate(cfg_train.model, num_classes=datamodule.num_classes) + hydra.utils.instantiate(cfg_train.model) hydra.utils.instantiate(cfg_train.trainer) From a4a164e8233447385f2b49cfbecde96b79128d27 Mon Sep 17 00:00:00 2001 From: GabrieleTi Date: Thu, 5 Mar 2026 11:59:53 +0100 Subject: [PATCH 14/60] Resolve threshold_k naming confusion and remove redundant dictionary --- .../components/metrics/contrastive_validation.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/src/models/components/metrics/contrastive_validation.py b/src/models/components/metrics/contrastive_validation.py index bf3ab5a..ce2385a 100644 --- a/src/models/components/metrics/contrastive_validation.py +++ b/src/models/components/metrics/contrastive_validation.py @@ -34,28 +34,24 @@ def forward( concept_scores = {} for i, configs in enumerate(self.concept_configs): - avr_scores = {k: [] for k in self.ks} - idx = configs["id"] is_max = configs["is_max"] k_threshold = configs.get("theta_k") aux_val = aux_vals[idx] if k_threshold: - k_threshold = ( + dynamic_k = ( sum(aux_val >= k_threshold).item() if is_max else sum(aux_val <= k_threshold).item() ) + else: + dynamic_k = None - score = self.topk_rank_agreement( - aux_val, similarity_matrix[i], self.ks, is_max, k_threshold - ) - - for k, v in score.items(): - avr_scores[k] = v + sim_val = similarity_matrix[i] + scores = self.topk_rank_agreement(aux_val, sim_val, self.ks, is_max, dynamic_k) - concept_scores[i] = avr_scores + concept_scores[i] = scores return concept_scores From a8da730659c2b651cd3057d26b0f8942d5cb4507 Mon Sep 17 00:00:00 2001 From: GabrieleTi Date: Thu, 5 Mar 2026 13:35:21 +0100 Subject: [PATCH 15/60] Add setup to basemodel --- src/models/base_model.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/models/base_model.py b/src/models/base_model.py index 0a8f259..0a0aeaf 100644 --- a/src/models/base_model.py +++ b/src/models/base_model.py @@ -36,6 +36,11 @@ def __init__( self.loss_fn = loss_fn self.metrics = metrics + @abstractmethod + def setup(self, stage: str) -> None: + """Updates model based data-bound configurations.""" + pass + @final def freezer(self) -> None: """Freezes modules based on provided trainable modules.""" From c04cc06d1d7535deac319aef3267fc4af27a6a86 Mon Sep 17 00:00:00 2001 From: GabrieleTi Date: Thu, 5 Mar 2026 13:36:45 +0100 Subject: [PATCH 16/60] abstracting setup method --- src/models/base_model.py | 3 ++- src/models/predictive_model.py | 2 +- src/models/text_alignment_model.py | 3 +-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/models/base_model.py b/src/models/base_model.py index 0a0aeaf..e3a2a16 100644 --- a/src/models/base_model.py +++ b/src/models/base_model.py @@ -38,7 +38,8 @@ def __init__( @abstractmethod def setup(self, stage: str) -> None: - """Updates model based data-bound configurations.""" + """Updates model based data-bound configurations (through datamodule), This method is + called after trainer is initialized and datamodule is available.""" pass @final diff --git a/src/models/predictive_model.py b/src/models/predictive_model.py index 49044d1..a0b4212 100644 --- a/src/models/predictive_model.py +++ b/src/models/predictive_model.py @@ -45,8 +45,8 @@ def __init__( # Prediction head self.prediction_head = prediction_head + @override def setup(self, stage: str) -> None: - """Setup the predictive model.""" self.num_classes = self.trainer.datamodule.num_classes self.tabular_dim = self.trainer.datamodule.tabular_dim diff --git a/src/models/text_alignment_model.py b/src/models/text_alignment_model.py index 2447cfd..9bc1ad5 100644 --- a/src/models/text_alignment_model.py +++ b/src/models/text_alignment_model.py @@ -56,9 +56,8 @@ def __init__( # Prediction head self.prediction_head = prediction_head + @override def setup(self, stage: str) -> None: - """Configures model based on the parameters provided by dataset (through datamodule) This - method is called after trainer is initialized and datamodule is available.""" self.num_classes = self.trainer.datamodule.num_classes self.tabular_dim = self.trainer.datamodule.tabular_dim From 2f118bd23c25b2c3b333e638ea614e452dc8d1b3 Mon Sep 17 00:00:00 2001 From: Thijs van der Plas Date: Fri, 6 Mar 2026 13:46:07 +0100 Subject: [PATCH 17/60] minor gee update in case outside of radius_max, return radius_max --- src/data_preprocessing/gee_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/data_preprocessing/gee_utils.py b/src/data_preprocessing/gee_utils.py index def17a9..f1f73bf 100644 --- a/src/data_preprocessing/gee_utils.py +++ b/src/data_preprocessing/gee_utils.py @@ -250,8 +250,8 @@ def get_distance_to_road_within_aoi(aoi, cell_size=30, radius_max=5000): reducer=ee.Reducer.mean(), geometry=aoi, scale=cell_size, maxPixels=1e9 ) return { - "maxdist_road": int(max_distance.get("distance").getInfo()), - "meandist_road": int(mean_distance.get("distance").getInfo()), + "maxdist_road": int(max_distance.get("distance").getInfo() or radius_max), + "meandist_road": int(mean_distance.get("distance").getInfo() or radius_max), } From ccd7ac08cf944066d9916db0b672f8bca6db3b60 Mon Sep 17 00:00:00 2001 From: Rob Knapen Date: Sat, 7 Mar 2026 14:41:19 +0100 Subject: [PATCH 18/60] Crop Yield Africa use case - dataset and example tabular regression only experiment. --- configs/data/yield_africa_all.yaml | 28 + .../experiment/yield_africa_tabular_reg.yaml | 25 + configs/metrics/yield_africa_regression.yaml | 7 + configs/model/yield_tabular_reg.yaml | 33 + src/data/yield_africa_dataset.py | 120 +++ .../make_model_ready_yield_africa.py | 735 ++++++++++++++++++ src/models/components/loss_fns/huber_loss.py | 29 + src/models/predictive_model.py | 12 +- 8 files changed, 983 insertions(+), 6 deletions(-) create mode 100644 configs/data/yield_africa_all.yaml create mode 100644 configs/experiment/yield_africa_tabular_reg.yaml create mode 100644 configs/metrics/yield_africa_regression.yaml create mode 100644 configs/model/yield_tabular_reg.yaml create mode 100644 src/data/yield_africa_dataset.py create mode 100644 src/data_preprocessing/make_model_ready_yield_africa.py create mode 100644 src/models/components/loss_fns/huber_loss.py diff --git a/configs/data/yield_africa_all.yaml b/configs/data/yield_africa_all.yaml new file mode 100644 index 0000000..f44a407 --- /dev/null +++ b/configs/data/yield_africa_all.yaml @@ -0,0 +1,28 @@ +_target_: src.data.base_datamodule.BaseDataModule + +dataset: + _target_: src.data.yield_africa_dataset.YieldAfricaDataset + data_dir: ${paths.data_dir} + modalities: + coords: {} + use_target_data: true + use_features: true + use_aux_data: none + seed: ${seed} + cache_dir: ${paths.cache_dir} + # Country/year filters — set to a list to restrict, null to include all. + # countries and years select only the listed values; + # exclude_countries and exclude_years drop the listed values. + countries: null + years: null + exclude_countries: ["BF"] + exclude_years: null + +batch_size: 64 +num_workers: 0 +pin_memory: false + +split_mode: "random" +train_val_test_split: [0.7, 0.15, 0.15] +save_split: false +seed: ${seed} diff --git a/configs/experiment/yield_africa_tabular_reg.yaml b/configs/experiment/yield_africa_tabular_reg.yaml new file mode 100644 index 0000000..fbcaedf --- /dev/null +++ b/configs/experiment/yield_africa_tabular_reg.yaml @@ -0,0 +1,25 @@ +# @package _global_ +# configs/experiment/yield_africa_tabular_reg.yaml +# Variant: Tabular features only, full dataset + +defaults: + - override /model: yield_tabular_reg + - override /data: yield_africa_all + - override /metrics: yield_africa_regression + +tags: ["yield_africa", "tabular_only", "regression"] +seed: 12345 + +trainer: + min_epochs: 1 + max_epochs: 50 + +data: + batch_size: 64 + +logger: + wandb: + tags: ${tags} + group: "yield_africa" + aim: + experiment: "yield_africa" diff --git a/configs/metrics/yield_africa_regression.yaml b/configs/metrics/yield_africa_regression.yaml new file mode 100644 index 0000000..79c441d --- /dev/null +++ b/configs/metrics/yield_africa_regression.yaml @@ -0,0 +1,7 @@ +_target_: src.models.components.metrics.metrics_wrapper.MetricsWrapper + +metrics: + - _target_: src.models.components.loss_fns.mse_loss.MSELoss + - _target_: src.models.components.loss_fns.rmse_loss.RMSELoss + - _target_: src.models.components.loss_fns.mae_loss.MAELoss + - _target_: src.models.components.metrics.r2.RSquared diff --git a/configs/model/yield_tabular_reg.yaml b/configs/model/yield_tabular_reg.yaml new file mode 100644 index 0000000..146c683 --- /dev/null +++ b/configs/model/yield_tabular_reg.yaml @@ -0,0 +1,33 @@ +_target_: src.models.predictive_model.PredictiveModel + +eo_encoder: + _target_: src.models.components.eo_encoders.multimodal_encoder.MultiModalEncoder + use_coords: false + use_tabular: true + tab_embed_dim: 128 + +prediction_head: + _target_: src.models.components.pred_heads.mlp_regression_head.MLPRegressionPredictionHead + nn_layers: 2 + hidden_dim: 64 + +# Both encoder and head have trainable parameters. +trainable_modules: [eo_encoder, prediction_head] + +metrics: ${metrics} + +optimizer: + _target_: torch.optim.Adam + _partial_: true + lr: 0.001 + weight_decay: 0.0001 + +scheduler: + _target_: torch.optim.lr_scheduler.ReduceLROnPlateau + _partial_: true + mode: min + factor: 0.1 + patience: 10 + +loss_fn: + _target_: src.models.components.loss_fns.huber_loss.HuberLoss diff --git a/src/data/yield_africa_dataset.py b/src/data/yield_africa_dataset.py new file mode 100644 index 0000000..65534cc --- /dev/null +++ b/src/data/yield_africa_dataset.py @@ -0,0 +1,120 @@ +"""Yield Africa dataset. + +Location: src/data/yield_africa_dataset.py + +Crop yield regression use case for East/Southern Africa. +Tabular features (soil, climate, etc.) live in the model-ready CSV as feat_* +columns and are picked up automatically by BaseDataset.get_records(). +They do NOT need to be listed in `modalities`. +""" + +import logging +from typing import Any, Dict, List, override + +import torch + +from src.data.base_dataset import BaseDataset + +log = logging.getLogger(__name__) + + +class YieldAfricaDataset(BaseDataset): + """Dataset for the crop yield regression use case (East/Southern Africa). + + CSV layout expected: + - name_loc : unique location identifier + - lat, lon : WGS84 coordinates + - target_* : crop yield target(s) [t/ha] + - feat_* : tabular features (soil properties, climate indices, etc.) + - aux_* : auxiliary data columns (optional) + - country, year : metadata columns used for optional filtering + + Modality design note + -------------------- + `implemented_mod = {"coords"}` because tabular features live directly in + the model-ready CSV and are picked up via the `feat_` column prefix. + They do NOT need to be listed in `modalities`. + """ + + def __init__( + self, + data_dir: str, + modalities: dict, + use_target_data: bool = True, + use_aux_data: Dict[str, Any] | str = None, + seed: int = 12345, + cache_dir: str = None, + mock: bool = False, + use_features: bool = True, + countries: List[str] | None = None, + years: List[int] | None = None, + exclude_countries: List[str] | None = None, + exclude_years: List[int] | None = None, + ) -> None: + super().__init__( + data_dir=data_dir, + modalities=modalities, + use_target_data=use_target_data, + use_aux_data=use_aux_data, + dataset_name="yield_africa", + seed=seed, + cache_dir=cache_dir, + implemented_mod={"coords"}, + mock=mock, + use_features=use_features, + ) + + # Apply country/year filters to self.df and rebuild records if needed. + # BaseDataset.__init__ has already loaded the CSV; filtering here avoids + # touching BaseDataset and keeps the logic use-case specific. + n_before = len(self.df) + if countries is not None and "country" in self.df.columns: + self.df = self.df[self.df["country"].isin(countries)].reset_index(drop=True) + if years is not None and "year" in self.df.columns: + self.df = self.df[self.df["year"].isin(years)].reset_index(drop=True) + if exclude_countries is not None and "country" in self.df.columns: + self.df = self.df[~self.df["country"].isin(exclude_countries)].reset_index(drop=True) + if exclude_years is not None and "year" in self.df.columns: + self.df = self.df[~self.df["year"].isin(exclude_years)].reset_index(drop=True) + + n_after = len(self.df) + if n_after != n_before: + log.info(f"Country/year filter: {n_before} → {n_after} records ({n_before - n_after} excluded)") + self.records = self.get_records() + + def setup(self) -> None: + """No files to download or prepare for this dataset.""" + return + + @override + def __getitem__(self, idx: int) -> Dict[str, Any]: + row = self.records[idx] + sample: Dict[str, Any] = {"eo": {}} + + for modality in self.modalities: + if modality == "coords": + sample["eo"]["coords"] = torch.tensor( + [row["lat"], row["lon"]], dtype=torch.float32 + ) + + if self.use_features and self.feat_names: + sample["eo"]["tabular"] = torch.tensor( + [row[k] for k in self.feat_names], dtype=torch.float32 + ) + + if self.use_target_data: + sample["target"] = torch.tensor( + [row[k] for k in self.target_names], dtype=torch.float32 + ) + + if self.use_aux_data: + sample["aux"] = {} + for aux_cat, vals in self.use_aux_data.items(): + if aux_cat == "aux": + sample["aux"][aux_cat] = torch.tensor( + [row[v] for v in vals], dtype=torch.float32 + ) + else: + sample["aux"][aux_cat] = [row[v] for v in vals] + + return sample diff --git a/src/data_preprocessing/make_model_ready_yield_africa.py b/src/data_preprocessing/make_model_ready_yield_africa.py new file mode 100644 index 0000000..814d901 --- /dev/null +++ b/src/data_preprocessing/make_model_ready_yield_africa.py @@ -0,0 +1,735 @@ +"""Build model-ready CSV/Parquet for the crop yield Africa use case +(data/yield_africa/model_ready_yield-africa.csv). + +Features: +- Load raw dataset (CSV or Parquet) +- Compute derived features (CN_ratio, layer deltas, WHC proxy, aridity index) +- Apply log transforms to skewed features +- Fit StandardScaler on train split only +- Encode categorical features as integer indices +- Remove yield outliers beyond 3 IQR +- Preserve metadata columns +- Save fitted transformers for inference-time reuse +- Calculate and save spatial cross-validation splits +""" + +import argparse +import json +import logging +import warnings +from pathlib import Path +from typing import Any, Dict, List, Tuple + +import joblib +import numpy as np +import pandas as pd +import torch +from sklearn.preprocessing import LabelEncoder, StandardScaler + +log = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Constants +# --------------------------------------------------------------------------- + +MODEL_READY_DATA_NAME = "yield_africa" + +TRAIN_COUNTRIES = ["BUR", "ETH", "KEN", "MAL", "RWA", "TAN", "ZAM"] + +SPATIAL_SPLIT_BLOCK_SIZE_KM = 50.0 +SPATIAL_SPLIT_N_SPLITS = 7 + +CONTINUOUS_FEATURES = [ + # Soil features + "C_0_20", + "C_20_50", + "N_0_20", + "N_20_50", + "P_0_20", + "P_20_50", + "MA_0_20", + "MA_20_50", + "PO_0_20", + "PO_20_50", + "pH_0_20", + "pH_20_50", + "BD_0_20", + "BD_20_50", + "ECX_0_20", + "ECX_20_50", + "CA_0_20", + "CA_20_50", + # Climate features + "PrecJJA", + "PrecMAM", + "PrecSON", + "PrecDJF", + "TaveJJA", + "TaveMAM", + "TaveSON", + "TaveDJF", + "TmaxJJA", + "TmaxMAM", + "TmaxSON", + "TmaxDJF", + "TminJJA", + "TminMAM", + "TminSON", + "TminDJF", + "CMD", + "Eref", + "MAP", + "MAT", + "TD", + "MWMT", + "MCMT", + "DD_above_5", + "DD_above_18", + "DD_below_18", + # Terrain features + "DEM", + "Slope", + "Aspect", + "CHILI", + "Top_div", + # Land cover / context + "Tree_c", + "Dist_water", + "Paved", + "Unpaved", + "Pop_10km", + # Derived features (computed automatically) + "CN_ratio", + "C_layer_delta", + "BD_layer_delta", + "WHC_proxy", + "aridity_index", +] + +# Categorical columns that are actual tabular inputs to the regression model (feat_ prefix). +# TX_*_cl are soil texture classes — they are not derived from a paired numerical column. +TABULAR_CATEGORICAL_FEATURES = [ + "TX_0_20_cl", + "TX_20_50_cl", +] + +# Categorical columns derived from their paired numerical columns (same name without _cl). +# Used for caption generation (aux_ prefix). +AUX_FEATURES = [ + # target (classified) + "Yld_ton_ha_cl", + # soil features + "C_0_20_cl", + "C_20_50_cl", + "N_0_20_cl", + "N_20_50_cl", + "P_0_20_cl", + "P_20_50_cl", + "MA_0_20_cl", + "MA_20_50_cl", + "PO_0_20_cl", + "PO_20_50_cl", + "pH_0_20_cl", + "pH_20_50_cl", + "BD_0_20_cl", + "BD_20_50_cl", + "ECX_0_20_cl", + "ECX_20_50_cl", + "CA_0_20_cl", + "CA_20_50_cl", + # climate features + "PrecJJA_cl", + "PrecMAM_cl", + "PrecSON_cl", + "PrecDJF_cl", + "TaveJJA_cl", + "TaveMAM_cl", + "TaveSON_cl", + "TaveDJF_cl", + "TmaxJJA_cl", + "TmaxMAM_cl", + "TmaxSON_cl", + "TmaxDJF_cl", + "TminJJA_cl", + "TminMAM_cl", + "TminSON_cl", + "TminDJF_cl", + "CMD_cl", + "Eref_cl", + "MAP_cl", + "MAT_cl", + "TD_cl", + "MCMT_cl", + "MWMT_cl", + "DD_above_5_cl", + "DD_above_18_cl", + "DD_below_18_cl", + # terrain features + "DEM_cl", + "Slope_cl", + "Aspect_cl", + "Landform_cl", + "CHILI_cl", + "Top_div_cl", + # land cover / context + "GLAD_cl", + "Tree_c_cl", + "Dist_water_cl", + "Paved_cl", + "Unpaved_cl", + "Pop_10km_cl", +] + +LOG_TRANSFORM_FEATURES = ["Dist_water", "Paved", "Unpaved", "Pop_10km"] + +TARGET_COLUMNS = ["Yld_ton_ha"] + +NAME_LOC_COLUMN = "ID" + +METADATA_COLUMNS = ["Lat", "Lon", "Country", "Year", "Location_accuracy"] + +# Saxton & Rawls (2006) Table 3-derived AWC (FC - WP) +# Conditions: ~2.5% OM, no salinity/gravel/density adjustment +# "Plant avail." (%v) converted to mm/m via mm/m = (%v) * 10 +# (because 1% v/v = 0.01 m³/m³ = 10 mm per m soil) +# Values are in mm/m (approximate field capacity - wilting point) +WHC_LOOKUP_SAXTON_RAWLS_2006_OM2P5 = { + "Sand": 50, + "Loamy sand": 70, + "Sandy loam": 100, + "Loam": 140, + "Silt loam": 200, + "Silt": 250, + "Sandy clay loam": 100, + "Clay loam": 140, + "Silty clay loam": 170, + "Silty clay": 140, + "Sandy clay": 110, + "Clay": 120, +} + +# --------------------------------------------------------------------------- +# Preprocessing functions +# --------------------------------------------------------------------------- + +def build_column_rename_map( + continuous_features: List[str], + tabular_categorical_features: List[str], + aux_features: List[str], + target_columns: List[str], + name_loc_column: str, + metadata_columns: List[str], +) -> Dict[str, str]: + """Build a column rename mapping that standardises predictor and target names. + + Convention: + - Numerical predictors and tabular categorical features: ``feat_{original.lower()}`` + - Aux/caption features (derived categorical classes): ``aux_{original.lower()}`` + - Target columns: ``target_{original.lower()}`` + - Name-location column: ``name_loc`` + - Metadata columns: ``{original.lower()}`` + """ + rename: Dict[str, str] = {} + for col in continuous_features: + rename[col] = f"feat_{col.lower()}" + for col in tabular_categorical_features: + rename[col] = f"feat_{col.lower()}" + for col in aux_features: + rename[col] = f"aux_{col.lower()}" + for col in target_columns: + rename[col] = f"target_{col.lower()}" + for col in metadata_columns: + rename[col] = col.lower() + if name_loc_column is not None: + rename[name_loc_column] = "name_loc" + return rename + + +def compute_derived_features(df: pd.DataFrame) -> pd.DataFrame: + """Compute derived features from raw measurements.""" + df = df.copy() + + # C:N ratio (guard against division by zero) + df["CN_ratio"] = np.where( + df["N_0_20"] > 0, + df["C_0_20"] / df["N_0_20"], + np.nan, + ) + + # Layer deltas (stratification indicators) + df["C_layer_delta"] = df["C_0_20"] - df["C_20_50"] + df["BD_layer_delta"] = df["BD_0_20"] - df["BD_20_50"] + + # Water Holding Capacity proxy from texture lookup, adjusted by bulk density + _whc_lookup_lower = {k.lower(): v for k, v in WHC_LOOKUP_SAXTON_RAWLS_2006_OM2P5.items()} + df["WHC_proxy"] = ( + df["TX_0_20_cl"] + .astype(str) + .str.lower() + .map(_whc_lookup_lower) + .fillna(WHC_LOOKUP_SAXTON_RAWLS_2006_OM2P5["Sandy loam"]) + ) + + # Adjust WHC by bulk density (inverse relationship, reference BD = 1.3 g/cm³) + bd_factor = np.where(df["BD_0_20"] > 0, 1.3 / df["BD_0_20"], 1.0) + df["WHC_proxy"] = df["WHC_proxy"] * bd_factor + + # Aridity index (guard against MAP=0) + df["aridity_index"] = np.where(df["MAP"] > 0, df["CMD"] / df["MAP"], np.nan) + + return df + + +def apply_log_transforms(df: pd.DataFrame, log_transform_features: List[str]) -> pd.DataFrame: + """Apply log(x + 1) transform to skewed features.""" + df = df.copy() + for col in log_transform_features: + if col in df.columns: + df[col] = np.log1p(np.maximum(df[col], 0)) + return df + + +def remove_yield_outliers( + df: pd.DataFrame, + target_col: str = "Yld_ton_ha", + iqr_multiplier: float = 3.0, +) -> Tuple[pd.DataFrame, pd.Series]: + """Remove yield outliers beyond IQR threshold.""" + if target_col not in df.columns: + warnings.warn(f"Target column '{target_col}' not found; skipping outlier removal") + return df, pd.Series([False] * len(df), index=df.index) + + q1 = df[target_col].quantile(0.25) + q3 = df[target_col].quantile(0.75) + iqr = q3 - q1 + lower_bound = q1 - iqr_multiplier * iqr + upper_bound = q3 + iqr_multiplier * iqr + outlier_mask = (df[target_col] < lower_bound) | (df[target_col] > upper_bound) + + n_outliers = outlier_mask.sum() + if n_outliers > 0: + log.info( + f"Removing {n_outliers} yield outliers (< {lower_bound:.2f} or > {upper_bound:.2f} t/ha)" + ) + + return df[~outlier_mask].copy(), outlier_mask + + +def fit_scaler(df: pd.DataFrame, continuous_features: List[str]) -> StandardScaler: + """Fit StandardScaler on continuous features.""" + available_features = [f for f in continuous_features if f in df.columns] + if len(available_features) < len(continuous_features): + missing = set(continuous_features) - set(available_features) + warnings.warn(f"Missing continuous features: {missing}") + scaler = StandardScaler() + scaler.fit(df[available_features]) + return scaler + + +def apply_scaler( + df: pd.DataFrame, + scaler: StandardScaler, + continuous_features: List[str], +) -> pd.DataFrame: + """Apply fitted StandardScaler to continuous features.""" + df = df.copy() + available_features = [f for f in continuous_features if f in df.columns] + df[available_features] = scaler.transform(df[available_features]) + return df + + +def fit_label_encoders( + df: pd.DataFrame, + categorical_features: List[str], +) -> Dict[str, LabelEncoder]: + """Fit LabelEncoders for categorical features.""" + encoders = {} + for col in categorical_features: + if col not in df.columns: + warnings.warn(f"Categorical feature '{col}' not found; skipping") + continue + encoder = LabelEncoder() + encoder.fit(df[col].dropna()) + encoders[col] = encoder + log.info(f" {col}: {len(encoder.classes_)} classes") + return encoders + + +def apply_label_encoders( + df: pd.DataFrame, + encoders: Dict[str, LabelEncoder], + categorical_features: List[str], +) -> pd.DataFrame: + """Apply fitted LabelEncoders to categorical features.""" + df = df.copy() + for col in categorical_features: + if col not in df.columns: + continue + if col not in encoders: + warnings.warn(f"No encoder found for '{col}'; skipping") + continue + encoder = encoders[col] + df[col] = df[col].apply( + lambda x: encoder.transform([x])[0] if x in encoder.classes_ else -1 + ) + return df + + +def calculate_spatial_splits( + df: pd.DataFrame, + block_size_km: float = 50.0, + n_splits: int = 7, + lat_col: str = "lat", + lon_col: str = "lon", + name_loc_col: str = "name_loc", + save_path: str | Path | None = None, +) -> Dict[str, Any]: + """Calculate spatial cross-validation splits using a grid-based blocking approach.""" + log.info(f"Calculating spatial splits with block size {block_size_km}km and {n_splits} folds") + + if "split" in df.columns: + train_df = df[df["split"] == "train"].copy() + log.info(f"Filtered to training split: {len(train_df)} samples") + else: + train_df = df.copy() + log.info(f"No 'split' column found; using all {len(train_df)} samples for spatial splits") + + # Approx conversion: 1 deg ~ 111 km + block_size_deg = block_size_km / 111.0 + train_df["lat_grid"] = np.floor(train_df[lat_col] / block_size_deg) + train_df["lon_grid"] = np.floor(train_df[lon_col] / block_size_deg) + train_df["block_id"] = ( + train_df["lat_grid"].astype(str) + "_" + train_df["lon_grid"].astype(str) + ) + + unique_blocks = train_df["block_id"].unique() + log.info(f"Created {len(unique_blocks)} spatial blocks") + + # Greedy bin packing: assign blocks largest-first to the smallest fold + block_counts = train_df["block_id"].value_counts().sort_values(ascending=False) + fold_samples = [0] * n_splits + fold_block_ids: List[List[str]] = [[] for _ in range(n_splits)] + + for block_id, count in block_counts.items(): + smallest_fold = int(np.argmin(fold_samples)) + fold_samples[smallest_fold] += count + fold_block_ids[smallest_fold].append(block_id) + + block_to_names = train_df.groupby("block_id")[name_loc_col].unique().to_dict() + + splits: Dict[str, Any] = {} + for fold in range(n_splits): + val_names = [ + name + for bid in fold_block_ids[fold] + for name in block_to_names[bid].tolist() + ] + train_names = [ + name + for f in range(n_splits) + if f != fold + for bid in fold_block_ids[f] + for name in block_to_names[bid].tolist() + ] + splits[f"fold_{fold}"] = {"train": train_names, "val": val_names} + log.info( + f" Fold {fold}: {len(train_names)} train locations, {len(val_names)} val locations " + f"({fold_samples[fold]} samples)" + ) + + if save_path: + save_path = Path(save_path) + save_path.parent.mkdir(parents=True, exist_ok=True) + torch.save(splits, save_path) + log.info(f"Saved spatial splits to {save_path}") + + return splits + + +# --------------------------------------------------------------------------- +# Main pipeline +# --------------------------------------------------------------------------- + + +def main( + source_csv: str, + out_csv: str, + out_parquet: str | None = None, + spatial_splits: bool = False, + countries: List[str] | None = None, + years: List[int] | None = None, + exclude_countries: List[str] | None = None, + exclude_years: List[int] | None = None, +) -> pd.DataFrame: + """Preprocessing workflow for the crop yield Africa dataset.""" + data_path = Path(source_csv) + out_csv_path = Path(out_csv) + data_dir = out_csv_path.parent + + scaler_path = data_dir / f"fitted_scaler_{MODEL_READY_DATA_NAME}.pkl" + encoders_path = data_dir / f"label_encoders_{MODEL_READY_DATA_NAME}.pkl" + spatial_split_path = ( + data_dir + / "splits" + / ( + f"split_spatial_{SPATIAL_SPLIT_N_SPLITS}_folds_with_" + f"{SPATIAL_SPLIT_BLOCK_SIZE_KM}km_blocks_{MODEL_READY_DATA_NAME}.pth" + ) + ) + + log.info("Starting preprocessing pipeline...") + log.info(f"Input: {data_path}") + + # Load raw data + raw_df = pd.read_csv(data_path) + n_raw = len(raw_df) + log.info(f"Loaded {n_raw} rows from {data_path}") + + # Filter by country and year before any other processing + df = raw_df.copy() + if countries is not None: + before = len(df) + df = df[df["Country"].isin(countries)].copy() + log.info( + f"Country filter ({', '.join(sorted(countries))}): " + f"kept {len(df)}, excluded {before - len(df)}" + ) + if years is not None: + before = len(df) + df = df[df["Year"].isin(years)].copy() + log.info( + f"Year filter ({', '.join(str(y) for y in sorted(years))}): " + f"kept {len(df)}, excluded {before - len(df)}" + ) + if exclude_countries is not None: + before = len(df) + df = df[~df["Country"].isin(exclude_countries)].copy() + log.info( + f"Exclude countries ({', '.join(sorted(exclude_countries))}): " + f"kept {len(df)}, excluded {before - len(df)}" + ) + if exclude_years is not None: + before = len(df) + df = df[~df["Year"].isin(exclude_years)].copy() + log.info( + f"Exclude years ({', '.join(str(y) for y in sorted(exclude_years))}): " + f"kept {len(df)}, excluded {before - len(df)}" + ) + n_after_filter = len(df) + + # Determine train mask on the filtered data + train_mask = df["Country"].isin(TRAIN_COUNTRIES) + log.info( + f"Train mask: {train_mask.sum()} rows ({', '.join(TRAIN_COUNTRIES)}) out of {n_after_filter} total" + ) + + # Compute derived features + log.info("Computing derived features...") + df = compute_derived_features(df) + + # Remove yield outliers + log.info("Removing yield outliers...") + df, outlier_mask = remove_yield_outliers(df, TARGET_COLUMNS[0], iqr_multiplier=3.0) + # Re-align train_mask after outlier removal + train_mask = train_mask[df.index] + + # Apply log transforms + log.info("Applying log transforms...") + df = apply_log_transforms(df, LOG_TRANSFORM_FEATURES) + + train_df = df[train_mask] + log.info(f"Training set size: {len(train_df)} samples") + + # Fit transformers on train split only + log.info("Fitting StandardScaler on train split...") + scaler = fit_scaler(train_df, CONTINUOUS_FEATURES) + + log.info("Fitting LabelEncoders on train split...") + all_categorical = TABULAR_CATEGORICAL_FEATURES + AUX_FEATURES + encoders = fit_label_encoders(train_df, all_categorical) + + # Apply transformations to full dataset + log.info("Applying transformations to full dataset...") + df = apply_scaler(df, scaler, CONTINUOUS_FEATURES) + df = apply_label_encoders(df, encoders, all_categorical) + + # Save transformers + scaler_path.parent.mkdir(parents=True, exist_ok=True) + joblib.dump(scaler, scaler_path) + log.info(f"Saved scaler to {scaler_path}") + + encoders_path.parent.mkdir(parents=True, exist_ok=True) + joblib.dump(encoders, encoders_path) + log.info(f"Saved encoders to {encoders_path}") + + # Validation checks + log.info("Validation checks:") + derived_cols = ["CN_ratio", "C_layer_delta", "BD_layer_delta", "WHC_proxy", "aridity_index"] + for col in derived_cols: + if col in df.columns: + n_nan = df[col].isna().sum() + n_inf = np.isinf(df[col]).sum() + if n_nan > 0 or n_inf > 0: + warnings.warn(f" {col}: {n_nan} NaN, {n_inf} Inf values") + else: + log.info(f" {col}: no NaN or Inf") + + for col in all_categorical: + if col in df.columns and col in encoders: + n_classes = len(encoders[col].classes_) + min_val = df[col].min() + max_val = df[col].max() + if min_val < 0 or max_val >= n_classes: + warnings.warn( + f" {col}: indices [{min_val}, {max_val}] out of range [0, {n_classes-1}]" + ) + + # Apply canonical column rename (feat_/aux_/target_ prefixes, lowercase meta) + rename_map = build_column_rename_map( + continuous_features=CONTINUOUS_FEATURES, + tabular_categorical_features=TABULAR_CATEGORICAL_FEATURES, + aux_features=AUX_FEATURES, + target_columns=TARGET_COLUMNS, + name_loc_column=NAME_LOC_COLUMN, + metadata_columns=METADATA_COLUMNS, + ) + df = df.rename(columns=rename_map) + + # Prefix name_loc IDs with country name + if "name_loc" in df.columns and "country" in df.columns: + df["name_loc"] = df["country"].astype(str).str.upper() + "_" + df["name_loc"].astype(str) + log.info("Prefixed name_loc IDs with country names") + n_duplicates = df["name_loc"].duplicated().sum() + if n_duplicates > 0: + warnings.warn(f"Found {n_duplicates} duplicate name_loc IDs after prefixing") + else: + log.info(f" No duplicates in name_loc ({df['name_loc'].nunique()} unique IDs)") + + # Convert location_accuracy from text to numerical values + if "location_accuracy" in df.columns: + accuracy_map = { + "high location accuracy": 1, + "medium location accuracy": 2, + "low location accuracy": 3, + } + df["location_accuracy"] = df["location_accuracy"].str.lower().map(accuracy_map) + + # Keep scaler metadata in sync with new column names + if hasattr(scaler, "feature_names_in_") and scaler.feature_names_in_ is not None: + scaler.feature_names_in_ = np.array( + [rename_map.get(n, n) for n in scaler.feature_names_in_] + ) + encoders = {rename_map.get(k, k): v for k, v in encoders.items()} + + # Calculate and save spatial splits (optional) + if spatial_splits: + calculate_spatial_splits( + df=df, + block_size_km=SPATIAL_SPLIT_BLOCK_SIZE_KM, + n_splits=SPATIAL_SPLIT_N_SPLITS, + save_path=spatial_split_path, + ) + else: + log.info("Skipping spatial split calculation (use --spatial_splits to enable)") + + # Save outputs + out_csv_path.parent.mkdir(parents=True, exist_ok=True) + df.to_csv(out_csv_path, index=False, float_format="%.7f") + log.info(f"Saved CSV to {out_csv_path}") + + if out_parquet is not None: + out_parquet_path = Path(out_parquet) + out_parquet_path.parent.mkdir(parents=True, exist_ok=True) + df.to_parquet(out_parquet_path, index=False) + log.info(f"Saved Parquet to {out_parquet_path}") + + log.info("=== Summary ===") + log.info(f" Raw rows loaded: {n_raw}") + log.info(f" Rows excluded by country/year filter: {n_raw - n_after_filter}") + log.info(f" Rows in output: {len(df)}") + log.info(f" Continuous features: {len(CONTINUOUS_FEATURES)}") + log.info(f" Tabular categorical features (feat_): {len(TABULAR_CATEGORICAL_FEATURES)}") + log.info(f" Aux/caption features (aux_): {len(AUX_FEATURES)}") + log.info( + f" Yield range: {df['target_yld_ton_ha'].min():.2f} - {df['target_yld_ton_ha'].max():.2f} t/ha" + ) + log.info(f" Mean yield: {df['target_yld_ton_ha'].mean():.2f} t/ha") + log.info(" Records per country and year:") + counts = df.groupby(["country", "year"]).size().unstack(fill_value=0) + for country, row in counts.iterrows(): + year_counts = [f"{year}: {count}" for year, count in row.items() if count > 0] + log.info(f" {country}: {', '.join(year_counts)}") + + return df + + +if __name__ == "__main__": + logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" + ) + + ap = argparse.ArgumentParser( + description="Build model-ready CSV/Parquet for the crop yield Africa use case." + ) + ap.add_argument( + "--source_csv", + required=True, + help="Path to the raw input CSV (e.g. data/yield_africa/yield_africa_v20260218.csv)", + ) + ap.add_argument( + "--out_csv", + required=True, + help="Path for the output model-ready CSV (e.g. data/yield_africa/model_ready_yield-africa.csv)", + ) + ap.add_argument( + "--out_parquet", + default=None, + help="Optional path for an additional Parquet output.", + ) + ap.add_argument( + "--spatial_splits", + action="store_true", + default=False, + help="Calculate and save spatial cross-validation splits (default: off).", + ) + ap.add_argument( + "--countries", + nargs="+", + default=None, + metavar="CODE", + help="Country codes to include (e.g. --countries ETH KEN TAN). Default: all countries.", + ) + ap.add_argument( + "--years", + nargs="+", + type=int, + default=None, + metavar="YEAR", + help="Years to include (e.g. --years 2018 2019 2020). Default: all years.", + ) + ap.add_argument( + "--exclude_countries", + nargs="+", + default=None, + metavar="CODE", + help="Country codes to exclude (e.g. --exclude_countries MOZ ZIM).", + ) + ap.add_argument( + "--exclude_years", + nargs="+", + type=int, + default=None, + metavar="YEAR", + help="Years to exclude (e.g. --exclude_years 2015 2016).", + ) + args = ap.parse_args() + main( + args.source_csv, + args.out_csv, + args.out_parquet, + args.spatial_splits, + args.countries, + args.years, + args.exclude_countries, + args.exclude_years, + ) \ No newline at end of file diff --git a/src/models/components/loss_fns/huber_loss.py b/src/models/components/loss_fns/huber_loss.py new file mode 100644 index 0000000..0fa2ff1 --- /dev/null +++ b/src/models/components/loss_fns/huber_loss.py @@ -0,0 +1,29 @@ +from typing import Dict, override + +import torch + +from src.models.components.loss_fns.base_loss_fn import BaseLossFn + + +class HuberLoss(BaseLossFn): + def __init__(self) -> None: + super().__init__() + self.criterion = torch.nn.HuberLoss(delta=1.0, reduction="mean") + self.name = "huber_loss" + + @override + def forward( + self, + pred: torch.Tensor, + labels: torch.Tensor | None = None, + batch: Dict[str, torch.Tensor] | None = None, + **kwargs, + ) -> torch.Tensor | Dict[str, torch.Tensor]: + + labels = labels if labels is not None else batch.get("target") + huber_loss = self.criterion(pred, labels) + + if "return_label" in kwargs: + return {self.name: huber_loss} + else: + return huber_loss \ No newline at end of file diff --git a/src/models/predictive_model.py b/src/models/predictive_model.py index a0b4212..24b8327 100644 --- a/src/models/predictive_model.py +++ b/src/models/predictive_model.py @@ -1,7 +1,6 @@ from typing import Dict, override import torch -import torch.nn.functional as F from src.models.base_model import BaseModel from src.models.components.eo_encoders.base_eo_encoder import BaseEOEncoder @@ -75,22 +74,23 @@ def setup_encoders_adapters(self): @override def forward(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: feats = self.eo_encoder(batch) - feats = F.normalize(feats, dim=-1) return self.prediction_head(feats) @override def _step(self, batch: Dict[str, torch.Tensor], mode: str = "train") -> torch.Tensor: - feats = self.forward(batch) + preds = self.forward(batch) log_kwargs = dict( - on_step=False, on_epoch=True, prog_bar=True, sync_dist=True, batch_size=feats.size(0) + on_step=False, on_epoch=True, prog_bar=True, sync_dist=True, batch_size=preds.size(0) ) - loss = self.loss_fn(feats, batch.get("target")) + loss = self.loss_fn(preds, batch.get("target")) self.log(f"{mode}_loss", loss, **log_kwargs) - metrics = self.metrics(pred=feats, batch=batch, mode=mode) + metrics = self.metrics(pred=preds, batch=batch, mode=mode) self.log_dict(metrics, **log_kwargs) + return loss + if __name__ == "__main__": _ = PredictiveModel(None, None, None, None, None, None, None) From 8fa2a9fb6619aa821ff69fce199007257bfcd3c6 Mon Sep 17 00:00:00 2001 From: Rob Knapen Date: Sat, 7 Mar 2026 16:45:56 +0100 Subject: [PATCH 19/60] Crop Yield Africa use case - example coords and fusion experiments; improved epoch r2 metric calculation. --- configs/data/yield_africa_all.yaml | 11 +- .../experiment/yield_africa_coords_reg.yaml | 25 ++ .../experiment/yield_africa_fusion_reg.yaml | 25 ++ configs/model/yield_fusion_reg.yaml | 33 ++ configs/model/yield_geoclip_reg.yaml | 32 ++ configs/model/yield_tabular_reg.yaml | 4 +- .../components/metrics/metrics_wrapper.py | 2 +- src/models/components/metrics/r2.py | 27 +- tests/test_datasets_and_datamodules.py | 3 +- tests/test_yield_africa.py | 310 ++++++++++++++++++ 10 files changed, 458 insertions(+), 14 deletions(-) create mode 100644 configs/experiment/yield_africa_coords_reg.yaml create mode 100644 configs/experiment/yield_africa_fusion_reg.yaml create mode 100644 configs/model/yield_fusion_reg.yaml create mode 100644 configs/model/yield_geoclip_reg.yaml create mode 100644 tests/test_yield_africa.py diff --git a/configs/data/yield_africa_all.yaml b/configs/data/yield_africa_all.yaml index f44a407..8667a4c 100644 --- a/configs/data/yield_africa_all.yaml +++ b/configs/data/yield_africa_all.yaml @@ -13,15 +13,20 @@ dataset: # Country/year filters — set to a list to restrict, null to include all. # countries and years select only the listed values; # exclude_countries and exclude_years drop the listed values. - countries: null - years: null - exclude_countries: ["BF"] + countries: ["KEN", "RWA", "TAN", "ZAM", "MAL"] + years: [2016, 2017, 2018, 2019, 2020, 2021] + exclude_countries: null exclude_years: null batch_size: 64 num_workers: 0 pin_memory: false +# todo - use spatial split (pre-calculate and then load from file) +# - hold out country/year block for validation +# - or leave one country out for validation +# - normalize data by country (after filtering) + split_mode: "random" train_val_test_split: [0.7, 0.15, 0.15] save_split: false diff --git a/configs/experiment/yield_africa_coords_reg.yaml b/configs/experiment/yield_africa_coords_reg.yaml new file mode 100644 index 0000000..5e967e2 --- /dev/null +++ b/configs/experiment/yield_africa_coords_reg.yaml @@ -0,0 +1,25 @@ +# @package _global_ +# configs/experiment/yield_africa_tabular_reg.yaml +# Variant: Tabular features only, full dataset + +defaults: + - override /model: yield_geoclip_reg + - override /data: yield_africa_all + - override /metrics: yield_africa_regression + +tags: ["yield_africa", "coords_only", "regression"] +seed: 12345 + +trainer: + min_epochs: 1 + max_epochs: 50 + +data: + batch_size: 64 + +logger: + wandb: + tags: ${tags} + group: "yield_africa" + aim: + experiment: "yield_africa" diff --git a/configs/experiment/yield_africa_fusion_reg.yaml b/configs/experiment/yield_africa_fusion_reg.yaml new file mode 100644 index 0000000..2cf0093 --- /dev/null +++ b/configs/experiment/yield_africa_fusion_reg.yaml @@ -0,0 +1,25 @@ +# @package _global_ +# configs/experiment/heat_guatemala_fusion_reg.yaml +# Variant C: GeoClip + tabular fusion + +defaults: + - override /model: yield_fusion_reg + - override /data: yield_africa_all + - override /metrics: yield_africa_regression + +tags: ["yield_africa", "fusion", "regression"] +seed: 12345 + +trainer: + min_epochs: 1 + max_epochs: 50 + +data: + batch_size: 64 + +logger: + wandb: + tags: ${tags} + group: "yield_africa" + aim: + experiment: "yield_africa" diff --git a/configs/model/yield_fusion_reg.yaml b/configs/model/yield_fusion_reg.yaml new file mode 100644 index 0000000..3f64972 --- /dev/null +++ b/configs/model/yield_fusion_reg.yaml @@ -0,0 +1,33 @@ +_target_: src.models.predictive_model.PredictiveModel + +eo_encoder: + _target_: src.models.components.eo_encoders.multimodal_encoder.MultiModalEncoder + use_coords: true + use_tabular: true +# tab_embed_dim: 64 + +prediction_head: + _target_: src.models.components.pred_heads.mlp_regression_head.MLPRegressionPredictionHead + nn_layers: 2 + hidden_dim: 256 + +# GeoClip frozen; tabular projection + head are trained. +trainable_modules: [eo_encoder, prediction_head] + +metrics: ${metrics} + +optimizer: + _target_: torch.optim.Adam + _partial_: true + lr: 0.001 + weight_decay: 0.0 + +scheduler: + _target_: torch.optim.lr_scheduler.ReduceLROnPlateau + _partial_: true + mode: min + factor: 0.1 + patience: 10 + +loss_fn: + _target_: src.models.components.loss_fns.huber_loss.HuberLoss diff --git a/configs/model/yield_geoclip_reg.yaml b/configs/model/yield_geoclip_reg.yaml new file mode 100644 index 0000000..7bee09b --- /dev/null +++ b/configs/model/yield_geoclip_reg.yaml @@ -0,0 +1,32 @@ +_target_: src.models.predictive_model.PredictiveModel + +eo_encoder: + _target_: src.models.components.eo_encoders.multimodal_encoder.MultiModalEncoder + use_coords: true + use_tabular: false + +prediction_head: + _target_: src.models.components.pred_heads.mlp_regression_head.MLPRegressionPredictionHead + nn_layers: 2 + hidden_dim: 256 + +# Only the prediction head is trained; GeoClip encoder is frozen. +trainable_modules: [prediction_head] + +metrics: ${metrics} + +optimizer: + _target_: torch.optim.Adam + _partial_: true + lr: 0.001 + weight_decay: 0.0 + +scheduler: + _target_: torch.optim.lr_scheduler.ReduceLROnPlateau + _partial_: true + mode: min + factor: 0.1 + patience: 10 + +loss_fn: + _target_: src.models.components.loss_fns.huber_loss.HuberLoss diff --git a/configs/model/yield_tabular_reg.yaml b/configs/model/yield_tabular_reg.yaml index 146c683..22251bf 100644 --- a/configs/model/yield_tabular_reg.yaml +++ b/configs/model/yield_tabular_reg.yaml @@ -9,7 +9,7 @@ eo_encoder: prediction_head: _target_: src.models.components.pred_heads.mlp_regression_head.MLPRegressionPredictionHead nn_layers: 2 - hidden_dim: 64 + hidden_dim: 256 # Both encoder and head have trainable parameters. trainable_modules: [eo_encoder, prediction_head] @@ -20,7 +20,7 @@ optimizer: _target_: torch.optim.Adam _partial_: true lr: 0.001 - weight_decay: 0.0001 + weight_decay: 0.0 scheduler: _target_: torch.optim.lr_scheduler.ReduceLROnPlateau diff --git a/src/models/components/metrics/metrics_wrapper.py b/src/models/components/metrics/metrics_wrapper.py index c15f395..33d86de 100644 --- a/src/models/components/metrics/metrics_wrapper.py +++ b/src/models/components/metrics/metrics_wrapper.py @@ -10,7 +10,7 @@ class MetricsWrapper(nn.Module): def __init__(self, metrics: List[BaseMetrics | BaseLossFn]) -> None: super().__init__() - self.metrics = metrics + self.metrics = nn.ModuleList(metrics) def forward(self, mode="train", **kwargs) -> Dict[str, torch.float]: """Calculates all metrics and adds all the results into one dictionary for logging.""" diff --git a/src/models/components/metrics/r2.py b/src/models/components/metrics/r2.py index cd2d38e..baccb06 100644 --- a/src/models/components/metrics/r2.py +++ b/src/models/components/metrics/r2.py @@ -1,14 +1,29 @@ from typing import Dict, override import torch +from torch import nn +from torchmetrics.regression import R2Score from src.models.components.metrics.base_metrics import BaseMetrics +_MODES = ("train", "val", "test") + class RSquared(BaseMetrics): + """Epoch-level R² using torchmetrics.R2Score. + + A separate R2Score accumulator is kept per mode so that train, val, and + test statistics never mix. Lightning detects the returned torchmetrics + Metric objects and calls .compute()/.reset() at epoch boundaries, giving + a correct epoch-wide R² instead of an average of per-batch R² values. + """ + def __init__(self) -> None: super().__init__() self.name = "r2" + # Keys are prefixed to avoid clashing with nn.Module attribute names + # (e.g. "train" conflicts with nn.Module.train()). + self._r2 = nn.ModuleDict({f"mode_{m}": R2Score() for m in _MODES}) @override def forward( @@ -17,12 +32,10 @@ def forward( labels: torch.Tensor | None = None, batch: Dict[str, torch.Tensor] | None = None, **kwargs, - ) -> torch.Tensor | Dict[str, torch.Tensor]: - + ) -> Dict[str, torch.Tensor]: labels = labels if labels is not None else batch.get("target") + mode = kwargs.get("mode", "train") - ss_res = torch.sum((labels - pred) ** 2) - ss_tot = torch.sum((labels - torch.mean(labels)) ** 2) + 1e-12 - r2 = 1.0 - ss_res / ss_tot - - return {self.name: r2} + metric = self._r2[f"mode_{mode}"] + metric.update(pred.squeeze(-1), labels.squeeze(-1)) + return {self.name: metric} \ No newline at end of file diff --git a/tests/test_datasets_and_datamodules.py b/tests/test_datasets_and_datamodules.py index 2141987..20c94be 100644 --- a/tests/test_datasets_and_datamodules.py +++ b/tests/test_datasets_and_datamodules.py @@ -2,11 +2,12 @@ from src.data.butterfly_dataset import ButterflyDataset from src.data.heat_guatemala_dataset import HeatGuatemalaDataset from src.data.satbird_dataset import SatBirdDataset +from src.data.yield_africa_dataset import YieldAfricaDataset def test_datasets_generic_properties(request, tmp_path, sample_csv): """This test checks that all datasets implement the basic properties and methods.""" - list_datasets = [ButterflyDataset, SatBirdDataset, HeatGuatemalaDataset] + list_datasets = [ButterflyDataset, SatBirdDataset, HeatGuatemalaDataset, YieldAfricaDataset] use_mock = request.config.getoption("--use-mock") if use_mock: csv_dir = sample_csv diff --git a/tests/test_yield_africa.py b/tests/test_yield_africa.py new file mode 100644 index 0000000..30f32d2 --- /dev/null +++ b/tests/test_yield_africa.py @@ -0,0 +1,310 @@ +"""Tests for the yield_africa use case. + +The mock CSV mirrors the schema produced by make_model_ready_yield_africa.py: + - name_loc, lat, lon + - target_yld_ton_ha + - feat_* (continuous soil/climate/terrain features + tabular categorical soil texture) + - aux_* (derived classification columns, used for caption generation) + - metadata: country, year, location_accuracy +""" + +import hydra +import pandas as pd +import pytest +import torch +from hydra import compose, initialize +from hydra.core.global_hydra import GlobalHydra + +from src.data.base_datamodule import BaseDataModule +from src.data.yield_africa_dataset import YieldAfricaDataset + +# --------------------------------------------------------------------------- +# Representative column sets that match the real model_ready_yield-africa.csv +# --------------------------------------------------------------------------- + +MOCK_FEAT_COLS = { + # continuous soil features + "feat_c_0_20": [1.2, 0.9, 1.5, 1.1, 0.8, 1.4, 1.6, 1.0, 1.3, 1.1], + "feat_n_0_20": [0.12, 0.09, 0.15, 0.11, 0.08, 0.14, 0.16, 0.10, 0.13, 0.11], + "feat_ph_0_20": [6.1, 5.8, 6.5, 6.0, 5.5, 6.3, 6.8, 5.9, 6.2, 6.4], + # continuous climate features + "feat_map": [820, 750, 910, 860, 700, 880, 930, 770, 815, 875], + "feat_mat": [22.1, 21.5, 23.0, 22.5, 21.0, 22.8, 23.3, 21.9, 22.2, 22.7], + # continuous terrain feature + "feat_dem": [450, 380, 510, 470, 360, 490, 530, 400, 460, 500], + # tabular categorical: soil texture class (real columns, not derived) + "feat_tx_0_20_cl": [2, 3, 1, 2, 4, 1, 3, 2, 1, 3], + "feat_tx_20_50_cl": [2, 2, 1, 3, 4, 1, 2, 3, 1, 2], +} + +MOCK_AUX_COLS = { + # derived classification columns (paired with the continuous feat_* above) + "aux_yld_ton_ha_cl": [1, 0, 2, 1, 0, 2, 2, 0, 1, 1], + "aux_c_0_20_cl": [1, 0, 2, 1, 0, 2, 2, 0, 1, 1], + "aux_ph_0_20_cl": [2, 1, 2, 2, 0, 2, 3, 1, 2, 2], + "aux_map_cl": [1, 0, 2, 1, 0, 2, 2, 0, 1, 2], +} + +MOCK_N_ROWS = 10 +MOCK_TABULAR_DIM = len(MOCK_FEAT_COLS) # 8 +MOCK_N_AUX = len(MOCK_AUX_COLS) # 4 + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def yield_africa_csv(tmp_path) -> str: + """Mock CSV with column names matching the real model_ready_yield-africa.csv.""" + data = { + "name_loc": [f"ETH_{i:04d}" for i in range(MOCK_N_ROWS)], + "lat": [5.0 + i * 0.5 for i in range(MOCK_N_ROWS)], + "lon": [30.0 + i * 0.5 for i in range(MOCK_N_ROWS)], + "target_yld_ton_ha": [2.1, 1.8, 3.0, 2.5, 1.2, 2.8, 3.3, 1.9, 2.0, 2.7], + "country": ["ETH"] * MOCK_N_ROWS, + "year": [2019] * MOCK_N_ROWS, + "location_accuracy": [1] * MOCK_N_ROWS, + } + data.update(MOCK_FEAT_COLS) + data.update(MOCK_AUX_COLS) + + mock_dir = tmp_path / "mock" + mock_dir.mkdir(parents=True, exist_ok=True) + pd.DataFrame(data).to_csv(mock_dir / "model_ready_mock.csv", index=False) + return str(tmp_path) + + +@pytest.fixture +def yield_africa_dataset(yield_africa_csv, tmp_path): + """YieldAfricaDataset backed by mock data, features enabled, no aux.""" + return YieldAfricaDataset( + data_dir=yield_africa_csv, + cache_dir=str(tmp_path / "cache"), + modalities={"coords": {}}, + use_target_data=True, + use_aux_data="none", + seed=42, + mock=True, + use_features=True, + ) + + +@pytest.fixture +def yield_africa_dataset_with_aux(yield_africa_csv, tmp_path): + """YieldAfricaDataset backed by mock data, features enabled, aux enabled.""" + return YieldAfricaDataset( + data_dir=yield_africa_csv, + cache_dir=str(tmp_path / "cache"), + modalities={"coords": {}}, + use_target_data=True, + use_aux_data="all", + seed=42, + mock=True, + use_features=True, + ) + + +@pytest.fixture +def yield_africa_datamodule(yield_africa_csv, tmp_path): + """BaseDataModule wrapping YieldAfricaDataset with a random split.""" + dataset = YieldAfricaDataset( + data_dir=yield_africa_csv, + cache_dir=str(tmp_path / "cache"), + modalities={"coords": {}}, + use_target_data=True, + use_aux_data="none", + seed=42, + mock=True, + use_features=True, + ) + return BaseDataModule( + dataset=dataset, + batch_size=4, + train_val_test_split=(7, 2, 1), + num_workers=0, + pin_memory=False, + split_mode="random", + save_split=False, + seed=42, + ) + + +# --------------------------------------------------------------------------- +# Dataset tests +# --------------------------------------------------------------------------- + + +def test_yield_africa_dataset_length(yield_africa_dataset): + assert len(yield_africa_dataset) == MOCK_N_ROWS + + +def test_yield_africa_dataset_sample_keys(yield_africa_dataset): + sample = yield_africa_dataset[0] + assert "eo" in sample + assert "coords" in sample["eo"] + assert "tabular" in sample["eo"] + assert "target" in sample + + +def test_yield_africa_dataset_sample_shapes(yield_africa_dataset): + sample = yield_africa_dataset[0] + assert sample["eo"]["coords"].shape == (2,) + assert sample["eo"]["tabular"].shape == (MOCK_TABULAR_DIM,) + assert sample["target"].shape == (1,) + + +def test_yield_africa_dataset_sample_dtypes(yield_africa_dataset): + sample = yield_africa_dataset[0] + assert sample["eo"]["coords"].dtype == torch.float32 + assert sample["eo"]["tabular"].dtype == torch.float32 + assert sample["target"].dtype == torch.float32 + + +def test_yield_africa_dataset_target_name(yield_africa_dataset): + assert yield_africa_dataset.target_names == ["target_yld_ton_ha"] + + +def test_yield_africa_dataset_attributes(yield_africa_dataset): + assert yield_africa_dataset.num_classes == 1 + assert yield_africa_dataset.tabular_dim == MOCK_TABULAR_DIM + assert set(yield_africa_dataset.feat_names) == set(MOCK_FEAT_COLS.keys()) + + +def test_yield_africa_dataset_feat_prefix(yield_africa_dataset): + """All tabular features must carry the feat_ prefix.""" + for name in yield_africa_dataset.feat_names: + assert name.startswith("feat_"), f"Unexpected feature name: {name}" + + +def test_yield_africa_dataset_coords_values(yield_africa_dataset): + """Coordinates returned must match the CSV values.""" + sample = yield_africa_dataset[0] + coords = sample["eo"]["coords"] + assert coords[0].item() == pytest.approx(5.0) # lat of row 0 + assert coords[1].item() == pytest.approx(30.0) # lon of row 0 + + +def test_yield_africa_dataset_target_values(yield_africa_dataset): + """Target values returned must match the CSV values.""" + expected = [2.1, 1.8, 3.0, 2.5, 1.2, 2.8, 3.3, 1.9, 2.0, 2.7] + for idx, exp in enumerate(expected): + sample = yield_africa_dataset[idx] + assert sample["target"][0].item() == pytest.approx(exp, rel=1e-5) + + +def test_yield_africa_dataset_no_features(yield_africa_csv, tmp_path): + """With use_features=False, tabular is absent and tabular_dim is None.""" + ds = YieldAfricaDataset( + data_dir=yield_africa_csv, + cache_dir=str(tmp_path / "cache"), + modalities={"coords": {}}, + use_target_data=True, + use_aux_data="none", + seed=0, + mock=True, + use_features=False, + ) + sample = ds[0] + assert "tabular" not in sample["eo"] + assert ds.tabular_dim is None + + +def test_yield_africa_dataset_aux_keys(yield_africa_dataset_with_aux): + """When aux is enabled, sample must contain an 'aux' dict.""" + sample = yield_africa_dataset_with_aux[0] + assert "aux" in sample + assert "aux" in sample["aux"] + + +def test_yield_africa_dataset_aux_columns(yield_africa_dataset_with_aux): + """Aux columns picked up must match the aux_* columns in the mock CSV.""" + resolved_aux = yield_africa_dataset_with_aux.use_aux_data["aux"] + assert set(resolved_aux) == set(MOCK_AUX_COLS.keys()) + + +def test_yield_africa_dataset_aux_shape(yield_africa_dataset_with_aux): + """Aux tensor shape must equal the number of aux_* columns.""" + sample = yield_africa_dataset_with_aux[0] + assert sample["aux"]["aux"].shape == (MOCK_N_AUX,) + + +# --------------------------------------------------------------------------- +# Datamodule tests +# --------------------------------------------------------------------------- + + +def test_yield_africa_datamodule_split_sizes(yield_africa_datamodule): + dm = yield_africa_datamodule + assert len(dm.data_train) == 7 + assert len(dm.data_val) == 2 + assert len(dm.data_test) == 1 + + +def test_yield_africa_datamodule_train_loader(yield_africa_datamodule): + dm = yield_africa_datamodule + dm.setup() + batch = next(iter(dm.train_dataloader())) + assert "eo" in batch + assert "coords" in batch["eo"] + assert "tabular" in batch["eo"] + assert batch["eo"]["coords"].shape == (4, 2) + assert batch["eo"]["tabular"].shape == (4, MOCK_TABULAR_DIM) + assert batch["target"].shape == (4, 1) + + +def test_yield_africa_datamodule_split_deterministic(yield_africa_csv, tmp_path): + def make_dm(): + dataset = YieldAfricaDataset( + data_dir=yield_africa_csv, + cache_dir=str(tmp_path / "cache"), + modalities={"coords": {}}, + use_target_data=True, + use_aux_data="none", + seed=42, + mock=True, + ) + return BaseDataModule( + dataset=dataset, + batch_size=4, + train_val_test_split=(7, 2, 1), + num_workers=0, + split_mode="random", + save_split=False, + seed=42, + ) + + dm1, dm2 = make_dm(), make_dm() + assert dm1.data_train.indices == dm2.data_train.indices + assert dm1.data_val.indices == dm2.data_val.indices + + +# --------------------------------------------------------------------------- +# Config tests +# --------------------------------------------------------------------------- + + +def test_yield_africa_config_loads(): + GlobalHydra.instance().clear() + with initialize(version_base="1.3", config_path="../configs"): + cfg = compose( + config_name="train.yaml", + overrides=["experiment=yield_africa_tabular_reg", "hydra.job.chdir=false"], + ) + assert cfg.data._target_ == "src.data.base_datamodule.BaseDataModule" + assert cfg.data.dataset._target_ == "src.data.yield_africa_dataset.YieldAfricaDataset" + assert cfg.model._target_ == "src.models.predictive_model.PredictiveModel" + GlobalHydra.instance().clear() + + +def test_yield_africa_model_instantiates(): + GlobalHydra.instance().clear() + with initialize(version_base="1.3", config_path="../configs"): + cfg = compose( + config_name="train.yaml", + overrides=["experiment=yield_africa_tabular_reg", "hydra.job.chdir=false"], + ) + model = hydra.utils.instantiate(cfg.model) + assert model is not None + GlobalHydra.instance().clear() \ No newline at end of file From a03d2285836665aeaa90ff35f16a0d9f2a04c310 Mon Sep 17 00:00:00 2001 From: Rob Knapen Date: Sat, 7 Mar 2026 17:26:06 +0100 Subject: [PATCH 20/60] The eo_encoders to geo_encoders renaming and refactoring --- configs/model/geoclip_alignment.yaml | 4 +- configs/model/geoclip_llm2clip_alignment.yaml | 4 +- configs/model/heat_fusion_reg.yaml | 6 +-- configs/model/heat_geoclip_reg.yaml | 4 +- configs/model/heat_tabular_reg.yaml | 8 ++-- configs/model/predictive_cnn_s2.yaml | 8 ++-- configs/model/predictive_geoclip.yaml | 4 +- configs/model/yield_fusion_reg.yaml | 6 +-- configs/model/yield_geoclip_reg.yaml | 4 +- configs/model/yield_tabular_reg.yaml | 6 +-- ...06-TvdP-inference-language-alignment.ipynb | 4 +- src/models/base_model.py | 2 +- .../{eo_encoders => geo_encoders}/__init__.py | 0 .../average_encoder.py | 28 ++++++------- .../base_geo_encoder.py} | 14 +++---- .../cnn_encoder.py | 30 ++++++------- .../{eo_encoders => geo_encoders}/geoclip.py | 20 ++++----- .../multimodal_encoder.py | 6 +-- src/models/predictive_model.py | 24 +++++------ src/models/text_alignment_model.py | 32 +++++++------- tests/test_eo_encoders.py | 42 +++++++++---------- tests/test_pred_heads.py | 2 +- 22 files changed, 129 insertions(+), 129 deletions(-) rename src/models/components/{eo_encoders => geo_encoders}/__init__.py (100%) rename src/models/components/{eo_encoders => geo_encoders}/average_encoder.py (68%) rename src/models/components/{eo_encoders/base_eo_encoder.py => geo_encoders/base_geo_encoder.py} (65%) rename src/models/components/{eo_encoders => geo_encoders}/cnn_encoder.py (86%) rename src/models/components/{eo_encoders => geo_encoders}/geoclip.py (50%) rename src/models/components/{eo_encoders => geo_encoders}/multimodal_encoder.py (93%) diff --git a/configs/model/geoclip_alignment.yaml b/configs/model/geoclip_alignment.yaml index 0753e79..52d6485 100644 --- a/configs/model/geoclip_alignment.yaml +++ b/configs/model/geoclip_alignment.yaml @@ -1,7 +1,7 @@ _target_: src.models.text_alignment_model.TextAlignmentModel -eo_encoder: - _target_: src.models.components.eo_encoders.geoclip.GeoClipCoordinateEncoder +geo_encoder: + _target_: src.models.components.geo_encoders.geoclip.GeoClipCoordinateEncoder text_encoder: _target_: src.models.components.text_encoders.clip_text_encoder.ClipTextEncoder diff --git a/configs/model/geoclip_llm2clip_alignment.yaml b/configs/model/geoclip_llm2clip_alignment.yaml index 76f8878..8ff882c 100644 --- a/configs/model/geoclip_llm2clip_alignment.yaml +++ b/configs/model/geoclip_llm2clip_alignment.yaml @@ -1,7 +1,7 @@ _target_: src.models.text_alignment_model.TextAlignmentModel -eo_encoder: - _target_: src.models.components.eo_encoders.geoclip.GeoClipCoordinateEncoder +geo_encoder: + _target_: src.models.components.geo_encoders.geoclip.GeoClipCoordinateEncoder text_encoder: _target_: src.models.components.text_encoders.llm2clip_text_encoder.LLM2CLIPTextEncoder diff --git a/configs/model/heat_fusion_reg.yaml b/configs/model/heat_fusion_reg.yaml index e507dd1..ff5a69b 100644 --- a/configs/model/heat_fusion_reg.yaml +++ b/configs/model/heat_fusion_reg.yaml @@ -8,8 +8,8 @@ _target_: src.models.predictive_model.PredictiveModel -eo_encoder: - _target_: src.models.components.eo_encoders.multimodal_encoder.MultiModalEncoder +geo_encoder: + _target_: src.models.components.geo_encoders.multimodal_encoder.MultiModalEncoder use_coords: true use_tabular: true # tab_embed_dim: 64 @@ -20,7 +20,7 @@ prediction_head: hidden_dim: 256 # GeoClip frozen; tabular projection + head are trained. -trainable_modules: [eo_encoder, prediction_head] +trainable_modules: [geo_encoder, prediction_head] metrics: ${metrics} diff --git a/configs/model/heat_geoclip_reg.yaml b/configs/model/heat_geoclip_reg.yaml index a33b976..0e8f11a 100644 --- a/configs/model/heat_geoclip_reg.yaml +++ b/configs/model/heat_geoclip_reg.yaml @@ -8,8 +8,8 @@ _target_: src.models.predictive_model.PredictiveModel -eo_encoder: - _target_: src.models.components.eo_encoders.multimodal_encoder.MultiModalEncoder +geo_encoder: + _target_: src.models.components.geo_encoders.multimodal_encoder.MultiModalEncoder use_coords: true use_tabular: false diff --git a/configs/model/heat_tabular_reg.yaml b/configs/model/heat_tabular_reg.yaml index affadab..8ca1184 100644 --- a/configs/model/heat_tabular_reg.yaml +++ b/configs/model/heat_tabular_reg.yaml @@ -10,14 +10,14 @@ # 1. HeatGuatemalaDataset.tabular_dim reads len(feat_names) from the CSV. # 2. BaseDataModule.tabular_dim delegates to the train dataset. # 3. PredictiveRegressionModel.setup() calls -# self.eo_encoder.build_tabular_branch(self.trainer.datamodule.tabular_dim) +# self.geo_encoder.build_tabular_branch(self.trainer.datamodule.tabular_dim) _target_: src.models.predictive_model.PredictiveModel metrics: ${metrics} -eo_encoder: - _target_: src.models.components.eo_encoders.multimodal_encoder.MultiModalEncoder +geo_encoder: + _target_: src.models.components.geo_encoders.multimodal_encoder.MultiModalEncoder use_coords: false use_tabular: true tab_embed_dim: 64 @@ -28,7 +28,7 @@ prediction_head: hidden_dim: 256 # Both encoder and head have trainable parameters. -trainable_modules: [eo_encoder, prediction_head] +trainable_modules: [geo_encoder, prediction_head] optimizer: _target_: torch.optim.Adam diff --git a/configs/model/predictive_cnn_s2.yaml b/configs/model/predictive_cnn_s2.yaml index 628ba66..52c46d9 100644 --- a/configs/model/predictive_cnn_s2.yaml +++ b/configs/model/predictive_cnn_s2.yaml @@ -1,15 +1,15 @@ _target_: src.models.predictive_model.PredictiveModel -eo_encoder: - _target_: src.models.components.eo_encoders.cnn_encoder.CNNEncoder +geo_encoder: + _target_: src.models.components.geo_encoders.cnn_encoder.CNNEncoder resnet_version: 18 freezing_strategy: none - eo_data_name: s2 + geo_data_name: s2 prediction_head: _target_: src.models.components.pred_heads.mlp_pred_head.MLPPredictionHead -trainable_modules: [eo_encoder, prediction_head] +trainable_modules: [geo_encoder, prediction_head] metrics: ${metrics} diff --git a/configs/model/predictive_geoclip.yaml b/configs/model/predictive_geoclip.yaml index ca7390f..7d9c0c5 100644 --- a/configs/model/predictive_geoclip.yaml +++ b/configs/model/predictive_geoclip.yaml @@ -1,7 +1,7 @@ _target_: src.models.predictive_model.PredictiveModel -eo_encoder: - _target_: src.models.components.eo_encoders.geoclip.GeoClipCoordinateEncoder +geo_encoder: + _target_: src.models.components.geo_encoders.geoclip.GeoClipCoordinateEncoder prediction_head: _target_: src.models.components.pred_heads.mlp_pred_head.MLPPredictionHead diff --git a/configs/model/yield_fusion_reg.yaml b/configs/model/yield_fusion_reg.yaml index 3f64972..e4e75f6 100644 --- a/configs/model/yield_fusion_reg.yaml +++ b/configs/model/yield_fusion_reg.yaml @@ -1,7 +1,7 @@ _target_: src.models.predictive_model.PredictiveModel -eo_encoder: - _target_: src.models.components.eo_encoders.multimodal_encoder.MultiModalEncoder +geo_encoder: + _target_: src.models.components.geo_encoders.multimodal_encoder.MultiModalEncoder use_coords: true use_tabular: true # tab_embed_dim: 64 @@ -12,7 +12,7 @@ prediction_head: hidden_dim: 256 # GeoClip frozen; tabular projection + head are trained. -trainable_modules: [eo_encoder, prediction_head] +trainable_modules: [geo_encoder, prediction_head] metrics: ${metrics} diff --git a/configs/model/yield_geoclip_reg.yaml b/configs/model/yield_geoclip_reg.yaml index 7bee09b..f333798 100644 --- a/configs/model/yield_geoclip_reg.yaml +++ b/configs/model/yield_geoclip_reg.yaml @@ -1,7 +1,7 @@ _target_: src.models.predictive_model.PredictiveModel -eo_encoder: - _target_: src.models.components.eo_encoders.multimodal_encoder.MultiModalEncoder +geo_encoder: + _target_: src.models.components.geo_encoders.multimodal_encoder.MultiModalEncoder use_coords: true use_tabular: false diff --git a/configs/model/yield_tabular_reg.yaml b/configs/model/yield_tabular_reg.yaml index 22251bf..af96374 100644 --- a/configs/model/yield_tabular_reg.yaml +++ b/configs/model/yield_tabular_reg.yaml @@ -1,7 +1,7 @@ _target_: src.models.predictive_model.PredictiveModel -eo_encoder: - _target_: src.models.components.eo_encoders.multimodal_encoder.MultiModalEncoder +geo_encoder: + _target_: src.models.components.geo_encoders.multimodal_encoder.MultiModalEncoder use_coords: false use_tabular: true tab_embed_dim: 128 @@ -12,7 +12,7 @@ prediction_head: hidden_dim: 256 # Both encoder and head have trainable parameters. -trainable_modules: [eo_encoder, prediction_head] +trainable_modules: [geo_encoder, prediction_head] metrics: ${metrics} diff --git a/notebooks/06-TvdP-inference-language-alignment.ipynb b/notebooks/06-TvdP-inference-language-alignment.ipynb index 4c4ab73..227578b 100644 --- a/notebooks/06-TvdP-inference-language-alignment.ipynb +++ b/notebooks/06-TvdP-inference-language-alignment.ipynb @@ -119,8 +119,8 @@ "metadata": {}, "outputs": [], "source": [ - "from src.models.components.eo_encoders.cnn_encoder import CNNEncoder\n", - "from src.models.components.eo_encoders.geoclip import GeoClipCoordinateEncoder\n", + "from src.models.components.geo_encoders.cnn_encoder import CNNEncoder\n", + "from src.models.components.geo_encoders.geoclip import GeoClipCoordinateEncoder\n", "from src.models.components.loss_fns.bce_loss import BCELoss\n", "from src.models.components.loss_fns.clip_loss import ClipLoss\n", "from src.models.components.pred_heads.mlp_pred_head import MLPPredictionHead\n", diff --git a/src/models/base_model.py b/src/models/base_model.py index e3a2a16..3e931b7 100644 --- a/src/models/base_model.py +++ b/src/models/base_model.py @@ -27,7 +27,7 @@ def __init__( """ super().__init__() self.save_hyperparameters( - ignore=["loss_fn", "eo_encoder", "prediction_head", "text_encoder", "metrics"] + ignore=["loss_fn", "geo_encoder", "prediction_head", "text_encoder", "metrics"] ) self.trainable_modules = trainable_modules diff --git a/src/models/components/eo_encoders/__init__.py b/src/models/components/geo_encoders/__init__.py similarity index 100% rename from src/models/components/eo_encoders/__init__.py rename to src/models/components/geo_encoders/__init__.py diff --git a/src/models/components/eo_encoders/average_encoder.py b/src/models/components/geo_encoders/average_encoder.py similarity index 68% rename from src/models/components/eo_encoders/average_encoder.py rename to src/models/components/geo_encoders/average_encoder.py index ccec36e..a312153 100644 --- a/src/models/components/eo_encoders/average_encoder.py +++ b/src/models/components/geo_encoders/average_encoder.py @@ -4,36 +4,36 @@ import torch.nn.functional as F from torch import nn -from src.models.components.eo_encoders.base_eo_encoder import BaseEOEncoder +from src.models.components.geo_encoders.base_geo_encoder import BaseGeoEncoder -class AverageEncoder(BaseEOEncoder): +class AverageEncoder(BaseGeoEncoder): def __init__( self, output_dim: int | None = None, - eo_data_name="aef", + geo_data_name="aef", ) -> None: super().__init__() dict_n_bands_default = {"s2": 4, "aef": 64, "tessera": 128} - self.allowed_eo_data_names: list[str] = list(dict_n_bands_default.keys()) + self.allowed_geo_data_names: list[str] = list(dict_n_bands_default.keys()) assert ( - eo_data_name in dict_n_bands_default - ), f"eo_data_name must be one of {self.allowed_eo_data_names}, got {eo_data_name}" - self.eo_data_name = eo_data_name + geo_data_name in dict_n_bands_default + ), f"geo_data_name must be one of {self.allowed_geo_data_names}, got {geo_data_name}" + self.geo_data_name = geo_data_name - if output_dim is None or output_dim == dict_n_bands_default[eo_data_name]: - self.output_dim = dict_n_bands_default[eo_data_name] + if output_dim is None or output_dim == dict_n_bands_default[geo_data_name]: + self.output_dim = dict_n_bands_default[geo_data_name] self.extra_projector = None - self.eo_encoder = self._average + self.geo_encoder = self._average else: assert ( type(output_dim) is int and output_dim > 0 ), f"output_dim must be positive int, got {output_dim}" self.output_dim = output_dim - self.extra_projector = nn.Linear(dict_n_bands_default[eo_data_name], output_dim) - self.eo_encoder = self._average_and_project + self.extra_projector = nn.Linear(dict_n_bands_default[geo_data_name], output_dim) + self.geo_encoder = self._average_and_project def _average(self, x: torch.Tensor) -> torch.Tensor: """Averages the input tensor over spatial dimensions. @@ -59,11 +59,11 @@ def forward(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: dtype = self.dtype if eo_data.dtype != dtype: eo_data = eo_data.to(dtype) - feats = self.eo_encoder(eo_data[self.eo_data_name]) + feats = self.geo_encoder(eo_data[self.geo_data_name]) # n_nans = torch.sum(torch.isnan(feats)).item() # assert ( # n_nans == 0 - # ), f"AverageEncoder output contains {n_nans}/{feats.numel()} NaNs PRIOR to normalization with data min {eo_data[self.eo_data_name].min()} and max {eo_data[self.eo_data_name].max()}." + # ), f"AverageEncoder output contains {n_nans}/{feats.numel()} NaNs PRIOR to normalization with data min {eo_data[self.geo_data_name].min()} and max {eo_data[self.geo_data_name].max()}." return feats.to(dtype) diff --git a/src/models/components/eo_encoders/base_eo_encoder.py b/src/models/components/geo_encoders/base_geo_encoder.py similarity index 65% rename from src/models/components/eo_encoders/base_eo_encoder.py rename to src/models/components/geo_encoders/base_geo_encoder.py index 643250c..162fcc9 100644 --- a/src/models/components/eo_encoders/base_eo_encoder.py +++ b/src/models/components/geo_encoders/base_geo_encoder.py @@ -5,15 +5,15 @@ from torch import nn -class BaseEOEncoder(nn.Module, ABC): +class BaseGeoEncoder(nn.Module, ABC): def __init__(self) -> None: super().__init__() - self.eo_encoder: nn.Module | None = None + self.geo_encoder: nn.Module | None = None self.output_dim: int | None = None # placeholders - self.allowed_eo_data_names: list[str] | None = None - self.eo_data_name: str | None = None + self.allowed_geo_data_names: list[str] | None = None + self.geo_data_name: str | None = None @abstractmethod def forward(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: @@ -23,16 +23,16 @@ def forward(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: def device(self) -> torch.device: devices = {p.device for p in self.parameters()} if len(devices) != 1: - raise RuntimeError("EO encoder is on multiple devices") + raise RuntimeError("GEO encoder is on multiple devices") return devices.pop() @property def dtype(self) -> torch.dtype: dtypes = {p.dtype for p in self.parameters()} if len(dtypes) != 1: - raise RuntimeError("EO encoder has multiple dtypes") + raise RuntimeError("GEO encoder has multiple dtypes") return dtypes.pop() if __name__ == "__main__": - _ = BaseEOEncoder(None) + _ = BaseGeoEncoder(None) diff --git a/src/models/components/eo_encoders/cnn_encoder.py b/src/models/components/geo_encoders/cnn_encoder.py similarity index 86% rename from src/models/components/eo_encoders/cnn_encoder.py rename to src/models/components/geo_encoders/cnn_encoder.py index d939a38..d76a4e3 100644 --- a/src/models/components/eo_encoders/cnn_encoder.py +++ b/src/models/components/geo_encoders/cnn_encoder.py @@ -5,10 +5,10 @@ from torch import nn from torch.nn import functional as F -from src.models.components.eo_encoders.base_eo_encoder import BaseEOEncoder +from src.models.components.geo_encoders.base_geo_encoder import BaseGeoEncoder -class CNNEncoder(BaseEOEncoder): +class CNNEncoder(BaseGeoEncoder): """Convolutional neural network EO encoder. Adapted from PECL. :param backbone: backbone model to use (resnet) @@ -25,7 +25,7 @@ def __init__( pretrained_cnn="imagenet", resnet_version=18, freezing_strategy="all", - eo_data_name="s2", + geo_data_name="s2", input_n_bands: int | None = None, output_dim=512, ) -> None: @@ -36,9 +36,9 @@ def __init__( self.resnet_version = resnet_version self.freezing_strategy = freezing_strategy - self.allowed_eo_data_names = ["s2", "aef", "tessera"] - assert eo_data_name in self.allowed_eo_data_names - self.eo_data_name = eo_data_name + self.allowed_geo_data_names = ["s2", "aef", "tessera"] + assert geo_data_name in self.allowed_geo_data_names + self.geo_data_name = geo_data_name self.set_n_input_bands(input_n_bands) assert ( @@ -46,23 +46,23 @@ def __init__( ), f"input_n_bands must be int >=3, got {self.input_n_bands}" self.output_dim = output_dim - self.eo_encoder = self.get_backbone() + self.geo_encoder = self.get_backbone() def set_n_input_bands(self, n_bands: int | None = None) -> None: - """Sets number of input bands based on eo_data_name if n_bands is None. + """Sets number of input bands based on geo_data_name if n_bands is None. :param n_bands: number of input bands :return: None """ - if n_bands is None: # infer from eo_data_name - if self.eo_data_name == "s2": + if n_bands is None: # infer from geo_data_name + if self.geo_data_name == "s2": self.input_n_bands = 4 - elif self.eo_data_name == "aef": + elif self.geo_data_name == "aef": self.input_n_bands = 64 - elif self.eo_data_name == "tessera": + elif self.geo_data_name == "tessera": self.input_n_bands = 128 print( - f"[CNNEncoder] Inferred input_n_bands={self.input_n_bands} for eo_data_name='{self.eo_data_name}'" + f"[CNNEncoder] Inferred input_n_bands={self.input_n_bands} for geo_data_name='{self.geo_data_name}'" ) else: self.input_n_bands = n_bands @@ -147,11 +147,11 @@ def forward( dtype = self.dtype if eo_data.dtype != dtype: eo_data = eo_data.to(dtype) - feats = self.eo_encoder(eo_data[self.eo_data_name]) + feats = self.geo_encoder(eo_data[self.geo_data_name]) # n_nans = torch.sum(torch.isnan(feats)).item() # assert ( # n_nans == 0 - # ), f"CNNEncoder output contains {n_nans}/{feats.numel()} NaNs PRIOR to normalization with data min {eo_data[self.eo_data_name].min()} and max {eo_data[self.eo_data_name].max()}." + # ), f"CNNEncoder output contains {n_nans}/{feats.numel()} NaNs PRIOR to normalization with data min {eo_data[self.geo_data_name].min()} and max {eo_data[self.geo_data_name].max()}." return feats.to(dtype) diff --git a/src/models/components/eo_encoders/geoclip.py b/src/models/components/geo_encoders/geoclip.py similarity index 50% rename from src/models/components/eo_encoders/geoclip.py rename to src/models/components/geo_encoders/geoclip.py index e9d1834..bd40aa9 100644 --- a/src/models/components/eo_encoders/geoclip.py +++ b/src/models/components/geo_encoders/geoclip.py @@ -4,22 +4,22 @@ from geoclip import LocationEncoder from torch.nn import functional as F -from src.models.components.eo_encoders.base_eo_encoder import BaseEOEncoder +from src.models.components.geo_encoders.base_geo_encoder import BaseGeoEncoder -class GeoClipCoordinateEncoder(BaseEOEncoder): +class GeoClipCoordinateEncoder(BaseGeoEncoder): def __init__( self, - eo_data_name="coords", + geo_data_name="coords", ) -> None: super().__init__() - self.eo_encoder = LocationEncoder() - self.output_dim = self.eo_encoder.LocEnc0.head[0].out_features - self.allowed_eo_data_names = ["coords"] + self.geo_encoder = LocationEncoder() + self.output_dim = self.geo_encoder.LocEnc0.head[0].out_features + self.allowed_geo_data_names = ["coords"] assert ( - eo_data_name in self.allowed_eo_data_names - ), f"eo_data_name must be one of {self.allowed_eo_data_names}, got {eo_data_name}" - self.eo_data_name = eo_data_name + geo_data_name in self.allowed_geo_data_names + ), f"geo_data_name must be one of {self.allowed_geo_data_names}, got {geo_data_name}" + self.geo_data_name = geo_data_name @override def forward( @@ -32,7 +32,7 @@ def forward( dtype = self.dtype if coords.dtype != dtype: coords = coords.to(dtype) - feats = self.eo_encoder(coords) + feats = self.geo_encoder(coords) return feats.to(dtype) diff --git a/src/models/components/eo_encoders/multimodal_encoder.py b/src/models/components/geo_encoders/multimodal_encoder.py similarity index 93% rename from src/models/components/eo_encoders/multimodal_encoder.py rename to src/models/components/geo_encoders/multimodal_encoder.py index 7dbf61d..dbf1159 100644 --- a/src/models/components/eo_encoders/multimodal_encoder.py +++ b/src/models/components/geo_encoders/multimodal_encoder.py @@ -10,11 +10,11 @@ import torch from torch import nn -from src.models.components.eo_encoders.base_eo_encoder import BaseEOEncoder -from src.models.components.eo_encoders.geoclip import GeoClipCoordinateEncoder +from src.models.components.geo_encoders.base_geo_encoder import BaseGeoEncoder +from src.models.components.geo_encoders.geoclip import GeoClipCoordinateEncoder -class MultiModalEncoder(BaseEOEncoder): +class MultiModalEncoder(BaseGeoEncoder): """ - coords only (use_coords=True, use_tabular=False) - tabular only (use_coords=False, use_tabular=True) diff --git a/src/models/predictive_model.py b/src/models/predictive_model.py index 24b8327..be69a34 100644 --- a/src/models/predictive_model.py +++ b/src/models/predictive_model.py @@ -3,8 +3,8 @@ import torch from src.models.base_model import BaseModel -from src.models.components.eo_encoders.base_eo_encoder import BaseEOEncoder -from src.models.components.eo_encoders.multimodal_encoder import MultiModalEncoder +from src.models.components.geo_encoders.base_geo_encoder import BaseGeoEncoder +from src.models.components.geo_encoders.multimodal_encoder import MultiModalEncoder from src.models.components.loss_fns.base_loss_fn import BaseLossFn from src.models.components.metrics.metrics_wrapper import MetricsWrapper from src.models.components.pred_heads.linear_pred_head import ( @@ -15,7 +15,7 @@ class PredictiveModel(BaseModel): def __init__( self, - eo_encoder: BaseEOEncoder, + geo_encoder: BaseGeoEncoder, prediction_head: BasePredictionHead, trainable_modules: list[str], optimizer: torch.optim.Optimizer, @@ -25,7 +25,7 @@ def __init__( ) -> None: """Implementation of the predictive model with replaceable EO encoder, and prediction head. - :param eo_encoder: eo encoder module (replaceable) + :param geo_encoder: geo encoder module (replaceable) :param prediction_head: prediction head module (replaceable) :param trainable_modules: list of modules to train (parts/modules or modules, modules) :param optimizer: optimizer to use for training @@ -38,8 +38,8 @@ def __init__( super().__init__(trainable_modules, optimizer, scheduler, loss_fn, metrics) - # EO encoder configuration - self.eo_encoder = eo_encoder + # Geo encoder configuration + self.geo_encoder = geo_encoder # Prediction head self.prediction_head = prediction_head @@ -58,14 +58,14 @@ def setup_encoders_adapters(self): """Set up encoders and missing adapters/projectors.""" # TODO: move to multi-modal eo encoder if ( - isinstance(self.eo_encoder, MultiModalEncoder) - and self.eo_encoder.use_tabular - and not self.eo_encoder._tabular_ready + isinstance(self.geo_encoder, MultiModalEncoder) + and self.geo_encoder.use_tabular + and not self.geo_encoder._tabular_ready ): - self.eo_encoder.build_tabular_branch(self.tabular_dim) + self.geo_encoder.build_tabular_branch(self.tabular_dim) self.prediction_head.set_dim( - input_dim=self.eo_encoder.output_dim, output_dim=self.num_classes + input_dim=self.geo_encoder.output_dim, output_dim=self.num_classes ) self.prediction_head.configure_nn() if "prediction_head" not in self.trainable_modules: @@ -73,7 +73,7 @@ def setup_encoders_adapters(self): @override def forward(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: - feats = self.eo_encoder(batch) + feats = self.geo_encoder(batch) return self.prediction_head(feats) @override diff --git a/src/models/text_alignment_model.py b/src/models/text_alignment_model.py index 2d62c38..78819b3 100644 --- a/src/models/text_alignment_model.py +++ b/src/models/text_alignment_model.py @@ -5,8 +5,8 @@ import torch.nn.functional as F from src.models.base_model import BaseModel -from src.models.components.eo_encoders.base_eo_encoder import BaseEOEncoder -from src.models.components.eo_encoders.multimodal_encoder import MultiModalEncoder +from src.models.components.geo_encoders.base_geo_encoder import BaseGeoEncoder +from src.models.components.geo_encoders.multimodal_encoder import MultiModalEncoder from src.models.components.loss_fns.base_loss_fn import BaseLossFn from src.models.components.metrics.contrastive_validation import ( RetrievalContrastiveValidation, @@ -23,7 +23,7 @@ class TextAlignmentModel(BaseModel): def __init__( self, - eo_encoder: BaseEOEncoder, + geo_encoder: BaseGeoEncoder, text_encoder: BaseTextEncoder, optimizer: torch.optim.Optimizer, scheduler: torch.optim.lr_scheduler, @@ -35,7 +35,7 @@ def __init__( ) -> None: """Implementation of contrastive text-eo modality alignment model. - :param eo_encoder: eo encoder module (replaceable) + :param geo_encoder: geo encoder module (replaceable) :param text_encoder: text encoder module (replaceable) :param optimizer: optimizer to use for training :param scheduler: scheduler to use for training @@ -52,7 +52,7 @@ def __init__( self.log_kwargs = dict(on_step=False, on_epoch=True, prog_bar=True, sync_dist=True) # Encoders configuration - self.eo_encoder = eo_encoder + self.geo_encoder = geo_encoder self.text_encoder = text_encoder # Prediction head @@ -76,29 +76,29 @@ def setup_encoders_adapters(self): """Set up encoders and missing adapters/projectors.""" # TODO: move to multi-modal eo encoder if ( - isinstance(self.eo_encoder, MultiModalEncoder) - and self.eo_encoder.use_tabular - and not self.eo_encoder._tabular_ready + isinstance(self.geo_encoder, MultiModalEncoder) + and self.geo_encoder.use_tabular + and not self.geo_encoder._tabular_ready ): - self.eo_encoder.build_tabular_branch(self.tabular_dim) + self.geo_encoder.build_tabular_branch(self.tabular_dim) # Extra projector for text encoder if eo and text dim not match - if self.eo_encoder.output_dim != self.text_encoder.output_dim: - self.text_encoder.add_projector(projected_dim=self.eo_encoder.output_dim) + if self.geo_encoder.output_dim != self.text_encoder.output_dim: + self.text_encoder.add_projector(projected_dim=self.geo_encoder.output_dim) self.trainable_modules.append("text_encoder.extra_projector") # TODO: if eo==geoclip_img pass on shared mlp if self.prediction_head is not None: self.prediction_head.set_dim( - input_dim=self.eo_encoder.output_dim, output_dim=self.num_classes + input_dim=self.geo_encoder.output_dim, output_dim=self.num_classes ) self.prediction_head.configure_nn() # Unify dtypes - if self.eo_encoder.dtype != self.text_encoder.dtype: - self.eo_encoder = self.eo_encoder.to(self.text_encoder.dtype) - print(f"Eo encoder dtype changed to {self.eo_encoder.dtype}") + if self.geo_encoder.dtype != self.text_encoder.dtype: + self.geo_encoder = self.geo_encoder.to(self.text_encoder.dtype) + print(f"Geo encoder dtype changed to {self.geo_encoder.dtype}") def setup_retrieval_evaluation(self): self.concept_configs = self.trainer.datamodule.concept_configs @@ -125,7 +125,7 @@ def forward( """Model forward logic.""" # Embed modalities - eo_feats = self.eo_encoder(batch) + eo_feats = self.geo_encoder(batch) text_feats = self.text_encoder(batch, mode) return eo_feats, text_feats diff --git a/tests/test_eo_encoders.py b/tests/test_eo_encoders.py index 635919d..70eea47 100644 --- a/tests/test_eo_encoders.py +++ b/tests/test_eo_encoders.py @@ -5,17 +5,17 @@ import pytest import torch -from src.models.components.eo_encoders.average_encoder import AverageEncoder -from src.models.components.eo_encoders.base_eo_encoder import BaseEOEncoder -from src.models.components.eo_encoders.cnn_encoder import CNNEncoder -from src.models.components.eo_encoders.geoclip import GeoClipCoordinateEncoder -from src.models.components.eo_encoders.multimodal_encoder import MultiModalEncoder +from src.models.components.geo_encoders.average_encoder import AverageEncoder +from src.models.components.geo_encoders.base_geo_encoder import BaseGeoEncoder +from src.models.components.geo_encoders.cnn_encoder import CNNEncoder +from src.models.components.geo_encoders.geoclip import GeoClipCoordinateEncoder +from src.models.components.geo_encoders.multimodal_encoder import MultiModalEncoder # @pytest.mark.slow -def test_eo_encoder_generic_properties(create_butterfly_dataset): +def test_geo_encoder_generic_properties(create_butterfly_dataset): """This test checks that all EO encoders implement the basic properties and methods.""" - dict_eo_encoders = { + dict_geo_encoders = { "geoclip_coords": GeoClipCoordinateEncoder, "cnn": CNNEncoder, "average": AverageEncoder, @@ -24,28 +24,28 @@ def test_eo_encoder_generic_properties(create_butterfly_dataset): ds, dm = create_butterfly_dataset batch = next(iter(dm.train_dataloader())) - for eo_encoder_name, eo_encoder_class in dict_eo_encoders.items(): - eo_encoder = eo_encoder_class() + for geo_encoder_name, geo_encoder_class in dict_geo_encoders.items(): + geo_encoder = geo_encoder_class() assert hasattr( - eo_encoder, "eo_encoder" - ), f"'eo_encoder' attribute missing in {eo_encoder_class.__name__}." + geo_encoder, "geo_encoder" + ), f"'geo_encoder' attribute missing in {geo_encoder_class.__name__}." assert hasattr( - eo_encoder, "output_dim" - ), f"'output_dim' attribute missing in {eo_encoder_class.__name__}." + geo_encoder, "output_dim" + ), f"'output_dim' attribute missing in {geo_encoder_class.__name__}." assert hasattr( - eo_encoder, "forward" - ), f"'forward' method missing in {eo_encoder_class.__name__}." + geo_encoder, "forward" + ), f"'forward' method missing in {geo_encoder_class.__name__}." assert callable( - getattr(eo_encoder, "forward") - ), f"'forward' is not callable in {eo_encoder_class.__name__}." + getattr(geo_encoder, "forward") + ), f"'forward' is not callable in {geo_encoder_class.__name__}." - if eo_encoder_name == "geoclip_coords": + if geo_encoder_name == "geoclip_coords": # TODO: try more EO encoders when (mock) test data also includes images. - feats = eo_encoder.forward(batch) + feats = geo_encoder.forward(batch) assert isinstance( feats, torch.Tensor - ), f"'forward' method of {eo_encoder_class.__name__} does not return a torch.Tensor." + ), f"'forward' method of {geo_encoder_class.__name__} does not return a torch.Tensor." assert ( feats.shape[0] == dm.batch_size_per_device - ), f"Output batch size mismatch in {eo_encoder_class.__name__}." + ), f"Output batch size mismatch in {geo_encoder_class.__name__}." diff --git a/tests/test_pred_heads.py b/tests/test_pred_heads.py index 09f8a9a..9bd2c62 100644 --- a/tests/test_pred_heads.py +++ b/tests/test_pred_heads.py @@ -5,7 +5,7 @@ import pytest import torch -from src.models.components.eo_encoders.geoclip import GeoClipCoordinateEncoder +from src.models.components.geo_encoders.geoclip import GeoClipCoordinateEncoder from src.models.components.pred_heads.base_pred_head import BasePredictionHead from src.models.components.pred_heads.linear_pred_head import LinearPredictionHead from src.models.components.pred_heads.mlp_pred_head import MLPPredictionHead From 1a20366d179a37df9f9fd8ad3b4dfe754256455c Mon Sep 17 00:00:00 2001 From: Rob Knapen Date: Sat, 7 Mar 2026 17:30:08 +0100 Subject: [PATCH 21/60] Missed renames --- tests/{test_eo_encoders.py => test_geo_encoders.py} | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) rename tests/{test_eo_encoders.py => test_geo_encoders.py} (91%) diff --git a/tests/test_eo_encoders.py b/tests/test_geo_encoders.py similarity index 91% rename from tests/test_eo_encoders.py rename to tests/test_geo_encoders.py index 70eea47..9a54732 100644 --- a/tests/test_eo_encoders.py +++ b/tests/test_geo_encoders.py @@ -14,7 +14,7 @@ # @pytest.mark.slow def test_geo_encoder_generic_properties(create_butterfly_dataset): - """This test checks that all EO encoders implement the basic properties and methods.""" + """This test checks that all GEO encoders implement the basic properties and methods.""" dict_geo_encoders = { "geoclip_coords": GeoClipCoordinateEncoder, "cnn": CNNEncoder, @@ -41,7 +41,7 @@ def test_geo_encoder_generic_properties(create_butterfly_dataset): ), f"'forward' is not callable in {geo_encoder_class.__name__}." if geo_encoder_name == "geoclip_coords": - # TODO: try more EO encoders when (mock) test data also includes images. + # TODO: try more GEO encoders when (mock) test data also includes images. feats = geo_encoder.forward(batch) assert isinstance( feats, torch.Tensor From 9c205b0611ae0cc511242e633e69524274877c86 Mon Sep 17 00:00:00 2001 From: Rob Knapen Date: Sun, 8 Mar 2026 20:38:25 +0100 Subject: [PATCH 22/60] Crop Yield use case: tessera embeddings download script --- .gitignore | 1 + src/data_preprocessing/tessera_embeds.py | 34 ++- .../yield_africa_tessera_preprocess.py | 285 ++++++++++++++++++ 3 files changed, 313 insertions(+), 7 deletions(-) create mode 100644 src/data_preprocessing/yield_africa_tessera_preprocess.py diff --git a/.gitignore b/.gitignore index 0bec17c..2dfb5d6 100644 --- a/.gitignore +++ b/.gitignore @@ -228,3 +228,4 @@ notebooks/01-TvdP-tmp.ipynb */source/* *.tif # for now ..env.swp +/data/yield_africa/ diff --git a/src/data_preprocessing/tessera_embeds.py b/src/data_preprocessing/tessera_embeds.py index b4e6852..eff5e33 100644 --- a/src/data_preprocessing/tessera_embeds.py +++ b/src/data_preprocessing/tessera_embeds.py @@ -1,7 +1,11 @@ import math import os +import threading import numpy as np + +# Serialises concurrent reads/writes to the per-directory meta.csv log file. +_meta_csv_lock = threading.Lock() import pandas as pd import rasterio from geotessera import GeoTessera @@ -122,6 +126,12 @@ def get_tessera_embeds( if reproject_memfile: memfiles.append(reproject_memfile) + if not tiles: + print(f"No TESSERA tiles found for {name_loc} at ({lon:.4f}, {lat:.4f}) year={year}. Skipping.") + for mf in memfiles: + mf.close() + return + mosaic, mosaic_transform = merge(tiles) mosaic = mosaic.transpose(1, 2, 0) @@ -134,11 +144,18 @@ def get_tessera_embeds( col, row = crs_to_pixel_coords(lon_utm, lat_utm, mosaic_transform) half = tile_size // 2 row_min = row - half - row_max = row + half + row_max = row + tile_size - half # tile_size - half ensures correct size for odd tile_size col_min = col - half - col_max = col + half + col_max = col + tile_size - half crop = mosaic[row_min:row_max, col_min:col_max, :] + if crop.shape[0] != tile_size or crop.shape[1] != tile_size: + print( + f"Unexpected crop shape {crop.shape} for {name_loc} " + f"(expected {tile_size}x{tile_size}). Skipping." + ) + return + # Save array os.makedirs(save_dir, exist_ok=True) np.save(embed_tile_name, crop) @@ -151,11 +168,14 @@ def get_tessera_embeds( meta_file = f"{save_dir}/meta.csv" - if os.path.exists(meta_file): - meta_df = pd.concat([meta_df, pd.read_csv(meta_file)], ignore_index=True) - - meta_df.to_csv(meta_file, index=False) - print(f"Meta data logged to {meta_file}") + with _meta_csv_lock: + try: + if os.path.exists(meta_file): + meta_df = pd.concat([meta_df, pd.read_csv(meta_file)], ignore_index=True) + meta_df.to_csv(meta_file, index=False) + print(f"Meta data logged to {meta_file}") + except Exception as e: + print(f"Warning: could not update meta.csv ({e}). Tile was saved successfully.") def tessera_from_df( diff --git a/src/data_preprocessing/yield_africa_tessera_preprocess.py b/src/data_preprocessing/yield_africa_tessera_preprocess.py new file mode 100644 index 0000000..3fcfbfa --- /dev/null +++ b/src/data_preprocessing/yield_africa_tessera_preprocess.py @@ -0,0 +1,285 @@ +"""Fetch and cache TESSERA embedding tiles for the yield_africa dataset. + +Location: src/data_preprocessing/yield_africa_tessera_preprocess.py + +Tiles are saved as NumPy arrays to: + {data_dir}/yield_africa/eo/tessera/tessera_{name_loc}.npy + +This matches the path built by BaseDataset.add_modality_paths_to_df() and +loaded by BaseDataset.setup_tessera() at training time. + +Unlike tessera_from_df() (which takes a single fixed year), this script +uses each record's own `year` column so that per-record inter-annual +phenology is captured — the key signal missing from the static tabular +features. + +The script is resumable: get_tessera_embeds() skips files that already +exist, so interrupted runs can be continued safely. + +Usage +----- + # All records + python src/data_preprocessing/yield_africa_tessera_preprocess.py \\ + --data_dir data/ + + # Single country, useful for incremental fetching + python src/data_preprocessing/yield_africa_tessera_preprocess.py \\ + --data_dir data/ --countries KEN RWA + + # Smaller tile size (faster, less context) + python src/data_preprocessing/yield_africa_tessera_preprocess.py \\ + --data_dir data/ --tile_size 5 +""" + +import argparse +import logging +import os +import socket +import sys +import threading +from concurrent.futures import FIRST_COMPLETED, ThreadPoolExecutor, wait +from pathlib import Path + +# Ensure the project root is on sys.path when the script is run directly. +sys.path.insert(0, str(Path(__file__).resolve().parents[2])) + +import pandas as pd +from geotessera import GeoTessera + +from src.data_preprocessing.tessera_embeds import get_tessera_embeds + +log = logging.getLogger(__name__) + +DATASET_NAME = "yield_africa" +MODEL_READY_CSV = f"model_ready_{DATASET_NAME}.csv" + +# Tile size in pixels. A small tile (e.g. 9) captures local context around +# each plot point without pulling in large surrounding areas. Consistent +# with the typical smallholder farm size in the region. +DEFAULT_TILE_SIZE = 9 + + +def fetch_tessera_tiles( + data_dir: str, + tile_size: int = DEFAULT_TILE_SIZE, + countries: list[str] | None = None, + years: list[int] | None = None, + cache_dir: str | None = None, + embeddings_dir: str | None = None, + workers: int = 2, +) -> None: + """Fetch TESSERA tiles for every record in the yield_africa CSV. + + :param data_dir: root data directory (same as ``paths.data_dir`` in configs) + :param tile_size: spatial extent of each tile in pixels + :param countries: optional list of country codes to restrict fetching + :param years: optional list of years to restrict fetching + :param cache_dir: directory for GeoTessera's internal registry cache; + defaults to ``{data_dir}/cache/tessera`` + :param embeddings_dir: directory where GeoTessera stores the raw downloaded + embedding tiles (``global_0.1_degree_representation/`` etc.). Defaults + to the current working directory when not set, which can silently fill + the project folder with tens of GB of data. Point this at an external + drive when disk space is limited. + :param workers: number of parallel download threads. Each thread keeps its + own GeoTessera instance to avoid shared state. Default: 2 (external + drive I/O is usually the bottleneck; more workers add contention). + """ + dataset_dir = Path(data_dir) / DATASET_NAME + csv_path = dataset_dir / MODEL_READY_CSV + save_dir = dataset_dir / "eo" / "tessera" + + if not csv_path.exists(): + raise FileNotFoundError(f"Model-ready CSV not found: {csv_path}") + + save_dir.mkdir(parents=True, exist_ok=True) + + if cache_dir is None: + cache_dir = str(Path(data_dir) / "cache" / "tessera") + + df = pd.read_csv(csv_path) + + # Optional filters (consistent with YieldAfricaDataset filter params) + if countries is not None: + df = df[df["country"].isin(countries)] + log.info(f"Filtered to countries {countries}: {len(df)} records") + if years is not None: + df = df[df["year"].isin(years)] + log.info(f"Filtered to years {years}: {len(df)} records") + + n_total = len(df) + n_existing = sum( + 1 for _, row in df.iterrows() + if (save_dir / f"tessera_{row.name_loc}.npy").exists() + ) + n_to_fetch = n_total - n_existing + + print( + f"Records: {n_total} total, {n_existing} already cached, " + f"{n_to_fetch} to fetch (tile_size={tile_size}, workers={workers})" + ) + + # Build GeoTessera constructor kwargs shared across all threads. + # Each thread creates its own instance (thread-local) to avoid sharing + # internal state such as open file handles and rasterio MemoryFiles. + _default_registry_dir = Path.home() / ".cache" / "geotessera" + _use_local_registry = (_default_registry_dir / "registry.parquet").exists() + _gt_kwargs: dict = { + # Skip SHA-256 hash verification after each tile download. Verification + # reads the entire (potentially large) file again after download, adding + # noticeable CPU time per tile and making progress look stalled. + "verify_hashes": False, + } + if embeddings_dir is not None: + _gt_kwargs["embeddings_dir"] = embeddings_dir + + _thread_local = threading.local() + + def _get_gt() -> GeoTessera: + """Return a thread-local GeoTessera instance, creating it on first use.""" + if not hasattr(_thread_local, "gt"): + if _use_local_registry: + _thread_local.gt = GeoTessera(registry_dir=_default_registry_dir, **_gt_kwargs) + else: + _thread_local.gt = GeoTessera(cache_dir=cache_dir, **_gt_kwargs) + return _thread_local.gt + + def _fetch_one(row) -> str: + get_tessera_embeds( + lon=row.lon, + lat=row.lat, + name_loc=row.name_loc, + year=int(row.year), + save_dir=str(save_dir), + tile_size=tile_size, + tessera_con=_get_gt(), + ) + return row.name_loc + + # Bound all socket operations (urllib HTTP requests inside geotessera). + # Without this, a stalled connection blocks the thread until the OS TCP + # keepalive fires, which can take many minutes. + SOCKET_TIMEOUT = 60 # seconds per socket operation + HEARTBEAT = 30 # print a heartbeat when no future completes this fast + TILE_TIMEOUT = 600 # give up warning after 10 min of complete silence + socket.setdefaulttimeout(SOCKET_TIMEOUT) + + rows = [row for _, row in df.iterrows()] + done = 0 + pending: set = set() + silent_seconds = 0 + + try: + with ThreadPoolExecutor(max_workers=workers) as pool: + # Submit all jobs up-front; the pool queues them internally. + futures = {pool.submit(_fetch_one, row): row.name_loc for row in rows} + pending = set(futures) + + while pending: + finished, pending = wait(pending, timeout=HEARTBEAT, return_when=FIRST_COMPLETED) + + if not finished: + silent_seconds += HEARTBEAT + print( + f" ... still working — {done}/{n_total} done, " + f"{len(pending)} pending, {silent_seconds}s since last completion" + ) + if silent_seconds >= TILE_TIMEOUT: + print(f" WARNING: no progress in {TILE_TIMEOUT}s, something may be stuck.") + continue + + silent_seconds = 0 + for fut in finished: + done += 1 + if done % 100 == 0 or done == n_total: + print(f" {done}/{n_total}") + try: + fut.result() + except Exception as exc: + print(f" ERROR fetching {futures[fut]}: {exc}") + + except KeyboardInterrupt: + print("\nInterrupted — cancelling queued futures (in-flight downloads will finish).") + for fut in pending: + fut.cancel() + + print(f"Done. Tiles saved to: {save_dir}") + + +def main() -> None: + logging.basicConfig(level=logging.INFO, format="%(message)s") + + parser = argparse.ArgumentParser( + description="Fetch TESSERA embedding tiles for the yield_africa dataset." + ) + parser.add_argument( + "--data_dir", + type=str, + default="data/", + help="Root data directory (same as paths.data_dir in configs). Default: data/", + ) + parser.add_argument( + "--tile_size", + type=int, + default=DEFAULT_TILE_SIZE, + help=f"Tile size in pixels. Default: {DEFAULT_TILE_SIZE}", + ) + parser.add_argument( + "--countries", + nargs="+", + default=None, + metavar="CODE", + help="Country codes to restrict fetching (e.g. KEN RWA). Default: all", + ) + parser.add_argument( + "--years", + nargs="+", + type=int, + default=None, + metavar="YEAR", + help="Years to restrict fetching (e.g. 2019 2020). Default: all", + ) + parser.add_argument( + "--cache_dir", + type=str, + default=None, + help="GeoTessera internal registry cache directory. Default: {data_dir}/cache/tessera", + ) + parser.add_argument( + "--embeddings_dir", + type=str, + default=os.environ.get("TESSERA_EMBEDDINGS_DIR"), + help=( + "Directory for GeoTessera raw source tiles " + "(global_0.1_degree_representation/ etc.). " + "Falls back to the TESSERA_EMBEDDINGS_DIR env var, then the current " + "working directory. Set TESSERA_EMBEDDINGS_DIR in .env to avoid " + "passing this flag every run." + ), + ) + parser.add_argument( + "--workers", + type=int, + default=2, + help="Number of parallel download threads. Default: 2", + ) + args = parser.parse_args() + + print( + f"Fetching TESSERA tiles data_dir={args.data_dir} " + f"tile_size={args.tile_size} countries={args.countries or 'all'} " + f"years={args.years or 'all'}" + ) + fetch_tessera_tiles( + data_dir=args.data_dir, + tile_size=args.tile_size, + countries=args.countries, + years=args.years, + cache_dir=args.cache_dir, + embeddings_dir=args.embeddings_dir, + workers=args.workers, + ) + + +if __name__ == "__main__": + main() \ No newline at end of file From 2ef555afd5b2e3d30d5e959eb079049d29bdcb06 Mon Sep 17 00:00:00 2001 From: Rob Knapen Date: Sun, 8 Mar 2026 20:41:22 +0100 Subject: [PATCH 23/60] Crop Yield use case: tessera modality added to dataset --- src/data/yield_africa_dataset.py | 84 ++++++++++++++++++++++++++++++-- tests/test_yield_africa.py | 9 +++- 2 files changed, 86 insertions(+), 7 deletions(-) diff --git a/src/data/yield_africa_dataset.py b/src/data/yield_africa_dataset.py index 65534cc..f2867d8 100644 --- a/src/data/yield_africa_dataset.py +++ b/src/data/yield_africa_dataset.py @@ -9,14 +9,24 @@ """ import logging +import os from typing import Any, Dict, List, override +import pandas as pd import torch from src.data.base_dataset import BaseDataset +# Number of channels in a TESSERA embedding tile (fixed by the geotessera model). +_TESSERA_CHANNELS = 128 + log = logging.getLogger(__name__) +# Fixed ordered list of all countries in the full dataset. +# Used to produce a consistent one-hot encoding regardless of which +# countries are present after filtering. +_ALL_COUNTRIES = ["BF", "BUR", "ETH", "KEN", "MAL", "RWA", "TAN", "ZAM"] + class YieldAfricaDataset(BaseDataset): """Dataset for the crop yield regression use case (East/Southern Africa). @@ -34,6 +44,12 @@ class YieldAfricaDataset(BaseDataset): `implemented_mod = {"coords"}` because tabular features live directly in the model-ready CSV and are picked up via the `feat_` column prefix. They do NOT need to be listed in `modalities`. + + In addition to the CSV feat_* columns, `year` and one-hot `country` + encodings are injected as `feat_year` and `feat_country_{CODE}` so that + the model can condition on inter-annual and cross-country variation. + The one-hot set always covers `_ALL_COUNTRIES` (8 countries) so that + `tabular_dim` is stable regardless of the country filter applied. """ def __init__( @@ -59,12 +75,27 @@ def __init__( dataset_name="yield_africa", seed=seed, cache_dir=cache_dir, - implemented_mod={"coords"}, + implemented_mod={"coords", "tessera"}, mock=mock, use_features=use_features, ) - # Apply country/year filters to self.df and rebuild records if needed. + # Inject year and country one-hot columns as feat_* so that + # get_records() picks them up automatically. Build all new columns in + # one concat to avoid pandas PerformanceWarning from repeated assignment. + if use_features and "year" in self.df.columns and "country" in self.df.columns: + # Normalise feat_year to the same scale as the pre-scaled CSV feat_* columns + # (roughly zero-mean, unit-std) so it doesn't dominate LayerNorm. + _YEAR_MEAN = 2018.0 + _YEAR_STD = 2.0 + new_cols: Dict[str, Any] = { + "feat_year": (self.df["year"].astype(float) - _YEAR_MEAN) / _YEAR_STD + } + for code in _ALL_COUNTRIES: + new_cols[f"feat_country_{code}"] = (self.df["country"] == code).astype(float) + self.df = pd.concat([self.df, pd.DataFrame(new_cols, index=self.df.index)], axis=1) + + # Apply country/year filters to self.df and rebuild records. # BaseDataset.__init__ has already loaded the CSV; filtering here avoids # touching BaseDataset and keeps the logic use-case specific. n_before = len(self.df) @@ -80,11 +111,44 @@ def __init__( n_after = len(self.df) if n_after != n_before: log.info(f"Country/year filter: {n_before} → {n_after} records ({n_before - n_after} excluded)") - self.records = self.get_records() + + # get_records() mutates self.use_aux_data in place (replacing pattern + # dicts with resolved column-name lists), so reset it from the original + # parameter before calling it a second time. + if use_aux_data is None or use_aux_data == "all": + self.use_aux_data = { + "aux": {"pattern": "^aux_(?!.*top).*"}, + "top": {"pattern": "^aux_.*top.*"}, + } + elif isinstance(use_aux_data, dict): + self.use_aux_data = use_aux_data + else: + self.use_aux_data = None + + # Always rebuild so feat_year / feat_country_* are reflected in + # self.feat_names and self.tabular_dim. + self.records = self.get_records() def setup(self) -> None: - """No files to download or prepare for this dataset.""" - return + """Check for requested modality data; warn if TESSERA tiles are absent. + + Unlike other datasets, TESSERA tiles for yield_africa vary per record + year and must be pre-fetched with the preprocessing script: + python src/data_preprocessing/yield_africa_tessera_preprocess.py + + setup_tessera() is intentionally not called here because it uses a + single fixed year for bulk download, which is incompatible with the + multi-year nature of this dataset. + """ + if "tessera" in self.modalities: + tessera_dir = os.path.join(self.data_dir, "eo", "tessera") + if not os.path.exists(tessera_dir) or len(os.listdir(tessera_dir)) == 0: + log.warning( + "TESSERA tiles not found at %s. " + "Run src/data_preprocessing/yield_africa_tessera_preprocess.py " + "to pre-fetch tiles. Missing tiles will fall back to zero tensors.", + tessera_dir, + ) @override def __getitem__(self, idx: int) -> Dict[str, Any]: @@ -96,6 +160,16 @@ def __getitem__(self, idx: int) -> Dict[str, Any]: sample["eo"]["coords"] = torch.tensor( [row["lat"], row["lon"]], dtype=torch.float32 ) + elif modality == "tessera": + tile_path = row["tessera_path"] + if os.path.exists(tile_path): + sample["eo"]["tessera"] = self.load_npy(tile_path) + else: + size = self.modalities["tessera"].get("size", 9) + log.debug("TESSERA tile missing: %s — using zero fallback.", tile_path) + sample["eo"]["tessera"] = torch.zeros( + _TESSERA_CHANNELS, size, size, dtype=torch.float32 + ) if self.use_features and self.feat_names: sample["eo"]["tabular"] = torch.tensor( diff --git a/tests/test_yield_africa.py b/tests/test_yield_africa.py index 30f32d2..b1e6436 100644 --- a/tests/test_yield_africa.py +++ b/tests/test_yield_africa.py @@ -46,7 +46,11 @@ } MOCK_N_ROWS = 10 -MOCK_TABULAR_DIM = len(MOCK_FEAT_COLS) # 8 +# feat_year (1) + feat_country_{code} (8) are injected by YieldAfricaDataset +# when country and year columns are present, so the effective tabular dim grows. +from src.data.yield_africa_dataset import _ALL_COUNTRIES +MOCK_INJECTED_FEAT_NAMES = {"feat_year"} | {f"feat_country_{c}" for c in _ALL_COUNTRIES} +MOCK_TABULAR_DIM = len(MOCK_FEAT_COLS) + len(MOCK_INJECTED_FEAT_NAMES) # 8 + 9 = 17 MOCK_N_AUX = len(MOCK_AUX_COLS) # 4 @@ -169,7 +173,8 @@ def test_yield_africa_dataset_target_name(yield_africa_dataset): def test_yield_africa_dataset_attributes(yield_africa_dataset): assert yield_africa_dataset.num_classes == 1 assert yield_africa_dataset.tabular_dim == MOCK_TABULAR_DIM - assert set(yield_africa_dataset.feat_names) == set(MOCK_FEAT_COLS.keys()) + expected_feat_names = set(MOCK_FEAT_COLS.keys()) | MOCK_INJECTED_FEAT_NAMES + assert set(yield_africa_dataset.feat_names) == expected_feat_names def test_yield_africa_dataset_feat_prefix(yield_africa_dataset): From 90b6a49eb0a83da3727907706a92323d35e84936 Mon Sep 17 00:00:00 2001 From: Rob Knapen Date: Sun, 8 Mar 2026 20:45:16 +0100 Subject: [PATCH 24/60] Configurable coords encoder in MultiModalEncoder --- .../geo_encoders/average_encoder.py | 18 +++-- .../geo_encoders/multimodal_encoder.py | 70 ++++++++++++++++--- .../pred_heads/mlp_regression_head.py | 5 +- 3 files changed, 71 insertions(+), 22 deletions(-) diff --git a/src/models/components/geo_encoders/average_encoder.py b/src/models/components/geo_encoders/average_encoder.py index a312153..86a6b9b 100644 --- a/src/models/components/geo_encoders/average_encoder.py +++ b/src/models/components/geo_encoders/average_encoder.py @@ -55,16 +55,14 @@ def _average_and_project(self, x: torch.Tensor) -> torch.Tensor: @override def forward(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: - eo_data = batch.get("eo", {}) - dtype = self.dtype - if eo_data.dtype != dtype: - eo_data = eo_data.to(dtype) - feats = self.geo_encoder(eo_data[self.geo_data_name]) - # n_nans = torch.sum(torch.isnan(feats)).item() - # assert ( - # n_nans == 0 - # ), f"AverageEncoder output contains {n_nans}/{feats.numel()} NaNs PRIOR to normalization with data min {eo_data[self.geo_data_name].min()} and max {eo_data[self.geo_data_name].max()}." - + tile = batch.get("eo", {}).get(self.geo_data_name) + # Determine target dtype from parameters when available (e.g. when the + # optional projection layer exists); otherwise keep the input dtype. + params = list(self.parameters()) + dtype = params[0].dtype if params else tile.dtype + if tile.dtype != dtype: + tile = tile.to(dtype) + feats = self.geo_encoder(tile) return feats.to(dtype) diff --git a/src/models/components/geo_encoders/multimodal_encoder.py b/src/models/components/geo_encoders/multimodal_encoder.py index dbf1159..2cb1e47 100644 --- a/src/models/components/geo_encoders/multimodal_encoder.py +++ b/src/models/components/geo_encoders/multimodal_encoder.py @@ -1,8 +1,10 @@ """Unified multimodal encoder for EO data. Controlled entirely via constructor flags: - - use_coords: encode lat/lon with GeoClip - - use_tabular: encode feat_* tabular columns + - use_coords: activate the spatial/geo encoder branch + - use_tabular: encode feat_* tabular columns + - geo_encoder_cfg: pluggable geo encoder (any BaseGeoEncoder subclass); + when None and use_coords=True, defaults to GeoClipCoordinateEncoder """ from typing import Dict, override @@ -16,9 +18,25 @@ class MultiModalEncoder(BaseGeoEncoder): """ - - coords only (use_coords=True, use_tabular=False) - - tabular only (use_coords=False, use_tabular=True) - - coords + tabular (use_coords=True, use_tabular=True) + Modes (controlled by use_coords / use_tabular flags): + + - geo only (use_coords=True, use_tabular=False) + - tabular only (use_coords=False, use_tabular=True) + - geo + tabular (use_coords=True, use_tabular=True) + + The geo encoder defaults to GeoClipCoordinateEncoder but can be replaced + with any BaseGeoEncoder via the geo_encoder_cfg parameter. Hydra + instantiates geo_encoder_cfg before passing it here, so it arrives as a + ready-to-use nn.Module (e.g. AverageEncoder for TESSERA tiles). + + Example config (TESSERA + tabular fusion): + geo_encoder: + _target_: ...MultiModalEncoder + use_coords: true + use_tabular: true + geo_encoder_cfg: + _target_: ...AverageEncoder + geo_data_name: tessera """ def __init__( @@ -26,7 +44,9 @@ def __init__( use_coords: bool = True, use_tabular: bool = False, tab_embed_dim: int = 64, + tabular_dropout: float = 0.0, tabular_dim: int = None, + geo_encoder_cfg: BaseGeoEncoder | None = None, ) -> None: super().__init__() @@ -35,12 +55,17 @@ def __init__( self.use_coords = use_coords self.use_tabular = use_tabular self.tab_embed_dim = tab_embed_dim + self.tabular_dropout = tabular_dropout self._tabular_ready = False + self.fusion_norm = None # set in build_tabular_branch when both branches active coords_dim = 0 if use_coords: - self.coords_encoder = GeoClipCoordinateEncoder() - coords_dim = self.coords_encoder.output_dim # 512 + if geo_encoder_cfg is not None: + self.coords_encoder = geo_encoder_cfg + else: + self.coords_encoder = GeoClipCoordinateEncoder() + coords_dim = self.coords_encoder.output_dim self._coords_dim = coords_dim @@ -57,20 +82,40 @@ def __init__( # ------------------------------------------------------------------ def build_tabular_branch(self, tabular_dim: int) -> None: - """Build (or rebuild) the tabular projection layer.""" + """Build (or rebuild) the tabular projection MLP. + + Architecture: LayerNorm → Linear(in, h) → ReLU → Dropout → + Linear(h, h//2) → ReLU → Dropout → Linear(h//2, out) + where h = max(tab_embed_dim * 2, 128). + """ if self._tabular_ready and hasattr(self, "_last_tabular_dim"): if self._last_tabular_dim == tabular_dim: return # already built with correct dim + hidden = max(self.tab_embed_dim * 2, 128) + drop = self.tabular_dropout self.tabular_proj = nn.Sequential( nn.LayerNorm(tabular_dim), - nn.Linear(tabular_dim, self.tab_embed_dim), + nn.Linear(tabular_dim, hidden), + nn.ReLU(), + nn.Dropout(drop), + nn.Linear(hidden, hidden // 2), nn.ReLU(), + nn.Dropout(drop), + nn.Linear(hidden // 2, self.tab_embed_dim), ) self._last_tabular_dim = tabular_dim self._tabular_ready = True self.output_dim = self._coords_dim + self.tab_embed_dim + # Normalise the fused representation when both branches are active. + # The geo encoder output and the tabular projection may have different + # scales, so a LayerNorm stabilises training after concat. + if self.use_coords: + self.fusion_norm = nn.LayerNorm(self.output_dim) + else: + self.fusion_norm = None + # ------------------------------------------------------------------ # Forward # ------------------------------------------------------------------ @@ -80,7 +125,7 @@ def forward(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: parts = [] if self.use_coords: - parts.append(self.coords_encoder(batch)) # (B, 512) + parts.append(self.coords_encoder(batch)) # (B, coords_encoder.output_dim) if self.use_tabular: assert self._tabular_ready, ( @@ -90,4 +135,7 @@ def forward(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: tab = batch["eo"]["tabular"].float() # (B, tabular_dim) parts.append(self.tabular_proj(tab)) # (B, tab_embed_dim) - return torch.cat(parts, dim=-1) + fused = torch.cat(parts, dim=-1) + if self.fusion_norm is not None: + fused = self.fusion_norm(fused) + return fused diff --git a/src/models/components/pred_heads/mlp_regression_head.py b/src/models/components/pred_heads/mlp_regression_head.py index c52509a..818da44 100644 --- a/src/models/components/pred_heads/mlp_regression_head.py +++ b/src/models/components/pred_heads/mlp_regression_head.py @@ -19,10 +19,11 @@ class MLPRegressionPredictionHead(BasePredictionHead): """MLP prediction head for regression tasks (outputs a continuous value).""" - def __init__(self, nn_layers: int = 2, hidden_dim: int = 256) -> None: + def __init__(self, nn_layers: int = 2, hidden_dim: int = 256, dropout: float = 0.0) -> None: super().__init__() self.nn_layers = nn_layers self.hidden_dim = hidden_dim + self.dropout = dropout @override def forward(self, feats: torch.Tensor) -> torch.Tensor: @@ -39,6 +40,8 @@ def configure_nn(self) -> None: for _ in range(self.nn_layers - 1): layers.append(nn.Linear(in_dim, self.hidden_dim)) layers.append(nn.ReLU()) + if self.dropout > 0.0: + layers.append(nn.Dropout(self.dropout)) in_dim = self.hidden_dim layers.append(nn.Linear(in_dim, self.output_dim)) From 20e7d51abf138d3084c6ca2129377edf69d1bc07 Mon Sep 17 00:00:00 2001 From: Rob Knapen Date: Sun, 8 Mar 2026 20:47:22 +0100 Subject: [PATCH 25/60] Crop Yield use case: leave-one-country-out split script --- .../yield_africa_loco_splits.py | 171 ++++++++++++++++++ 1 file changed, 171 insertions(+) create mode 100644 src/data_preprocessing/yield_africa_loco_splits.py diff --git a/src/data_preprocessing/yield_africa_loco_splits.py b/src/data_preprocessing/yield_africa_loco_splits.py new file mode 100644 index 0000000..1bfb1b7 --- /dev/null +++ b/src/data_preprocessing/yield_africa_loco_splits.py @@ -0,0 +1,171 @@ +"""Generate leave-one-country-out (LOCO) split files for the yield_africa dataset. + +Location: src/data_preprocessing/yield_africa_loco_splits.py + +For each held-out country one `.pth` file is written to +`{data_dir}/yield_africa/splits/split_loco_{COUNTRY}.pth`. + +Split layout +------------ +- test : all records from the held-out country +- train : 80 % of records from the remaining countries (random, seeded) +- val : 20 % of records from the remaining countries (random, seeded) + +The files are consumed by BaseDataModule when `split_mode: from_file` and +`saved_split_file_name: split_loco_{COUNTRY}.pth`. + +Usage +----- + python src/data_preprocessing/yield_africa_loco_splits.py --data_dir data/ + python src/data_preprocessing/yield_africa_loco_splits.py --data_dir data/ --country KEN +""" + +import argparse +import logging +from pathlib import Path + +import numpy as np +import pandas as pd +import torch + +log = logging.getLogger(__name__) + +# All countries present in the full dataset (must match _ALL_COUNTRIES in +# yield_africa_dataset.py so that the feature encoding is consistent). +ALL_COUNTRIES = ["BF", "BUR", "ETH", "KEN", "MAL", "RWA", "TAN", "ZAM"] + +DATASET_NAME = "yield_africa" +MODEL_READY_CSV = f"model_ready_{DATASET_NAME}.csv" + + +def make_loco_split( + df: pd.DataFrame, + test_country: str, + val_fraction: float = 0.2, + seed: int = 12345, +) -> dict: + """Return a split-indices dict for one held-out country. + + :param df: full model-ready dataframe (must contain 'country' and 'name_loc') + :param test_country: country code to hold out as the test set + :param val_fraction: fraction of the non-test pool to use for validation + :param seed: random seed for the train/val shuffle + :return: dict with 'train_indices', 'val_indices', 'test_indices' as pd.Series of name_locs + """ + test_mask = df["country"] == test_country + test_locs = df.loc[test_mask, "name_loc"].reset_index(drop=True) + + remaining = df.loc[~test_mask, "name_loc"].reset_index(drop=True) + rng = np.random.default_rng(seed) + shuffled = remaining.sample(frac=1, random_state=seed).reset_index(drop=True) + n_val = int(len(shuffled) * val_fraction) + val_locs = shuffled.iloc[:n_val] + train_locs = shuffled.iloc[n_val:] + + return { + "train_indices": train_locs, + "val_indices": val_locs, + "test_indices": test_locs, + } + + +def generate_splits( + data_dir: str, + countries: list[str] | None = None, + val_fraction: float = 0.2, + seed: int = 12345, +) -> None: + """Generate and save LOCO split files for the requested countries. + + :param data_dir: root data directory (same as `paths.data_dir` in configs) + :param countries: list of country codes to generate splits for; None means all + :param val_fraction: fraction of non-test data to use for validation + :param seed: random seed + """ + dataset_dir = Path(data_dir) / DATASET_NAME + csv_path = dataset_dir / MODEL_READY_CSV + splits_dir = dataset_dir / "splits" + + if not csv_path.exists(): + raise FileNotFoundError(f"Model-ready CSV not found: {csv_path}") + + splits_dir.mkdir(parents=True, exist_ok=True) + + df = pd.read_csv(csv_path) + if "country" not in df.columns or "name_loc" not in df.columns: + raise ValueError("CSV must contain 'country' and 'name_loc' columns") + + available = sorted(df["country"].unique().tolist()) + targets = countries if countries is not None else available + + for country in targets: + if country not in available: + log.warning(f"Country '{country}' not found in CSV (available: {available}), skipping") + continue + + split = make_loco_split(df, country, val_fraction=val_fraction, seed=seed) + n_train = len(split["train_indices"]) + n_val = len(split["val_indices"]) + n_test = len(split["test_indices"]) + + out_path = splits_dir / f"split_loco_{country}.pth" + torch.save(split, out_path) + + log.info( + f" {country}: train={n_train}, val={n_val}, test={n_test} " + f"-> {out_path.name}" + ) + print( + f" Saved split_loco_{country}.pth " + f"(train={n_train}, val={n_val}, test={n_test})" + ) + + +def main() -> None: + logging.basicConfig(level=logging.INFO, format="%(message)s") + + parser = argparse.ArgumentParser( + description="Generate leave-one-country-out split files for yield_africa." + ) + parser.add_argument( + "--data_dir", + type=str, + default="data/", + help="Root data directory (same as paths.data_dir in configs). Default: data/", + ) + parser.add_argument( + "--country", + type=str, + default=None, + help="Single country code to generate a split for. Omit to generate all.", + ) + parser.add_argument( + "--val_fraction", + type=float, + default=0.2, + help="Fraction of non-test records used for validation. Default: 0.2", + ) + parser.add_argument( + "--seed", + type=int, + default=12345, + help="Random seed for the train/val shuffle. Default: 12345", + ) + args = parser.parse_args() + + countries = [args.country] if args.country else None + print( + f"Generating LOCO splits data_dir={args.data_dir} " + f"countries={countries or 'all'} val_fraction={args.val_fraction} seed={args.seed}" + ) + generate_splits( + data_dir=args.data_dir, + countries=countries, + val_fraction=args.val_fraction, + seed=args.seed, + ) + print("Done.") + + +if __name__ == "__main__": + main() \ No newline at end of file From 3e1aec0fe6215ce978f7d81d29b85e85ca20d959 Mon Sep 17 00:00:00 2001 From: Rob Knapen Date: Sun, 8 Mar 2026 20:49:01 +0100 Subject: [PATCH 26/60] Crop Yield use case: Hydra configs for various experiments --- configs/data/yield_africa_all.yaml | 4 +- configs/data/yield_africa_loco.yaml | 33 ++++++++++++++ configs/data/yield_africa_tessera.yaml | 31 +++++++++++++ .../experiment/yield_africa_coords_reg.yaml | 2 +- .../experiment/yield_africa_fusion_reg.yaml | 2 +- .../experiment/yield_africa_tabular_loco.yaml | 30 +++++++++++++ .../experiment/yield_africa_tabular_reg.yaml | 2 +- .../yield_africa_tessera_fusion_reg.yaml | 31 +++++++++++++ .../experiment/yield_africa_tessera_reg.yaml | 27 ++++++++++++ configs/model/yield_fusion_reg.yaml | 6 ++- configs/model/yield_geoclip_reg.yaml | 3 +- configs/model/yield_tabular_reg.yaml | 6 ++- configs/model/yield_tessera_fusion_reg.yaml | 43 +++++++++++++++++++ configs/model/yield_tessera_reg.yaml | 32 ++++++++++++++ 14 files changed, 242 insertions(+), 10 deletions(-) create mode 100644 configs/data/yield_africa_loco.yaml create mode 100644 configs/data/yield_africa_tessera.yaml create mode 100644 configs/experiment/yield_africa_tabular_loco.yaml create mode 100644 configs/experiment/yield_africa_tessera_fusion_reg.yaml create mode 100644 configs/experiment/yield_africa_tessera_reg.yaml create mode 100644 configs/model/yield_tessera_fusion_reg.yaml create mode 100644 configs/model/yield_tessera_reg.yaml diff --git a/configs/data/yield_africa_all.yaml b/configs/data/yield_africa_all.yaml index 8667a4c..e40a8a2 100644 --- a/configs/data/yield_africa_all.yaml +++ b/configs/data/yield_africa_all.yaml @@ -13,8 +13,8 @@ dataset: # Country/year filters — set to a list to restrict, null to include all. # countries and years select only the listed values; # exclude_countries and exclude_years drop the listed values. - countries: ["KEN", "RWA", "TAN", "ZAM", "MAL"] - years: [2016, 2017, 2018, 2019, 2020, 2021] + countries: ["BF", "BUR", "ETH", "KEN", "MAL", "RWA", "TAN", "ZAM"] + years: [2014, 2016, 2017, 2018, 2019, 2020, 2021, 2023, 2024] exclude_countries: null exclude_years: null diff --git a/configs/data/yield_africa_loco.yaml b/configs/data/yield_africa_loco.yaml new file mode 100644 index 0000000..3a20b3e --- /dev/null +++ b/configs/data/yield_africa_loco.yaml @@ -0,0 +1,33 @@ +_target_: src.data.base_datamodule.BaseDataModule + +dataset: + _target_: src.data.yield_africa_dataset.YieldAfricaDataset + data_dir: ${paths.data_dir} + modalities: + coords: {} + use_target_data: true + use_features: true + use_aux_data: none + seed: ${seed} + cache_dir: ${paths.cache_dir} + # Include all countries and years so the split file determines the partition. + countries: ["BF", "BUR", "ETH", "KEN", "MAL", "RWA", "TAN", "ZAM"] + years: [2014, 2016, 2017, 2018, 2019, 2020, 2021, 2023, 2024] + exclude_countries: null + exclude_years: null + +batch_size: 64 +num_workers: 0 +pin_memory: false + +# Leave-one-country-out split loaded from a pre-generated file. +# Generate split files first: +# python src/data_preprocessing/yield_africa_loco_splits.py --data_dir +# +# Override saved_split_file_name at the command line to change the held-out country: +# python src/train.py experiment=yield_africa_tabular_loco \ +# data.saved_split_file_name=split_loco_RWA.pth +split_mode: "from_file" +saved_split_file_name: "split_loco_KEN.pth" +save_split: false +seed: ${seed} \ No newline at end of file diff --git a/configs/data/yield_africa_tessera.yaml b/configs/data/yield_africa_tessera.yaml new file mode 100644 index 0000000..e997f19 --- /dev/null +++ b/configs/data/yield_africa_tessera.yaml @@ -0,0 +1,31 @@ +_target_: src.data.base_datamodule.BaseDataModule + +dataset: + _target_: src.data.yield_africa_dataset.YieldAfricaDataset + data_dir: ${paths.data_dir} + modalities: + tessera: + # size must match the tile_size used when running the preprocessing script. + # Default: 9 pixels (set by yield_africa_tessera_preprocess.py --tile_size). + size: 9 + format: npy + # year is intentionally omitted: yield_africa fetches per-record year tiles + # via the preprocessing script rather than a single bulk-year download. + use_target_data: true + use_features: true + use_aux_data: none + seed: ${seed} + cache_dir: ${paths.cache_dir} + countries: ["BF", "BUR", "ETH", "KEN", "MAL", "RWA", "TAN", "ZAM"] + years: [2014, 2016, 2017, 2018, 2019, 2020, 2021, 2023, 2024] + exclude_countries: null + exclude_years: null + +batch_size: 64 +num_workers: 0 +pin_memory: false + +split_mode: "random" +train_val_test_split: [0.7, 0.15, 0.15] +save_split: false +seed: ${seed} \ No newline at end of file diff --git a/configs/experiment/yield_africa_coords_reg.yaml b/configs/experiment/yield_africa_coords_reg.yaml index 5e967e2..5690656 100644 --- a/configs/experiment/yield_africa_coords_reg.yaml +++ b/configs/experiment/yield_africa_coords_reg.yaml @@ -12,7 +12,7 @@ seed: 12345 trainer: min_epochs: 1 - max_epochs: 50 + max_epochs: 150 data: batch_size: 64 diff --git a/configs/experiment/yield_africa_fusion_reg.yaml b/configs/experiment/yield_africa_fusion_reg.yaml index 2cf0093..fa1fbdd 100644 --- a/configs/experiment/yield_africa_fusion_reg.yaml +++ b/configs/experiment/yield_africa_fusion_reg.yaml @@ -12,7 +12,7 @@ seed: 12345 trainer: min_epochs: 1 - max_epochs: 50 + max_epochs: 150 data: batch_size: 64 diff --git a/configs/experiment/yield_africa_tabular_loco.yaml b/configs/experiment/yield_africa_tabular_loco.yaml new file mode 100644 index 0000000..865ab92 --- /dev/null +++ b/configs/experiment/yield_africa_tabular_loco.yaml @@ -0,0 +1,30 @@ +# @package _global_ +# configs/experiment/yield_africa_tabular_loco.yaml +# Tabular-only model evaluated with leave-one-country-out split. +# Default held-out country: KEN (largest, most representative test set). +# +# To evaluate on a different held-out country: +# python src/train.py experiment=yield_africa_tabular_loco \ +# data.saved_split_file_name=split_loco_RWA.pth + +defaults: + - override /model: yield_tabular_reg + - override /data: yield_africa_loco + - override /metrics: yield_africa_regression + +tags: ["yield_africa", "tabular_only", "regression", "loco"] +seed: 12345 + +trainer: + min_epochs: 1 + max_epochs: 150 + +data: + batch_size: 64 + +logger: + wandb: + tags: ${tags} + group: "yield_africa" + aim: + experiment: "yield_africa" \ No newline at end of file diff --git a/configs/experiment/yield_africa_tabular_reg.yaml b/configs/experiment/yield_africa_tabular_reg.yaml index fbcaedf..57f6b36 100644 --- a/configs/experiment/yield_africa_tabular_reg.yaml +++ b/configs/experiment/yield_africa_tabular_reg.yaml @@ -12,7 +12,7 @@ seed: 12345 trainer: min_epochs: 1 - max_epochs: 50 + max_epochs: 150 data: batch_size: 64 diff --git a/configs/experiment/yield_africa_tessera_fusion_reg.yaml b/configs/experiment/yield_africa_tessera_fusion_reg.yaml new file mode 100644 index 0000000..c9edb54 --- /dev/null +++ b/configs/experiment/yield_africa_tessera_fusion_reg.yaml @@ -0,0 +1,31 @@ +# @package _global_ +# configs/experiment/yield_africa_tessera_fusion_reg.yaml +# Variant: TESSERA spatial encoder + tabular features fusion. +# Requires: +# 1. TESSERA tiles pre-fetched: +# python src/data_preprocessing/yield_africa_tessera_preprocess.py --data_dir +# 2. MultiModalEncoder geo_encoder_cfg support. + +defaults: + - override /model: yield_tessera_fusion_reg + - override /data: yield_africa_tessera + - override /metrics: yield_africa_regression + +tags: ["yield_africa", "tessera_fusion", "regression"] +seed: 12345 + +trainer: + min_epochs: 1 + max_epochs: 150 + +data: + batch_size: 64 + dataset: + use_features: true + +logger: + wandb: + tags: ${tags} + group: "yield_africa" + aim: + experiment: "yield_africa" \ No newline at end of file diff --git a/configs/experiment/yield_africa_tessera_reg.yaml b/configs/experiment/yield_africa_tessera_reg.yaml new file mode 100644 index 0000000..7e0c93b --- /dev/null +++ b/configs/experiment/yield_africa_tessera_reg.yaml @@ -0,0 +1,27 @@ +# @package _global_ +# configs/experiment/yield_africa_tessera_reg.yaml +# Variant: TESSERA spatial encoder only (no tabular features). +# Requires tiles pre-fetched by: +# python src/data_preprocessing/yield_africa_tessera_preprocess.py --data_dir + +defaults: + - override /model: yield_tessera_reg + - override /data: yield_africa_tessera + - override /metrics: yield_africa_regression + +tags: ["yield_africa", "tessera_only", "regression"] +seed: 12345 + +trainer: + min_epochs: 1 + max_epochs: 150 + +data: + batch_size: 64 + +logger: + wandb: + tags: ${tags} + group: "yield_africa" + aim: + experiment: "yield_africa" \ No newline at end of file diff --git a/configs/model/yield_fusion_reg.yaml b/configs/model/yield_fusion_reg.yaml index e4e75f6..20ecfd7 100644 --- a/configs/model/yield_fusion_reg.yaml +++ b/configs/model/yield_fusion_reg.yaml @@ -4,12 +4,14 @@ geo_encoder: _target_: src.models.components.geo_encoders.multimodal_encoder.MultiModalEncoder use_coords: true use_tabular: true -# tab_embed_dim: 64 + tab_embed_dim: 256 + tabular_dropout: 0.2 prediction_head: _target_: src.models.components.pred_heads.mlp_regression_head.MLPRegressionPredictionHead nn_layers: 2 hidden_dim: 256 + dropout: 0.2 # GeoClip frozen; tabular projection + head are trained. trainable_modules: [geo_encoder, prediction_head] @@ -20,7 +22,7 @@ optimizer: _target_: torch.optim.Adam _partial_: true lr: 0.001 - weight_decay: 0.0 + weight_decay: 1e-4 scheduler: _target_: torch.optim.lr_scheduler.ReduceLROnPlateau diff --git a/configs/model/yield_geoclip_reg.yaml b/configs/model/yield_geoclip_reg.yaml index f333798..8d5eb5e 100644 --- a/configs/model/yield_geoclip_reg.yaml +++ b/configs/model/yield_geoclip_reg.yaml @@ -9,6 +9,7 @@ prediction_head: _target_: src.models.components.pred_heads.mlp_regression_head.MLPRegressionPredictionHead nn_layers: 2 hidden_dim: 256 + dropout: 0.2 # Only the prediction head is trained; GeoClip encoder is frozen. trainable_modules: [prediction_head] @@ -19,7 +20,7 @@ optimizer: _target_: torch.optim.Adam _partial_: true lr: 0.001 - weight_decay: 0.0 + weight_decay: 1e-4 scheduler: _target_: torch.optim.lr_scheduler.ReduceLROnPlateau diff --git a/configs/model/yield_tabular_reg.yaml b/configs/model/yield_tabular_reg.yaml index af96374..fc95575 100644 --- a/configs/model/yield_tabular_reg.yaml +++ b/configs/model/yield_tabular_reg.yaml @@ -4,12 +4,14 @@ geo_encoder: _target_: src.models.components.geo_encoders.multimodal_encoder.MultiModalEncoder use_coords: false use_tabular: true - tab_embed_dim: 128 + tab_embed_dim: 256 + tabular_dropout: 0.2 prediction_head: _target_: src.models.components.pred_heads.mlp_regression_head.MLPRegressionPredictionHead nn_layers: 2 hidden_dim: 256 + dropout: 0.2 # Both encoder and head have trainable parameters. trainable_modules: [geo_encoder, prediction_head] @@ -20,7 +22,7 @@ optimizer: _target_: torch.optim.Adam _partial_: true lr: 0.001 - weight_decay: 0.0 + weight_decay: 1e-4 scheduler: _target_: torch.optim.lr_scheduler.ReduceLROnPlateau diff --git a/configs/model/yield_tessera_fusion_reg.yaml b/configs/model/yield_tessera_fusion_reg.yaml new file mode 100644 index 0000000..68ffa88 --- /dev/null +++ b/configs/model/yield_tessera_fusion_reg.yaml @@ -0,0 +1,43 @@ +_target_: src.models.predictive_model.PredictiveModel + +# MultiModalEncoder with a pluggable geo encoder. +# The geo_encoder_cfg replaces the hardcoded GeoClipCoordinateEncoder with +# AverageEncoder(tessera), so the spatial branch uses inter-annual phenology +# instead of static coordinate embeddings. +geo_encoder: + _target_: src.models.components.geo_encoders.multimodal_encoder.MultiModalEncoder + use_coords: true + use_tabular: true + tab_embed_dim: 256 + tabular_dropout: 0.2 + geo_encoder_cfg: + _target_: src.models.components.geo_encoders.average_encoder.AverageEncoder + geo_data_name: tessera + # output_dim defaults to 128; no projection needed. + +prediction_head: + _target_: src.models.components.pred_heads.mlp_regression_head.MLPRegressionPredictionHead + nn_layers: 2 + hidden_dim: 256 + dropout: 0.2 + +# geo_encoder includes the tessera AverageEncoder + tabular branch; head is always trained. +trainable_modules: [geo_encoder, prediction_head] + +metrics: ${metrics} + +optimizer: + _target_: torch.optim.Adam + _partial_: true + lr: 0.001 + weight_decay: 1e-4 + +scheduler: + _target_: torch.optim.lr_scheduler.ReduceLROnPlateau + _partial_: true + mode: min + factor: 0.1 + patience: 10 + +loss_fn: + _target_: src.models.components.loss_fns.huber_loss.HuberLoss \ No newline at end of file diff --git a/configs/model/yield_tessera_reg.yaml b/configs/model/yield_tessera_reg.yaml new file mode 100644 index 0000000..dcd84e2 --- /dev/null +++ b/configs/model/yield_tessera_reg.yaml @@ -0,0 +1,32 @@ +_target_: src.models.predictive_model.PredictiveModel + +geo_encoder: + _target_: src.models.components.geo_encoders.average_encoder.AverageEncoder + geo_data_name: tessera + # output_dim defaults to 128 (the native tessera channel count); no projection. + +prediction_head: + _target_: src.models.components.pred_heads.mlp_regression_head.MLPRegressionPredictionHead + nn_layers: 2 + hidden_dim: 256 + dropout: 0.2 + +trainable_modules: [geo_encoder, prediction_head] + +metrics: ${metrics} + +optimizer: + _target_: torch.optim.Adam + _partial_: true + lr: 0.001 + weight_decay: 1e-4 + +scheduler: + _target_: torch.optim.lr_scheduler.ReduceLROnPlateau + _partial_: true + mode: min + factor: 0.1 + patience: 10 + +loss_fn: + _target_: src.models.components.loss_fns.huber_loss.HuberLoss \ No newline at end of file From eb7d802409bea651a5d5c90c52d76cda1078b816 Mon Sep 17 00:00:00 2001 From: Rob Knapen Date: Sun, 8 Mar 2026 23:10:08 +0100 Subject: [PATCH 27/60] Fixed the freezer --- src/models/base_model.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/src/models/base_model.py b/src/models/base_model.py index 3e931b7..7b3cb0a 100644 --- a/src/models/base_model.py +++ b/src/models/base_model.py @@ -61,9 +61,23 @@ def freezer(self) -> None: # Freeze the rest param.requires_grad = False - # Set module modes correctly + # Set module modes correctly. + # A module should be in train() if: + # - it IS a trainable module (name == t), or + # - it is a CHILD of a trainable module (name starts with t + "."), or + # - it is an ANCESTOR of a trainable module (t starts with name + "."), + # so that container modules reflect the correct mode, or + # - it is the root module (""), which must be train when any child is. + def _in_train_scope(name: str) -> bool: + if not name: # root module + return bool(self.trainable_modules) + for t in self.trainable_modules: + if name == t or name.startswith(t + ".") or t.startswith(name + "."): + return True + return False + for name, module in self.named_modules(): - if any(t.startswith(name) for t in self.trainable_modules): + if _in_train_scope(name): module.train() else: module.eval() From e6e00b5737f7d5abdf9760e76d12a55f3c933cc2 Mon Sep 17 00:00:00 2001 From: Rob Knapen Date: Sun, 8 Mar 2026 23:16:30 +0100 Subject: [PATCH 28/60] Adds tessera embeddings directory to .env.example --- .env.example | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.env.example b/.env.example index 33713cb..4b8c52e 100644 --- a/.env.example +++ b/.env.example @@ -10,6 +10,10 @@ TRAINER_PROFILE="gpu" # cpu/gpu/mps/ddp HF_HOME="${PROJECT_ROOT}/.cache/huggingface/" # set or will default to './.cache/huggingface/' DATA_DIR="${PROJECT_ROOT}/data/" # set to your local data folder (for aether), or will default to '${PROJECT_ROOT}/data/' +# When using (and downloading) TESSERA embeddings (e.g., crop yield use case) +# Note that this folder can get large ... +TESSERA_EMBEDDINGS_DIR="${PROJECT_ROOT}/data/cache/tessera/" + # Working directories # STORAGE_MODE=# or "shared" # SHARED_CACHE=# or "/path/to/shared/.cache" From 7e71170736cbaac72940157ef334d3d39d4f3348 Mon Sep 17 00:00:00 2001 From: Rob Knapen Date: Tue, 10 Mar 2026 11:19:04 +0100 Subject: [PATCH 29/60] Makes L2 normalization optional in predictive model. Switching it off preserves feature magnitudes which tends to work better for regression tasks. --- configs/model/yield_fusion_reg.yaml | 2 ++ configs/model/yield_geoclip_reg.yaml | 2 ++ configs/model/yield_tabular_reg.yaml | 2 ++ configs/model/yield_tessera_fusion_reg.yaml | 2 ++ configs/model/yield_tessera_reg.yaml | 2 ++ src/models/predictive_model.py | 10 +++++++++- 6 files changed, 19 insertions(+), 1 deletion(-) diff --git a/configs/model/yield_fusion_reg.yaml b/configs/model/yield_fusion_reg.yaml index 20ecfd7..df20414 100644 --- a/configs/model/yield_fusion_reg.yaml +++ b/configs/model/yield_fusion_reg.yaml @@ -15,6 +15,8 @@ prediction_head: # GeoClip frozen; tabular projection + head are trained. trainable_modules: [geo_encoder, prediction_head] +# Disable L2 normalization before MLP regression to keep magnitude of features +normalize_features: false metrics: ${metrics} diff --git a/configs/model/yield_geoclip_reg.yaml b/configs/model/yield_geoclip_reg.yaml index 8d5eb5e..39b3e97 100644 --- a/configs/model/yield_geoclip_reg.yaml +++ b/configs/model/yield_geoclip_reg.yaml @@ -13,6 +13,8 @@ prediction_head: # Only the prediction head is trained; GeoClip encoder is frozen. trainable_modules: [prediction_head] +# Disable L2 normalization before MLP regression to keep magnitude of features +normalize_features: false metrics: ${metrics} diff --git a/configs/model/yield_tabular_reg.yaml b/configs/model/yield_tabular_reg.yaml index fc95575..c3b659d 100644 --- a/configs/model/yield_tabular_reg.yaml +++ b/configs/model/yield_tabular_reg.yaml @@ -15,6 +15,8 @@ prediction_head: # Both encoder and head have trainable parameters. trainable_modules: [geo_encoder, prediction_head] +# Disable L2 normalization before MLP regression to keep magnitude of features +normalize_features: false metrics: ${metrics} diff --git a/configs/model/yield_tessera_fusion_reg.yaml b/configs/model/yield_tessera_fusion_reg.yaml index 68ffa88..a8e0e8a 100644 --- a/configs/model/yield_tessera_fusion_reg.yaml +++ b/configs/model/yield_tessera_fusion_reg.yaml @@ -23,6 +23,8 @@ prediction_head: # geo_encoder includes the tessera AverageEncoder + tabular branch; head is always trained. trainable_modules: [geo_encoder, prediction_head] +# Disable L2 normalization before MLP regression to keep magnitude of features +normalize_features: false metrics: ${metrics} diff --git a/configs/model/yield_tessera_reg.yaml b/configs/model/yield_tessera_reg.yaml index dcd84e2..69adb0b 100644 --- a/configs/model/yield_tessera_reg.yaml +++ b/configs/model/yield_tessera_reg.yaml @@ -12,6 +12,8 @@ prediction_head: dropout: 0.2 trainable_modules: [geo_encoder, prediction_head] +# Disable L2 normalization before MLP regression to keep magnitude of features +normalize_features: false metrics: ${metrics} diff --git a/src/models/predictive_model.py b/src/models/predictive_model.py index be69a34..4f37412 100644 --- a/src/models/predictive_model.py +++ b/src/models/predictive_model.py @@ -1,6 +1,7 @@ from typing import Dict, override import torch +import torch.nn.functional as F from src.models.base_model import BaseModel from src.models.components.geo_encoders.base_geo_encoder import BaseGeoEncoder @@ -22,8 +23,9 @@ def __init__( scheduler: torch.optim.lr_scheduler, loss_fn: BaseLossFn, metrics: MetricsWrapper, + normalize_features: bool = True, ) -> None: - """Implementation of the predictive model with replaceable EO encoder, and prediction head. + """Implementation of the predictive model with replaceable GEO encoder, and prediction head. :param geo_encoder: geo encoder module (replaceable) :param prediction_head: prediction head module (replaceable) @@ -34,6 +36,8 @@ def __init__( :param metrics: metrics to use for model performance evaluation :param num_classes: number of target classes :param tabular_dim: number of tabular features + :param normalize_features: if True, apply L2 normalisation to encoder output before + the prediction head (default: True) """ super().__init__(trainable_modules, optimizer, scheduler, loss_fn, metrics) @@ -44,6 +48,8 @@ def __init__( # Prediction head self.prediction_head = prediction_head + self.normalize_features = normalize_features + @override def setup(self, stage: str) -> None: self.num_classes = self.trainer.datamodule.num_classes @@ -74,6 +80,8 @@ def setup_encoders_adapters(self): @override def forward(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: feats = self.geo_encoder(batch) + if self.normalize_features: + feats = F.normalize(feats, dim=-1) return self.prediction_head(feats) @override From c4f75b8c8fefac883bea5275cc8870ffc61e8ea4 Mon Sep 17 00:00:00 2001 From: Rob Knapen Date: Tue, 10 Mar 2026 12:06:46 +0100 Subject: [PATCH 30/60] Crop Yield use case: Consolidated / clarified cache use for downloaded raw tessera tiles --- .env.example | 6 ++- .../make_model_ready_yield_africa.py | 2 +- .../yield_africa_tessera_preprocess.py | 49 +++++++++---------- 3 files changed, 29 insertions(+), 28 deletions(-) diff --git a/.env.example b/.env.example index 4b8c52e..4e8a6ee 100644 --- a/.env.example +++ b/.env.example @@ -10,8 +10,10 @@ TRAINER_PROFILE="gpu" # cpu/gpu/mps/ddp HF_HOME="${PROJECT_ROOT}/.cache/huggingface/" # set or will default to './.cache/huggingface/' DATA_DIR="${PROJECT_ROOT}/data/" # set to your local data folder (for aether), or will default to '${PROJECT_ROOT}/data/' -# When using (and downloading) TESSERA embeddings (e.g., crop yield use case) -# Note that this folder can get large ... +# Base cache directory for TESSERA. +# GeoTessera registry/metadata is stored here; large raw source tiles go in the +# raw/ subfolder. This folder can get very large — point it at an external drive +# if needed. TESSERA_EMBEDDINGS_DIR="${PROJECT_ROOT}/data/cache/tessera/" # Working directories diff --git a/src/data_preprocessing/make_model_ready_yield_africa.py b/src/data_preprocessing/make_model_ready_yield_africa.py index 814d901..25c4399 100644 --- a/src/data_preprocessing/make_model_ready_yield_africa.py +++ b/src/data_preprocessing/make_model_ready_yield_africa.py @@ -679,7 +679,7 @@ def main( ap.add_argument( "--out_csv", required=True, - help="Path for the output model-ready CSV (e.g. data/yield_africa/model_ready_yield-africa.csv)", + help="Path for the output model-ready CSV (e.g. data/yield_africa/model_ready_yield_africa.csv)", ) ap.add_argument( "--out_parquet", diff --git a/src/data_preprocessing/yield_africa_tessera_preprocess.py b/src/data_preprocessing/yield_africa_tessera_preprocess.py index 3fcfbfa..15823d1 100644 --- a/src/data_preprocessing/yield_africa_tessera_preprocess.py +++ b/src/data_preprocessing/yield_africa_tessera_preprocess.py @@ -65,7 +65,6 @@ def fetch_tessera_tiles( countries: list[str] | None = None, years: list[int] | None = None, cache_dir: str | None = None, - embeddings_dir: str | None = None, workers: int = 2, ) -> None: """Fetch TESSERA tiles for every record in the yield_africa CSV. @@ -74,13 +73,12 @@ def fetch_tessera_tiles( :param tile_size: spatial extent of each tile in pixels :param countries: optional list of country codes to restrict fetching :param years: optional list of years to restrict fetching - :param cache_dir: directory for GeoTessera's internal registry cache; - defaults to ``{data_dir}/cache/tessera`` - :param embeddings_dir: directory where GeoTessera stores the raw downloaded - embedding tiles (``global_0.1_degree_representation/`` etc.). Defaults - to the current working directory when not set, which can silently fill - the project folder with tens of GB of data. Point this at an external - drive when disk space is limited. + :param cache_dir: base directory for all TESSERA cache files. GeoTessera's + internal registry is stored here; the large raw downloaded source tiles + (``global_0.1_degree_representation/`` etc.) are kept in the ``raw/`` + subfolder. Defaults to the ``TESSERA_EMBEDDINGS_DIR`` env var when set, + otherwise ``{data_dir}/cache/tessera``. Point this at an external drive + when disk space is limited. :param workers: number of parallel download threads. Each thread keeps its own GeoTessera instance to avoid shared state. Default: 2 (external drive I/O is usually the bottleneck; more workers add contention). @@ -95,7 +93,9 @@ def fetch_tessera_tiles( save_dir.mkdir(parents=True, exist_ok=True) if cache_dir is None: - cache_dir = str(Path(data_dir) / "cache" / "tessera") + cache_dir = os.environ.get("TESSERA_EMBEDDINGS_DIR") or str(Path(data_dir) / "cache" / "tessera") + + embeddings_dir = str(Path(cache_dir) / "raw") df = pd.read_csv(csv_path) @@ -116,7 +116,9 @@ def fetch_tessera_tiles( print( f"Records: {n_total} total, {n_existing} already cached, " - f"{n_to_fetch} to fetch (tile_size={tile_size}, workers={workers})" + f"{n_to_fetch} to fetch (tile_size={tile_size}, workers={workers})\n" + f" cache_dir : {cache_dir}\n" + f" embeddings_dir: {embeddings_dir}" ) # Build GeoTessera constructor kwargs shared across all threads. @@ -130,8 +132,7 @@ def fetch_tessera_tiles( # noticeable CPU time per tile and making progress look stalled. "verify_hashes": False, } - if embeddings_dir is not None: - _gt_kwargs["embeddings_dir"] = embeddings_dir + _gt_kwargs["embeddings_dir"] = embeddings_dir _thread_local = threading.local() @@ -243,25 +244,24 @@ def main() -> None: "--cache_dir", type=str, default=None, - help="GeoTessera internal registry cache directory. Default: {data_dir}/cache/tessera", - ) - parser.add_argument( - "--embeddings_dir", - type=str, - default=os.environ.get("TESSERA_EMBEDDINGS_DIR"), help=( - "Directory for GeoTessera raw source tiles " - "(global_0.1_degree_representation/ etc.). " - "Falls back to the TESSERA_EMBEDDINGS_DIR env var, then the current " - "working directory. Set TESSERA_EMBEDDINGS_DIR in .env to avoid " - "passing this flag every run." + "Base directory for all TESSERA cache files. " + "GeoTessera's registry is stored here; large raw source tiles go in " + "the raw/ subfolder. " + "Falls back to the TESSERA_EMBEDDINGS_DIR env var, then " + "{data_dir}/cache/tessera. Set TESSERA_EMBEDDINGS_DIR in .env to " + "avoid passing this flag every run." ), ) parser.add_argument( "--workers", type=int, default=2, - help="Number of parallel download threads. Default: 2", + help=( + "Number of parallel download threads. Default: 2. " + "When writing to an external drive too many workers can cause I/O " + "bottlenecks. Reduce the number of workers to improve throughput." + ), ) args = parser.parse_args() @@ -276,7 +276,6 @@ def main() -> None: countries=args.countries, years=args.years, cache_dir=args.cache_dir, - embeddings_dir=args.embeddings_dir, workers=args.workers, ) From 829dd67c2c406ff0b621bdfa7f35e00cc557c343 Mon Sep 17 00:00:00 2001 From: GabrieleTi Date: Tue, 10 Mar 2026 14:59:06 +0100 Subject: [PATCH 31/60] Create tabular encoder --- .../geo_encoders/tabular_encoder.py | 60 +++++++++++++++++++ 1 file changed, 60 insertions(+) create mode 100644 src/models/components/geo_encoders/tabular_encoder.py diff --git a/src/models/components/geo_encoders/tabular_encoder.py b/src/models/components/geo_encoders/tabular_encoder.py new file mode 100644 index 0000000..47bc621 --- /dev/null +++ b/src/models/components/geo_encoders/tabular_encoder.py @@ -0,0 +1,60 @@ +from typing import Dict, override + +import torch +from torch import nn + +from src.models.components.geo_encoders.base_geo_encoder import BaseGeoEncoder + + +class TabularEncoder(BaseGeoEncoder): + """Tabular data encoder.""" + + def __init__( + self, + output_dim: int, + input_dim: int | None = None, + hidden_dim: int | None = None, + dropout_prob: float = 0.0, + geo_data_name: str = "tabular", + ): + super().__init__() + + self.input_dim = input_dim + self.output_dim = output_dim + self.hidden_dim = hidden_dim + self.dropout_prob = dropout_prob + + self.geo_encoder: nn.Module | None = None + + self.allowed_geo_data_names = ["tabular"] + assert ( + geo_data_name in self.allowed_geo_data_names + ), f"geo_data_name must be one of {self.allowed_geo_data_names}, got {geo_data_name}" + self.geo_data_name = geo_data_name + + def configure_nn(self, input_dim: int) -> None: + self.input_dim = input_dim + if self.hidden_dim is None: + self.hidden_dim = max(self.input_dim * 2, 128) + + self.geo_encoder = nn.Sequential( + nn.LayerNorm(self.input_dim), + nn.Linear(self.input_dim, self.hidden_dim), + nn.ReLU(), + nn.Dropout(self.dropout_prob), + nn.Linear(self.hidden_dim, self.hidden_dim // 2), + nn.ReLU(), + nn.Dropout(self.dropout_prob), + nn.Linear(self.hidden_dim // 2, self.output_dim), + ) + + @override + def forward(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: + tab_data = batch.get("eo", {}).get("tabular") + + dtype = self.dtype + if tab_data.dtype != dtype: + tab_data = tab_data.to(dtype) + feats = self.geo_encoder(tab_data) + + return feats.to(dtype) From a215e275cab67d59d5e7035dd4ef41fa1c7ff371 Mon Sep 17 00:00:00 2001 From: Rob Knapen Date: Wed, 11 Mar 2026 10:45:20 +0100 Subject: [PATCH 32/60] Makes DBScan clustering more efficient and much faster. --- src/data/base_datamodule.py | 31 +++++++++++++++++++------------ 1 file changed, 19 insertions(+), 12 deletions(-) diff --git a/src/data/base_datamodule.py b/src/data/base_datamodule.py index 201c671..1ef0a48 100644 --- a/src/data/base_datamodule.py +++ b/src/data/base_datamodule.py @@ -1,12 +1,12 @@ import copy import os +import time from functools import partial from typing import Any, Dict, List, Tuple import numpy as np import pandas as pd import torch -from geopy.distance import distance as geodist # avoid naming confusion from lightning import LightningDataModule from sklearn.cluster import DBSCAN from sklearn.model_selection import GroupShuffleSplit @@ -121,20 +121,27 @@ def split_data(self) -> None: } elif self.hparams.split_mode == "spatial_clusters": - print("Splitting dataset using spatial clusters. This can take a while...") - coords = np.array([self.dataset.df.lat, self.dataset.df.lon]).T - if len(coords) > 2000: - print( - "Warning: DBSCAN clustering on more than 2000 samples may be slow. Maybe set n_jobs in DBScan?" - ) - # 4000 m distance between points. Use geodist to calculate true distance. min_dist = self.hparams.spatial_split_distance_m + coords = np.array([self.dataset.df.lat, self.dataset.df.lon]).T + n = len(coords) + print( + f"Splitting {n} samples into spatial clusters " + f"(eps={min_dist / 1000:.1f} km, haversine, n_jobs=-1)..." + ) + # Convert (lat, lon) degrees to radians for sklearn's haversine metric. + # haversine returns arc length on the unit sphere, so eps must be in radians. + _EARTH_RADIUS_M = 6_371_000 + coords_rad = np.radians(coords) + eps_rad = min_dist / _EARTH_RADIUS_M + t0 = time.time() clustering = DBSCAN( - eps=min_dist, - metric=lambda u, v: geodist(u, v).meters, + eps=eps_rad, + metric="haversine", + algorithm="ball_tree", min_samples=2, - ).fit(coords) - print("Clustering done. Creating splits and saving.") + n_jobs=-1, + ).fit(coords_rad) + print(f"DBSCAN done in {time.time() - t0:.1f}s. Creating splits...") # Non-clustered points are labeled -1. Change to new cluster label. clusters = copy.deepcopy(clustering.labels_) new_cl = np.max(clusters) + 1 From 221a495d49a33380e1fd9c7bae7e0868b89dc734 Mon Sep 17 00:00:00 2001 From: GabrieleTi Date: Wed, 11 Mar 2026 10:46:36 +0100 Subject: [PATCH 33/60] Create mlp projector/adapter --- .../components/geo_encoders/mlp_projector.py | 44 +++++++++++++++++++ 1 file changed, 44 insertions(+) create mode 100644 src/models/components/geo_encoders/mlp_projector.py diff --git a/src/models/components/geo_encoders/mlp_projector.py b/src/models/components/geo_encoders/mlp_projector.py new file mode 100644 index 0000000..5417eb5 --- /dev/null +++ b/src/models/components/geo_encoders/mlp_projector.py @@ -0,0 +1,44 @@ +import torch +from torch import nn + +from src.models.components.geo_encoders.base_geo_encoder import BaseGeoEncoder + + +class MLPProjector(BaseGeoEncoder): + def __init__( + self, + output_dim: int, + input_dim: int | None = None, + nn_layers: int = 2, + hidden_dim: int = 256, + ) -> None: + super().__init__() + + self.output_dim = output_dim + self.input_dim = input_dim + self.nn_layers = nn_layers + self.hidden_dim = hidden_dim + + # Placeholder + self.net: nn.Module | None = None + + def set_input_dim(self, input_dim: int) -> None: + self.input_dim = input_dim + + def configure_nn(self) -> None: + """Configure the MLP network.""" + assert self.input_dim is not None, "input_dim must be defined" + assert self.output_dim is not None, "output_dim must be defined" + layers = [] + input_dim = self.input_dim + + for i in range(self.nn_layers - 1): + layers.append(nn.Linear(input_dim, self.hidden_dim)) + layers.append(nn.ReLU()) + input_dim = self.hidden_dim + + layers.append(nn.Linear(input_dim, self.output_dim)) + self.net = nn.Sequential(*layers) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.net(x) From 6e083b192c38adeccde73eaa8dfa44885b954407 Mon Sep 17 00:00:00 2001 From: GabrieleTi Date: Wed, 11 Mar 2026 10:47:11 +0100 Subject: [PATCH 34/60] Replace multi-modal encoder with wrapper --- configs/model/heat_fusion_reg.yaml | 15 ++- configs/model/heat_geoclip_reg.yaml | 4 +- configs/model/heat_tabular_reg.yaml | 7 +- configs/model/yield_fusion_reg.yaml | 17 +++- configs/model/yield_geoclip_reg.yaml | 4 +- configs/model/yield_tabular_reg.yaml | 9 +- configs/model/yield_tessera_fusion_reg.yaml | 25 +++-- .../geo_encoders/encoder_wrapper.py | 95 +++++++++++++++++++ src/models/predictive_model.py | 20 ++-- 9 files changed, 153 insertions(+), 43 deletions(-) create mode 100644 src/models/components/geo_encoders/encoder_wrapper.py diff --git a/configs/model/heat_fusion_reg.yaml b/configs/model/heat_fusion_reg.yaml index ff5a69b..3671555 100644 --- a/configs/model/heat_fusion_reg.yaml +++ b/configs/model/heat_fusion_reg.yaml @@ -9,10 +9,17 @@ _target_: src.models.predictive_model.PredictiveModel geo_encoder: - _target_: src.models.components.geo_encoders.multimodal_encoder.MultiModalEncoder - use_coords: true - use_tabular: true -# tab_embed_dim: 64 + _target_: src.models.components.geo_encoders.encoder_wrapper.EncoderWrapper + + encoder_branches: + - encoder: + _target_: src.models.components.geo_encoders.geoclip.GeoClipCoordinateEncoder + - encoder: + _target_: src.models.components.geo_encoders.tabular_encoder.TabularEncoder + output_dim: 64 + geo_data_name: tabular + + fusion_strategy: "concat" prediction_head: _target_: src.models.components.pred_heads.mlp_regression_head.MLPRegressionPredictionHead diff --git a/configs/model/heat_geoclip_reg.yaml b/configs/model/heat_geoclip_reg.yaml index 0e8f11a..d29701c 100644 --- a/configs/model/heat_geoclip_reg.yaml +++ b/configs/model/heat_geoclip_reg.yaml @@ -9,9 +9,7 @@ _target_: src.models.predictive_model.PredictiveModel geo_encoder: - _target_: src.models.components.geo_encoders.multimodal_encoder.MultiModalEncoder - use_coords: true - use_tabular: false + _target_: src.models.components.geo_encoders.geoclip.GeoClipCoordinateEncoder prediction_head: _target_: src.models.components.pred_heads.mlp_regression_head.MLPRegressionPredictionHead diff --git a/configs/model/heat_tabular_reg.yaml b/configs/model/heat_tabular_reg.yaml index 8ca1184..c7f9fff 100644 --- a/configs/model/heat_tabular_reg.yaml +++ b/configs/model/heat_tabular_reg.yaml @@ -17,10 +17,9 @@ _target_: src.models.predictive_model.PredictiveModel metrics: ${metrics} geo_encoder: - _target_: src.models.components.geo_encoders.multimodal_encoder.MultiModalEncoder - use_coords: false - use_tabular: true - tab_embed_dim: 64 + _target_: src.models.components.geo_encoders.tabular_encoder.TabularEncoder + output_dim: 64 + geo_data_name: tabular prediction_head: _target_: src.models.components.pred_heads.mlp_regression_head.MLPRegressionPredictionHead diff --git a/configs/model/yield_fusion_reg.yaml b/configs/model/yield_fusion_reg.yaml index df20414..05b680f 100644 --- a/configs/model/yield_fusion_reg.yaml +++ b/configs/model/yield_fusion_reg.yaml @@ -1,11 +1,18 @@ _target_: src.models.predictive_model.PredictiveModel geo_encoder: - _target_: src.models.components.geo_encoders.multimodal_encoder.MultiModalEncoder - use_coords: true - use_tabular: true - tab_embed_dim: 256 - tabular_dropout: 0.2 + _target_: src.models.components.geo_encoders.encoder_wrapper.EncoderWrapper + + encoder_branches: + - encoder: + _target_: src.models.components.geo_encoders.geoclip.GeoClipCoordinateEncoder + - encoder: + _target_: src.models.components.geo_encoders.tabular_encoder.TabularEncoder + output_dim: 256 + dropout_prob: 0.2 + geo_data_name: tabular + + fusion_strategy: "concat" prediction_head: _target_: src.models.components.pred_heads.mlp_regression_head.MLPRegressionPredictionHead diff --git a/configs/model/yield_geoclip_reg.yaml b/configs/model/yield_geoclip_reg.yaml index 39b3e97..20978e6 100644 --- a/configs/model/yield_geoclip_reg.yaml +++ b/configs/model/yield_geoclip_reg.yaml @@ -1,9 +1,7 @@ _target_: src.models.predictive_model.PredictiveModel geo_encoder: - _target_: src.models.components.geo_encoders.multimodal_encoder.MultiModalEncoder - use_coords: true - use_tabular: false + _target_: src.models.components.geo_encoders.geoclip.GeoClipCoordinateEncoder prediction_head: _target_: src.models.components.pred_heads.mlp_regression_head.MLPRegressionPredictionHead diff --git a/configs/model/yield_tabular_reg.yaml b/configs/model/yield_tabular_reg.yaml index c3b659d..22d431a 100644 --- a/configs/model/yield_tabular_reg.yaml +++ b/configs/model/yield_tabular_reg.yaml @@ -1,11 +1,10 @@ _target_: src.models.predictive_model.PredictiveModel geo_encoder: - _target_: src.models.components.geo_encoders.multimodal_encoder.MultiModalEncoder - use_coords: false - use_tabular: true - tab_embed_dim: 256 - tabular_dropout: 0.2 + _target_: src.models.components.geo_encoders.tabular_encoder.TabularEncoder + output_dim: 256 + dropout_prob: 0.2 + geo_data_name: tabular prediction_head: _target_: src.models.components.pred_heads.mlp_regression_head.MLPRegressionPredictionHead diff --git a/configs/model/yield_tessera_fusion_reg.yaml b/configs/model/yield_tessera_fusion_reg.yaml index a8e0e8a..0fc49d4 100644 --- a/configs/model/yield_tessera_fusion_reg.yaml +++ b/configs/model/yield_tessera_fusion_reg.yaml @@ -4,16 +4,21 @@ _target_: src.models.predictive_model.PredictiveModel # The geo_encoder_cfg replaces the hardcoded GeoClipCoordinateEncoder with # AverageEncoder(tessera), so the spatial branch uses inter-annual phenology # instead of static coordinate embeddings. + geo_encoder: - _target_: src.models.components.geo_encoders.multimodal_encoder.MultiModalEncoder - use_coords: true - use_tabular: true - tab_embed_dim: 256 - tabular_dropout: 0.2 - geo_encoder_cfg: - _target_: src.models.components.geo_encoders.average_encoder.AverageEncoder - geo_data_name: tessera - # output_dim defaults to 128; no projection needed. + _target_: src.models.components.geo_encoders.encoder_wrapper.EncoderWrapper + + encoder_branches: + - encoder: + _target_: src.models.components.geo_encoders.average_encoder.AverageEncoder + geo_data_name: tessera + - encoder: + _target_: src.models.components.geo_encoders.tabular_encoder.TabularEncoder + output_dim: 256 + dropout_prob: 0.2 + geo_data_name: tabular + + fusion_strategy: "concat" prediction_head: _target_: src.models.components.pred_heads.mlp_regression_head.MLPRegressionPredictionHead @@ -42,4 +47,4 @@ scheduler: patience: 10 loss_fn: - _target_: src.models.components.loss_fns.huber_loss.HuberLoss \ No newline at end of file + _target_: src.models.components.loss_fns.huber_loss.HuberLoss diff --git a/src/models/components/geo_encoders/encoder_wrapper.py b/src/models/components/geo_encoders/encoder_wrapper.py new file mode 100644 index 0000000..bbc8453 --- /dev/null +++ b/src/models/components/geo_encoders/encoder_wrapper.py @@ -0,0 +1,95 @@ +from typing import Any, Dict, List, override + +import torch + +from src.models.components.geo_encoders.base_geo_encoder import BaseGeoEncoder +from src.models.components.geo_encoders.tabular_encoder import TabularEncoder + + +class EncoderWrapper(BaseGeoEncoder): + """Wrapper class for multi-modal encoders.""" + + def __init__( + self, + encoder_branches: List[Dict[str, Any]], + fusion_strategy: str, + ): + super().__init__() + + self.encoder_branches = encoder_branches + assert fusion_strategy in ["mean", "concat", "none"], ValueError( + f'Unsupported fusion strategy "{fusion_strategy}"' + ) + self.fusion_strategy = fusion_strategy + self.output_dim = None + + # Configure/initialise missing/conditional parts + for branch in self.encoder_branches: + intermediate_dim = branch.get("encoder").output_dim + projector = branch.get("projector", None) + if projector is not None: + projector.set_input_dim(input_dim=intermediate_dim) + projector.configure_nn() + + def configure_nn(self, tabular_dim: int) -> None: + output_dims = [] + new_parts = set() + for branch in self.encoder_branches: + if isinstance(branch["encoder"], TabularEncoder): + branch["encoder"].configure_nn(tabular_dim) + new_parts.add("ta") + if branch.get("projector"): + output_dims.append(branch["projector"].output_dim) + else: + output_dims.append(branch["encoder"].output_dim) + + if self.fusion_strategy == "concat": + self.output_dim = sum(output_dims) + elif self.fusion_strategy == "mean": + assert set(output_dims) == 1, ValueError( + f"Encoder branches produces different output dimensions {output_dims} and cannot be averaged." + ) + self.output_dim = output_dims[0] + + @override + def forward(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: + branch_feats = [] + for branch in self.encoder_branches: + feats = branch["encoder"].forward(batch) # each encoder knows what modality it needs + + if branch.get("projector", None): + feats = branch["projector"].forward(feats) + + branch_feats.append(feats) + + if self.fusion_strategy == "concat": + return torch.cat(branch_feats, dim=1) + return torch.mean(branch_feats, dim=1) + + @property + def device(self): + devices = set() + for branch in self.encoder_branches: + encoder = branch["encoder"] + devices.update({p.device for p in encoder.parameters()}) + projector = branch.get("projector") + if projector is not None: + devices.update({p.device for p in projector.parameters()}) + + if len(devices) != 1: + raise RuntimeError("GEO encoder is on multiple devices") + return devices.pop() + + @property + def dtype(self) -> torch.dtype: + dtypes = set() + for branch in self.encoder_branches: + encoder = branch["encoder"] + dtypes.update({p.dtype for p in encoder.parameters()}) + projector = branch.get("projector") + if projector is not None: + dtypes.update({p.dtype for p in projector.parameters()}) + + if len(dtypes) != 1: + raise RuntimeError("GEO encoder is on multiple devices") + return dtypes.pop() diff --git a/src/models/predictive_model.py b/src/models/predictive_model.py index 4f37412..2954584 100644 --- a/src/models/predictive_model.py +++ b/src/models/predictive_model.py @@ -5,7 +5,8 @@ from src.models.base_model import BaseModel from src.models.components.geo_encoders.base_geo_encoder import BaseGeoEncoder -from src.models.components.geo_encoders.multimodal_encoder import MultiModalEncoder +from src.models.components.geo_encoders.encoder_wrapper import EncoderWrapper +from src.models.components.geo_encoders.tabular_encoder import TabularEncoder from src.models.components.loss_fns.base_loss_fn import BaseLossFn from src.models.components.metrics.metrics_wrapper import MetricsWrapper from src.models.components.pred_heads.linear_pred_head import ( @@ -25,7 +26,8 @@ def __init__( metrics: MetricsWrapper, normalize_features: bool = True, ) -> None: - """Implementation of the predictive model with replaceable GEO encoder, and prediction head. + """Implementation of the predictive model with replaceable GEO encoder, and prediction + head. :param geo_encoder: geo encoder module (replaceable) :param prediction_head: prediction head module (replaceable) @@ -36,8 +38,8 @@ def __init__( :param metrics: metrics to use for model performance evaluation :param num_classes: number of target classes :param tabular_dim: number of tabular features - :param normalize_features: if True, apply L2 normalisation to encoder output before - the prediction head (default: True) + :param normalize_features: if True, apply L2 normalisation to encoder output before the + prediction head (default: True) """ super().__init__(trainable_modules, optimizer, scheduler, loss_fn, metrics) @@ -63,12 +65,12 @@ def setup(self, stage: str) -> None: def setup_encoders_adapters(self): """Set up encoders and missing adapters/projectors.""" # TODO: move to multi-modal eo encoder - if ( - isinstance(self.geo_encoder, MultiModalEncoder) - and self.geo_encoder.use_tabular - and not self.geo_encoder._tabular_ready + if isinstance(self.geo_encoder, TabularEncoder) or isinstance( + self.geo_encoder, EncoderWrapper ): - self.geo_encoder.build_tabular_branch(self.tabular_dim) + self.geo_encoder.configure_nn(self.tabular_dim) + if self.tabular_dim: + self.trainable_modules.append("tabular_encoder") self.prediction_head.set_dim( input_dim=self.geo_encoder.output_dim, output_dim=self.num_classes From f4aa97b468853761aac54f3e547cc6c6c6d91c33 Mon Sep 17 00:00:00 2001 From: Rob Knapen Date: Wed, 11 Mar 2026 10:53:12 +0100 Subject: [PATCH 35/60] Crop Yield use case: spatial splitting --- .../yield_africa_spatial_splits.py | 326 ++++++++++++++++++ 1 file changed, 326 insertions(+) create mode 100644 src/data_preprocessing/yield_africa_spatial_splits.py diff --git a/src/data_preprocessing/yield_africa_spatial_splits.py b/src/data_preprocessing/yield_africa_spatial_splits.py new file mode 100644 index 0000000..4c9aa98 --- /dev/null +++ b/src/data_preprocessing/yield_africa_spatial_splits.py @@ -0,0 +1,326 @@ +"""Generate spatial-cluster split files for the yield_africa dataset. + +Location: src/data_preprocessing/yield_africa_spatial_splits.py + +Uses DBSCAN with a haversine distance metric to group nearby field locations +into clusters, then assigns whole clusters to train/val/test so that no +geographically close points straddle a split boundary. + +One `.pth` file is written per distance threshold to +`{data_dir}/yield_africa/splits/split_spatial_{distance_km}km.pth`. + +Split layout +------------ +- train : ~70 % of records (cluster-aligned) +- val : ~15 % of records (cluster-aligned) +- test : ~15 % of records (cluster-aligned) + +Proportions are approximate because whole clusters are kept intact. + +The files are consumed by BaseDataModule when `split_mode: from_file` and +`saved_split_file_name: split_spatial_{distance_km}km.pth`. + +Usage +----- + # Generate the default set of splits (10 km, 25 km, 50 km) + python src/data_preprocessing/yield_africa_spatial_splits.py --data_dir data/ + + # Generate a single split at a specific distance + python src/data_preprocessing/yield_africa_spatial_splits.py --data_dir data/ --distance_km 25 + + # Generate multiple distances in one run + python src/data_preprocessing/yield_africa_spatial_splits.py --data_dir data/ --distance_km 10 25 50 + +Notes +----- +- DBSCAN uses sklearn's built-in haversine metric with a BallTree spatial index + and n_jobs=-1, which is significantly faster than a Python geodesic lambda. + Haversine vs. true geodesic error is < 0.1% at distances up to ~100 km. +- `min_samples=2` means a pair of fields within `distance_km` of each other + forms a cluster; isolated fields each become their own singleton cluster. +- All clusters are kept intact across the split boundary, so the test set + contains no locations geographically close to any training location. +""" + +import argparse +import copy +import logging +import time +from pathlib import Path + +import numpy as np +import pandas as pd +import torch +from sklearn.cluster import DBSCAN + +log = logging.getLogger(__name__) + +DATASET_NAME = "yield_africa" +MODEL_READY_CSV = f"model_ready_{DATASET_NAME}.csv" + +# Default distances to generate when no --distance_km is supplied. +DEFAULT_DISTANCES_KM = [10, 25, 50] + +# Split proportions (must sum to 1.0). +TRAIN_FRAC = 0.70 +VAL_FRAC = 0.15 +TEST_FRAC = 0.15 + +# Fixed random seed for GroupShuffleSplit. +SEED = 12345 + + +def make_spatial_split( + df: pd.DataFrame, + distance_m: int, + train_val_test_split: tuple[float, float, float] = (TRAIN_FRAC, VAL_FRAC, TEST_FRAC), + seed: int = SEED, +) -> dict: + """Return a split-indices dict using DBSCAN spatial clustering. + + :param df: full model-ready dataframe (must contain 'lat', 'lon', 'name_loc') + :param distance_m: DBSCAN eps in metres — pairs of fields closer than this + value are assigned to the same cluster + :param train_val_test_split: (train, val, test) proportions, must sum to 1.0 + :param seed: random seed for GroupShuffleSplit + :return: dict with 'train_indices', 'val_indices', 'test_indices' as + pd.Series of name_loc strings, plus 'clusters' as a numpy array of + cluster labels (same length as df) + """ + # Deduplicate to unique (lat, lon) locations before clustering. + # yield_africa has ~9 rows per location (one per year); running DBSCAN on all + # rows produces giant clusters whose row counts are unequal, causing + # GroupShuffleSplit (which splits by cluster count) to produce badly skewed + # train/val/test proportions. Clustering unique locations and propagating + # the split back to all rows fixes this. + unique_locs = df.drop_duplicates(subset=["lat", "lon"]).reset_index(drop=True) + n_unique = len(unique_locs) + n_total = len(df) + if n_unique < n_total: + print( + f" Deduplicating: {n_unique} unique locations from {n_total} rows " + f"(~{n_total / n_unique:.1f} rows/location)." + ) + + # Convert (lat, lon) degrees to radians for sklearn's haversine metric. + # haversine returns arc length on the unit sphere, so eps must be in radians. + # Error vs. true geodesic is < 0.1% at distances up to ~100 km. + _EARTH_RADIUS_M = 6_371_000 + coords_rad = np.radians(np.array([unique_locs["lat"].values, unique_locs["lon"].values]).T) + eps_rad = distance_m / _EARTH_RADIUS_M + + print( + f" Running DBSCAN (eps={distance_m / 1000:.1f} km, haversine, " + f"n={n_unique} locations, n_jobs=-1)..." + ) + t0 = time.time() + clustering = DBSCAN( + eps=eps_rad, + metric="haversine", + algorithm="ball_tree", + min_samples=2, + n_jobs=-1, + ).fit(coords_rad) + print(f" DBSCAN done in {time.time() - t0:.1f}s.") + + # Noise points (label -1) each become their own unique cluster so that + # GroupShuffleSplit can assign them individually to a split partition. + clusters = copy.deepcopy(clustering.labels_) + next_label = int(np.max(clusters)) + 1 + for i, label in enumerate(clusters): + if label == -1: + clusters[i] = next_label + next_label += 1 + + n_clusters = len(np.unique(clusters)) + n_noise = int(np.sum(clustering.labels_ == -1)) + print(f" Clustering done: {n_clusters} location clusters ({n_noise} singleton noise points).") + + train_prop, val_prop, test_prop = train_val_test_split + + # Greedy size-aware cluster assignment. + # + # GroupShuffleSplit splits by cluster *count*, not by sample count. When the + # cluster size distribution is heavily skewed (a few mega-clusters + many + # tiny 2-location clusters), this produces badly imbalanced splits. + # + # Instead: shuffle clusters for randomness, sort by size descending, then + # assign each cluster to whichever split is furthest below its sample-count + # target. Each cluster goes to exactly one split, so there is no overlap. + rng = np.random.default_rng(seed) + unique_clusters, cluster_sizes = np.unique(clusters, return_counts=True) + + # Shuffle first so ties are broken randomly, then sort by descending size. + shuffle_order = rng.permutation(len(unique_clusters)) + unique_clusters = unique_clusters[shuffle_order] + cluster_sizes = cluster_sizes[shuffle_order] + size_order = np.argsort(-cluster_sizes) + unique_clusters = unique_clusters[size_order] + cluster_sizes = cluster_sizes[size_order] + + target_train = n_unique * train_prop + target_val = n_unique * val_prop + target_test = n_unique * test_prop + train_clusters, val_clusters, test_clusters = [], [], [] + count_train, count_val, count_test = 0, 0, 0 + + for cluster_id, size in zip(unique_clusters, cluster_sizes): + deficit_train = target_train - count_train + deficit_val = target_val - count_val + deficit_test = target_test - count_test + if deficit_train >= deficit_val and deficit_train >= deficit_test: + train_clusters.append(cluster_id) + count_train += size + elif deficit_val >= deficit_test: + val_clusters.append(cluster_id) + count_val += size + else: + test_clusters.append(cluster_id) + count_test += size + + train_loc_mask = np.isin(clusters, train_clusters) + val_loc_mask = np.isin(clusters, val_clusters) + test_loc_mask = np.isin(clusters, test_clusters) + + # Sanity checks: every location assigned, no cluster in multiple splits. + assert train_loc_mask.sum() + val_loc_mask.sum() + test_loc_mask.sum() == n_unique + assert len(set(train_clusters) & set(val_clusters)) == 0 + assert len(set(train_clusters) & set(test_clusters)) == 0 + assert len(set(val_clusters) & set(test_clusters)) == 0 + + print( + f" Split (locations): train={train_loc_mask.sum()}, " + f"val={val_loc_mask.sum()}, test={test_loc_mask.sum()}" + ) + + # Propagate location-level split assignments back to all rows by (lat, lon). + train_latlon = set( + zip(unique_locs.loc[train_loc_mask, "lat"], unique_locs.loc[train_loc_mask, "lon"]) + ) + val_latlon = set( + zip(unique_locs.loc[val_loc_mask, "lat"], unique_locs.loc[val_loc_mask, "lon"]) + ) + test_latlon = set( + zip(unique_locs.loc[test_loc_mask, "lat"], unique_locs.loc[test_loc_mask, "lon"]) + ) + row_latlon = list(zip(df["lat"], df["lon"])) + train_mask = np.array([ll in train_latlon for ll in row_latlon]) + val_mask = np.array([ll in val_latlon for ll in row_latlon]) + test_mask = np.array([ll in test_latlon for ll in row_latlon]) + + assert train_mask.sum() + val_mask.sum() + test_mask.sum() == n_total, ( + "Not all rows were assigned to a split — check for (lat, lon) values that " + "don't match any unique location after deduplication." + ) + + name_locs = df["name_loc"].reset_index(drop=True) + return { + "train_indices": name_locs[train_mask].reset_index(drop=True), + "val_indices": name_locs[val_mask].reset_index(drop=True), + "test_indices": name_locs[test_mask].reset_index(drop=True), + "clusters": clusters, + } + + +def generate_splits( + data_dir: str, + distances_km: list[int] | None = None, + seed: int = SEED, +) -> None: + """Generate and save spatial-cluster split files for the requested distances. + + :param data_dir: root data directory (same as `paths.data_dir` in configs) + :param distances_km: list of DBSCAN cluster distances in kilometres; None + uses DEFAULT_DISTANCES_KM + :param seed: random seed for GroupShuffleSplit + """ + if distances_km is None: + distances_km = DEFAULT_DISTANCES_KM + + dataset_dir = Path(data_dir) / DATASET_NAME + csv_path = dataset_dir / MODEL_READY_CSV + splits_dir = dataset_dir / "splits" + + if not csv_path.exists(): + raise FileNotFoundError(f"Model-ready CSV not found: {csv_path}") + + splits_dir.mkdir(parents=True, exist_ok=True) + + df = pd.read_csv(csv_path) + for col in ("lat", "lon", "name_loc"): + if col not in df.columns: + raise ValueError(f"CSV must contain a '{col}' column") + + print(f"Loaded {len(df)} rows from {csv_path}") + + for dist_km in distances_km: + dist_m = dist_km * 1000 + print(f"\nGenerating spatial split at {dist_km} km ({dist_m} m)...") + + split = make_spatial_split(df, distance_m=dist_m, seed=seed) + n_train = len(split["train_indices"]) + n_val = len(split["val_indices"]) + n_test = len(split["test_indices"]) + + out_name = f"split_spatial_{dist_km}km.pth" + out_path = splits_dir / out_name + torch.save(split, out_path) + + print( + f" Saved {out_name} " + f"(train={n_train}, val={n_val}, test={n_test}, " + f"total={n_train + n_val + n_test}/{len(df)})" + ) + log.info( + f" {dist_km}km: train={n_train}, val={n_val}, test={n_test} -> {out_name}" + ) + + +def main() -> None: + logging.basicConfig(level=logging.INFO, format="%(message)s") + + parser = argparse.ArgumentParser( + description="Generate spatial-cluster split files for yield_africa.", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=__doc__, + ) + parser.add_argument( + "--data_dir", + type=str, + default="data/", + help="Root data directory (same as paths.data_dir in configs). Default: data/", + ) + parser.add_argument( + "--distance_km", + type=int, + nargs="+", + default=None, + metavar="KM", + help=( + "Cluster distance threshold(s) in km. " + f"Omit to generate the default set: {DEFAULT_DISTANCES_KM} km." + ), + ) + parser.add_argument( + "--seed", + type=int, + default=SEED, + help=f"Random seed for GroupShuffleSplit. Default: {SEED}", + ) + args = parser.parse_args() + + distances = args.distance_km # None means use defaults + print( + f"Generating spatial splits data_dir={args.data_dir} " + f"distances_km={distances or DEFAULT_DISTANCES_KM} seed={args.seed}" + ) + generate_splits( + data_dir=args.data_dir, + distances_km=distances, + seed=args.seed, + ) + print("\nDone.") + + +if __name__ == "__main__": + main() From ac5724281d3f1a9669b6169e878f406b6ba2a4c7 Mon Sep 17 00:00:00 2001 From: Rob Knapen Date: Wed, 11 Mar 2026 10:54:17 +0100 Subject: [PATCH 36/60] Crop Yield use case: configs for various experiments --- configs/data/yield_africa_spatial.yaml | 33 ++++++++++++++++ configs/data/yield_africa_tessera_loco.yaml | 39 +++++++++++++++++++ .../data/yield_africa_tessera_spatial.yaml | 39 +++++++++++++++++++ .../experiment/yield_africa_fusion_loco.yaml | 33 ++++++++++++++++ .../yield_africa_fusion_spatial.yaml | 33 ++++++++++++++++ .../yield_africa_tabular_spatial.yaml | 33 ++++++++++++++++ .../yield_africa_tessera_fusion_loco.yaml | 38 ++++++++++++++++++ .../yield_africa_tessera_fusion_spatial.yaml | 38 ++++++++++++++++++ 8 files changed, 286 insertions(+) create mode 100644 configs/data/yield_africa_spatial.yaml create mode 100644 configs/data/yield_africa_tessera_loco.yaml create mode 100644 configs/data/yield_africa_tessera_spatial.yaml create mode 100644 configs/experiment/yield_africa_fusion_loco.yaml create mode 100644 configs/experiment/yield_africa_fusion_spatial.yaml create mode 100644 configs/experiment/yield_africa_tabular_spatial.yaml create mode 100644 configs/experiment/yield_africa_tessera_fusion_loco.yaml create mode 100644 configs/experiment/yield_africa_tessera_fusion_spatial.yaml diff --git a/configs/data/yield_africa_spatial.yaml b/configs/data/yield_africa_spatial.yaml new file mode 100644 index 0000000..9313100 --- /dev/null +++ b/configs/data/yield_africa_spatial.yaml @@ -0,0 +1,33 @@ +_target_: src.data.base_datamodule.BaseDataModule + +dataset: + _target_: src.data.yield_africa_dataset.YieldAfricaDataset + data_dir: ${paths.data_dir} + modalities: + coords: {} + use_target_data: true + use_features: true + use_aux_data: none + seed: ${seed} + cache_dir: ${paths.cache_dir} + # Include all countries and years so the split file determines the partition. + countries: ["BF", "BUR", "ETH", "KEN", "MAL", "RWA", "TAN", "ZAM"] + years: [2014, 2016, 2017, 2018, 2019, 2020, 2021, 2023, 2024] + exclude_countries: null + exclude_years: null + +batch_size: 64 +num_workers: 0 +pin_memory: false + +# Spatial-cluster split loaded from a pre-generated file. +# Generate split files first (produces 10 km, 25 km, and 50 km variants): +# python src/data_preprocessing/yield_africa_spatial_splits.py --data_dir +# +# Override saved_split_file_name at the command line to change the cluster distance: +# python src/train.py experiment=yield_africa_tabular_spatial \ +# data.saved_split_file_name=split_spatial_10km.pth +split_mode: "from_file" +saved_split_file_name: "split_spatial_25km.pth" +save_split: false +seed: ${seed} diff --git a/configs/data/yield_africa_tessera_loco.yaml b/configs/data/yield_africa_tessera_loco.yaml new file mode 100644 index 0000000..0be62c3 --- /dev/null +++ b/configs/data/yield_africa_tessera_loco.yaml @@ -0,0 +1,39 @@ +_target_: src.data.base_datamodule.BaseDataModule + +dataset: + _target_: src.data.yield_africa_dataset.YieldAfricaDataset + data_dir: ${paths.data_dir} + modalities: + tessera: + # size must match the tile_size used when running the preprocessing script. + # Default: 9 pixels (set by yield_africa_tessera_preprocess.py --tile_size). + size: 9 + format: npy + # year is intentionally omitted: yield_africa fetches per-record year tiles + # via the preprocessing script rather than a single bulk-year download. + use_target_data: true + use_features: true + use_aux_data: none + seed: ${seed} + cache_dir: ${paths.cache_dir} + # Include all countries and years so the split file determines the partition. + countries: ["BF", "BUR", "ETH", "KEN", "MAL", "RWA", "TAN", "ZAM"] + years: [2014, 2016, 2017, 2018, 2019, 2020, 2021, 2023, 2024] + exclude_countries: null + exclude_years: null + +batch_size: 64 +num_workers: 0 +pin_memory: false + +# Leave-one-country-out split loaded from a pre-generated file. +# Generate split files first: +# python src/data_preprocessing/yield_africa_loco_splits.py --data_dir +# +# Override saved_split_file_name at the command line to change the held-out country: +# python src/train.py experiment=yield_africa_tessera_fusion_loco \ +# data.saved_split_file_name=split_loco_RWA.pth +split_mode: "from_file" +saved_split_file_name: "split_loco_KEN.pth" +save_split: false +seed: ${seed} diff --git a/configs/data/yield_africa_tessera_spatial.yaml b/configs/data/yield_africa_tessera_spatial.yaml new file mode 100644 index 0000000..9424801 --- /dev/null +++ b/configs/data/yield_africa_tessera_spatial.yaml @@ -0,0 +1,39 @@ +_target_: src.data.base_datamodule.BaseDataModule + +dataset: + _target_: src.data.yield_africa_dataset.YieldAfricaDataset + data_dir: ${paths.data_dir} + modalities: + tessera: + # size must match the tile_size used when running the preprocessing script. + # Default: 9 pixels (set by yield_africa_tessera_preprocess.py --tile_size). + size: 9 + format: npy + # year is intentionally omitted: yield_africa fetches per-record year tiles + # via the preprocessing script rather than a single bulk-year download. + use_target_data: true + use_features: true + use_aux_data: none + seed: ${seed} + cache_dir: ${paths.cache_dir} + # Include all countries and years so the split file determines the partition. + countries: ["BF", "BUR", "ETH", "KEN", "MAL", "RWA", "TAN", "ZAM"] + years: [2014, 2016, 2017, 2018, 2019, 2020, 2021, 2023, 2024] + exclude_countries: null + exclude_years: null + +batch_size: 64 +num_workers: 0 +pin_memory: false + +# Spatial-cluster split loaded from a pre-generated file. +# Generate split files first (produces 10 km, 25 km, and 50 km variants): +# python src/data_preprocessing/yield_africa_spatial_splits.py --data_dir +# +# Override saved_split_file_name at the command line to change the cluster distance: +# python src/train.py experiment=yield_africa_tessera_fusion_spatial \ +# data.saved_split_file_name=split_spatial_10km.pth +split_mode: "from_file" +saved_split_file_name: "split_spatial_25km.pth" +save_split: false +seed: ${seed} diff --git a/configs/experiment/yield_africa_fusion_loco.yaml b/configs/experiment/yield_africa_fusion_loco.yaml new file mode 100644 index 0000000..2540642 --- /dev/null +++ b/configs/experiment/yield_africa_fusion_loco.yaml @@ -0,0 +1,33 @@ +# @package _global_ +# configs/experiment/yield_africa_fusion_loco.yaml +# GeoClip + tabular fusion model evaluated with leave-one-country-out split. +# Default held-out country: KEN (largest, most representative test set). +# +# Generate split files first: +# python src/data_preprocessing/yield_africa_loco_splits.py --data_dir +# +# To evaluate on a different held-out country: +# python src/train.py experiment=yield_africa_fusion_loco \ +# data.saved_split_file_name=split_loco_RWA.pth + +defaults: + - override /model: yield_fusion_reg + - override /data: yield_africa_loco + - override /metrics: yield_africa_regression + +tags: ["yield_africa", "fusion", "regression", "loco"] +seed: 12345 + +trainer: + min_epochs: 1 + max_epochs: 150 + +data: + batch_size: 64 + +logger: + wandb: + tags: ${tags} + group: "yield_africa" + aim: + experiment: "yield_africa" diff --git a/configs/experiment/yield_africa_fusion_spatial.yaml b/configs/experiment/yield_africa_fusion_spatial.yaml new file mode 100644 index 0000000..98c4221 --- /dev/null +++ b/configs/experiment/yield_africa_fusion_spatial.yaml @@ -0,0 +1,33 @@ +# @package _global_ +# configs/experiment/yield_africa_fusion_spatial.yaml +# GeoClip + tabular fusion model evaluated with a spatial-cluster split. +# Default cluster distance: 25 km (split_spatial_25km.pth). +# +# Generate split files first: +# python src/data_preprocessing/yield_africa_spatial_splits.py --data_dir +# +# To evaluate at a different cluster distance: +# python src/train.py experiment=yield_africa_fusion_spatial \ +# data.saved_split_file_name=split_spatial_10km.pth + +defaults: + - override /model: yield_fusion_reg + - override /data: yield_africa_spatial + - override /metrics: yield_africa_regression + +tags: ["yield_africa", "fusion", "regression", "spatial"] +seed: 12345 + +trainer: + min_epochs: 1 + max_epochs: 150 + +data: + batch_size: 64 + +logger: + wandb: + tags: ${tags} + group: "yield_africa" + aim: + experiment: "yield_africa" diff --git a/configs/experiment/yield_africa_tabular_spatial.yaml b/configs/experiment/yield_africa_tabular_spatial.yaml new file mode 100644 index 0000000..9c57961 --- /dev/null +++ b/configs/experiment/yield_africa_tabular_spatial.yaml @@ -0,0 +1,33 @@ +# @package _global_ +# configs/experiment/yield_africa_tabular_spatial.yaml +# Tabular-only model evaluated with a spatial-cluster split. +# Default cluster distance: 25 km (split_spatial_25km.pth). +# +# Generate split files first: +# python src/data_preprocessing/yield_africa_spatial_splits.py --data_dir +# +# To evaluate at a different cluster distance: +# python src/train.py experiment=yield_africa_tabular_spatial \ +# data.saved_split_file_name=split_spatial_10km.pth + +defaults: + - override /model: yield_tabular_reg + - override /data: yield_africa_spatial + - override /metrics: yield_africa_regression + +tags: ["yield_africa", "tabular_only", "regression", "spatial"] +seed: 12345 + +trainer: + min_epochs: 1 + max_epochs: 150 + +data: + batch_size: 64 + +logger: + wandb: + tags: ${tags} + group: "yield_africa" + aim: + experiment: "yield_africa" diff --git a/configs/experiment/yield_africa_tessera_fusion_loco.yaml b/configs/experiment/yield_africa_tessera_fusion_loco.yaml new file mode 100644 index 0000000..ee9aa9d --- /dev/null +++ b/configs/experiment/yield_africa_tessera_fusion_loco.yaml @@ -0,0 +1,38 @@ +# @package _global_ +# configs/experiment/yield_africa_tessera_fusion_loco.yaml +# TESSERA + tabular fusion model evaluated with leave-one-country-out split. +# Default held-out country: KEN (largest, most representative test set). +# +# Requires: +# 1. TESSERA tiles pre-fetched: +# python src/data_preprocessing/yield_africa_tessera_preprocess.py --data_dir +# 2. LOCO split files pre-generated: +# python src/data_preprocessing/yield_africa_loco_splits.py --data_dir +# +# To evaluate on a different held-out country: +# python src/train.py experiment=yield_africa_tessera_fusion_loco \ +# data.saved_split_file_name=split_loco_RWA.pth + +defaults: + - override /model: yield_tessera_fusion_reg + - override /data: yield_africa_tessera_loco + - override /metrics: yield_africa_regression + +tags: ["yield_africa", "tessera_fusion", "regression", "loco"] +seed: 12345 + +trainer: + min_epochs: 1 + max_epochs: 150 + +data: + batch_size: 64 + dataset: + use_features: true + +logger: + wandb: + tags: ${tags} + group: "yield_africa" + aim: + experiment: "yield_africa" diff --git a/configs/experiment/yield_africa_tessera_fusion_spatial.yaml b/configs/experiment/yield_africa_tessera_fusion_spatial.yaml new file mode 100644 index 0000000..b0eaf9d --- /dev/null +++ b/configs/experiment/yield_africa_tessera_fusion_spatial.yaml @@ -0,0 +1,38 @@ +# @package _global_ +# configs/experiment/yield_africa_tessera_fusion_spatial.yaml +# TESSERA + tabular fusion model evaluated with a spatial-cluster split. +# Default cluster distance: 25 km (split_spatial_25km.pth). +# +# Requires: +# 1. TESSERA tiles pre-fetched: +# python src/data_preprocessing/yield_africa_tessera_preprocess.py --data_dir +# 2. Spatial split files pre-generated: +# python src/data_preprocessing/yield_africa_spatial_splits.py --data_dir +# +# To evaluate at a different cluster distance: +# python src/train.py experiment=yield_africa_tessera_fusion_spatial \ +# data.saved_split_file_name=split_spatial_10km.pth + +defaults: + - override /model: yield_tessera_fusion_reg + - override /data: yield_africa_tessera_spatial + - override /metrics: yield_africa_regression + +tags: ["yield_africa", "tessera_fusion", "regression", "spatial"] +seed: 12345 + +trainer: + min_epochs: 1 + max_epochs: 150 + +data: + batch_size: 64 + dataset: + use_features: true + +logger: + wandb: + tags: ${tags} + group: "yield_africa" + aim: + experiment: "yield_africa" From 944814dce8a9d30f76815a08ae1cb0eae47f221d Mon Sep 17 00:00:00 2001 From: Rob Knapen Date: Wed, 11 Mar 2026 11:31:04 +0100 Subject: [PATCH 37/60] Adds RRMSE loss function for crop yield error comparison --- configs/metrics/yield_africa_regression.yaml | 3 +- src/models/components/loss_fns/rrmse_loss.py | 44 ++++++++++++++++++++ 2 files changed, 46 insertions(+), 1 deletion(-) create mode 100644 src/models/components/loss_fns/rrmse_loss.py diff --git a/configs/metrics/yield_africa_regression.yaml b/configs/metrics/yield_africa_regression.yaml index 79c441d..7960283 100644 --- a/configs/metrics/yield_africa_regression.yaml +++ b/configs/metrics/yield_africa_regression.yaml @@ -1,7 +1,8 @@ _target_: src.models.components.metrics.metrics_wrapper.MetricsWrapper metrics: - - _target_: src.models.components.loss_fns.mse_loss.MSELoss + - _target_: src.models.components.loss_fns.huber_loss.HuberLoss - _target_: src.models.components.loss_fns.rmse_loss.RMSELoss - _target_: src.models.components.loss_fns.mae_loss.MAELoss + - _target_: src.models.components.loss_fns.rrmse_loss.RRMSELoss - _target_: src.models.components.metrics.r2.RSquared diff --git a/src/models/components/loss_fns/rrmse_loss.py b/src/models/components/loss_fns/rrmse_loss.py new file mode 100644 index 0000000..1720f1b --- /dev/null +++ b/src/models/components/loss_fns/rrmse_loss.py @@ -0,0 +1,44 @@ +from typing import Dict, override + +import torch + +from src.models.components.loss_fns.base_loss_fn import BaseLossFn + + +class RRMSELoss(BaseLossFn): + """Relative Root Mean Squared Error (RRMSE). + + RRMSE = RMSE / mean(|labels|) + + Normalises RMSE by the mean absolute value of the target, giving a + unit-free percentage error. This makes results comparable across crops + and regions with different absolute yield scales (e.g. t/ha ranges + differ significantly between maize in Zambia and rice in Rwanda). + + Returns a fraction (e.g. 0.15 = 15 % error). Multiply by 100 for + percentage when reporting. + """ + + def __init__(self) -> None: + super().__init__() + self.criterion = torch.nn.MSELoss() + self.name = "rrmse_loss" + + @override + def forward( + self, + pred: torch.Tensor, + labels: torch.Tensor | None = None, + batch: Dict[str, torch.Tensor] | None = None, + **kwargs, + ) -> torch.Tensor | Dict[str, torch.Tensor]: + + labels = labels if labels is not None else batch.get("target") + rmse = torch.sqrt(self.criterion(pred, labels)) + mean_abs = torch.mean(torch.abs(labels)) + loss = rmse / (mean_abs + 1e-8) + + if "return_label" in kwargs: + return {self.name: loss} + else: + return loss From 5c90f4fafc36d0070ccdd3857b53fe50526108d6 Mon Sep 17 00:00:00 2001 From: Rob Knapen Date: Wed, 11 Mar 2026 12:27:47 +0100 Subject: [PATCH 38/60] Crop Yield use case: Adds Fourier harmonics as engineered location features. --- src/data/yield_africa_dataset.py | 52 +++++++++++++++++++++++++++++--- tests/test_yield_africa.py | 13 +++++--- 2 files changed, 56 insertions(+), 9 deletions(-) diff --git a/src/data/yield_africa_dataset.py b/src/data/yield_africa_dataset.py index f2867d8..bf549c0 100644 --- a/src/data/yield_africa_dataset.py +++ b/src/data/yield_africa_dataset.py @@ -12,6 +12,7 @@ import os from typing import Any, Dict, List, override +import numpy as np import pandas as pd import torch @@ -27,6 +28,16 @@ # countries are present after filtering. _ALL_COUNTRIES = ["BF", "BUR", "ETH", "KEN", "MAL", "RWA", "TAN", "ZAM"] +# Study-area bounds used to normalise coordinates before computing Fourier +# harmonics. Normalising to the actual data extent (rather than ±90°/±180°) +# makes the harmonics maximally discriminative within the dataset. +# Latitude : 30°S – 15°N → centre −7.5°, half-range 22.5° +# Longitude : 10°E – 45°E → centre 27.5°, half-range 17.5° +_LAT_CENTER = -7.5 +_LAT_HALF_RANGE = 22.5 +_LON_CENTER = 27.5 +_LON_HALF_RANGE = 17.5 + class YieldAfricaDataset(BaseDataset): """Dataset for the crop yield regression use case (East/Southern Africa). @@ -45,11 +56,20 @@ class YieldAfricaDataset(BaseDataset): the model-ready CSV and are picked up via the `feat_` column prefix. They do NOT need to be listed in `modalities`. - In addition to the CSV feat_* columns, `year` and one-hot `country` - encodings are injected as `feat_year` and `feat_country_{CODE}` so that - the model can condition on inter-annual and cross-country variation. - The one-hot set always covers `_ALL_COUNTRIES` (8 countries) so that - `tabular_dim` is stable regardless of the country filter applied. + In addition to the CSV feat_* columns, the following features are injected: + - ``feat_year`` : normalised year (zero-mean, unit-std) + - ``feat_country_{CODE}`` : one-hot country encoding (always 8 columns, + stable across country filters) + - ``feat_lat_sin1/cos1`` : fundamental latitude harmonic, normalised to + the study-area extent (30°S–15°N) + - ``feat_lat_sin2/cos2`` : second latitude harmonic (captures bimodal vs. + unimodal rainfall boundary near the equator) + - ``feat_lon_sin1/cos1`` : fundamental longitude harmonic, normalised to + the study-area extent (10°E–45°E) + + The Fourier harmonics encode the ITCZ-driven latitudinal climate gradient at + interpretable frequencies, complementing GeoCLIP's photo-derived coordinate + embedding and enabling richer text captions for the explainability component. """ def __init__( @@ -93,6 +113,28 @@ def __init__( } for code in _ALL_COUNTRIES: new_cols[f"feat_country_{code}"] = (self.df["country"] == code).astype(float) + + # Fourier harmonics of coordinates, normalised to the study-area extent. + # + # Africa's agricultural patterns follow the ITCZ-driven latitudinal climate + # gradient: rainfall regime (uni- vs. bimodal), growing-season length, and + # temperature vary sinusoidally with latitude. Explicit harmonics give the + # model these signals directly and at interpretable frequencies, complementing + # GeoCLIP's learned (but photo-derived) coordinate embedding. + # + # lat_norm / lon_norm ∈ [-1, 1] within the study area; π * norm ∈ [-π, π]. + # Two harmonics for latitude (captures both the broad N-S gradient and the + # equatorial-bimodal / southern-unimodal boundary); one for longitude + # (east-west Indian Ocean moisture gradient). + lat_norm = (self.df["lat"].astype(float) - _LAT_CENTER) / _LAT_HALF_RANGE + lon_norm = (self.df["lon"].astype(float) - _LON_CENTER) / _LON_HALF_RANGE + new_cols["feat_lat_sin1"] = np.sin(np.pi * lat_norm) + new_cols["feat_lat_cos1"] = np.cos(np.pi * lat_norm) + new_cols["feat_lat_sin2"] = np.sin(2.0 * np.pi * lat_norm) + new_cols["feat_lat_cos2"] = np.cos(2.0 * np.pi * lat_norm) + new_cols["feat_lon_sin1"] = np.sin(np.pi * lon_norm) + new_cols["feat_lon_cos1"] = np.cos(np.pi * lon_norm) + self.df = pd.concat([self.df, pd.DataFrame(new_cols, index=self.df.index)], axis=1) # Apply country/year filters to self.df and rebuild records. diff --git a/tests/test_yield_africa.py b/tests/test_yield_africa.py index b1e6436..9124c96 100644 --- a/tests/test_yield_africa.py +++ b/tests/test_yield_africa.py @@ -46,11 +46,16 @@ } MOCK_N_ROWS = 10 -# feat_year (1) + feat_country_{code} (8) are injected by YieldAfricaDataset -# when country and year columns are present, so the effective tabular dim grows. +# YieldAfricaDataset injects extra feat_* columns when country and year columns +# are present: feat_year (1) + feat_country_{code} (8) + Fourier harmonics (6). from src.data.yield_africa_dataset import _ALL_COUNTRIES -MOCK_INJECTED_FEAT_NAMES = {"feat_year"} | {f"feat_country_{c}" for c in _ALL_COUNTRIES} -MOCK_TABULAR_DIM = len(MOCK_FEAT_COLS) + len(MOCK_INJECTED_FEAT_NAMES) # 8 + 9 = 17 +MOCK_INJECTED_FEAT_NAMES = ( + {"feat_year"} + | {f"feat_country_{c}" for c in _ALL_COUNTRIES} + | {"feat_lat_sin1", "feat_lat_cos1", "feat_lat_sin2", "feat_lat_cos2", + "feat_lon_sin1", "feat_lon_cos1"} +) +MOCK_TABULAR_DIM = len(MOCK_FEAT_COLS) + len(MOCK_INJECTED_FEAT_NAMES) # 8 + 15 = 23 MOCK_N_AUX = len(MOCK_AUX_COLS) # 4 From 5c40f0700a2b288786440d161c46fc9f4f65f28f Mon Sep 17 00:00:00 2001 From: Thijs van der Plas Date: Wed, 11 Mar 2026 14:52:11 +0100 Subject: [PATCH 39/60] concept captions manually picked relevant captions for biodiv UC, inspecting histograms to look for sensible thresholds. --- data/s2bms/concept_captions/v2.json | 103 ++++++++++++++++++++++++++++ 1 file changed, 103 insertions(+) create mode 100644 data/s2bms/concept_captions/v2.json diff --git a/data/s2bms/concept_captions/v2.json b/data/s2bms/concept_captions/v2.json new file mode 100644 index 0000000..9b2bdd5 --- /dev/null +++ b/data/s2bms/concept_captions/v2.json @@ -0,0 +1,103 @@ +[ + { + "concept_caption": "Densely populated area with many houses", + "is_max": true, + "theta_k": 0.3, + "col": "aux_corine_frac_11" + }, + { + "concept_caption": "Very sparsely populated area with few houses", + "is_max": false, + "theta_k": 0.05, + "col": "aux_corine_frac_11" + },{ + "concept_caption": "Area with infrastructure such as roads, railways, airport, ports and heavy industry.", + "is_max": true, + "theta_k": 0.1, + "col": "aux_corine_frac_12" + }, + { + "concept_caption": "Arable land with crops for agriculture", + "is_max": true, + "theta_k": 0.65, + "col": "aux_corine_frac_21" + }, + { + "concept_caption": "Pasture fields with grass for grazing animals", + "is_max": true, + "theta_k": 0.6, + "col": "aux_corine_frac_231" + }, + { + "concept_caption": "Agricultural land used for crops, pasture or mixed farming", + "is_max": true, + "theta_k": 0.05, + "col": "aux_corine_frac_24" + }, + { + "concept_caption": "Forested area with many trees", + "is_max": true, + "theta_k": 0.25, + "col": "aux_corine_frac_31" + }, + { + "concept_caption": "Scrub area with trees, shrub, moors.", + "is_max": true, + "theta_k": 0.2, + "col": "aux_corine_frac_32" + }, + { + "concept_caption": "Moorlands and heathlands with low vegetation", + "is_max": true, + "theta_k": 0.2, + "col": "aux_corine_frac_322" + }, + { + "concept_caption": "Wetlands such as marshes, swamps, mudflats and bogs.", + "is_max": true, + "theta_k": 0.2, + "col": "aux_corine_frac_4" + }, + { + "concept_caption": "Peat bogs", + "is_max": true, + "theta_k": 0.2, + "col": "aux_corine_frac_412" + }, + { + "concept_caption": "Water bodies such as lakes, rivers and sea", + "is_max": true, + "theta_k": 0.2, + "col": "aux_corine_frac_5" + }, + { + "concept_caption": "Warm area with high summer temperatures", + "is_max": true, + "theta_k": 22, + "col": "aux_bioclim_05" + }, + { + "concept_caption": "Cold area with low winter temperatures", + "is_max": false, + "theta_k": 0, + "col": "aux_bioclim_06" + }, + { + "concept_caption": "Wet area with a lot of rainfall", + "is_max": true, + "theta_k": 950, + "col": "aux_bioclim_12" + }, + { + "concept_caption": "Remote area far from roads and infrastructure", + "is_max": true, + "theta_k": 1500, + "col": "aux_meandist_road" + }, + { + "concept_caption": "Densely populated area with many houses", + "is_max": true, + "theta_k": 1500, + "col": "aux_pop_density" + } +] From 8782f37666d3712156d1eab1561fc874db8179ed Mon Sep 17 00:00:00 2001 From: GabrieleTi Date: Thu, 12 Mar 2026 09:49:41 +0100 Subject: [PATCH 40/60] Fix processor to clip to max sequence length of CLIP text encoder --- src/models/components/text_encoders/clip_text_encoder.py | 9 +++++++-- tests/test_yield_africa.py | 9 +++++---- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/src/models/components/text_encoders/clip_text_encoder.py b/src/models/components/text_encoders/clip_text_encoder.py index 027c503..7642865 100644 --- a/src/models/components/text_encoders/clip_text_encoder.py +++ b/src/models/components/text_encoders/clip_text_encoder.py @@ -2,7 +2,6 @@ import torch from geoclip import GeoCLIP -from torch.nn import functional as F from transformers import CLIPModel, CLIPProcessor from src.models.components.text_encoders.base_text_encoder import ( @@ -38,7 +37,13 @@ def forward(self, batch: Dict[str, torch.Tensor], mode: str) -> torch.Tensor: avr_embeds = [] for captions_per_row in text_input: # Tokenize and embed - text_tokens = self.processor(text=captions_per_row, return_tensors="pt", padding=True) + text_tokens = self.processor( + text=captions_per_row, + return_tensors="pt", + padding=True, + truncation=True, + max_length=77, + ) device = next(self.model.parameters()).device text_tokens = {k: v.to(device) for k, v in text_tokens.items()} diff --git a/tests/test_yield_africa.py b/tests/test_yield_africa.py index b1e6436..4eb8152 100644 --- a/tests/test_yield_africa.py +++ b/tests/test_yield_africa.py @@ -49,9 +49,10 @@ # feat_year (1) + feat_country_{code} (8) are injected by YieldAfricaDataset # when country and year columns are present, so the effective tabular dim grows. from src.data.yield_africa_dataset import _ALL_COUNTRIES + MOCK_INJECTED_FEAT_NAMES = {"feat_year"} | {f"feat_country_{c}" for c in _ALL_COUNTRIES} MOCK_TABULAR_DIM = len(MOCK_FEAT_COLS) + len(MOCK_INJECTED_FEAT_NAMES) # 8 + 9 = 17 -MOCK_N_AUX = len(MOCK_AUX_COLS) # 4 +MOCK_N_AUX = len(MOCK_AUX_COLS) # 4 # --------------------------------------------------------------------------- @@ -178,7 +179,7 @@ def test_yield_africa_dataset_attributes(yield_africa_dataset): def test_yield_africa_dataset_feat_prefix(yield_africa_dataset): - """All tabular features must carry the feat_ prefix.""" + """All tabular features must carry the feat prefix.""" for name in yield_africa_dataset.feat_names: assert name.startswith("feat_"), f"Unexpected feature name: {name}" @@ -187,7 +188,7 @@ def test_yield_africa_dataset_coords_values(yield_africa_dataset): """Coordinates returned must match the CSV values.""" sample = yield_africa_dataset[0] coords = sample["eo"]["coords"] - assert coords[0].item() == pytest.approx(5.0) # lat of row 0 + assert coords[0].item() == pytest.approx(5.0) # lat of row 0 assert coords[1].item() == pytest.approx(30.0) # lon of row 0 @@ -312,4 +313,4 @@ def test_yield_africa_model_instantiates(): ) model = hydra.utils.instantiate(cfg.model) assert model is not None - GlobalHydra.instance().clear() \ No newline at end of file + GlobalHydra.instance().clear() From 9af0d5c970cbef1b2a291eab2633d8b4b4b25900 Mon Sep 17 00:00:00 2001 From: GabrieleTi Date: Thu, 12 Mar 2026 09:56:19 +0100 Subject: [PATCH 41/60] Add setup method to all geo-encoders --- src/models/components/geo_encoders/cnn_encoder.py | 9 +++++++-- src/models/components/geo_encoders/geoclip.py | 11 ++++++++--- .../components/geo_encoders/tabular_encoder.py | 12 +++++++++++- 3 files changed, 26 insertions(+), 6 deletions(-) diff --git a/src/models/components/geo_encoders/cnn_encoder.py b/src/models/components/geo_encoders/cnn_encoder.py index d76a4e3..d73a7b3 100644 --- a/src/models/components/geo_encoders/cnn_encoder.py +++ b/src/models/components/geo_encoders/cnn_encoder.py @@ -1,9 +1,8 @@ -from typing import Dict, override +from typing import Dict, List, override import torch import torchvision.models as models from torch import nn -from torch.nn import functional as F from src.models.components.geo_encoders.base_geo_encoder import BaseGeoEncoder @@ -132,6 +131,12 @@ def get_backbone(self): else: raise ValueError(f"Unsupported backbone: {self.backbone}") + @override + def setup(self) -> List[str]: + # TODO: could you make sure new layers are returned here to be added to trainable parts? + # Maybe move the get_backbone method in here? + return [] + @override def forward( self, diff --git a/src/models/components/geo_encoders/geoclip.py b/src/models/components/geo_encoders/geoclip.py index bd40aa9..e530dd1 100644 --- a/src/models/components/geo_encoders/geoclip.py +++ b/src/models/components/geo_encoders/geoclip.py @@ -1,4 +1,4 @@ -from typing import Dict, override +from typing import Dict, List, override import torch from geoclip import LocationEncoder @@ -13,14 +13,19 @@ def __init__( geo_data_name="coords", ) -> None: super().__init__() - self.geo_encoder = LocationEncoder() - self.output_dim = self.geo_encoder.LocEnc0.head[0].out_features + self.allowed_geo_data_names = ["coords"] assert ( geo_data_name in self.allowed_geo_data_names ), f"geo_data_name must be one of {self.allowed_geo_data_names}, got {geo_data_name}" self.geo_data_name = geo_data_name + @override + def setup(self) -> List[str]: + self.geo_encoder = LocationEncoder() + self.output_dim = self.geo_encoder.LocEnc0.head[0].out_features + return [] + @override def forward( self, diff --git a/src/models/components/geo_encoders/tabular_encoder.py b/src/models/components/geo_encoders/tabular_encoder.py index 47bc621..09af55c 100644 --- a/src/models/components/geo_encoders/tabular_encoder.py +++ b/src/models/components/geo_encoders/tabular_encoder.py @@ -32,8 +32,18 @@ def __init__( ), f"geo_data_name must be one of {self.allowed_geo_data_names}, got {geo_data_name}" self.geo_data_name = geo_data_name - def configure_nn(self, input_dim: int) -> None: + @override + def setup(self, input_dim: int = None) -> list[str]: + self.configure_nn(input_dim) + return ["tabular_encoder"] + + def set_tabular_input_dim(self, input_dim: int) -> None: self.input_dim = input_dim + + def configure_nn(self, input_dim: int = None) -> None: + input_dim = input_dim or self.input_dim + assert input_dim is not None, "input_dim must be defined" + if self.hidden_dim is None: self.hidden_dim = max(self.input_dim * 2, 128) From d85d3292b8bbc6182688ca0c463422e7e0932996 Mon Sep 17 00:00:00 2001 From: GabrieleTi Date: Thu, 12 Mar 2026 09:56:32 +0100 Subject: [PATCH 42/60] Add setup method and fix devices/dtypes --- .../geo_encoders/base_geo_encoder.py | 32 +++++++++++++++---- 1 file changed, 26 insertions(+), 6 deletions(-) diff --git a/src/models/components/geo_encoders/base_geo_encoder.py b/src/models/components/geo_encoders/base_geo_encoder.py index 162fcc9..f79f3f4 100644 --- a/src/models/components/geo_encoders/base_geo_encoder.py +++ b/src/models/components/geo_encoders/base_geo_encoder.py @@ -20,19 +20,39 @@ def forward(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: pass @property - def device(self) -> torch.device: + def device(self) -> torch.device | None: devices = {p.device for p in self.parameters()} - if len(devices) != 1: + if len(devices) > 1: raise RuntimeError("GEO encoder is on multiple devices") + elif len(devices) == 0: + return None return devices.pop() @property - def dtype(self) -> torch.dtype: + def dtype(self) -> torch.dtype | None: dtypes = {p.dtype for p in self.parameters()} - if len(dtypes) != 1: + if len(dtypes) > 1: raise RuntimeError("GEO encoder has multiple dtypes") + elif len(dtypes) == 0: + return None return dtypes.pop() + @abstractmethod + def setup(self) -> list[str]: + """Configures networks, data-dependent parts. + + Gets called in model.setup() method. Returns names of any new module configured to be added + to the trainable modules list. + """ + pass + + def add_projector(self, projected_dim: int) -> None: + """Adds an extra linear projection layer to the geo encoder. -if __name__ == "__main__": - _ = BaseGeoEncoder(None) + NB: is not used by default, needs to be called explicitly in forward(). + """ + self.extra_projector = nn.Linear(self.output_dim, projected_dim, dtype=self.dtype) + print( + f"Extra linear projection layer added with mapping dimension {self.output_dim} to {projected_dim}" + ) + self.output_dim = projected_dim From ffdf9858403884c2989865b6f6b8ffbc00a097fb Mon Sep 17 00:00:00 2001 From: GabrieleTi Date: Thu, 12 Mar 2026 09:58:45 +0100 Subject: [PATCH 43/60] Add setup method and simplify architecture --- .../geo_encoders/average_encoder.py | 64 ++++++------------- 1 file changed, 19 insertions(+), 45 deletions(-) diff --git a/src/models/components/geo_encoders/average_encoder.py b/src/models/components/geo_encoders/average_encoder.py index 86a6b9b..885b175 100644 --- a/src/models/components/geo_encoders/average_encoder.py +++ b/src/models/components/geo_encoders/average_encoder.py @@ -1,7 +1,6 @@ -from typing import Dict, override +from typing import Dict, List, override import torch -import torch.nn.functional as F from torch import nn from src.models.components.geo_encoders.base_geo_encoder import BaseGeoEncoder @@ -10,61 +9,36 @@ class AverageEncoder(BaseGeoEncoder): def __init__( self, - output_dim: int | None = None, geo_data_name="aef", ) -> None: + """Encoder to avreage tile values into a 1D vector. + + :param geo_data_name: modality name + """ super().__init__() - dict_n_bands_default = {"s2": 4, "aef": 64, "tessera": 128} - self.allowed_geo_data_names: list[str] = list(dict_n_bands_default.keys()) + self.dict_n_bands_default = {"s2": 4, "aef": 64, "tessera": 128} + self.allowed_geo_data_names: list[str] = list(self.dict_n_bands_default.keys()) assert ( - geo_data_name in dict_n_bands_default + geo_data_name in self.allowed_geo_data_names ), f"geo_data_name must be one of {self.allowed_geo_data_names}, got {geo_data_name}" self.geo_data_name = geo_data_name - if output_dim is None or output_dim == dict_n_bands_default[geo_data_name]: - self.output_dim = dict_n_bands_default[geo_data_name] - self.extra_projector = None - self.geo_encoder = self._average - else: - assert ( - type(output_dim) is int and output_dim > 0 - ), f"output_dim must be positive int, got {output_dim}" - self.output_dim = output_dim - self.extra_projector = nn.Linear(dict_n_bands_default[geo_data_name], output_dim) - self.geo_encoder = self._average_and_project - - def _average(self, x: torch.Tensor) -> torch.Tensor: - """Averages the input tensor over spatial dimensions. - - :param x: input tensor of shape (B, C, H, W) - :return: averaged tensor of shape (B, C) - """ - return x.mean(dim=(-2, -1)) - - def _average_and_project(self, x: torch.Tensor) -> torch.Tensor: - """Averages the input tensor over spatial dimensions and projects to output_dim. + @override + def setup(self) -> List[str]: + """Configures networks, data-dependent parts. - :param x: input tensor of shape (B, C, H, W) - :return: projected tensor of shape (B, output_dim) + Gets called in model.setup() method. Returns names of any new module configured to be added + to the trainable modules list. """ - x_avg = x.mean(dim=(-2, -1)) - x_proj = self.extra_projector(x_avg) - return x_proj + self.output_dim = self.dict_n_bands_default[self.geo_data_name] + self.geo_encoder = nn.Identity() + return [] @override def forward(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: + """Data forward pass through the encoder.""" tile = batch.get("eo", {}).get(self.geo_data_name) - # Determine target dtype from parameters when available (e.g. when the - # optional projection layer exists); otherwise keep the input dtype. - params = list(self.parameters()) - dtype = params[0].dtype if params else tile.dtype - if tile.dtype != dtype: - tile = tile.to(dtype) - feats = self.geo_encoder(tile) - return feats.to(dtype) - - -if __name__ == "__main__": - _ = AverageEncoder(None, None) + feats = self.geo_encoder(tile.mean(dim=(-2, -1))) + return feats From 62422ffbfa676a91129a3c6ef23bbafb6e989f02 Mon Sep 17 00:00:00 2001 From: GabrieleTi Date: Thu, 12 Mar 2026 09:59:10 +0100 Subject: [PATCH 44/60] Change how trainable parts are reported/printed --- src/models/base_model.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/models/base_model.py b/src/models/base_model.py index 7b3cb0a..a28c085 100644 --- a/src/models/base_model.py +++ b/src/models/base_model.py @@ -55,8 +55,7 @@ def freezer(self) -> None: # Enable exceptions if name.startswith(self.trainable_modules): param.requires_grad = True - top_name = name.split(".", 2)[:2] - trainable.add(".".join(top_name)) + trainable.add(name) else: # Freeze the rest param.requires_grad = False From aea39af74fc30ebe4128fb1a3799e7e7ff45250a Mon Sep 17 00:00:00 2001 From: GabrieleTi Date: Thu, 12 Mar 2026 10:10:26 +0100 Subject: [PATCH 45/60] Add setup to prediction heads + docs --- .../components/pred_heads/base_pred_head.py | 23 +++++++++++++--- .../components/pred_heads/linear_pred_head.py | 22 +++++++++++++-- .../components/pred_heads/mlp_pred_head.py | 25 +++++++++++++++-- .../pred_heads/mlp_regression_head.py | 27 +++++++++++++++++-- 4 files changed, 87 insertions(+), 10 deletions(-) diff --git a/src/models/components/pred_heads/base_pred_head.py b/src/models/components/pred_heads/base_pred_head.py index 7a4f7e1..3615fa5 100644 --- a/src/models/components/pred_heads/base_pred_head.py +++ b/src/models/components/pred_heads/base_pred_head.py @@ -7,6 +7,7 @@ class BasePredictionHead(nn.Module, ABC): def __init__(self) -> None: + """Base prediction head interface class.""" super().__init__() self.net: nn.Module | None = None self.input_dim: int | None = None @@ -14,16 +15,30 @@ def __init__(self) -> None: @abstractmethod def forward(self, feats: torch.Tensor) -> torch.Tensor: + """Forward pass through the prediction head.""" pass @final def set_dim(self, input_dim: int, output_dim: int) -> None: + """Set dimensions for the prediction head configuration. + + :param input_dim: input dimension + :param output_dim: output dimension + """ + assert isinstance(self.input_dim, int), TypeError( + "Input dimension must be specified as integer" + ) + assert isinstance(self.output_dim, int), TypeError( + "Output dimension must be specified as integer" + ) self.input_dim = input_dim self.output_dim = output_dim - assert type(self.input_dim) is int, self.input_dim - if output_dim is not None: - assert type(self.output_dim) is int, self.output_dim @abstractmethod - def configure_nn(self) -> None: + def setup(self) -> None: + """Configures networks, data-dependent parts. + + Gets called in model.setup() method. Returns names of any new module configured to be added + to the trainable modules list. + """ pass diff --git a/src/models/components/pred_heads/linear_pred_head.py b/src/models/components/pred_heads/linear_pred_head.py index aa8ee55..94338a7 100644 --- a/src/models/components/pred_heads/linear_pred_head.py +++ b/src/models/components/pred_heads/linear_pred_head.py @@ -7,15 +7,33 @@ class LinearPredictionHead(BasePredictionHead): - def __init__(self): + def __init__( + self, + input_dim: int | None = None, + output_dim: int | None = None, + ) -> None: + """Linear prediction head for classification. + + :param input_dim: the size of input dimension + :param output_dim: the size of output dimension + """ super().__init__() + if input_dim and output_dim: + self.set_dim(input_dim, output_dim) @override def forward(self, feats: torch.Tensor) -> torch.Tensor: + """Forward pass through the prediction head.""" + return torch.sigmoid(self.net(feats)) @override - def configure_nn(self) -> None: + def setup(self) -> None: + """Configures networks, data-dependent parts. + + Gets called in model.setup() method. Returns names of any new module configured to be added + to the trainable modules list. + """ assert type(self.input_dim) is int, self.input_dim assert type(self.output_dim) is int, self.output_dim self.net = nn.Linear(self.input_dim, self.output_dim) diff --git a/src/models/components/pred_heads/mlp_pred_head.py b/src/models/components/pred_heads/mlp_pred_head.py index 28db550..144d602 100644 --- a/src/models/components/pred_heads/mlp_pred_head.py +++ b/src/models/components/pred_heads/mlp_pred_head.py @@ -7,17 +7,38 @@ class MLPPredictionHead(BasePredictionHead): - def __init__(self, nn_layers: int = 2, hidden_dim: int = 256) -> None: + def __init__( + self, + nn_layers: int = 2, + hidden_dim: int = 256, + input_dim: int | None = None, + output_dim: int | None = None, + ) -> None: + """MLP prediction head for classification. + + :param nn_layers: number of layers in MLP + :param hidden_dim: the size of hidden dimensions + :param input_dim: the size of input dimension + :param output_dim: the size of output dimension + """ super().__init__() self.nn_layers = nn_layers self.hidden_dim = hidden_dim + if input_dim and output_dim: + self.set_dim(input_dim, output_dim) @override def forward(self, feats: torch.Tensor) -> torch.Tensor: + """Forward pass through the prediction head.""" return torch.sigmoid(self.net(feats)) @override - def configure_nn(self) -> None: + def setup(self) -> None: + """Configures networks, data-dependent parts. + + Gets called in model.setup() method. Returns names of any new module configured to be added + to the trainable modules list. + """ assert type(self.input_dim) is int, self.input_dim assert type(self.output_dim) is int, self.output_dim layers = [] diff --git a/src/models/components/pred_heads/mlp_regression_head.py b/src/models/components/pred_heads/mlp_regression_head.py index 818da44..a179efe 100644 --- a/src/models/components/pred_heads/mlp_regression_head.py +++ b/src/models/components/pred_heads/mlp_regression_head.py @@ -19,18 +19,41 @@ class MLPRegressionPredictionHead(BasePredictionHead): """MLP prediction head for regression tasks (outputs a continuous value).""" - def __init__(self, nn_layers: int = 2, hidden_dim: int = 256, dropout: float = 0.0) -> None: + def __init__( + self, + nn_layers: int = 2, + hidden_dim: int = 256, + dropout: float = 0.0, + input_dim: int | None = None, + output_dim: int | None = None, + ) -> None: + """MLP prediction head for regression tasks. + + :param nn_layers: number of layers in MLP + :param hidden_dim: the size of hidden dimensions + :param dropout: the dropout rate + :param input_dim: the size of input dimension + :param output_dim: the size of output dimension + """ super().__init__() self.nn_layers = nn_layers self.hidden_dim = hidden_dim self.dropout = dropout + if input_dim and output_dim: + self.set_dim(input_dim, output_dim) @override def forward(self, feats: torch.Tensor) -> torch.Tensor: + """Forward pass through the prediction head.""" return self.net(feats) @override - def configure_nn(self) -> None: + def setup(self) -> None: + """Configures networks, data-dependent parts. + + Gets called in model.setup() method. Returns names of any new module configured to be added + to the trainable modules list. + """ assert isinstance(self.input_dim, int), self.input_dim assert isinstance(self.output_dim, int), self.output_dim From ca43ec8ed86557d0d96bd5fc47dfbcd6b994cb26 Mon Sep 17 00:00:00 2001 From: GabrieleTi Date: Thu, 12 Mar 2026 10:11:28 +0100 Subject: [PATCH 46/60] pre-commit hook changes --- .../make_model_ready_yield_africa.py | 9 +++------ src/data_preprocessing/tessera_embeds.py | 4 +++- .../yield_africa_loco_splits.py | 8 +++----- .../yield_africa_tessera_preprocess.py | 19 +++++++++++-------- src/models/components/loss_fns/huber_loss.py | 2 +- src/models/components/metrics/r2.py | 10 +++++----- 6 files changed, 26 insertions(+), 26 deletions(-) diff --git a/src/data_preprocessing/make_model_ready_yield_africa.py b/src/data_preprocessing/make_model_ready_yield_africa.py index 25c4399..f2b6dd0 100644 --- a/src/data_preprocessing/make_model_ready_yield_africa.py +++ b/src/data_preprocessing/make_model_ready_yield_africa.py @@ -212,6 +212,7 @@ # Preprocessing functions # --------------------------------------------------------------------------- + def build_column_rename_map( continuous_features: List[str], tabular_categorical_features: List[str], @@ -419,11 +420,7 @@ def calculate_spatial_splits( splits: Dict[str, Any] = {} for fold in range(n_splits): - val_names = [ - name - for bid in fold_block_ids[fold] - for name in block_to_names[bid].tolist() - ] + val_names = [name for bid in fold_block_ids[fold] for name in block_to_names[bid].tolist()] train_names = [ name for f in range(n_splits) @@ -732,4 +729,4 @@ def main( args.years, args.exclude_countries, args.exclude_years, - ) \ No newline at end of file + ) diff --git a/src/data_preprocessing/tessera_embeds.py b/src/data_preprocessing/tessera_embeds.py index eff5e33..fcc6c32 100644 --- a/src/data_preprocessing/tessera_embeds.py +++ b/src/data_preprocessing/tessera_embeds.py @@ -127,7 +127,9 @@ def get_tessera_embeds( memfiles.append(reproject_memfile) if not tiles: - print(f"No TESSERA tiles found for {name_loc} at ({lon:.4f}, {lat:.4f}) year={year}. Skipping.") + print( + f"No TESSERA tiles found for {name_loc} at ({lon:.4f}, {lat:.4f}) year={year}. Skipping." + ) for mf in memfiles: mf.close() return diff --git a/src/data_preprocessing/yield_africa_loco_splits.py b/src/data_preprocessing/yield_africa_loco_splits.py index 1bfb1b7..cb563b8 100644 --- a/src/data_preprocessing/yield_africa_loco_splits.py +++ b/src/data_preprocessing/yield_africa_loco_splits.py @@ -112,12 +112,10 @@ def generate_splits( torch.save(split, out_path) log.info( - f" {country}: train={n_train}, val={n_val}, test={n_test} " - f"-> {out_path.name}" + f" {country}: train={n_train}, val={n_val}, test={n_test} " f"-> {out_path.name}" ) print( - f" Saved split_loco_{country}.pth " - f"(train={n_train}, val={n_val}, test={n_test})" + f" Saved split_loco_{country}.pth " f"(train={n_train}, val={n_val}, test={n_test})" ) @@ -168,4 +166,4 @@ def main() -> None: if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/src/data_preprocessing/yield_africa_tessera_preprocess.py b/src/data_preprocessing/yield_africa_tessera_preprocess.py index 15823d1..62ae298 100644 --- a/src/data_preprocessing/yield_africa_tessera_preprocess.py +++ b/src/data_preprocessing/yield_africa_tessera_preprocess.py @@ -93,7 +93,9 @@ def fetch_tessera_tiles( save_dir.mkdir(parents=True, exist_ok=True) if cache_dir is None: - cache_dir = os.environ.get("TESSERA_EMBEDDINGS_DIR") or str(Path(data_dir) / "cache" / "tessera") + cache_dir = os.environ.get("TESSERA_EMBEDDINGS_DIR") or str( + Path(data_dir) / "cache" / "tessera" + ) embeddings_dir = str(Path(cache_dir) / "raw") @@ -109,8 +111,7 @@ def fetch_tessera_tiles( n_total = len(df) n_existing = sum( - 1 for _, row in df.iterrows() - if (save_dir / f"tessera_{row.name_loc}.npy").exists() + 1 for _, row in df.iterrows() if (save_dir / f"tessera_{row.name_loc}.npy").exists() ) n_to_fetch = n_total - n_existing @@ -160,9 +161,9 @@ def _fetch_one(row) -> str: # Bound all socket operations (urllib HTTP requests inside geotessera). # Without this, a stalled connection blocks the thread until the OS TCP # keepalive fires, which can take many minutes. - SOCKET_TIMEOUT = 60 # seconds per socket operation - HEARTBEAT = 30 # print a heartbeat when no future completes this fast - TILE_TIMEOUT = 600 # give up warning after 10 min of complete silence + SOCKET_TIMEOUT = 60 # seconds per socket operation + HEARTBEAT = 30 # print a heartbeat when no future completes this fast + TILE_TIMEOUT = 600 # give up warning after 10 min of complete silence socket.setdefaulttimeout(SOCKET_TIMEOUT) rows = [row for _, row in df.iterrows()] @@ -186,7 +187,9 @@ def _fetch_one(row) -> str: f"{len(pending)} pending, {silent_seconds}s since last completion" ) if silent_seconds >= TILE_TIMEOUT: - print(f" WARNING: no progress in {TILE_TIMEOUT}s, something may be stuck.") + print( + f" WARNING: no progress in {TILE_TIMEOUT}s, something may be stuck." + ) continue silent_seconds = 0 @@ -281,4 +284,4 @@ def main() -> None: if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/src/models/components/loss_fns/huber_loss.py b/src/models/components/loss_fns/huber_loss.py index 0fa2ff1..57c0bed 100644 --- a/src/models/components/loss_fns/huber_loss.py +++ b/src/models/components/loss_fns/huber_loss.py @@ -26,4 +26,4 @@ def forward( if "return_label" in kwargs: return {self.name: huber_loss} else: - return huber_loss \ No newline at end of file + return huber_loss diff --git a/src/models/components/metrics/r2.py b/src/models/components/metrics/r2.py index baccb06..e9d2c4e 100644 --- a/src/models/components/metrics/r2.py +++ b/src/models/components/metrics/r2.py @@ -12,10 +12,10 @@ class RSquared(BaseMetrics): """Epoch-level R² using torchmetrics.R2Score. - A separate R2Score accumulator is kept per mode so that train, val, and - test statistics never mix. Lightning detects the returned torchmetrics - Metric objects and calls .compute()/.reset() at epoch boundaries, giving - a correct epoch-wide R² instead of an average of per-batch R² values. + A separate R2Score accumulator is kept per mode so that train, val, and test statistics never + mix. Lightning detects the returned torchmetrics Metric objects and calls .compute()/.reset() + at epoch boundaries, giving a correct epoch-wide R² instead of an average of per-batch R² + values. """ def __init__(self) -> None: @@ -38,4 +38,4 @@ def forward( metric = self._r2[f"mode_{mode}"] metric.update(pred.squeeze(-1), labels.squeeze(-1)) - return {self.name: metric} \ No newline at end of file + return {self.name: metric} From afef0b6958801e8f1e1c26b8737cf1a7e0edf4bc Mon Sep 17 00:00:00 2001 From: GabrieleTi Date: Thu, 12 Mar 2026 10:12:13 +0100 Subject: [PATCH 47/60] Setup method for mlp projector --- src/models/components/geo_encoders/mlp_projector.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/models/components/geo_encoders/mlp_projector.py b/src/models/components/geo_encoders/mlp_projector.py index 5417eb5..f4c1137 100644 --- a/src/models/components/geo_encoders/mlp_projector.py +++ b/src/models/components/geo_encoders/mlp_projector.py @@ -1,3 +1,5 @@ +from typing import List, override + import torch from torch import nn @@ -22,6 +24,11 @@ def __init__( # Placeholder self.net: nn.Module | None = None + @override + def setup(self) -> List[str]: + self.configure_nn() + return ["net"] + def set_input_dim(self, input_dim: int) -> None: self.input_dim = input_dim From 6200a78b47d56dbfcf4fa9dea7ba9b12fd4e009f Mon Sep 17 00:00:00 2001 From: GabrieleTi Date: Thu, 12 Mar 2026 10:12:29 +0100 Subject: [PATCH 48/60] Introduce encoder wrapper to remove multi-modal encoder --- configs/data/yield_africa_loco.yaml | 2 +- configs/data/yield_africa_tessera.yaml | 2 +- .../experiment/yield_africa_tabular_loco.yaml | 2 +- .../yield_africa_tessera_fusion_reg.yaml | 2 +- .../experiment/yield_africa_tessera_reg.yaml | 2 +- .../model/example_for_encoder_wrapper.yaml | 15 ++ configs/model/yield_tessera_fusion_reg.yaml | 5 + configs/model/yield_tessera_reg.yaml | 2 +- src/data/yield_africa_dataset.py | 4 +- .../geo_encoders/encoder_wrapper.py | 109 ++++++++++---- .../geo_encoders/multimodal_encoder.py | 141 ------------------ src/models/predictive_model.py | 14 +- src/models/text_alignment_model.py | 80 ++++++---- tests/test_pred_heads.py | 2 +- 14 files changed, 172 insertions(+), 210 deletions(-) create mode 100644 configs/model/example_for_encoder_wrapper.yaml delete mode 100644 src/models/components/geo_encoders/multimodal_encoder.py diff --git a/configs/data/yield_africa_loco.yaml b/configs/data/yield_africa_loco.yaml index 3a20b3e..92559d0 100644 --- a/configs/data/yield_africa_loco.yaml +++ b/configs/data/yield_africa_loco.yaml @@ -30,4 +30,4 @@ pin_memory: false split_mode: "from_file" saved_split_file_name: "split_loco_KEN.pth" save_split: false -seed: ${seed} \ No newline at end of file +seed: ${seed} diff --git a/configs/data/yield_africa_tessera.yaml b/configs/data/yield_africa_tessera.yaml index e997f19..533b483 100644 --- a/configs/data/yield_africa_tessera.yaml +++ b/configs/data/yield_africa_tessera.yaml @@ -28,4 +28,4 @@ pin_memory: false split_mode: "random" train_val_test_split: [0.7, 0.15, 0.15] save_split: false -seed: ${seed} \ No newline at end of file +seed: ${seed} diff --git a/configs/experiment/yield_africa_tabular_loco.yaml b/configs/experiment/yield_africa_tabular_loco.yaml index 865ab92..cce363d 100644 --- a/configs/experiment/yield_africa_tabular_loco.yaml +++ b/configs/experiment/yield_africa_tabular_loco.yaml @@ -27,4 +27,4 @@ logger: tags: ${tags} group: "yield_africa" aim: - experiment: "yield_africa" \ No newline at end of file + experiment: "yield_africa" diff --git a/configs/experiment/yield_africa_tessera_fusion_reg.yaml b/configs/experiment/yield_africa_tessera_fusion_reg.yaml index c9edb54..93c2052 100644 --- a/configs/experiment/yield_africa_tessera_fusion_reg.yaml +++ b/configs/experiment/yield_africa_tessera_fusion_reg.yaml @@ -28,4 +28,4 @@ logger: tags: ${tags} group: "yield_africa" aim: - experiment: "yield_africa" \ No newline at end of file + experiment: "yield_africa" diff --git a/configs/experiment/yield_africa_tessera_reg.yaml b/configs/experiment/yield_africa_tessera_reg.yaml index 7e0c93b..ca4e2ed 100644 --- a/configs/experiment/yield_africa_tessera_reg.yaml +++ b/configs/experiment/yield_africa_tessera_reg.yaml @@ -24,4 +24,4 @@ logger: tags: ${tags} group: "yield_africa" aim: - experiment: "yield_africa" \ No newline at end of file + experiment: "yield_africa" diff --git a/configs/model/example_for_encoder_wrapper.yaml b/configs/model/example_for_encoder_wrapper.yaml new file mode 100644 index 0000000..c2dc55e --- /dev/null +++ b/configs/model/example_for_encoder_wrapper.yaml @@ -0,0 +1,15 @@ +_target_: src.models.components.geo_encoders.encoder_wrapper.EncoderWrapper + +encoder_branches: + - encoder: + _target_: src.models.components.geo_encoders.average_encoder.AverageEncoder + geo_data_name: aef + projector: + _target_: src.models.components.geo_encoders.mlp_projector.MLPProjector + nn_layers: 2 + hidden_dim: 512 + output_dim: 512 +# - encoder: # another branch +# __target__: + +fusion_strategy: "concat" diff --git a/configs/model/yield_tessera_fusion_reg.yaml b/configs/model/yield_tessera_fusion_reg.yaml index 0fc49d4..c30b6af 100644 --- a/configs/model/yield_tessera_fusion_reg.yaml +++ b/configs/model/yield_tessera_fusion_reg.yaml @@ -12,6 +12,11 @@ geo_encoder: - encoder: _target_: src.models.components.geo_encoders.average_encoder.AverageEncoder geo_data_name: tessera + projector: + _target_: src.models.components.geo_encoders.mlp_projector.MLPProjector + nn_layers: 2 + hidden_dim: 512 + output_dim: 512 - encoder: _target_: src.models.components.geo_encoders.tabular_encoder.TabularEncoder output_dim: 256 diff --git a/configs/model/yield_tessera_reg.yaml b/configs/model/yield_tessera_reg.yaml index 69adb0b..b5ef731 100644 --- a/configs/model/yield_tessera_reg.yaml +++ b/configs/model/yield_tessera_reg.yaml @@ -31,4 +31,4 @@ scheduler: patience: 10 loss_fn: - _target_: src.models.components.loss_fns.huber_loss.HuberLoss \ No newline at end of file + _target_: src.models.components.loss_fns.huber_loss.HuberLoss diff --git a/src/data/yield_africa_dataset.py b/src/data/yield_africa_dataset.py index f2867d8..8c09801 100644 --- a/src/data/yield_africa_dataset.py +++ b/src/data/yield_africa_dataset.py @@ -110,7 +110,9 @@ def __init__( n_after = len(self.df) if n_after != n_before: - log.info(f"Country/year filter: {n_before} → {n_after} records ({n_before - n_after} excluded)") + log.info( + f"Country/year filter: {n_before} → {n_after} records ({n_before - n_after} excluded)" + ) # get_records() mutates self.use_aux_data in place (replacing pattern # dicts with resolved column-name lists), so reset it from the original diff --git a/src/models/components/geo_encoders/encoder_wrapper.py b/src/models/components/geo_encoders/encoder_wrapper.py index bbc8453..9cdf8ad 100644 --- a/src/models/components/geo_encoders/encoder_wrapper.py +++ b/src/models/components/geo_encoders/encoder_wrapper.py @@ -1,6 +1,7 @@ from typing import Any, Dict, List, override import torch +import torch.nn as nn from src.models.components.geo_encoders.base_geo_encoder import BaseGeoEncoder from src.models.components.geo_encoders.tabular_encoder import TabularEncoder @@ -12,43 +13,101 @@ class EncoderWrapper(BaseGeoEncoder): def __init__( self, encoder_branches: List[Dict[str, Any]], - fusion_strategy: str, + fusion_strategy: str = "concat", ): super().__init__() - self.encoder_branches = encoder_branches + self.output_dim = None + + self._reformat_set_branches(encoder_branches) + assert fusion_strategy in ["mean", "concat", "none"], ValueError( f'Unsupported fusion strategy "{fusion_strategy}"' ) self.fusion_strategy = fusion_strategy - self.output_dim = None + + def _reformat_set_branches(self, encoder_branches: List[Dict[str, Any]]): + """Reformatting to allow registering modules properly.""" + self.encoder_branches = nn.ModuleList() + + for branch in encoder_branches: + module_dict = nn.ModuleDict({"encoder": branch["encoder"]}) + + if branch.get("projector") is not None: + module_dict["projector"] = branch["projector"] + + self.encoder_branches.append(module_dict) + + @override + def setup(self) -> List[str]: + new_modules = [] # Configure/initialise missing/conditional parts - for branch in self.encoder_branches: - intermediate_dim = branch.get("encoder").output_dim - projector = branch.get("projector", None) - if projector is not None: + for i, branch in enumerate(self.encoder_branches): + # Setup encoder + encoder = branch["encoder"] + + # Configure tabular encoder + if isinstance(encoder, TabularEncoder): + if self.tabular_dim is None: + raise ValueError("TabularEncoder requires tabular_dim") + encoder.set_tabular_input_dim(self.tabular_dim) + + new_parts = encoder.setup() + new_modules.extend( + [f"encoder_branches.{str(i)}.encoder.{p}" for p in new_parts] + if len(new_parts) != 0 + else [] + ) + + # Configure adapter/projector if requested + if "projector" in branch: + projector = branch["projector"] + + intermediate_dim = encoder.output_dim projector.set_input_dim(input_dim=intermediate_dim) - projector.configure_nn() + new_parts = projector.setup() + new_modules.extend( + [f"encoder_branches.{str(i)}.projector.{p}" for p in new_parts] + if len(new_parts) != 0 + else [] + ) + + self.set_output_dim() + return new_modules + + def set_tabular_input_dim(self, tabular_dim=None): + """Set tabular dimension if there is tabular encoder.""" + self.tabular_dim = None - def configure_nn(self, tabular_dim: int) -> None: + for branch in self.encoder_branches: + branch_out_dim = branch["encoder"] + if isinstance(branch_out_dim, TabularEncoder): + self.tabular_dim = tabular_dim + return + + def set_output_dim(self): + """Calculates the output dimension.""" + + # Collect all output dimensions output_dims = [] - new_parts = set() for branch in self.encoder_branches: - if isinstance(branch["encoder"], TabularEncoder): - branch["encoder"].configure_nn(tabular_dim) - new_parts.add("ta") - if branch.get("projector"): - output_dims.append(branch["projector"].output_dim) - else: - output_dims.append(branch["encoder"].output_dim) + branch_out_dim = branch["encoder"].output_dim + if "projector" in branch: + projector = branch["projector"] + branch_out_dim = projector.output_dim + + output_dims.append(branch_out_dim) + + # Combine output dimensions if self.fusion_strategy == "concat": self.output_dim = sum(output_dims) elif self.fusion_strategy == "mean": - assert set(output_dims) == 1, ValueError( - f"Encoder branches produces different output dimensions {output_dims} and cannot be averaged." - ) + if set(output_dims) != 1: + raise ValueError( + f"Encoder branches produces different output dimensions {output_dims} and cannot be averaged." + ) self.output_dim = output_dims[0] @override @@ -57,7 +116,7 @@ def forward(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: for branch in self.encoder_branches: feats = branch["encoder"].forward(batch) # each encoder knows what modality it needs - if branch.get("projector", None): + if "projector" in branch: feats = branch["projector"].forward(feats) branch_feats.append(feats) @@ -72,8 +131,8 @@ def device(self): for branch in self.encoder_branches: encoder = branch["encoder"] devices.update({p.device for p in encoder.parameters()}) - projector = branch.get("projector") - if projector is not None: + if "projector" in branch: + projector = branch["projector"] devices.update({p.device for p in projector.parameters()}) if len(devices) != 1: @@ -86,8 +145,8 @@ def dtype(self) -> torch.dtype: for branch in self.encoder_branches: encoder = branch["encoder"] dtypes.update({p.dtype for p in encoder.parameters()}) - projector = branch.get("projector") - if projector is not None: + if "projector" in branch: + projector = branch["projector"] dtypes.update({p.dtype for p in projector.parameters()}) if len(dtypes) != 1: diff --git a/src/models/components/geo_encoders/multimodal_encoder.py b/src/models/components/geo_encoders/multimodal_encoder.py deleted file mode 100644 index 2cb1e47..0000000 --- a/src/models/components/geo_encoders/multimodal_encoder.py +++ /dev/null @@ -1,141 +0,0 @@ -"""Unified multimodal encoder for EO data. - -Controlled entirely via constructor flags: - - use_coords: activate the spatial/geo encoder branch - - use_tabular: encode feat_* tabular columns - - geo_encoder_cfg: pluggable geo encoder (any BaseGeoEncoder subclass); - when None and use_coords=True, defaults to GeoClipCoordinateEncoder -""" - -from typing import Dict, override - -import torch -from torch import nn - -from src.models.components.geo_encoders.base_geo_encoder import BaseGeoEncoder -from src.models.components.geo_encoders.geoclip import GeoClipCoordinateEncoder - - -class MultiModalEncoder(BaseGeoEncoder): - """ - Modes (controlled by use_coords / use_tabular flags): - - - geo only (use_coords=True, use_tabular=False) - - tabular only (use_coords=False, use_tabular=True) - - geo + tabular (use_coords=True, use_tabular=True) - - The geo encoder defaults to GeoClipCoordinateEncoder but can be replaced - with any BaseGeoEncoder via the geo_encoder_cfg parameter. Hydra - instantiates geo_encoder_cfg before passing it here, so it arrives as a - ready-to-use nn.Module (e.g. AverageEncoder for TESSERA tiles). - - Example config (TESSERA + tabular fusion): - geo_encoder: - _target_: ...MultiModalEncoder - use_coords: true - use_tabular: true - geo_encoder_cfg: - _target_: ...AverageEncoder - geo_data_name: tessera - """ - - def __init__( - self, - use_coords: bool = True, - use_tabular: bool = False, - tab_embed_dim: int = 64, - tabular_dropout: float = 0.0, - tabular_dim: int = None, - geo_encoder_cfg: BaseGeoEncoder | None = None, - ) -> None: - super().__init__() - - assert use_coords or use_tabular, "At least one of use_coords or use_tabular must be True." - - self.use_coords = use_coords - self.use_tabular = use_tabular - self.tab_embed_dim = tab_embed_dim - self.tabular_dropout = tabular_dropout - self._tabular_ready = False - self.fusion_norm = None # set in build_tabular_branch when both branches active - - coords_dim = 0 - if use_coords: - if geo_encoder_cfg is not None: - self.coords_encoder = geo_encoder_cfg - else: - self.coords_encoder = GeoClipCoordinateEncoder() - coords_dim = self.coords_encoder.output_dim - - self._coords_dim = coords_dim - - # Built only if dim is already known - if use_tabular and tabular_dim is not None: - self.build_tabular_branch(tabular_dim) - elif use_tabular: - self.tabular_proj = None - else: - self.output_dim = coords_dim - - # ------------------------------------------------------------------ - # Public API - # ------------------------------------------------------------------ - - def build_tabular_branch(self, tabular_dim: int) -> None: - """Build (or rebuild) the tabular projection MLP. - - Architecture: LayerNorm → Linear(in, h) → ReLU → Dropout → - Linear(h, h//2) → ReLU → Dropout → Linear(h//2, out) - where h = max(tab_embed_dim * 2, 128). - """ - if self._tabular_ready and hasattr(self, "_last_tabular_dim"): - if self._last_tabular_dim == tabular_dim: - return # already built with correct dim - - hidden = max(self.tab_embed_dim * 2, 128) - drop = self.tabular_dropout - self.tabular_proj = nn.Sequential( - nn.LayerNorm(tabular_dim), - nn.Linear(tabular_dim, hidden), - nn.ReLU(), - nn.Dropout(drop), - nn.Linear(hidden, hidden // 2), - nn.ReLU(), - nn.Dropout(drop), - nn.Linear(hidden // 2, self.tab_embed_dim), - ) - self._last_tabular_dim = tabular_dim - self._tabular_ready = True - self.output_dim = self._coords_dim + self.tab_embed_dim - - # Normalise the fused representation when both branches are active. - # The geo encoder output and the tabular projection may have different - # scales, so a LayerNorm stabilises training after concat. - if self.use_coords: - self.fusion_norm = nn.LayerNorm(self.output_dim) - else: - self.fusion_norm = None - - # ------------------------------------------------------------------ - # Forward - # ------------------------------------------------------------------ - - @override - def forward(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: - parts = [] - - if self.use_coords: - parts.append(self.coords_encoder(batch)) # (B, coords_encoder.output_dim) - - if self.use_tabular: - assert self._tabular_ready, ( - "Tabular branch not built yet. Call build_tabular_branch(tabular_dim) first, " - "or pass tabular_dim to the constructor." - ) - tab = batch["eo"]["tabular"].float() # (B, tabular_dim) - parts.append(self.tabular_proj(tab)) # (B, tab_embed_dim) - - fused = torch.cat(parts, dim=-1) - if self.fusion_norm is not None: - fused = self.fusion_norm(fused) - return fused diff --git a/src/models/predictive_model.py b/src/models/predictive_model.py index 2954584..16bae50 100644 --- a/src/models/predictive_model.py +++ b/src/models/predictive_model.py @@ -50,6 +50,7 @@ def __init__( # Prediction head self.prediction_head = prediction_head + # Normalise features boolean self.normalize_features = normalize_features @override @@ -65,17 +66,22 @@ def setup(self, stage: str) -> None: def setup_encoders_adapters(self): """Set up encoders and missing adapters/projectors.""" # TODO: move to multi-modal eo encoder + + # If tabular encoder used, we need to specify tabular dim if isinstance(self.geo_encoder, TabularEncoder) or isinstance( self.geo_encoder, EncoderWrapper ): - self.geo_encoder.configure_nn(self.tabular_dim) - if self.tabular_dim: - self.trainable_modules.append("tabular_encoder") + self.geo_encoder.set_tabular_input_dim(self.tabular_dim) + + # Setup encoders that need data-depended configurations + new_modules = [f"geo_encoder.{i}]" for i in self.geo_encoder.setup()] + self.trainable_modules.extend(new_modules) + # Configure prediction head based on geo-encoder output_dim self.prediction_head.set_dim( input_dim=self.geo_encoder.output_dim, output_dim=self.num_classes ) - self.prediction_head.configure_nn() + self.prediction_head.setup() if "prediction_head" not in self.trainable_modules: self.trainable_modules.append("prediction_head") diff --git a/src/models/text_alignment_model.py b/src/models/text_alignment_model.py index 78819b3..2630ff3 100644 --- a/src/models/text_alignment_model.py +++ b/src/models/text_alignment_model.py @@ -32,6 +32,7 @@ def __init__( metrics: MetricsWrapper, prediction_head: BasePredictionHead | None = None, ks: list[int] | None = [5, 10, 15], + match_to_geo: bool = True, ) -> None: """Implementation of contrastive text-eo modality alignment model. @@ -45,15 +46,20 @@ def __init__( :param num_classes: number of target classes :param tabular_dim: number of tabular features :param prediction_head: prediction head + :param ks: list of ks + :param match_to_geo: whether to match dimensions of text encoder to geo_encoder or visa- + versa """ super().__init__(trainable_modules, optimizer, scheduler, loss_fn, metrics) + # Metrics self.ks = ks self.log_kwargs = dict(on_step=False, on_epoch=True, prog_bar=True, sync_dist=True) # Encoders configuration self.geo_encoder = geo_encoder self.text_encoder = text_encoder + self.match_to_geo = match_to_geo # Prediction head self.prediction_head = prediction_head @@ -74,31 +80,38 @@ def setup(self, stage: str) -> None: def setup_encoders_adapters(self): """Set up encoders and missing adapters/projectors.""" - # TODO: move to multi-modal eo encoder - if ( - isinstance(self.geo_encoder, MultiModalEncoder) - and self.geo_encoder.use_tabular - and not self.geo_encoder._tabular_ready - ): - self.geo_encoder.build_tabular_branch(self.tabular_dim) + # We don't use tabular encoders for wrapping + # if ( + # isinstance(self.geo_encoder, MultiModalEncoder) + # and self.geo_encoder.use_tabular + # and not self.geo_encoder._tabular_ready + # ): + # self.geo_encoder.build_tabular_branch(self.tabular_dim) + + # Setup encoders that need data-depended configurations + new_modules = [f"geo_encoder.{i}" for i in self.geo_encoder.setup()] + self.trainable_modules.extend(new_modules) # Extra projector for text encoder if eo and text dim not match if self.geo_encoder.output_dim != self.text_encoder.output_dim: - self.text_encoder.add_projector(projected_dim=self.geo_encoder.output_dim) - self.trainable_modules.append("text_encoder.extra_projector") - - # TODO: if eo==geoclip_img pass on shared mlp - + if self.match_to_geo: + self.text_encoder.add_projector(projected_dim=self.geo_encoder.output_dim) + self.trainable_modules.append("text_encoder.extra_projector") + else: + self.geo_encoder.add_projector(projected_dim=self.text_encoder.output_dim) + self.trainable_modules.append("geo_encoder.extra_projector") + + # Configure prediction head based on geo-encoder output_dim if self.prediction_head is not None: self.prediction_head.set_dim( input_dim=self.geo_encoder.output_dim, output_dim=self.num_classes ) - self.prediction_head.configure_nn() + self.prediction_head.setup() - # Unify dtypes - if self.geo_encoder.dtype != self.text_encoder.dtype: - self.geo_encoder = self.geo_encoder.to(self.text_encoder.dtype) - print(f"Geo encoder dtype changed to {self.geo_encoder.dtype}") + # # Unify dtypes -> moving to data part, rather than changing parameter type + # if self.geo_encoder.dtype != self.text_encoder.dtype: + # self.geo_encoder = self.geo_encoder.to(self.text_encoder.dtype) + # print(f"Geo encoder dtype changed to {self.geo_encoder.dtype}") def setup_retrieval_evaluation(self): self.concept_configs = self.trainer.datamodule.concept_configs @@ -125,33 +138,37 @@ def forward( """Model forward logic.""" # Embed modalities - eo_feats = self.geo_encoder(batch) + geo_feats = self.geo_encoder(batch) text_feats = self.text_encoder(batch, mode) - return eo_feats, text_feats + + # Change dtype of geo data if it does not match text dtype + if geo_feats.dtype != text_feats.dtype: + geo_feats = geo_feats.to(text_feats.dtype) + return geo_feats, text_feats @override def _step(self, batch: Dict[str, torch.Tensor], mode: str = "train"): """Model step logic.""" # Embed - eo_feats, text_feats = self.forward(batch, mode) - local_batch_size = eo_feats.size(0) + geo_feats, text_feats = self.forward(batch, mode) + local_batch_size = geo_feats.size(0) # batch recomposing in ddp if self.trainer.world_size > 1: - feats = torch.stack([eo_feats, text_feats], dim=0) + feats = torch.stack([geo_feats, text_feats], dim=0) feats = self.all_gather(feats) feats = feats.reshape(2, -1, feats.size(-1)) - eo_feats, text_feats = feats[0], feats[1] + geo_feats, text_feats = feats[0], feats[1] # Get loss - loss = self.loss_fn(eo_feats, text_feats) + loss = self.loss_fn(geo_feats, text_feats) # Get similarities with torch.no_grad(): metrics = self.metrics( mode=mode, - eo_feats=eo_feats, + geo_feats=geo_feats, text_feats=text_feats, local_batch_size=local_batch_size, ) @@ -172,23 +189,22 @@ def _step(self, batch: Dict[str, torch.Tensor], mode: str = "train"): if mode in ["val", "test"]: self.outputs_epoch_memory.append( { - "eo_feats": eo_feats.detach(), + "geo_feats": geo_feats.detach(), "aux_vals": batch.get("aux", {}).get("aux").detach(), } ) return loss - @override def _on_epoch_end(self, mode: str): # Combine batches - eo_feats = torch.cat([x["eo_feats"] for x in self.outputs_epoch_memory], dim=0) + geo_feats = torch.cat([x["geo_feats"] for x in self.outputs_epoch_memory], dim=0) aux_vals = torch.cat([x["aux_vals"] for x in self.outputs_epoch_memory], dim=0) # Rank on similarity - similarity = self.concept_similarities(eo_feats) + similarity = self.concept_similarities(geo_feats) concept_scores = self.contrastive_val(similarity, aux_values=aux_vals) # TODO pearson @@ -216,7 +232,7 @@ def on_validation_epoch_end(self): def on_test_epoch_end(self): return self._on_epoch_end("test") - def concept_similarities(self, eo_embeds, concept=None) -> torch.Tensor: + def concept_similarities(self, geo_embeds, concept=None) -> torch.Tensor: # Get concept embeddings if concept is not None: # If only one concept is provided @@ -232,8 +248,8 @@ def concept_similarities(self, eo_embeds, concept=None) -> torch.Tensor: concept_embeds = self.text_encoder({"text": self.concepts}, mode="train") # Similarity - eo_embeds = F.normalize(eo_embeds, dim=1) + geo_embeds = F.normalize(geo_embeds, dim=1) concept_embeds = F.normalize(concept_embeds, dim=1) - similarity_matrix = concept_embeds @ eo_embeds.T + similarity_matrix = concept_embeds @ geo_embeds.T return similarity_matrix diff --git a/tests/test_pred_heads.py b/tests/test_pred_heads.py index 9bd2c62..15553ea 100644 --- a/tests/test_pred_heads.py +++ b/tests/test_pred_heads.py @@ -43,7 +43,7 @@ def test_pred_head_generic_properties(create_butterfly_dataset): assert callable( getattr(pred_head, "configure_nn") ), f"'configure_nn' is not callable in {pred_head_class.__name__}." - pred_head.configure_nn() + pred_head.setup() assert hasattr(pred_head, "net"), f"'net' attribute missing in {pred_head_class.__name__}." assert hasattr( pred_head, "forward" From 94832e054919e171614fadf9e6d64c8757ec30b1 Mon Sep 17 00:00:00 2001 From: GabrieleTi Date: Thu, 12 Mar 2026 10:51:49 +0100 Subject: [PATCH 49/60] Fix encoder tests --- .../components/pred_heads/base_pred_head.py | 4 ++-- src/models/text_alignment_model.py | 1 - tests/test_geo_encoders.py | 16 ++++++++++++---- 3 files changed, 14 insertions(+), 7 deletions(-) diff --git a/src/models/components/pred_heads/base_pred_head.py b/src/models/components/pred_heads/base_pred_head.py index 3615fa5..5ed63c4 100644 --- a/src/models/components/pred_heads/base_pred_head.py +++ b/src/models/components/pred_heads/base_pred_head.py @@ -25,10 +25,10 @@ def set_dim(self, input_dim: int, output_dim: int) -> None: :param input_dim: input dimension :param output_dim: output dimension """ - assert isinstance(self.input_dim, int), TypeError( + assert isinstance(input_dim, int), TypeError( "Input dimension must be specified as integer" ) - assert isinstance(self.output_dim, int), TypeError( + assert isinstance(output_dim, int), TypeError( "Output dimension must be specified as integer" ) self.input_dim = input_dim diff --git a/src/models/text_alignment_model.py b/src/models/text_alignment_model.py index 2630ff3..9cffc5e 100644 --- a/src/models/text_alignment_model.py +++ b/src/models/text_alignment_model.py @@ -6,7 +6,6 @@ from src.models.base_model import BaseModel from src.models.components.geo_encoders.base_geo_encoder import BaseGeoEncoder -from src.models.components.geo_encoders.multimodal_encoder import MultiModalEncoder from src.models.components.loss_fns.base_loss_fn import BaseLossFn from src.models.components.metrics.contrastive_validation import ( RetrievalContrastiveValidation, diff --git a/tests/test_geo_encoders.py b/tests/test_geo_encoders.py index 9a54732..57ff921 100644 --- a/tests/test_geo_encoders.py +++ b/tests/test_geo_encoders.py @@ -6,10 +6,10 @@ import torch from src.models.components.geo_encoders.average_encoder import AverageEncoder -from src.models.components.geo_encoders.base_geo_encoder import BaseGeoEncoder from src.models.components.geo_encoders.cnn_encoder import CNNEncoder from src.models.components.geo_encoders.geoclip import GeoClipCoordinateEncoder -from src.models.components.geo_encoders.multimodal_encoder import MultiModalEncoder +from src.models.components.geo_encoders.mlp_projector import MLPProjector +from src.models.components.geo_encoders.tabular_encoder import TabularEncoder # @pytest.mark.slow @@ -19,13 +19,21 @@ def test_geo_encoder_generic_properties(create_butterfly_dataset): "geoclip_coords": GeoClipCoordinateEncoder, "cnn": CNNEncoder, "average": AverageEncoder, - "multimodal_coords": MultiModalEncoder, + "tabular": TabularEncoder, + "mlp_projector": MLPProjector, } ds, dm = create_butterfly_dataset batch = next(iter(dm.train_dataloader())) for geo_encoder_name, geo_encoder_class in dict_geo_encoders.items(): - geo_encoder = geo_encoder_class() + if geo_encoder_class is MLPProjector: + geo_encoder = geo_encoder_class(output_dim=64, input_dim=128) + elif geo_encoder_class is TabularEncoder: + geo_encoder = geo_encoder_class(output_dim=64, input_dim=128, hidden_dim=128) + else: + geo_encoder = geo_encoder_class() + + geo_encoder.setup() assert hasattr( geo_encoder, "geo_encoder" From 004b803f37783d9777ce75f325454acdfa6d2077 Mon Sep 17 00:00:00 2001 From: GabrieleTi Date: Thu, 12 Mar 2026 12:02:45 +0100 Subject: [PATCH 50/60] fix tests --- .../components/metrics/contrastive_similarities.py | 4 ++-- tests/test_pred_heads.py | 12 +++++++----- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/src/models/components/metrics/contrastive_similarities.py b/src/models/components/metrics/contrastive_similarities.py index c3aae7e..4a72635 100644 --- a/src/models/components/metrics/contrastive_similarities.py +++ b/src/models/components/metrics/contrastive_similarities.py @@ -15,7 +15,7 @@ def __init__(self, k_list=None) -> None: def forward( self, mode: str, - eo_feats: torch.Tensor, + geo_feats: torch.Tensor, text_feats: torch.Tensor, local_batch_size: int, **kwargs, @@ -23,7 +23,7 @@ def forward( """Calculate cosine similarity between eo and text embeddings and logs it.""" # Similarity matrix - cos_sim_matrix = F.cosine_similarity(eo_feats[:, None, :], text_feats[None, :, :], dim=-1) + cos_sim_matrix = F.cosine_similarity(geo_feats[:, None, :], text_feats[None, :, :], dim=-1) # Average for positive and negative pairs # TODO change label option if we change what gets treated to be pos/neg diff --git a/tests/test_pred_heads.py b/tests/test_pred_heads.py index 15553ea..abc99b4 100644 --- a/tests/test_pred_heads.py +++ b/tests/test_pred_heads.py @@ -19,11 +19,13 @@ def test_pred_head_generic_properties(create_butterfly_dataset): ds, dm = create_butterfly_dataset batch = next(iter(dm.train_dataloader())) eo_encoder = GeoClipCoordinateEncoder() + eo_encoder.setup() feats = eo_encoder.forward(batch) list_pred_heads = [LinearPredictionHead, MLPPredictionHead, MLPRegressionPredictionHead] for pred_head_class in list_pred_heads: - pred_head = pred_head_class() + pred_head = pred_head_class(input_dim=64, output_dim=64) + pred_head.setup() assert hasattr( pred_head, "set_dim" ), f"'set_dim' method missing in {pred_head_class.__name__}." @@ -38,11 +40,11 @@ def test_pred_head_generic_properties(create_butterfly_dataset): pred_head, "output_dim" ), f"'output_dim' attribute missing in {pred_head_class.__name__}." assert hasattr( - pred_head, "configure_nn" - ), f"'configure_nn' method missing in {pred_head_class.__name__}." + pred_head, "setup" + ), f"'setup' method missing in {pred_head_class.__name__}." assert callable( - getattr(pred_head, "configure_nn") - ), f"'configure_nn' is not callable in {pred_head_class.__name__}." + getattr(pred_head, "setup") + ), f"'setup' is not callable in {pred_head_class.__name__}." pred_head.setup() assert hasattr(pred_head, "net"), f"'net' attribute missing in {pred_head_class.__name__}." assert hasattr( From 0855b4ae25a5e24ce1d071b2afa6dfeae0d279f1 Mon Sep 17 00:00:00 2001 From: GabrieleTi Date: Thu, 12 Mar 2026 12:17:49 +0100 Subject: [PATCH 51/60] fix tests --- src/models/base_model.py | 8 ++++---- src/models/predictive_model.py | 4 ++++ 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/src/models/base_model.py b/src/models/base_model.py index a28c085..89fbaf6 100644 --- a/src/models/base_model.py +++ b/src/models/base_model.py @@ -45,7 +45,7 @@ def setup(self, stage: str) -> None: @final def freezer(self) -> None: """Freezes modules based on provided trainable modules.""" - self.trainable_modules = tuple(self.trainable_modules) or tuple() + trainable_modules = tuple(self.trainable_modules) or tuple() # Store higher level module names for printing of trainable parts trainable = set() @@ -53,7 +53,7 @@ def freezer(self) -> None: # Freeze modules for name, param in self.named_parameters(): # Enable exceptions - if name.startswith(self.trainable_modules): + if name.startswith(trainable_modules): param.requires_grad = True trainable.add(name) else: @@ -69,8 +69,8 @@ def freezer(self) -> None: # - it is the root module (""), which must be train when any child is. def _in_train_scope(name: str) -> bool: if not name: # root module - return bool(self.trainable_modules) - for t in self.trainable_modules: + return bool(trainable_modules) + for t in trainable_modules: if name == t or name.startswith(t + ".") or t.startswith(name + "."): return True return False diff --git a/src/models/predictive_model.py b/src/models/predictive_model.py index 16bae50..b46a3f4 100644 --- a/src/models/predictive_model.py +++ b/src/models/predictive_model.py @@ -58,6 +58,10 @@ def setup(self, stage: str) -> None: self.num_classes = self.trainer.datamodule.num_classes self.tabular_dim = self.trainer.datamodule.tabular_dim + if stage != "fit": + if isinstance(self.trainable_modules, tuple): + self.trainable_modules = list(self.trainable_modules) + self.setup_encoders_adapters() # Freezing requested parts From 215e6034fae57c300bb860628627713d6b5f2c86 Mon Sep 17 00:00:00 2001 From: GabrieleTi Date: Thu, 12 Mar 2026 14:52:30 +0100 Subject: [PATCH 52/60] Fix depth of summary report for modules --- configs/callbacks/default.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/configs/callbacks/default.yaml b/configs/callbacks/default.yaml index 149a92f..e161020 100644 --- a/configs/callbacks/default.yaml +++ b/configs/callbacks/default.yaml @@ -19,4 +19,4 @@ early_stopping: mode: "min" model_summary: - max_depth: 2 + max_depth: 1 From eb571bc2791b2e46d66e086940717c58e4c840d8 Mon Sep 17 00:00:00 2001 From: Rob Knapen Date: Thu, 12 Mar 2026 15:44:45 +0100 Subject: [PATCH 53/60] Crop Yield use case: Reduced MLP projector, equal contribution of spatial and tabular encoders (for now). --- configs/model/yield_tessera_fusion_reg.yaml | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/configs/model/yield_tessera_fusion_reg.yaml b/configs/model/yield_tessera_fusion_reg.yaml index c30b6af..1e8c175 100644 --- a/configs/model/yield_tessera_fusion_reg.yaml +++ b/configs/model/yield_tessera_fusion_reg.yaml @@ -14,9 +14,8 @@ geo_encoder: geo_data_name: tessera projector: _target_: src.models.components.geo_encoders.mlp_projector.MLPProjector - nn_layers: 2 - hidden_dim: 512 - output_dim: 512 + nn_layers: 1 + output_dim: 256 - encoder: _target_: src.models.components.geo_encoders.tabular_encoder.TabularEncoder output_dim: 256 From e27ecdc8c73408270fc502e8832487b496152704 Mon Sep 17 00:00:00 2001 From: GabrieleTi Date: Sun, 15 Mar 2026 13:22:22 +0100 Subject: [PATCH 54/60] fix value 0 being ignored --- src/models/components/metrics/contrastive_validation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/models/components/metrics/contrastive_validation.py b/src/models/components/metrics/contrastive_validation.py index ce2385a..76d216a 100644 --- a/src/models/components/metrics/contrastive_validation.py +++ b/src/models/components/metrics/contrastive_validation.py @@ -39,7 +39,7 @@ def forward( k_threshold = configs.get("theta_k") aux_val = aux_vals[idx] - if k_threshold: + if k_threshold is not None: dynamic_k = ( sum(aux_val >= k_threshold).item() if is_max From e6115bf9fcfbf5f9a1c77fb7b44905efd3db3d79 Mon Sep 17 00:00:00 2001 From: GabrieleTi Date: Sun, 15 Mar 2026 13:24:19 +0100 Subject: [PATCH 55/60] Add model set up print statements --- src/models/components/geo_encoders/average_encoder.py | 2 ++ src/models/components/geo_encoders/cnn_encoder.py | 1 + src/models/components/geo_encoders/geoclip.py | 1 + src/models/components/geo_encoders/mlp_projector.py | 1 + src/models/components/geo_encoders/tabular_encoder.py | 1 + src/models/components/text_encoders/clip_text_encoder.py | 5 +++++ src/models/predictive_model.py | 2 ++ src/models/text_alignment_model.py | 2 ++ 8 files changed, 15 insertions(+) diff --git a/src/models/components/geo_encoders/average_encoder.py b/src/models/components/geo_encoders/average_encoder.py index 885b175..0c0eaf6 100644 --- a/src/models/components/geo_encoders/average_encoder.py +++ b/src/models/components/geo_encoders/average_encoder.py @@ -34,6 +34,8 @@ def setup(self) -> List[str]: """ self.output_dim = self.dict_n_bands_default[self.geo_data_name] self.geo_encoder = nn.Identity() + print(f"Model set up with average geo-encoder for {self.geo_data_name}") + return [] @override diff --git a/src/models/components/geo_encoders/cnn_encoder.py b/src/models/components/geo_encoders/cnn_encoder.py index d73a7b3..34f8f48 100644 --- a/src/models/components/geo_encoders/cnn_encoder.py +++ b/src/models/components/geo_encoders/cnn_encoder.py @@ -135,6 +135,7 @@ def get_backbone(self): def setup(self) -> List[str]: # TODO: could you make sure new layers are returned here to be added to trainable parts? # Maybe move the get_backbone method in here? + print(f"Model setup with cnn geo-encoder for {self.geo_data_name}") return [] @override diff --git a/src/models/components/geo_encoders/geoclip.py b/src/models/components/geo_encoders/geoclip.py index e530dd1..38dffdb 100644 --- a/src/models/components/geo_encoders/geoclip.py +++ b/src/models/components/geo_encoders/geoclip.py @@ -24,6 +24,7 @@ def __init__( def setup(self) -> List[str]: self.geo_encoder = LocationEncoder() self.output_dim = self.geo_encoder.LocEnc0.head[0].out_features + print("Model setup with GeoClip coordinate encoder") return [] @override diff --git a/src/models/components/geo_encoders/mlp_projector.py b/src/models/components/geo_encoders/mlp_projector.py index f4c1137..e622216 100644 --- a/src/models/components/geo_encoders/mlp_projector.py +++ b/src/models/components/geo_encoders/mlp_projector.py @@ -27,6 +27,7 @@ def __init__( @override def setup(self) -> List[str]: self.configure_nn() + print("Model setup with MLP projector") return ["net"] def set_input_dim(self, input_dim: int) -> None: diff --git a/src/models/components/geo_encoders/tabular_encoder.py b/src/models/components/geo_encoders/tabular_encoder.py index 09af55c..1ae4b1d 100644 --- a/src/models/components/geo_encoders/tabular_encoder.py +++ b/src/models/components/geo_encoders/tabular_encoder.py @@ -35,6 +35,7 @@ def __init__( @override def setup(self, input_dim: int = None) -> list[str]: self.configure_nn(input_dim) + print("Model setup with Tabular geo-encoder") return ["tabular_encoder"] def set_tabular_input_dim(self, input_dim: int) -> None: diff --git a/src/models/components/text_encoders/clip_text_encoder.py b/src/models/components/text_encoders/clip_text_encoder.py index 7642865..075a390 100644 --- a/src/models/components/text_encoders/clip_text_encoder.py +++ b/src/models/components/text_encoders/clip_text_encoder.py @@ -24,8 +24,13 @@ def __init__(self, hf_cache_dir: str = "../.cache", output_normalization="l2") - self.projector = GeoCLIP().image_encoder.mlp + self.model.vision_model = None + self.model.visual_projection = None + self.output_dim = 512 + print("Model set up with CLIP text encoder") + @override def forward(self, batch: Dict[str, torch.Tensor], mode: str) -> torch.Tensor: # Get text inputs diff --git a/src/models/predictive_model.py b/src/models/predictive_model.py index b46a3f4..7b951c2 100644 --- a/src/models/predictive_model.py +++ b/src/models/predictive_model.py @@ -62,7 +62,9 @@ def setup(self, stage: str) -> None: if isinstance(self.trainable_modules, tuple): self.trainable_modules = list(self.trainable_modules) + print("-------Model------------") self.setup_encoders_adapters() + print("------------------------") # Freezing requested parts self.freezer() diff --git a/src/models/text_alignment_model.py b/src/models/text_alignment_model.py index 9cffc5e..49b8a91 100644 --- a/src/models/text_alignment_model.py +++ b/src/models/text_alignment_model.py @@ -69,7 +69,9 @@ def setup(self, stage: str) -> None: self.tabular_dim = self.trainer.datamodule.tabular_dim # Set up encoders and missing adapters/projectors + print("-------Model------------") self.setup_encoders_adapters() + print("------------------------") # Freeze requested parts self.freezer() From dbdd47683dde7b02a70b9a109e550677963bf69e Mon Sep 17 00:00:00 2001 From: GabrieleTi Date: Sun, 15 Mar 2026 13:46:00 +0100 Subject: [PATCH 56/60] Guatemala UC tessera --- pyproject.toml | 1 + src/data/heat_guatemala_dataset.py | 10 +++++++++- src/data_preprocessing/tessera_embeds.py | 12 +++++++++++- 3 files changed, 21 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 4c51d60..910e924 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,6 +33,7 @@ dependencies = [ "peft>=0.18.1", "llm2vec", "setuptools<81", + "geotessera>=0.7.3", ] [project.optional-dependencies] diff --git a/src/data/heat_guatemala_dataset.py b/src/data/heat_guatemala_dataset.py index c20b686..f4fabac 100644 --- a/src/data/heat_guatemala_dataset.py +++ b/src/data/heat_guatemala_dataset.py @@ -56,7 +56,7 @@ def __init__( dataset_name="heat_guatemala", seed=seed, cache_dir=cache_dir, - implemented_mod={"coords"}, + implemented_mod={"coords", "tessera"}, mock=mock, use_features=use_features, ) @@ -67,6 +67,14 @@ def __init__( def setup(self) -> None: """No files to download / prepare for this dataset.""" + # Set up each requested modality + for mod in self.modalities.keys(): + if mod == "coords" and len(self.modalities.keys()) == 1: + return + elif mod == "tessera": + self.setup_tessera() + # elif mod == "aef": + # self.setup_aef() return @override diff --git a/src/data_preprocessing/tessera_embeds.py b/src/data_preprocessing/tessera_embeds.py index fcc6c32..c2414b8 100644 --- a/src/data_preprocessing/tessera_embeds.py +++ b/src/data_preprocessing/tessera_embeds.py @@ -199,7 +199,7 @@ def tessera_from_df( # Tessera connection cache_dir = os.path.join(cache_dir, "tessera") - gt = GeoTessera(cache_dir=cache_dir) + gt = GeoTessera(cache_dir=cache_dir, embeddings_dir=cache_dir, dataset_version="v1") # Iter each coord n = len(model_ready_df) @@ -267,3 +267,13 @@ def inspect_np_arr_as_tiff( dst.write(arr_to_write[i], i + 1) print(f"Tiff version of np array saved to {file_path}") + + +if __name__ == "__main__": + os.chdir("../..") + + df = pd.read_csv("data/heat_guatemala/model_ready_heat_guatemala.csv") + + tessera_from_df( + df, "data/heat_guatemala/eo/tessera_2024", year=2024, tile_size=10, cache_dir="data/cache" + ) From 6891a51b3b7df17a688ceeb4f784c76ffd80d611 Mon Sep 17 00:00:00 2001 From: GabrieleTi Date: Sun, 15 Mar 2026 13:48:23 +0100 Subject: [PATCH 57/60] Alignment training --- configs/paths/shared.yaml | 4 ++-- scripts/schedule.sh | 20 ++++++++++++-------- 2 files changed, 14 insertions(+), 10 deletions(-) diff --git a/configs/paths/shared.yaml b/configs/paths/shared.yaml index 76e6046..a4702b0 100644 --- a/configs/paths/shared.yaml +++ b/configs/paths/shared.yaml @@ -5,8 +5,8 @@ root_dir: ${oc.env:PROJECT_ROOT,./} # path to data directory -data_dir: ${oc.env:DATA_DIR,oc.env:SHARED_ROOT/data/,${paths.root_dir}/data/} -cache_dir: ${oc.env:CACHE_DIR,${paths.data_dir}/cache} +data_dir: ${oc.env:DATA_DIR,${oc.env:SHARED_ROOT,${paths.root_dir}}/data} +cache_dir: ${oc.env:SHARED_CACHE,${paths.data_dir}/cache} # path to logging directory log_dir: ${oc.env:SHARED_ROOT}/logs/ diff --git a/scripts/schedule.sh b/scripts/schedule.sh index 8054b63..8bab438 100644 --- a/scripts/schedule.sh +++ b/scripts/schedule.sh @@ -1,23 +1,27 @@ #!/bin/bash -#SBATCH--cpus-per-task=8 -#SBATCH--partition=gpu -#SBATCH--gpus=1 -#SBATCH--job-name=aether -#SBATCH--mem=100G -#SBATCH--time=100 +#SBATCH --cpus-per-task=8 +#SBATCH --partition=gpu +#SBATCH --gpus=1 +#SBATCH --job-name=aether +#SBATCH --mem=100G +#SBATCH --time=100 +#SBATCH --output=logs/out_%j.out +#SBATCH --error=logs/err_%j.err # Schedule execution of many runs # Run from root folder with: bash scripts/schedule.sh # Variables +# shellcheck disable=SC1091 source .env -# Environment +#Environment +# shellcheck disable=SC1091 source .venv/bin/activate # Runs #srun python src/train.py experiment=alignment trainer=$TRAINER_PROFILE logger=$LOGGER -#srun python src/train.py experiment=prediction logger=wandb +srun python -u src/train.py experiment=alignment_v1 # example runs with overwritten configs #srun python src/train.py experiment=alignment trainer=ddp_sim trainer.max_epochs=10 data.pin_memory=false From 0c710d946de2f7d1dd9a1fe1016de9135865c1d4 Mon Sep 17 00:00:00 2001 From: gabriele Date: Thu, 19 Mar 2026 13:45:08 +0100 Subject: [PATCH 58/60] De-duplicate geotessera requirements Remove from optional --- pyproject.toml | 3 --- 1 file changed, 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 910e924..f49df39 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,9 +41,6 @@ create-data = [ "geemap>=0.36.6", "pipreqs>=0.5.0", ] -geotessera = [ - "geotessera>=0.7.3", -] [tool.pytest.ini_options] addopts = [ From 9e406a467778c7d4384f6681e77994d7e7ed0bc8 Mon Sep 17 00:00:00 2001 From: gabriele Date: Thu, 19 Mar 2026 13:47:06 +0100 Subject: [PATCH 59/60] Create input and output dimensions as attributes --- src/models/components/pred_heads/linear_pred_head.py | 3 +-- src/models/components/pred_heads/mlp_pred_head.py | 3 +-- src/models/components/pred_heads/mlp_regression_head.py | 3 +-- 3 files changed, 3 insertions(+), 6 deletions(-) diff --git a/src/models/components/pred_heads/linear_pred_head.py b/src/models/components/pred_heads/linear_pred_head.py index 94338a7..61efc7d 100644 --- a/src/models/components/pred_heads/linear_pred_head.py +++ b/src/models/components/pred_heads/linear_pred_head.py @@ -18,8 +18,7 @@ def __init__( :param output_dim: the size of output dimension """ super().__init__() - if input_dim and output_dim: - self.set_dim(input_dim, output_dim) + self.set_dim(input_dim, output_dim) @override def forward(self, feats: torch.Tensor) -> torch.Tensor: diff --git a/src/models/components/pred_heads/mlp_pred_head.py b/src/models/components/pred_heads/mlp_pred_head.py index 144d602..282b3bf 100644 --- a/src/models/components/pred_heads/mlp_pred_head.py +++ b/src/models/components/pred_heads/mlp_pred_head.py @@ -24,8 +24,7 @@ def __init__( super().__init__() self.nn_layers = nn_layers self.hidden_dim = hidden_dim - if input_dim and output_dim: - self.set_dim(input_dim, output_dim) + self.set_dim(input_dim, output_dim) @override def forward(self, feats: torch.Tensor) -> torch.Tensor: diff --git a/src/models/components/pred_heads/mlp_regression_head.py b/src/models/components/pred_heads/mlp_regression_head.py index a179efe..6894bb4 100644 --- a/src/models/components/pred_heads/mlp_regression_head.py +++ b/src/models/components/pred_heads/mlp_regression_head.py @@ -39,8 +39,7 @@ def __init__( self.nn_layers = nn_layers self.hidden_dim = hidden_dim self.dropout = dropout - if input_dim and output_dim: - self.set_dim(input_dim, output_dim) + self.set_dim(input_dim, output_dim) @override def forward(self, feats: torch.Tensor) -> torch.Tensor: From d35806692cfa40b92aa4b036814dd5f4d6ada9a5 Mon Sep 17 00:00:00 2001 From: GabrieleTi Date: Mon, 23 Mar 2026 16:19:30 +0100 Subject: [PATCH 60/60] Fix broken tests --- src/models/components/pred_heads/linear_pred_head.py | 3 ++- src/models/components/pred_heads/mlp_pred_head.py | 4 +++- src/models/components/pred_heads/mlp_regression_head.py | 4 +++- 3 files changed, 8 insertions(+), 3 deletions(-) diff --git a/src/models/components/pred_heads/linear_pred_head.py b/src/models/components/pred_heads/linear_pred_head.py index 61efc7d..94338a7 100644 --- a/src/models/components/pred_heads/linear_pred_head.py +++ b/src/models/components/pred_heads/linear_pred_head.py @@ -18,7 +18,8 @@ def __init__( :param output_dim: the size of output dimension """ super().__init__() - self.set_dim(input_dim, output_dim) + if input_dim and output_dim: + self.set_dim(input_dim, output_dim) @override def forward(self, feats: torch.Tensor) -> torch.Tensor: diff --git a/src/models/components/pred_heads/mlp_pred_head.py b/src/models/components/pred_heads/mlp_pred_head.py index 282b3bf..6d4c124 100644 --- a/src/models/components/pred_heads/mlp_pred_head.py +++ b/src/models/components/pred_heads/mlp_pred_head.py @@ -24,7 +24,9 @@ def __init__( super().__init__() self.nn_layers = nn_layers self.hidden_dim = hidden_dim - self.set_dim(input_dim, output_dim) + + if input_dim and output_dim: + self.set_dim(input_dim, output_dim) @override def forward(self, feats: torch.Tensor) -> torch.Tensor: diff --git a/src/models/components/pred_heads/mlp_regression_head.py b/src/models/components/pred_heads/mlp_regression_head.py index 6894bb4..d9553ec 100644 --- a/src/models/components/pred_heads/mlp_regression_head.py +++ b/src/models/components/pred_heads/mlp_regression_head.py @@ -39,7 +39,9 @@ def __init__( self.nn_layers = nn_layers self.hidden_dim = hidden_dim self.dropout = dropout - self.set_dim(input_dim, output_dim) + + if input_dim and output_dim: + self.set_dim(input_dim, output_dim) @override def forward(self, feats: torch.Tensor) -> torch.Tensor: