diff --git a/.env.example b/.env.example index 33713cb..4e8a6ee 100644 --- a/.env.example +++ b/.env.example @@ -10,6 +10,12 @@ TRAINER_PROFILE="gpu" # cpu/gpu/mps/ddp HF_HOME="${PROJECT_ROOT}/.cache/huggingface/" # set or will default to './.cache/huggingface/' DATA_DIR="${PROJECT_ROOT}/data/" # set to your local data folder (for aether), or will default to '${PROJECT_ROOT}/data/' +# Base cache directory for TESSERA. +# GeoTessera registry/metadata is stored here; large raw source tiles go in the +# raw/ subfolder. This folder can get very large — point it at an external drive +# if needed. +TESSERA_EMBEDDINGS_DIR="${PROJECT_ROOT}/data/cache/tessera/" + # Working directories # STORAGE_MODE=# or "shared" # SHARED_CACHE=# or "/path/to/shared/.cache" diff --git a/.gitignore b/.gitignore index 0bec17c..2dfb5d6 100644 --- a/.gitignore +++ b/.gitignore @@ -228,3 +228,4 @@ notebooks/01-TvdP-tmp.ipynb */source/* *.tif # for now ..env.swp +/data/yield_africa/ diff --git a/configs/callbacks/default.yaml b/configs/callbacks/default.yaml index 149a92f..e161020 100644 --- a/configs/callbacks/default.yaml +++ b/configs/callbacks/default.yaml @@ -19,4 +19,4 @@ early_stopping: mode: "min" model_summary: - max_depth: 2 + max_depth: 1 diff --git a/configs/data/butterfly_coords_text.yaml b/configs/data/butterfly_coords_text.yaml index 92b6bd7..9c98968 100644 --- a/configs/data/butterfly_coords_text.yaml +++ b/configs/data/butterfly_coords_text.yaml @@ -14,6 +14,7 @@ dataset: caption_builder: _target_: src.data.butterfly_caption_builder.ButterflyCaptionBuilder templates_fname: v3.json + concepts_fname: v1.json data_dir: ${paths.data_dir}/s2bms seed: ${seed} diff --git a/configs/data/butterfly_full_param_example.yaml b/configs/data/butterfly_full_param_example.yaml index 541f4c2..78923f5 100644 --- a/configs/data/butterfly_full_param_example.yaml +++ b/configs/data/butterfly_full_param_example.yaml @@ -22,7 +22,8 @@ dataset: caption_builder: _target_: src.data.butterfly_caption_builder.ButterflyCaptionBuilder - templates_fname: caption_templates.json + templates_fname: v3.json + concepts_fname: v1.json data_dir: ${paths.data_dir}/s2bms seed: ${seed} diff --git a/configs/data/yield_africa_all.yaml b/configs/data/yield_africa_all.yaml new file mode 100644 index 0000000..e40a8a2 --- /dev/null +++ b/configs/data/yield_africa_all.yaml @@ -0,0 +1,33 @@ +_target_: src.data.base_datamodule.BaseDataModule + +dataset: + _target_: src.data.yield_africa_dataset.YieldAfricaDataset + data_dir: ${paths.data_dir} + modalities: + coords: {} + use_target_data: true + use_features: true + use_aux_data: none + seed: ${seed} + cache_dir: ${paths.cache_dir} + # Country/year filters — set to a list to restrict, null to include all. + # countries and years select only the listed values; + # exclude_countries and exclude_years drop the listed values. + countries: ["BF", "BUR", "ETH", "KEN", "MAL", "RWA", "TAN", "ZAM"] + years: [2014, 2016, 2017, 2018, 2019, 2020, 2021, 2023, 2024] + exclude_countries: null + exclude_years: null + +batch_size: 64 +num_workers: 0 +pin_memory: false + +# todo - use spatial split (pre-calculate and then load from file) +# - hold out country/year block for validation +# - or leave one country out for validation +# - normalize data by country (after filtering) + +split_mode: "random" +train_val_test_split: [0.7, 0.15, 0.15] +save_split: false +seed: ${seed} diff --git a/configs/data/yield_africa_loco.yaml b/configs/data/yield_africa_loco.yaml new file mode 100644 index 0000000..92559d0 --- /dev/null +++ b/configs/data/yield_africa_loco.yaml @@ -0,0 +1,33 @@ +_target_: src.data.base_datamodule.BaseDataModule + +dataset: + _target_: src.data.yield_africa_dataset.YieldAfricaDataset + data_dir: ${paths.data_dir} + modalities: + coords: {} + use_target_data: true + use_features: true + use_aux_data: none + seed: ${seed} + cache_dir: ${paths.cache_dir} + # Include all countries and years so the split file determines the partition. + countries: ["BF", "BUR", "ETH", "KEN", "MAL", "RWA", "TAN", "ZAM"] + years: [2014, 2016, 2017, 2018, 2019, 2020, 2021, 2023, 2024] + exclude_countries: null + exclude_years: null + +batch_size: 64 +num_workers: 0 +pin_memory: false + +# Leave-one-country-out split loaded from a pre-generated file. +# Generate split files first: +# python src/data_preprocessing/yield_africa_loco_splits.py --data_dir +# +# Override saved_split_file_name at the command line to change the held-out country: +# python src/train.py experiment=yield_africa_tabular_loco \ +# data.saved_split_file_name=split_loco_RWA.pth +split_mode: "from_file" +saved_split_file_name: "split_loco_KEN.pth" +save_split: false +seed: ${seed} diff --git a/configs/data/yield_africa_spatial.yaml b/configs/data/yield_africa_spatial.yaml new file mode 100644 index 0000000..9313100 --- /dev/null +++ b/configs/data/yield_africa_spatial.yaml @@ -0,0 +1,33 @@ +_target_: src.data.base_datamodule.BaseDataModule + +dataset: + _target_: src.data.yield_africa_dataset.YieldAfricaDataset + data_dir: ${paths.data_dir} + modalities: + coords: {} + use_target_data: true + use_features: true + use_aux_data: none + seed: ${seed} + cache_dir: ${paths.cache_dir} + # Include all countries and years so the split file determines the partition. + countries: ["BF", "BUR", "ETH", "KEN", "MAL", "RWA", "TAN", "ZAM"] + years: [2014, 2016, 2017, 2018, 2019, 2020, 2021, 2023, 2024] + exclude_countries: null + exclude_years: null + +batch_size: 64 +num_workers: 0 +pin_memory: false + +# Spatial-cluster split loaded from a pre-generated file. +# Generate split files first (produces 10 km, 25 km, and 50 km variants): +# python src/data_preprocessing/yield_africa_spatial_splits.py --data_dir +# +# Override saved_split_file_name at the command line to change the cluster distance: +# python src/train.py experiment=yield_africa_tabular_spatial \ +# data.saved_split_file_name=split_spatial_10km.pth +split_mode: "from_file" +saved_split_file_name: "split_spatial_25km.pth" +save_split: false +seed: ${seed} diff --git a/configs/data/yield_africa_tessera.yaml b/configs/data/yield_africa_tessera.yaml new file mode 100644 index 0000000..533b483 --- /dev/null +++ b/configs/data/yield_africa_tessera.yaml @@ -0,0 +1,31 @@ +_target_: src.data.base_datamodule.BaseDataModule + +dataset: + _target_: src.data.yield_africa_dataset.YieldAfricaDataset + data_dir: ${paths.data_dir} + modalities: + tessera: + # size must match the tile_size used when running the preprocessing script. + # Default: 9 pixels (set by yield_africa_tessera_preprocess.py --tile_size). + size: 9 + format: npy + # year is intentionally omitted: yield_africa fetches per-record year tiles + # via the preprocessing script rather than a single bulk-year download. + use_target_data: true + use_features: true + use_aux_data: none + seed: ${seed} + cache_dir: ${paths.cache_dir} + countries: ["BF", "BUR", "ETH", "KEN", "MAL", "RWA", "TAN", "ZAM"] + years: [2014, 2016, 2017, 2018, 2019, 2020, 2021, 2023, 2024] + exclude_countries: null + exclude_years: null + +batch_size: 64 +num_workers: 0 +pin_memory: false + +split_mode: "random" +train_val_test_split: [0.7, 0.15, 0.15] +save_split: false +seed: ${seed} diff --git a/configs/data/yield_africa_tessera_loco.yaml b/configs/data/yield_africa_tessera_loco.yaml new file mode 100644 index 0000000..0be62c3 --- /dev/null +++ b/configs/data/yield_africa_tessera_loco.yaml @@ -0,0 +1,39 @@ +_target_: src.data.base_datamodule.BaseDataModule + +dataset: + _target_: src.data.yield_africa_dataset.YieldAfricaDataset + data_dir: ${paths.data_dir} + modalities: + tessera: + # size must match the tile_size used when running the preprocessing script. + # Default: 9 pixels (set by yield_africa_tessera_preprocess.py --tile_size). + size: 9 + format: npy + # year is intentionally omitted: yield_africa fetches per-record year tiles + # via the preprocessing script rather than a single bulk-year download. + use_target_data: true + use_features: true + use_aux_data: none + seed: ${seed} + cache_dir: ${paths.cache_dir} + # Include all countries and years so the split file determines the partition. + countries: ["BF", "BUR", "ETH", "KEN", "MAL", "RWA", "TAN", "ZAM"] + years: [2014, 2016, 2017, 2018, 2019, 2020, 2021, 2023, 2024] + exclude_countries: null + exclude_years: null + +batch_size: 64 +num_workers: 0 +pin_memory: false + +# Leave-one-country-out split loaded from a pre-generated file. +# Generate split files first: +# python src/data_preprocessing/yield_africa_loco_splits.py --data_dir +# +# Override saved_split_file_name at the command line to change the held-out country: +# python src/train.py experiment=yield_africa_tessera_fusion_loco \ +# data.saved_split_file_name=split_loco_RWA.pth +split_mode: "from_file" +saved_split_file_name: "split_loco_KEN.pth" +save_split: false +seed: ${seed} diff --git a/configs/data/yield_africa_tessera_spatial.yaml b/configs/data/yield_africa_tessera_spatial.yaml new file mode 100644 index 0000000..9424801 --- /dev/null +++ b/configs/data/yield_africa_tessera_spatial.yaml @@ -0,0 +1,39 @@ +_target_: src.data.base_datamodule.BaseDataModule + +dataset: + _target_: src.data.yield_africa_dataset.YieldAfricaDataset + data_dir: ${paths.data_dir} + modalities: + tessera: + # size must match the tile_size used when running the preprocessing script. + # Default: 9 pixels (set by yield_africa_tessera_preprocess.py --tile_size). + size: 9 + format: npy + # year is intentionally omitted: yield_africa fetches per-record year tiles + # via the preprocessing script rather than a single bulk-year download. + use_target_data: true + use_features: true + use_aux_data: none + seed: ${seed} + cache_dir: ${paths.cache_dir} + # Include all countries and years so the split file determines the partition. + countries: ["BF", "BUR", "ETH", "KEN", "MAL", "RWA", "TAN", "ZAM"] + years: [2014, 2016, 2017, 2018, 2019, 2020, 2021, 2023, 2024] + exclude_countries: null + exclude_years: null + +batch_size: 64 +num_workers: 0 +pin_memory: false + +# Spatial-cluster split loaded from a pre-generated file. +# Generate split files first (produces 10 km, 25 km, and 50 km variants): +# python src/data_preprocessing/yield_africa_spatial_splits.py --data_dir +# +# Override saved_split_file_name at the command line to change the cluster distance: +# python src/train.py experiment=yield_africa_tessera_fusion_spatial \ +# data.saved_split_file_name=split_spatial_10km.pth +split_mode: "from_file" +saved_split_file_name: "split_spatial_25km.pth" +save_split: false +seed: ${seed} diff --git a/configs/experiment/yield_africa_coords_reg.yaml b/configs/experiment/yield_africa_coords_reg.yaml new file mode 100644 index 0000000..5690656 --- /dev/null +++ b/configs/experiment/yield_africa_coords_reg.yaml @@ -0,0 +1,25 @@ +# @package _global_ +# configs/experiment/yield_africa_tabular_reg.yaml +# Variant: Tabular features only, full dataset + +defaults: + - override /model: yield_geoclip_reg + - override /data: yield_africa_all + - override /metrics: yield_africa_regression + +tags: ["yield_africa", "coords_only", "regression"] +seed: 12345 + +trainer: + min_epochs: 1 + max_epochs: 150 + +data: + batch_size: 64 + +logger: + wandb: + tags: ${tags} + group: "yield_africa" + aim: + experiment: "yield_africa" diff --git a/configs/experiment/yield_africa_fusion_loco.yaml b/configs/experiment/yield_africa_fusion_loco.yaml new file mode 100644 index 0000000..2540642 --- /dev/null +++ b/configs/experiment/yield_africa_fusion_loco.yaml @@ -0,0 +1,33 @@ +# @package _global_ +# configs/experiment/yield_africa_fusion_loco.yaml +# GeoClip + tabular fusion model evaluated with leave-one-country-out split. +# Default held-out country: KEN (largest, most representative test set). +# +# Generate split files first: +# python src/data_preprocessing/yield_africa_loco_splits.py --data_dir +# +# To evaluate on a different held-out country: +# python src/train.py experiment=yield_africa_fusion_loco \ +# data.saved_split_file_name=split_loco_RWA.pth + +defaults: + - override /model: yield_fusion_reg + - override /data: yield_africa_loco + - override /metrics: yield_africa_regression + +tags: ["yield_africa", "fusion", "regression", "loco"] +seed: 12345 + +trainer: + min_epochs: 1 + max_epochs: 150 + +data: + batch_size: 64 + +logger: + wandb: + tags: ${tags} + group: "yield_africa" + aim: + experiment: "yield_africa" diff --git a/configs/experiment/yield_africa_fusion_reg.yaml b/configs/experiment/yield_africa_fusion_reg.yaml new file mode 100644 index 0000000..fa1fbdd --- /dev/null +++ b/configs/experiment/yield_africa_fusion_reg.yaml @@ -0,0 +1,25 @@ +# @package _global_ +# configs/experiment/heat_guatemala_fusion_reg.yaml +# Variant C: GeoClip + tabular fusion + +defaults: + - override /model: yield_fusion_reg + - override /data: yield_africa_all + - override /metrics: yield_africa_regression + +tags: ["yield_africa", "fusion", "regression"] +seed: 12345 + +trainer: + min_epochs: 1 + max_epochs: 150 + +data: + batch_size: 64 + +logger: + wandb: + tags: ${tags} + group: "yield_africa" + aim: + experiment: "yield_africa" diff --git a/configs/experiment/yield_africa_fusion_spatial.yaml b/configs/experiment/yield_africa_fusion_spatial.yaml new file mode 100644 index 0000000..98c4221 --- /dev/null +++ b/configs/experiment/yield_africa_fusion_spatial.yaml @@ -0,0 +1,33 @@ +# @package _global_ +# configs/experiment/yield_africa_fusion_spatial.yaml +# GeoClip + tabular fusion model evaluated with a spatial-cluster split. +# Default cluster distance: 25 km (split_spatial_25km.pth). +# +# Generate split files first: +# python src/data_preprocessing/yield_africa_spatial_splits.py --data_dir +# +# To evaluate at a different cluster distance: +# python src/train.py experiment=yield_africa_fusion_spatial \ +# data.saved_split_file_name=split_spatial_10km.pth + +defaults: + - override /model: yield_fusion_reg + - override /data: yield_africa_spatial + - override /metrics: yield_africa_regression + +tags: ["yield_africa", "fusion", "regression", "spatial"] +seed: 12345 + +trainer: + min_epochs: 1 + max_epochs: 150 + +data: + batch_size: 64 + +logger: + wandb: + tags: ${tags} + group: "yield_africa" + aim: + experiment: "yield_africa" diff --git a/configs/experiment/yield_africa_tabular_loco.yaml b/configs/experiment/yield_africa_tabular_loco.yaml new file mode 100644 index 0000000..cce363d --- /dev/null +++ b/configs/experiment/yield_africa_tabular_loco.yaml @@ -0,0 +1,30 @@ +# @package _global_ +# configs/experiment/yield_africa_tabular_loco.yaml +# Tabular-only model evaluated with leave-one-country-out split. +# Default held-out country: KEN (largest, most representative test set). +# +# To evaluate on a different held-out country: +# python src/train.py experiment=yield_africa_tabular_loco \ +# data.saved_split_file_name=split_loco_RWA.pth + +defaults: + - override /model: yield_tabular_reg + - override /data: yield_africa_loco + - override /metrics: yield_africa_regression + +tags: ["yield_africa", "tabular_only", "regression", "loco"] +seed: 12345 + +trainer: + min_epochs: 1 + max_epochs: 150 + +data: + batch_size: 64 + +logger: + wandb: + tags: ${tags} + group: "yield_africa" + aim: + experiment: "yield_africa" diff --git a/configs/experiment/yield_africa_tabular_reg.yaml b/configs/experiment/yield_africa_tabular_reg.yaml new file mode 100644 index 0000000..57f6b36 --- /dev/null +++ b/configs/experiment/yield_africa_tabular_reg.yaml @@ -0,0 +1,25 @@ +# @package _global_ +# configs/experiment/yield_africa_tabular_reg.yaml +# Variant: Tabular features only, full dataset + +defaults: + - override /model: yield_tabular_reg + - override /data: yield_africa_all + - override /metrics: yield_africa_regression + +tags: ["yield_africa", "tabular_only", "regression"] +seed: 12345 + +trainer: + min_epochs: 1 + max_epochs: 150 + +data: + batch_size: 64 + +logger: + wandb: + tags: ${tags} + group: "yield_africa" + aim: + experiment: "yield_africa" diff --git a/configs/experiment/yield_africa_tabular_spatial.yaml b/configs/experiment/yield_africa_tabular_spatial.yaml new file mode 100644 index 0000000..9c57961 --- /dev/null +++ b/configs/experiment/yield_africa_tabular_spatial.yaml @@ -0,0 +1,33 @@ +# @package _global_ +# configs/experiment/yield_africa_tabular_spatial.yaml +# Tabular-only model evaluated with a spatial-cluster split. +# Default cluster distance: 25 km (split_spatial_25km.pth). +# +# Generate split files first: +# python src/data_preprocessing/yield_africa_spatial_splits.py --data_dir +# +# To evaluate at a different cluster distance: +# python src/train.py experiment=yield_africa_tabular_spatial \ +# data.saved_split_file_name=split_spatial_10km.pth + +defaults: + - override /model: yield_tabular_reg + - override /data: yield_africa_spatial + - override /metrics: yield_africa_regression + +tags: ["yield_africa", "tabular_only", "regression", "spatial"] +seed: 12345 + +trainer: + min_epochs: 1 + max_epochs: 150 + +data: + batch_size: 64 + +logger: + wandb: + tags: ${tags} + group: "yield_africa" + aim: + experiment: "yield_africa" diff --git a/configs/experiment/yield_africa_tessera_fusion_loco.yaml b/configs/experiment/yield_africa_tessera_fusion_loco.yaml new file mode 100644 index 0000000..ee9aa9d --- /dev/null +++ b/configs/experiment/yield_africa_tessera_fusion_loco.yaml @@ -0,0 +1,38 @@ +# @package _global_ +# configs/experiment/yield_africa_tessera_fusion_loco.yaml +# TESSERA + tabular fusion model evaluated with leave-one-country-out split. +# Default held-out country: KEN (largest, most representative test set). +# +# Requires: +# 1. TESSERA tiles pre-fetched: +# python src/data_preprocessing/yield_africa_tessera_preprocess.py --data_dir +# 2. LOCO split files pre-generated: +# python src/data_preprocessing/yield_africa_loco_splits.py --data_dir +# +# To evaluate on a different held-out country: +# python src/train.py experiment=yield_africa_tessera_fusion_loco \ +# data.saved_split_file_name=split_loco_RWA.pth + +defaults: + - override /model: yield_tessera_fusion_reg + - override /data: yield_africa_tessera_loco + - override /metrics: yield_africa_regression + +tags: ["yield_africa", "tessera_fusion", "regression", "loco"] +seed: 12345 + +trainer: + min_epochs: 1 + max_epochs: 150 + +data: + batch_size: 64 + dataset: + use_features: true + +logger: + wandb: + tags: ${tags} + group: "yield_africa" + aim: + experiment: "yield_africa" diff --git a/configs/experiment/yield_africa_tessera_fusion_reg.yaml b/configs/experiment/yield_africa_tessera_fusion_reg.yaml new file mode 100644 index 0000000..93c2052 --- /dev/null +++ b/configs/experiment/yield_africa_tessera_fusion_reg.yaml @@ -0,0 +1,31 @@ +# @package _global_ +# configs/experiment/yield_africa_tessera_fusion_reg.yaml +# Variant: TESSERA spatial encoder + tabular features fusion. +# Requires: +# 1. TESSERA tiles pre-fetched: +# python src/data_preprocessing/yield_africa_tessera_preprocess.py --data_dir +# 2. MultiModalEncoder geo_encoder_cfg support. + +defaults: + - override /model: yield_tessera_fusion_reg + - override /data: yield_africa_tessera + - override /metrics: yield_africa_regression + +tags: ["yield_africa", "tessera_fusion", "regression"] +seed: 12345 + +trainer: + min_epochs: 1 + max_epochs: 150 + +data: + batch_size: 64 + dataset: + use_features: true + +logger: + wandb: + tags: ${tags} + group: "yield_africa" + aim: + experiment: "yield_africa" diff --git a/configs/experiment/yield_africa_tessera_fusion_spatial.yaml b/configs/experiment/yield_africa_tessera_fusion_spatial.yaml new file mode 100644 index 0000000..b0eaf9d --- /dev/null +++ b/configs/experiment/yield_africa_tessera_fusion_spatial.yaml @@ -0,0 +1,38 @@ +# @package _global_ +# configs/experiment/yield_africa_tessera_fusion_spatial.yaml +# TESSERA + tabular fusion model evaluated with a spatial-cluster split. +# Default cluster distance: 25 km (split_spatial_25km.pth). +# +# Requires: +# 1. TESSERA tiles pre-fetched: +# python src/data_preprocessing/yield_africa_tessera_preprocess.py --data_dir +# 2. Spatial split files pre-generated: +# python src/data_preprocessing/yield_africa_spatial_splits.py --data_dir +# +# To evaluate at a different cluster distance: +# python src/train.py experiment=yield_africa_tessera_fusion_spatial \ +# data.saved_split_file_name=split_spatial_10km.pth + +defaults: + - override /model: yield_tessera_fusion_reg + - override /data: yield_africa_tessera_spatial + - override /metrics: yield_africa_regression + +tags: ["yield_africa", "tessera_fusion", "regression", "spatial"] +seed: 12345 + +trainer: + min_epochs: 1 + max_epochs: 150 + +data: + batch_size: 64 + dataset: + use_features: true + +logger: + wandb: + tags: ${tags} + group: "yield_africa" + aim: + experiment: "yield_africa" diff --git a/configs/experiment/yield_africa_tessera_reg.yaml b/configs/experiment/yield_africa_tessera_reg.yaml new file mode 100644 index 0000000..ca4e2ed --- /dev/null +++ b/configs/experiment/yield_africa_tessera_reg.yaml @@ -0,0 +1,27 @@ +# @package _global_ +# configs/experiment/yield_africa_tessera_reg.yaml +# Variant: TESSERA spatial encoder only (no tabular features). +# Requires tiles pre-fetched by: +# python src/data_preprocessing/yield_africa_tessera_preprocess.py --data_dir + +defaults: + - override /model: yield_tessera_reg + - override /data: yield_africa_tessera + - override /metrics: yield_africa_regression + +tags: ["yield_africa", "tessera_only", "regression"] +seed: 12345 + +trainer: + min_epochs: 1 + max_epochs: 150 + +data: + batch_size: 64 + +logger: + wandb: + tags: ${tags} + group: "yield_africa" + aim: + experiment: "yield_africa" diff --git a/configs/metrics/yield_africa_regression.yaml b/configs/metrics/yield_africa_regression.yaml new file mode 100644 index 0000000..7960283 --- /dev/null +++ b/configs/metrics/yield_africa_regression.yaml @@ -0,0 +1,8 @@ +_target_: src.models.components.metrics.metrics_wrapper.MetricsWrapper + +metrics: + - _target_: src.models.components.loss_fns.huber_loss.HuberLoss + - _target_: src.models.components.loss_fns.rmse_loss.RMSELoss + - _target_: src.models.components.loss_fns.mae_loss.MAELoss + - _target_: src.models.components.loss_fns.rrmse_loss.RRMSELoss + - _target_: src.models.components.metrics.r2.RSquared diff --git a/configs/model/example_for_encoder_wrapper.yaml b/configs/model/example_for_encoder_wrapper.yaml new file mode 100644 index 0000000..c2dc55e --- /dev/null +++ b/configs/model/example_for_encoder_wrapper.yaml @@ -0,0 +1,15 @@ +_target_: src.models.components.geo_encoders.encoder_wrapper.EncoderWrapper + +encoder_branches: + - encoder: + _target_: src.models.components.geo_encoders.average_encoder.AverageEncoder + geo_data_name: aef + projector: + _target_: src.models.components.geo_encoders.mlp_projector.MLPProjector + nn_layers: 2 + hidden_dim: 512 + output_dim: 512 +# - encoder: # another branch +# __target__: + +fusion_strategy: "concat" diff --git a/configs/model/geoclip_alignment.yaml b/configs/model/geoclip_alignment.yaml index 0753e79..52d6485 100644 --- a/configs/model/geoclip_alignment.yaml +++ b/configs/model/geoclip_alignment.yaml @@ -1,7 +1,7 @@ _target_: src.models.text_alignment_model.TextAlignmentModel -eo_encoder: - _target_: src.models.components.eo_encoders.geoclip.GeoClipCoordinateEncoder +geo_encoder: + _target_: src.models.components.geo_encoders.geoclip.GeoClipCoordinateEncoder text_encoder: _target_: src.models.components.text_encoders.clip_text_encoder.ClipTextEncoder diff --git a/configs/model/geoclip_llm2clip_alignment.yaml b/configs/model/geoclip_llm2clip_alignment.yaml index 76f8878..8ff882c 100644 --- a/configs/model/geoclip_llm2clip_alignment.yaml +++ b/configs/model/geoclip_llm2clip_alignment.yaml @@ -1,7 +1,7 @@ _target_: src.models.text_alignment_model.TextAlignmentModel -eo_encoder: - _target_: src.models.components.eo_encoders.geoclip.GeoClipCoordinateEncoder +geo_encoder: + _target_: src.models.components.geo_encoders.geoclip.GeoClipCoordinateEncoder text_encoder: _target_: src.models.components.text_encoders.llm2clip_text_encoder.LLM2CLIPTextEncoder diff --git a/configs/model/heat_fusion_reg.yaml b/configs/model/heat_fusion_reg.yaml index e507dd1..3671555 100644 --- a/configs/model/heat_fusion_reg.yaml +++ b/configs/model/heat_fusion_reg.yaml @@ -8,11 +8,18 @@ _target_: src.models.predictive_model.PredictiveModel -eo_encoder: - _target_: src.models.components.eo_encoders.multimodal_encoder.MultiModalEncoder - use_coords: true - use_tabular: true -# tab_embed_dim: 64 +geo_encoder: + _target_: src.models.components.geo_encoders.encoder_wrapper.EncoderWrapper + + encoder_branches: + - encoder: + _target_: src.models.components.geo_encoders.geoclip.GeoClipCoordinateEncoder + - encoder: + _target_: src.models.components.geo_encoders.tabular_encoder.TabularEncoder + output_dim: 64 + geo_data_name: tabular + + fusion_strategy: "concat" prediction_head: _target_: src.models.components.pred_heads.mlp_regression_head.MLPRegressionPredictionHead @@ -20,7 +27,7 @@ prediction_head: hidden_dim: 256 # GeoClip frozen; tabular projection + head are trained. -trainable_modules: [eo_encoder, prediction_head] +trainable_modules: [geo_encoder, prediction_head] metrics: ${metrics} diff --git a/configs/model/heat_geoclip_reg.yaml b/configs/model/heat_geoclip_reg.yaml index a33b976..d29701c 100644 --- a/configs/model/heat_geoclip_reg.yaml +++ b/configs/model/heat_geoclip_reg.yaml @@ -8,10 +8,8 @@ _target_: src.models.predictive_model.PredictiveModel -eo_encoder: - _target_: src.models.components.eo_encoders.multimodal_encoder.MultiModalEncoder - use_coords: true - use_tabular: false +geo_encoder: + _target_: src.models.components.geo_encoders.geoclip.GeoClipCoordinateEncoder prediction_head: _target_: src.models.components.pred_heads.mlp_regression_head.MLPRegressionPredictionHead diff --git a/configs/model/heat_tabular_reg.yaml b/configs/model/heat_tabular_reg.yaml index affadab..c7f9fff 100644 --- a/configs/model/heat_tabular_reg.yaml +++ b/configs/model/heat_tabular_reg.yaml @@ -10,17 +10,16 @@ # 1. HeatGuatemalaDataset.tabular_dim reads len(feat_names) from the CSV. # 2. BaseDataModule.tabular_dim delegates to the train dataset. # 3. PredictiveRegressionModel.setup() calls -# self.eo_encoder.build_tabular_branch(self.trainer.datamodule.tabular_dim) +# self.geo_encoder.build_tabular_branch(self.trainer.datamodule.tabular_dim) _target_: src.models.predictive_model.PredictiveModel metrics: ${metrics} -eo_encoder: - _target_: src.models.components.eo_encoders.multimodal_encoder.MultiModalEncoder - use_coords: false - use_tabular: true - tab_embed_dim: 64 +geo_encoder: + _target_: src.models.components.geo_encoders.tabular_encoder.TabularEncoder + output_dim: 64 + geo_data_name: tabular prediction_head: _target_: src.models.components.pred_heads.mlp_regression_head.MLPRegressionPredictionHead @@ -28,7 +27,7 @@ prediction_head: hidden_dim: 256 # Both encoder and head have trainable parameters. -trainable_modules: [eo_encoder, prediction_head] +trainable_modules: [geo_encoder, prediction_head] optimizer: _target_: torch.optim.Adam diff --git a/configs/model/predictive_cnn_s2.yaml b/configs/model/predictive_cnn_s2.yaml index 628ba66..52c46d9 100644 --- a/configs/model/predictive_cnn_s2.yaml +++ b/configs/model/predictive_cnn_s2.yaml @@ -1,15 +1,15 @@ _target_: src.models.predictive_model.PredictiveModel -eo_encoder: - _target_: src.models.components.eo_encoders.cnn_encoder.CNNEncoder +geo_encoder: + _target_: src.models.components.geo_encoders.cnn_encoder.CNNEncoder resnet_version: 18 freezing_strategy: none - eo_data_name: s2 + geo_data_name: s2 prediction_head: _target_: src.models.components.pred_heads.mlp_pred_head.MLPPredictionHead -trainable_modules: [eo_encoder, prediction_head] +trainable_modules: [geo_encoder, prediction_head] metrics: ${metrics} diff --git a/configs/model/predictive_geoclip.yaml b/configs/model/predictive_geoclip.yaml index ca7390f..7d9c0c5 100644 --- a/configs/model/predictive_geoclip.yaml +++ b/configs/model/predictive_geoclip.yaml @@ -1,7 +1,7 @@ _target_: src.models.predictive_model.PredictiveModel -eo_encoder: - _target_: src.models.components.eo_encoders.geoclip.GeoClipCoordinateEncoder +geo_encoder: + _target_: src.models.components.geo_encoders.geoclip.GeoClipCoordinateEncoder prediction_head: _target_: src.models.components.pred_heads.mlp_pred_head.MLPPredictionHead diff --git a/configs/model/yield_fusion_reg.yaml b/configs/model/yield_fusion_reg.yaml new file mode 100644 index 0000000..05b680f --- /dev/null +++ b/configs/model/yield_fusion_reg.yaml @@ -0,0 +1,44 @@ +_target_: src.models.predictive_model.PredictiveModel + +geo_encoder: + _target_: src.models.components.geo_encoders.encoder_wrapper.EncoderWrapper + + encoder_branches: + - encoder: + _target_: src.models.components.geo_encoders.geoclip.GeoClipCoordinateEncoder + - encoder: + _target_: src.models.components.geo_encoders.tabular_encoder.TabularEncoder + output_dim: 256 + dropout_prob: 0.2 + geo_data_name: tabular + + fusion_strategy: "concat" + +prediction_head: + _target_: src.models.components.pred_heads.mlp_regression_head.MLPRegressionPredictionHead + nn_layers: 2 + hidden_dim: 256 + dropout: 0.2 + +# GeoClip frozen; tabular projection + head are trained. +trainable_modules: [geo_encoder, prediction_head] +# Disable L2 normalization before MLP regression to keep magnitude of features +normalize_features: false + +metrics: ${metrics} + +optimizer: + _target_: torch.optim.Adam + _partial_: true + lr: 0.001 + weight_decay: 1e-4 + +scheduler: + _target_: torch.optim.lr_scheduler.ReduceLROnPlateau + _partial_: true + mode: min + factor: 0.1 + patience: 10 + +loss_fn: + _target_: src.models.components.loss_fns.huber_loss.HuberLoss diff --git a/configs/model/yield_geoclip_reg.yaml b/configs/model/yield_geoclip_reg.yaml new file mode 100644 index 0000000..20978e6 --- /dev/null +++ b/configs/model/yield_geoclip_reg.yaml @@ -0,0 +1,33 @@ +_target_: src.models.predictive_model.PredictiveModel + +geo_encoder: + _target_: src.models.components.geo_encoders.geoclip.GeoClipCoordinateEncoder + +prediction_head: + _target_: src.models.components.pred_heads.mlp_regression_head.MLPRegressionPredictionHead + nn_layers: 2 + hidden_dim: 256 + dropout: 0.2 + +# Only the prediction head is trained; GeoClip encoder is frozen. +trainable_modules: [prediction_head] +# Disable L2 normalization before MLP regression to keep magnitude of features +normalize_features: false + +metrics: ${metrics} + +optimizer: + _target_: torch.optim.Adam + _partial_: true + lr: 0.001 + weight_decay: 1e-4 + +scheduler: + _target_: torch.optim.lr_scheduler.ReduceLROnPlateau + _partial_: true + mode: min + factor: 0.1 + patience: 10 + +loss_fn: + _target_: src.models.components.loss_fns.huber_loss.HuberLoss diff --git a/configs/model/yield_tabular_reg.yaml b/configs/model/yield_tabular_reg.yaml new file mode 100644 index 0000000..22d431a --- /dev/null +++ b/configs/model/yield_tabular_reg.yaml @@ -0,0 +1,36 @@ +_target_: src.models.predictive_model.PredictiveModel + +geo_encoder: + _target_: src.models.components.geo_encoders.tabular_encoder.TabularEncoder + output_dim: 256 + dropout_prob: 0.2 + geo_data_name: tabular + +prediction_head: + _target_: src.models.components.pred_heads.mlp_regression_head.MLPRegressionPredictionHead + nn_layers: 2 + hidden_dim: 256 + dropout: 0.2 + +# Both encoder and head have trainable parameters. +trainable_modules: [geo_encoder, prediction_head] +# Disable L2 normalization before MLP regression to keep magnitude of features +normalize_features: false + +metrics: ${metrics} + +optimizer: + _target_: torch.optim.Adam + _partial_: true + lr: 0.001 + weight_decay: 1e-4 + +scheduler: + _target_: torch.optim.lr_scheduler.ReduceLROnPlateau + _partial_: true + mode: min + factor: 0.1 + patience: 10 + +loss_fn: + _target_: src.models.components.loss_fns.huber_loss.HuberLoss diff --git a/configs/model/yield_tessera_fusion_reg.yaml b/configs/model/yield_tessera_fusion_reg.yaml new file mode 100644 index 0000000..1e8c175 --- /dev/null +++ b/configs/model/yield_tessera_fusion_reg.yaml @@ -0,0 +1,54 @@ +_target_: src.models.predictive_model.PredictiveModel + +# MultiModalEncoder with a pluggable geo encoder. +# The geo_encoder_cfg replaces the hardcoded GeoClipCoordinateEncoder with +# AverageEncoder(tessera), so the spatial branch uses inter-annual phenology +# instead of static coordinate embeddings. + +geo_encoder: + _target_: src.models.components.geo_encoders.encoder_wrapper.EncoderWrapper + + encoder_branches: + - encoder: + _target_: src.models.components.geo_encoders.average_encoder.AverageEncoder + geo_data_name: tessera + projector: + _target_: src.models.components.geo_encoders.mlp_projector.MLPProjector + nn_layers: 1 + output_dim: 256 + - encoder: + _target_: src.models.components.geo_encoders.tabular_encoder.TabularEncoder + output_dim: 256 + dropout_prob: 0.2 + geo_data_name: tabular + + fusion_strategy: "concat" + +prediction_head: + _target_: src.models.components.pred_heads.mlp_regression_head.MLPRegressionPredictionHead + nn_layers: 2 + hidden_dim: 256 + dropout: 0.2 + +# geo_encoder includes the tessera AverageEncoder + tabular branch; head is always trained. +trainable_modules: [geo_encoder, prediction_head] +# Disable L2 normalization before MLP regression to keep magnitude of features +normalize_features: false + +metrics: ${metrics} + +optimizer: + _target_: torch.optim.Adam + _partial_: true + lr: 0.001 + weight_decay: 1e-4 + +scheduler: + _target_: torch.optim.lr_scheduler.ReduceLROnPlateau + _partial_: true + mode: min + factor: 0.1 + patience: 10 + +loss_fn: + _target_: src.models.components.loss_fns.huber_loss.HuberLoss diff --git a/configs/model/yield_tessera_reg.yaml b/configs/model/yield_tessera_reg.yaml new file mode 100644 index 0000000..b5ef731 --- /dev/null +++ b/configs/model/yield_tessera_reg.yaml @@ -0,0 +1,34 @@ +_target_: src.models.predictive_model.PredictiveModel + +geo_encoder: + _target_: src.models.components.geo_encoders.average_encoder.AverageEncoder + geo_data_name: tessera + # output_dim defaults to 128 (the native tessera channel count); no projection. + +prediction_head: + _target_: src.models.components.pred_heads.mlp_regression_head.MLPRegressionPredictionHead + nn_layers: 2 + hidden_dim: 256 + dropout: 0.2 + +trainable_modules: [geo_encoder, prediction_head] +# Disable L2 normalization before MLP regression to keep magnitude of features +normalize_features: false + +metrics: ${metrics} + +optimizer: + _target_: torch.optim.Adam + _partial_: true + lr: 0.001 + weight_decay: 1e-4 + +scheduler: + _target_: torch.optim.lr_scheduler.ReduceLROnPlateau + _partial_: true + mode: min + factor: 0.1 + patience: 10 + +loss_fn: + _target_: src.models.components.loss_fns.huber_loss.HuberLoss diff --git a/configs/paths/shared.yaml b/configs/paths/shared.yaml index 76e6046..a4702b0 100644 --- a/configs/paths/shared.yaml +++ b/configs/paths/shared.yaml @@ -5,8 +5,8 @@ root_dir: ${oc.env:PROJECT_ROOT,./} # path to data directory -data_dir: ${oc.env:DATA_DIR,oc.env:SHARED_ROOT/data/,${paths.root_dir}/data/} -cache_dir: ${oc.env:CACHE_DIR,${paths.data_dir}/cache} +data_dir: ${oc.env:DATA_DIR,${oc.env:SHARED_ROOT,${paths.root_dir}}/data} +cache_dir: ${oc.env:SHARED_CACHE,${paths.data_dir}/cache} # path to logging directory log_dir: ${oc.env:SHARED_ROOT}/logs/ diff --git a/data/s2bms/concept_captions/v1.json b/data/s2bms/concept_captions/v1.json new file mode 100644 index 0000000..323089f --- /dev/null +++ b/data/s2bms/concept_captions/v1.json @@ -0,0 +1,14 @@ +[ + { + "concept_caption": "Forested area", + "is_max": true, + "theta_k": 0.5, + "col": "aux_corine_frac_3" + }, + { + "concept_caption": "Sparsely populated area", + "is_max": false, + "theta_k": 0.2, + "col": "aux_corine_frac_11" + } +] diff --git a/data/s2bms/concept_captions/v2.json b/data/s2bms/concept_captions/v2.json new file mode 100644 index 0000000..9b2bdd5 --- /dev/null +++ b/data/s2bms/concept_captions/v2.json @@ -0,0 +1,103 @@ +[ + { + "concept_caption": "Densely populated area with many houses", + "is_max": true, + "theta_k": 0.3, + "col": "aux_corine_frac_11" + }, + { + "concept_caption": "Very sparsely populated area with few houses", + "is_max": false, + "theta_k": 0.05, + "col": "aux_corine_frac_11" + },{ + "concept_caption": "Area with infrastructure such as roads, railways, airport, ports and heavy industry.", + "is_max": true, + "theta_k": 0.1, + "col": "aux_corine_frac_12" + }, + { + "concept_caption": "Arable land with crops for agriculture", + "is_max": true, + "theta_k": 0.65, + "col": "aux_corine_frac_21" + }, + { + "concept_caption": "Pasture fields with grass for grazing animals", + "is_max": true, + "theta_k": 0.6, + "col": "aux_corine_frac_231" + }, + { + "concept_caption": "Agricultural land used for crops, pasture or mixed farming", + "is_max": true, + "theta_k": 0.05, + "col": "aux_corine_frac_24" + }, + { + "concept_caption": "Forested area with many trees", + "is_max": true, + "theta_k": 0.25, + "col": "aux_corine_frac_31" + }, + { + "concept_caption": "Scrub area with trees, shrub, moors.", + "is_max": true, + "theta_k": 0.2, + "col": "aux_corine_frac_32" + }, + { + "concept_caption": "Moorlands and heathlands with low vegetation", + "is_max": true, + "theta_k": 0.2, + "col": "aux_corine_frac_322" + }, + { + "concept_caption": "Wetlands such as marshes, swamps, mudflats and bogs.", + "is_max": true, + "theta_k": 0.2, + "col": "aux_corine_frac_4" + }, + { + "concept_caption": "Peat bogs", + "is_max": true, + "theta_k": 0.2, + "col": "aux_corine_frac_412" + }, + { + "concept_caption": "Water bodies such as lakes, rivers and sea", + "is_max": true, + "theta_k": 0.2, + "col": "aux_corine_frac_5" + }, + { + "concept_caption": "Warm area with high summer temperatures", + "is_max": true, + "theta_k": 22, + "col": "aux_bioclim_05" + }, + { + "concept_caption": "Cold area with low winter temperatures", + "is_max": false, + "theta_k": 0, + "col": "aux_bioclim_06" + }, + { + "concept_caption": "Wet area with a lot of rainfall", + "is_max": true, + "theta_k": 950, + "col": "aux_bioclim_12" + }, + { + "concept_caption": "Remote area far from roads and infrastructure", + "is_max": true, + "theta_k": 1500, + "col": "aux_meandist_road" + }, + { + "concept_caption": "Densely populated area with many houses", + "is_max": true, + "theta_k": 1500, + "col": "aux_pop_density" + } +] diff --git a/data/s2bms/caption_templates/v1.json b/data/s2bms/location_caption_templates/v1.json similarity index 100% rename from data/s2bms/caption_templates/v1.json rename to data/s2bms/location_caption_templates/v1.json diff --git a/data/s2bms/caption_templates/v3.json b/data/s2bms/location_caption_templates/v3.json similarity index 100% rename from data/s2bms/caption_templates/v3.json rename to data/s2bms/location_caption_templates/v3.json diff --git a/data/s2bms/caption_templates/v4.json b/data/s2bms/location_caption_templates/v4.json similarity index 100% rename from data/s2bms/caption_templates/v4.json rename to data/s2bms/location_caption_templates/v4.json diff --git a/data/s2bms/caption_templates/v5.json b/data/s2bms/location_caption_templates/v5.json similarity index 100% rename from data/s2bms/caption_templates/v5.json rename to data/s2bms/location_caption_templates/v5.json diff --git a/notebooks/04-TvdP_generate-caption-templates.ipynb b/notebooks/04-TvdP_generate-caption-templates.ipynb index b5c90f1..39a6ef9 100644 --- a/notebooks/04-TvdP_generate-caption-templates.ipynb +++ b/notebooks/04-TvdP_generate-caption-templates.ipynb @@ -44,11 +44,11 @@ "metadata": {}, "outputs": [], "source": [ - "# tmp = cg.generate_captions(n=20, seed=0, save_path=os.path.join(os.environ['DATA_DIR'], 's2bms/caption_templates'))\n", + "# tmp = cg.generate_captions(n=20, seed=0, save_path=os.path.join(os.environ['DATA_DIR'], 's2bms/location_caption_templates'))\n", "tmp = cg.generate_captions(\n", " n=50,\n", " seed=0,\n", - " save_path=os.path.join(os.environ[\"PROJECT_ROOT\"], \"data/s2bms/caption_templates\"),\n", + " save_path=os.path.join(os.environ[\"PROJECT_ROOT\"], \"data/s2bms/location_caption_templates\"),\n", ")" ] } diff --git a/notebooks/06-TvdP-inference-language-alignment.ipynb b/notebooks/06-TvdP-inference-language-alignment.ipynb index 4c4ab73..227578b 100644 --- a/notebooks/06-TvdP-inference-language-alignment.ipynb +++ b/notebooks/06-TvdP-inference-language-alignment.ipynb @@ -119,8 +119,8 @@ "metadata": {}, "outputs": [], "source": [ - "from src.models.components.eo_encoders.cnn_encoder import CNNEncoder\n", - "from src.models.components.eo_encoders.geoclip import GeoClipCoordinateEncoder\n", + "from src.models.components.geo_encoders.cnn_encoder import CNNEncoder\n", + "from src.models.components.geo_encoders.geoclip import GeoClipCoordinateEncoder\n", "from src.models.components.loss_fns.bce_loss import BCELoss\n", "from src.models.components.loss_fns.clip_loss import ClipLoss\n", "from src.models.components.pred_heads.mlp_pred_head import MLPPredictionHead\n", diff --git a/pyproject.toml b/pyproject.toml index 4c51d60..f49df39 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,6 +33,7 @@ dependencies = [ "peft>=0.18.1", "llm2vec", "setuptools<81", + "geotessera>=0.7.3", ] [project.optional-dependencies] @@ -40,9 +41,6 @@ create-data = [ "geemap>=0.36.6", "pipreqs>=0.5.0", ] -geotessera = [ - "geotessera>=0.7.3", -] [tool.pytest.ini_options] addopts = [ diff --git a/scripts/schedule.sh b/scripts/schedule.sh index 8054b63..8bab438 100644 --- a/scripts/schedule.sh +++ b/scripts/schedule.sh @@ -1,23 +1,27 @@ #!/bin/bash -#SBATCH--cpus-per-task=8 -#SBATCH--partition=gpu -#SBATCH--gpus=1 -#SBATCH--job-name=aether -#SBATCH--mem=100G -#SBATCH--time=100 +#SBATCH --cpus-per-task=8 +#SBATCH --partition=gpu +#SBATCH --gpus=1 +#SBATCH --job-name=aether +#SBATCH --mem=100G +#SBATCH --time=100 +#SBATCH --output=logs/out_%j.out +#SBATCH --error=logs/err_%j.err # Schedule execution of many runs # Run from root folder with: bash scripts/schedule.sh # Variables +# shellcheck disable=SC1091 source .env -# Environment +#Environment +# shellcheck disable=SC1091 source .venv/bin/activate # Runs #srun python src/train.py experiment=alignment trainer=$TRAINER_PROFILE logger=$LOGGER -#srun python src/train.py experiment=prediction logger=wandb +srun python -u src/train.py experiment=alignment_v1 # example runs with overwritten configs #srun python src/train.py experiment=alignment trainer=ddp_sim trainer.max_epochs=10 data.pin_memory=false diff --git a/src/data/base_caption_builder.py b/src/data/base_caption_builder.py index 3ca4b77..54a020f 100644 --- a/src/data/base_caption_builder.py +++ b/src/data/base_caption_builder.py @@ -11,7 +11,9 @@ class BaseCaptionBuilder(ABC): - def __init__(self, templates_fname: str, data_dir: str, seed: int) -> None: + def __init__( + self, templates_fname: str, concepts_fname: str, data_dir: str, seed: int + ) -> None: """Interface of caption builder class for converting numerical auxiliary data values into textual descriptions from provided caption templates. @@ -21,10 +23,13 @@ def __init__(self, templates_fname: str, data_dir: str, seed: int) -> None: """ self.data_dir = data_dir - templates_path = os.path.join(self.data_dir, "caption_templates", templates_fname) + templates_path = os.path.join(self.data_dir, "location_caption_templates", templates_fname) self.templates = json.load(open(templates_path)) self.tokens_in_template = [self._extract_tokens(t) for t in self.templates] + concepts_path = os.path.join(self.data_dir, "concept_captions", concepts_fname) + self.concepts = json.load(open(concepts_path)) + self.column_to_metadata_map: Dict[str] | None = None self.seed = seed random.seed(self.seed) @@ -42,7 +47,9 @@ def sync_with_dataset(self, dataset: BaseDataset) -> None: @staticmethod def _extract_tokens(template: str) -> List[str]: """Extract tokens in template and return a list of tokens.""" - return re.findall(r"<([^<>]+)>", template) + tokens = re.findall(r"<([^<>]+)>", template) + # TODO: check if those columns exist in the dataset maps + return tokens @staticmethod def _fill(template: str, fillers: Dict[str, str]) -> str: @@ -96,15 +103,18 @@ def all(self, aux_values) -> List[str]: return formatted_rows - def build_concepts(self, aux_values) -> List[str]: - pass + def sync_concepts(self) -> List[str]: + for concept in self.concepts: + concept["id"] = self.column_to_metadata_map["aux"][concept["col"]]["id"] class DummyCaptionBuilder(BaseCaptionBuilder): """Dummy caption builder for testing purposes.""" - def __init__(self, templates_fname: str, data_dir: str, seed: int) -> None: - super().__init__(templates_fname, data_dir, seed) + def __init__( + self, templates_fname: str, concepts_fname: str, data_dir: str, seed: int + ) -> None: + super().__init__(templates_fname, concepts_fname, data_dir, seed) def sync_with_dataset(self, dataset) -> None: pass diff --git a/src/data/base_datamodule.py b/src/data/base_datamodule.py index 3b639ec..1ef0a48 100644 --- a/src/data/base_datamodule.py +++ b/src/data/base_datamodule.py @@ -1,12 +1,12 @@ import copy import os +import time from functools import partial from typing import Any, Dict, List, Tuple import numpy as np import pandas as pd import torch -from geopy.distance import distance as geodist # avoid naming confusion from lightning import LightningDataModule from sklearn.cluster import DBSCAN from sklearn.model_selection import GroupShuffleSplit @@ -63,6 +63,7 @@ def __init__( assert caption_builder is not None, "Caption_builder cannot be None" self.caption_builder = caption_builder self.caption_builder.sync_with_dataset(self.dataset) + self.concept_configs = caption_builder.concepts self.split_data() @@ -120,20 +121,27 @@ def split_data(self) -> None: } elif self.hparams.split_mode == "spatial_clusters": - print("Splitting dataset using spatial clusters. This can take a while...") - coords = np.array([self.dataset.df.lat, self.dataset.df.lon]).T - if len(coords) > 2000: - print( - "Warning: DBSCAN clustering on more than 2000 samples may be slow. Maybe set n_jobs in DBScan?" - ) - # 4000 m distance between points. Use geodist to calculate true distance. min_dist = self.hparams.spatial_split_distance_m + coords = np.array([self.dataset.df.lat, self.dataset.df.lon]).T + n = len(coords) + print( + f"Splitting {n} samples into spatial clusters " + f"(eps={min_dist / 1000:.1f} km, haversine, n_jobs=-1)..." + ) + # Convert (lat, lon) degrees to radians for sklearn's haversine metric. + # haversine returns arc length on the unit sphere, so eps must be in radians. + _EARTH_RADIUS_M = 6_371_000 + coords_rad = np.radians(coords) + eps_rad = min_dist / _EARTH_RADIUS_M + t0 = time.time() clustering = DBSCAN( - eps=min_dist, - metric=lambda u, v: geodist(u, v).meters, + eps=eps_rad, + metric="haversine", + algorithm="ball_tree", min_samples=2, - ).fit(coords) - print("Clustering done. Creating splits and saving.") + n_jobs=-1, + ).fit(coords_rad) + print(f"DBSCAN done in {time.time() - t0:.1f}s. Creating splits...") # Non-clustered points are labeled -1. Change to new cluster label. clusters = copy.deepcopy(clustering.labels_) new_cl = np.max(clusters) + 1 diff --git a/src/data/butterfly_caption_builder.py b/src/data/butterfly_caption_builder.py index 66b60a3..d6dee33 100644 --- a/src/data/butterfly_caption_builder.py +++ b/src/data/butterfly_caption_builder.py @@ -16,8 +16,10 @@ class ButterflyCaptionBuilder(BaseCaptionBuilder): - def __init__(self, templates_fname: str, data_dir: str, seed: int): - super().__init__(templates_fname, data_dir, seed) + def __init__( + self, templates_fname: str, concepts_fname: str, data_dir: str, seed: int + ) -> None: + super().__init__(templates_fname, concepts_fname, data_dir, seed) @override def sync_with_dataset(self, dataset: BaseDataset) -> None: @@ -42,6 +44,8 @@ def sync_with_dataset(self, dataset: BaseDataset) -> None: "units": units, } + self.sync_concepts() + def get_corine_column_keys(self): """Returns metadata for corine columns.""" if not os.path.isfile(os.path.join(self.data_dir, "corine_classes.csv")): diff --git a/src/data/butterfly_dataset.py b/src/data/butterfly_dataset.py index 422f8c2..5ecacf7 100644 --- a/src/data/butterfly_dataset.py +++ b/src/data/butterfly_dataset.py @@ -17,7 +17,7 @@ def __init__( data_dir: str, modalities: dict, use_target_data: bool = True, - use_aux_data: Dict[str, List[str] | str] | None = None, + use_aux_data: Any = None, seed: int = 12345, cache_dir: str = None, mock: bool = False, @@ -28,7 +28,7 @@ def __init__( :param modalities: a dict of modalities needed as EO data (for EO encoder) (e.g., {"coords": None, "s2": {"channels": "rgb", "preprocessing": "zscored"}}) :param use_target_data: if target values should be returned - :param use_aux_data: if auxiliary values should be returned + :param use_aux_data: which (if any) auxiliary values should be returned :param seed: random seed :param cache_dir: path to cache dir :param mock: whether to mock csv file diff --git a/src/data/collate_fns.py b/src/data/collate_fns.py index 4e37844..2c001b4 100644 --- a/src/data/collate_fns.py +++ b/src/data/collate_fns.py @@ -39,10 +39,7 @@ def collate_fn( # convert aux into captions if mode == "train": batch_collected["text"] = caption_builder.random(batch_collected["aux"]) - elif mode == "val": - batch_collected["text"] = caption_builder.all(batch_collected["aux"]) else: batch_collected["text"] = caption_builder.all(batch_collected["aux"]) - # batch_collected['concepts'] = caption_builder.build_concepts(batch_collected["aux"]) return batch_collected diff --git a/src/data/heat_guatemala_dataset.py b/src/data/heat_guatemala_dataset.py index c20b686..f4fabac 100644 --- a/src/data/heat_guatemala_dataset.py +++ b/src/data/heat_guatemala_dataset.py @@ -56,7 +56,7 @@ def __init__( dataset_name="heat_guatemala", seed=seed, cache_dir=cache_dir, - implemented_mod={"coords"}, + implemented_mod={"coords", "tessera"}, mock=mock, use_features=use_features, ) @@ -67,6 +67,14 @@ def __init__( def setup(self) -> None: """No files to download / prepare for this dataset.""" + # Set up each requested modality + for mod in self.modalities.keys(): + if mod == "coords" and len(self.modalities.keys()) == 1: + return + elif mod == "tessera": + self.setup_tessera() + # elif mod == "aef": + # self.setup_aef() return @override diff --git a/src/data/yield_africa_dataset.py b/src/data/yield_africa_dataset.py new file mode 100644 index 0000000..5456bf2 --- /dev/null +++ b/src/data/yield_africa_dataset.py @@ -0,0 +1,238 @@ +"""Yield Africa dataset. + +Location: src/data/yield_africa_dataset.py + +Crop yield regression use case for East/Southern Africa. +Tabular features (soil, climate, etc.) live in the model-ready CSV as feat_* +columns and are picked up automatically by BaseDataset.get_records(). +They do NOT need to be listed in `modalities`. +""" + +import logging +import os +from typing import Any, Dict, List, override + +import numpy as np +import pandas as pd +import torch + +from src.data.base_dataset import BaseDataset + +# Number of channels in a TESSERA embedding tile (fixed by the geotessera model). +_TESSERA_CHANNELS = 128 + +log = logging.getLogger(__name__) + +# Fixed ordered list of all countries in the full dataset. +# Used to produce a consistent one-hot encoding regardless of which +# countries are present after filtering. +_ALL_COUNTRIES = ["BF", "BUR", "ETH", "KEN", "MAL", "RWA", "TAN", "ZAM"] + +# Study-area bounds used to normalise coordinates before computing Fourier +# harmonics. Normalising to the actual data extent (rather than ±90°/±180°) +# makes the harmonics maximally discriminative within the dataset. +# Latitude : 30°S – 15°N → centre −7.5°, half-range 22.5° +# Longitude : 10°E – 45°E → centre 27.5°, half-range 17.5° +_LAT_CENTER = -7.5 +_LAT_HALF_RANGE = 22.5 +_LON_CENTER = 27.5 +_LON_HALF_RANGE = 17.5 + + +class YieldAfricaDataset(BaseDataset): + """Dataset for the crop yield regression use case (East/Southern Africa). + + CSV layout expected: + - name_loc : unique location identifier + - lat, lon : WGS84 coordinates + - target_* : crop yield target(s) [t/ha] + - feat_* : tabular features (soil properties, climate indices, etc.) + - aux_* : auxiliary data columns (optional) + - country, year : metadata columns used for optional filtering + + Modality design note + -------------------- + `implemented_mod = {"coords"}` because tabular features live directly in + the model-ready CSV and are picked up via the `feat_` column prefix. + They do NOT need to be listed in `modalities`. + + In addition to the CSV feat_* columns, the following features are injected: + - ``feat_year`` : normalised year (zero-mean, unit-std) + - ``feat_country_{CODE}`` : one-hot country encoding (always 8 columns, + stable across country filters) + - ``feat_lat_sin1/cos1`` : fundamental latitude harmonic, normalised to + the study-area extent (30°S–15°N) + - ``feat_lat_sin2/cos2`` : second latitude harmonic (captures bimodal vs. + unimodal rainfall boundary near the equator) + - ``feat_lon_sin1/cos1`` : fundamental longitude harmonic, normalised to + the study-area extent (10°E–45°E) + + The Fourier harmonics encode the ITCZ-driven latitudinal climate gradient at + interpretable frequencies, complementing GeoCLIP's photo-derived coordinate + embedding and enabling richer text captions for the explainability component. + """ + + def __init__( + self, + data_dir: str, + modalities: dict, + use_target_data: bool = True, + use_aux_data: Dict[str, Any] | str = None, + seed: int = 12345, + cache_dir: str = None, + mock: bool = False, + use_features: bool = True, + countries: List[str] | None = None, + years: List[int] | None = None, + exclude_countries: List[str] | None = None, + exclude_years: List[int] | None = None, + ) -> None: + super().__init__( + data_dir=data_dir, + modalities=modalities, + use_target_data=use_target_data, + use_aux_data=use_aux_data, + dataset_name="yield_africa", + seed=seed, + cache_dir=cache_dir, + implemented_mod={"coords", "tessera"}, + mock=mock, + use_features=use_features, + ) + + # Inject year and country one-hot columns as feat_* so that + # get_records() picks them up automatically. Build all new columns in + # one concat to avoid pandas PerformanceWarning from repeated assignment. + if use_features and "year" in self.df.columns and "country" in self.df.columns: + # Normalise feat_year to the same scale as the pre-scaled CSV feat_* columns + # (roughly zero-mean, unit-std) so it doesn't dominate LayerNorm. + _YEAR_MEAN = 2018.0 + _YEAR_STD = 2.0 + new_cols: Dict[str, Any] = { + "feat_year": (self.df["year"].astype(float) - _YEAR_MEAN) / _YEAR_STD + } + for code in _ALL_COUNTRIES: + new_cols[f"feat_country_{code}"] = (self.df["country"] == code).astype(float) + + # Fourier harmonics of coordinates, normalised to the study-area extent. + # + # Africa's agricultural patterns follow the ITCZ-driven latitudinal climate + # gradient: rainfall regime (uni- vs. bimodal), growing-season length, and + # temperature vary sinusoidally with latitude. Explicit harmonics give the + # model these signals directly and at interpretable frequencies, complementing + # GeoCLIP's learned (but photo-derived) coordinate embedding. + # + # lat_norm / lon_norm ∈ [-1, 1] within the study area; π * norm ∈ [-π, π]. + # Two harmonics for latitude (captures both the broad N-S gradient and the + # equatorial-bimodal / southern-unimodal boundary); one for longitude + # (east-west Indian Ocean moisture gradient). + lat_norm = (self.df["lat"].astype(float) - _LAT_CENTER) / _LAT_HALF_RANGE + lon_norm = (self.df["lon"].astype(float) - _LON_CENTER) / _LON_HALF_RANGE + new_cols["feat_lat_sin1"] = np.sin(np.pi * lat_norm) + new_cols["feat_lat_cos1"] = np.cos(np.pi * lat_norm) + new_cols["feat_lat_sin2"] = np.sin(2.0 * np.pi * lat_norm) + new_cols["feat_lat_cos2"] = np.cos(2.0 * np.pi * lat_norm) + new_cols["feat_lon_sin1"] = np.sin(np.pi * lon_norm) + new_cols["feat_lon_cos1"] = np.cos(np.pi * lon_norm) + + self.df = pd.concat([self.df, pd.DataFrame(new_cols, index=self.df.index)], axis=1) + + # Apply country/year filters to self.df and rebuild records. + # BaseDataset.__init__ has already loaded the CSV; filtering here avoids + # touching BaseDataset and keeps the logic use-case specific. + n_before = len(self.df) + if countries is not None and "country" in self.df.columns: + self.df = self.df[self.df["country"].isin(countries)].reset_index(drop=True) + if years is not None and "year" in self.df.columns: + self.df = self.df[self.df["year"].isin(years)].reset_index(drop=True) + if exclude_countries is not None and "country" in self.df.columns: + self.df = self.df[~self.df["country"].isin(exclude_countries)].reset_index(drop=True) + if exclude_years is not None and "year" in self.df.columns: + self.df = self.df[~self.df["year"].isin(exclude_years)].reset_index(drop=True) + + n_after = len(self.df) + if n_after != n_before: + log.info( + f"Country/year filter: {n_before} → {n_after} records ({n_before - n_after} excluded)" + ) + + # get_records() mutates self.use_aux_data in place (replacing pattern + # dicts with resolved column-name lists), so reset it from the original + # parameter before calling it a second time. + if use_aux_data is None or use_aux_data == "all": + self.use_aux_data = { + "aux": {"pattern": "^aux_(?!.*top).*"}, + "top": {"pattern": "^aux_.*top.*"}, + } + elif isinstance(use_aux_data, dict): + self.use_aux_data = use_aux_data + else: + self.use_aux_data = None + + # Always rebuild so feat_year / feat_country_* are reflected in + # self.feat_names and self.tabular_dim. + self.records = self.get_records() + + def setup(self) -> None: + """Check for requested modality data; warn if TESSERA tiles are absent. + + Unlike other datasets, TESSERA tiles for yield_africa vary per record + year and must be pre-fetched with the preprocessing script: + python src/data_preprocessing/yield_africa_tessera_preprocess.py + + setup_tessera() is intentionally not called here because it uses a + single fixed year for bulk download, which is incompatible with the + multi-year nature of this dataset. + """ + if "tessera" in self.modalities: + tessera_dir = os.path.join(self.data_dir, "eo", "tessera") + if not os.path.exists(tessera_dir) or len(os.listdir(tessera_dir)) == 0: + log.warning( + "TESSERA tiles not found at %s. " + "Run src/data_preprocessing/yield_africa_tessera_preprocess.py " + "to pre-fetch tiles. Missing tiles will fall back to zero tensors.", + tessera_dir, + ) + + @override + def __getitem__(self, idx: int) -> Dict[str, Any]: + row = self.records[idx] + sample: Dict[str, Any] = {"eo": {}} + + for modality in self.modalities: + if modality == "coords": + sample["eo"]["coords"] = torch.tensor( + [row["lat"], row["lon"]], dtype=torch.float32 + ) + elif modality == "tessera": + tile_path = row["tessera_path"] + if os.path.exists(tile_path): + sample["eo"]["tessera"] = self.load_npy(tile_path) + else: + size = self.modalities["tessera"].get("size", 9) + log.debug("TESSERA tile missing: %s — using zero fallback.", tile_path) + sample["eo"]["tessera"] = torch.zeros( + _TESSERA_CHANNELS, size, size, dtype=torch.float32 + ) + + if self.use_features and self.feat_names: + sample["eo"]["tabular"] = torch.tensor( + [row[k] for k in self.feat_names], dtype=torch.float32 + ) + + if self.use_target_data: + sample["target"] = torch.tensor( + [row[k] for k in self.target_names], dtype=torch.float32 + ) + + if self.use_aux_data: + sample["aux"] = {} + for aux_cat, vals in self.use_aux_data.items(): + if aux_cat == "aux": + sample["aux"][aux_cat] = torch.tensor( + [row[v] for v in vals], dtype=torch.float32 + ) + else: + sample["aux"][aux_cat] = [row[v] for v in vals] + + return sample diff --git a/src/data_preprocessing/gee_utils.py b/src/data_preprocessing/gee_utils.py index def17a9..f1f73bf 100644 --- a/src/data_preprocessing/gee_utils.py +++ b/src/data_preprocessing/gee_utils.py @@ -250,8 +250,8 @@ def get_distance_to_road_within_aoi(aoi, cell_size=30, radius_max=5000): reducer=ee.Reducer.mean(), geometry=aoi, scale=cell_size, maxPixels=1e9 ) return { - "maxdist_road": int(max_distance.get("distance").getInfo()), - "meandist_road": int(mean_distance.get("distance").getInfo()), + "maxdist_road": int(max_distance.get("distance").getInfo() or radius_max), + "meandist_road": int(mean_distance.get("distance").getInfo() or radius_max), } diff --git a/src/data_preprocessing/make_model_ready_yield_africa.py b/src/data_preprocessing/make_model_ready_yield_africa.py new file mode 100644 index 0000000..f2b6dd0 --- /dev/null +++ b/src/data_preprocessing/make_model_ready_yield_africa.py @@ -0,0 +1,732 @@ +"""Build model-ready CSV/Parquet for the crop yield Africa use case +(data/yield_africa/model_ready_yield-africa.csv). + +Features: +- Load raw dataset (CSV or Parquet) +- Compute derived features (CN_ratio, layer deltas, WHC proxy, aridity index) +- Apply log transforms to skewed features +- Fit StandardScaler on train split only +- Encode categorical features as integer indices +- Remove yield outliers beyond 3 IQR +- Preserve metadata columns +- Save fitted transformers for inference-time reuse +- Calculate and save spatial cross-validation splits +""" + +import argparse +import json +import logging +import warnings +from pathlib import Path +from typing import Any, Dict, List, Tuple + +import joblib +import numpy as np +import pandas as pd +import torch +from sklearn.preprocessing import LabelEncoder, StandardScaler + +log = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Constants +# --------------------------------------------------------------------------- + +MODEL_READY_DATA_NAME = "yield_africa" + +TRAIN_COUNTRIES = ["BUR", "ETH", "KEN", "MAL", "RWA", "TAN", "ZAM"] + +SPATIAL_SPLIT_BLOCK_SIZE_KM = 50.0 +SPATIAL_SPLIT_N_SPLITS = 7 + +CONTINUOUS_FEATURES = [ + # Soil features + "C_0_20", + "C_20_50", + "N_0_20", + "N_20_50", + "P_0_20", + "P_20_50", + "MA_0_20", + "MA_20_50", + "PO_0_20", + "PO_20_50", + "pH_0_20", + "pH_20_50", + "BD_0_20", + "BD_20_50", + "ECX_0_20", + "ECX_20_50", + "CA_0_20", + "CA_20_50", + # Climate features + "PrecJJA", + "PrecMAM", + "PrecSON", + "PrecDJF", + "TaveJJA", + "TaveMAM", + "TaveSON", + "TaveDJF", + "TmaxJJA", + "TmaxMAM", + "TmaxSON", + "TmaxDJF", + "TminJJA", + "TminMAM", + "TminSON", + "TminDJF", + "CMD", + "Eref", + "MAP", + "MAT", + "TD", + "MWMT", + "MCMT", + "DD_above_5", + "DD_above_18", + "DD_below_18", + # Terrain features + "DEM", + "Slope", + "Aspect", + "CHILI", + "Top_div", + # Land cover / context + "Tree_c", + "Dist_water", + "Paved", + "Unpaved", + "Pop_10km", + # Derived features (computed automatically) + "CN_ratio", + "C_layer_delta", + "BD_layer_delta", + "WHC_proxy", + "aridity_index", +] + +# Categorical columns that are actual tabular inputs to the regression model (feat_ prefix). +# TX_*_cl are soil texture classes — they are not derived from a paired numerical column. +TABULAR_CATEGORICAL_FEATURES = [ + "TX_0_20_cl", + "TX_20_50_cl", +] + +# Categorical columns derived from their paired numerical columns (same name without _cl). +# Used for caption generation (aux_ prefix). +AUX_FEATURES = [ + # target (classified) + "Yld_ton_ha_cl", + # soil features + "C_0_20_cl", + "C_20_50_cl", + "N_0_20_cl", + "N_20_50_cl", + "P_0_20_cl", + "P_20_50_cl", + "MA_0_20_cl", + "MA_20_50_cl", + "PO_0_20_cl", + "PO_20_50_cl", + "pH_0_20_cl", + "pH_20_50_cl", + "BD_0_20_cl", + "BD_20_50_cl", + "ECX_0_20_cl", + "ECX_20_50_cl", + "CA_0_20_cl", + "CA_20_50_cl", + # climate features + "PrecJJA_cl", + "PrecMAM_cl", + "PrecSON_cl", + "PrecDJF_cl", + "TaveJJA_cl", + "TaveMAM_cl", + "TaveSON_cl", + "TaveDJF_cl", + "TmaxJJA_cl", + "TmaxMAM_cl", + "TmaxSON_cl", + "TmaxDJF_cl", + "TminJJA_cl", + "TminMAM_cl", + "TminSON_cl", + "TminDJF_cl", + "CMD_cl", + "Eref_cl", + "MAP_cl", + "MAT_cl", + "TD_cl", + "MCMT_cl", + "MWMT_cl", + "DD_above_5_cl", + "DD_above_18_cl", + "DD_below_18_cl", + # terrain features + "DEM_cl", + "Slope_cl", + "Aspect_cl", + "Landform_cl", + "CHILI_cl", + "Top_div_cl", + # land cover / context + "GLAD_cl", + "Tree_c_cl", + "Dist_water_cl", + "Paved_cl", + "Unpaved_cl", + "Pop_10km_cl", +] + +LOG_TRANSFORM_FEATURES = ["Dist_water", "Paved", "Unpaved", "Pop_10km"] + +TARGET_COLUMNS = ["Yld_ton_ha"] + +NAME_LOC_COLUMN = "ID" + +METADATA_COLUMNS = ["Lat", "Lon", "Country", "Year", "Location_accuracy"] + +# Saxton & Rawls (2006) Table 3-derived AWC (FC - WP) +# Conditions: ~2.5% OM, no salinity/gravel/density adjustment +# "Plant avail." (%v) converted to mm/m via mm/m = (%v) * 10 +# (because 1% v/v = 0.01 m³/m³ = 10 mm per m soil) +# Values are in mm/m (approximate field capacity - wilting point) +WHC_LOOKUP_SAXTON_RAWLS_2006_OM2P5 = { + "Sand": 50, + "Loamy sand": 70, + "Sandy loam": 100, + "Loam": 140, + "Silt loam": 200, + "Silt": 250, + "Sandy clay loam": 100, + "Clay loam": 140, + "Silty clay loam": 170, + "Silty clay": 140, + "Sandy clay": 110, + "Clay": 120, +} + +# --------------------------------------------------------------------------- +# Preprocessing functions +# --------------------------------------------------------------------------- + + +def build_column_rename_map( + continuous_features: List[str], + tabular_categorical_features: List[str], + aux_features: List[str], + target_columns: List[str], + name_loc_column: str, + metadata_columns: List[str], +) -> Dict[str, str]: + """Build a column rename mapping that standardises predictor and target names. + + Convention: + - Numerical predictors and tabular categorical features: ``feat_{original.lower()}`` + - Aux/caption features (derived categorical classes): ``aux_{original.lower()}`` + - Target columns: ``target_{original.lower()}`` + - Name-location column: ``name_loc`` + - Metadata columns: ``{original.lower()}`` + """ + rename: Dict[str, str] = {} + for col in continuous_features: + rename[col] = f"feat_{col.lower()}" + for col in tabular_categorical_features: + rename[col] = f"feat_{col.lower()}" + for col in aux_features: + rename[col] = f"aux_{col.lower()}" + for col in target_columns: + rename[col] = f"target_{col.lower()}" + for col in metadata_columns: + rename[col] = col.lower() + if name_loc_column is not None: + rename[name_loc_column] = "name_loc" + return rename + + +def compute_derived_features(df: pd.DataFrame) -> pd.DataFrame: + """Compute derived features from raw measurements.""" + df = df.copy() + + # C:N ratio (guard against division by zero) + df["CN_ratio"] = np.where( + df["N_0_20"] > 0, + df["C_0_20"] / df["N_0_20"], + np.nan, + ) + + # Layer deltas (stratification indicators) + df["C_layer_delta"] = df["C_0_20"] - df["C_20_50"] + df["BD_layer_delta"] = df["BD_0_20"] - df["BD_20_50"] + + # Water Holding Capacity proxy from texture lookup, adjusted by bulk density + _whc_lookup_lower = {k.lower(): v for k, v in WHC_LOOKUP_SAXTON_RAWLS_2006_OM2P5.items()} + df["WHC_proxy"] = ( + df["TX_0_20_cl"] + .astype(str) + .str.lower() + .map(_whc_lookup_lower) + .fillna(WHC_LOOKUP_SAXTON_RAWLS_2006_OM2P5["Sandy loam"]) + ) + + # Adjust WHC by bulk density (inverse relationship, reference BD = 1.3 g/cm³) + bd_factor = np.where(df["BD_0_20"] > 0, 1.3 / df["BD_0_20"], 1.0) + df["WHC_proxy"] = df["WHC_proxy"] * bd_factor + + # Aridity index (guard against MAP=0) + df["aridity_index"] = np.where(df["MAP"] > 0, df["CMD"] / df["MAP"], np.nan) + + return df + + +def apply_log_transforms(df: pd.DataFrame, log_transform_features: List[str]) -> pd.DataFrame: + """Apply log(x + 1) transform to skewed features.""" + df = df.copy() + for col in log_transform_features: + if col in df.columns: + df[col] = np.log1p(np.maximum(df[col], 0)) + return df + + +def remove_yield_outliers( + df: pd.DataFrame, + target_col: str = "Yld_ton_ha", + iqr_multiplier: float = 3.0, +) -> Tuple[pd.DataFrame, pd.Series]: + """Remove yield outliers beyond IQR threshold.""" + if target_col not in df.columns: + warnings.warn(f"Target column '{target_col}' not found; skipping outlier removal") + return df, pd.Series([False] * len(df), index=df.index) + + q1 = df[target_col].quantile(0.25) + q3 = df[target_col].quantile(0.75) + iqr = q3 - q1 + lower_bound = q1 - iqr_multiplier * iqr + upper_bound = q3 + iqr_multiplier * iqr + outlier_mask = (df[target_col] < lower_bound) | (df[target_col] > upper_bound) + + n_outliers = outlier_mask.sum() + if n_outliers > 0: + log.info( + f"Removing {n_outliers} yield outliers (< {lower_bound:.2f} or > {upper_bound:.2f} t/ha)" + ) + + return df[~outlier_mask].copy(), outlier_mask + + +def fit_scaler(df: pd.DataFrame, continuous_features: List[str]) -> StandardScaler: + """Fit StandardScaler on continuous features.""" + available_features = [f for f in continuous_features if f in df.columns] + if len(available_features) < len(continuous_features): + missing = set(continuous_features) - set(available_features) + warnings.warn(f"Missing continuous features: {missing}") + scaler = StandardScaler() + scaler.fit(df[available_features]) + return scaler + + +def apply_scaler( + df: pd.DataFrame, + scaler: StandardScaler, + continuous_features: List[str], +) -> pd.DataFrame: + """Apply fitted StandardScaler to continuous features.""" + df = df.copy() + available_features = [f for f in continuous_features if f in df.columns] + df[available_features] = scaler.transform(df[available_features]) + return df + + +def fit_label_encoders( + df: pd.DataFrame, + categorical_features: List[str], +) -> Dict[str, LabelEncoder]: + """Fit LabelEncoders for categorical features.""" + encoders = {} + for col in categorical_features: + if col not in df.columns: + warnings.warn(f"Categorical feature '{col}' not found; skipping") + continue + encoder = LabelEncoder() + encoder.fit(df[col].dropna()) + encoders[col] = encoder + log.info(f" {col}: {len(encoder.classes_)} classes") + return encoders + + +def apply_label_encoders( + df: pd.DataFrame, + encoders: Dict[str, LabelEncoder], + categorical_features: List[str], +) -> pd.DataFrame: + """Apply fitted LabelEncoders to categorical features.""" + df = df.copy() + for col in categorical_features: + if col not in df.columns: + continue + if col not in encoders: + warnings.warn(f"No encoder found for '{col}'; skipping") + continue + encoder = encoders[col] + df[col] = df[col].apply( + lambda x: encoder.transform([x])[0] if x in encoder.classes_ else -1 + ) + return df + + +def calculate_spatial_splits( + df: pd.DataFrame, + block_size_km: float = 50.0, + n_splits: int = 7, + lat_col: str = "lat", + lon_col: str = "lon", + name_loc_col: str = "name_loc", + save_path: str | Path | None = None, +) -> Dict[str, Any]: + """Calculate spatial cross-validation splits using a grid-based blocking approach.""" + log.info(f"Calculating spatial splits with block size {block_size_km}km and {n_splits} folds") + + if "split" in df.columns: + train_df = df[df["split"] == "train"].copy() + log.info(f"Filtered to training split: {len(train_df)} samples") + else: + train_df = df.copy() + log.info(f"No 'split' column found; using all {len(train_df)} samples for spatial splits") + + # Approx conversion: 1 deg ~ 111 km + block_size_deg = block_size_km / 111.0 + train_df["lat_grid"] = np.floor(train_df[lat_col] / block_size_deg) + train_df["lon_grid"] = np.floor(train_df[lon_col] / block_size_deg) + train_df["block_id"] = ( + train_df["lat_grid"].astype(str) + "_" + train_df["lon_grid"].astype(str) + ) + + unique_blocks = train_df["block_id"].unique() + log.info(f"Created {len(unique_blocks)} spatial blocks") + + # Greedy bin packing: assign blocks largest-first to the smallest fold + block_counts = train_df["block_id"].value_counts().sort_values(ascending=False) + fold_samples = [0] * n_splits + fold_block_ids: List[List[str]] = [[] for _ in range(n_splits)] + + for block_id, count in block_counts.items(): + smallest_fold = int(np.argmin(fold_samples)) + fold_samples[smallest_fold] += count + fold_block_ids[smallest_fold].append(block_id) + + block_to_names = train_df.groupby("block_id")[name_loc_col].unique().to_dict() + + splits: Dict[str, Any] = {} + for fold in range(n_splits): + val_names = [name for bid in fold_block_ids[fold] for name in block_to_names[bid].tolist()] + train_names = [ + name + for f in range(n_splits) + if f != fold + for bid in fold_block_ids[f] + for name in block_to_names[bid].tolist() + ] + splits[f"fold_{fold}"] = {"train": train_names, "val": val_names} + log.info( + f" Fold {fold}: {len(train_names)} train locations, {len(val_names)} val locations " + f"({fold_samples[fold]} samples)" + ) + + if save_path: + save_path = Path(save_path) + save_path.parent.mkdir(parents=True, exist_ok=True) + torch.save(splits, save_path) + log.info(f"Saved spatial splits to {save_path}") + + return splits + + +# --------------------------------------------------------------------------- +# Main pipeline +# --------------------------------------------------------------------------- + + +def main( + source_csv: str, + out_csv: str, + out_parquet: str | None = None, + spatial_splits: bool = False, + countries: List[str] | None = None, + years: List[int] | None = None, + exclude_countries: List[str] | None = None, + exclude_years: List[int] | None = None, +) -> pd.DataFrame: + """Preprocessing workflow for the crop yield Africa dataset.""" + data_path = Path(source_csv) + out_csv_path = Path(out_csv) + data_dir = out_csv_path.parent + + scaler_path = data_dir / f"fitted_scaler_{MODEL_READY_DATA_NAME}.pkl" + encoders_path = data_dir / f"label_encoders_{MODEL_READY_DATA_NAME}.pkl" + spatial_split_path = ( + data_dir + / "splits" + / ( + f"split_spatial_{SPATIAL_SPLIT_N_SPLITS}_folds_with_" + f"{SPATIAL_SPLIT_BLOCK_SIZE_KM}km_blocks_{MODEL_READY_DATA_NAME}.pth" + ) + ) + + log.info("Starting preprocessing pipeline...") + log.info(f"Input: {data_path}") + + # Load raw data + raw_df = pd.read_csv(data_path) + n_raw = len(raw_df) + log.info(f"Loaded {n_raw} rows from {data_path}") + + # Filter by country and year before any other processing + df = raw_df.copy() + if countries is not None: + before = len(df) + df = df[df["Country"].isin(countries)].copy() + log.info( + f"Country filter ({', '.join(sorted(countries))}): " + f"kept {len(df)}, excluded {before - len(df)}" + ) + if years is not None: + before = len(df) + df = df[df["Year"].isin(years)].copy() + log.info( + f"Year filter ({', '.join(str(y) for y in sorted(years))}): " + f"kept {len(df)}, excluded {before - len(df)}" + ) + if exclude_countries is not None: + before = len(df) + df = df[~df["Country"].isin(exclude_countries)].copy() + log.info( + f"Exclude countries ({', '.join(sorted(exclude_countries))}): " + f"kept {len(df)}, excluded {before - len(df)}" + ) + if exclude_years is not None: + before = len(df) + df = df[~df["Year"].isin(exclude_years)].copy() + log.info( + f"Exclude years ({', '.join(str(y) for y in sorted(exclude_years))}): " + f"kept {len(df)}, excluded {before - len(df)}" + ) + n_after_filter = len(df) + + # Determine train mask on the filtered data + train_mask = df["Country"].isin(TRAIN_COUNTRIES) + log.info( + f"Train mask: {train_mask.sum()} rows ({', '.join(TRAIN_COUNTRIES)}) out of {n_after_filter} total" + ) + + # Compute derived features + log.info("Computing derived features...") + df = compute_derived_features(df) + + # Remove yield outliers + log.info("Removing yield outliers...") + df, outlier_mask = remove_yield_outliers(df, TARGET_COLUMNS[0], iqr_multiplier=3.0) + # Re-align train_mask after outlier removal + train_mask = train_mask[df.index] + + # Apply log transforms + log.info("Applying log transforms...") + df = apply_log_transforms(df, LOG_TRANSFORM_FEATURES) + + train_df = df[train_mask] + log.info(f"Training set size: {len(train_df)} samples") + + # Fit transformers on train split only + log.info("Fitting StandardScaler on train split...") + scaler = fit_scaler(train_df, CONTINUOUS_FEATURES) + + log.info("Fitting LabelEncoders on train split...") + all_categorical = TABULAR_CATEGORICAL_FEATURES + AUX_FEATURES + encoders = fit_label_encoders(train_df, all_categorical) + + # Apply transformations to full dataset + log.info("Applying transformations to full dataset...") + df = apply_scaler(df, scaler, CONTINUOUS_FEATURES) + df = apply_label_encoders(df, encoders, all_categorical) + + # Save transformers + scaler_path.parent.mkdir(parents=True, exist_ok=True) + joblib.dump(scaler, scaler_path) + log.info(f"Saved scaler to {scaler_path}") + + encoders_path.parent.mkdir(parents=True, exist_ok=True) + joblib.dump(encoders, encoders_path) + log.info(f"Saved encoders to {encoders_path}") + + # Validation checks + log.info("Validation checks:") + derived_cols = ["CN_ratio", "C_layer_delta", "BD_layer_delta", "WHC_proxy", "aridity_index"] + for col in derived_cols: + if col in df.columns: + n_nan = df[col].isna().sum() + n_inf = np.isinf(df[col]).sum() + if n_nan > 0 or n_inf > 0: + warnings.warn(f" {col}: {n_nan} NaN, {n_inf} Inf values") + else: + log.info(f" {col}: no NaN or Inf") + + for col in all_categorical: + if col in df.columns and col in encoders: + n_classes = len(encoders[col].classes_) + min_val = df[col].min() + max_val = df[col].max() + if min_val < 0 or max_val >= n_classes: + warnings.warn( + f" {col}: indices [{min_val}, {max_val}] out of range [0, {n_classes-1}]" + ) + + # Apply canonical column rename (feat_/aux_/target_ prefixes, lowercase meta) + rename_map = build_column_rename_map( + continuous_features=CONTINUOUS_FEATURES, + tabular_categorical_features=TABULAR_CATEGORICAL_FEATURES, + aux_features=AUX_FEATURES, + target_columns=TARGET_COLUMNS, + name_loc_column=NAME_LOC_COLUMN, + metadata_columns=METADATA_COLUMNS, + ) + df = df.rename(columns=rename_map) + + # Prefix name_loc IDs with country name + if "name_loc" in df.columns and "country" in df.columns: + df["name_loc"] = df["country"].astype(str).str.upper() + "_" + df["name_loc"].astype(str) + log.info("Prefixed name_loc IDs with country names") + n_duplicates = df["name_loc"].duplicated().sum() + if n_duplicates > 0: + warnings.warn(f"Found {n_duplicates} duplicate name_loc IDs after prefixing") + else: + log.info(f" No duplicates in name_loc ({df['name_loc'].nunique()} unique IDs)") + + # Convert location_accuracy from text to numerical values + if "location_accuracy" in df.columns: + accuracy_map = { + "high location accuracy": 1, + "medium location accuracy": 2, + "low location accuracy": 3, + } + df["location_accuracy"] = df["location_accuracy"].str.lower().map(accuracy_map) + + # Keep scaler metadata in sync with new column names + if hasattr(scaler, "feature_names_in_") and scaler.feature_names_in_ is not None: + scaler.feature_names_in_ = np.array( + [rename_map.get(n, n) for n in scaler.feature_names_in_] + ) + encoders = {rename_map.get(k, k): v for k, v in encoders.items()} + + # Calculate and save spatial splits (optional) + if spatial_splits: + calculate_spatial_splits( + df=df, + block_size_km=SPATIAL_SPLIT_BLOCK_SIZE_KM, + n_splits=SPATIAL_SPLIT_N_SPLITS, + save_path=spatial_split_path, + ) + else: + log.info("Skipping spatial split calculation (use --spatial_splits to enable)") + + # Save outputs + out_csv_path.parent.mkdir(parents=True, exist_ok=True) + df.to_csv(out_csv_path, index=False, float_format="%.7f") + log.info(f"Saved CSV to {out_csv_path}") + + if out_parquet is not None: + out_parquet_path = Path(out_parquet) + out_parquet_path.parent.mkdir(parents=True, exist_ok=True) + df.to_parquet(out_parquet_path, index=False) + log.info(f"Saved Parquet to {out_parquet_path}") + + log.info("=== Summary ===") + log.info(f" Raw rows loaded: {n_raw}") + log.info(f" Rows excluded by country/year filter: {n_raw - n_after_filter}") + log.info(f" Rows in output: {len(df)}") + log.info(f" Continuous features: {len(CONTINUOUS_FEATURES)}") + log.info(f" Tabular categorical features (feat_): {len(TABULAR_CATEGORICAL_FEATURES)}") + log.info(f" Aux/caption features (aux_): {len(AUX_FEATURES)}") + log.info( + f" Yield range: {df['target_yld_ton_ha'].min():.2f} - {df['target_yld_ton_ha'].max():.2f} t/ha" + ) + log.info(f" Mean yield: {df['target_yld_ton_ha'].mean():.2f} t/ha") + log.info(" Records per country and year:") + counts = df.groupby(["country", "year"]).size().unstack(fill_value=0) + for country, row in counts.iterrows(): + year_counts = [f"{year}: {count}" for year, count in row.items() if count > 0] + log.info(f" {country}: {', '.join(year_counts)}") + + return df + + +if __name__ == "__main__": + logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" + ) + + ap = argparse.ArgumentParser( + description="Build model-ready CSV/Parquet for the crop yield Africa use case." + ) + ap.add_argument( + "--source_csv", + required=True, + help="Path to the raw input CSV (e.g. data/yield_africa/yield_africa_v20260218.csv)", + ) + ap.add_argument( + "--out_csv", + required=True, + help="Path for the output model-ready CSV (e.g. data/yield_africa/model_ready_yield_africa.csv)", + ) + ap.add_argument( + "--out_parquet", + default=None, + help="Optional path for an additional Parquet output.", + ) + ap.add_argument( + "--spatial_splits", + action="store_true", + default=False, + help="Calculate and save spatial cross-validation splits (default: off).", + ) + ap.add_argument( + "--countries", + nargs="+", + default=None, + metavar="CODE", + help="Country codes to include (e.g. --countries ETH KEN TAN). Default: all countries.", + ) + ap.add_argument( + "--years", + nargs="+", + type=int, + default=None, + metavar="YEAR", + help="Years to include (e.g. --years 2018 2019 2020). Default: all years.", + ) + ap.add_argument( + "--exclude_countries", + nargs="+", + default=None, + metavar="CODE", + help="Country codes to exclude (e.g. --exclude_countries MOZ ZIM).", + ) + ap.add_argument( + "--exclude_years", + nargs="+", + type=int, + default=None, + metavar="YEAR", + help="Years to exclude (e.g. --exclude_years 2015 2016).", + ) + args = ap.parse_args() + main( + args.source_csv, + args.out_csv, + args.out_parquet, + args.spatial_splits, + args.countries, + args.years, + args.exclude_countries, + args.exclude_years, + ) diff --git a/src/data_preprocessing/tessera_embeds.py b/src/data_preprocessing/tessera_embeds.py index b4e6852..c2414b8 100644 --- a/src/data_preprocessing/tessera_embeds.py +++ b/src/data_preprocessing/tessera_embeds.py @@ -1,7 +1,11 @@ import math import os +import threading import numpy as np + +# Serialises concurrent reads/writes to the per-directory meta.csv log file. +_meta_csv_lock = threading.Lock() import pandas as pd import rasterio from geotessera import GeoTessera @@ -122,6 +126,14 @@ def get_tessera_embeds( if reproject_memfile: memfiles.append(reproject_memfile) + if not tiles: + print( + f"No TESSERA tiles found for {name_loc} at ({lon:.4f}, {lat:.4f}) year={year}. Skipping." + ) + for mf in memfiles: + mf.close() + return + mosaic, mosaic_transform = merge(tiles) mosaic = mosaic.transpose(1, 2, 0) @@ -134,11 +146,18 @@ def get_tessera_embeds( col, row = crs_to_pixel_coords(lon_utm, lat_utm, mosaic_transform) half = tile_size // 2 row_min = row - half - row_max = row + half + row_max = row + tile_size - half # tile_size - half ensures correct size for odd tile_size col_min = col - half - col_max = col + half + col_max = col + tile_size - half crop = mosaic[row_min:row_max, col_min:col_max, :] + if crop.shape[0] != tile_size or crop.shape[1] != tile_size: + print( + f"Unexpected crop shape {crop.shape} for {name_loc} " + f"(expected {tile_size}x{tile_size}). Skipping." + ) + return + # Save array os.makedirs(save_dir, exist_ok=True) np.save(embed_tile_name, crop) @@ -151,11 +170,14 @@ def get_tessera_embeds( meta_file = f"{save_dir}/meta.csv" - if os.path.exists(meta_file): - meta_df = pd.concat([meta_df, pd.read_csv(meta_file)], ignore_index=True) - - meta_df.to_csv(meta_file, index=False) - print(f"Meta data logged to {meta_file}") + with _meta_csv_lock: + try: + if os.path.exists(meta_file): + meta_df = pd.concat([meta_df, pd.read_csv(meta_file)], ignore_index=True) + meta_df.to_csv(meta_file, index=False) + print(f"Meta data logged to {meta_file}") + except Exception as e: + print(f"Warning: could not update meta.csv ({e}). Tile was saved successfully.") def tessera_from_df( @@ -177,7 +199,7 @@ def tessera_from_df( # Tessera connection cache_dir = os.path.join(cache_dir, "tessera") - gt = GeoTessera(cache_dir=cache_dir) + gt = GeoTessera(cache_dir=cache_dir, embeddings_dir=cache_dir, dataset_version="v1") # Iter each coord n = len(model_ready_df) @@ -245,3 +267,13 @@ def inspect_np_arr_as_tiff( dst.write(arr_to_write[i], i + 1) print(f"Tiff version of np array saved to {file_path}") + + +if __name__ == "__main__": + os.chdir("../..") + + df = pd.read_csv("data/heat_guatemala/model_ready_heat_guatemala.csv") + + tessera_from_df( + df, "data/heat_guatemala/eo/tessera_2024", year=2024, tile_size=10, cache_dir="data/cache" + ) diff --git a/src/data_preprocessing/yield_africa_loco_splits.py b/src/data_preprocessing/yield_africa_loco_splits.py new file mode 100644 index 0000000..cb563b8 --- /dev/null +++ b/src/data_preprocessing/yield_africa_loco_splits.py @@ -0,0 +1,169 @@ +"""Generate leave-one-country-out (LOCO) split files for the yield_africa dataset. + +Location: src/data_preprocessing/yield_africa_loco_splits.py + +For each held-out country one `.pth` file is written to +`{data_dir}/yield_africa/splits/split_loco_{COUNTRY}.pth`. + +Split layout +------------ +- test : all records from the held-out country +- train : 80 % of records from the remaining countries (random, seeded) +- val : 20 % of records from the remaining countries (random, seeded) + +The files are consumed by BaseDataModule when `split_mode: from_file` and +`saved_split_file_name: split_loco_{COUNTRY}.pth`. + +Usage +----- + python src/data_preprocessing/yield_africa_loco_splits.py --data_dir data/ + python src/data_preprocessing/yield_africa_loco_splits.py --data_dir data/ --country KEN +""" + +import argparse +import logging +from pathlib import Path + +import numpy as np +import pandas as pd +import torch + +log = logging.getLogger(__name__) + +# All countries present in the full dataset (must match _ALL_COUNTRIES in +# yield_africa_dataset.py so that the feature encoding is consistent). +ALL_COUNTRIES = ["BF", "BUR", "ETH", "KEN", "MAL", "RWA", "TAN", "ZAM"] + +DATASET_NAME = "yield_africa" +MODEL_READY_CSV = f"model_ready_{DATASET_NAME}.csv" + + +def make_loco_split( + df: pd.DataFrame, + test_country: str, + val_fraction: float = 0.2, + seed: int = 12345, +) -> dict: + """Return a split-indices dict for one held-out country. + + :param df: full model-ready dataframe (must contain 'country' and 'name_loc') + :param test_country: country code to hold out as the test set + :param val_fraction: fraction of the non-test pool to use for validation + :param seed: random seed for the train/val shuffle + :return: dict with 'train_indices', 'val_indices', 'test_indices' as pd.Series of name_locs + """ + test_mask = df["country"] == test_country + test_locs = df.loc[test_mask, "name_loc"].reset_index(drop=True) + + remaining = df.loc[~test_mask, "name_loc"].reset_index(drop=True) + rng = np.random.default_rng(seed) + shuffled = remaining.sample(frac=1, random_state=seed).reset_index(drop=True) + n_val = int(len(shuffled) * val_fraction) + val_locs = shuffled.iloc[:n_val] + train_locs = shuffled.iloc[n_val:] + + return { + "train_indices": train_locs, + "val_indices": val_locs, + "test_indices": test_locs, + } + + +def generate_splits( + data_dir: str, + countries: list[str] | None = None, + val_fraction: float = 0.2, + seed: int = 12345, +) -> None: + """Generate and save LOCO split files for the requested countries. + + :param data_dir: root data directory (same as `paths.data_dir` in configs) + :param countries: list of country codes to generate splits for; None means all + :param val_fraction: fraction of non-test data to use for validation + :param seed: random seed + """ + dataset_dir = Path(data_dir) / DATASET_NAME + csv_path = dataset_dir / MODEL_READY_CSV + splits_dir = dataset_dir / "splits" + + if not csv_path.exists(): + raise FileNotFoundError(f"Model-ready CSV not found: {csv_path}") + + splits_dir.mkdir(parents=True, exist_ok=True) + + df = pd.read_csv(csv_path) + if "country" not in df.columns or "name_loc" not in df.columns: + raise ValueError("CSV must contain 'country' and 'name_loc' columns") + + available = sorted(df["country"].unique().tolist()) + targets = countries if countries is not None else available + + for country in targets: + if country not in available: + log.warning(f"Country '{country}' not found in CSV (available: {available}), skipping") + continue + + split = make_loco_split(df, country, val_fraction=val_fraction, seed=seed) + n_train = len(split["train_indices"]) + n_val = len(split["val_indices"]) + n_test = len(split["test_indices"]) + + out_path = splits_dir / f"split_loco_{country}.pth" + torch.save(split, out_path) + + log.info( + f" {country}: train={n_train}, val={n_val}, test={n_test} " f"-> {out_path.name}" + ) + print( + f" Saved split_loco_{country}.pth " f"(train={n_train}, val={n_val}, test={n_test})" + ) + + +def main() -> None: + logging.basicConfig(level=logging.INFO, format="%(message)s") + + parser = argparse.ArgumentParser( + description="Generate leave-one-country-out split files for yield_africa." + ) + parser.add_argument( + "--data_dir", + type=str, + default="data/", + help="Root data directory (same as paths.data_dir in configs). Default: data/", + ) + parser.add_argument( + "--country", + type=str, + default=None, + help="Single country code to generate a split for. Omit to generate all.", + ) + parser.add_argument( + "--val_fraction", + type=float, + default=0.2, + help="Fraction of non-test records used for validation. Default: 0.2", + ) + parser.add_argument( + "--seed", + type=int, + default=12345, + help="Random seed for the train/val shuffle. Default: 12345", + ) + args = parser.parse_args() + + countries = [args.country] if args.country else None + print( + f"Generating LOCO splits data_dir={args.data_dir} " + f"countries={countries or 'all'} val_fraction={args.val_fraction} seed={args.seed}" + ) + generate_splits( + data_dir=args.data_dir, + countries=countries, + val_fraction=args.val_fraction, + seed=args.seed, + ) + print("Done.") + + +if __name__ == "__main__": + main() diff --git a/src/data_preprocessing/yield_africa_spatial_splits.py b/src/data_preprocessing/yield_africa_spatial_splits.py new file mode 100644 index 0000000..4c9aa98 --- /dev/null +++ b/src/data_preprocessing/yield_africa_spatial_splits.py @@ -0,0 +1,326 @@ +"""Generate spatial-cluster split files for the yield_africa dataset. + +Location: src/data_preprocessing/yield_africa_spatial_splits.py + +Uses DBSCAN with a haversine distance metric to group nearby field locations +into clusters, then assigns whole clusters to train/val/test so that no +geographically close points straddle a split boundary. + +One `.pth` file is written per distance threshold to +`{data_dir}/yield_africa/splits/split_spatial_{distance_km}km.pth`. + +Split layout +------------ +- train : ~70 % of records (cluster-aligned) +- val : ~15 % of records (cluster-aligned) +- test : ~15 % of records (cluster-aligned) + +Proportions are approximate because whole clusters are kept intact. + +The files are consumed by BaseDataModule when `split_mode: from_file` and +`saved_split_file_name: split_spatial_{distance_km}km.pth`. + +Usage +----- + # Generate the default set of splits (10 km, 25 km, 50 km) + python src/data_preprocessing/yield_africa_spatial_splits.py --data_dir data/ + + # Generate a single split at a specific distance + python src/data_preprocessing/yield_africa_spatial_splits.py --data_dir data/ --distance_km 25 + + # Generate multiple distances in one run + python src/data_preprocessing/yield_africa_spatial_splits.py --data_dir data/ --distance_km 10 25 50 + +Notes +----- +- DBSCAN uses sklearn's built-in haversine metric with a BallTree spatial index + and n_jobs=-1, which is significantly faster than a Python geodesic lambda. + Haversine vs. true geodesic error is < 0.1% at distances up to ~100 km. +- `min_samples=2` means a pair of fields within `distance_km` of each other + forms a cluster; isolated fields each become their own singleton cluster. +- All clusters are kept intact across the split boundary, so the test set + contains no locations geographically close to any training location. +""" + +import argparse +import copy +import logging +import time +from pathlib import Path + +import numpy as np +import pandas as pd +import torch +from sklearn.cluster import DBSCAN + +log = logging.getLogger(__name__) + +DATASET_NAME = "yield_africa" +MODEL_READY_CSV = f"model_ready_{DATASET_NAME}.csv" + +# Default distances to generate when no --distance_km is supplied. +DEFAULT_DISTANCES_KM = [10, 25, 50] + +# Split proportions (must sum to 1.0). +TRAIN_FRAC = 0.70 +VAL_FRAC = 0.15 +TEST_FRAC = 0.15 + +# Fixed random seed for GroupShuffleSplit. +SEED = 12345 + + +def make_spatial_split( + df: pd.DataFrame, + distance_m: int, + train_val_test_split: tuple[float, float, float] = (TRAIN_FRAC, VAL_FRAC, TEST_FRAC), + seed: int = SEED, +) -> dict: + """Return a split-indices dict using DBSCAN spatial clustering. + + :param df: full model-ready dataframe (must contain 'lat', 'lon', 'name_loc') + :param distance_m: DBSCAN eps in metres — pairs of fields closer than this + value are assigned to the same cluster + :param train_val_test_split: (train, val, test) proportions, must sum to 1.0 + :param seed: random seed for GroupShuffleSplit + :return: dict with 'train_indices', 'val_indices', 'test_indices' as + pd.Series of name_loc strings, plus 'clusters' as a numpy array of + cluster labels (same length as df) + """ + # Deduplicate to unique (lat, lon) locations before clustering. + # yield_africa has ~9 rows per location (one per year); running DBSCAN on all + # rows produces giant clusters whose row counts are unequal, causing + # GroupShuffleSplit (which splits by cluster count) to produce badly skewed + # train/val/test proportions. Clustering unique locations and propagating + # the split back to all rows fixes this. + unique_locs = df.drop_duplicates(subset=["lat", "lon"]).reset_index(drop=True) + n_unique = len(unique_locs) + n_total = len(df) + if n_unique < n_total: + print( + f" Deduplicating: {n_unique} unique locations from {n_total} rows " + f"(~{n_total / n_unique:.1f} rows/location)." + ) + + # Convert (lat, lon) degrees to radians for sklearn's haversine metric. + # haversine returns arc length on the unit sphere, so eps must be in radians. + # Error vs. true geodesic is < 0.1% at distances up to ~100 km. + _EARTH_RADIUS_M = 6_371_000 + coords_rad = np.radians(np.array([unique_locs["lat"].values, unique_locs["lon"].values]).T) + eps_rad = distance_m / _EARTH_RADIUS_M + + print( + f" Running DBSCAN (eps={distance_m / 1000:.1f} km, haversine, " + f"n={n_unique} locations, n_jobs=-1)..." + ) + t0 = time.time() + clustering = DBSCAN( + eps=eps_rad, + metric="haversine", + algorithm="ball_tree", + min_samples=2, + n_jobs=-1, + ).fit(coords_rad) + print(f" DBSCAN done in {time.time() - t0:.1f}s.") + + # Noise points (label -1) each become their own unique cluster so that + # GroupShuffleSplit can assign them individually to a split partition. + clusters = copy.deepcopy(clustering.labels_) + next_label = int(np.max(clusters)) + 1 + for i, label in enumerate(clusters): + if label == -1: + clusters[i] = next_label + next_label += 1 + + n_clusters = len(np.unique(clusters)) + n_noise = int(np.sum(clustering.labels_ == -1)) + print(f" Clustering done: {n_clusters} location clusters ({n_noise} singleton noise points).") + + train_prop, val_prop, test_prop = train_val_test_split + + # Greedy size-aware cluster assignment. + # + # GroupShuffleSplit splits by cluster *count*, not by sample count. When the + # cluster size distribution is heavily skewed (a few mega-clusters + many + # tiny 2-location clusters), this produces badly imbalanced splits. + # + # Instead: shuffle clusters for randomness, sort by size descending, then + # assign each cluster to whichever split is furthest below its sample-count + # target. Each cluster goes to exactly one split, so there is no overlap. + rng = np.random.default_rng(seed) + unique_clusters, cluster_sizes = np.unique(clusters, return_counts=True) + + # Shuffle first so ties are broken randomly, then sort by descending size. + shuffle_order = rng.permutation(len(unique_clusters)) + unique_clusters = unique_clusters[shuffle_order] + cluster_sizes = cluster_sizes[shuffle_order] + size_order = np.argsort(-cluster_sizes) + unique_clusters = unique_clusters[size_order] + cluster_sizes = cluster_sizes[size_order] + + target_train = n_unique * train_prop + target_val = n_unique * val_prop + target_test = n_unique * test_prop + train_clusters, val_clusters, test_clusters = [], [], [] + count_train, count_val, count_test = 0, 0, 0 + + for cluster_id, size in zip(unique_clusters, cluster_sizes): + deficit_train = target_train - count_train + deficit_val = target_val - count_val + deficit_test = target_test - count_test + if deficit_train >= deficit_val and deficit_train >= deficit_test: + train_clusters.append(cluster_id) + count_train += size + elif deficit_val >= deficit_test: + val_clusters.append(cluster_id) + count_val += size + else: + test_clusters.append(cluster_id) + count_test += size + + train_loc_mask = np.isin(clusters, train_clusters) + val_loc_mask = np.isin(clusters, val_clusters) + test_loc_mask = np.isin(clusters, test_clusters) + + # Sanity checks: every location assigned, no cluster in multiple splits. + assert train_loc_mask.sum() + val_loc_mask.sum() + test_loc_mask.sum() == n_unique + assert len(set(train_clusters) & set(val_clusters)) == 0 + assert len(set(train_clusters) & set(test_clusters)) == 0 + assert len(set(val_clusters) & set(test_clusters)) == 0 + + print( + f" Split (locations): train={train_loc_mask.sum()}, " + f"val={val_loc_mask.sum()}, test={test_loc_mask.sum()}" + ) + + # Propagate location-level split assignments back to all rows by (lat, lon). + train_latlon = set( + zip(unique_locs.loc[train_loc_mask, "lat"], unique_locs.loc[train_loc_mask, "lon"]) + ) + val_latlon = set( + zip(unique_locs.loc[val_loc_mask, "lat"], unique_locs.loc[val_loc_mask, "lon"]) + ) + test_latlon = set( + zip(unique_locs.loc[test_loc_mask, "lat"], unique_locs.loc[test_loc_mask, "lon"]) + ) + row_latlon = list(zip(df["lat"], df["lon"])) + train_mask = np.array([ll in train_latlon for ll in row_latlon]) + val_mask = np.array([ll in val_latlon for ll in row_latlon]) + test_mask = np.array([ll in test_latlon for ll in row_latlon]) + + assert train_mask.sum() + val_mask.sum() + test_mask.sum() == n_total, ( + "Not all rows were assigned to a split — check for (lat, lon) values that " + "don't match any unique location after deduplication." + ) + + name_locs = df["name_loc"].reset_index(drop=True) + return { + "train_indices": name_locs[train_mask].reset_index(drop=True), + "val_indices": name_locs[val_mask].reset_index(drop=True), + "test_indices": name_locs[test_mask].reset_index(drop=True), + "clusters": clusters, + } + + +def generate_splits( + data_dir: str, + distances_km: list[int] | None = None, + seed: int = SEED, +) -> None: + """Generate and save spatial-cluster split files for the requested distances. + + :param data_dir: root data directory (same as `paths.data_dir` in configs) + :param distances_km: list of DBSCAN cluster distances in kilometres; None + uses DEFAULT_DISTANCES_KM + :param seed: random seed for GroupShuffleSplit + """ + if distances_km is None: + distances_km = DEFAULT_DISTANCES_KM + + dataset_dir = Path(data_dir) / DATASET_NAME + csv_path = dataset_dir / MODEL_READY_CSV + splits_dir = dataset_dir / "splits" + + if not csv_path.exists(): + raise FileNotFoundError(f"Model-ready CSV not found: {csv_path}") + + splits_dir.mkdir(parents=True, exist_ok=True) + + df = pd.read_csv(csv_path) + for col in ("lat", "lon", "name_loc"): + if col not in df.columns: + raise ValueError(f"CSV must contain a '{col}' column") + + print(f"Loaded {len(df)} rows from {csv_path}") + + for dist_km in distances_km: + dist_m = dist_km * 1000 + print(f"\nGenerating spatial split at {dist_km} km ({dist_m} m)...") + + split = make_spatial_split(df, distance_m=dist_m, seed=seed) + n_train = len(split["train_indices"]) + n_val = len(split["val_indices"]) + n_test = len(split["test_indices"]) + + out_name = f"split_spatial_{dist_km}km.pth" + out_path = splits_dir / out_name + torch.save(split, out_path) + + print( + f" Saved {out_name} " + f"(train={n_train}, val={n_val}, test={n_test}, " + f"total={n_train + n_val + n_test}/{len(df)})" + ) + log.info( + f" {dist_km}km: train={n_train}, val={n_val}, test={n_test} -> {out_name}" + ) + + +def main() -> None: + logging.basicConfig(level=logging.INFO, format="%(message)s") + + parser = argparse.ArgumentParser( + description="Generate spatial-cluster split files for yield_africa.", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=__doc__, + ) + parser.add_argument( + "--data_dir", + type=str, + default="data/", + help="Root data directory (same as paths.data_dir in configs). Default: data/", + ) + parser.add_argument( + "--distance_km", + type=int, + nargs="+", + default=None, + metavar="KM", + help=( + "Cluster distance threshold(s) in km. " + f"Omit to generate the default set: {DEFAULT_DISTANCES_KM} km." + ), + ) + parser.add_argument( + "--seed", + type=int, + default=SEED, + help=f"Random seed for GroupShuffleSplit. Default: {SEED}", + ) + args = parser.parse_args() + + distances = args.distance_km # None means use defaults + print( + f"Generating spatial splits data_dir={args.data_dir} " + f"distances_km={distances or DEFAULT_DISTANCES_KM} seed={args.seed}" + ) + generate_splits( + data_dir=args.data_dir, + distances_km=distances, + seed=args.seed, + ) + print("\nDone.") + + +if __name__ == "__main__": + main() diff --git a/src/data_preprocessing/yield_africa_tessera_preprocess.py b/src/data_preprocessing/yield_africa_tessera_preprocess.py new file mode 100644 index 0000000..62ae298 --- /dev/null +++ b/src/data_preprocessing/yield_africa_tessera_preprocess.py @@ -0,0 +1,287 @@ +"""Fetch and cache TESSERA embedding tiles for the yield_africa dataset. + +Location: src/data_preprocessing/yield_africa_tessera_preprocess.py + +Tiles are saved as NumPy arrays to: + {data_dir}/yield_africa/eo/tessera/tessera_{name_loc}.npy + +This matches the path built by BaseDataset.add_modality_paths_to_df() and +loaded by BaseDataset.setup_tessera() at training time. + +Unlike tessera_from_df() (which takes a single fixed year), this script +uses each record's own `year` column so that per-record inter-annual +phenology is captured — the key signal missing from the static tabular +features. + +The script is resumable: get_tessera_embeds() skips files that already +exist, so interrupted runs can be continued safely. + +Usage +----- + # All records + python src/data_preprocessing/yield_africa_tessera_preprocess.py \\ + --data_dir data/ + + # Single country, useful for incremental fetching + python src/data_preprocessing/yield_africa_tessera_preprocess.py \\ + --data_dir data/ --countries KEN RWA + + # Smaller tile size (faster, less context) + python src/data_preprocessing/yield_africa_tessera_preprocess.py \\ + --data_dir data/ --tile_size 5 +""" + +import argparse +import logging +import os +import socket +import sys +import threading +from concurrent.futures import FIRST_COMPLETED, ThreadPoolExecutor, wait +from pathlib import Path + +# Ensure the project root is on sys.path when the script is run directly. +sys.path.insert(0, str(Path(__file__).resolve().parents[2])) + +import pandas as pd +from geotessera import GeoTessera + +from src.data_preprocessing.tessera_embeds import get_tessera_embeds + +log = logging.getLogger(__name__) + +DATASET_NAME = "yield_africa" +MODEL_READY_CSV = f"model_ready_{DATASET_NAME}.csv" + +# Tile size in pixels. A small tile (e.g. 9) captures local context around +# each plot point without pulling in large surrounding areas. Consistent +# with the typical smallholder farm size in the region. +DEFAULT_TILE_SIZE = 9 + + +def fetch_tessera_tiles( + data_dir: str, + tile_size: int = DEFAULT_TILE_SIZE, + countries: list[str] | None = None, + years: list[int] | None = None, + cache_dir: str | None = None, + workers: int = 2, +) -> None: + """Fetch TESSERA tiles for every record in the yield_africa CSV. + + :param data_dir: root data directory (same as ``paths.data_dir`` in configs) + :param tile_size: spatial extent of each tile in pixels + :param countries: optional list of country codes to restrict fetching + :param years: optional list of years to restrict fetching + :param cache_dir: base directory for all TESSERA cache files. GeoTessera's + internal registry is stored here; the large raw downloaded source tiles + (``global_0.1_degree_representation/`` etc.) are kept in the ``raw/`` + subfolder. Defaults to the ``TESSERA_EMBEDDINGS_DIR`` env var when set, + otherwise ``{data_dir}/cache/tessera``. Point this at an external drive + when disk space is limited. + :param workers: number of parallel download threads. Each thread keeps its + own GeoTessera instance to avoid shared state. Default: 2 (external + drive I/O is usually the bottleneck; more workers add contention). + """ + dataset_dir = Path(data_dir) / DATASET_NAME + csv_path = dataset_dir / MODEL_READY_CSV + save_dir = dataset_dir / "eo" / "tessera" + + if not csv_path.exists(): + raise FileNotFoundError(f"Model-ready CSV not found: {csv_path}") + + save_dir.mkdir(parents=True, exist_ok=True) + + if cache_dir is None: + cache_dir = os.environ.get("TESSERA_EMBEDDINGS_DIR") or str( + Path(data_dir) / "cache" / "tessera" + ) + + embeddings_dir = str(Path(cache_dir) / "raw") + + df = pd.read_csv(csv_path) + + # Optional filters (consistent with YieldAfricaDataset filter params) + if countries is not None: + df = df[df["country"].isin(countries)] + log.info(f"Filtered to countries {countries}: {len(df)} records") + if years is not None: + df = df[df["year"].isin(years)] + log.info(f"Filtered to years {years}: {len(df)} records") + + n_total = len(df) + n_existing = sum( + 1 for _, row in df.iterrows() if (save_dir / f"tessera_{row.name_loc}.npy").exists() + ) + n_to_fetch = n_total - n_existing + + print( + f"Records: {n_total} total, {n_existing} already cached, " + f"{n_to_fetch} to fetch (tile_size={tile_size}, workers={workers})\n" + f" cache_dir : {cache_dir}\n" + f" embeddings_dir: {embeddings_dir}" + ) + + # Build GeoTessera constructor kwargs shared across all threads. + # Each thread creates its own instance (thread-local) to avoid sharing + # internal state such as open file handles and rasterio MemoryFiles. + _default_registry_dir = Path.home() / ".cache" / "geotessera" + _use_local_registry = (_default_registry_dir / "registry.parquet").exists() + _gt_kwargs: dict = { + # Skip SHA-256 hash verification after each tile download. Verification + # reads the entire (potentially large) file again after download, adding + # noticeable CPU time per tile and making progress look stalled. + "verify_hashes": False, + } + _gt_kwargs["embeddings_dir"] = embeddings_dir + + _thread_local = threading.local() + + def _get_gt() -> GeoTessera: + """Return a thread-local GeoTessera instance, creating it on first use.""" + if not hasattr(_thread_local, "gt"): + if _use_local_registry: + _thread_local.gt = GeoTessera(registry_dir=_default_registry_dir, **_gt_kwargs) + else: + _thread_local.gt = GeoTessera(cache_dir=cache_dir, **_gt_kwargs) + return _thread_local.gt + + def _fetch_one(row) -> str: + get_tessera_embeds( + lon=row.lon, + lat=row.lat, + name_loc=row.name_loc, + year=int(row.year), + save_dir=str(save_dir), + tile_size=tile_size, + tessera_con=_get_gt(), + ) + return row.name_loc + + # Bound all socket operations (urllib HTTP requests inside geotessera). + # Without this, a stalled connection blocks the thread until the OS TCP + # keepalive fires, which can take many minutes. + SOCKET_TIMEOUT = 60 # seconds per socket operation + HEARTBEAT = 30 # print a heartbeat when no future completes this fast + TILE_TIMEOUT = 600 # give up warning after 10 min of complete silence + socket.setdefaulttimeout(SOCKET_TIMEOUT) + + rows = [row for _, row in df.iterrows()] + done = 0 + pending: set = set() + silent_seconds = 0 + + try: + with ThreadPoolExecutor(max_workers=workers) as pool: + # Submit all jobs up-front; the pool queues them internally. + futures = {pool.submit(_fetch_one, row): row.name_loc for row in rows} + pending = set(futures) + + while pending: + finished, pending = wait(pending, timeout=HEARTBEAT, return_when=FIRST_COMPLETED) + + if not finished: + silent_seconds += HEARTBEAT + print( + f" ... still working — {done}/{n_total} done, " + f"{len(pending)} pending, {silent_seconds}s since last completion" + ) + if silent_seconds >= TILE_TIMEOUT: + print( + f" WARNING: no progress in {TILE_TIMEOUT}s, something may be stuck." + ) + continue + + silent_seconds = 0 + for fut in finished: + done += 1 + if done % 100 == 0 or done == n_total: + print(f" {done}/{n_total}") + try: + fut.result() + except Exception as exc: + print(f" ERROR fetching {futures[fut]}: {exc}") + + except KeyboardInterrupt: + print("\nInterrupted — cancelling queued futures (in-flight downloads will finish).") + for fut in pending: + fut.cancel() + + print(f"Done. Tiles saved to: {save_dir}") + + +def main() -> None: + logging.basicConfig(level=logging.INFO, format="%(message)s") + + parser = argparse.ArgumentParser( + description="Fetch TESSERA embedding tiles for the yield_africa dataset." + ) + parser.add_argument( + "--data_dir", + type=str, + default="data/", + help="Root data directory (same as paths.data_dir in configs). Default: data/", + ) + parser.add_argument( + "--tile_size", + type=int, + default=DEFAULT_TILE_SIZE, + help=f"Tile size in pixels. Default: {DEFAULT_TILE_SIZE}", + ) + parser.add_argument( + "--countries", + nargs="+", + default=None, + metavar="CODE", + help="Country codes to restrict fetching (e.g. KEN RWA). Default: all", + ) + parser.add_argument( + "--years", + nargs="+", + type=int, + default=None, + metavar="YEAR", + help="Years to restrict fetching (e.g. 2019 2020). Default: all", + ) + parser.add_argument( + "--cache_dir", + type=str, + default=None, + help=( + "Base directory for all TESSERA cache files. " + "GeoTessera's registry is stored here; large raw source tiles go in " + "the raw/ subfolder. " + "Falls back to the TESSERA_EMBEDDINGS_DIR env var, then " + "{data_dir}/cache/tessera. Set TESSERA_EMBEDDINGS_DIR in .env to " + "avoid passing this flag every run." + ), + ) + parser.add_argument( + "--workers", + type=int, + default=2, + help=( + "Number of parallel download threads. Default: 2. " + "When writing to an external drive too many workers can cause I/O " + "bottlenecks. Reduce the number of workers to improve throughput." + ), + ) + args = parser.parse_args() + + print( + f"Fetching TESSERA tiles data_dir={args.data_dir} " + f"tile_size={args.tile_size} countries={args.countries or 'all'} " + f"years={args.years or 'all'}" + ) + fetch_tessera_tiles( + data_dir=args.data_dir, + tile_size=args.tile_size, + countries=args.countries, + years=args.years, + cache_dir=args.cache_dir, + workers=args.workers, + ) + + +if __name__ == "__main__": + main() diff --git a/src/models/base_model.py b/src/models/base_model.py index 407d08a..89fbaf6 100644 --- a/src/models/base_model.py +++ b/src/models/base_model.py @@ -16,8 +16,6 @@ def __init__( scheduler: torch.optim.lr_scheduler, loss_fn: BaseLossFn, metrics: MetricsWrapper, - num_classes: int | None = None, - tabular_dim: int | None = None, ) -> None: """Interface for any model. @@ -26,23 +24,28 @@ def __init__( :param scheduler: scheduler for the model weight update :param loss_fn: loss function :param metrics: metrics to track for model performance estimation - :param num_classes: number of classes to predict """ super().__init__() self.save_hyperparameters( - ignore=["loss_fn", "eo_encoder", "prediction_head", "text_encoder", "metrics"] + ignore=["loss_fn", "geo_encoder", "prediction_head", "text_encoder", "metrics"] ) self.trainable_modules = trainable_modules - self.num_classes = num_classes - self.tabular_dim = tabular_dim + self.num_classes: int | None = None + self.tabular_dim: int | None = None self.loss_fn = loss_fn self.metrics = metrics + @abstractmethod + def setup(self, stage: str) -> None: + """Updates model based data-bound configurations (through datamodule), This method is + called after trainer is initialized and datamodule is available.""" + pass + @final def freezer(self) -> None: """Freezes modules based on provided trainable modules.""" - self.trainable_modules = tuple(self.trainable_modules) or tuple() + trainable_modules = tuple(self.trainable_modules) or tuple() # Store higher level module names for printing of trainable parts trainable = set() @@ -50,17 +53,30 @@ def freezer(self) -> None: # Freeze modules for name, param in self.named_parameters(): # Enable exceptions - if name.startswith(self.trainable_modules): + if name.startswith(trainable_modules): param.requires_grad = True - top_name = name.split(".", 2)[:2] - trainable.add(".".join(top_name)) + trainable.add(name) else: # Freeze the rest param.requires_grad = False - # Set module modes correctly + # Set module modes correctly. + # A module should be in train() if: + # - it IS a trainable module (name == t), or + # - it is a CHILD of a trainable module (name starts with t + "."), or + # - it is an ANCESTOR of a trainable module (t starts with name + "."), + # so that container modules reflect the correct mode, or + # - it is the root module (""), which must be train when any child is. + def _in_train_scope(name: str) -> bool: + if not name: # root module + return bool(trainable_modules) + for t in trainable_modules: + if name == t or name.startswith(t + ".") or t.startswith(name + "."): + return True + return False + for name, module in self.named_modules(): - if any(t.startswith(name) for t in self.trainable_modules): + if _in_train_scope(name): module.train() else: module.eval() diff --git a/src/models/components/eo_encoders/average_encoder.py b/src/models/components/eo_encoders/average_encoder.py deleted file mode 100644 index ccec36e..0000000 --- a/src/models/components/eo_encoders/average_encoder.py +++ /dev/null @@ -1,72 +0,0 @@ -from typing import Dict, override - -import torch -import torch.nn.functional as F -from torch import nn - -from src.models.components.eo_encoders.base_eo_encoder import BaseEOEncoder - - -class AverageEncoder(BaseEOEncoder): - def __init__( - self, - output_dim: int | None = None, - eo_data_name="aef", - ) -> None: - super().__init__() - - dict_n_bands_default = {"s2": 4, "aef": 64, "tessera": 128} - self.allowed_eo_data_names: list[str] = list(dict_n_bands_default.keys()) - - assert ( - eo_data_name in dict_n_bands_default - ), f"eo_data_name must be one of {self.allowed_eo_data_names}, got {eo_data_name}" - self.eo_data_name = eo_data_name - - if output_dim is None or output_dim == dict_n_bands_default[eo_data_name]: - self.output_dim = dict_n_bands_default[eo_data_name] - self.extra_projector = None - self.eo_encoder = self._average - else: - assert ( - type(output_dim) is int and output_dim > 0 - ), f"output_dim must be positive int, got {output_dim}" - self.output_dim = output_dim - self.extra_projector = nn.Linear(dict_n_bands_default[eo_data_name], output_dim) - self.eo_encoder = self._average_and_project - - def _average(self, x: torch.Tensor) -> torch.Tensor: - """Averages the input tensor over spatial dimensions. - - :param x: input tensor of shape (B, C, H, W) - :return: averaged tensor of shape (B, C) - """ - return x.mean(dim=(-2, -1)) - - def _average_and_project(self, x: torch.Tensor) -> torch.Tensor: - """Averages the input tensor over spatial dimensions and projects to output_dim. - - :param x: input tensor of shape (B, C, H, W) - :return: projected tensor of shape (B, output_dim) - """ - x_avg = x.mean(dim=(-2, -1)) - x_proj = self.extra_projector(x_avg) - return x_proj - - @override - def forward(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: - eo_data = batch.get("eo", {}) - dtype = self.dtype - if eo_data.dtype != dtype: - eo_data = eo_data.to(dtype) - feats = self.eo_encoder(eo_data[self.eo_data_name]) - # n_nans = torch.sum(torch.isnan(feats)).item() - # assert ( - # n_nans == 0 - # ), f"AverageEncoder output contains {n_nans}/{feats.numel()} NaNs PRIOR to normalization with data min {eo_data[self.eo_data_name].min()} and max {eo_data[self.eo_data_name].max()}." - - return feats.to(dtype) - - -if __name__ == "__main__": - _ = AverageEncoder(None, None) diff --git a/src/models/components/eo_encoders/base_eo_encoder.py b/src/models/components/eo_encoders/base_eo_encoder.py deleted file mode 100644 index 643250c..0000000 --- a/src/models/components/eo_encoders/base_eo_encoder.py +++ /dev/null @@ -1,38 +0,0 @@ -from abc import ABC, abstractmethod -from typing import Dict - -import torch -from torch import nn - - -class BaseEOEncoder(nn.Module, ABC): - def __init__(self) -> None: - super().__init__() - self.eo_encoder: nn.Module | None = None - self.output_dim: int | None = None - - # placeholders - self.allowed_eo_data_names: list[str] | None = None - self.eo_data_name: str | None = None - - @abstractmethod - def forward(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: - pass - - @property - def device(self) -> torch.device: - devices = {p.device for p in self.parameters()} - if len(devices) != 1: - raise RuntimeError("EO encoder is on multiple devices") - return devices.pop() - - @property - def dtype(self) -> torch.dtype: - dtypes = {p.dtype for p in self.parameters()} - if len(dtypes) != 1: - raise RuntimeError("EO encoder has multiple dtypes") - return dtypes.pop() - - -if __name__ == "__main__": - _ = BaseEOEncoder(None) diff --git a/src/models/components/eo_encoders/geoclip.py b/src/models/components/eo_encoders/geoclip.py deleted file mode 100644 index e9d1834..0000000 --- a/src/models/components/eo_encoders/geoclip.py +++ /dev/null @@ -1,41 +0,0 @@ -from typing import Dict, override - -import torch -from geoclip import LocationEncoder -from torch.nn import functional as F - -from src.models.components.eo_encoders.base_eo_encoder import BaseEOEncoder - - -class GeoClipCoordinateEncoder(BaseEOEncoder): - def __init__( - self, - eo_data_name="coords", - ) -> None: - super().__init__() - self.eo_encoder = LocationEncoder() - self.output_dim = self.eo_encoder.LocEnc0.head[0].out_features - self.allowed_eo_data_names = ["coords"] - assert ( - eo_data_name in self.allowed_eo_data_names - ), f"eo_data_name must be one of {self.allowed_eo_data_names}, got {eo_data_name}" - self.eo_data_name = eo_data_name - - @override - def forward( - self, - batch: Dict[str, torch.Tensor], - ) -> torch.Tensor: - - coords = batch.get("eo", {}).get("coords") - - dtype = self.dtype - if coords.dtype != dtype: - coords = coords.to(dtype) - feats = self.eo_encoder(coords) - - return feats.to(dtype) - - -if __name__ == "__main__": - _ = GeoClipCoordinateEncoder(None) diff --git a/src/models/components/eo_encoders/multimodal_encoder.py b/src/models/components/eo_encoders/multimodal_encoder.py deleted file mode 100644 index 7dbf61d..0000000 --- a/src/models/components/eo_encoders/multimodal_encoder.py +++ /dev/null @@ -1,93 +0,0 @@ -"""Unified multimodal encoder for EO data. - -Controlled entirely via constructor flags: - - use_coords: encode lat/lon with GeoClip - - use_tabular: encode feat_* tabular columns -""" - -from typing import Dict, override - -import torch -from torch import nn - -from src.models.components.eo_encoders.base_eo_encoder import BaseEOEncoder -from src.models.components.eo_encoders.geoclip import GeoClipCoordinateEncoder - - -class MultiModalEncoder(BaseEOEncoder): - """ - - coords only (use_coords=True, use_tabular=False) - - tabular only (use_coords=False, use_tabular=True) - - coords + tabular (use_coords=True, use_tabular=True) - """ - - def __init__( - self, - use_coords: bool = True, - use_tabular: bool = False, - tab_embed_dim: int = 64, - tabular_dim: int = None, - ) -> None: - super().__init__() - - assert use_coords or use_tabular, "At least one of use_coords or use_tabular must be True." - - self.use_coords = use_coords - self.use_tabular = use_tabular - self.tab_embed_dim = tab_embed_dim - self._tabular_ready = False - - coords_dim = 0 - if use_coords: - self.coords_encoder = GeoClipCoordinateEncoder() - coords_dim = self.coords_encoder.output_dim # 512 - - self._coords_dim = coords_dim - - # Built only if dim is already known - if use_tabular and tabular_dim is not None: - self.build_tabular_branch(tabular_dim) - elif use_tabular: - self.tabular_proj = None - else: - self.output_dim = coords_dim - - # ------------------------------------------------------------------ - # Public API - # ------------------------------------------------------------------ - - def build_tabular_branch(self, tabular_dim: int) -> None: - """Build (or rebuild) the tabular projection layer.""" - if self._tabular_ready and hasattr(self, "_last_tabular_dim"): - if self._last_tabular_dim == tabular_dim: - return # already built with correct dim - - self.tabular_proj = nn.Sequential( - nn.LayerNorm(tabular_dim), - nn.Linear(tabular_dim, self.tab_embed_dim), - nn.ReLU(), - ) - self._last_tabular_dim = tabular_dim - self._tabular_ready = True - self.output_dim = self._coords_dim + self.tab_embed_dim - - # ------------------------------------------------------------------ - # Forward - # ------------------------------------------------------------------ - - @override - def forward(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: - parts = [] - - if self.use_coords: - parts.append(self.coords_encoder(batch)) # (B, 512) - - if self.use_tabular: - assert self._tabular_ready, ( - "Tabular branch not built yet. Call build_tabular_branch(tabular_dim) first, " - "or pass tabular_dim to the constructor." - ) - tab = batch["eo"]["tabular"].float() # (B, tabular_dim) - parts.append(self.tabular_proj(tab)) # (B, tab_embed_dim) - - return torch.cat(parts, dim=-1) diff --git a/src/models/components/eo_encoders/__init__.py b/src/models/components/geo_encoders/__init__.py similarity index 100% rename from src/models/components/eo_encoders/__init__.py rename to src/models/components/geo_encoders/__init__.py diff --git a/src/models/components/geo_encoders/average_encoder.py b/src/models/components/geo_encoders/average_encoder.py new file mode 100644 index 0000000..0c0eaf6 --- /dev/null +++ b/src/models/components/geo_encoders/average_encoder.py @@ -0,0 +1,46 @@ +from typing import Dict, List, override + +import torch +from torch import nn + +from src.models.components.geo_encoders.base_geo_encoder import BaseGeoEncoder + + +class AverageEncoder(BaseGeoEncoder): + def __init__( + self, + geo_data_name="aef", + ) -> None: + """Encoder to avreage tile values into a 1D vector. + + :param geo_data_name: modality name + """ + super().__init__() + + self.dict_n_bands_default = {"s2": 4, "aef": 64, "tessera": 128} + self.allowed_geo_data_names: list[str] = list(self.dict_n_bands_default.keys()) + + assert ( + geo_data_name in self.allowed_geo_data_names + ), f"geo_data_name must be one of {self.allowed_geo_data_names}, got {geo_data_name}" + self.geo_data_name = geo_data_name + + @override + def setup(self) -> List[str]: + """Configures networks, data-dependent parts. + + Gets called in model.setup() method. Returns names of any new module configured to be added + to the trainable modules list. + """ + self.output_dim = self.dict_n_bands_default[self.geo_data_name] + self.geo_encoder = nn.Identity() + print(f"Model set up with average geo-encoder for {self.geo_data_name}") + + return [] + + @override + def forward(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: + """Data forward pass through the encoder.""" + tile = batch.get("eo", {}).get(self.geo_data_name) + feats = self.geo_encoder(tile.mean(dim=(-2, -1))) + return feats diff --git a/src/models/components/geo_encoders/base_geo_encoder.py b/src/models/components/geo_encoders/base_geo_encoder.py new file mode 100644 index 0000000..f79f3f4 --- /dev/null +++ b/src/models/components/geo_encoders/base_geo_encoder.py @@ -0,0 +1,58 @@ +from abc import ABC, abstractmethod +from typing import Dict + +import torch +from torch import nn + + +class BaseGeoEncoder(nn.Module, ABC): + def __init__(self) -> None: + super().__init__() + self.geo_encoder: nn.Module | None = None + self.output_dim: int | None = None + + # placeholders + self.allowed_geo_data_names: list[str] | None = None + self.geo_data_name: str | None = None + + @abstractmethod + def forward(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: + pass + + @property + def device(self) -> torch.device | None: + devices = {p.device for p in self.parameters()} + if len(devices) > 1: + raise RuntimeError("GEO encoder is on multiple devices") + elif len(devices) == 0: + return None + return devices.pop() + + @property + def dtype(self) -> torch.dtype | None: + dtypes = {p.dtype for p in self.parameters()} + if len(dtypes) > 1: + raise RuntimeError("GEO encoder has multiple dtypes") + elif len(dtypes) == 0: + return None + return dtypes.pop() + + @abstractmethod + def setup(self) -> list[str]: + """Configures networks, data-dependent parts. + + Gets called in model.setup() method. Returns names of any new module configured to be added + to the trainable modules list. + """ + pass + + def add_projector(self, projected_dim: int) -> None: + """Adds an extra linear projection layer to the geo encoder. + + NB: is not used by default, needs to be called explicitly in forward(). + """ + self.extra_projector = nn.Linear(self.output_dim, projected_dim, dtype=self.dtype) + print( + f"Extra linear projection layer added with mapping dimension {self.output_dim} to {projected_dim}" + ) + self.output_dim = projected_dim diff --git a/src/models/components/eo_encoders/cnn_encoder.py b/src/models/components/geo_encoders/cnn_encoder.py similarity index 81% rename from src/models/components/eo_encoders/cnn_encoder.py rename to src/models/components/geo_encoders/cnn_encoder.py index d939a38..34f8f48 100644 --- a/src/models/components/eo_encoders/cnn_encoder.py +++ b/src/models/components/geo_encoders/cnn_encoder.py @@ -1,14 +1,13 @@ -from typing import Dict, override +from typing import Dict, List, override import torch import torchvision.models as models from torch import nn -from torch.nn import functional as F -from src.models.components.eo_encoders.base_eo_encoder import BaseEOEncoder +from src.models.components.geo_encoders.base_geo_encoder import BaseGeoEncoder -class CNNEncoder(BaseEOEncoder): +class CNNEncoder(BaseGeoEncoder): """Convolutional neural network EO encoder. Adapted from PECL. :param backbone: backbone model to use (resnet) @@ -25,7 +24,7 @@ def __init__( pretrained_cnn="imagenet", resnet_version=18, freezing_strategy="all", - eo_data_name="s2", + geo_data_name="s2", input_n_bands: int | None = None, output_dim=512, ) -> None: @@ -36,9 +35,9 @@ def __init__( self.resnet_version = resnet_version self.freezing_strategy = freezing_strategy - self.allowed_eo_data_names = ["s2", "aef", "tessera"] - assert eo_data_name in self.allowed_eo_data_names - self.eo_data_name = eo_data_name + self.allowed_geo_data_names = ["s2", "aef", "tessera"] + assert geo_data_name in self.allowed_geo_data_names + self.geo_data_name = geo_data_name self.set_n_input_bands(input_n_bands) assert ( @@ -46,23 +45,23 @@ def __init__( ), f"input_n_bands must be int >=3, got {self.input_n_bands}" self.output_dim = output_dim - self.eo_encoder = self.get_backbone() + self.geo_encoder = self.get_backbone() def set_n_input_bands(self, n_bands: int | None = None) -> None: - """Sets number of input bands based on eo_data_name if n_bands is None. + """Sets number of input bands based on geo_data_name if n_bands is None. :param n_bands: number of input bands :return: None """ - if n_bands is None: # infer from eo_data_name - if self.eo_data_name == "s2": + if n_bands is None: # infer from geo_data_name + if self.geo_data_name == "s2": self.input_n_bands = 4 - elif self.eo_data_name == "aef": + elif self.geo_data_name == "aef": self.input_n_bands = 64 - elif self.eo_data_name == "tessera": + elif self.geo_data_name == "tessera": self.input_n_bands = 128 print( - f"[CNNEncoder] Inferred input_n_bands={self.input_n_bands} for eo_data_name='{self.eo_data_name}'" + f"[CNNEncoder] Inferred input_n_bands={self.input_n_bands} for geo_data_name='{self.geo_data_name}'" ) else: self.input_n_bands = n_bands @@ -132,6 +131,13 @@ def get_backbone(self): else: raise ValueError(f"Unsupported backbone: {self.backbone}") + @override + def setup(self) -> List[str]: + # TODO: could you make sure new layers are returned here to be added to trainable parts? + # Maybe move the get_backbone method in here? + print(f"Model setup with cnn geo-encoder for {self.geo_data_name}") + return [] + @override def forward( self, @@ -147,11 +153,11 @@ def forward( dtype = self.dtype if eo_data.dtype != dtype: eo_data = eo_data.to(dtype) - feats = self.eo_encoder(eo_data[self.eo_data_name]) + feats = self.geo_encoder(eo_data[self.geo_data_name]) # n_nans = torch.sum(torch.isnan(feats)).item() # assert ( # n_nans == 0 - # ), f"CNNEncoder output contains {n_nans}/{feats.numel()} NaNs PRIOR to normalization with data min {eo_data[self.eo_data_name].min()} and max {eo_data[self.eo_data_name].max()}." + # ), f"CNNEncoder output contains {n_nans}/{feats.numel()} NaNs PRIOR to normalization with data min {eo_data[self.geo_data_name].min()} and max {eo_data[self.geo_data_name].max()}." return feats.to(dtype) diff --git a/src/models/components/geo_encoders/encoder_wrapper.py b/src/models/components/geo_encoders/encoder_wrapper.py new file mode 100644 index 0000000..9cdf8ad --- /dev/null +++ b/src/models/components/geo_encoders/encoder_wrapper.py @@ -0,0 +1,154 @@ +from typing import Any, Dict, List, override + +import torch +import torch.nn as nn + +from src.models.components.geo_encoders.base_geo_encoder import BaseGeoEncoder +from src.models.components.geo_encoders.tabular_encoder import TabularEncoder + + +class EncoderWrapper(BaseGeoEncoder): + """Wrapper class for multi-modal encoders.""" + + def __init__( + self, + encoder_branches: List[Dict[str, Any]], + fusion_strategy: str = "concat", + ): + super().__init__() + + self.output_dim = None + + self._reformat_set_branches(encoder_branches) + + assert fusion_strategy in ["mean", "concat", "none"], ValueError( + f'Unsupported fusion strategy "{fusion_strategy}"' + ) + self.fusion_strategy = fusion_strategy + + def _reformat_set_branches(self, encoder_branches: List[Dict[str, Any]]): + """Reformatting to allow registering modules properly.""" + self.encoder_branches = nn.ModuleList() + + for branch in encoder_branches: + module_dict = nn.ModuleDict({"encoder": branch["encoder"]}) + + if branch.get("projector") is not None: + module_dict["projector"] = branch["projector"] + + self.encoder_branches.append(module_dict) + + @override + def setup(self) -> List[str]: + new_modules = [] + + # Configure/initialise missing/conditional parts + for i, branch in enumerate(self.encoder_branches): + # Setup encoder + encoder = branch["encoder"] + + # Configure tabular encoder + if isinstance(encoder, TabularEncoder): + if self.tabular_dim is None: + raise ValueError("TabularEncoder requires tabular_dim") + encoder.set_tabular_input_dim(self.tabular_dim) + + new_parts = encoder.setup() + new_modules.extend( + [f"encoder_branches.{str(i)}.encoder.{p}" for p in new_parts] + if len(new_parts) != 0 + else [] + ) + + # Configure adapter/projector if requested + if "projector" in branch: + projector = branch["projector"] + + intermediate_dim = encoder.output_dim + projector.set_input_dim(input_dim=intermediate_dim) + new_parts = projector.setup() + new_modules.extend( + [f"encoder_branches.{str(i)}.projector.{p}" for p in new_parts] + if len(new_parts) != 0 + else [] + ) + + self.set_output_dim() + return new_modules + + def set_tabular_input_dim(self, tabular_dim=None): + """Set tabular dimension if there is tabular encoder.""" + self.tabular_dim = None + + for branch in self.encoder_branches: + branch_out_dim = branch["encoder"] + if isinstance(branch_out_dim, TabularEncoder): + self.tabular_dim = tabular_dim + return + + def set_output_dim(self): + """Calculates the output dimension.""" + + # Collect all output dimensions + output_dims = [] + for branch in self.encoder_branches: + branch_out_dim = branch["encoder"].output_dim + + if "projector" in branch: + projector = branch["projector"] + branch_out_dim = projector.output_dim + + output_dims.append(branch_out_dim) + + # Combine output dimensions + if self.fusion_strategy == "concat": + self.output_dim = sum(output_dims) + elif self.fusion_strategy == "mean": + if set(output_dims) != 1: + raise ValueError( + f"Encoder branches produces different output dimensions {output_dims} and cannot be averaged." + ) + self.output_dim = output_dims[0] + + @override + def forward(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: + branch_feats = [] + for branch in self.encoder_branches: + feats = branch["encoder"].forward(batch) # each encoder knows what modality it needs + + if "projector" in branch: + feats = branch["projector"].forward(feats) + + branch_feats.append(feats) + + if self.fusion_strategy == "concat": + return torch.cat(branch_feats, dim=1) + return torch.mean(branch_feats, dim=1) + + @property + def device(self): + devices = set() + for branch in self.encoder_branches: + encoder = branch["encoder"] + devices.update({p.device for p in encoder.parameters()}) + if "projector" in branch: + projector = branch["projector"] + devices.update({p.device for p in projector.parameters()}) + + if len(devices) != 1: + raise RuntimeError("GEO encoder is on multiple devices") + return devices.pop() + + @property + def dtype(self) -> torch.dtype: + dtypes = set() + for branch in self.encoder_branches: + encoder = branch["encoder"] + dtypes.update({p.dtype for p in encoder.parameters()}) + if "projector" in branch: + projector = branch["projector"] + dtypes.update({p.dtype for p in projector.parameters()}) + + if len(dtypes) != 1: + raise RuntimeError("GEO encoder is on multiple devices") + return dtypes.pop() diff --git a/src/models/components/geo_encoders/geoclip.py b/src/models/components/geo_encoders/geoclip.py new file mode 100644 index 0000000..38dffdb --- /dev/null +++ b/src/models/components/geo_encoders/geoclip.py @@ -0,0 +1,47 @@ +from typing import Dict, List, override + +import torch +from geoclip import LocationEncoder +from torch.nn import functional as F + +from src.models.components.geo_encoders.base_geo_encoder import BaseGeoEncoder + + +class GeoClipCoordinateEncoder(BaseGeoEncoder): + def __init__( + self, + geo_data_name="coords", + ) -> None: + super().__init__() + + self.allowed_geo_data_names = ["coords"] + assert ( + geo_data_name in self.allowed_geo_data_names + ), f"geo_data_name must be one of {self.allowed_geo_data_names}, got {geo_data_name}" + self.geo_data_name = geo_data_name + + @override + def setup(self) -> List[str]: + self.geo_encoder = LocationEncoder() + self.output_dim = self.geo_encoder.LocEnc0.head[0].out_features + print("Model setup with GeoClip coordinate encoder") + return [] + + @override + def forward( + self, + batch: Dict[str, torch.Tensor], + ) -> torch.Tensor: + + coords = batch.get("eo", {}).get("coords") + + dtype = self.dtype + if coords.dtype != dtype: + coords = coords.to(dtype) + feats = self.geo_encoder(coords) + + return feats.to(dtype) + + +if __name__ == "__main__": + _ = GeoClipCoordinateEncoder(None) diff --git a/src/models/components/geo_encoders/mlp_projector.py b/src/models/components/geo_encoders/mlp_projector.py new file mode 100644 index 0000000..e622216 --- /dev/null +++ b/src/models/components/geo_encoders/mlp_projector.py @@ -0,0 +1,52 @@ +from typing import List, override + +import torch +from torch import nn + +from src.models.components.geo_encoders.base_geo_encoder import BaseGeoEncoder + + +class MLPProjector(BaseGeoEncoder): + def __init__( + self, + output_dim: int, + input_dim: int | None = None, + nn_layers: int = 2, + hidden_dim: int = 256, + ) -> None: + super().__init__() + + self.output_dim = output_dim + self.input_dim = input_dim + self.nn_layers = nn_layers + self.hidden_dim = hidden_dim + + # Placeholder + self.net: nn.Module | None = None + + @override + def setup(self) -> List[str]: + self.configure_nn() + print("Model setup with MLP projector") + return ["net"] + + def set_input_dim(self, input_dim: int) -> None: + self.input_dim = input_dim + + def configure_nn(self) -> None: + """Configure the MLP network.""" + assert self.input_dim is not None, "input_dim must be defined" + assert self.output_dim is not None, "output_dim must be defined" + layers = [] + input_dim = self.input_dim + + for i in range(self.nn_layers - 1): + layers.append(nn.Linear(input_dim, self.hidden_dim)) + layers.append(nn.ReLU()) + input_dim = self.hidden_dim + + layers.append(nn.Linear(input_dim, self.output_dim)) + self.net = nn.Sequential(*layers) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.net(x) diff --git a/src/models/components/geo_encoders/tabular_encoder.py b/src/models/components/geo_encoders/tabular_encoder.py new file mode 100644 index 0000000..1ae4b1d --- /dev/null +++ b/src/models/components/geo_encoders/tabular_encoder.py @@ -0,0 +1,71 @@ +from typing import Dict, override + +import torch +from torch import nn + +from src.models.components.geo_encoders.base_geo_encoder import BaseGeoEncoder + + +class TabularEncoder(BaseGeoEncoder): + """Tabular data encoder.""" + + def __init__( + self, + output_dim: int, + input_dim: int | None = None, + hidden_dim: int | None = None, + dropout_prob: float = 0.0, + geo_data_name: str = "tabular", + ): + super().__init__() + + self.input_dim = input_dim + self.output_dim = output_dim + self.hidden_dim = hidden_dim + self.dropout_prob = dropout_prob + + self.geo_encoder: nn.Module | None = None + + self.allowed_geo_data_names = ["tabular"] + assert ( + geo_data_name in self.allowed_geo_data_names + ), f"geo_data_name must be one of {self.allowed_geo_data_names}, got {geo_data_name}" + self.geo_data_name = geo_data_name + + @override + def setup(self, input_dim: int = None) -> list[str]: + self.configure_nn(input_dim) + print("Model setup with Tabular geo-encoder") + return ["tabular_encoder"] + + def set_tabular_input_dim(self, input_dim: int) -> None: + self.input_dim = input_dim + + def configure_nn(self, input_dim: int = None) -> None: + input_dim = input_dim or self.input_dim + assert input_dim is not None, "input_dim must be defined" + + if self.hidden_dim is None: + self.hidden_dim = max(self.input_dim * 2, 128) + + self.geo_encoder = nn.Sequential( + nn.LayerNorm(self.input_dim), + nn.Linear(self.input_dim, self.hidden_dim), + nn.ReLU(), + nn.Dropout(self.dropout_prob), + nn.Linear(self.hidden_dim, self.hidden_dim // 2), + nn.ReLU(), + nn.Dropout(self.dropout_prob), + nn.Linear(self.hidden_dim // 2, self.output_dim), + ) + + @override + def forward(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: + tab_data = batch.get("eo", {}).get("tabular") + + dtype = self.dtype + if tab_data.dtype != dtype: + tab_data = tab_data.to(dtype) + feats = self.geo_encoder(tab_data) + + return feats.to(dtype) diff --git a/src/models/components/loss_fns/huber_loss.py b/src/models/components/loss_fns/huber_loss.py new file mode 100644 index 0000000..57c0bed --- /dev/null +++ b/src/models/components/loss_fns/huber_loss.py @@ -0,0 +1,29 @@ +from typing import Dict, override + +import torch + +from src.models.components.loss_fns.base_loss_fn import BaseLossFn + + +class HuberLoss(BaseLossFn): + def __init__(self) -> None: + super().__init__() + self.criterion = torch.nn.HuberLoss(delta=1.0, reduction="mean") + self.name = "huber_loss" + + @override + def forward( + self, + pred: torch.Tensor, + labels: torch.Tensor | None = None, + batch: Dict[str, torch.Tensor] | None = None, + **kwargs, + ) -> torch.Tensor | Dict[str, torch.Tensor]: + + labels = labels if labels is not None else batch.get("target") + huber_loss = self.criterion(pred, labels) + + if "return_label" in kwargs: + return {self.name: huber_loss} + else: + return huber_loss diff --git a/src/models/components/loss_fns/rrmse_loss.py b/src/models/components/loss_fns/rrmse_loss.py new file mode 100644 index 0000000..1720f1b --- /dev/null +++ b/src/models/components/loss_fns/rrmse_loss.py @@ -0,0 +1,44 @@ +from typing import Dict, override + +import torch + +from src.models.components.loss_fns.base_loss_fn import BaseLossFn + + +class RRMSELoss(BaseLossFn): + """Relative Root Mean Squared Error (RRMSE). + + RRMSE = RMSE / mean(|labels|) + + Normalises RMSE by the mean absolute value of the target, giving a + unit-free percentage error. This makes results comparable across crops + and regions with different absolute yield scales (e.g. t/ha ranges + differ significantly between maize in Zambia and rice in Rwanda). + + Returns a fraction (e.g. 0.15 = 15 % error). Multiply by 100 for + percentage when reporting. + """ + + def __init__(self) -> None: + super().__init__() + self.criterion = torch.nn.MSELoss() + self.name = "rrmse_loss" + + @override + def forward( + self, + pred: torch.Tensor, + labels: torch.Tensor | None = None, + batch: Dict[str, torch.Tensor] | None = None, + **kwargs, + ) -> torch.Tensor | Dict[str, torch.Tensor]: + + labels = labels if labels is not None else batch.get("target") + rmse = torch.sqrt(self.criterion(pred, labels)) + mean_abs = torch.mean(torch.abs(labels)) + loss = rmse / (mean_abs + 1e-8) + + if "return_label" in kwargs: + return {self.name: loss} + else: + return loss diff --git a/src/models/components/metrics/contrastive_similarities.py b/src/models/components/metrics/contrastive_similarities.py index c3aae7e..4a72635 100644 --- a/src/models/components/metrics/contrastive_similarities.py +++ b/src/models/components/metrics/contrastive_similarities.py @@ -15,7 +15,7 @@ def __init__(self, k_list=None) -> None: def forward( self, mode: str, - eo_feats: torch.Tensor, + geo_feats: torch.Tensor, text_feats: torch.Tensor, local_batch_size: int, **kwargs, @@ -23,7 +23,7 @@ def forward( """Calculate cosine similarity between eo and text embeddings and logs it.""" # Similarity matrix - cos_sim_matrix = F.cosine_similarity(eo_feats[:, None, :], text_feats[None, :, :], dim=-1) + cos_sim_matrix = F.cosine_similarity(geo_feats[:, None, :], text_feats[None, :, :], dim=-1) # Average for positive and negative pairs # TODO change label option if we change what gets treated to be pos/neg diff --git a/src/models/components/metrics/contrastive_validation.py b/src/models/components/metrics/contrastive_validation.py new file mode 100644 index 0000000..76d216a --- /dev/null +++ b/src/models/components/metrics/contrastive_validation.py @@ -0,0 +1,91 @@ +from typing import Any, Dict, List, override + +import torch + +from src.models.components.metrics.base_metrics import BaseMetrics + + +class RetrievalContrastiveValidation(BaseMetrics): + def __init__(self, ks: List[Any], concept_configs: List[Any]) -> None: + """Evaluates how many eo embeddings are retrieved in top-k metrics based the GT labels. + + :param ks: k values for top-k metrics + :param concept_configs: concept configurations containing details about min/max mode, which + aux_col to use as GT. + """ + super().__init__() + + self.concept_configs = concept_configs + + self.ks = ks + if any("theta_k" in c for c in self.concept_configs): + self.ks.append("dynamic_k") + + @override + def forward( + self, + similarity_matrix: torch.Tensor, + aux_values: torch.Tensor, + **kwargs, + ) -> torch.Tensor | Dict[str, torch.Tensor]: + """Calculates top-k metrics based the GT (aux-derived) labels.""" + + aux_vals = aux_values.T + + concept_scores = {} + for i, configs in enumerate(self.concept_configs): + idx = configs["id"] + is_max = configs["is_max"] + k_threshold = configs.get("theta_k") + aux_val = aux_vals[idx] + + if k_threshold is not None: + dynamic_k = ( + sum(aux_val >= k_threshold).item() + if is_max + else sum(aux_val <= k_threshold).item() + ) + else: + dynamic_k = None + + sim_val = similarity_matrix[i] + scores = self.topk_rank_agreement(aux_val, sim_val, self.ks, is_max, dynamic_k) + + concept_scores[i] = scores + + return concept_scores + + @staticmethod + def topk_rank_agreement(gt_vals, pred_vals, ks, is_max=True, dynamic_k=None): + """Get how much of top-k concept retrievals are predicted correctly.""" + num_candidates = len(gt_vals) + + gt_order = torch.argsort(gt_vals, descending=True) + pred_order = torch.argsort(pred_vals, descending=True) + + gt_rank_pos = torch.empty_like(gt_order) + gt_rank_pos[gt_order] = torch.arange(num_candidates, device=gt_order.device) + + pred_rank_pos = torch.empty_like(pred_order) + pred_rank_pos[pred_order] = torch.arange(num_candidates, device=pred_order.device) + + results = {} + + for k in ks: + k_key = k + if k == "dynamic_k": + if dynamic_k != 0: + k = dynamic_k + else: + continue + + if is_max: + gt_mask = gt_rank_pos < k + pred_mask = pred_rank_pos < k + else: + k_inverted = num_candidates - k + gt_mask = gt_rank_pos >= k_inverted + pred_mask = pred_rank_pos >= k_inverted + results[k_key] = (gt_mask & pred_mask).sum().item() / k * 100 + + return results diff --git a/src/models/components/metrics/metrics_wrapper.py b/src/models/components/metrics/metrics_wrapper.py index c15f395..33d86de 100644 --- a/src/models/components/metrics/metrics_wrapper.py +++ b/src/models/components/metrics/metrics_wrapper.py @@ -10,7 +10,7 @@ class MetricsWrapper(nn.Module): def __init__(self, metrics: List[BaseMetrics | BaseLossFn]) -> None: super().__init__() - self.metrics = metrics + self.metrics = nn.ModuleList(metrics) def forward(self, mode="train", **kwargs) -> Dict[str, torch.float]: """Calculates all metrics and adds all the results into one dictionary for logging.""" diff --git a/src/models/components/metrics/r2.py b/src/models/components/metrics/r2.py index cd2d38e..e9d2c4e 100644 --- a/src/models/components/metrics/r2.py +++ b/src/models/components/metrics/r2.py @@ -1,14 +1,29 @@ from typing import Dict, override import torch +from torch import nn +from torchmetrics.regression import R2Score from src.models.components.metrics.base_metrics import BaseMetrics +_MODES = ("train", "val", "test") + class RSquared(BaseMetrics): + """Epoch-level R² using torchmetrics.R2Score. + + A separate R2Score accumulator is kept per mode so that train, val, and test statistics never + mix. Lightning detects the returned torchmetrics Metric objects and calls .compute()/.reset() + at epoch boundaries, giving a correct epoch-wide R² instead of an average of per-batch R² + values. + """ + def __init__(self) -> None: super().__init__() self.name = "r2" + # Keys are prefixed to avoid clashing with nn.Module attribute names + # (e.g. "train" conflicts with nn.Module.train()). + self._r2 = nn.ModuleDict({f"mode_{m}": R2Score() for m in _MODES}) @override def forward( @@ -17,12 +32,10 @@ def forward( labels: torch.Tensor | None = None, batch: Dict[str, torch.Tensor] | None = None, **kwargs, - ) -> torch.Tensor | Dict[str, torch.Tensor]: - + ) -> Dict[str, torch.Tensor]: labels = labels if labels is not None else batch.get("target") + mode = kwargs.get("mode", "train") - ss_res = torch.sum((labels - pred) ** 2) - ss_tot = torch.sum((labels - torch.mean(labels)) ** 2) + 1e-12 - r2 = 1.0 - ss_res / ss_tot - - return {self.name: r2} + metric = self._r2[f"mode_{mode}"] + metric.update(pred.squeeze(-1), labels.squeeze(-1)) + return {self.name: metric} diff --git a/src/models/components/pred_heads/base_pred_head.py b/src/models/components/pred_heads/base_pred_head.py index 7a4f7e1..5ed63c4 100644 --- a/src/models/components/pred_heads/base_pred_head.py +++ b/src/models/components/pred_heads/base_pred_head.py @@ -7,6 +7,7 @@ class BasePredictionHead(nn.Module, ABC): def __init__(self) -> None: + """Base prediction head interface class.""" super().__init__() self.net: nn.Module | None = None self.input_dim: int | None = None @@ -14,16 +15,30 @@ def __init__(self) -> None: @abstractmethod def forward(self, feats: torch.Tensor) -> torch.Tensor: + """Forward pass through the prediction head.""" pass @final def set_dim(self, input_dim: int, output_dim: int) -> None: + """Set dimensions for the prediction head configuration. + + :param input_dim: input dimension + :param output_dim: output dimension + """ + assert isinstance(input_dim, int), TypeError( + "Input dimension must be specified as integer" + ) + assert isinstance(output_dim, int), TypeError( + "Output dimension must be specified as integer" + ) self.input_dim = input_dim self.output_dim = output_dim - assert type(self.input_dim) is int, self.input_dim - if output_dim is not None: - assert type(self.output_dim) is int, self.output_dim @abstractmethod - def configure_nn(self) -> None: + def setup(self) -> None: + """Configures networks, data-dependent parts. + + Gets called in model.setup() method. Returns names of any new module configured to be added + to the trainable modules list. + """ pass diff --git a/src/models/components/pred_heads/linear_pred_head.py b/src/models/components/pred_heads/linear_pred_head.py index aa8ee55..94338a7 100644 --- a/src/models/components/pred_heads/linear_pred_head.py +++ b/src/models/components/pred_heads/linear_pred_head.py @@ -7,15 +7,33 @@ class LinearPredictionHead(BasePredictionHead): - def __init__(self): + def __init__( + self, + input_dim: int | None = None, + output_dim: int | None = None, + ) -> None: + """Linear prediction head for classification. + + :param input_dim: the size of input dimension + :param output_dim: the size of output dimension + """ super().__init__() + if input_dim and output_dim: + self.set_dim(input_dim, output_dim) @override def forward(self, feats: torch.Tensor) -> torch.Tensor: + """Forward pass through the prediction head.""" + return torch.sigmoid(self.net(feats)) @override - def configure_nn(self) -> None: + def setup(self) -> None: + """Configures networks, data-dependent parts. + + Gets called in model.setup() method. Returns names of any new module configured to be added + to the trainable modules list. + """ assert type(self.input_dim) is int, self.input_dim assert type(self.output_dim) is int, self.output_dim self.net = nn.Linear(self.input_dim, self.output_dim) diff --git a/src/models/components/pred_heads/mlp_pred_head.py b/src/models/components/pred_heads/mlp_pred_head.py index 28db550..6d4c124 100644 --- a/src/models/components/pred_heads/mlp_pred_head.py +++ b/src/models/components/pred_heads/mlp_pred_head.py @@ -7,17 +7,39 @@ class MLPPredictionHead(BasePredictionHead): - def __init__(self, nn_layers: int = 2, hidden_dim: int = 256) -> None: + def __init__( + self, + nn_layers: int = 2, + hidden_dim: int = 256, + input_dim: int | None = None, + output_dim: int | None = None, + ) -> None: + """MLP prediction head for classification. + + :param nn_layers: number of layers in MLP + :param hidden_dim: the size of hidden dimensions + :param input_dim: the size of input dimension + :param output_dim: the size of output dimension + """ super().__init__() self.nn_layers = nn_layers self.hidden_dim = hidden_dim + if input_dim and output_dim: + self.set_dim(input_dim, output_dim) + @override def forward(self, feats: torch.Tensor) -> torch.Tensor: + """Forward pass through the prediction head.""" return torch.sigmoid(self.net(feats)) @override - def configure_nn(self) -> None: + def setup(self) -> None: + """Configures networks, data-dependent parts. + + Gets called in model.setup() method. Returns names of any new module configured to be added + to the trainable modules list. + """ assert type(self.input_dim) is int, self.input_dim assert type(self.output_dim) is int, self.output_dim layers = [] diff --git a/src/models/components/pred_heads/mlp_regression_head.py b/src/models/components/pred_heads/mlp_regression_head.py index c52509a..d9553ec 100644 --- a/src/models/components/pred_heads/mlp_regression_head.py +++ b/src/models/components/pred_heads/mlp_regression_head.py @@ -19,17 +19,42 @@ class MLPRegressionPredictionHead(BasePredictionHead): """MLP prediction head for regression tasks (outputs a continuous value).""" - def __init__(self, nn_layers: int = 2, hidden_dim: int = 256) -> None: + def __init__( + self, + nn_layers: int = 2, + hidden_dim: int = 256, + dropout: float = 0.0, + input_dim: int | None = None, + output_dim: int | None = None, + ) -> None: + """MLP prediction head for regression tasks. + + :param nn_layers: number of layers in MLP + :param hidden_dim: the size of hidden dimensions + :param dropout: the dropout rate + :param input_dim: the size of input dimension + :param output_dim: the size of output dimension + """ super().__init__() self.nn_layers = nn_layers self.hidden_dim = hidden_dim + self.dropout = dropout + + if input_dim and output_dim: + self.set_dim(input_dim, output_dim) @override def forward(self, feats: torch.Tensor) -> torch.Tensor: + """Forward pass through the prediction head.""" return self.net(feats) @override - def configure_nn(self) -> None: + def setup(self) -> None: + """Configures networks, data-dependent parts. + + Gets called in model.setup() method. Returns names of any new module configured to be added + to the trainable modules list. + """ assert isinstance(self.input_dim, int), self.input_dim assert isinstance(self.output_dim, int), self.output_dim @@ -39,6 +64,8 @@ def configure_nn(self) -> None: for _ in range(self.nn_layers - 1): layers.append(nn.Linear(in_dim, self.hidden_dim)) layers.append(nn.ReLU()) + if self.dropout > 0.0: + layers.append(nn.Dropout(self.dropout)) in_dim = self.hidden_dim layers.append(nn.Linear(in_dim, self.output_dim)) diff --git a/src/models/components/text_encoders/clip_text_encoder.py b/src/models/components/text_encoders/clip_text_encoder.py index 027c503..075a390 100644 --- a/src/models/components/text_encoders/clip_text_encoder.py +++ b/src/models/components/text_encoders/clip_text_encoder.py @@ -2,7 +2,6 @@ import torch from geoclip import GeoCLIP -from torch.nn import functional as F from transformers import CLIPModel, CLIPProcessor from src.models.components.text_encoders.base_text_encoder import ( @@ -25,8 +24,13 @@ def __init__(self, hf_cache_dir: str = "../.cache", output_normalization="l2") - self.projector = GeoCLIP().image_encoder.mlp + self.model.vision_model = None + self.model.visual_projection = None + self.output_dim = 512 + print("Model set up with CLIP text encoder") + @override def forward(self, batch: Dict[str, torch.Tensor], mode: str) -> torch.Tensor: # Get text inputs @@ -38,7 +42,13 @@ def forward(self, batch: Dict[str, torch.Tensor], mode: str) -> torch.Tensor: avr_embeds = [] for captions_per_row in text_input: # Tokenize and embed - text_tokens = self.processor(text=captions_per_row, return_tensors="pt", padding=True) + text_tokens = self.processor( + text=captions_per_row, + return_tensors="pt", + padding=True, + truncation=True, + max_length=77, + ) device = next(self.model.parameters()).device text_tokens = {k: v.to(device) for k, v in text_tokens.items()} diff --git a/src/models/predictive_model.py b/src/models/predictive_model.py index f46a4dc..7b951c2 100644 --- a/src/models/predictive_model.py +++ b/src/models/predictive_model.py @@ -4,8 +4,9 @@ import torch.nn.functional as F from src.models.base_model import BaseModel -from src.models.components.eo_encoders.base_eo_encoder import BaseEOEncoder -from src.models.components.eo_encoders.multimodal_encoder import MultiModalEncoder +from src.models.components.geo_encoders.base_geo_encoder import BaseGeoEncoder +from src.models.components.geo_encoders.encoder_wrapper import EncoderWrapper +from src.models.components.geo_encoders.tabular_encoder import TabularEncoder from src.models.components.loss_fns.base_loss_fn import BaseLossFn from src.models.components.metrics.metrics_wrapper import MetricsWrapper from src.models.components.pred_heads.linear_pred_head import ( @@ -16,19 +17,19 @@ class PredictiveModel(BaseModel): def __init__( self, - eo_encoder: BaseEOEncoder, + geo_encoder: BaseGeoEncoder, prediction_head: BasePredictionHead, trainable_modules: list[str], optimizer: torch.optim.Optimizer, scheduler: torch.optim.lr_scheduler, loss_fn: BaseLossFn, metrics: MetricsWrapper, - num_classes: int | None = None, - tabular_dim: int | None = None, + normalize_features: bool = True, ) -> None: - """Implementation of the predictive model with replaceable EO encoder, and prediction head. + """Implementation of the predictive model with replaceable GEO encoder, and prediction + head. - :param eo_encoder: eo encoder module (replaceable) + :param geo_encoder: geo encoder module (replaceable) :param prediction_head: prediction head module (replaceable) :param trainable_modules: list of modules to train (parts/modules or modules, modules) :param optimizer: optimizer to use for training @@ -37,52 +38,81 @@ def __init__( :param metrics: metrics to use for model performance evaluation :param num_classes: number of target classes :param tabular_dim: number of tabular features + :param normalize_features: if True, apply L2 normalisation to encoder output before the + prediction head (default: True) """ - super().__init__( - trainable_modules, optimizer, scheduler, loss_fn, metrics, num_classes, tabular_dim - ) - - # EO encoder configuration - self.eo_encoder = eo_encoder + super().__init__(trainable_modules, optimizer, scheduler, loss_fn, metrics) - # TODO: move to multi-modal eo encoder - if ( - isinstance(self.eo_encoder, MultiModalEncoder) - and self.eo_encoder.use_tabular - and not self.eo_encoder._tabular_ready - ): - self.eo_encoder.build_tabular_branch(tabular_dim) + # Geo encoder configuration + self.geo_encoder = geo_encoder # Prediction head self.prediction_head = prediction_head - self.prediction_head.set_dim(input_dim=self.eo_encoder.output_dim, output_dim=num_classes) - self.prediction_head.configure_nn() - if "prediction_head" not in self.trainable_modules: - self.trainable_modules.append("prediction_head") + + # Normalise features boolean + self.normalize_features = normalize_features + + @override + def setup(self, stage: str) -> None: + self.num_classes = self.trainer.datamodule.num_classes + self.tabular_dim = self.trainer.datamodule.tabular_dim + + if stage != "fit": + if isinstance(self.trainable_modules, tuple): + self.trainable_modules = list(self.trainable_modules) + + print("-------Model------------") + self.setup_encoders_adapters() + print("------------------------") # Freezing requested parts self.freezer() + def setup_encoders_adapters(self): + """Set up encoders and missing adapters/projectors.""" + # TODO: move to multi-modal eo encoder + + # If tabular encoder used, we need to specify tabular dim + if isinstance(self.geo_encoder, TabularEncoder) or isinstance( + self.geo_encoder, EncoderWrapper + ): + self.geo_encoder.set_tabular_input_dim(self.tabular_dim) + + # Setup encoders that need data-depended configurations + new_modules = [f"geo_encoder.{i}]" for i in self.geo_encoder.setup()] + self.trainable_modules.extend(new_modules) + + # Configure prediction head based on geo-encoder output_dim + self.prediction_head.set_dim( + input_dim=self.geo_encoder.output_dim, output_dim=self.num_classes + ) + self.prediction_head.setup() + if "prediction_head" not in self.trainable_modules: + self.trainable_modules.append("prediction_head") + @override def forward(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: - feats = self.eo_encoder(batch) - feats = F.normalize(feats, dim=-1) + feats = self.geo_encoder(batch) + if self.normalize_features: + feats = F.normalize(feats, dim=-1) return self.prediction_head(feats) @override def _step(self, batch: Dict[str, torch.Tensor], mode: str = "train") -> torch.Tensor: - feats = self.forward(batch) + preds = self.forward(batch) log_kwargs = dict( - on_step=False, on_epoch=True, prog_bar=True, sync_dist=True, batch_size=feats.size(0) + on_step=False, on_epoch=True, prog_bar=True, sync_dist=True, batch_size=preds.size(0) ) - loss = self.loss_fn(feats, batch.get("target")) + loss = self.loss_fn(preds, batch.get("target")) self.log(f"{mode}_loss", loss, **log_kwargs) - metrics = self.metrics(pred=feats, batch=batch, mode=mode) + metrics = self.metrics(pred=preds, batch=batch, mode=mode) self.log_dict(metrics, **log_kwargs) + return loss + if __name__ == "__main__": _ = PredictiveModel(None, None, None, None, None, None, None) diff --git a/src/models/text_alignment_model.py b/src/models/text_alignment_model.py index 21b476a..49b8a91 100644 --- a/src/models/text_alignment_model.py +++ b/src/models/text_alignment_model.py @@ -1,12 +1,15 @@ +from io import text_encoding from typing import Dict, Tuple, override import torch import torch.nn.functional as F from src.models.base_model import BaseModel -from src.models.components.eo_encoders.base_eo_encoder import BaseEOEncoder -from src.models.components.eo_encoders.multimodal_encoder import MultiModalEncoder +from src.models.components.geo_encoders.base_geo_encoder import BaseGeoEncoder from src.models.components.loss_fns.base_loss_fn import BaseLossFn +from src.models.components.metrics.contrastive_validation import ( + RetrievalContrastiveValidation, +) from src.models.components.metrics.metrics_wrapper import MetricsWrapper from src.models.components.pred_heads.linear_pred_head import ( BasePredictionHead, @@ -19,20 +22,20 @@ class TextAlignmentModel(BaseModel): def __init__( self, - eo_encoder: BaseEOEncoder, + geo_encoder: BaseGeoEncoder, text_encoder: BaseTextEncoder, optimizer: torch.optim.Optimizer, scheduler: torch.optim.lr_scheduler, loss_fn: BaseLossFn, trainable_modules: list[str], metrics: MetricsWrapper, - num_classes: int | None = None, - tabular_dim: int | None = None, prediction_head: BasePredictionHead | None = None, + ks: list[int] | None = [5, 10, 15], + match_to_geo: bool = True, ) -> None: """Implementation of contrastive text-eo modality alignment model. - :param eo_encoder: eo encoder module (replaceable) + :param geo_encoder: geo encoder module (replaceable) :param text_encoder: text encoder module (replaceable) :param optimizer: optimizer to use for training :param scheduler: scheduler to use for training @@ -42,44 +45,90 @@ def __init__( :param num_classes: number of target classes :param tabular_dim: number of tabular features :param prediction_head: prediction head + :param ks: list of ks + :param match_to_geo: whether to match dimensions of text encoder to geo_encoder or visa- + versa """ - super().__init__( - trainable_modules, optimizer, scheduler, loss_fn, metrics, num_classes, tabular_dim - ) + super().__init__(trainable_modules, optimizer, scheduler, loss_fn, metrics) - # Encoders configuration - self.eo_encoder = eo_encoder - # TODO: move to multi-modal eo encoder - if ( - isinstance(self.eo_encoder, MultiModalEncoder) - and self.eo_encoder.use_tabular - and not self.eo_encoder._tabular_ready - ): - self.eo_encoder.build_tabular_branch(tabular_dim) + # Metrics + self.ks = ks + self.log_kwargs = dict(on_step=False, on_epoch=True, prog_bar=True, sync_dist=True) + # Encoders configuration + self.geo_encoder = geo_encoder self.text_encoder = text_encoder - # TODO: if eo==geoclip_img pass on shared mlp - - # Extra projector for text encoder if eo and text dim not match - if self.eo_encoder.output_dim != self.text_encoder.output_dim: - self.text_encoder.add_projector(projected_dim=self.eo_encoder.output_dim) - self.trainable_modules.append("text_encoder.extra_projector") + self.match_to_geo = match_to_geo # Prediction head self.prediction_head = prediction_head + + @override + def setup(self, stage: str) -> None: + self.num_classes = self.trainer.datamodule.num_classes + self.tabular_dim = self.trainer.datamodule.tabular_dim + + # Set up encoders and missing adapters/projectors + print("-------Model------------") + self.setup_encoders_adapters() + print("------------------------") + + # Freeze requested parts + self.freezer() + + # Configure contrastive retrieval evaluation + self.setup_retrieval_evaluation() + + def setup_encoders_adapters(self): + """Set up encoders and missing adapters/projectors.""" + # We don't use tabular encoders for wrapping + # if ( + # isinstance(self.geo_encoder, MultiModalEncoder) + # and self.geo_encoder.use_tabular + # and not self.geo_encoder._tabular_ready + # ): + # self.geo_encoder.build_tabular_branch(self.tabular_dim) + + # Setup encoders that need data-depended configurations + new_modules = [f"geo_encoder.{i}" for i in self.geo_encoder.setup()] + self.trainable_modules.extend(new_modules) + + # Extra projector for text encoder if eo and text dim not match + if self.geo_encoder.output_dim != self.text_encoder.output_dim: + if self.match_to_geo: + self.text_encoder.add_projector(projected_dim=self.geo_encoder.output_dim) + self.trainable_modules.append("text_encoder.extra_projector") + else: + self.geo_encoder.add_projector(projected_dim=self.text_encoder.output_dim) + self.trainable_modules.append("geo_encoder.extra_projector") + + # Configure prediction head based on geo-encoder output_dim if self.prediction_head is not None: self.prediction_head.set_dim( - input_dim=self.eo_encoder.output_dim, output_dim=num_classes + input_dim=self.geo_encoder.output_dim, output_dim=self.num_classes ) - self.prediction_head.configure_nn() + self.prediction_head.setup() - # Unify dtypes - if self.eo_encoder.dtype != self.text_encoder.dtype: - self.eo_encoder = self.eo_encoder.to(self.text_encoder.dtype) - print(f"Eo encoder dtype changed to {self.eo_encoder.dtype}") + # # Unify dtypes -> moving to data part, rather than changing parameter type + # if self.geo_encoder.dtype != self.text_encoder.dtype: + # self.geo_encoder = self.geo_encoder.to(self.text_encoder.dtype) + # print(f"Geo encoder dtype changed to {self.geo_encoder.dtype}") - # Freezing requested parts - self.freezer() + def setup_retrieval_evaluation(self): + self.concept_configs = self.trainer.datamodule.concept_configs + self.concepts = [c["concept_caption"] for c in self.concept_configs] + + self.contrastive_val = RetrievalContrastiveValidation(self.ks, self.concept_configs) + self.outputs_epoch_memory = [] + + for trainable_module in self.trainable_modules: + if "text" in trainable_module: + self.concept_embeds = None + return + + # Encode concepts if text branch is frozen + with torch.no_grad(): + self.concept_embeds = self.text_encoder({"text": self.concepts}, mode="train") @override def forward( @@ -90,50 +139,118 @@ def forward( """Model forward logic.""" # Embed modalities - eo_feats = self.eo_encoder(batch) + geo_feats = self.geo_encoder(batch) text_feats = self.text_encoder(batch, mode) - return eo_feats, text_feats + + # Change dtype of geo data if it does not match text dtype + if geo_feats.dtype != text_feats.dtype: + geo_feats = geo_feats.to(text_feats.dtype) + return geo_feats, text_feats @override - def _step(self, batch: Dict[str, torch.Tensor], mode: str = "train") -> torch.Tensor: + def _step(self, batch: Dict[str, torch.Tensor], mode: str = "train"): """Model step logic.""" # Embed - eo_feats, text_feats = self.forward(batch, mode) - local_batch_size = eo_feats.size(0) + geo_feats, text_feats = self.forward(batch, mode) + local_batch_size = geo_feats.size(0) # batch recomposing in ddp if self.trainer.world_size > 1: - feats = torch.stack([eo_feats, text_feats], dim=0) + feats = torch.stack([geo_feats, text_feats], dim=0) feats = self.all_gather(feats) feats = feats.reshape(2, -1, feats.size(-1)) - eo_feats, text_feats = feats[0], feats[1] + geo_feats, text_feats = feats[0], feats[1] # Get loss - loss = self.loss_fn(eo_feats, text_feats) + loss = self.loss_fn(geo_feats, text_feats) # Get similarities with torch.no_grad(): metrics = self.metrics( mode=mode, - eo_feats=eo_feats, + geo_feats=geo_feats, text_feats=text_feats, local_batch_size=local_batch_size, ) # Logging - log_kwargs = dict( - on_step=False, - on_epoch=True, - prog_bar=True, - sync_dist=True, - batch_size=local_batch_size, - ) - self.log(f"{mode}_loss", loss, **log_kwargs) + self.log(f"{mode}_loss", loss, batch_size=local_batch_size, **self.log_kwargs) if self.loss_fn.__getattr__("log_temp") and mode == "train": - self.log("temp", self.loss_fn.__getattr__("log_temp").exp(), **log_kwargs) + self.log( + "temp", + self.loss_fn.__getattr__("log_temp").exp(), + batch_size=local_batch_size, + **self.log_kwargs, + ) + + self.log_dict(metrics, batch_size=local_batch_size, **self.log_kwargs) - self.log_dict(metrics, **log_kwargs) + if mode in ["val", "test"]: + self.outputs_epoch_memory.append( + { + "geo_feats": geo_feats.detach(), + "aux_vals": batch.get("aux", {}).get("aux").detach(), + } + ) return loss + + def _on_epoch_end(self, mode: str): + + # Combine batches + geo_feats = torch.cat([x["geo_feats"] for x in self.outputs_epoch_memory], dim=0) + + aux_vals = torch.cat([x["aux_vals"] for x in self.outputs_epoch_memory], dim=0) + + # Rank on similarity + similarity = self.concept_similarities(geo_feats) + + concept_scores = self.contrastive_val(similarity, aux_values=aux_vals) + # TODO pearson + + avr_scores = {f"{mode}_avr_top-{k}": [] for k in self.ks} + for i, result in concept_scores.items(): + print(f'\nConcept "{self.concepts[i]}" average top-k accuracies in {mode} split:') + for k, v in result.items(): + print(f"Top-{k}: {v:.1f}%") + avr_scores[f"{mode}_avr_top-{k}"].append(v) + + for k, v in avr_scores.items(): + avr_scores[k] = sum(v) / len(v) + + self.log_dict(avr_scores) + + # Reset memory + self.outputs_epoch_memory.clear() + + @override + def on_validation_epoch_end(self): + return self._on_epoch_end("val") + + @override + def on_test_epoch_end(self): + return self._on_epoch_end("test") + + def concept_similarities(self, geo_embeds, concept=None) -> torch.Tensor: + # Get concept embeddings + if concept is not None: + # If only one concept is provided + if isinstance(concept, str): + concept = [concept] + with torch.no_grad(): + concept_embeds = self.text_encoder({"text": concept}, mode="train") + + elif self.concept_embeds is not None: + concept_embeds = self.concept_embeds + else: + with torch.no_grad(): + concept_embeds = self.text_encoder({"text": self.concepts}, mode="train") + + # Similarity + geo_embeds = F.normalize(geo_embeds, dim=1) + concept_embeds = F.normalize(concept_embeds, dim=1) + similarity_matrix = concept_embeds @ geo_embeds.T + + return similarity_matrix diff --git a/src/train.py b/src/train.py index 34f4a99..8348fbb 100644 --- a/src/train.py +++ b/src/train.py @@ -51,9 +51,7 @@ def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: datamodule: BaseDataModule = hydra.utils.instantiate(cfg.data) log.info(f"Instantiating model <{cfg.model._target_}>") - model: LightningModule = hydra.utils.instantiate( - cfg.model, num_classes=datamodule.num_classes, tabular_dim=datamodule.tabular_dim - ) + model: LightningModule = hydra.utils.instantiate(cfg.model) log.info("Instantiating callbacks...") callbacks: List[Callback] = instantiate_callbacks(cfg.get("callbacks")) diff --git a/tests/conftest.py b/tests/conftest.py index e4e9b2d..e4dcbf4 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -7,12 +7,11 @@ import pandas as pd import pytest import rootutils -import torch from hydra import compose, initialize from hydra.core.global_hydra import GlobalHydra from omegaconf import DictConfig, open_dict -from src.data.base_caption_builder import BaseCaptionBuilder, DummyCaptionBuilder +from src.data.base_caption_builder import DummyCaptionBuilder from src.data.base_datamodule import BaseDataModule from src.data.butterfly_dataset import ButterflyDataset @@ -165,13 +164,28 @@ def create_butterfly_dataset(request, sample_csv, tmp_path): mock=use_mock, ) - templates_path = tmp_path / "caption_templates" / "v1.json" - os.makedirs(str(tmp_path / "caption_templates"), exist_ok=True) + templates_path = tmp_path / "location_caption_templates" / "v1.json" + os.makedirs(str(tmp_path / "location_caption_templates"), exist_ok=True) print(f"Mock captions written to {templates_path}") templates_path.write_text(json.dumps([" text"])) + concepts_path = tmp_path / "concept_captions" / "v1.json" + os.makedirs(str(tmp_path / "concept_captions"), exist_ok=True) + print(f"Concept captions written to {concepts_path}") + concepts_path.write_text( + json.dumps( + """[{ + "concept_caption": "Forested area", + "is_max": true, + "theta_k": 0.5, + "col": "aux_corine_frac_311" + }]""" + ) + ) + caption_builder = DummyCaptionBuilder( templates_fname="v1.json", + concepts_fname="v1.json", data_dir=str(tmp_path), seed=0, ) diff --git a/tests/test_captions.py b/tests/test_captions.py index 66b132f..85951dd 100644 --- a/tests/test_captions.py +++ b/tests/test_captions.py @@ -1,10 +1,7 @@ import json import os -import pandas as pd -import pytest - -from src.data.base_caption_builder import BaseCaptionBuilder, DummyCaptionBuilder +from src.data.base_caption_builder import DummyCaptionBuilder from src.data.base_datamodule import BaseDataModule from src.data.butterfly_caption_builder import ButterflyCaptionBuilder from src.data.butterfly_dataset import ButterflyDataset @@ -12,12 +9,26 @@ def test_datamodule_uses_collate_when_aux_data(request, sample_csv, tmp_path): use_mock = request.config.getoption("--use-mock") - templates_path = tmp_path / "caption_templates" / "v1.json" - os.makedirs(str(tmp_path / "caption_templates"), exist_ok=True) + templates_path = tmp_path / "location_caption_templates" / "v1.json" + os.makedirs(str(tmp_path / "location_caption_templates"), exist_ok=True) print(f"Mock captions written to {templates_path}") templates_path.write_text(json.dumps([" text"])) - caption_builder = DummyCaptionBuilder("v1.json", data_dir=str(tmp_path), seed=0) + concepts_path = tmp_path / "concept_captions" / "v1.json" + os.makedirs(str(tmp_path / "concept_captions"), exist_ok=True) + print(f"Concept captions written to {concepts_path}") + concepts_path.write_text( + json.dumps( + """[{ + "concept_caption": "Forested area", + "is_max": true, + "theta_k": 0.5, + "col": "aux_corine_frac_311" + }]""" + ) + ) + + caption_builder = DummyCaptionBuilder("v1.json", "v1.json", data_dir=str(tmp_path), seed=0) dataset = ButterflyDataset( data_dir=sample_csv, @@ -49,6 +60,7 @@ def test_captionbuilder_generic_properties(tmp_path): dict_caption_builders = {"butterfly": ButterflyCaptionBuilder, "dummy": DummyCaptionBuilder} templates_fname = "v1.json" + concepts_fname = "v1.json" for name_cb, cb_class in dict_caption_builders.items(): # There is no data on git anymore @@ -56,21 +68,38 @@ def test_captionbuilder_generic_properties(tmp_path): # repo_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) # templates_path = os.path.join(repo_root, "data", "s2bms") # else: - templates_path = tmp_path - templates_fpath = templates_path / "caption_templates" / templates_fname - os.makedirs(str(templates_path / "caption_templates"), exist_ok=True) - templates_fpath.write_text(json.dumps([" text"])) - print(f"written to {templates_path}") + templates_path = tmp_path / "location_caption_templates" / templates_fname + os.makedirs(str(tmp_path / "location_caption_templates"), exist_ok=True) + print(f"Mock captions written to {templates_path}") + templates_path.write_text(json.dumps([" text"])) + + concepts_path = tmp_path / "concept_captions" / concepts_fname + os.makedirs(str(tmp_path / "concept_captions"), exist_ok=True) + print(f"Concept captions written to {concepts_path}") + concepts_path.write_text( + json.dumps( + """[{ + "concept_caption": "Forested area", + "is_max": true, + "theta_k": 0.5, + "col": "aux_corine_frac_311" + }]""" + ) + ) caption_builder = cb_class( templates_fname=templates_fname, - data_dir=templates_path, + concepts_fname=concepts_fname, + data_dir=tmp_path, seed=0, ) assert hasattr( caption_builder, "templates" ), f"'templates' attribute missing in {cb_class.__name__}." + assert hasattr( + caption_builder, "concepts" + ), f"'concepts' attribute missing in {cb_class.__name__}." assert hasattr( caption_builder, "data_dir" ), f"'data_dir' attribute missing in {cb_class.__name__}." diff --git a/tests/test_configs.py b/tests/test_configs.py index 9bb9c21..e34abf6 100644 --- a/tests/test_configs.py +++ b/tests/test_configs.py @@ -15,8 +15,7 @@ def test_train_config(cfg_train: DictConfig) -> None: HydraConfig().set_config(cfg_train) - datamodule = hydra.utils.instantiate(cfg_train.data) - hydra.utils.instantiate(cfg_train.model, num_classes=datamodule.num_classes) + hydra.utils.instantiate(cfg_train.model) hydra.utils.instantiate(cfg_train.trainer) diff --git a/tests/test_datasets_and_datamodules.py b/tests/test_datasets_and_datamodules.py index 2141987..20c94be 100644 --- a/tests/test_datasets_and_datamodules.py +++ b/tests/test_datasets_and_datamodules.py @@ -2,11 +2,12 @@ from src.data.butterfly_dataset import ButterflyDataset from src.data.heat_guatemala_dataset import HeatGuatemalaDataset from src.data.satbird_dataset import SatBirdDataset +from src.data.yield_africa_dataset import YieldAfricaDataset def test_datasets_generic_properties(request, tmp_path, sample_csv): """This test checks that all datasets implement the basic properties and methods.""" - list_datasets = [ButterflyDataset, SatBirdDataset, HeatGuatemalaDataset] + list_datasets = [ButterflyDataset, SatBirdDataset, HeatGuatemalaDataset, YieldAfricaDataset] use_mock = request.config.getoption("--use-mock") if use_mock: csv_dir = sample_csv diff --git a/tests/test_eo_encoders.py b/tests/test_eo_encoders.py deleted file mode 100644 index 635919d..0000000 --- a/tests/test_eo_encoders.py +++ /dev/null @@ -1,51 +0,0 @@ -import json -import os - -import pandas as pd -import pytest -import torch - -from src.models.components.eo_encoders.average_encoder import AverageEncoder -from src.models.components.eo_encoders.base_eo_encoder import BaseEOEncoder -from src.models.components.eo_encoders.cnn_encoder import CNNEncoder -from src.models.components.eo_encoders.geoclip import GeoClipCoordinateEncoder -from src.models.components.eo_encoders.multimodal_encoder import MultiModalEncoder - - -# @pytest.mark.slow -def test_eo_encoder_generic_properties(create_butterfly_dataset): - """This test checks that all EO encoders implement the basic properties and methods.""" - dict_eo_encoders = { - "geoclip_coords": GeoClipCoordinateEncoder, - "cnn": CNNEncoder, - "average": AverageEncoder, - "multimodal_coords": MultiModalEncoder, - } - ds, dm = create_butterfly_dataset - batch = next(iter(dm.train_dataloader())) - - for eo_encoder_name, eo_encoder_class in dict_eo_encoders.items(): - eo_encoder = eo_encoder_class() - - assert hasattr( - eo_encoder, "eo_encoder" - ), f"'eo_encoder' attribute missing in {eo_encoder_class.__name__}." - assert hasattr( - eo_encoder, "output_dim" - ), f"'output_dim' attribute missing in {eo_encoder_class.__name__}." - assert hasattr( - eo_encoder, "forward" - ), f"'forward' method missing in {eo_encoder_class.__name__}." - assert callable( - getattr(eo_encoder, "forward") - ), f"'forward' is not callable in {eo_encoder_class.__name__}." - - if eo_encoder_name == "geoclip_coords": - # TODO: try more EO encoders when (mock) test data also includes images. - feats = eo_encoder.forward(batch) - assert isinstance( - feats, torch.Tensor - ), f"'forward' method of {eo_encoder_class.__name__} does not return a torch.Tensor." - assert ( - feats.shape[0] == dm.batch_size_per_device - ), f"Output batch size mismatch in {eo_encoder_class.__name__}." diff --git a/tests/test_geo_encoders.py b/tests/test_geo_encoders.py new file mode 100644 index 0000000..57ff921 --- /dev/null +++ b/tests/test_geo_encoders.py @@ -0,0 +1,59 @@ +import json +import os + +import pandas as pd +import pytest +import torch + +from src.models.components.geo_encoders.average_encoder import AverageEncoder +from src.models.components.geo_encoders.cnn_encoder import CNNEncoder +from src.models.components.geo_encoders.geoclip import GeoClipCoordinateEncoder +from src.models.components.geo_encoders.mlp_projector import MLPProjector +from src.models.components.geo_encoders.tabular_encoder import TabularEncoder + + +# @pytest.mark.slow +def test_geo_encoder_generic_properties(create_butterfly_dataset): + """This test checks that all GEO encoders implement the basic properties and methods.""" + dict_geo_encoders = { + "geoclip_coords": GeoClipCoordinateEncoder, + "cnn": CNNEncoder, + "average": AverageEncoder, + "tabular": TabularEncoder, + "mlp_projector": MLPProjector, + } + ds, dm = create_butterfly_dataset + batch = next(iter(dm.train_dataloader())) + + for geo_encoder_name, geo_encoder_class in dict_geo_encoders.items(): + if geo_encoder_class is MLPProjector: + geo_encoder = geo_encoder_class(output_dim=64, input_dim=128) + elif geo_encoder_class is TabularEncoder: + geo_encoder = geo_encoder_class(output_dim=64, input_dim=128, hidden_dim=128) + else: + geo_encoder = geo_encoder_class() + + geo_encoder.setup() + + assert hasattr( + geo_encoder, "geo_encoder" + ), f"'geo_encoder' attribute missing in {geo_encoder_class.__name__}." + assert hasattr( + geo_encoder, "output_dim" + ), f"'output_dim' attribute missing in {geo_encoder_class.__name__}." + assert hasattr( + geo_encoder, "forward" + ), f"'forward' method missing in {geo_encoder_class.__name__}." + assert callable( + getattr(geo_encoder, "forward") + ), f"'forward' is not callable in {geo_encoder_class.__name__}." + + if geo_encoder_name == "geoclip_coords": + # TODO: try more GEO encoders when (mock) test data also includes images. + feats = geo_encoder.forward(batch) + assert isinstance( + feats, torch.Tensor + ), f"'forward' method of {geo_encoder_class.__name__} does not return a torch.Tensor." + assert ( + feats.shape[0] == dm.batch_size_per_device + ), f"Output batch size mismatch in {geo_encoder_class.__name__}." diff --git a/tests/test_pred_heads.py b/tests/test_pred_heads.py index 09f8a9a..abc99b4 100644 --- a/tests/test_pred_heads.py +++ b/tests/test_pred_heads.py @@ -5,7 +5,7 @@ import pytest import torch -from src.models.components.eo_encoders.geoclip import GeoClipCoordinateEncoder +from src.models.components.geo_encoders.geoclip import GeoClipCoordinateEncoder from src.models.components.pred_heads.base_pred_head import BasePredictionHead from src.models.components.pred_heads.linear_pred_head import LinearPredictionHead from src.models.components.pred_heads.mlp_pred_head import MLPPredictionHead @@ -19,11 +19,13 @@ def test_pred_head_generic_properties(create_butterfly_dataset): ds, dm = create_butterfly_dataset batch = next(iter(dm.train_dataloader())) eo_encoder = GeoClipCoordinateEncoder() + eo_encoder.setup() feats = eo_encoder.forward(batch) list_pred_heads = [LinearPredictionHead, MLPPredictionHead, MLPRegressionPredictionHead] for pred_head_class in list_pred_heads: - pred_head = pred_head_class() + pred_head = pred_head_class(input_dim=64, output_dim=64) + pred_head.setup() assert hasattr( pred_head, "set_dim" ), f"'set_dim' method missing in {pred_head_class.__name__}." @@ -38,12 +40,12 @@ def test_pred_head_generic_properties(create_butterfly_dataset): pred_head, "output_dim" ), f"'output_dim' attribute missing in {pred_head_class.__name__}." assert hasattr( - pred_head, "configure_nn" - ), f"'configure_nn' method missing in {pred_head_class.__name__}." + pred_head, "setup" + ), f"'setup' method missing in {pred_head_class.__name__}." assert callable( - getattr(pred_head, "configure_nn") - ), f"'configure_nn' is not callable in {pred_head_class.__name__}." - pred_head.configure_nn() + getattr(pred_head, "setup") + ), f"'setup' is not callable in {pred_head_class.__name__}." + pred_head.setup() assert hasattr(pred_head, "net"), f"'net' attribute missing in {pred_head_class.__name__}." assert hasattr( pred_head, "forward" diff --git a/tests/test_yield_africa.py b/tests/test_yield_africa.py new file mode 100644 index 0000000..9104c1b --- /dev/null +++ b/tests/test_yield_africa.py @@ -0,0 +1,327 @@ +"""Tests for the yield_africa use case. + +The mock CSV mirrors the schema produced by make_model_ready_yield_africa.py: + - name_loc, lat, lon + - target_yld_ton_ha + - feat_* (continuous soil/climate/terrain features + tabular categorical soil texture) + - aux_* (derived classification columns, used for caption generation) + - metadata: country, year, location_accuracy +""" + +import hydra +import pandas as pd +import pytest +import torch +from hydra import compose, initialize +from hydra.core.global_hydra import GlobalHydra + +from src.data.base_datamodule import BaseDataModule +from src.data.yield_africa_dataset import YieldAfricaDataset + +# --------------------------------------------------------------------------- +# Representative column sets that match the real model_ready_yield-africa.csv +# --------------------------------------------------------------------------- + +MOCK_FEAT_COLS = { + # continuous soil features + "feat_c_0_20": [1.2, 0.9, 1.5, 1.1, 0.8, 1.4, 1.6, 1.0, 1.3, 1.1], + "feat_n_0_20": [0.12, 0.09, 0.15, 0.11, 0.08, 0.14, 0.16, 0.10, 0.13, 0.11], + "feat_ph_0_20": [6.1, 5.8, 6.5, 6.0, 5.5, 6.3, 6.8, 5.9, 6.2, 6.4], + # continuous climate features + "feat_map": [820, 750, 910, 860, 700, 880, 930, 770, 815, 875], + "feat_mat": [22.1, 21.5, 23.0, 22.5, 21.0, 22.8, 23.3, 21.9, 22.2, 22.7], + # continuous terrain feature + "feat_dem": [450, 380, 510, 470, 360, 490, 530, 400, 460, 500], + # tabular categorical: soil texture class (real columns, not derived) + "feat_tx_0_20_cl": [2, 3, 1, 2, 4, 1, 3, 2, 1, 3], + "feat_tx_20_50_cl": [2, 2, 1, 3, 4, 1, 2, 3, 1, 2], +} + +MOCK_AUX_COLS = { + # derived classification columns (paired with the continuous feat_* above) + "aux_yld_ton_ha_cl": [1, 0, 2, 1, 0, 2, 2, 0, 1, 1], + "aux_c_0_20_cl": [1, 0, 2, 1, 0, 2, 2, 0, 1, 1], + "aux_ph_0_20_cl": [2, 1, 2, 2, 0, 2, 3, 1, 2, 2], + "aux_map_cl": [1, 0, 2, 1, 0, 2, 2, 0, 1, 2], +} + +MOCK_N_ROWS = 10 +# YieldAfricaDataset injects extra feat_* columns when country and year columns +# are present: feat_year (1) + feat_country_{code} (8) + Fourier harmonics (6). +from src.data.yield_africa_dataset import _ALL_COUNTRIES + +MOCK_INJECTED_FEAT_NAMES = ( + {"feat_year"} + | {f"feat_country_{c}" for c in _ALL_COUNTRIES} + | { + "feat_lat_sin1", + "feat_lat_cos1", + "feat_lat_sin2", + "feat_lat_cos2", + "feat_lon_sin1", + "feat_lon_cos1", + } +) +MOCK_TABULAR_DIM = len(MOCK_FEAT_COLS) + len(MOCK_INJECTED_FEAT_NAMES) # 8 + 15 = 23 +MOCK_N_AUX = len(MOCK_AUX_COLS) # 4 + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def yield_africa_csv(tmp_path) -> str: + """Mock CSV with column names matching the real model_ready_yield-africa.csv.""" + data = { + "name_loc": [f"ETH_{i:04d}" for i in range(MOCK_N_ROWS)], + "lat": [5.0 + i * 0.5 for i in range(MOCK_N_ROWS)], + "lon": [30.0 + i * 0.5 for i in range(MOCK_N_ROWS)], + "target_yld_ton_ha": [2.1, 1.8, 3.0, 2.5, 1.2, 2.8, 3.3, 1.9, 2.0, 2.7], + "country": ["ETH"] * MOCK_N_ROWS, + "year": [2019] * MOCK_N_ROWS, + "location_accuracy": [1] * MOCK_N_ROWS, + } + data.update(MOCK_FEAT_COLS) + data.update(MOCK_AUX_COLS) + + mock_dir = tmp_path / "mock" + mock_dir.mkdir(parents=True, exist_ok=True) + pd.DataFrame(data).to_csv(mock_dir / "model_ready_mock.csv", index=False) + return str(tmp_path) + + +@pytest.fixture +def yield_africa_dataset(yield_africa_csv, tmp_path): + """YieldAfricaDataset backed by mock data, features enabled, no aux.""" + return YieldAfricaDataset( + data_dir=yield_africa_csv, + cache_dir=str(tmp_path / "cache"), + modalities={"coords": {}}, + use_target_data=True, + use_aux_data="none", + seed=42, + mock=True, + use_features=True, + ) + + +@pytest.fixture +def yield_africa_dataset_with_aux(yield_africa_csv, tmp_path): + """YieldAfricaDataset backed by mock data, features enabled, aux enabled.""" + return YieldAfricaDataset( + data_dir=yield_africa_csv, + cache_dir=str(tmp_path / "cache"), + modalities={"coords": {}}, + use_target_data=True, + use_aux_data="all", + seed=42, + mock=True, + use_features=True, + ) + + +@pytest.fixture +def yield_africa_datamodule(yield_africa_csv, tmp_path): + """BaseDataModule wrapping YieldAfricaDataset with a random split.""" + dataset = YieldAfricaDataset( + data_dir=yield_africa_csv, + cache_dir=str(tmp_path / "cache"), + modalities={"coords": {}}, + use_target_data=True, + use_aux_data="none", + seed=42, + mock=True, + use_features=True, + ) + return BaseDataModule( + dataset=dataset, + batch_size=4, + train_val_test_split=(7, 2, 1), + num_workers=0, + pin_memory=False, + split_mode="random", + save_split=False, + seed=42, + ) + + +# --------------------------------------------------------------------------- +# Dataset tests +# --------------------------------------------------------------------------- + + +def test_yield_africa_dataset_length(yield_africa_dataset): + assert len(yield_africa_dataset) == MOCK_N_ROWS + + +def test_yield_africa_dataset_sample_keys(yield_africa_dataset): + sample = yield_africa_dataset[0] + assert "eo" in sample + assert "coords" in sample["eo"] + assert "tabular" in sample["eo"] + assert "target" in sample + + +def test_yield_africa_dataset_sample_shapes(yield_africa_dataset): + sample = yield_africa_dataset[0] + assert sample["eo"]["coords"].shape == (2,) + assert sample["eo"]["tabular"].shape == (MOCK_TABULAR_DIM,) + assert sample["target"].shape == (1,) + + +def test_yield_africa_dataset_sample_dtypes(yield_africa_dataset): + sample = yield_africa_dataset[0] + assert sample["eo"]["coords"].dtype == torch.float32 + assert sample["eo"]["tabular"].dtype == torch.float32 + assert sample["target"].dtype == torch.float32 + + +def test_yield_africa_dataset_target_name(yield_africa_dataset): + assert yield_africa_dataset.target_names == ["target_yld_ton_ha"] + + +def test_yield_africa_dataset_attributes(yield_africa_dataset): + assert yield_africa_dataset.num_classes == 1 + assert yield_africa_dataset.tabular_dim == MOCK_TABULAR_DIM + expected_feat_names = set(MOCK_FEAT_COLS.keys()) | MOCK_INJECTED_FEAT_NAMES + assert set(yield_africa_dataset.feat_names) == expected_feat_names + + +def test_yield_africa_dataset_feat_prefix(yield_africa_dataset): + """All tabular features must carry the feat prefix.""" + for name in yield_africa_dataset.feat_names: + assert name.startswith("feat_"), f"Unexpected feature name: {name}" + + +def test_yield_africa_dataset_coords_values(yield_africa_dataset): + """Coordinates returned must match the CSV values.""" + sample = yield_africa_dataset[0] + coords = sample["eo"]["coords"] + assert coords[0].item() == pytest.approx(5.0) # lat of row 0 + assert coords[1].item() == pytest.approx(30.0) # lon of row 0 + + +def test_yield_africa_dataset_target_values(yield_africa_dataset): + """Target values returned must match the CSV values.""" + expected = [2.1, 1.8, 3.0, 2.5, 1.2, 2.8, 3.3, 1.9, 2.0, 2.7] + for idx, exp in enumerate(expected): + sample = yield_africa_dataset[idx] + assert sample["target"][0].item() == pytest.approx(exp, rel=1e-5) + + +def test_yield_africa_dataset_no_features(yield_africa_csv, tmp_path): + """With use_features=False, tabular is absent and tabular_dim is None.""" + ds = YieldAfricaDataset( + data_dir=yield_africa_csv, + cache_dir=str(tmp_path / "cache"), + modalities={"coords": {}}, + use_target_data=True, + use_aux_data="none", + seed=0, + mock=True, + use_features=False, + ) + sample = ds[0] + assert "tabular" not in sample["eo"] + assert ds.tabular_dim is None + + +def test_yield_africa_dataset_aux_keys(yield_africa_dataset_with_aux): + """When aux is enabled, sample must contain an 'aux' dict.""" + sample = yield_africa_dataset_with_aux[0] + assert "aux" in sample + assert "aux" in sample["aux"] + + +def test_yield_africa_dataset_aux_columns(yield_africa_dataset_with_aux): + """Aux columns picked up must match the aux_* columns in the mock CSV.""" + resolved_aux = yield_africa_dataset_with_aux.use_aux_data["aux"] + assert set(resolved_aux) == set(MOCK_AUX_COLS.keys()) + + +def test_yield_africa_dataset_aux_shape(yield_africa_dataset_with_aux): + """Aux tensor shape must equal the number of aux_* columns.""" + sample = yield_africa_dataset_with_aux[0] + assert sample["aux"]["aux"].shape == (MOCK_N_AUX,) + + +# --------------------------------------------------------------------------- +# Datamodule tests +# --------------------------------------------------------------------------- + + +def test_yield_africa_datamodule_split_sizes(yield_africa_datamodule): + dm = yield_africa_datamodule + assert len(dm.data_train) == 7 + assert len(dm.data_val) == 2 + assert len(dm.data_test) == 1 + + +def test_yield_africa_datamodule_train_loader(yield_africa_datamodule): + dm = yield_africa_datamodule + dm.setup() + batch = next(iter(dm.train_dataloader())) + assert "eo" in batch + assert "coords" in batch["eo"] + assert "tabular" in batch["eo"] + assert batch["eo"]["coords"].shape == (4, 2) + assert batch["eo"]["tabular"].shape == (4, MOCK_TABULAR_DIM) + assert batch["target"].shape == (4, 1) + + +def test_yield_africa_datamodule_split_deterministic(yield_africa_csv, tmp_path): + def make_dm(): + dataset = YieldAfricaDataset( + data_dir=yield_africa_csv, + cache_dir=str(tmp_path / "cache"), + modalities={"coords": {}}, + use_target_data=True, + use_aux_data="none", + seed=42, + mock=True, + ) + return BaseDataModule( + dataset=dataset, + batch_size=4, + train_val_test_split=(7, 2, 1), + num_workers=0, + split_mode="random", + save_split=False, + seed=42, + ) + + dm1, dm2 = make_dm(), make_dm() + assert dm1.data_train.indices == dm2.data_train.indices + assert dm1.data_val.indices == dm2.data_val.indices + + +# --------------------------------------------------------------------------- +# Config tests +# --------------------------------------------------------------------------- + + +def test_yield_africa_config_loads(): + GlobalHydra.instance().clear() + with initialize(version_base="1.3", config_path="../configs"): + cfg = compose( + config_name="train.yaml", + overrides=["experiment=yield_africa_tabular_reg", "hydra.job.chdir=false"], + ) + assert cfg.data._target_ == "src.data.base_datamodule.BaseDataModule" + assert cfg.data.dataset._target_ == "src.data.yield_africa_dataset.YieldAfricaDataset" + assert cfg.model._target_ == "src.models.predictive_model.PredictiveModel" + GlobalHydra.instance().clear() + + +def test_yield_africa_model_instantiates(): + GlobalHydra.instance().clear() + with initialize(version_base="1.3", config_path="../configs"): + cfg = compose( + config_name="train.yaml", + overrides=["experiment=yield_africa_tabular_reg", "hydra.job.chdir=false"], + ) + model = hydra.utils.instantiate(cfg.model) + assert model is not None + GlobalHydra.instance().clear()