diff --git a/data/helios/mapbiomas/dense_config/model.yaml b/data/helios/mapbiomas/dense_config/model.yaml new file mode 100644 index 000000000..5551407fb --- /dev/null +++ b/data/helios/mapbiomas/dense_config/model.yaml @@ -0,0 +1,131 @@ +model: + class_path: rslearn.train.lightning_module.RslearnLightningModule + init_args: + model: + class_path: rslearn.models.multitask.MultiTaskModel + init_args: + encoder: + - class_path: rslearn.models.olmoearth_pretrain.model.OlmoEarth + init_args: + model_id: OLMOEARTH_V1_1_BASE + patch_size: 3 + decoders: + segment: + - class_path: rslearn.models.upsample.Upsample + init_args: + scale_factor: 3 + - class_path: rslearn.models.conv.Conv + init_args: + in_channels: 768 + out_channels: 28 + kernel_size: 1 + activation: + class_path: torch.nn.Identity + - class_path: rslearn.train.tasks.segmentation.SegmentationHead + lr: 0.0001 + scheduler: + class_path: rslearn.train.scheduler.PlateauScheduler + init_args: + factor: 0.2 + patience: 2 + min_lr: 0 + cooldown: 10 +data: + class_path: rslearn.train.data_module.RslearnDataModule + init_args: + path: /weka/dfive-default/rslearn-eai/datasets/mapbiomas_3k + inputs: + sentinel2_l2a: + data_type: "raster" + layers: ["sentinel2_l2a_feb", "sentinel2_l2a_may", "sentinel2_l2a_aug", "sentinel2_l2a_nov"] + bands: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12", "B01", "B09"] + passthrough: true + dtype: FLOAT32 + load_all_layers: true + targets: + data_type: "raster" + layers: ["label_raster"] + bands: ["label"] + dtype: INT32 + is_target: true + task: + class_path: rslearn.train.tasks.multi_task.MultiTask + init_args: + tasks: + segment: + class_path: rslearn.train.tasks.segmentation.SegmentationTask + init_args: + num_classes: 28 + class_id_mapping: + 3: 1 + 4: 2 + 5: 3 + 6: 4 + 9: 5 + 11: 6 + 12: 7 + 15: 8 + 20: 9 + 21: 10 + 23: 11 + 24: 12 + 25: 13 + 29: 14 + 30: 15 + 32: 16 + 33: 17 + 35: 18 + 39: 19 + 40: 20 + 41: 21 + 46: 22 + 47: 23 + 48: 24 + 49: 25 + 50: 26 + 62: 27 + nodata_value: 0 + enable_f1_metric: true + metric_kwargs: + average: "macro" + input_mapping: + segment: + targets: "targets" + batch_size: 32 + num_workers: 16 + default_config: + transforms: + - class_path: rslearn.models.olmoearth_pretrain.norm.OlmoEarthNormalize + init_args: + band_names: + sentinel2_l2a: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12", "B01", "B09"] + train_config: + groups: ["mapbiomas_dense_raster"] + tags: + split: "train" + val_config: + groups: ["mapbiomas_dense_raster"] + tags: + split: "val" + test_config: + groups: ["mapbiomas_dense_raster"] + tags: + split: "val" +trainer: + max_epochs: 50 + callbacks: + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: "epoch" + - class_path: rslearn.train.callbacks.checkpointing.ManagedBestLastCheckpoint + init_args: + monitor: val_segment/F1 + mode: max + - class_path: rslearn.train.callbacks.freeze_unfreeze.FreezeUnfreeze + init_args: + module_selector: ["model", "encoder", 0] + unfreeze_at_epoch: 50 + unfreeze_lr_factor: 10 +management_dir: ${RSLP_PREFIX}/projects +project_name: 20260605_mapbiomas_dense +run_name: placeholder diff --git a/data/helios/mapbiomas/sparse_config/model.yaml b/data/helios/mapbiomas/sparse_config/model.yaml new file mode 100644 index 000000000..1296f5850 --- /dev/null +++ b/data/helios/mapbiomas/sparse_config/model.yaml @@ -0,0 +1,118 @@ +model: + class_path: rslearn.train.lightning_module.RslearnLightningModule + init_args: + model: + class_path: rslearn.models.multitask.MultiTaskModel + init_args: + encoder: + - class_path: rslearn.models.olmoearth_pretrain.model.OlmoEarth + init_args: + model_id: OLMOEARTH_V1_1_BASE + patch_size: 3 + decoders: + segment: + - class_path: rslearn.models.conv.Conv + init_args: + in_channels: 768 + out_channels: 12 + kernel_size: 1 + activation: + class_path: torch.nn.Identity + - class_path: rslearn.train.tasks.segmentation.SegmentationHead + lr: 0.0001 + scheduler: + class_path: rslearn.train.scheduler.PlateauScheduler + init_args: + factor: 0.2 + patience: 2 + min_lr: 0 + cooldown: 10 +data: + class_path: rslearn.train.data_module.RslearnDataModule + init_args: + path: /weka/dfive-default/rslearn-eai/datasets/mapbiomas_3k + inputs: + sentinel2_l2a: + data_type: "raster" + layers: ["sentinel2_l2a_feb", "sentinel2_l2a_may", "sentinel2_l2a_aug", "sentinel2_l2a_nov"] + bands: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12", "B01", "B09"] + passthrough: true + dtype: FLOAT32 + load_all_layers: true + targets: + data_type: "raster" + layers: ["label_raster"] + bands: ["label"] + dtype: INT32 + is_target: true + task: + class_path: rslearn.train.tasks.multi_task.MultiTask + init_args: + tasks: + segment: + class_path: rslearn.train.tasks.segmentation.SegmentationTask + init_args: + num_classes: 12 + class_id_mapping: + 3: 1 + 4: 2 + 9: 3 + 11: 4 + 12: 5 + 15: 6 + 19: 7 + 20: 8 + 24: 9 + 33: 10 + 36: 11 + nodata_value: 0 + enable_f1_metric: true + metric_kwargs: + average: "macro" + input_mapping: + segment: + targets: "targets" + batch_size: 32 + num_workers: 16 + default_config: + crop_size: 48 + transforms: + - class_path: rslearn.models.olmoearth_pretrain.norm.OlmoEarthNormalize + init_args: + band_names: + sentinel2_l2a: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12", "B01", "B09"] + - class_path: rslearn.train.transforms.resize.Resize + init_args: + target_size: [16, 16] + selectors: ["target/segment/classes", "target/segment/valid"] + interpolation: "nearest" + train_config: + groups: ["mapbiomas_expert_sparse"] + tags: + split: "train" + val_config: + groups: ["mapbiomas_expert_sparse"] + tags: + split: "val" + test_config: + groups: ["mapbiomas_expert_sparse"] + tags: + split: "val" +trainer: + max_epochs: 50 + callbacks: + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: "epoch" + - class_path: rslearn.train.callbacks.checkpointing.ManagedBestLastCheckpoint + init_args: + monitor: val_segment/F1 + mode: max + - class_path: rslearn.train.callbacks.freeze_unfreeze.FreezeUnfreeze + init_args: + module_selector: ["model", "encoder", 0] + unfreeze_at_epoch: 50 + unfreeze_lr_factor: 10 +management_dir: ${RSLP_PREFIX}/projects +project_name: 20260605_mapbiomas +run_name: placeholder diff --git a/rslp/mapbiomas/create_windows_dense_raster.py b/rslp/mapbiomas/create_windows_dense_raster.py new file mode 100644 index 000000000..0ee7f19c3 --- /dev/null +++ b/rslp/mapbiomas/create_windows_dense_raster.py @@ -0,0 +1,258 @@ +"""Create rslearn windows with dense target labels from MapBiomas coverage rasters. + +Each window is target_size x target_size pixels at 10m resolution. A patch of +(target_size / 3) pixels is read from the corresponding year's MapBiomas raster +(~30m) and upscaled 3x via nearest-neighbor to fill the target. Omitted classes +are remapped to no-data (0); windows where all pixels are no-data after +remapping are skipped. + +Usage:: + + python -m rslp.mapbiomas.create_windows_dense_raster \ + --ds-name mapbiomas_3k \ + --omit-classes 31 0 +""" + +from __future__ import annotations + +import argparse +import multiprocessing +import os +from datetime import datetime, timezone +from pathlib import Path + +import numpy as np +import pandas as pd +import rasterio +import rasterio.windows +import shapely +import tqdm +from rslearn.const import WGS84_PROJECTION +from rslearn.dataset import Dataset, Window +from rslearn.utils import Projection, STGeometry, get_utm_ups_crs +from rslearn.utils.mp import star_imap_unordered +from rslearn.utils.raster_array import RasterArray +from rslearn.utils.raster_format import GeotiffRasterFormat +from upath import UPath + +from rslp.utils.windows import calculate_bounds + +MY_ROOT = Path(os.environ.get("MY_ROOT", ".")) + +WINDOW_RESOLUTION = 10 +LABEL_LAYER = "label_raster" +BAND_NAME = "label" +NODATA_VALUE = 0 +GROUP_NAME = "mapbiomas_dense_raster" +UPSCALE_FACTOR = 3 + +DEFAULT_CSV = ( + MY_ROOT / "rslearn_projects/rslp/mapbiomas/subsampling/sample_dense_raster_4k.csv" +) +DEFAULT_RASTER_DIR = MY_ROOT / "datasets/mapbiomas/data" +DEFAULT_DS_NAME = "mapbiomas_3k" + + +def create_window( + csv_row: pd.Series, + ds_path: UPath, + target_size: int, + omit_classes: set[int], + raster_dir: Path, +) -> str: + """Create one dense-label window from a MapBiomas coverage raster patch. + + Returns a status string: "created", "omit_class", or "omit_nodata". + """ + target_id = int(csv_row["TARGETID"]) + year = int(csv_row["YEAR"]) + split = csv_row["split"] + window_name = f"{target_id}_{year}" + + # LON/LAT columns are swapped in the source CSV + longitude = csv_row["LAT"] + latitude = csv_row["LON"] + + raster_path = raster_dir / f"brazil_coverage_{year}.tif" + if not raster_path.exists(): + raise FileNotFoundError(f"Raster not found: {raster_path}") + + raster_patch_size = target_size // UPSCALE_FACTOR + half_patch = raster_patch_size // 2 + + with rasterio.open(raster_path) as src: + r, c = src.index(longitude, latitude) + rio_window = rasterio.windows.Window( + col_off=c - half_patch, + row_off=r - half_patch, + width=raster_patch_size, + height=raster_patch_size, + ) + patch = src.read(1, window=rio_window, boundless=True, fill_value=NODATA_VALUE) + + # Remap omitted classes to no-data + if omit_classes: + for cls in omit_classes: + patch[patch == cls] = NODATA_VALUE + + if np.all(patch == NODATA_VALUE): + print( + f"Omitted window {window_name}: all pixels are no-data after class omission" + ) + return "omit_nodata" + + # Upscale via nearest-neighbor (repeat each pixel 3x3) + upscaled = np.repeat( + np.repeat(patch, UPSCALE_FACTOR, axis=0), UPSCALE_FACTOR, axis=1 + ) + + # Trim to exact target_size in case of rounding (e.g. target_size not divisible by 3) + upscaled = upscaled[:target_size, :target_size] + + # Project point to UTM and build rslearn window + src_point = shapely.Point(longitude, latitude) + src_geometry = STGeometry(WGS84_PROJECTION, src_point, None) + dst_crs = get_utm_ups_crs(longitude, latitude) + dst_projection = Projection(dst_crs, WINDOW_RESOLUTION, -WINDOW_RESOLUTION) + dst_geometry = src_geometry.to_projection(dst_projection) + bounds = calculate_bounds(dst_geometry, target_size) + + start_time = datetime(year, 1, 1, tzinfo=timezone.utc) + end_time = datetime(year, 12, 31, tzinfo=timezone.utc) + + dataset = Dataset(ds_path) + window = Window( + storage=dataset.storage, + group=GROUP_NAME, + name=window_name, + projection=dst_projection, + bounds=bounds, + time_range=(start_time, end_time), + options={"split": split}, + ) + window.save() + + raster = upscaled[np.newaxis, :, :].astype(np.uint8) + + raster_out_dir = window.get_raster_dir(LABEL_LAYER, [BAND_NAME]) + GeotiffRasterFormat().encode_raster( + raster_out_dir, + window.projection, + window.bounds, + RasterArray(chw_array=raster), + ) + window.mark_layer_completed(LABEL_LAYER) + return "created" + + +def create_windows_from_csv( + csv_path: UPath, + ds_name: str, + target_size: int, + omit_classes: set[int], + raster_dir: Path, + workers: int, +) -> None: + """Create dense-label windows for all rows in the subsample CSV.""" + rslearn_root = os.environ.get("RSLEARN_EAI_ROOT") + if not rslearn_root: + raise RuntimeError("RSLEARN_EAI_ROOT environment variable is not set") + ds_path = UPath(rslearn_root) / ds_name + + raster_patch_size = target_size // UPSCALE_FACTOR + + df = pd.read_csv(csv_path) + csv_rows = [row for _, row in df.iterrows()] + + print(f"Loaded {len(csv_rows)} rows from {csv_path}") + print(f"Dataset path: {ds_path}") + print(f"Target window size: {target_size}x{target_size} at 10m") + print( + f"Raster patch size: {raster_patch_size}x{raster_patch_size} at ~30m (upscale {UPSCALE_FACTOR}x)" + ) + if omit_classes: + print(f"Omitting classes: {sorted(omit_classes)}") + + jobs = [ + dict( + csv_row=row, + ds_path=ds_path, + target_size=target_size, + omit_classes=omit_classes, + raster_dir=raster_dir, + ) + for row in csv_rows + ] + + counts = {"created": 0, "omit_class": 0, "omit_nodata": 0} + p = multiprocessing.Pool(workers) + outputs = star_imap_unordered(p, create_window, jobs) + for status in tqdm.tqdm(outputs, total=len(jobs)): + counts[status] += 1 + p.close() + + total_omitted = counts["omit_class"] + counts["omit_nodata"] + print("\n" + "=" * 60) + print("WINDOW CREATION SUMMARY (dense raster)") + print("=" * 60) + print(f" Created: {counts['created']}") + print(f" Omitted (class): {counts['omit_class']}") + print(f" Omitted (all nodata): {counts['omit_nodata']}") + print(f" Total omitted: {total_omitted}") + print(f" Total processed: {len(jobs)}") + print("=" * 60) + + +if __name__ == "__main__": + multiprocessing.set_start_method("forkserver") + parser = argparse.ArgumentParser( + description="Create rslearn windows with dense raster labels for MapBiomas.", + ) + parser.add_argument( + "--csv-path", + type=str, + default=str(DEFAULT_CSV), + help="Path to the dense raster subsample CSV.", + ) + parser.add_argument( + "--ds-name", + type=str, + default=DEFAULT_DS_NAME, + help="Dataset name, appended to RSLEARN_EAI_ROOT (default: %(default)s).", + ) + parser.add_argument( + "--raster-dir", + type=str, + default=str(DEFAULT_RASTER_DIR), + help="Directory containing brazil_coverage_{year}.tif rasters.", + ) + parser.add_argument( + "--omit-classes", + type=int, + nargs="*", + default=[], + help="Class IDs to remap to no-data (windows left all-nodata are skipped).", + ) + parser.add_argument( + "--target-size", + type=int, + default=48, + help="Window size in 10m pixels (default: 48). " + "A patch of (target_size / 3) pixels is read from the ~30m raster.", + ) + parser.add_argument( + "--workers", + type=int, + default=16, + help="Number of multiprocessing workers (default: 32).", + ) + args = parser.parse_args() + + create_windows_from_csv( + csv_path=UPath(args.csv_path), + ds_name=args.ds_name, + target_size=args.target_size, + omit_classes=set(args.omit_classes), + raster_dir=Path(args.raster_dir), + workers=args.workers, + ) diff --git a/rslp/mapbiomas/create_windows_expert_sparse.py b/rslp/mapbiomas/create_windows_expert_sparse.py new file mode 100644 index 000000000..e9d8ceca4 --- /dev/null +++ b/rslp/mapbiomas/create_windows_expert_sparse.py @@ -0,0 +1,231 @@ +"""Create rslearn windows with sparse target labels from expert validation points. + +Each window is an *extended* window of size (2 * target_size - 3) pixels at 10m +resolution, centered on the expert point. The expert label is placed as a 3x3 +block at the window center (representing one ~30m validation pixel). All other +pixels are set to the no-data value (0). + +The extended sizing guarantees that any random target_size x target_size crop +from the window will always fully contain the 3x3 labeled block. + +Usage:: + + python -m rslp.mapbiomas.create_windows_expert_sparse \ + --ds-name mapbiomas_3k \ + --omit-classes 29 5 25 +""" + +from __future__ import annotations + +import argparse +import multiprocessing +import os +from datetime import datetime, timezone +from pathlib import Path + +import numpy as np +import pandas as pd +import shapely +import tqdm +from rslearn.const import WGS84_PROJECTION +from rslearn.dataset import Dataset, Window +from rslearn.utils import Projection, STGeometry, get_utm_ups_crs +from rslearn.utils.mp import star_imap_unordered +from rslearn.utils.raster_array import RasterArray +from rslearn.utils.raster_format import GeotiffRasterFormat +from upath import UPath + +from rslp.utils.windows import calculate_bounds + +MY_ROOT = Path(os.environ.get("MY_ROOT", ".")) + +WINDOW_RESOLUTION = 10 +LABEL_LAYER = "label_raster" +BAND_NAME = "label" +NODATA_VALUE = 0 +GROUP_NAME = "mapbiomas_expert_sparse" +LABEL_BLOCK_SIZE = 3 + +DEFAULT_CSV = ( + MY_ROOT / "rslearn_projects/rslp/mapbiomas/subsampling/sample_expert_points_4k.csv" +) +DEFAULT_DS_NAME = "mapbiomas_3k" + + +def compute_extended_window_size(target_size: int) -> int: + """Compute extended window size so every target_size crop contains the 3x3 center.""" + return 2 * target_size - LABEL_BLOCK_SIZE + + +def create_window( + csv_row: pd.Series, + ds_path: UPath, + window_size: int, + omit_classes: set[int], +) -> str: + """Create one sparse-label window from an expert validation point. + + Returns a status string: "created", "omit_class", or "omit_nodata". + """ + class_id = int(csv_row["CLASS"]) + target_id = int(csv_row["TARGETID"]) + year = int(csv_row["YEAR"]) + split = csv_row["split"] + window_name = f"{target_id}_{year}" + + if class_id in omit_classes: + print(f"Omitted window {window_name}: CLASS {class_id} in omit list") + return "omit_class" + + # LON/LAT columns are swapped in the source CSV + longitude = csv_row["LAT"] + latitude = csv_row["LON"] + + src_point = shapely.Point(longitude, latitude) + src_geometry = STGeometry(WGS84_PROJECTION, src_point, None) + dst_crs = get_utm_ups_crs(longitude, latitude) + dst_projection = Projection(dst_crs, WINDOW_RESOLUTION, -WINDOW_RESOLUTION) + dst_geometry = src_geometry.to_projection(dst_projection) + bounds = calculate_bounds(dst_geometry, window_size) + + start_time = datetime(year, 1, 1, tzinfo=timezone.utc) + end_time = datetime(year, 12, 31, tzinfo=timezone.utc) + + dataset = Dataset(ds_path) + window = Window( + storage=dataset.storage, + group=GROUP_NAME, + name=window_name, + projection=dst_projection, + bounds=bounds, + time_range=(start_time, end_time), + options={"split": split}, + ) + window.save() + + raster_h = bounds[3] - bounds[1] + raster_w = bounds[2] - bounds[0] + raster = np.full((1, raster_h, raster_w), NODATA_VALUE, dtype=np.uint8) + + cy, cx = raster_h // 2, raster_w // 2 + half = LABEL_BLOCK_SIZE // 2 + y_lo = max(cy - half, 0) + y_hi = min(cy + half + 1, raster_h) + x_lo = max(cx - half, 0) + x_hi = min(cx + half + 1, raster_w) + raster[0, y_lo:y_hi, x_lo:x_hi] = class_id + + raster_dir = window.get_raster_dir(LABEL_LAYER, [BAND_NAME]) + GeotiffRasterFormat().encode_raster( + raster_dir, + window.projection, + window.bounds, + RasterArray(chw_array=raster), + ) + window.mark_layer_completed(LABEL_LAYER) + return "created" + + +def create_windows_from_csv( + csv_path: UPath, + ds_name: str, + target_size: int, + omit_classes: set[int], + workers: int, +) -> None: + """Create sparse-label windows for all rows in the expert subsample CSV.""" + rslearn_root = os.environ.get("RSLEARN_EAI_ROOT") + if not rslearn_root: + raise RuntimeError("RSLEARN_EAI_ROOT environment variable is not set") + ds_path = UPath(rslearn_root) / ds_name + + window_size = compute_extended_window_size(target_size) + + df = pd.read_csv(csv_path) + csv_rows = [row for _, row in df.iterrows()] + + print(f"Loaded {len(csv_rows)} rows from {csv_path}") + print(f"Dataset path: {ds_path}") + print(f"Target crop size: {target_size}x{target_size}") + print( + f"Extended window size: {window_size}x{window_size} (2*{target_size} - {LABEL_BLOCK_SIZE})" + ) + if omit_classes: + print(f"Omitting classes: {sorted(omit_classes)}") + + jobs = [ + dict( + csv_row=row, + ds_path=ds_path, + window_size=window_size, + omit_classes=omit_classes, + ) + for row in csv_rows + ] + + counts = {"created": 0, "omit_class": 0, "omit_nodata": 0} + p = multiprocessing.Pool(workers) + outputs = star_imap_unordered(p, create_window, jobs) + for status in tqdm.tqdm(outputs, total=len(jobs)): + counts[status] += 1 + p.close() + + total_omitted = counts["omit_class"] + counts["omit_nodata"] + print("\n" + "=" * 60) + print("WINDOW CREATION SUMMARY (expert sparse)") + print("=" * 60) + print(f" Created: {counts['created']}") + print(f" Omitted (class): {counts['omit_class']}") + print(f" Omitted (all nodata): {counts['omit_nodata']}") + print(f" Total omitted: {total_omitted}") + print(f" Total processed: {len(jobs)}") + print("=" * 60) + + +if __name__ == "__main__": + multiprocessing.set_start_method("forkserver") + parser = argparse.ArgumentParser( + description="Create rslearn windows with sparse expert labels for MapBiomas.", + ) + parser.add_argument( + "--csv-path", + type=str, + default=str(DEFAULT_CSV), + help="Path to the expert subsample CSV.", + ) + parser.add_argument( + "--ds-name", + type=str, + default=DEFAULT_DS_NAME, + help="Dataset name, appended to RSLEARN_EAI_ROOT (default: %(default)s).", + ) + parser.add_argument( + "--omit-classes", + type=int, + nargs="*", + default=[], + help="Class IDs to treat as no-data (windows with only omitted classes are skipped).", + ) + parser.add_argument( + "--target-size", + type=int, + default=48, + help="Target crop size in 10m pixels (default: 48). " + "The actual window will be (2*target_size - 3) to ensure any " + "random crop always includes the 3x3 labeled center.", + ) + parser.add_argument( + "--workers", + type=int, + default=16, + help="Number of multiprocessing workers (default: 32).", + ) + args = parser.parse_args() + + create_windows_from_csv( + csv_path=UPath(args.csv_path), + ds_name=args.ds_name, + target_size=args.target_size, + omit_classes=set(args.omit_classes), + workers=args.workers, + ) diff --git a/rslp/mapbiomas/sanity_check.py b/rslp/mapbiomas/sanity_check.py new file mode 100644 index 000000000..a8a12ddd2 --- /dev/null +++ b/rslp/mapbiomas/sanity_check.py @@ -0,0 +1,691 @@ +"""Sanity-check the created rslearn label windows against reference data. + +Validates that: + A) Every window maps to a row in the subsample CSV. + B) TARGETID, CLASS (and LON/LAT for sparse) match the 85k validation shapefile. + C) (dense only) The 48x48 rslearn raster perfectly matches the 16x16 MapBiomas + raster patch upscaled 3x. + D) Produces a 5x5 visualization grid of randomly sampled windows. + +Usage:: + + # Sparse expert labels + python -m rslp.mapbiomas.sanity_check --mode sparse --ds-name mapbiomas_3k + + # Dense raster labels + python -m rslp.mapbiomas.sanity_check --mode dense --ds-name mapbiomas_3k +""" + +from __future__ import annotations + +import argparse +import os +import sys +from pathlib import Path + +import geopandas as gpd +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import rasterio +import rasterio.windows +from pyproj import Transformer + +MY_ROOT = Path(os.environ.get("MY_ROOT", ".")) + +DEFAULT_DS_NAME = "mapbiomas_3k" +DEFAULT_RASTER_DIR = MY_ROOT / "datasets/mapbiomas/data" +DEFAULT_SHP_PATH = ( + MY_ROOT / "datasets/mapbiomas/metadata/mapbiomas_85k_points_validation.shp" +) +DEFAULT_SPARSE_CSV = ( + MY_ROOT / "rslearn_projects/rslp/mapbiomas/subsampling/sample_expert_points_4k.csv" +) +DEFAULT_DENSE_CSV = ( + MY_ROOT / "rslearn_projects/rslp/mapbiomas/subsampling/sample_dense_raster_4k.csv" +) + +NODATA_VALUE = 0 +UPSCALE_FACTOR = 3 +LABEL_BLOCK_SIZE = 3 + +YEAR_RANGE = range(2016, 2023) + +MAPBIOMAS_COLORS = { + 0: "#ffffff", + 1: "#1f8d49", + 3: "#006400", + 4: "#00ff00", + 5: "#687537", + 6: "#76a5af", + 9: "#02d659", + 11: "#519799", + 12: "#d6bc74", + 13: "#d89f5c", + 15: "#edde8e", + 18: "#E974ED", + 19: "#C27BA0", + 20: "#db4d4f", + 21: "#ffefc3", + 23: "#d68fe2", + 24: "#d4271e", + 25: "#9c0027", + 26: "#2532e4", + 29: "#ffaa5f", + 30: "#9065d0", + 31: "#091077", + 32: "#fc8114", + 33: "#2532e4", + 35: "#9065d0", + 36: "#e787f8", + 39: "#f5b3c8", + 40: "#c71585", + 41: "#f54ca9", + 46: "#d68fe2", + 47: "#6b9c32", + 48: "#6b9c32", + 49: "#a1d99b", + 50: "#a1d99b", + 62: "#e6ccff", +} + + +def _class_cmap(data: np.ndarray) -> np.ndarray: + """Map class IDs to RGBA image for visualization.""" + from matplotlib.colors import to_rgba + + h, w = data.shape + rgba = np.ones((h, w, 4), dtype=np.float32) + rgba[:, :, :3] = 0.8 # default grey for unknown classes + + for cls_id, hex_color in MAPBIOMAS_COLORS.items(): + mask = data == cls_id + if not mask.any(): + continue + r, g, b, a = to_rgba(hex_color) + rgba[mask] = [r, g, b, a] + + # Make nodata transparent-ish (white with low alpha) + nodata_mask = data == NODATA_VALUE + rgba[nodata_mask] = [1.0, 1.0, 1.0, 0.3] + + return rgba + + +# --------------------------------------------------------------------------- +# Load reference data +# --------------------------------------------------------------------------- + + +def load_shapefile_reference(shp_path: Path) -> pd.DataFrame: + """Load the 85k validation shapefile and melt to long format. + + Returns DataFrame with columns: TARGETID, LON, LAT, YEAR, CLASS. + """ + gdf = gpd.read_file(shp_path) + static_cols = ["TARGETID", "LON", "LAT"] + frames: list[pd.DataFrame] = [] + for year in YEAR_RANGE: + cls_col = f"CLASS_{year}" + if cls_col not in gdf.columns: + continue + sub = gdf[static_cols + [cls_col]].copy() + sub.columns = [*static_cols, "CLASS"] + sub["YEAR"] = year + sub["CLASS"] = pd.to_numeric(sub["CLASS"], errors="coerce") + frames.append(sub) + long = pd.concat(frames, ignore_index=True) + long = long[long["CLASS"].notna()].copy() + long["CLASS"] = long["CLASS"].astype(int) + long["TARGETID"] = long["TARGETID"].astype(int) + return long + + +# --------------------------------------------------------------------------- +# Check A: every window has a corresponding CSV row +# --------------------------------------------------------------------------- + + +def check_a_csv_lookup( + window_names: list[str], csv_df: pd.DataFrame, mode: str +) -> dict[str, pd.Series]: + """Verify every window name maps to a CSV row. Returns {window_name: row}.""" + csv_df = csv_df.copy() + csv_df["TARGETID"] = csv_df["TARGETID"].astype(int) + csv_df["YEAR"] = csv_df["YEAR"].astype(int) + csv_df["_key"] = csv_df["TARGETID"].astype(str) + "_" + csv_df["YEAR"].astype(str) + csv_lookup = csv_df.set_index("_key") + + ok, fail = 0, 0 + result: dict[str, pd.Series] = {} + for wname in window_names: + if wname in csv_lookup.index: + result[wname] = csv_lookup.loc[wname] + ok += 1 + else: + print(f" [FAIL] Window {wname} not found in {mode} CSV") + fail += 1 + + print( + f" Check A ({mode}): {ok}/{ok + fail} windows found in CSV" + f" ({fail} missing)" + ) + return result + + +# --------------------------------------------------------------------------- +# Check B: cross-reference with shapefile +# --------------------------------------------------------------------------- + + +def check_b_shapefile_reference( + lookup: dict[str, pd.Series], + ref_df: pd.DataFrame, + mode: str, +) -> None: + """Verify TARGETID, CLASS, and (for sparse) LON/LAT match the shapefile.""" + ref_df = ref_df.copy() + ref_df["_key"] = ref_df["TARGETID"].astype(str) + "_" + ref_df["YEAR"].astype(str) + ref_lookup = ref_df.set_index("_key") + + ok, fail_missing, fail_class, fail_lonlat = 0, 0, 0, 0 + for wname, csv_row in lookup.items(): + if wname not in ref_lookup.index: + print(f" [FAIL] Window {wname}: TARGETID/YEAR not in shapefile") + fail_missing += 1 + continue + + ref_row = ref_lookup.loc[wname] + csv_class = int(csv_row["CLASS"]) + ref_class = int(ref_row["CLASS"]) + if csv_class != ref_class: + print( + f" [FAIL] Window {wname}: CLASS mismatch " + f"(csv={csv_class}, shp={ref_class})" + ) + fail_class += 1 + continue + + if mode == "sparse": + # LON/LAT columns are swapped in CSV relative to standard meaning, + # but the shapefile has the same swap, so compare directly. + csv_lon, csv_lat = float(csv_row["LON"]), float(csv_row["LAT"]) + ref_lon, ref_lat = float(ref_row["LON"]), float(ref_row["LAT"]) + if not ( + np.isclose(csv_lon, ref_lon, atol=1e-4) + and np.isclose(csv_lat, ref_lat, atol=1e-4) + ): + print( + f" [FAIL] Window {wname}: LON/LAT mismatch " + f"(csv=({csv_lon:.6f},{csv_lat:.6f}), " + f"shp=({ref_lon:.6f},{ref_lat:.6f}))" + ) + fail_lonlat += 1 + continue + + ok += 1 + + total = ok + fail_missing + fail_class + fail_lonlat + parts = [] + if fail_missing: + parts.append(f"{fail_missing} missing") + if fail_class: + parts.append(f"{fail_class} class mismatch") + if fail_lonlat: + parts.append(f"{fail_lonlat} lon/lat mismatch") + detail = f" ({', '.join(parts)})" if parts else "" + print(f" Check B ({mode}): {ok}/{total} passed{detail}") + + +# --------------------------------------------------------------------------- +# Check C (dense only): raster alignment +# --------------------------------------------------------------------------- + + +def check_c_dense_raster_alignment( + lookup: dict[str, pd.Series], + ds_path: Path, + raster_dir: Path, + target_size: int, +) -> list[str]: + """Verify the rslearn 48x48 label matches the MapBiomas 16x16 patch x3. + + Returns list of window names that failed the check. + """ + raster_patch_size = target_size // UPSCALE_FACTOR + half_patch = raster_patch_size // 2 + group_name = "mapbiomas_dense_raster" + + ok, fail_shape, fail_align = 0, 0, 0 + failed_windows: list[str] = [] + raster_cache: dict[int, rasterio.DatasetReader] = {} + + try: + for wname, csv_row in lookup.items(): + year = int(csv_row["YEAR"]) + # LON/LAT swapped in CSV + longitude = float(csv_row["LAT"]) + latitude = float(csv_row["LON"]) + + # Read reference patch from MapBiomas raster + if year not in raster_cache: + raster_path = raster_dir / f"brazil_coverage_{year}.tif" + raster_cache[year] = rasterio.open(raster_path) + src = raster_cache[year] + r, c = src.index(longitude, latitude) + rio_win = rasterio.windows.Window( + col_off=c - half_patch, + row_off=r - half_patch, + width=raster_patch_size, + height=raster_patch_size, + ) + ref_patch = src.read(1, window=rio_win, boundless=True, fill_value=0) + + # Read rslearn label raster + tif_path = ( + ds_path + / "windows" + / group_name + / wname + / "layers" + / "label_raster" + / "label" + / "geotiff.tif" + ) + if not tif_path.exists(): + print(f" [FAIL] Window {wname}: rslearn geotiff not found") + fail_shape += 1 + failed_windows.append(wname) + continue + + with rasterio.open(tif_path) as rsrc: + rslearn_raster = rsrc.read(1) + + if rslearn_raster.shape != (target_size, target_size): + print( + f" [FAIL] Window {wname}: rslearn raster shape " + f"{rslearn_raster.shape} != ({target_size},{target_size})" + ) + fail_shape += 1 + failed_windows.append(wname) + continue + + # Downsample rslearn raster back to raster_patch_size and compare + downsampled = rslearn_raster[::UPSCALE_FACTOR, ::UPSCALE_FACTOR] + if downsampled.shape != ref_patch.shape: + print( + f" [FAIL] Window {wname}: downsampled shape " + f"{downsampled.shape} != ref {ref_patch.shape}" + ) + fail_shape += 1 + failed_windows.append(wname) + continue + + # Only compare pixels where both sides have data. The rslearn + # raster may have nodata where omitted classes (e.g. 31) were + # zeroed out during window creation. + both_valid = (downsampled != NODATA_VALUE) & (ref_patch != NODATA_VALUE) + mismatch = (downsampled != ref_patch) & both_valid + if mismatch.any(): + n_diff = int(mismatch.sum()) + print( + f" [FAIL] Window {wname}: {n_diff}/{int(both_valid.sum())} " + f"valid pixels differ" + ) + fail_align += 1 + failed_windows.append(wname) + continue + + ok += 1 + finally: + for src in raster_cache.values(): + src.close() + + total = ok + fail_shape + fail_align + parts = [] + if fail_shape: + parts.append(f"{fail_shape} shape/missing") + if fail_align: + parts.append(f"{fail_align} pixel mismatch") + detail = f" ({', '.join(parts)})" if parts else "" + print(f" Check C (dense): {ok}/{total} passed{detail}") + return failed_windows + + +# --------------------------------------------------------------------------- +# Check D: visualization +# --------------------------------------------------------------------------- + + +def visualize_sparse( + lookup: dict[str, pd.Series], + ds_path: Path, + out_path: Path, + n: int = 25, + seed: int = 42, +) -> None: + """5x5 grid of sparse windows with labeled pixels highlighted.""" + group_name = "mapbiomas_expert_sparse" + rng = np.random.default_rng(seed) + keys = list(lookup.keys()) + chosen = rng.choice(keys, size=min(n, len(keys)), replace=False) + + rows, cols = 5, 5 + fig, axes = plt.subplots(rows, cols, figsize=(18, 18)) + fig.suptitle("Sparse expert label windows (labeled pixels colored)", fontsize=14) + + for idx, ax in enumerate(axes.flat): + if idx >= len(chosen): + ax.axis("off") + continue + + wname = chosen[idx] + csv_row = lookup[wname] + ref_class = int(csv_row["CLASS"]) + + tif_path = ( + ds_path + / "windows" + / group_name + / wname + / "layers" + / "label_raster" + / "label" + / "geotiff.tif" + ) + if not tif_path.exists(): + ax.set_title(f"{wname}\nMISSING", fontsize=7) + ax.axis("off") + continue + + with rasterio.open(tif_path) as src: + raster = src.read(1) + raster_crs = src.crs + raster_transform = src.transform + + rgba = _class_cmap(raster) + ax.imshow(rgba, interpolation="nearest") + + # Overlay the reference lat/lon as a cyan dot (projected to pixel coords) + ref_lon = float(csv_row["LAT"]) # LON/LAT swapped in CSV + ref_lat = float(csv_row["LON"]) + transformer = Transformer.from_crs("EPSG:4326", raster_crs, always_xy=True) + ref_x, ref_y = transformer.transform(ref_lon, ref_lat) + inv_transform = ~raster_transform + ref_col_px, ref_row_px = inv_transform * (ref_x, ref_y) + ax.plot( + ref_col_px, + ref_row_px, + "o", + color="cyan", + markersize=6, + markeredgecolor="black", + markeredgewidth=0.8, + zorder=5, + ) + + label_vals = np.unique(raster[raster != NODATA_VALUE]) + label_str = ",".join(str(v) for v in label_vals) + ax.set_title( + f"{wname}\nlabel={label_str} ref={ref_class}", + fontsize=7, + ) + ax.set_xticks([]) + ax.set_yticks([]) + + plt.tight_layout() + fig.savefig(out_path, dpi=150, bbox_inches="tight") + print(f" Saved sparse visualization to {out_path}") + plt.close(fig) + + +def visualize_dense( + lookup: dict[str, pd.Series], + ds_path: Path, + raster_dir: Path, + target_size: int, + out_path: Path, + n: int = 25, + seed: int = 42, + priority_windows: list[str] | None = None, +) -> None: + """5x5 grid: each cell shows rslearn 48x48 and reference 16x16 side by side. + + Windows in *priority_windows* (e.g. those that failed Check C) are shown + first; remaining slots are filled by random sampling. + """ + group_name = "mapbiomas_dense_raster" + raster_patch_size = target_size // UPSCALE_FACTOR + half_patch = raster_patch_size // 2 + rng = np.random.default_rng(seed) + keys = list(lookup.keys()) + + priority = [] + if priority_windows: + priority = [w for w in priority_windows if w in lookup][:n] + remaining_n = n - len(priority) + remaining_pool = [k for k in keys if k not in set(priority)] + random_pick = ( + list( + rng.choice( + remaining_pool, + size=min(remaining_n, len(remaining_pool)), + replace=False, + ) + ) + if remaining_n > 0 and remaining_pool + else [] + ) + chosen = priority + random_pick + + rows, cols = 5, 5 + fig, axes = plt.subplots(rows, cols * 2, figsize=(24, 18)) + fig.suptitle( + "Dense raster label windows (left: rslearn 48x48, right: reference 16x16)", + fontsize=14, + ) + + raster_cache: dict[int, rasterio.DatasetReader] = {} + + try: + for idx in range(rows * cols): + ax_left = axes[idx // cols, (idx % cols) * 2] + ax_right = axes[idx // cols, (idx % cols) * 2 + 1] + + if idx >= len(chosen): + ax_left.axis("off") + ax_right.axis("off") + continue + + wname = chosen[idx] + csv_row = lookup[wname] + year = int(csv_row["YEAR"]) + longitude = float(csv_row["LAT"]) + latitude = float(csv_row["LON"]) + + # rslearn raster + tif_path = ( + ds_path + / "windows" + / group_name + / wname + / "layers" + / "label_raster" + / "label" + / "geotiff.tif" + ) + if not tif_path.exists(): + ax_left.set_title(f"{wname}\nMISSING", fontsize=6) + ax_left.axis("off") + ax_right.axis("off") + continue + + with rasterio.open(tif_path) as src: + rslearn_raster = src.read(1) + + # Reference raster + if year not in raster_cache: + rp = raster_dir / f"brazil_coverage_{year}.tif" + raster_cache[year] = rasterio.open(rp) + src_ref = raster_cache[year] + r, c = src_ref.index(longitude, latitude) + rio_win = rasterio.windows.Window( + col_off=c - half_patch, + row_off=r - half_patch, + width=raster_patch_size, + height=raster_patch_size, + ) + ref_patch = src_ref.read(1, window=rio_win, boundless=True, fill_value=0) + + is_failed = priority_windows and wname in set(priority_windows) + title_prefix = "[FAIL C] " if is_failed else "" + title_color = "red" if is_failed else "black" + + ax_left.imshow(_class_cmap(rslearn_raster), interpolation="nearest") + ax_left.set_title( + f"{title_prefix}{wname}\nrslearn {rslearn_raster.shape}", + fontsize=6, + color=title_color, + ) + ax_left.set_xticks([]) + ax_left.set_yticks([]) + + ax_right.imshow(_class_cmap(ref_patch), interpolation="nearest") + ax_right.set_title(f"ref {ref_patch.shape}", fontsize=6) + ax_right.set_xticks([]) + ax_right.set_yticks([]) + finally: + for src in raster_cache.values(): + src.close() + + plt.tight_layout() + fig.savefig(out_path, dpi=150, bbox_inches="tight") + print(f" Saved dense visualization to {out_path}") + plt.close(fig) + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + + +def main() -> None: + """Sanity-check MapBiomas rslearn label windows.""" + parser = argparse.ArgumentParser( + description="Sanity-check MapBiomas rslearn label windows.", + ) + parser.add_argument( + "--mode", + type=str, + required=True, + choices=["sparse", "dense"], + help="Which window set to check.", + ) + parser.add_argument( + "--ds-name", + type=str, + default=DEFAULT_DS_NAME, + help="Dataset name under RSLEARN_EAI_ROOT.", + ) + parser.add_argument( + "--shp-path", + type=str, + default=str(DEFAULT_SHP_PATH), + help="Path to the 85k validation shapefile.", + ) + parser.add_argument( + "--raster-dir", + type=str, + default=str(DEFAULT_RASTER_DIR), + help="Directory containing brazil_coverage_{year}.tif rasters (dense only).", + ) + parser.add_argument( + "--target-size", + type=int, + default=48, + help="Target crop size in 10m pixels (default: 48).", + ) + parser.add_argument( + "--seed", + type=int, + default=42, + help="Random seed for visualization sampling.", + ) + parser.add_argument( + "--out-dir", + type=str, + default=str(MY_ROOT / "rslearn_projects/rslp/mapbiomas"), + help="Directory for output plots.", + ) + args = parser.parse_args() + + rslearn_root = os.environ.get("RSLEARN_EAI_ROOT") + if not rslearn_root: + raise RuntimeError("RSLEARN_EAI_ROOT environment variable is not set") + ds_path = Path(rslearn_root) / args.ds_name + + mode = args.mode + raster_dir = Path(args.raster_dir) + shp_path = Path(args.shp_path) + out_dir = Path(args.out_dir) + target_size = args.target_size + + if mode == "sparse": + group_name = "mapbiomas_expert_sparse" + csv_path = DEFAULT_SPARSE_CSV + else: + group_name = "mapbiomas_dense_raster" + csv_path = DEFAULT_DENSE_CSV + + windows_dir = ds_path / "windows" / group_name + if not windows_dir.exists(): + print(f"ERROR: windows directory not found: {windows_dir}") + sys.exit(1) + + window_names = sorted([p.name for p in windows_dir.iterdir() if p.is_dir()]) + print(f"Found {len(window_names)} windows in {windows_dir}") + + # Load reference data + print(f"Loading subsample CSV: {csv_path}") + csv_df = pd.read_csv(csv_path) + print(f"Loading shapefile: {shp_path}") + ref_df = load_shapefile_reference(shp_path) + + # --- Check A --- + print("\n--- Check A: CSV lookup ---") + lookup = check_a_csv_lookup(window_names, csv_df, mode) + + # --- Check B --- + print("\n--- Check B: Shapefile cross-reference ---") + check_b_shapefile_reference(lookup, ref_df, mode) + + # --- Check C (dense only) --- + failed_c: list[str] = [] + if mode == "dense": + print("\n--- Check C: Dense raster alignment ---") + failed_c = check_c_dense_raster_alignment( + lookup, ds_path, raster_dir, target_size + ) + + # --- Check D: Visualization --- + print(f"\n--- Check D: Visualization ({mode}) ---") + out_dir.mkdir(parents=True, exist_ok=True) + plot_path = out_dir / f"sanity_check_{mode}.png" + if mode == "sparse": + visualize_sparse(lookup, ds_path, plot_path, seed=args.seed) + else: + visualize_dense( + lookup, + ds_path, + raster_dir, + target_size, + plot_path, + seed=args.seed, + priority_windows=failed_c, + ) + + print("\nDone.") + + +if __name__ == "__main__": + main() diff --git a/rslp/mapbiomas/subsampling/__init__.py b/rslp/mapbiomas/subsampling/__init__.py new file mode 100644 index 000000000..be217b121 --- /dev/null +++ b/rslp/mapbiomas/subsampling/__init__.py @@ -0,0 +1 @@ +"""MapBiomas balanced subsampling utilities.""" diff --git a/rslp/mapbiomas/subsampling/sample_dense_raster.py b/rslp/mapbiomas/subsampling/sample_dense_raster.py new file mode 100644 index 000000000..7ffa3bcd6 --- /dev/null +++ b/rslp/mapbiomas/subsampling/sample_dense_raster.py @@ -0,0 +1,563 @@ +"""Sample dense raster windows around expert-point locations. + +For each point in the subsample CSV produced by sample_expert_points.py, +reads a search-size patch from the matching year's MapBiomas raster, samples +n-candidates random window-size sub-windows, scores each by weighted +minority-class pixel count, and keeps the best one. Outputs an updated +subsample CSV whose LON/LAT columns reflect the selected window centre, +plus per-class/per-split statistics. + +Setting --search-size equal to --window-size with --n-candidates 1 reproduces +a simple centred-window extraction (no optimisation). +""" + +from __future__ import annotations + +import argparse +import os +from pathlib import Path + +import numpy as np +import pandas as pd +import rasterio +from rasterio.windows import Window + +MY_ROOT = Path(os.environ.get("MY_ROOT", ".")) + +DEFAULT_CSV = ( + MY_ROOT / "rslearn_projects/rslp/mapbiomas/subsampling/sample_expert_points_4k.csv" +) +DEFAULT_RASTER_DIR = MY_ROOT / "datasets/mapbiomas/data" +DEFAULT_OUT_DIR = MY_ROOT / "rslearn_projects/rslp/mapbiomas/subsampling" +DEFAULT_LEGEND = ( + MY_ROOT / "datasets/mapbiomas/metadata/Codigos-da-legenda-colecao-10.csv" +) +DEFAULT_HIERARCHY = MY_ROOT / "datasets/mapbiomas/metadata/hierarchy.csv" + +MINORITY_CLASSES: set[int] = { + 6, + 41, + 5, + 29, + 46, + 48, + 40, + 47, + 25, + 50, + 49, + 35, + 32, + 23, + 30, + 62, +} + + +# --------------------------------------------------------------------------- +# Legend / summary utilities +# --------------------------------------------------------------------------- + + +def load_legend(path: Path) -> dict[int, str]: + """Return {class_id: english_description} from the MapBiomas legend CSV.""" + legend = pd.read_csv(path, sep="\t") + legend["Description"] = legend["Description"].str.strip() + return dict(zip(legend["Class_ID"].astype(int), legend["Description"])) + + +def load_hierarchy(path: Path) -> dict[int, dict]: + """Parse the hierarchy CSV and return per-class hierarchy info. + + Returns ``{class_id: {"leaf_level": int, + "parent_class_id": int | None, + "parent_leaf_level": int | None, + "parent_class_desc": str | None}}``. + """ + h = pd.read_csv(path) + result: dict[int, dict] = {} + for cid, grp in h.groupby("Class_ID"): + cid = int(cid) + leaf_level = int(grp["Leaf_Level"].iloc[0]) + + if leaf_level > 1: + parent_row = grp[grp["Hierarchy_Level"] == leaf_level - 1].iloc[0] + parent_class_id = int(parent_row["Level_Class_ID"]) + parent_leaf_level = leaf_level - 1 + parent_class_desc = str(parent_row["Level_Description"]) + else: + parent_class_id = None + parent_leaf_level = None + parent_class_desc = None + + result[cid] = { + "leaf_level": leaf_level, + "parent_class_id": parent_class_id, + "parent_leaf_level": parent_leaf_level, + "parent_class_desc": parent_class_desc, + } + return result + + +def build_summary( + per_window: pd.DataFrame, + legend: dict[int, str], + window_size: int, + hierarchy: dict[int, dict] | None = None, +) -> pd.DataFrame: + """Aggregate per-window counts into a global summary with train/val breakdown.""" + class_cols = [c for c in per_window.columns if c.startswith("class_")] + class_ids = [int(c.split("_")[1]) for c in class_cols] + + pixels_per_window = window_size * window_size + + rows: list[dict] = [] + for cid, col in zip(class_ids, class_cols): + vals_all = per_window[col].values + total_all = int(vals_all.sum()) + + rec: dict = { + "class_id": cid, + "class_name": legend.get(cid, "unknown"), + "leaf_level": None, + "parent_class_id": None, + "parent_leaf_level": None, + "parent_class_desc": None, + "total_pixels": total_all, + "frac_of_all": total_all / (len(per_window) * pixels_per_window), + "mean_per_window": float(vals_all.mean()), + "std_per_window": float(vals_all.std()), + } + + if hierarchy and cid in hierarchy: + hi = hierarchy[cid] + rec["leaf_level"] = hi["leaf_level"] + rec["parent_class_id"] = hi["parent_class_id"] + rec["parent_leaf_level"] = hi["parent_leaf_level"] + rec["parent_class_desc"] = hi["parent_class_desc"] + + for split in ("train", "val"): + mask = per_window["split"] == split + vals = per_window.loc[mask, col].values + n_windows = int(mask.sum()) + total = int(vals.sum()) + rec[f"{split}_total_pixels"] = total + rec[f"{split}_frac"] = ( + total / (n_windows * pixels_per_window) if n_windows else 0.0 + ) + rec[f"{split}_mean_per_window"] = float(vals.mean()) if n_windows else 0.0 + rec[f"{split}_std_per_window"] = float(vals.std()) if n_windows else 0.0 + + rows.append(rec) + + summary = ( + pd.DataFrame(rows) + .sort_values("total_pixels", ascending=False) + .reset_index(drop=True) + ) + return summary + + +def print_summary( + summary: pd.DataFrame, per_window: pd.DataFrame, window_size: int +) -> None: + """Print a human-readable summary to stdout.""" + pixels_per_window = window_size * window_size + n_total = len(per_window) + n_train = (per_window["split"] == "train").sum() + n_val = (per_window["split"] == "val").sum() + + print("\n" + "=" * 140) + print("WINDOW CLASS STATISTICS SUMMARY") + print( + f" Window size: {window_size}x{window_size} ({pixels_per_window} pixels/window)" + ) + print(f" Samples: {n_total} total ({n_train} train, {n_val} val)") + print(f" Total pixels across all windows: {n_total * pixels_per_window:,}") + print("=" * 140) + + has_hierarchy = ( + "leaf_level" in summary.columns and summary["leaf_level"].notna().any() + ) + + if has_hierarchy: + fmt = ( + "{:<6s} {:<35s} {:<5s} {:<6s} {:<25s}" + " {:>12s} {:>8s} {:>12s} {:>8s} {:>12s} {:>8s}" + ) + print( + fmt.format( + "ID", + "Class", + "Lvl", + "ParID", + "Parent", + "All pixels", + "All %", + "Train px", + "Train %", + "Val px", + "Val %", + ) + ) + else: + fmt = "{:<6s} {:<35s} {:>12s} {:>8s} {:>12s} {:>8s} {:>12s} {:>8s}" + print( + fmt.format( + "ID", + "Class", + "All pixels", + "All %", + "Train px", + "Train %", + "Val px", + "Val %", + ) + ) + print("-" * 140) + + for _, r in summary.iterrows(): + if has_hierarchy: + lvl = str(int(r["leaf_level"])) if pd.notna(r["leaf_level"]) else "" + par_id = ( + str(int(r["parent_class_id"])) + if pd.notna(r["parent_class_id"]) + else "-" + ) + par_desc = ( + str(r["parent_class_desc"])[:25] + if pd.notna(r["parent_class_desc"]) + else "-" + ) + print( + fmt.format( + str(int(r["class_id"])), + r["class_name"][:35], + lvl, + par_id, + par_desc, + f"{int(r['total_pixels']):,}", + f"{r['frac_of_all']:.3%}", + f"{int(r['train_total_pixels']):,}", + f"{r['train_frac']:.3%}", + f"{int(r['val_total_pixels']):,}", + f"{r['val_frac']:.3%}", + ) + ) + else: + print( + ( + "{:<6s} {:<35s} {:>12s} {:>8s} {:>12s} {:>8s} {:>12s} {:>8s}" + ).format( + str(int(r["class_id"])), + r["class_name"][:35], + f"{int(r['total_pixels']):,}", + f"{r['frac_of_all']:.3%}", + f"{int(r['train_total_pixels']):,}", + f"{r['train_frac']:.3%}", + f"{int(r['val_total_pixels']):,}", + f"{r['val_frac']:.3%}", + ) + ) + print("=" * 140 + "\n") + + +# Weights = 1 / observed_fraction from the baseline window_class_stats run. +MINORITY_WEIGHTS: dict[int, float] = { + 6: 1.0 / 0.01899, + 41: 1.0 / 0.01792, + 5: 1.0 / 0.00772, + 29: 1.0 / 0.00634, + 46: 1.0 / 0.00583, + 48: 1.0 / 0.00474, + 40: 1.0 / 0.00246, + 47: 1.0 / 0.00235, + 25: 1.0 / 0.00183, + 50: 1.0 / 0.00081, + 49: 1.0 / 0.00066, + 35: 1.0 / 0.00059, + 32: 1.0 / 0.00045, + 23: 1.0 / 0.00041, + 30: 1.0 / 0.00034, + 62: 1.0 / 0.00033, +} + +# Pre-build a uint8-indexed lookup table for fast vectorised scoring. +_WEIGHT_LUT = np.zeros(256, dtype=np.float64) +for _cls, _w in MINORITY_WEIGHTS.items(): + _WEIGHT_LUT[_cls] = _w + + +def _score_subwindow(patch: np.ndarray, row_off: int, col_off: int, size: int) -> float: + """Sum minority weights for all pixels in the sub-window.""" + sub = patch[row_off : row_off + size, col_off : col_off + size] + return float(_WEIGHT_LUT[sub].sum()) + + +def optimize_windows( + df: pd.DataFrame, + raster_dir: Path, + window_size: int, + search_size: int, + n_candidates: int, + rng: np.random.Generator, +) -> tuple[pd.DataFrame, pd.DataFrame]: + """Find the best 32x32 sub-window per sample and return updated CSV + per-window counts. + + Returns: + ------- + optimized_df : pd.DataFrame + Same schema as the input subsample CSV but with LON/LAT shifted to the + selected window centre. + per_window : pd.DataFrame + Per-window class counts (same format as window_class_stats output). + """ + half_search = search_size // 2 + half_win = window_size // 2 + max_offset = ( + search_size - window_size + ) # valid sub-window origin range [0, max_offset] + + updated_rows: list[dict] = [] + count_records: list[dict] = [] + + for year, group in df.groupby("YEAR"): + raster_path = raster_dir / f"brazil_coverage_{year}.tif" + if not raster_path.exists(): + raise FileNotFoundError(f"Raster not found: {raster_path}") + + print(f" Year {year}: {len(group)} samples from {raster_path.name}") + + with rasterio.open(raster_path) as src: + for _, row in group.iterrows(): + actual_lon = row["LAT"] + actual_lat = row["LON"] + r, c = src.index(actual_lon, actual_lat) + + big_win = Window( + col_off=c - half_search, + row_off=r - half_search, + width=search_size, + height=search_size, + ) + patch = src.read(1, window=big_win, boundless=True, fill_value=0) + + offsets = np.column_stack( + [ + rng.integers(0, max_offset + 1, size=n_candidates), + rng.integers(0, max_offset + 1, size=n_candidates), + ] + ) + + best_score = -1.0 + best_idx = 0 + for i in range(n_candidates): + ro, co = int(offsets[i, 0]), int(offsets[i, 1]) + score = _score_subwindow(patch, ro, co, window_size) + if score > best_score: + best_score = score + best_idx = i + + best_ro = int(offsets[best_idx, 0]) + best_co = int(offsets[best_idx, 1]) + + # Raster row/col of the selected sub-window centre + new_r = (r - half_search) + best_ro + half_win + new_c = (c - half_search) + best_co + half_win + new_lon, new_lat = src.xy(new_r, new_c) + + # Build updated CSV row (LON/LAT columns are swapped in source) + updated_rows.append( + { + "TARGETID": int(row["TARGETID"]), + "LON": new_lat, + "LAT": new_lon, + "YEAR": int(row["YEAR"]), + "CLASS": int(row["CLASS"]), + "BORDA": int(row["BORDA"]), + "COUNT": int(row["COUNT"]), + "CARTA_2": row["CARTA_2"], + "DECLIVIDAD": row["DECLIVIDAD"], + "split": row["split"], + } + ) + + # Per-window class counts for the selected sub-window + sub = patch[ + best_ro : best_ro + window_size, best_co : best_co + window_size + ] + codes, counts = np.unique(sub, return_counts=True) + rec: dict = { + "TARGETID": int(row["TARGETID"]), + "YEAR": int(row["YEAR"]), + "center_class": int(row["CLASS"]), + "split": row["split"], + } + for code, cnt in zip(codes.tolist(), counts.tolist()): + rec[f"class_{code}"] = cnt + count_records.append(rec) + + optimized_df = pd.DataFrame(updated_rows) + out_cols = [ + "TARGETID", + "LON", + "LAT", + "YEAR", + "CLASS", + "BORDA", + "COUNT", + "CARTA_2", + "DECLIVIDAD", + "split", + ] + optimized_df = ( + optimized_df[out_cols] + .sort_values( + ["split", "CLASS", "YEAR"], + ) + .reset_index(drop=True) + ) + + per_window = pd.DataFrame(count_records).fillna(0) + class_cols = sorted( + [c for c in per_window.columns if c.startswith("class_")], + key=lambda c: int(c.split("_")[1]), + ) + meta_cols = ["TARGETID", "YEAR", "center_class", "split"] + per_window = per_window[meta_cols + class_cols] + for col in class_cols: + per_window[col] = per_window[col].astype(int) + + return optimized_df, per_window + + +def main() -> None: + """Select minority-boosting 32x32 windows for each subsample point.""" + parser = argparse.ArgumentParser( + description="Select minority-boosting 32x32 windows for each subsample point.", + ) + parser.add_argument( + "--csv-path", + type=str, + default=str(DEFAULT_CSV), + help="Path to the subsample CSV (output of sample_expert_points.py).", + ) + parser.add_argument( + "--raster-dir", + type=str, + default=str(DEFAULT_RASTER_DIR), + help="Directory containing brazil_coverage_{year}.tif rasters.", + ) + parser.add_argument( + "--window-size", + type=int, + default=16, + help="Side length of the target window in pixels (default: 16).", + ) + parser.add_argument( + "--search-size", + type=int, + default=512, + help="Side length of the search neighbourhood in pixels (default: 512).", + ) + parser.add_argument( + "--n-candidates", + type=int, + default=64, + help="Number of random sub-window candidates to evaluate (default: 64).", + ) + parser.add_argument( + "--seed", + type=int, + default=42, + help="Random seed for reproducibility.", + ) + parser.add_argument( + "--out-dir", + type=str, + default=str(DEFAULT_OUT_DIR), + help="Directory for output files.", + ) + parser.add_argument( + "--legend-path", + type=str, + default=str(DEFAULT_LEGEND), + help="Path to the MapBiomas legend CSV (tab-separated).", + ) + parser.add_argument( + "--hierarchy-path", + type=str, + default=str(DEFAULT_HIERARCHY), + help="Path to the MapBiomas hierarchy CSV.", + ) + args = parser.parse_args() + + csv_path = Path(args.csv_path) + raster_dir = Path(args.raster_dir) + out_dir = Path(args.out_dir) + legend_path = Path(args.legend_path) + hierarchy_path = Path(args.hierarchy_path) + window_size = args.window_size + search_size = args.search_size + + legend = load_legend(legend_path) + hierarchy = load_hierarchy(hierarchy_path) if hierarchy_path.exists() else None + minority_names = { + cid: legend.get(cid, str(cid)) for cid in sorted(MINORITY_CLASSES) + } + + print("=" * 70) + print("MINORITY-BOOSTING WINDOW SELECTION") + print("=" * 70) + print( + f" Search window: {search_size}x{search_size} pixels centred on each point" + ) + print(f" Target window: {window_size}x{window_size} pixels") + print(f" Candidates: {args.n_candidates} random sub-windows per point") + print(f" Seed: {args.seed}") + print(f" Minority classes ({len(MINORITY_CLASSES)}):") + for cid in sorted(MINORITY_CLASSES): + print( + f" {cid:>3d} {minority_names[cid]:<35s} weight={MINORITY_WEIGHTS[cid]:,.1f}" + ) + print("\n Scoring: for each candidate sub-window, sum the weight of every") + print(" minority-class pixel. Keep the sub-window with the highest score.") + print("=" * 70) + + print(f"\nLoading subsample CSV: {csv_path}") + df = pd.read_csv(csv_path) + print( + f" {len(df)} samples ({(df['split'] == 'train').sum()} train, " + f"{(df['split'] == 'val').sum()} val)" + ) + + rng = np.random.default_rng(args.seed) + + print( + f"\nSearching {search_size}x{search_size} neighbourhood, " + f"{args.n_candidates} candidates per point …" + ) + optimized_df, per_window = optimize_windows( + df, + raster_dir, + window_size, + search_size, + args.n_candidates, + rng, + ) + + out_dir.mkdir(parents=True, exist_ok=True) + csv_out = out_dir / "sample_dense_raster_4k.csv" + optimized_df.to_csv(csv_out, index=False) + print(f"\nWrote optimized subsample: {csv_out} ({len(optimized_df)} rows)") + + summary = build_summary(per_window, legend, window_size, hierarchy) + summary_path = out_dir / "sample_dense_raster_window_stats_summary.csv" + summary.to_csv(summary_path, index=False) + print(f"Wrote summary stats: {summary_path} ({len(summary)} rows)") + + print_summary(summary, per_window, window_size) + + +if __name__ == "__main__": + main() diff --git a/rslp/mapbiomas/subsampling/sample_expert_points.py b/rslp/mapbiomas/subsampling/sample_expert_points.py new file mode 100644 index 000000000..d94005cd8 --- /dev/null +++ b/rslp/mapbiomas/subsampling/sample_expert_points.py @@ -0,0 +1,631 @@ +"""Build a balanced 4k subsample (3k train + 1k val) from MapBiomas validation points. + +Criteria: +- Balanced class representation across as many classes as possible (water-filling). +- Uniform sampling across CARTA_2 tiles. +- Best-effort DECLIVIDAD representation and year uniformity (2016-2022). +- Only best data points (COUNT_year == 3). +- 25% edge pixels (BORDA_year == 1), 75% interior (BORDA_year == 0) best-effort. +- Each spatial pixel (TARGETID) used at most once across all years. +- Excluded classes: 13, 23, 27, 30, 31, 32, 50. +""" + +from __future__ import annotations + +import argparse +import os +from collections import Counter +from pathlib import Path + +import geopandas as gpd +import numpy as np +import pandas as pd + +MY_ROOT = Path(os.environ.get("MY_ROOT", ".")) + +DEFAULT_LEGEND = ( + MY_ROOT / "datasets/mapbiomas/metadata/Codigos-da-legenda-colecao-10.csv" +) +DEFAULT_HIERARCHY = MY_ROOT / "datasets/mapbiomas/metadata/hierarchy.csv" + +EXCLUDED_CLASSES: set[int] = { + 13, + 23, + 27, + 30, + 31, + 32, + 50, +} # less than 100 points in data +YEAR_RANGE: range = range(2016, 2023) # 2016-2022 inclusive + + +# --------------------------------------------------------------------------- +# Step 1: load & melt +# --------------------------------------------------------------------------- + + +def load_and_melt(shp_path: str | Path) -> pd.DataFrame: + """Load the shapefile and melt year-specific columns into long format. + + Returns a DataFrame with columns: + TARGETID, LON, LAT, YEAR, CLASS, BORDA, COUNT, CARTA_2, DECLIVIDAD + filtered to COUNT == 3 and classes not in EXCLUDED_CLASSES. + """ + gdf = gpd.read_file(shp_path) + + static_cols = ["TARGETID", "LON", "LAT", "CARTA_2", "DECLIVIDAD"] + frames: list[pd.DataFrame] = [] + for year in YEAR_RANGE: + cls_col = f"CLASS_{year}" + cnt_col = f"COUNT_{year}" + brd_col = f"BORDA_{year}" + if cls_col not in gdf.columns: + continue + sub = gdf[static_cols + [cls_col, cnt_col, brd_col]].copy() + sub.columns = [*static_cols, "CLASS", "COUNT", "BORDA"] + sub["YEAR"] = year + sub["COUNT"] = pd.to_numeric(sub["COUNT"], errors="coerce") + sub["CLASS"] = pd.to_numeric(sub["CLASS"], errors="coerce") + sub["BORDA"] = pd.to_numeric(sub["BORDA"], errors="coerce") + frames.append(sub) + + long = pd.concat(frames, ignore_index=True) + + # Filter: best quality only, drop excluded classes and NaN class + long = long[long["COUNT"] == 3].copy() + long = long[long["CLASS"].notna()].copy() + long["CLASS"] = long["CLASS"].astype(int) + long = long[~long["CLASS"].isin(EXCLUDED_CLASSES)].copy() + + long = long.reset_index(drop=True) + return long + + +# --------------------------------------------------------------------------- +# Step 2: water-filling quota +# --------------------------------------------------------------------------- + + +def compute_quotas( + long: pd.DataFrame, + total: int, +) -> dict[int, int]: + """Water-fill per-class quotas so sum == total, capped by unique-pixel availability.""" + avail = long.groupby("CLASS")["TARGETID"].nunique().to_dict() + classes = sorted(avail.keys()) + + sorted_caps = sorted(avail[c] for c in classes) + n = len(classes) + level = 0.0 + remaining = total + + for i, cap in enumerate(sorted_caps): + slots = n - i + if (cap - level) * slots >= remaining: + level += remaining / slots + break + remaining -= int(cap - level) * slots + level = cap + else: + level = sorted_caps[-1] + + base = int(np.floor(level)) + quotas = {c: min(avail[c], base) for c in classes} + deficit = total - sum(quotas.values()) + + # distribute remainder one-at-a-time to classes that still have room + frac_parts: list[tuple[float, int]] = [] + for c in classes: + headroom = avail[c] - quotas[c] + if headroom > 0: + frac_parts.append((level - int(np.floor(level)), c)) + frac_parts.sort(key=lambda x: -x[0]) + + idx = 0 + while deficit > 0: + for c in classes: + if deficit <= 0: + break + if avail[c] - quotas[c] > 0: + quotas[c] += 1 + deficit -= 1 + idx += 1 + if idx > len(classes): + break + + return quotas + + +# --------------------------------------------------------------------------- +# Step 3: stratified selection (rarest class first) +# --------------------------------------------------------------------------- + + +def select_pixels( + long: pd.DataFrame, + quotas: dict[int, int], + edge_frac: float, + rng: np.random.Generator, +) -> pd.DataFrame: + """Select pixels per class, rarest first, with CARTA_2/DECLIVIDAD/YEAR balance. + + Each TARGETID is used at most once globally. Edge/interior mix is best-effort. + """ + used_ids: set[int] = set() + selected_rows: list[pd.DataFrame] = [] + + class_order = sorted(quotas.keys(), key=lambda c: quotas[c]) + active_classes = [c for c in class_order if quotas[c] > 0] + total_quota = sum(quotas[c] for c in active_classes) + cumulative = 0 + + for i, cls in enumerate(active_classes, 1): + quota = quotas[cls] + + pool = long[(long["CLASS"] == cls) & (~long["TARGETID"].isin(used_ids))].copy() + if pool.empty: + print( + f" [{i}/{len(active_classes)}] class {cls:>2d}: " + f"quota={quota}, pool empty — skipped" + ) + continue + + n_edge_target = int(round(edge_frac * quota)) + n_interior_target = quota - n_edge_target + + edge_pool = pool[pool["BORDA"] == 1] + interior_pool = pool[pool["BORDA"] == 0] + + chosen_edge = _stratified_pick( + edge_pool, + n_edge_target, + used_ids, + rng, + label=f"class {cls} edge", + ) + newly_used = set(chosen_edge["TARGETID"]) + used_ids.update(newly_used) + + interior_pool = interior_pool[~interior_pool["TARGETID"].isin(used_ids)] + chosen_interior = _stratified_pick( + interior_pool, + n_interior_target, + used_ids, + rng, + label=f"class {cls} interior", + ) + used_ids.update(chosen_interior["TARGETID"]) + + # Backfill shortfall from the other pool + got = len(chosen_edge) + len(chosen_interior) + shortfall = quota - got + if shortfall > 0: + remaining = pool[~pool["TARGETID"].isin(used_ids)] + backfill = _stratified_pick( + remaining, + shortfall, + used_ids, + rng, + label=f"class {cls} backfill", + ) + used_ids.update(backfill["TARGETID"]) + selected_rows.extend([chosen_edge, chosen_interior, backfill]) + got += len(backfill) + else: + selected_rows.extend([chosen_edge, chosen_interior]) + + cumulative += got + print( + f" [{i}/{len(active_classes)}] class {cls:>2d}: " + f"picked {got}/{quota} — cumulative {cumulative}/{total_quota}" + ) + + result = pd.concat(selected_rows, ignore_index=True) + return result + + +def _stratified_pick( + pool: pd.DataFrame, + n: int, + used_ids: set[int], + rng: np.random.Generator, + label: str = "", +) -> pd.DataFrame: + """Pick up to *n* rows from *pool* spread across CARTA_2 / DECLIVIDAD / YEAR. + + Within each tile, prefer rows that fill under-represented DECLIVIDAD and YEAR + buckets. Each picked TARGETID is immediately added to used_ids so no pixel + repeats. + """ + if n <= 0 or pool.empty: + return pool.iloc[:0] + + pool = pool[~pool["TARGETID"].isin(used_ids)].copy() + if pool.empty: + return pool + + pool = pool.sample(frac=1, random_state=int(rng.integers(2**31))).reset_index( + drop=True + ) + + tiles = sorted(pool["CARTA_2"].unique()) + tile_pools: dict[str, pd.DataFrame] = { + t: pool[pool["CARTA_2"] == t].copy() for t in tiles + } + + year_counts: Counter[int] = Counter() + decliv_counts: Counter[str] = Counter() + picked_indices: list[int] = [] + local_used: set[int] = set() + + log_interval = max(1, n // 5) + tag = f" {label}: " if label else " pick: " + + tile_idx = 0 + stall_counter = 0 + while len(picked_indices) < n and stall_counter < len(tiles) + 1: + tile = tiles[tile_idx % len(tiles)] + tile_idx += 1 + tp = tile_pools[tile] + tp = tp[~tp["TARGETID"].isin(local_used)] + tile_pools[tile] = tp + if tp.empty: + stall_counter += 1 + continue + stall_counter = 0 + + scores = np.zeros(len(tp)) + for i, (_, row) in enumerate(tp.iterrows()): + yr = row["YEAR"] + dc = row["DECLIVIDAD"] + scores[i] = -(year_counts.get(yr, 0) + decliv_counts.get(dc, 0)) + + best_idx = int(np.argmax(scores)) + chosen_row = tp.iloc[best_idx] + picked_indices.append(tp.index[best_idx]) + tid = chosen_row["TARGETID"] + local_used.add(tid) + used_ids.add(tid) + year_counts[chosen_row["YEAR"]] += 1 + decliv_counts[chosen_row["DECLIVIDAD"]] += 1 + + for t2 in tiles: + tp2 = tile_pools[t2] + if not tp2.empty: + tile_pools[t2] = tp2[tp2["TARGETID"] != tid] + + picked = len(picked_indices) + if picked % log_interval == 0 or picked == n: + print(f"{tag}{picked}/{n}", flush=True) + + if not picked_indices: + return pool.iloc[:0] + return pool.loc[pool.index.isin(picked_indices)].copy() + + +# --------------------------------------------------------------------------- +# Step 4: train / val split +# --------------------------------------------------------------------------- + + +def train_val_split( + selected: pd.DataFrame, + n_train: int, + n_val: int, + rng: np.random.Generator, +) -> pd.DataFrame: + """Stratified per-class 75/25 train/val split.""" + total = n_train + n_val + train_frac = n_train / total + + parts: list[pd.DataFrame] = [] + for cls, grp in selected.groupby("CLASS"): + grp = grp.sample(frac=1, random_state=int(rng.integers(2**31))) + n_tr = int(round(len(grp) * train_frac)) + grp = grp.copy() + grp["split"] = "val" + grp.iloc[:n_tr, grp.columns.get_loc("split")] = "train" + parts.append(grp) + + result = pd.concat(parts, ignore_index=True) + + # Adjust global counts to hit exact targets + n_train_actual = (result["split"] == "train").sum() + diff = n_train_actual - n_train + if diff > 0: + train_idx = ( + result[result["split"] == "train"] + .sample(n=diff, random_state=int(rng.integers(2**31))) + .index + ) + result.loc[train_idx, "split"] = "val" + elif diff < 0: + val_idx = ( + result[result["split"] == "val"] + .sample(n=-diff, random_state=int(rng.integers(2**31))) + .index + ) + result.loc[val_idx, "split"] = "train" + + return result + + +# --------------------------------------------------------------------------- +# Step 5: summary report +# --------------------------------------------------------------------------- + + +def load_legend(path: Path) -> dict[int, str]: + """Return {class_id: english_description} from the MapBiomas legend CSV.""" + legend = pd.read_csv(path, sep="\t") + legend["Description"] = legend["Description"].str.strip() + return dict(zip(legend["Class_ID"].astype(int), legend["Description"])) + + +def load_hierarchy(path: Path) -> dict[int, dict]: + """Parse the hierarchy CSV and return per-class hierarchy info. + + Returns ``{class_id: {"leaf_level": int, + "parent_class_id": int | None, + "parent_leaf_level": int | None, + "parent_class_desc": str | None}}``. + """ + h = pd.read_csv(path) + result: dict[int, dict] = {} + for cid, grp in h.groupby("Class_ID"): + cid = int(cid) + leaf_level = int(grp["Leaf_Level"].iloc[0]) + + if leaf_level > 1: + parent_row = grp[grp["Hierarchy_Level"] == leaf_level - 1].iloc[0] + parent_class_id = int(parent_row["Level_Class_ID"]) + parent_leaf_level = leaf_level - 1 + parent_class_desc = str(parent_row["Level_Description"]) + else: + parent_class_id = None + parent_leaf_level = None + parent_class_desc = None + + result[cid] = { + "leaf_level": leaf_level, + "parent_class_id": parent_class_id, + "parent_leaf_level": parent_leaf_level, + "parent_class_desc": parent_class_desc, + } + return result + + +def build_summary( + df: pd.DataFrame, + legend: dict[int, str], + hierarchy: dict[int, dict] | None = None, +) -> pd.DataFrame: + """Build a class-level summary with train/val breakdown and diversity metrics.""" + n_total = len(df) + rows: list[dict] = [] + + for cls, grp in df.groupby("CLASS"): + n = len(grp) + cid = int(cls) + train_mask = grp["split"] == "train" + n_train = int(train_mask.sum()) + n_val = n - n_train + + rec: dict = { + "class_id": cid, + "class_name": legend.get(cid, "unknown"), + "leaf_level": None, + "parent_class_id": None, + "parent_leaf_level": None, + "parent_class_desc": None, + "total_points": n, + "frac_of_all": n / n_total, + "train_points": n_train, + "train_frac": n_train / n if n else 0.0, + "val_points": n_val, + "val_frac": n_val / n if n else 0.0, + "edge_frac": float((grp["BORDA"] == 1).mean()), + "n_tiles": int(grp["CARTA_2"].nunique()), + "n_years": int(grp["YEAR"].nunique()), + } + + if hierarchy and cid in hierarchy: + hi = hierarchy[cid] + rec["leaf_level"] = hi["leaf_level"] + rec["parent_class_id"] = hi["parent_class_id"] + rec["parent_leaf_level"] = hi["parent_leaf_level"] + rec["parent_class_desc"] = hi["parent_class_desc"] + + rows.append(rec) + + summary = ( + pd.DataFrame(rows) + .sort_values("total_points", ascending=False) + .reset_index(drop=True) + ) + return summary + + +def print_summary( + df: pd.DataFrame, + summary: pd.DataFrame | None = None, +) -> None: + """Print a diagnostic summary of the subsample.""" + print("\n" + "=" * 70) + print("SUBSAMPLE SUMMARY") + print("=" * 70) + + print(f"\nTotal rows: {len(df)}") + print(f" Train: {(df['split'] == 'train').sum()}") + print(f" Val: {(df['split'] == 'val').sum()}") + print(f" Unique TARGETIDs: {df['TARGETID'].nunique()} (should == {len(df)})") + + print("\n--- Per-class counts ---") + cls_split = df.groupby(["CLASS", "split"]).size().unstack(fill_value=0) + cls_split["total"] = cls_split.sum(axis=1) + print(cls_split.to_string()) + + if ( + summary is not None + and "leaf_level" in summary.columns + and summary["leaf_level"].notna().any() + ): + print("\n--- Class hierarchy ---") + fmt = "{:<6s} {:<35s} {:<5s} {:<6s} {:<5s} {:<25s}" + print(fmt.format("ID", "Class", "Lvl", "ParID", "ParLv", "Parent")) + print("-" * 90) + for _, r in summary.sort_values("class_id").iterrows(): + lvl = str(int(r["leaf_level"])) if pd.notna(r["leaf_level"]) else "" + par_id = ( + str(int(r["parent_class_id"])) + if pd.notna(r["parent_class_id"]) + else "-" + ) + par_lv = ( + str(int(r["parent_leaf_level"])) + if pd.notna(r["parent_leaf_level"]) + else "-" + ) + par_desc = ( + str(r["parent_class_desc"])[:25] + if pd.notna(r["parent_class_desc"]) + else "-" + ) + print( + fmt.format( + str(int(r["class_id"])), + str(r["class_name"])[:35], + lvl, + par_id, + par_lv, + par_desc, + ) + ) + + print("\n--- BORDA distribution ---") + borda_split = df.groupby(["BORDA", "split"]).size().unstack(fill_value=0) + borda_split["total"] = borda_split.sum(axis=1) + print(borda_split.to_string()) + n_edge = (df["BORDA"] == 1).sum() + print(f" Edge fraction: {n_edge / len(df):.3f} (target 0.250)") + + print("\n--- YEAR distribution ---") + year_split = df.groupby(["YEAR", "split"]).size().unstack(fill_value=0) + year_split["total"] = year_split.sum(axis=1) + print(year_split.to_string()) + + print("\n--- DECLIVIDAD distribution ---") + dec_split = df.groupby(["DECLIVIDAD", "split"]).size().unstack(fill_value=0) + dec_split["total"] = dec_split.sum(axis=1) + print(dec_split.to_string()) + + print("\n--- CARTA_2 tile coverage ---") + print(f" Tiles represented: {df['CARTA_2'].nunique()}") + + print("=" * 70 + "\n") + + +# --------------------------------------------------------------------------- +# main +# --------------------------------------------------------------------------- + + +def main() -> None: + """Build a balanced MapBiomas subsample (3k train + 1k val).""" + parser = argparse.ArgumentParser( + description="Build a balanced MapBiomas subsample (3k train + 1k val)." + ) + parser.add_argument( + "--shp-path", + type=str, + default=str( + MY_ROOT / "datasets/mapbiomas/metadata/mapbiomas_85k_points_validation.shp" + ), + help="Path to the 85k validation shapefile.", + ) + parser.add_argument( + "--out-path", + type=str, + default=str( + MY_ROOT + / "rslearn_projects/rslp/mapbiomas/subsampling/sample_expert_points_4k.csv" + ), + help="Output CSV path.", + ) + parser.add_argument("--n-train", type=int, default=3000) + parser.add_argument("--n-val", type=int, default=1000) + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--edge-frac", type=float, default=0.25) + parser.add_argument( + "--legend-path", + type=str, + default=str(DEFAULT_LEGEND), + help="Path to the MapBiomas legend CSV (tab-separated).", + ) + parser.add_argument( + "--hierarchy-path", + type=str, + default=str(DEFAULT_HIERARCHY), + help="Path to the MapBiomas hierarchy CSV.", + ) + args = parser.parse_args() + + total = args.n_train + args.n_val + rng = np.random.default_rng(args.seed) + + print(f"Loading shapefile: {args.shp_path}") + long = load_and_melt(args.shp_path) + n_unique = long["TARGETID"].nunique() + n_classes = long["CLASS"].nunique() + print(f" Long-format rows (COUNT==3, kept classes): {len(long)}") + print(f" Unique pixels: {n_unique}, Classes: {n_classes}") + + quotas = compute_quotas(long, total) + print(f"\nPer-class quotas (total {sum(quotas.values())}):") + for cls in sorted(quotas, key=lambda c: -quotas[c]): + avail = long[long["CLASS"] == cls]["TARGETID"].nunique() + print(f" class {cls:>2d}: quota={quotas[cls]:>4d} (avail={avail})") + + print("\nSelecting pixels …") + selected = select_pixels(long, quotas, args.edge_frac, rng) + print(f" Selected: {len(selected)} rows") + + selected = train_val_split(selected, args.n_train, args.n_val, rng) + + out_cols = [ + "TARGETID", + "LON", + "LAT", + "YEAR", + "CLASS", + "BORDA", + "COUNT", + "CARTA_2", + "DECLIVIDAD", + "split", + ] + selected = ( + selected[out_cols] + .sort_values(["split", "CLASS", "YEAR"]) + .reset_index(drop=True) + ) + + out_path = Path(args.out_path) + out_path.parent.mkdir(parents=True, exist_ok=True) + selected.to_csv(out_path, index=False) + print(f"\nWrote {len(selected)} rows to {out_path}") + + legend = load_legend(Path(args.legend_path)) + hierarchy_path = Path(args.hierarchy_path) + hierarchy = load_hierarchy(hierarchy_path) if hierarchy_path.exists() else None + summary = build_summary(selected, legend, hierarchy) + summary_path = out_path.parent / "sample_expert_points_summary.csv" + summary.to_csv(summary_path, index=False) + print(f"Wrote summary stats: {summary_path} ({len(summary)} rows)") + + print_summary(selected, summary) + + +if __name__ == "__main__": + main() diff --git a/rslp/mapbiomas/subsampling/visualize_sample_expert_points.py b/rslp/mapbiomas/subsampling/visualize_sample_expert_points.py new file mode 100644 index 000000000..66d55f259 --- /dev/null +++ b/rslp/mapbiomas/subsampling/visualize_sample_expert_points.py @@ -0,0 +1,191 @@ +"""Visualize the MapBiomas subsample produced by sample_expert_points.py. + +Generates a 2x3 figure: + (0,0) Geographic scatter – train only + (0,1) Geographic scatter – val only + (0,2) CARTA_2 histogram (overlaid train/val) + (1,0) Class distribution bar chart + (1,1) Year distribution bar chart + (1,2) DECLIVIDAD distribution bar chart +""" + +from __future__ import annotations + +import argparse +import os +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd + +MY_ROOT = Path(os.environ.get("MY_ROOT", ".")) + +SPLIT_COLORS = {"train": "#1f77b4", "val": "#ff7f0e"} + + +def main() -> None: + """Visualize the MapBiomas 4k subsample.""" + parser = argparse.ArgumentParser( + description="Visualize the MapBiomas 4k subsample." + ) + parser.add_argument( + "--csv-path", + type=str, + default=str( + MY_ROOT + / "rslearn_projects/rslp/mapbiomas/subsampling/sample_expert_points_4k.csv" + ), + help="Path to the subsample CSV produced by sample_expert_points.py.", + ) + parser.add_argument( + "--out-path", + type=str, + default=None, + help="If set, save the figure to this path instead of showing it.", + ) + args = parser.parse_args() + + df = pd.read_csv(args.csv_path) + train = df[df["split"] == "train"] + val = df[df["split"] == "val"] + + fig, axes = plt.subplots(2, 3, figsize=(22, 12)) + fig.suptitle("MapBiomas Subsample Overview", fontsize=16, fontweight="bold") + + # --- (0,0) Geographic scatter – train --- + ax = axes[0, 0] + ax.scatter( + train["LAT"], + train["LON"], + s=4, + alpha=0.4, + color=SPLIT_COLORS["train"], + rasterized=True, + ) + ax.set_xlabel("Longitude") + ax.set_ylabel("Latitude") + ax.set_title(f"Train ({len(train)} pts)") + + # --- (0,1) Geographic scatter – val --- + ax = axes[0, 1] + ax.scatter( + val["LAT"], + val["LON"], + s=4, + alpha=0.4, + color=SPLIT_COLORS["val"], + rasterized=True, + ) + ax.set_xlabel("Longitude") + ax.set_ylabel("Latitude") + ax.set_title(f"Val ({len(val)} pts)") + + # Share axis limits between the two geo plots + all_lon = df["LAT"] + all_lat = df["LON"] + lon_pad = (all_lon.max() - all_lon.min()) * 0.03 + lat_pad = (all_lat.max() - all_lat.min()) * 0.03 + for a in axes[0, :2]: + a.set_xlim(all_lon.min() - lon_pad, all_lon.max() + lon_pad) + a.set_ylim(all_lat.min() - lat_pad, all_lat.max() + lat_pad) + + # --- (0,2) CARTA_2 histogram --- + ax = axes[0, 2] + tiles_sorted = sorted(df["CARTA_2"].unique()) + tile_to_idx = {t: i for i, t in enumerate(tiles_sorted)} + train_idx = train["CARTA_2"].map(tile_to_idx) + val_idx = val["CARTA_2"].map(tile_to_idx) + bins = np.linspace(-0.5, len(tiles_sorted) - 0.5, min(50, len(tiles_sorted)) + 1) + ax.hist(train_idx, bins=bins, alpha=0.5, label="train", color=SPLIT_COLORS["train"]) + ax.hist(val_idx, bins=bins, alpha=0.5, label="val", color=SPLIT_COLORS["val"]) + ax.set_xlabel("CARTA_2 tile (index)") + ax.set_ylabel("Count") + ax.set_title(f"CARTA_2 Distribution ({len(tiles_sorted)} tiles)") + ax.legend() + + # --- (1,0) Class distribution --- + _grouped_bar( + axes[1, 0], + df, + group_col="CLASS", + title="Class Distribution", + xlabel="Class", + ylabel="Count", + ) + + # --- (1,1) Year distribution --- + _grouped_bar( + axes[1, 1], + df, + group_col="YEAR", + title="Year Distribution", + xlabel="Year", + ylabel="Count", + ) + + # --- (1,2) DECLIVIDAD distribution --- + _grouped_bar( + axes[1, 2], + df, + group_col="DECLIVIDAD", + title="DECLIVIDAD Distribution", + xlabel="DECLIVIDAD", + ylabel="Count", + ) + + fig.tight_layout(rect=[0, 0, 1, 0.96]) + + if args.out_path: + out = Path(args.out_path) + out.parent.mkdir(parents=True, exist_ok=True) + fig.savefig(out, dpi=150, bbox_inches="tight") + print(f"Saved figure to {out}") + else: + plt.show() + + +def _grouped_bar( + ax: plt.Axes, + df: pd.DataFrame, + group_col: str, + title: str, + xlabel: str, + ylabel: str, +) -> None: + """Draw a side-by-side bar chart of train vs val counts for *group_col*.""" + ct = df.groupby([group_col, "split"]).size().unstack(fill_value=0) + for split in ["train", "val"]: + if split not in ct.columns: + ct[split] = 0 + ct = ct[["train", "val"]].sort_index() + + labels = [str(v) for v in ct.index] + x = range(len(labels)) + bar_width = 0.4 + + ax.bar( + [i - bar_width / 2 for i in x], + ct["train"], + width=bar_width, + label="train", + color=SPLIT_COLORS["train"], + ) + ax.bar( + [i + bar_width / 2 for i in x], + ct["val"], + width=bar_width, + label="val", + color=SPLIT_COLORS["val"], + ) + + ax.set_xticks(list(x)) + ax.set_xticklabels(labels, rotation=45, ha="right") + ax.set_xlabel(xlabel) + ax.set_ylabel(ylabel) + ax.set_title(title) + ax.legend() + + +if __name__ == "__main__": + main() diff --git a/rslp/utils/beaker.py b/rslp/utils/beaker.py index fe5ce52b4..9e94be03e 100644 --- a/rslp/utils/beaker.py +++ b/rslp/utils/beaker.py @@ -7,7 +7,7 @@ from beaker.client import Beaker DEFAULT_WORKSPACE = "ai2/earth-systems" -DEFAULT_BUDGET = "ai2/es-platform" +DEFAULT_BUDGET = "ai2/atec-olmoearth" @dataclass