From 95efb9215738b0baabc3badbaa44d5a058b93ec7 Mon Sep 17 00:00:00 2001 From: root Date: Tue, 19 May 2026 18:57:44 +0000 Subject: [PATCH 1/6] adding oep subsampling strategy for olmoearth_pretrain eval task --- rslp/lfmc/subsample_oep_eval.py | 474 ++++++++++++++++++++++++++++++++ rslp/lfmc/visualize_oep_eval.py | 176 ++++++++++++ 2 files changed, 650 insertions(+) create mode 100644 rslp/lfmc/subsample_oep_eval.py create mode 100644 rslp/lfmc/visualize_oep_eval.py diff --git a/rslp/lfmc/subsample_oep_eval.py b/rslp/lfmc/subsample_oep_eval.py new file mode 100644 index 000000000..16e067ae4 --- /dev/null +++ b/rslp/lfmc/subsample_oep_eval.py @@ -0,0 +1,474 @@ +"""Tag a spatially-balanced subset of LFMC windows with the oep_eval tag. + +For train: selects ~1000 windows distributed evenly across US states, operating +at the location level. Each selected location contributes at most 1 year of +samples (the densest 365-day window), maximizing geographic coverage. +For val/test: tags ALL windows. + +Usage: + python subsample_oep_eval.py --dataset_path /path/to/dataset --target_train 1000 --seed 42 +""" + +import argparse +import json +import os +from collections import defaultdict +from datetime import datetime, timedelta +from typing import Any + +import geopandas as gpd +import numpy as np +import pandas as pd +from pyproj import Transformer +from shapely.geometry import Point + +LocKey = tuple[str, tuple[Any, ...]] + +US_STATES_SHP = ( + "/weka/dfive-default/hadriens/datasets/Misc/Us states/cb_2018_us_state_20m.shp" +) + + +def load_windows(dataset_path: str) -> list[dict]: + """Load all window metadata from the dataset.""" + windows_dir = os.path.join(dataset_path, "windows", "spatial_split") + names = os.listdir(windows_dir) + print(f" Found {len(names)} window directories, loading metadata...") + windows = [] + for i, name in enumerate(names): + if i % 5000 == 0 and i > 0: + print(f" ... loaded {i}/{len(names)}") + meta_path = os.path.join(windows_dir, name, "metadata.json") + if not os.path.isfile(meta_path): + continue + with open(meta_path) as f: + meta = json.load(f) + meta["_dir"] = os.path.join(windows_dir, name) + meta["_name"] = name + windows.append(meta) + return windows + + +def compute_location_centroid_wgs84( + crs: str, bounds: list, x_resolution: float, y_resolution: float +) -> tuple[float, float]: + """Convert the center of pixel bounds to WGS84 lon/lat. + + Bounds are stored in pixel coordinates; multiply by resolution to get + real-world UTM coordinates. + """ + cx_pixel = (bounds[0] + bounds[2]) / 2.0 + cy_pixel = (bounds[1] + bounds[3]) / 2.0 + cx_utm = cx_pixel * x_resolution + cy_utm = cy_pixel * y_resolution + transformer = Transformer.from_crs(crs, "EPSG:4326", always_xy=True) + lon, lat = transformer.transform(cx_utm, cy_utm) + return lon, lat + + +def get_window_date(window: dict) -> datetime: + """Extract the sample date from a window's time_range.""" + return datetime.fromisoformat(window["time_range"][0]) + + +def best_one_year_subset(windows: list[dict]) -> list[dict]: + """Find the densest 365-day sliding window and return only those samples. + + Slides across all sample dates and picks the 365-day interval containing + the most samples. + """ + if len(windows) <= 1: + return windows + + sorted_windows = sorted(windows, key=get_window_date) + dates = [get_window_date(w) for w in sorted_windows] + year_delta = timedelta(days=365) + + best_start = 0 + best_end = 0 + best_count = 0 + + for i, start_date in enumerate(dates): + end_date = start_date + year_delta + # Find how many samples fall within [start_date, start_date + 365 days] + j = i + while j < len(dates) and dates[j] <= end_date: + j += 1 + count = j - i + if count > best_count: + best_count = count + best_start = i + best_end = j + + return sorted_windows[best_start:best_end] + + +def assign_locations_to_states( + locations: dict[LocKey, list[dict]], states_gdf: gpd.GeoDataFrame +) -> dict[LocKey, str]: + """Assign each location key to a US state via point-in-polygon.""" + loc_keys = list(locations.keys()) + points = [] + for key in loc_keys: + crs, bounds = key[0], list(key[1]) + first_window = locations[key][0] + x_res = first_window["projection"]["x_resolution"] + y_res = first_window["projection"]["y_resolution"] + lon, lat = compute_location_centroid_wgs84(crs, bounds, x_res, y_res) + points.append(Point(lon, lat)) + + points_gdf = gpd.GeoDataFrame( + {"loc_key": loc_keys}, geometry=points, crs="EPSG:4326" + ) + states_gdf_4326 = states_gdf.to_crs("EPSG:4326") + joined = gpd.sjoin(points_gdf, states_gdf_4326[["NAME", "geometry"]], how="left") + + loc_to_state: dict[LocKey, str] = {} + for _, row in joined.iterrows(): + state = row.get("NAME") + if isinstance(state, str) and not pd.isna(state): + loc_to_state[row["loc_key"]] = state + else: + loc_to_state[row["loc_key"]] = "Unknown" + return loc_to_state + + +def select_locations_balanced( + locations: dict[LocKey, list[dict]], + loc_to_state: dict[LocKey, str], + loc_one_year_count: dict[LocKey, int], + target: int, + seed: int, +) -> list[LocKey]: + """Select locations with even distribution across states. + + Uses the 1-year capped sample count for each location when computing quotas, + so that more locations can be selected for better geographic coverage. + Returns list of selected location keys. + """ + rng = np.random.default_rng(seed) + + state_to_locs: dict[str, list[LocKey]] = defaultdict(list) + for loc_key, state in loc_to_state.items(): + state_to_locs[state].append(loc_key) + + states_with_data = [s for s in state_to_locs if s != "Unknown"] + if not states_with_data: + states_with_data = list(state_to_locs.keys()) + + selected: list[LocKey] = [] + remaining_target = target + + # Iteratively allocate: states with fewer capped samples than quota get all + settled_states = set() + for _ in range(20): + unsettled = [s for s in states_with_data if s not in settled_states] + if not unsettled: + break + + per_state_quota = remaining_target / len(unsettled) if unsettled else 0 + newly_settled = [] + + for state in unsettled: + locs = state_to_locs[state] + total_capped = sum(loc_one_year_count[lk] for lk in locs) + if total_capped <= per_state_quota: + selected.extend(locs) + remaining_target -= total_capped + newly_settled.append(state) + + if not newly_settled: + break + settled_states.update(newly_settled) + + # For remaining states, select locations to fill quota + unsettled = [s for s in states_with_data if s not in settled_states] + if unsettled: + per_state_quota = remaining_target / len(unsettled) if unsettled else 0 + for state in unsettled: + locs = state_to_locs[state] + locs_with_count = [(lk, loc_one_year_count[lk]) for lk in locs] + + # Shuffle for randomness, then sort by count (prefer dense locations) + shuffled = list(locs_with_count) + rng.shuffle(shuffled) + shuffled.sort(key=lambda x: x[1], reverse=True) + + state_selected: list[LocKey] = [] + state_total = 0 + for lk, count in shuffled: + if state_total + count > per_state_quota and state_selected: + break + state_selected.append(lk) + state_total += count + + selected.extend(state_selected) + remaining_target -= state_total + + # Include "Unknown" state locations if any remain and we're under target + if "Unknown" in state_to_locs and remaining_target > 0: + unknown_locs = state_to_locs["Unknown"] + unknown_with_count = [(lk, loc_one_year_count[lk]) for lk in unknown_locs] + rng.shuffle(unknown_with_count) + unknown_with_count.sort(key=lambda x: x[1], reverse=True) + for lk, count in unknown_with_count: + if remaining_target <= 0: + break + selected.append(lk) + remaining_target -= count + + return selected + + +def tag_windows(windows: list[dict]) -> int: + """Add oep_eval tag to the given windows' metadata.json files.""" + tagged = 0 + for i, w in enumerate(windows): + if i % 2000 == 0 and i > 0: + print(f" ... tagged {i}/{len(windows)}") + meta_path = os.path.join(w["_dir"], "metadata.json") + with open(meta_path) as f: + meta = json.load(f) + if "oep_eval" not in meta.get("options", {}): + meta.setdefault("options", {})["oep_eval"] = "" + with open(meta_path, "w") as f: + json.dump(meta, f) + tagged += 1 + else: + tagged += 1 + return tagged + + +VAL_MAX_SAMPLES = 800 +TEST_MAX_SAMPLES = 500 + + +def subsample_split( + windows: list[dict], + target: int, + states_gdf: gpd.GeoDataFrame, + seed: int, + split_name: str, +) -> tuple[list[dict], dict]: + """Subsample a split using spatially-balanced selection with 1-year cap. + + Returns (selected_windows, split_stats_dict). + """ + # Group by location + locations: dict[LocKey, list[dict]] = defaultdict(list) + for w in windows: + key: LocKey = (w["projection"]["crs"], tuple(w["bounds"])) + locations[key].append(w) + print(f" Unique {split_name} locations: {len(locations)}") + + # Precompute best 1-year subset for each location + print(f" Computing best 1-year window per {split_name} location...") + loc_one_year: dict[LocKey, list[dict]] = {} + loc_one_year_count: dict[LocKey, int] = {} + for loc_key, loc_windows in locations.items(): + subset = best_one_year_subset(loc_windows) + loc_one_year[loc_key] = subset + loc_one_year_count[loc_key] = len(subset) + avg_capped = np.mean(list(loc_one_year_count.values())) + print(f" Avg samples per location (1-year cap): {avg_capped:.1f}") + + # Assign locations to states + print(f" Assigning {split_name} locations to states...") + loc_to_state = assign_locations_to_states(locations, states_gdf) + + # State distribution + state_total_samples: dict[str, int] = defaultdict(int) + state_total_locations: dict[str, int] = defaultdict(int) + for loc_key, state in loc_to_state.items(): + state_total_samples[state] += len(locations[loc_key]) + state_total_locations[state] += 1 + + # Select locations + print(f" Selecting locations for ~{target} {split_name} samples...") + selected_locs = select_locations_balanced( + locations, loc_to_state, loc_one_year_count, target, seed + ) + + # Collect the 1-year subset + selected_windows = [] + for loc_key in selected_locs: + selected_windows.extend(loc_one_year[loc_key]) + print( + f" Selected {len(selected_locs)} locations -> {len(selected_windows)} {split_name} samples" + ) + + # Per-state reporting + selected_state_samples: dict[str, int] = defaultdict(int) + selected_state_locations: dict[str, int] = defaultdict(int) + for loc_key in selected_locs: + state = loc_to_state[loc_key] + selected_state_samples[state] += loc_one_year_count[loc_key] + selected_state_locations[state] += 1 + print(f"\n Per-state {split_name} selection:") + for state in sorted(selected_state_samples.keys()): + total_locs = state_total_locations[state] + sel_locs = selected_state_locations[state] + total_samp = state_total_samples[state] + sel_samp = selected_state_samples[state] + print( + f" {state}: {sel_locs}/{total_locs} locations, {sel_samp}/{total_samp} samples" + ) + + # Build stats for this split + all_states = sorted( + set(list(selected_state_samples.keys()) + list(state_total_locations.keys())) + ) + split_stats = { + "target_samples": target, + "total_locations": len(locations), + "selected_locations": len(selected_locs), + "total_samples": len(windows), + "selected_samples": len(selected_windows), + "avg_samples_per_location_1yr_cap": float(avg_capped), + "per_state": { + state: { + "total_locations": state_total_locations[state], + "selected_locations": selected_state_locations.get(state, 0), + "total_samples": state_total_samples[state], + "selected_samples": selected_state_samples.get(state, 0), + } + for state in all_states + }, + } + + return selected_windows, split_stats + + +def main() -> None: + """Tag a spatially-balanced oep_eval subset for LFMC datasets.""" + parser = argparse.ArgumentParser(description="Tag oep_eval subset for LFMC") + parser.add_argument("--dataset_path", type=str, required=True) + parser.add_argument("--target_train", type=int, default=1000) + parser.add_argument("--seed", type=int, default=42) + parser.add_argument( + "--states_shp", + type=str, + default=US_STATES_SHP, + help="Path to US states shapefile", + ) + parser.add_argument( + "--dry_run", action="store_true", help="Print stats without tagging" + ) + args = parser.parse_args() + + print(f"Loading windows from {args.dataset_path}...") + all_windows = load_windows(args.dataset_path) + print(f" Total windows: {len(all_windows)}") + + # Split by train/val/test + train_windows = [ + w for w in all_windows if w.get("options", {}).get("split") == "train" + ] + val_windows = [w for w in all_windows if w.get("options", {}).get("split") == "val"] + test_windows = [ + w for w in all_windows if w.get("options", {}).get("split") == "test" + ] + print( + f" Train: {len(train_windows)}, Val: {len(val_windows)}, Test: {len(test_windows)}" + ) + + # Load states shapefile + print("Loading US states shapefile...") + states_gdf = gpd.read_file(args.states_shp) + + # --- Train subsampling --- + print(f"\n--- TRAIN (target: {args.target_train}) ---") + selected_train, train_stats = subsample_split( + train_windows, args.target_train, states_gdf, args.seed, "train" + ) + + # --- Val: subsample if over threshold, else use all --- + if len(val_windows) > VAL_MAX_SAMPLES: + print( + f"\n--- VAL (target: {VAL_MAX_SAMPLES}, total {len(val_windows)} > {VAL_MAX_SAMPLES}) ---" + ) + selected_val, val_stats = subsample_split( + val_windows, VAL_MAX_SAMPLES, states_gdf, args.seed + 1, "val" + ) + else: + print( + f"\n--- VAL: using all {len(val_windows)} samples (below {VAL_MAX_SAMPLES} threshold) ---" + ) + selected_val = val_windows + val_stats = { + "total_samples": len(val_windows), + "selected_samples": len(val_windows), + } + + # --- Test: subsample if over threshold, else use all --- + if len(test_windows) > TEST_MAX_SAMPLES: + print( + f"\n--- TEST (target: {TEST_MAX_SAMPLES}, total {len(test_windows)} > {TEST_MAX_SAMPLES}) ---" + ) + selected_test, test_stats = subsample_split( + test_windows, TEST_MAX_SAMPLES, states_gdf, args.seed + 2, "test" + ) + else: + print( + f"\n--- TEST: using all {len(test_windows)} samples (below {TEST_MAX_SAMPLES} threshold) ---" + ) + selected_test = test_windows + test_stats = { + "total_samples": len(test_windows), + "selected_samples": len(test_windows), + } + + if args.dry_run: + print("\n[DRY RUN] No tagging performed.") + return + + # Tag windows + print("\nTagging selected train samples...") + train_tagged = tag_windows(selected_train) + print(f" Tagged {train_tagged} train samples") + + print("Tagging selected val samples...") + val_tagged = tag_windows(selected_val) + print(f" Tagged {val_tagged} val samples") + + print("Tagging selected test samples...") + test_tagged = tag_windows(selected_test) + print(f" Tagged {test_tagged} test samples") + + # Write manifest + manifest = { + "train": [w["_name"] for w in selected_train], + "val": [w["_name"] for w in selected_val], + "test": [w["_name"] for w in selected_test], + } + manifest_path = os.path.join(args.dataset_path, "oep_eval_manifest.json") + with open(manifest_path, "w") as f: + json.dump(manifest, f, indent=2) + print(f"\nWrote manifest to {manifest_path}") + + # Write stats + stats = { + "seed": args.seed, + "one_year_cap_days": 365, + "val_max_samples_threshold": VAL_MAX_SAMPLES, + "test_max_samples_threshold": TEST_MAX_SAMPLES, + "total_samples": len(all_windows), + "tagged_samples": { + "train": len(selected_train), + "val": len(selected_val), + "test": len(selected_test), + }, + "train": train_stats, + "val": val_stats, + "test": test_stats, + } + stats_path = os.path.join(args.dataset_path, "oep_eval_stats.json") + with open(stats_path, "w") as f: + json.dump(stats, f, indent=2) + print(f"Wrote stats to {stats_path}") + + print("\nDone!") + + +if __name__ == "__main__": + main() diff --git a/rslp/lfmc/visualize_oep_eval.py b/rslp/lfmc/visualize_oep_eval.py new file mode 100644 index 000000000..12a7a3017 --- /dev/null +++ b/rslp/lfmc/visualize_oep_eval.py @@ -0,0 +1,176 @@ +"""Visualize oep_eval subsampled locations vs all original locations on a US map. + +Produces a 2x2 figure: + - Row 1: Woody dataset + - Row 2: Herbaceous dataset + - Left column: all original locations + - Right column: oep_eval-tagged subset only +Points are colored by split (train=blue, val=orange, test=green). + +Usage: + python visualize_oep_eval.py \ + --woody_path /path/to/woody/dataset \ + --herbaceous_path /path/to/herbaceous/dataset \ + --output oep_eval_map.png +""" + +import argparse +import json +import os +from collections import defaultdict + +import geopandas as gpd +import matplotlib.pyplot as plt +from pyproj import Transformer + +US_STATES_SHP = ( + "/weka/dfive-default/hadriens/datasets/Misc/Us states/cb_2018_us_state_20m.shp" +) + +SPLIT_COLORS = {"train": "#2176AE", "val": "#F77F00", "test": "#06D6A0"} + + +def load_locations(dataset_path: str) -> dict[str, list[tuple[float, float]]]: + """Load unique locations per split, return {split: [(lon, lat), ...]}.""" + windows_dir = os.path.join(dataset_path, "windows", "spatial_split") + seen: dict[str, set[tuple]] = defaultdict(set) + locations: dict[str, list[tuple[float, float]]] = defaultdict(list) + + for name in os.listdir(windows_dir): + meta_path = os.path.join(windows_dir, name, "metadata.json") + if not os.path.isfile(meta_path): + continue + with open(meta_path) as f: + meta = json.load(f) + + split = meta.get("options", {}).get("split", "unknown") + crs = meta["projection"]["crs"] + bounds = meta["bounds"] + loc_key = (crs, tuple(bounds)) + + if loc_key in seen[split]: + continue + seen[split].add(loc_key) + + x_res = meta["projection"]["x_resolution"] + y_res = meta["projection"]["y_resolution"] + cx = (bounds[0] + bounds[2]) / 2.0 * x_res + cy = (bounds[1] + bounds[3]) / 2.0 * y_res + transformer = Transformer.from_crs(crs, "EPSG:4326", always_xy=True) + lon, lat = transformer.transform(cx, cy) + locations[split].append((lon, lat)) + + return dict(locations) + + +def load_tagged_locations(dataset_path: str) -> dict[str, list[tuple[float, float]]]: + """Load unique locations that have the oep_eval tag, per split.""" + windows_dir = os.path.join(dataset_path, "windows", "spatial_split") + seen: dict[str, set[tuple]] = defaultdict(set) + locations: dict[str, list[tuple[float, float]]] = defaultdict(list) + + for name in os.listdir(windows_dir): + meta_path = os.path.join(windows_dir, name, "metadata.json") + if not os.path.isfile(meta_path): + continue + with open(meta_path) as f: + meta = json.load(f) + + if "oep_eval" not in meta.get("options", {}): + continue + + split = meta.get("options", {}).get("split", "unknown") + crs = meta["projection"]["crs"] + bounds = meta["bounds"] + loc_key = (crs, tuple(bounds)) + + if loc_key in seen[split]: + continue + seen[split].add(loc_key) + + x_res = meta["projection"]["x_resolution"] + y_res = meta["projection"]["y_resolution"] + cx = (bounds[0] + bounds[2]) / 2.0 * x_res + cy = (bounds[1] + bounds[3]) / 2.0 * y_res + transformer = Transformer.from_crs(crs, "EPSG:4326", always_xy=True) + lon, lat = transformer.transform(cx, cy) + locations[split].append((lon, lat)) + + return dict(locations) + + +def plot_locations_on_ax( + ax: plt.Axes, + locations: dict[str, list[tuple[float, float]]], + states_gdf: gpd.GeoDataFrame, + title: str, +) -> None: + """Plot location points on a US map axis.""" + states_gdf.boundary.plot(ax=ax, linewidth=0.5, color="gray") + + for split in ["test", "val", "train"]: + if split not in locations: + continue + pts = locations[split] + if not pts: + continue + lons, lats = zip(*pts) + ax.scatter( + lons, + lats, + c=SPLIT_COLORS[split], + s=15, + alpha=0.7, + edgecolors="none", + label=f"{split} ({len(pts)} locs)", + zorder=3, + ) + + ax.set_xlim(-130, -65) + ax.set_ylim(24, 50) + ax.set_title(title, fontsize=11, fontweight="bold") + ax.set_aspect("equal") + ax.tick_params(labelsize=7) + ax.legend(loc="lower left", fontsize=7, framealpha=0.8) + + +def main() -> None: + """Visualize oep_eval subsampled vs original locations on a US map.""" + parser = argparse.ArgumentParser(description="Visualize oep_eval locations") + parser.add_argument("--woody_path", type=str, required=True) + parser.add_argument("--herbaceous_path", type=str, required=True) + parser.add_argument("--states_shp", type=str, default=US_STATES_SHP) + parser.add_argument("--output", type=str, default="oep_eval_map.png") + args = parser.parse_args() + + print("Loading US states shapefile...") + states_gdf = gpd.read_file(args.states_shp) + states_gdf = states_gdf[~states_gdf["NAME"].isin(["Alaska", "Hawaii"])] + states_gdf = states_gdf.to_crs("EPSG:4326") + + print("Loading woody locations...") + woody_all = load_locations(args.woody_path) + woody_tagged = load_tagged_locations(args.woody_path) + + print("Loading herbaceous locations...") + herb_all = load_locations(args.herbaceous_path) + herb_tagged = load_tagged_locations(args.herbaceous_path) + + fig, axes = plt.subplots(2, 2, figsize=(16, 10)) + + plot_locations_on_ax(axes[0, 0], woody_all, states_gdf, "Woody — All Locations") + plot_locations_on_ax( + axes[0, 1], woody_tagged, states_gdf, "Woody — oep_eval Subset" + ) + plot_locations_on_ax(axes[1, 0], herb_all, states_gdf, "Herbaceous — All Locations") + plot_locations_on_ax( + axes[1, 1], herb_tagged, states_gdf, "Herbaceous — oep_eval Subset" + ) + + plt.tight_layout() + plt.savefig(args.output, dpi=150, bbox_inches="tight") + print(f"Saved figure to {args.output}") + + +if __name__ == "__main__": + main() From a92974b92ca3a19cf711524f9a30f0e91047db32 Mon Sep 17 00:00:00 2001 From: Hadrien Sablon Date: Tue, 19 May 2026 17:14:47 -0700 Subject: [PATCH 2/6] updates --- rslp/lfmc/subsample_oep_eval.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/rslp/lfmc/subsample_oep_eval.py b/rslp/lfmc/subsample_oep_eval.py index 16e067ae4..530a1edd6 100644 --- a/rslp/lfmc/subsample_oep_eval.py +++ b/rslp/lfmc/subsample_oep_eval.py @@ -1,5 +1,9 @@ """Tag a spatially-balanced subset of LFMC windows with the oep_eval tag. +WARNING: This script is intended to be run exactly once per dataset. It will +refuse to run if a manifest file (oep_eval_manifest.json) from a previous run +is found. + For train: selects ~1000 windows distributed evenly across US states, operating at the location level. Each selected location contributes at most 1 year of samples (the densest 365-day window), maximizing geographic coverage. @@ -234,8 +238,6 @@ def tag_windows(windows: list[dict]) -> int: with open(meta_path, "w") as f: json.dump(meta, f) tagged += 1 - else: - tagged += 1 return tagged @@ -356,6 +358,14 @@ def main() -> None: ) args = parser.parse_args() + manifest_path = os.path.join(args.dataset_path, "oep_eval_manifest.json") + if os.path.exists(manifest_path): + raise RuntimeError( + f"Manifest file already exists at {manifest_path}. " + "This script is intended to be run exactly once per dataset. " + "Remove the manifest file manually if you need to re-run." + ) + print(f"Loading windows from {args.dataset_path}...") all_windows = load_windows(args.dataset_path) print(f" Total windows: {len(all_windows)}") @@ -441,7 +451,6 @@ def main() -> None: "val": [w["_name"] for w in selected_val], "test": [w["_name"] for w in selected_test], } - manifest_path = os.path.join(args.dataset_path, "oep_eval_manifest.json") with open(manifest_path, "w") as f: json.dump(manifest, f, indent=2) print(f"\nWrote manifest to {manifest_path}") From 1cd1e28244f149240434e2a09fb410310f6fc344 Mon Sep 17 00:00:00 2001 From: root Date: Thu, 28 May 2026 16:21:02 +0000 Subject: [PATCH 3/6] wip --- rslp/lfmc/subsample_oep_eval.py | 34 +++++++++++++++++---------------- rslp/lfmc/visualize_oep_eval.py | 28 ++++++++++++++++----------- 2 files changed, 35 insertions(+), 27 deletions(-) diff --git a/rslp/lfmc/subsample_oep_eval.py b/rslp/lfmc/subsample_oep_eval.py index 16e067ae4..a72e42883 100644 --- a/rslp/lfmc/subsample_oep_eval.py +++ b/rslp/lfmc/subsample_oep_eval.py @@ -1,4 +1,4 @@ -"""Tag a spatially-balanced subset of LFMC windows with the oep_eval tag. +"""Tag a spatially-balanced subset of LFMC windows with a configurable tag. For train: selects ~1000 windows distributed evenly across US states, operating at the location level. Each selected location contributes at most 1 year of @@ -6,7 +6,7 @@ For val/test: tags ALL windows. Usage: - python subsample_oep_eval.py --dataset_path /path/to/dataset --target_train 1000 --seed 42 + python subsample_oep_eval.py --dataset_path /path/to/dataset --tag oep_eval --target_train 1000 --seed 42 """ import argparse @@ -220,8 +220,8 @@ def select_locations_balanced( return selected -def tag_windows(windows: list[dict]) -> int: - """Add oep_eval tag to the given windows' metadata.json files.""" +def tag_windows(windows: list[dict], tag: str) -> int: + """Add the given tag to the windows' metadata.json files.""" tagged = 0 for i, w in enumerate(windows): if i % 2000 == 0 and i > 0: @@ -229,8 +229,8 @@ def tag_windows(windows: list[dict]) -> int: meta_path = os.path.join(w["_dir"], "metadata.json") with open(meta_path) as f: meta = json.load(f) - if "oep_eval" not in meta.get("options", {}): - meta.setdefault("options", {})["oep_eval"] = "" + if tag not in meta.get("options", {}): + meta.setdefault("options", {})[tag] = "" with open(meta_path, "w") as f: json.dump(meta, f) tagged += 1 @@ -340,9 +340,10 @@ def subsample_split( def main() -> None: - """Tag a spatially-balanced oep_eval subset for LFMC datasets.""" - parser = argparse.ArgumentParser(description="Tag oep_eval subset for LFMC") + """Tag a spatially-balanced subset for LFMC datasets.""" + parser = argparse.ArgumentParser(description="Tag a subset for LFMC") parser.add_argument("--dataset_path", type=str, required=True) + parser.add_argument("--tag", type=str, default="oep_eval", help="Tag name to apply") parser.add_argument("--target_train", type=int, default=1000) parser.add_argument("--seed", type=int, default=42) parser.add_argument( @@ -423,16 +424,16 @@ def main() -> None: return # Tag windows - print("\nTagging selected train samples...") - train_tagged = tag_windows(selected_train) + print(f"\nTagging selected train samples with '{args.tag}'...") + train_tagged = tag_windows(selected_train, args.tag) print(f" Tagged {train_tagged} train samples") - print("Tagging selected val samples...") - val_tagged = tag_windows(selected_val) + print(f"Tagging selected val samples with '{args.tag}'...") + val_tagged = tag_windows(selected_val, args.tag) print(f" Tagged {val_tagged} val samples") - print("Tagging selected test samples...") - test_tagged = tag_windows(selected_test) + print(f"Tagging selected test samples with '{args.tag}'...") + test_tagged = tag_windows(selected_test, args.tag) print(f" Tagged {test_tagged} test samples") # Write manifest @@ -441,13 +442,14 @@ def main() -> None: "val": [w["_name"] for w in selected_val], "test": [w["_name"] for w in selected_test], } - manifest_path = os.path.join(args.dataset_path, "oep_eval_manifest.json") + manifest_path = os.path.join(args.dataset_path, f"{args.tag}_manifest.json") with open(manifest_path, "w") as f: json.dump(manifest, f, indent=2) print(f"\nWrote manifest to {manifest_path}") # Write stats stats = { + "tag": args.tag, "seed": args.seed, "one_year_cap_days": 365, "val_max_samples_threshold": VAL_MAX_SAMPLES, @@ -462,7 +464,7 @@ def main() -> None: "val": val_stats, "test": test_stats, } - stats_path = os.path.join(args.dataset_path, "oep_eval_stats.json") + stats_path = os.path.join(args.dataset_path, f"{args.tag}_stats.json") with open(stats_path, "w") as f: json.dump(stats, f, indent=2) print(f"Wrote stats to {stats_path}") diff --git a/rslp/lfmc/visualize_oep_eval.py b/rslp/lfmc/visualize_oep_eval.py index 12a7a3017..90c9f0d14 100644 --- a/rslp/lfmc/visualize_oep_eval.py +++ b/rslp/lfmc/visualize_oep_eval.py @@ -1,16 +1,17 @@ -"""Visualize oep_eval subsampled locations vs all original locations on a US map. +"""Visualize tagged subsampled locations vs all original locations on a US map. Produces a 2x2 figure: - Row 1: Woody dataset - Row 2: Herbaceous dataset - Left column: all original locations - - Right column: oep_eval-tagged subset only + - Right column: tagged subset only Points are colored by split (train=blue, val=orange, test=green). Usage: python visualize_oep_eval.py \ --woody_path /path/to/woody/dataset \ --herbaceous_path /path/to/herbaceous/dataset \ + --tag oep_eval \ --output oep_eval_map.png """ @@ -63,8 +64,10 @@ def load_locations(dataset_path: str) -> dict[str, list[tuple[float, float]]]: return dict(locations) -def load_tagged_locations(dataset_path: str) -> dict[str, list[tuple[float, float]]]: - """Load unique locations that have the oep_eval tag, per split.""" +def load_tagged_locations( + dataset_path: str, tag: str +) -> dict[str, list[tuple[float, float]]]: + """Load unique locations that have the given tag, per split.""" windows_dir = os.path.join(dataset_path, "windows", "spatial_split") seen: dict[str, set[tuple]] = defaultdict(set) locations: dict[str, list[tuple[float, float]]] = defaultdict(list) @@ -76,7 +79,7 @@ def load_tagged_locations(dataset_path: str) -> dict[str, list[tuple[float, floa with open(meta_path) as f: meta = json.load(f) - if "oep_eval" not in meta.get("options", {}): + if tag not in meta.get("options", {}): continue split = meta.get("options", {}).get("split", "unknown") @@ -135,10 +138,13 @@ def plot_locations_on_ax( def main() -> None: - """Visualize oep_eval subsampled vs original locations on a US map.""" - parser = argparse.ArgumentParser(description="Visualize oep_eval locations") + """Visualize tagged subsampled vs original locations on a US map.""" + parser = argparse.ArgumentParser(description="Visualize tagged locations") parser.add_argument("--woody_path", type=str, required=True) parser.add_argument("--herbaceous_path", type=str, required=True) + parser.add_argument( + "--tag", type=str, default="oep_eval", help="Tag name to filter on" + ) parser.add_argument("--states_shp", type=str, default=US_STATES_SHP) parser.add_argument("--output", type=str, default="oep_eval_map.png") args = parser.parse_args() @@ -150,21 +156,21 @@ def main() -> None: print("Loading woody locations...") woody_all = load_locations(args.woody_path) - woody_tagged = load_tagged_locations(args.woody_path) + woody_tagged = load_tagged_locations(args.woody_path, args.tag) print("Loading herbaceous locations...") herb_all = load_locations(args.herbaceous_path) - herb_tagged = load_tagged_locations(args.herbaceous_path) + herb_tagged = load_tagged_locations(args.herbaceous_path, args.tag) fig, axes = plt.subplots(2, 2, figsize=(16, 10)) plot_locations_on_ax(axes[0, 0], woody_all, states_gdf, "Woody — All Locations") plot_locations_on_ax( - axes[0, 1], woody_tagged, states_gdf, "Woody — oep_eval Subset" + axes[0, 1], woody_tagged, states_gdf, f"Woody — {args.tag} Subset" ) plot_locations_on_ax(axes[1, 0], herb_all, states_gdf, "Herbaceous — All Locations") plot_locations_on_ax( - axes[1, 1], herb_tagged, states_gdf, "Herbaceous — oep_eval Subset" + axes[1, 1], herb_tagged, states_gdf, f"Herbaceous — {args.tag} Subset" ) plt.tight_layout() From 4deec953cde43a182ae65d57dc3068651a88d9ba Mon Sep 17 00:00:00 2001 From: root Date: Wed, 3 Jun 2026 20:43:13 +0000 Subject: [PATCH 4/6] wip --- .../v2_lfmc/subsample/woody3k_lprobe.yaml | 109 +++++++ .../v2_lfmc/subsample/woody3k_unet.yaml | 101 +++++++ rslp/lfmc/compute_target_stats.py | 285 ++++++++++++++++++ 3 files changed, 495 insertions(+) create mode 100644 data/helios/v2_lfmc/subsample/woody3k_lprobe.yaml create mode 100644 data/helios/v2_lfmc/subsample/woody3k_unet.yaml create mode 100644 rslp/lfmc/compute_target_stats.py diff --git a/data/helios/v2_lfmc/subsample/woody3k_lprobe.yaml b/data/helios/v2_lfmc/subsample/woody3k_lprobe.yaml new file mode 100644 index 000000000..d048e0e10 --- /dev/null +++ b/data/helios/v2_lfmc/subsample/woody3k_lprobe.yaml @@ -0,0 +1,109 @@ +model: + class_path: rslearn.train.lightning_module.RslearnLightningModule + init_args: + model: + class_path: rslearn.models.multitask.MultiTaskModel + init_args: + encoder: + - class_path: rslearn.models.olmoearth_pretrain.model.OlmoEarth + init_args: + model_id: OLMOEARTH_V1_1_BASE + patch_size: 4 + decoders: + lfmc_estimation: + - class_path: rslearn.models.upsample.Upsample + init_args: + scale_factor: 4 + - class_path: rslearn.models.conv.Conv + init_args: + in_channels: 768 + out_channels: 1 + kernel_size: 1 + activation: + class_path: torch.nn.Identity + - class_path: rslearn.train.tasks.per_pixel_regression.PerPixelRegressionHead + lr: 0.00002 + scheduler: + class_path: rslearn.train.scheduler.PlateauScheduler + init_args: + factor: 0.2 + patience: 2 + min_lr: 0 + cooldown: 10 +data: + class_path: rslearn.train.data_module.RslearnDataModule + init_args: + path: /weka/dfive-default/rslearn-eai/datasets/lfmc/20251023/woody/scratch/dataset + inputs: + sentinel2_l2a: + data_type: "raster" + layers: ["sentinel2"] + bands: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12", "B01", "B09"] + passthrough: true + dtype: FLOAT32 + load_all_item_groups: true + load_all_layers: true + targets: + data_type: "raster" + layers: ["labels"] + bands: ["value"] + dtype: FLOAT32 + is_target: true + task: + class_path: rslearn.train.tasks.multi_task.MultiTask + init_args: + tasks: + lfmc_estimation: + class_path: rslearn.train.tasks.per_pixel_regression.PerPixelRegressionTask + init_args: + nodata_value: -1 + metrics: ["rmse", "r2"] + #scale_factor: 0.0024038 + target_mean: 108.389520 + target_std: 36.616698 + + input_mapping: + lfmc_estimation: + targets: "targets" + batch_size: 32 + num_workers: 16 + default_config: + crop_size: 32 + transforms: + - class_path: rslearn.models.olmoearth_pretrain.norm.OlmoEarthNormalize + init_args: + band_names: + sentinel2_l2a: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12", "B01", "B09"] + train_config: + groups: ["spatial_split"] + tags: + split: "train" + oep_eval_big: "" + val_config: + groups: ["spatial_split"] + tags: + split: "val" + oep_eval_big: "" + test_config: + groups: ["spatial_split"] + tags: + split: "test" + oep_eval_big: "" +trainer: + max_epochs: 100 + callbacks: + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: "epoch" + - class_path: rslearn.train.callbacks.checkpointing.ManagedBestLastCheckpoint + init_args: + monitor: val_loss + mode: min + - class_path: rslearn.train.callbacks.freeze_unfreeze.FreezeUnfreeze + init_args: + module_selector: ["model", "encoder", 0] + unfreeze_at_epoch: 50 + unfreeze_lr_factor: 10 +management_dir: ${RSLP_PREFIX}/projects +project_name: 20260528_lfmc_finetune +run_name: placeholder diff --git a/data/helios/v2_lfmc/subsample/woody3k_unet.yaml b/data/helios/v2_lfmc/subsample/woody3k_unet.yaml new file mode 100644 index 000000000..a3c570f03 --- /dev/null +++ b/data/helios/v2_lfmc/subsample/woody3k_unet.yaml @@ -0,0 +1,101 @@ +model: + class_path: rslearn.train.lightning_module.RslearnLightningModule + init_args: + model: + class_path: rslearn.models.multitask.MultiTaskModel + init_args: + encoder: + - class_path: rslearn.models.olmoearth_pretrain.model.OlmoEarth + init_args: + model_id: OLMOEARTH_V1_1_BASE + patch_size: 4 + decoders: + lfmc_estimation: + - class_path: rslearn.models.unet.UNetDecoder + init_args: + in_channels: [[4, 768]] + out_channels: 1 + conv_layers_per_resolution: 2 + num_channels: {8: 512, 4: 512, 2: 256, 1: 128} + - class_path: rslearn.train.tasks.per_pixel_regression.PerPixelRegressionHead + lr: 0.00002 + scheduler: + class_path: rslearn.train.scheduler.PlateauScheduler + init_args: + factor: 0.2 + patience: 2 + min_lr: 0 + cooldown: 10 +data: + class_path: rslearn.train.data_module.RslearnDataModule + init_args: + path: /weka/dfive-default/rslearn-eai/datasets/lfmc/20251023/woody/scratch/dataset + inputs: + sentinel2_l2a: + data_type: "raster" + layers: ["sentinel2"] + bands: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12", "B01", "B09"] + passthrough: true + dtype: FLOAT32 + load_all_item_groups: true + load_all_layers: true + targets: + data_type: "raster" + layers: ["labels"] + bands: ["value"] + dtype: FLOAT32 + is_target: true + task: + class_path: rslearn.train.tasks.multi_task.MultiTask + init_args: + tasks: + lfmc_estimation: + class_path: rslearn.train.tasks.per_pixel_regression.PerPixelRegressionTask + init_args: + nodata_value: -1 + metrics: ["rmse", "r2"] + input_mapping: + lfmc_estimation: + targets: "targets" + batch_size: 32 + num_workers: 16 + default_config: + crop_size: 32 + transforms: + - class_path: rslearn.models.olmoearth_pretrain.norm.OlmoEarthNormalize + init_args: + band_names: + sentinel2_l2a: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12", "B01", "B09"] + train_config: + groups: ["spatial_split"] + tags: + split: "train" + oep_eval_big: "" + val_config: + groups: ["spatial_split"] + tags: + split: "val" + oep_eval_big: "" + test_config: + groups: ["spatial_split"] + tags: + split: "test" + oep_eval_big: "" +trainer: + max_epochs: 100 + callbacks: + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: "epoch" + - class_path: rslearn.train.callbacks.checkpointing.ManagedBestLastCheckpoint + init_args: + monitor: val_loss + mode: min + - class_path: rslearn.train.callbacks.freeze_unfreeze.FreezeUnfreeze + init_args: + module_selector: ["model", "encoder", 0] + unfreeze_at_epoch: 20 + unfreeze_lr_factor: 10 +management_dir: ${RSLP_PREFIX}/projects +project_name: 20260528_lfmc_finetune +run_name: placeholder diff --git a/rslp/lfmc/compute_target_stats.py b/rslp/lfmc/compute_target_stats.py new file mode 100644 index 000000000..96e51b745 --- /dev/null +++ b/rslp/lfmc/compute_target_stats.py @@ -0,0 +1,285 @@ +"""Compute mean/std of LFMC target values over a split (and optional tag). + +These statistics are intended to be plugged into the ``target_mean`` / ``target_std`` +arguments of ``rslearn.train.tasks.per_pixel_regression.PerPixelRegressionTask`` so that +the regression target is normalized during training while metrics are still reported in +the original units. + +The statistics are computed only over *valid* pixels: pixels equal to the ignore value +(``-1`` by default, matching the task's ``nodata_value``) and non-finite pixels are +excluded. Windows are filtered the same way the training data module filters them: by +group, by the ``split`` option, and (optionally) by requiring a tag option to be present. + +IMPORTANT: only compute these statistics on the train split to avoid leaking validation +or test information into the normalization. + +Usage: + python -m rslp.lfmc.compute_target_stats \ + --dataset_path /weka/dfive-default/rslearn-eai/datasets/lfmc/20251023/woody/scratch/dataset \ + --split train \ + --tag oep_eval_big +""" + +import argparse +import json +import math +import os +from collections.abc import Iterable +from concurrent.futures import ProcessPoolExecutor, as_completed +from dataclasses import dataclass + +import numpy as np +import rasterio + + +@dataclass +class PixelStats: + """Streaming accumulator for per-pixel statistics over valid pixels.""" + + count: int = 0 + total: float = 0.0 + total_sq: float = 0.0 + minimum: float = math.inf + maximum: float = -math.inf + + def merge(self, other: "PixelStats") -> None: + """Merge another accumulator into this one.""" + self.count += other.count + self.total += other.total + self.total_sq += other.total_sq + self.minimum = min(self.minimum, other.minimum) + self.maximum = max(self.maximum, other.maximum) + + @property + def mean(self) -> float: + """Mean over valid pixels.""" + if self.count == 0: + raise ValueError("no valid pixels found") + return self.total / self.count + + def std(self, ddof: int = 0) -> float: + """Standard deviation over valid pixels. + + Args: + ddof: delta degrees of freedom. Use 0 (population std) for normalization. + """ + if self.count - ddof <= 0: + raise ValueError("not enough valid pixels to compute std") + # Var = E[x^2] - E[x]^2, scaled to the requested ddof. Clamp tiny negatives that + # can arise from floating point error. + variance = (self.total_sq - self.total**2 / self.count) / (self.count - ddof) + return math.sqrt(max(variance, 0.0)) + + +def _matches_filters( + options: dict, + split: str | None, + tag: str | None, +) -> bool: + """Return whether a window's options pass the split/tag filters. + + Mirrors the tag filtering done by rslearn's ModelDataset: the tag key must be present + in the window options (an empty configured value just requires presence). + """ + if split is not None and options.get("split") != split: + return False + if tag is not None and tag not in options: + return False + return True + + +def _select_window_dirs( + dataset_path: str, + group: str, + split: str | None, + tag: str | None, +) -> list[str]: + """Find window directories matching the split/tag filters.""" + windows_dir = os.path.join(dataset_path, "windows", group) + names = sorted(os.listdir(windows_dir)) + selected = [] + for name in names: + window_dir = os.path.join(windows_dir, name) + meta_path = os.path.join(window_dir, "metadata.json") + if not os.path.isfile(meta_path): + continue + with open(meta_path) as f: + meta = json.load(f) + if _matches_filters(meta.get("options", {}), split, tag): + selected.append(window_dir) + return selected + + +# Globals set per worker process so the geotiff path can be built without re-passing +# constant arguments for every task. +_LAYER = "labels" +_IGNORE_VALUE = -1.0 + + +def _init_worker(layer: str, ignore_value: float) -> None: + global _LAYER, _IGNORE_VALUE + _LAYER = layer + _IGNORE_VALUE = ignore_value + + +def _window_stats(window_dir: str) -> PixelStats: + """Read one window's label raster and accumulate stats over valid pixels.""" + stats = PixelStats() + geotiff_path = os.path.join(window_dir, "layers", _LAYER, "value", "geotiff.tif") + if not os.path.isfile(geotiff_path): + return stats + with rasterio.open(geotiff_path) as ds: + array = ds.read().astype(np.float64) + valid = np.isfinite(array) & (array != _IGNORE_VALUE) + values = array[valid] + if values.size == 0: + return stats + stats.count = int(values.size) + stats.total = float(values.sum()) + stats.total_sq = float(np.square(values).sum()) + stats.minimum = float(values.min()) + stats.maximum = float(values.max()) + return stats + + +def compute_stats( + window_dirs: Iterable[str], + layer: str, + ignore_value: float, + workers: int, +) -> PixelStats: + """Compute aggregate pixel statistics across the given window directories.""" + window_dirs = list(window_dirs) + total = PixelStats() + if workers <= 1: + _init_worker(layer, ignore_value) + for i, window_dir in enumerate(window_dirs): + total.merge(_window_stats(window_dir)) + if (i + 1) % 5000 == 0: + print(f" ... processed {i + 1}/{len(window_dirs)} windows") + return total + + with ProcessPoolExecutor( + max_workers=workers, + initializer=_init_worker, + initargs=(layer, ignore_value), + ) as executor: + futures = { + executor.submit(_window_stats, window_dir): window_dir + for window_dir in window_dirs + } + for i, future in enumerate(as_completed(futures)): + total.merge(future.result()) + if (i + 1) % 5000 == 0: + print(f" ... processed {i + 1}/{len(window_dirs)} windows") + return total + + +def main() -> None: + """Compute and report LFMC target mean/std for a split.""" + parser = argparse.ArgumentParser( + description="Compute mean/std of LFMC target values over a split." + ) + parser.add_argument("--dataset_path", type=str, required=True) + parser.add_argument( + "--group", + type=str, + default="spatial_split", + help="Window group to read (default: spatial_split).", + ) + parser.add_argument( + "--split", + type=str, + default="train", + help="Only include windows whose 'split' option equals this. " + "Pass an empty string to disable the split filter.", + ) + parser.add_argument( + "--tag", + type=str, + default=None, + help="If set, only include windows that have this tag option (e.g. " + "oep_eval_big).", + ) + parser.add_argument( + "--layer", + type=str, + default="labels", + help="Raster layer holding the target (default: labels).", + ) + parser.add_argument( + "--ignore_value", + type=float, + default=-1.0, + help="Pixel value to treat as invalid/nodata (default: -1).", + ) + parser.add_argument( + "--workers", + type=int, + default=16, + help="Number of worker processes for reading rasters (default: 16).", + ) + parser.add_argument( + "--output", + type=str, + default=None, + help="Optional path to write the computed stats as JSON.", + ) + args = parser.parse_args() + + split = args.split if args.split else None + + print( + f"Selecting windows from group '{args.group}' " + f"(split={split}, tag={args.tag})..." + ) + window_dirs = _select_window_dirs(args.dataset_path, args.group, split, args.tag) + print(f" Selected {len(window_dirs)} windows") + if not window_dirs: + raise RuntimeError("no windows matched the given filters") + + print("Reading target rasters and accumulating statistics...") + stats = compute_stats(window_dirs, args.layer, args.ignore_value, args.workers) + + if stats.count == 0: + raise RuntimeError("no valid target pixels found") + + mean = stats.mean + std_pop = stats.std(ddof=0) + std_sample = stats.std(ddof=1) + + result = { + "dataset_path": args.dataset_path, + "group": args.group, + "split": split, + "tag": args.tag, + "layer": args.layer, + "ignore_value": args.ignore_value, + "num_windows": len(window_dirs), + "num_valid_pixels": stats.count, + "mean": mean, + "std": std_pop, + "std_sample": std_sample, + "min": stats.minimum, + "max": stats.maximum, + } + + print("\n=== LFMC target statistics (valid pixels only) ===") + print(f" windows: {len(window_dirs)}") + print(f" valid pixels: {stats.count}") + print(f" mean: {mean:.6f}") + print(f" std (pop): {std_pop:.6f}") + print(f" std (sample): {std_sample:.6f}") + print(f" min / max: {stats.minimum:.6f} / {stats.maximum:.6f}") + print("\nUse these in the PerPixelRegressionTask config:") + print(f" target_mean: {mean:.6f}") + print(f" target_std: {std_pop:.6f}") + + if args.output: + with open(args.output, "w") as f: + json.dump(result, f, indent=2) + print(f"\nWrote stats to {args.output}") + + +if __name__ == "__main__": + main() From 398f73cfd297e8cd33d9a7a3fa17e435862dad13 Mon Sep 17 00:00:00 2001 From: root Date: Tue, 9 Jun 2026 20:52:10 +0000 Subject: [PATCH 5/6] adding subsampling strategy for mapbiomas + check + sampling arteficts/viz --- data/helios/mapbiomas/dense_config/model.yaml | 131 ++++ .../helios/mapbiomas/sparse_config/model.yaml | 118 +++ rslp/mapbiomas/create_windows_dense_raster.py | 258 +++++++ .../mapbiomas/create_windows_expert_sparse.py | 231 ++++++ rslp/mapbiomas/sanity_check.py | 691 ++++++++++++++++++ rslp/mapbiomas/subsampling/__init__.py | 1 + .../subsampling/sample_dense_raster.py | 563 ++++++++++++++ .../subsampling/sample_expert_points.py | 631 ++++++++++++++++ .../visualize_sample_expert_points.py | 191 +++++ rslp/utils/beaker.py | 2 +- 10 files changed, 2816 insertions(+), 1 deletion(-) create mode 100644 data/helios/mapbiomas/dense_config/model.yaml create mode 100644 data/helios/mapbiomas/sparse_config/model.yaml create mode 100644 rslp/mapbiomas/create_windows_dense_raster.py create mode 100644 rslp/mapbiomas/create_windows_expert_sparse.py create mode 100644 rslp/mapbiomas/sanity_check.py create mode 100644 rslp/mapbiomas/subsampling/__init__.py create mode 100644 rslp/mapbiomas/subsampling/sample_dense_raster.py create mode 100644 rslp/mapbiomas/subsampling/sample_expert_points.py create mode 100644 rslp/mapbiomas/subsampling/visualize_sample_expert_points.py diff --git a/data/helios/mapbiomas/dense_config/model.yaml b/data/helios/mapbiomas/dense_config/model.yaml new file mode 100644 index 000000000..5551407fb --- /dev/null +++ b/data/helios/mapbiomas/dense_config/model.yaml @@ -0,0 +1,131 @@ +model: + class_path: rslearn.train.lightning_module.RslearnLightningModule + init_args: + model: + class_path: rslearn.models.multitask.MultiTaskModel + init_args: + encoder: + - class_path: rslearn.models.olmoearth_pretrain.model.OlmoEarth + init_args: + model_id: OLMOEARTH_V1_1_BASE + patch_size: 3 + decoders: + segment: + - class_path: rslearn.models.upsample.Upsample + init_args: + scale_factor: 3 + - class_path: rslearn.models.conv.Conv + init_args: + in_channels: 768 + out_channels: 28 + kernel_size: 1 + activation: + class_path: torch.nn.Identity + - class_path: rslearn.train.tasks.segmentation.SegmentationHead + lr: 0.0001 + scheduler: + class_path: rslearn.train.scheduler.PlateauScheduler + init_args: + factor: 0.2 + patience: 2 + min_lr: 0 + cooldown: 10 +data: + class_path: rslearn.train.data_module.RslearnDataModule + init_args: + path: /weka/dfive-default/rslearn-eai/datasets/mapbiomas_3k + inputs: + sentinel2_l2a: + data_type: "raster" + layers: ["sentinel2_l2a_feb", "sentinel2_l2a_may", "sentinel2_l2a_aug", "sentinel2_l2a_nov"] + bands: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12", "B01", "B09"] + passthrough: true + dtype: FLOAT32 + load_all_layers: true + targets: + data_type: "raster" + layers: ["label_raster"] + bands: ["label"] + dtype: INT32 + is_target: true + task: + class_path: rslearn.train.tasks.multi_task.MultiTask + init_args: + tasks: + segment: + class_path: rslearn.train.tasks.segmentation.SegmentationTask + init_args: + num_classes: 28 + class_id_mapping: + 3: 1 + 4: 2 + 5: 3 + 6: 4 + 9: 5 + 11: 6 + 12: 7 + 15: 8 + 20: 9 + 21: 10 + 23: 11 + 24: 12 + 25: 13 + 29: 14 + 30: 15 + 32: 16 + 33: 17 + 35: 18 + 39: 19 + 40: 20 + 41: 21 + 46: 22 + 47: 23 + 48: 24 + 49: 25 + 50: 26 + 62: 27 + nodata_value: 0 + enable_f1_metric: true + metric_kwargs: + average: "macro" + input_mapping: + segment: + targets: "targets" + batch_size: 32 + num_workers: 16 + default_config: + transforms: + - class_path: rslearn.models.olmoearth_pretrain.norm.OlmoEarthNormalize + init_args: + band_names: + sentinel2_l2a: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12", "B01", "B09"] + train_config: + groups: ["mapbiomas_dense_raster"] + tags: + split: "train" + val_config: + groups: ["mapbiomas_dense_raster"] + tags: + split: "val" + test_config: + groups: ["mapbiomas_dense_raster"] + tags: + split: "val" +trainer: + max_epochs: 50 + callbacks: + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: "epoch" + - class_path: rslearn.train.callbacks.checkpointing.ManagedBestLastCheckpoint + init_args: + monitor: val_segment/F1 + mode: max + - class_path: rslearn.train.callbacks.freeze_unfreeze.FreezeUnfreeze + init_args: + module_selector: ["model", "encoder", 0] + unfreeze_at_epoch: 50 + unfreeze_lr_factor: 10 +management_dir: ${RSLP_PREFIX}/projects +project_name: 20260605_mapbiomas_dense +run_name: placeholder diff --git a/data/helios/mapbiomas/sparse_config/model.yaml b/data/helios/mapbiomas/sparse_config/model.yaml new file mode 100644 index 000000000..1296f5850 --- /dev/null +++ b/data/helios/mapbiomas/sparse_config/model.yaml @@ -0,0 +1,118 @@ +model: + class_path: rslearn.train.lightning_module.RslearnLightningModule + init_args: + model: + class_path: rslearn.models.multitask.MultiTaskModel + init_args: + encoder: + - class_path: rslearn.models.olmoearth_pretrain.model.OlmoEarth + init_args: + model_id: OLMOEARTH_V1_1_BASE + patch_size: 3 + decoders: + segment: + - class_path: rslearn.models.conv.Conv + init_args: + in_channels: 768 + out_channels: 12 + kernel_size: 1 + activation: + class_path: torch.nn.Identity + - class_path: rslearn.train.tasks.segmentation.SegmentationHead + lr: 0.0001 + scheduler: + class_path: rslearn.train.scheduler.PlateauScheduler + init_args: + factor: 0.2 + patience: 2 + min_lr: 0 + cooldown: 10 +data: + class_path: rslearn.train.data_module.RslearnDataModule + init_args: + path: /weka/dfive-default/rslearn-eai/datasets/mapbiomas_3k + inputs: + sentinel2_l2a: + data_type: "raster" + layers: ["sentinel2_l2a_feb", "sentinel2_l2a_may", "sentinel2_l2a_aug", "sentinel2_l2a_nov"] + bands: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12", "B01", "B09"] + passthrough: true + dtype: FLOAT32 + load_all_layers: true + targets: + data_type: "raster" + layers: ["label_raster"] + bands: ["label"] + dtype: INT32 + is_target: true + task: + class_path: rslearn.train.tasks.multi_task.MultiTask + init_args: + tasks: + segment: + class_path: rslearn.train.tasks.segmentation.SegmentationTask + init_args: + num_classes: 12 + class_id_mapping: + 3: 1 + 4: 2 + 9: 3 + 11: 4 + 12: 5 + 15: 6 + 19: 7 + 20: 8 + 24: 9 + 33: 10 + 36: 11 + nodata_value: 0 + enable_f1_metric: true + metric_kwargs: + average: "macro" + input_mapping: + segment: + targets: "targets" + batch_size: 32 + num_workers: 16 + default_config: + crop_size: 48 + transforms: + - class_path: rslearn.models.olmoearth_pretrain.norm.OlmoEarthNormalize + init_args: + band_names: + sentinel2_l2a: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12", "B01", "B09"] + - class_path: rslearn.train.transforms.resize.Resize + init_args: + target_size: [16, 16] + selectors: ["target/segment/classes", "target/segment/valid"] + interpolation: "nearest" + train_config: + groups: ["mapbiomas_expert_sparse"] + tags: + split: "train" + val_config: + groups: ["mapbiomas_expert_sparse"] + tags: + split: "val" + test_config: + groups: ["mapbiomas_expert_sparse"] + tags: + split: "val" +trainer: + max_epochs: 50 + callbacks: + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: "epoch" + - class_path: rslearn.train.callbacks.checkpointing.ManagedBestLastCheckpoint + init_args: + monitor: val_segment/F1 + mode: max + - class_path: rslearn.train.callbacks.freeze_unfreeze.FreezeUnfreeze + init_args: + module_selector: ["model", "encoder", 0] + unfreeze_at_epoch: 50 + unfreeze_lr_factor: 10 +management_dir: ${RSLP_PREFIX}/projects +project_name: 20260605_mapbiomas +run_name: placeholder diff --git a/rslp/mapbiomas/create_windows_dense_raster.py b/rslp/mapbiomas/create_windows_dense_raster.py new file mode 100644 index 000000000..0ee7f19c3 --- /dev/null +++ b/rslp/mapbiomas/create_windows_dense_raster.py @@ -0,0 +1,258 @@ +"""Create rslearn windows with dense target labels from MapBiomas coverage rasters. + +Each window is target_size x target_size pixels at 10m resolution. A patch of +(target_size / 3) pixels is read from the corresponding year's MapBiomas raster +(~30m) and upscaled 3x via nearest-neighbor to fill the target. Omitted classes +are remapped to no-data (0); windows where all pixels are no-data after +remapping are skipped. + +Usage:: + + python -m rslp.mapbiomas.create_windows_dense_raster \ + --ds-name mapbiomas_3k \ + --omit-classes 31 0 +""" + +from __future__ import annotations + +import argparse +import multiprocessing +import os +from datetime import datetime, timezone +from pathlib import Path + +import numpy as np +import pandas as pd +import rasterio +import rasterio.windows +import shapely +import tqdm +from rslearn.const import WGS84_PROJECTION +from rslearn.dataset import Dataset, Window +from rslearn.utils import Projection, STGeometry, get_utm_ups_crs +from rslearn.utils.mp import star_imap_unordered +from rslearn.utils.raster_array import RasterArray +from rslearn.utils.raster_format import GeotiffRasterFormat +from upath import UPath + +from rslp.utils.windows import calculate_bounds + +MY_ROOT = Path(os.environ.get("MY_ROOT", ".")) + +WINDOW_RESOLUTION = 10 +LABEL_LAYER = "label_raster" +BAND_NAME = "label" +NODATA_VALUE = 0 +GROUP_NAME = "mapbiomas_dense_raster" +UPSCALE_FACTOR = 3 + +DEFAULT_CSV = ( + MY_ROOT / "rslearn_projects/rslp/mapbiomas/subsampling/sample_dense_raster_4k.csv" +) +DEFAULT_RASTER_DIR = MY_ROOT / "datasets/mapbiomas/data" +DEFAULT_DS_NAME = "mapbiomas_3k" + + +def create_window( + csv_row: pd.Series, + ds_path: UPath, + target_size: int, + omit_classes: set[int], + raster_dir: Path, +) -> str: + """Create one dense-label window from a MapBiomas coverage raster patch. + + Returns a status string: "created", "omit_class", or "omit_nodata". + """ + target_id = int(csv_row["TARGETID"]) + year = int(csv_row["YEAR"]) + split = csv_row["split"] + window_name = f"{target_id}_{year}" + + # LON/LAT columns are swapped in the source CSV + longitude = csv_row["LAT"] + latitude = csv_row["LON"] + + raster_path = raster_dir / f"brazil_coverage_{year}.tif" + if not raster_path.exists(): + raise FileNotFoundError(f"Raster not found: {raster_path}") + + raster_patch_size = target_size // UPSCALE_FACTOR + half_patch = raster_patch_size // 2 + + with rasterio.open(raster_path) as src: + r, c = src.index(longitude, latitude) + rio_window = rasterio.windows.Window( + col_off=c - half_patch, + row_off=r - half_patch, + width=raster_patch_size, + height=raster_patch_size, + ) + patch = src.read(1, window=rio_window, boundless=True, fill_value=NODATA_VALUE) + + # Remap omitted classes to no-data + if omit_classes: + for cls in omit_classes: + patch[patch == cls] = NODATA_VALUE + + if np.all(patch == NODATA_VALUE): + print( + f"Omitted window {window_name}: all pixels are no-data after class omission" + ) + return "omit_nodata" + + # Upscale via nearest-neighbor (repeat each pixel 3x3) + upscaled = np.repeat( + np.repeat(patch, UPSCALE_FACTOR, axis=0), UPSCALE_FACTOR, axis=1 + ) + + # Trim to exact target_size in case of rounding (e.g. target_size not divisible by 3) + upscaled = upscaled[:target_size, :target_size] + + # Project point to UTM and build rslearn window + src_point = shapely.Point(longitude, latitude) + src_geometry = STGeometry(WGS84_PROJECTION, src_point, None) + dst_crs = get_utm_ups_crs(longitude, latitude) + dst_projection = Projection(dst_crs, WINDOW_RESOLUTION, -WINDOW_RESOLUTION) + dst_geometry = src_geometry.to_projection(dst_projection) + bounds = calculate_bounds(dst_geometry, target_size) + + start_time = datetime(year, 1, 1, tzinfo=timezone.utc) + end_time = datetime(year, 12, 31, tzinfo=timezone.utc) + + dataset = Dataset(ds_path) + window = Window( + storage=dataset.storage, + group=GROUP_NAME, + name=window_name, + projection=dst_projection, + bounds=bounds, + time_range=(start_time, end_time), + options={"split": split}, + ) + window.save() + + raster = upscaled[np.newaxis, :, :].astype(np.uint8) + + raster_out_dir = window.get_raster_dir(LABEL_LAYER, [BAND_NAME]) + GeotiffRasterFormat().encode_raster( + raster_out_dir, + window.projection, + window.bounds, + RasterArray(chw_array=raster), + ) + window.mark_layer_completed(LABEL_LAYER) + return "created" + + +def create_windows_from_csv( + csv_path: UPath, + ds_name: str, + target_size: int, + omit_classes: set[int], + raster_dir: Path, + workers: int, +) -> None: + """Create dense-label windows for all rows in the subsample CSV.""" + rslearn_root = os.environ.get("RSLEARN_EAI_ROOT") + if not rslearn_root: + raise RuntimeError("RSLEARN_EAI_ROOT environment variable is not set") + ds_path = UPath(rslearn_root) / ds_name + + raster_patch_size = target_size // UPSCALE_FACTOR + + df = pd.read_csv(csv_path) + csv_rows = [row for _, row in df.iterrows()] + + print(f"Loaded {len(csv_rows)} rows from {csv_path}") + print(f"Dataset path: {ds_path}") + print(f"Target window size: {target_size}x{target_size} at 10m") + print( + f"Raster patch size: {raster_patch_size}x{raster_patch_size} at ~30m (upscale {UPSCALE_FACTOR}x)" + ) + if omit_classes: + print(f"Omitting classes: {sorted(omit_classes)}") + + jobs = [ + dict( + csv_row=row, + ds_path=ds_path, + target_size=target_size, + omit_classes=omit_classes, + raster_dir=raster_dir, + ) + for row in csv_rows + ] + + counts = {"created": 0, "omit_class": 0, "omit_nodata": 0} + p = multiprocessing.Pool(workers) + outputs = star_imap_unordered(p, create_window, jobs) + for status in tqdm.tqdm(outputs, total=len(jobs)): + counts[status] += 1 + p.close() + + total_omitted = counts["omit_class"] + counts["omit_nodata"] + print("\n" + "=" * 60) + print("WINDOW CREATION SUMMARY (dense raster)") + print("=" * 60) + print(f" Created: {counts['created']}") + print(f" Omitted (class): {counts['omit_class']}") + print(f" Omitted (all nodata): {counts['omit_nodata']}") + print(f" Total omitted: {total_omitted}") + print(f" Total processed: {len(jobs)}") + print("=" * 60) + + +if __name__ == "__main__": + multiprocessing.set_start_method("forkserver") + parser = argparse.ArgumentParser( + description="Create rslearn windows with dense raster labels for MapBiomas.", + ) + parser.add_argument( + "--csv-path", + type=str, + default=str(DEFAULT_CSV), + help="Path to the dense raster subsample CSV.", + ) + parser.add_argument( + "--ds-name", + type=str, + default=DEFAULT_DS_NAME, + help="Dataset name, appended to RSLEARN_EAI_ROOT (default: %(default)s).", + ) + parser.add_argument( + "--raster-dir", + type=str, + default=str(DEFAULT_RASTER_DIR), + help="Directory containing brazil_coverage_{year}.tif rasters.", + ) + parser.add_argument( + "--omit-classes", + type=int, + nargs="*", + default=[], + help="Class IDs to remap to no-data (windows left all-nodata are skipped).", + ) + parser.add_argument( + "--target-size", + type=int, + default=48, + help="Window size in 10m pixels (default: 48). " + "A patch of (target_size / 3) pixels is read from the ~30m raster.", + ) + parser.add_argument( + "--workers", + type=int, + default=16, + help="Number of multiprocessing workers (default: 32).", + ) + args = parser.parse_args() + + create_windows_from_csv( + csv_path=UPath(args.csv_path), + ds_name=args.ds_name, + target_size=args.target_size, + omit_classes=set(args.omit_classes), + raster_dir=Path(args.raster_dir), + workers=args.workers, + ) diff --git a/rslp/mapbiomas/create_windows_expert_sparse.py b/rslp/mapbiomas/create_windows_expert_sparse.py new file mode 100644 index 000000000..e9d8ceca4 --- /dev/null +++ b/rslp/mapbiomas/create_windows_expert_sparse.py @@ -0,0 +1,231 @@ +"""Create rslearn windows with sparse target labels from expert validation points. + +Each window is an *extended* window of size (2 * target_size - 3) pixels at 10m +resolution, centered on the expert point. The expert label is placed as a 3x3 +block at the window center (representing one ~30m validation pixel). All other +pixels are set to the no-data value (0). + +The extended sizing guarantees that any random target_size x target_size crop +from the window will always fully contain the 3x3 labeled block. + +Usage:: + + python -m rslp.mapbiomas.create_windows_expert_sparse \ + --ds-name mapbiomas_3k \ + --omit-classes 29 5 25 +""" + +from __future__ import annotations + +import argparse +import multiprocessing +import os +from datetime import datetime, timezone +from pathlib import Path + +import numpy as np +import pandas as pd +import shapely +import tqdm +from rslearn.const import WGS84_PROJECTION +from rslearn.dataset import Dataset, Window +from rslearn.utils import Projection, STGeometry, get_utm_ups_crs +from rslearn.utils.mp import star_imap_unordered +from rslearn.utils.raster_array import RasterArray +from rslearn.utils.raster_format import GeotiffRasterFormat +from upath import UPath + +from rslp.utils.windows import calculate_bounds + +MY_ROOT = Path(os.environ.get("MY_ROOT", ".")) + +WINDOW_RESOLUTION = 10 +LABEL_LAYER = "label_raster" +BAND_NAME = "label" +NODATA_VALUE = 0 +GROUP_NAME = "mapbiomas_expert_sparse" +LABEL_BLOCK_SIZE = 3 + +DEFAULT_CSV = ( + MY_ROOT / "rslearn_projects/rslp/mapbiomas/subsampling/sample_expert_points_4k.csv" +) +DEFAULT_DS_NAME = "mapbiomas_3k" + + +def compute_extended_window_size(target_size: int) -> int: + """Compute extended window size so every target_size crop contains the 3x3 center.""" + return 2 * target_size - LABEL_BLOCK_SIZE + + +def create_window( + csv_row: pd.Series, + ds_path: UPath, + window_size: int, + omit_classes: set[int], +) -> str: + """Create one sparse-label window from an expert validation point. + + Returns a status string: "created", "omit_class", or "omit_nodata". + """ + class_id = int(csv_row["CLASS"]) + target_id = int(csv_row["TARGETID"]) + year = int(csv_row["YEAR"]) + split = csv_row["split"] + window_name = f"{target_id}_{year}" + + if class_id in omit_classes: + print(f"Omitted window {window_name}: CLASS {class_id} in omit list") + return "omit_class" + + # LON/LAT columns are swapped in the source CSV + longitude = csv_row["LAT"] + latitude = csv_row["LON"] + + src_point = shapely.Point(longitude, latitude) + src_geometry = STGeometry(WGS84_PROJECTION, src_point, None) + dst_crs = get_utm_ups_crs(longitude, latitude) + dst_projection = Projection(dst_crs, WINDOW_RESOLUTION, -WINDOW_RESOLUTION) + dst_geometry = src_geometry.to_projection(dst_projection) + bounds = calculate_bounds(dst_geometry, window_size) + + start_time = datetime(year, 1, 1, tzinfo=timezone.utc) + end_time = datetime(year, 12, 31, tzinfo=timezone.utc) + + dataset = Dataset(ds_path) + window = Window( + storage=dataset.storage, + group=GROUP_NAME, + name=window_name, + projection=dst_projection, + bounds=bounds, + time_range=(start_time, end_time), + options={"split": split}, + ) + window.save() + + raster_h = bounds[3] - bounds[1] + raster_w = bounds[2] - bounds[0] + raster = np.full((1, raster_h, raster_w), NODATA_VALUE, dtype=np.uint8) + + cy, cx = raster_h // 2, raster_w // 2 + half = LABEL_BLOCK_SIZE // 2 + y_lo = max(cy - half, 0) + y_hi = min(cy + half + 1, raster_h) + x_lo = max(cx - half, 0) + x_hi = min(cx + half + 1, raster_w) + raster[0, y_lo:y_hi, x_lo:x_hi] = class_id + + raster_dir = window.get_raster_dir(LABEL_LAYER, [BAND_NAME]) + GeotiffRasterFormat().encode_raster( + raster_dir, + window.projection, + window.bounds, + RasterArray(chw_array=raster), + ) + window.mark_layer_completed(LABEL_LAYER) + return "created" + + +def create_windows_from_csv( + csv_path: UPath, + ds_name: str, + target_size: int, + omit_classes: set[int], + workers: int, +) -> None: + """Create sparse-label windows for all rows in the expert subsample CSV.""" + rslearn_root = os.environ.get("RSLEARN_EAI_ROOT") + if not rslearn_root: + raise RuntimeError("RSLEARN_EAI_ROOT environment variable is not set") + ds_path = UPath(rslearn_root) / ds_name + + window_size = compute_extended_window_size(target_size) + + df = pd.read_csv(csv_path) + csv_rows = [row for _, row in df.iterrows()] + + print(f"Loaded {len(csv_rows)} rows from {csv_path}") + print(f"Dataset path: {ds_path}") + print(f"Target crop size: {target_size}x{target_size}") + print( + f"Extended window size: {window_size}x{window_size} (2*{target_size} - {LABEL_BLOCK_SIZE})" + ) + if omit_classes: + print(f"Omitting classes: {sorted(omit_classes)}") + + jobs = [ + dict( + csv_row=row, + ds_path=ds_path, + window_size=window_size, + omit_classes=omit_classes, + ) + for row in csv_rows + ] + + counts = {"created": 0, "omit_class": 0, "omit_nodata": 0} + p = multiprocessing.Pool(workers) + outputs = star_imap_unordered(p, create_window, jobs) + for status in tqdm.tqdm(outputs, total=len(jobs)): + counts[status] += 1 + p.close() + + total_omitted = counts["omit_class"] + counts["omit_nodata"] + print("\n" + "=" * 60) + print("WINDOW CREATION SUMMARY (expert sparse)") + print("=" * 60) + print(f" Created: {counts['created']}") + print(f" Omitted (class): {counts['omit_class']}") + print(f" Omitted (all nodata): {counts['omit_nodata']}") + print(f" Total omitted: {total_omitted}") + print(f" Total processed: {len(jobs)}") + print("=" * 60) + + +if __name__ == "__main__": + multiprocessing.set_start_method("forkserver") + parser = argparse.ArgumentParser( + description="Create rslearn windows with sparse expert labels for MapBiomas.", + ) + parser.add_argument( + "--csv-path", + type=str, + default=str(DEFAULT_CSV), + help="Path to the expert subsample CSV.", + ) + parser.add_argument( + "--ds-name", + type=str, + default=DEFAULT_DS_NAME, + help="Dataset name, appended to RSLEARN_EAI_ROOT (default: %(default)s).", + ) + parser.add_argument( + "--omit-classes", + type=int, + nargs="*", + default=[], + help="Class IDs to treat as no-data (windows with only omitted classes are skipped).", + ) + parser.add_argument( + "--target-size", + type=int, + default=48, + help="Target crop size in 10m pixels (default: 48). " + "The actual window will be (2*target_size - 3) to ensure any " + "random crop always includes the 3x3 labeled center.", + ) + parser.add_argument( + "--workers", + type=int, + default=16, + help="Number of multiprocessing workers (default: 32).", + ) + args = parser.parse_args() + + create_windows_from_csv( + csv_path=UPath(args.csv_path), + ds_name=args.ds_name, + target_size=args.target_size, + omit_classes=set(args.omit_classes), + workers=args.workers, + ) diff --git a/rslp/mapbiomas/sanity_check.py b/rslp/mapbiomas/sanity_check.py new file mode 100644 index 000000000..a8a12ddd2 --- /dev/null +++ b/rslp/mapbiomas/sanity_check.py @@ -0,0 +1,691 @@ +"""Sanity-check the created rslearn label windows against reference data. + +Validates that: + A) Every window maps to a row in the subsample CSV. + B) TARGETID, CLASS (and LON/LAT for sparse) match the 85k validation shapefile. + C) (dense only) The 48x48 rslearn raster perfectly matches the 16x16 MapBiomas + raster patch upscaled 3x. + D) Produces a 5x5 visualization grid of randomly sampled windows. + +Usage:: + + # Sparse expert labels + python -m rslp.mapbiomas.sanity_check --mode sparse --ds-name mapbiomas_3k + + # Dense raster labels + python -m rslp.mapbiomas.sanity_check --mode dense --ds-name mapbiomas_3k +""" + +from __future__ import annotations + +import argparse +import os +import sys +from pathlib import Path + +import geopandas as gpd +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import rasterio +import rasterio.windows +from pyproj import Transformer + +MY_ROOT = Path(os.environ.get("MY_ROOT", ".")) + +DEFAULT_DS_NAME = "mapbiomas_3k" +DEFAULT_RASTER_DIR = MY_ROOT / "datasets/mapbiomas/data" +DEFAULT_SHP_PATH = ( + MY_ROOT / "datasets/mapbiomas/metadata/mapbiomas_85k_points_validation.shp" +) +DEFAULT_SPARSE_CSV = ( + MY_ROOT / "rslearn_projects/rslp/mapbiomas/subsampling/sample_expert_points_4k.csv" +) +DEFAULT_DENSE_CSV = ( + MY_ROOT / "rslearn_projects/rslp/mapbiomas/subsampling/sample_dense_raster_4k.csv" +) + +NODATA_VALUE = 0 +UPSCALE_FACTOR = 3 +LABEL_BLOCK_SIZE = 3 + +YEAR_RANGE = range(2016, 2023) + +MAPBIOMAS_COLORS = { + 0: "#ffffff", + 1: "#1f8d49", + 3: "#006400", + 4: "#00ff00", + 5: "#687537", + 6: "#76a5af", + 9: "#02d659", + 11: "#519799", + 12: "#d6bc74", + 13: "#d89f5c", + 15: "#edde8e", + 18: "#E974ED", + 19: "#C27BA0", + 20: "#db4d4f", + 21: "#ffefc3", + 23: "#d68fe2", + 24: "#d4271e", + 25: "#9c0027", + 26: "#2532e4", + 29: "#ffaa5f", + 30: "#9065d0", + 31: "#091077", + 32: "#fc8114", + 33: "#2532e4", + 35: "#9065d0", + 36: "#e787f8", + 39: "#f5b3c8", + 40: "#c71585", + 41: "#f54ca9", + 46: "#d68fe2", + 47: "#6b9c32", + 48: "#6b9c32", + 49: "#a1d99b", + 50: "#a1d99b", + 62: "#e6ccff", +} + + +def _class_cmap(data: np.ndarray) -> np.ndarray: + """Map class IDs to RGBA image for visualization.""" + from matplotlib.colors import to_rgba + + h, w = data.shape + rgba = np.ones((h, w, 4), dtype=np.float32) + rgba[:, :, :3] = 0.8 # default grey for unknown classes + + for cls_id, hex_color in MAPBIOMAS_COLORS.items(): + mask = data == cls_id + if not mask.any(): + continue + r, g, b, a = to_rgba(hex_color) + rgba[mask] = [r, g, b, a] + + # Make nodata transparent-ish (white with low alpha) + nodata_mask = data == NODATA_VALUE + rgba[nodata_mask] = [1.0, 1.0, 1.0, 0.3] + + return rgba + + +# --------------------------------------------------------------------------- +# Load reference data +# --------------------------------------------------------------------------- + + +def load_shapefile_reference(shp_path: Path) -> pd.DataFrame: + """Load the 85k validation shapefile and melt to long format. + + Returns DataFrame with columns: TARGETID, LON, LAT, YEAR, CLASS. + """ + gdf = gpd.read_file(shp_path) + static_cols = ["TARGETID", "LON", "LAT"] + frames: list[pd.DataFrame] = [] + for year in YEAR_RANGE: + cls_col = f"CLASS_{year}" + if cls_col not in gdf.columns: + continue + sub = gdf[static_cols + [cls_col]].copy() + sub.columns = [*static_cols, "CLASS"] + sub["YEAR"] = year + sub["CLASS"] = pd.to_numeric(sub["CLASS"], errors="coerce") + frames.append(sub) + long = pd.concat(frames, ignore_index=True) + long = long[long["CLASS"].notna()].copy() + long["CLASS"] = long["CLASS"].astype(int) + long["TARGETID"] = long["TARGETID"].astype(int) + return long + + +# --------------------------------------------------------------------------- +# Check A: every window has a corresponding CSV row +# --------------------------------------------------------------------------- + + +def check_a_csv_lookup( + window_names: list[str], csv_df: pd.DataFrame, mode: str +) -> dict[str, pd.Series]: + """Verify every window name maps to a CSV row. Returns {window_name: row}.""" + csv_df = csv_df.copy() + csv_df["TARGETID"] = csv_df["TARGETID"].astype(int) + csv_df["YEAR"] = csv_df["YEAR"].astype(int) + csv_df["_key"] = csv_df["TARGETID"].astype(str) + "_" + csv_df["YEAR"].astype(str) + csv_lookup = csv_df.set_index("_key") + + ok, fail = 0, 0 + result: dict[str, pd.Series] = {} + for wname in window_names: + if wname in csv_lookup.index: + result[wname] = csv_lookup.loc[wname] + ok += 1 + else: + print(f" [FAIL] Window {wname} not found in {mode} CSV") + fail += 1 + + print( + f" Check A ({mode}): {ok}/{ok + fail} windows found in CSV" + f" ({fail} missing)" + ) + return result + + +# --------------------------------------------------------------------------- +# Check B: cross-reference with shapefile +# --------------------------------------------------------------------------- + + +def check_b_shapefile_reference( + lookup: dict[str, pd.Series], + ref_df: pd.DataFrame, + mode: str, +) -> None: + """Verify TARGETID, CLASS, and (for sparse) LON/LAT match the shapefile.""" + ref_df = ref_df.copy() + ref_df["_key"] = ref_df["TARGETID"].astype(str) + "_" + ref_df["YEAR"].astype(str) + ref_lookup = ref_df.set_index("_key") + + ok, fail_missing, fail_class, fail_lonlat = 0, 0, 0, 0 + for wname, csv_row in lookup.items(): + if wname not in ref_lookup.index: + print(f" [FAIL] Window {wname}: TARGETID/YEAR not in shapefile") + fail_missing += 1 + continue + + ref_row = ref_lookup.loc[wname] + csv_class = int(csv_row["CLASS"]) + ref_class = int(ref_row["CLASS"]) + if csv_class != ref_class: + print( + f" [FAIL] Window {wname}: CLASS mismatch " + f"(csv={csv_class}, shp={ref_class})" + ) + fail_class += 1 + continue + + if mode == "sparse": + # LON/LAT columns are swapped in CSV relative to standard meaning, + # but the shapefile has the same swap, so compare directly. + csv_lon, csv_lat = float(csv_row["LON"]), float(csv_row["LAT"]) + ref_lon, ref_lat = float(ref_row["LON"]), float(ref_row["LAT"]) + if not ( + np.isclose(csv_lon, ref_lon, atol=1e-4) + and np.isclose(csv_lat, ref_lat, atol=1e-4) + ): + print( + f" [FAIL] Window {wname}: LON/LAT mismatch " + f"(csv=({csv_lon:.6f},{csv_lat:.6f}), " + f"shp=({ref_lon:.6f},{ref_lat:.6f}))" + ) + fail_lonlat += 1 + continue + + ok += 1 + + total = ok + fail_missing + fail_class + fail_lonlat + parts = [] + if fail_missing: + parts.append(f"{fail_missing} missing") + if fail_class: + parts.append(f"{fail_class} class mismatch") + if fail_lonlat: + parts.append(f"{fail_lonlat} lon/lat mismatch") + detail = f" ({', '.join(parts)})" if parts else "" + print(f" Check B ({mode}): {ok}/{total} passed{detail}") + + +# --------------------------------------------------------------------------- +# Check C (dense only): raster alignment +# --------------------------------------------------------------------------- + + +def check_c_dense_raster_alignment( + lookup: dict[str, pd.Series], + ds_path: Path, + raster_dir: Path, + target_size: int, +) -> list[str]: + """Verify the rslearn 48x48 label matches the MapBiomas 16x16 patch x3. + + Returns list of window names that failed the check. + """ + raster_patch_size = target_size // UPSCALE_FACTOR + half_patch = raster_patch_size // 2 + group_name = "mapbiomas_dense_raster" + + ok, fail_shape, fail_align = 0, 0, 0 + failed_windows: list[str] = [] + raster_cache: dict[int, rasterio.DatasetReader] = {} + + try: + for wname, csv_row in lookup.items(): + year = int(csv_row["YEAR"]) + # LON/LAT swapped in CSV + longitude = float(csv_row["LAT"]) + latitude = float(csv_row["LON"]) + + # Read reference patch from MapBiomas raster + if year not in raster_cache: + raster_path = raster_dir / f"brazil_coverage_{year}.tif" + raster_cache[year] = rasterio.open(raster_path) + src = raster_cache[year] + r, c = src.index(longitude, latitude) + rio_win = rasterio.windows.Window( + col_off=c - half_patch, + row_off=r - half_patch, + width=raster_patch_size, + height=raster_patch_size, + ) + ref_patch = src.read(1, window=rio_win, boundless=True, fill_value=0) + + # Read rslearn label raster + tif_path = ( + ds_path + / "windows" + / group_name + / wname + / "layers" + / "label_raster" + / "label" + / "geotiff.tif" + ) + if not tif_path.exists(): + print(f" [FAIL] Window {wname}: rslearn geotiff not found") + fail_shape += 1 + failed_windows.append(wname) + continue + + with rasterio.open(tif_path) as rsrc: + rslearn_raster = rsrc.read(1) + + if rslearn_raster.shape != (target_size, target_size): + print( + f" [FAIL] Window {wname}: rslearn raster shape " + f"{rslearn_raster.shape} != ({target_size},{target_size})" + ) + fail_shape += 1 + failed_windows.append(wname) + continue + + # Downsample rslearn raster back to raster_patch_size and compare + downsampled = rslearn_raster[::UPSCALE_FACTOR, ::UPSCALE_FACTOR] + if downsampled.shape != ref_patch.shape: + print( + f" [FAIL] Window {wname}: downsampled shape " + f"{downsampled.shape} != ref {ref_patch.shape}" + ) + fail_shape += 1 + failed_windows.append(wname) + continue + + # Only compare pixels where both sides have data. The rslearn + # raster may have nodata where omitted classes (e.g. 31) were + # zeroed out during window creation. + both_valid = (downsampled != NODATA_VALUE) & (ref_patch != NODATA_VALUE) + mismatch = (downsampled != ref_patch) & both_valid + if mismatch.any(): + n_diff = int(mismatch.sum()) + print( + f" [FAIL] Window {wname}: {n_diff}/{int(both_valid.sum())} " + f"valid pixels differ" + ) + fail_align += 1 + failed_windows.append(wname) + continue + + ok += 1 + finally: + for src in raster_cache.values(): + src.close() + + total = ok + fail_shape + fail_align + parts = [] + if fail_shape: + parts.append(f"{fail_shape} shape/missing") + if fail_align: + parts.append(f"{fail_align} pixel mismatch") + detail = f" ({', '.join(parts)})" if parts else "" + print(f" Check C (dense): {ok}/{total} passed{detail}") + return failed_windows + + +# --------------------------------------------------------------------------- +# Check D: visualization +# --------------------------------------------------------------------------- + + +def visualize_sparse( + lookup: dict[str, pd.Series], + ds_path: Path, + out_path: Path, + n: int = 25, + seed: int = 42, +) -> None: + """5x5 grid of sparse windows with labeled pixels highlighted.""" + group_name = "mapbiomas_expert_sparse" + rng = np.random.default_rng(seed) + keys = list(lookup.keys()) + chosen = rng.choice(keys, size=min(n, len(keys)), replace=False) + + rows, cols = 5, 5 + fig, axes = plt.subplots(rows, cols, figsize=(18, 18)) + fig.suptitle("Sparse expert label windows (labeled pixels colored)", fontsize=14) + + for idx, ax in enumerate(axes.flat): + if idx >= len(chosen): + ax.axis("off") + continue + + wname = chosen[idx] + csv_row = lookup[wname] + ref_class = int(csv_row["CLASS"]) + + tif_path = ( + ds_path + / "windows" + / group_name + / wname + / "layers" + / "label_raster" + / "label" + / "geotiff.tif" + ) + if not tif_path.exists(): + ax.set_title(f"{wname}\nMISSING", fontsize=7) + ax.axis("off") + continue + + with rasterio.open(tif_path) as src: + raster = src.read(1) + raster_crs = src.crs + raster_transform = src.transform + + rgba = _class_cmap(raster) + ax.imshow(rgba, interpolation="nearest") + + # Overlay the reference lat/lon as a cyan dot (projected to pixel coords) + ref_lon = float(csv_row["LAT"]) # LON/LAT swapped in CSV + ref_lat = float(csv_row["LON"]) + transformer = Transformer.from_crs("EPSG:4326", raster_crs, always_xy=True) + ref_x, ref_y = transformer.transform(ref_lon, ref_lat) + inv_transform = ~raster_transform + ref_col_px, ref_row_px = inv_transform * (ref_x, ref_y) + ax.plot( + ref_col_px, + ref_row_px, + "o", + color="cyan", + markersize=6, + markeredgecolor="black", + markeredgewidth=0.8, + zorder=5, + ) + + label_vals = np.unique(raster[raster != NODATA_VALUE]) + label_str = ",".join(str(v) for v in label_vals) + ax.set_title( + f"{wname}\nlabel={label_str} ref={ref_class}", + fontsize=7, + ) + ax.set_xticks([]) + ax.set_yticks([]) + + plt.tight_layout() + fig.savefig(out_path, dpi=150, bbox_inches="tight") + print(f" Saved sparse visualization to {out_path}") + plt.close(fig) + + +def visualize_dense( + lookup: dict[str, pd.Series], + ds_path: Path, + raster_dir: Path, + target_size: int, + out_path: Path, + n: int = 25, + seed: int = 42, + priority_windows: list[str] | None = None, +) -> None: + """5x5 grid: each cell shows rslearn 48x48 and reference 16x16 side by side. + + Windows in *priority_windows* (e.g. those that failed Check C) are shown + first; remaining slots are filled by random sampling. + """ + group_name = "mapbiomas_dense_raster" + raster_patch_size = target_size // UPSCALE_FACTOR + half_patch = raster_patch_size // 2 + rng = np.random.default_rng(seed) + keys = list(lookup.keys()) + + priority = [] + if priority_windows: + priority = [w for w in priority_windows if w in lookup][:n] + remaining_n = n - len(priority) + remaining_pool = [k for k in keys if k not in set(priority)] + random_pick = ( + list( + rng.choice( + remaining_pool, + size=min(remaining_n, len(remaining_pool)), + replace=False, + ) + ) + if remaining_n > 0 and remaining_pool + else [] + ) + chosen = priority + random_pick + + rows, cols = 5, 5 + fig, axes = plt.subplots(rows, cols * 2, figsize=(24, 18)) + fig.suptitle( + "Dense raster label windows (left: rslearn 48x48, right: reference 16x16)", + fontsize=14, + ) + + raster_cache: dict[int, rasterio.DatasetReader] = {} + + try: + for idx in range(rows * cols): + ax_left = axes[idx // cols, (idx % cols) * 2] + ax_right = axes[idx // cols, (idx % cols) * 2 + 1] + + if idx >= len(chosen): + ax_left.axis("off") + ax_right.axis("off") + continue + + wname = chosen[idx] + csv_row = lookup[wname] + year = int(csv_row["YEAR"]) + longitude = float(csv_row["LAT"]) + latitude = float(csv_row["LON"]) + + # rslearn raster + tif_path = ( + ds_path + / "windows" + / group_name + / wname + / "layers" + / "label_raster" + / "label" + / "geotiff.tif" + ) + if not tif_path.exists(): + ax_left.set_title(f"{wname}\nMISSING", fontsize=6) + ax_left.axis("off") + ax_right.axis("off") + continue + + with rasterio.open(tif_path) as src: + rslearn_raster = src.read(1) + + # Reference raster + if year not in raster_cache: + rp = raster_dir / f"brazil_coverage_{year}.tif" + raster_cache[year] = rasterio.open(rp) + src_ref = raster_cache[year] + r, c = src_ref.index(longitude, latitude) + rio_win = rasterio.windows.Window( + col_off=c - half_patch, + row_off=r - half_patch, + width=raster_patch_size, + height=raster_patch_size, + ) + ref_patch = src_ref.read(1, window=rio_win, boundless=True, fill_value=0) + + is_failed = priority_windows and wname in set(priority_windows) + title_prefix = "[FAIL C] " if is_failed else "" + title_color = "red" if is_failed else "black" + + ax_left.imshow(_class_cmap(rslearn_raster), interpolation="nearest") + ax_left.set_title( + f"{title_prefix}{wname}\nrslearn {rslearn_raster.shape}", + fontsize=6, + color=title_color, + ) + ax_left.set_xticks([]) + ax_left.set_yticks([]) + + ax_right.imshow(_class_cmap(ref_patch), interpolation="nearest") + ax_right.set_title(f"ref {ref_patch.shape}", fontsize=6) + ax_right.set_xticks([]) + ax_right.set_yticks([]) + finally: + for src in raster_cache.values(): + src.close() + + plt.tight_layout() + fig.savefig(out_path, dpi=150, bbox_inches="tight") + print(f" Saved dense visualization to {out_path}") + plt.close(fig) + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + + +def main() -> None: + """Sanity-check MapBiomas rslearn label windows.""" + parser = argparse.ArgumentParser( + description="Sanity-check MapBiomas rslearn label windows.", + ) + parser.add_argument( + "--mode", + type=str, + required=True, + choices=["sparse", "dense"], + help="Which window set to check.", + ) + parser.add_argument( + "--ds-name", + type=str, + default=DEFAULT_DS_NAME, + help="Dataset name under RSLEARN_EAI_ROOT.", + ) + parser.add_argument( + "--shp-path", + type=str, + default=str(DEFAULT_SHP_PATH), + help="Path to the 85k validation shapefile.", + ) + parser.add_argument( + "--raster-dir", + type=str, + default=str(DEFAULT_RASTER_DIR), + help="Directory containing brazil_coverage_{year}.tif rasters (dense only).", + ) + parser.add_argument( + "--target-size", + type=int, + default=48, + help="Target crop size in 10m pixels (default: 48).", + ) + parser.add_argument( + "--seed", + type=int, + default=42, + help="Random seed for visualization sampling.", + ) + parser.add_argument( + "--out-dir", + type=str, + default=str(MY_ROOT / "rslearn_projects/rslp/mapbiomas"), + help="Directory for output plots.", + ) + args = parser.parse_args() + + rslearn_root = os.environ.get("RSLEARN_EAI_ROOT") + if not rslearn_root: + raise RuntimeError("RSLEARN_EAI_ROOT environment variable is not set") + ds_path = Path(rslearn_root) / args.ds_name + + mode = args.mode + raster_dir = Path(args.raster_dir) + shp_path = Path(args.shp_path) + out_dir = Path(args.out_dir) + target_size = args.target_size + + if mode == "sparse": + group_name = "mapbiomas_expert_sparse" + csv_path = DEFAULT_SPARSE_CSV + else: + group_name = "mapbiomas_dense_raster" + csv_path = DEFAULT_DENSE_CSV + + windows_dir = ds_path / "windows" / group_name + if not windows_dir.exists(): + print(f"ERROR: windows directory not found: {windows_dir}") + sys.exit(1) + + window_names = sorted([p.name for p in windows_dir.iterdir() if p.is_dir()]) + print(f"Found {len(window_names)} windows in {windows_dir}") + + # Load reference data + print(f"Loading subsample CSV: {csv_path}") + csv_df = pd.read_csv(csv_path) + print(f"Loading shapefile: {shp_path}") + ref_df = load_shapefile_reference(shp_path) + + # --- Check A --- + print("\n--- Check A: CSV lookup ---") + lookup = check_a_csv_lookup(window_names, csv_df, mode) + + # --- Check B --- + print("\n--- Check B: Shapefile cross-reference ---") + check_b_shapefile_reference(lookup, ref_df, mode) + + # --- Check C (dense only) --- + failed_c: list[str] = [] + if mode == "dense": + print("\n--- Check C: Dense raster alignment ---") + failed_c = check_c_dense_raster_alignment( + lookup, ds_path, raster_dir, target_size + ) + + # --- Check D: Visualization --- + print(f"\n--- Check D: Visualization ({mode}) ---") + out_dir.mkdir(parents=True, exist_ok=True) + plot_path = out_dir / f"sanity_check_{mode}.png" + if mode == "sparse": + visualize_sparse(lookup, ds_path, plot_path, seed=args.seed) + else: + visualize_dense( + lookup, + ds_path, + raster_dir, + target_size, + plot_path, + seed=args.seed, + priority_windows=failed_c, + ) + + print("\nDone.") + + +if __name__ == "__main__": + main() diff --git a/rslp/mapbiomas/subsampling/__init__.py b/rslp/mapbiomas/subsampling/__init__.py new file mode 100644 index 000000000..be217b121 --- /dev/null +++ b/rslp/mapbiomas/subsampling/__init__.py @@ -0,0 +1 @@ +"""MapBiomas balanced subsampling utilities.""" diff --git a/rslp/mapbiomas/subsampling/sample_dense_raster.py b/rslp/mapbiomas/subsampling/sample_dense_raster.py new file mode 100644 index 000000000..7ffa3bcd6 --- /dev/null +++ b/rslp/mapbiomas/subsampling/sample_dense_raster.py @@ -0,0 +1,563 @@ +"""Sample dense raster windows around expert-point locations. + +For each point in the subsample CSV produced by sample_expert_points.py, +reads a search-size patch from the matching year's MapBiomas raster, samples +n-candidates random window-size sub-windows, scores each by weighted +minority-class pixel count, and keeps the best one. Outputs an updated +subsample CSV whose LON/LAT columns reflect the selected window centre, +plus per-class/per-split statistics. + +Setting --search-size equal to --window-size with --n-candidates 1 reproduces +a simple centred-window extraction (no optimisation). +""" + +from __future__ import annotations + +import argparse +import os +from pathlib import Path + +import numpy as np +import pandas as pd +import rasterio +from rasterio.windows import Window + +MY_ROOT = Path(os.environ.get("MY_ROOT", ".")) + +DEFAULT_CSV = ( + MY_ROOT / "rslearn_projects/rslp/mapbiomas/subsampling/sample_expert_points_4k.csv" +) +DEFAULT_RASTER_DIR = MY_ROOT / "datasets/mapbiomas/data" +DEFAULT_OUT_DIR = MY_ROOT / "rslearn_projects/rslp/mapbiomas/subsampling" +DEFAULT_LEGEND = ( + MY_ROOT / "datasets/mapbiomas/metadata/Codigos-da-legenda-colecao-10.csv" +) +DEFAULT_HIERARCHY = MY_ROOT / "datasets/mapbiomas/metadata/hierarchy.csv" + +MINORITY_CLASSES: set[int] = { + 6, + 41, + 5, + 29, + 46, + 48, + 40, + 47, + 25, + 50, + 49, + 35, + 32, + 23, + 30, + 62, +} + + +# --------------------------------------------------------------------------- +# Legend / summary utilities +# --------------------------------------------------------------------------- + + +def load_legend(path: Path) -> dict[int, str]: + """Return {class_id: english_description} from the MapBiomas legend CSV.""" + legend = pd.read_csv(path, sep="\t") + legend["Description"] = legend["Description"].str.strip() + return dict(zip(legend["Class_ID"].astype(int), legend["Description"])) + + +def load_hierarchy(path: Path) -> dict[int, dict]: + """Parse the hierarchy CSV and return per-class hierarchy info. + + Returns ``{class_id: {"leaf_level": int, + "parent_class_id": int | None, + "parent_leaf_level": int | None, + "parent_class_desc": str | None}}``. + """ + h = pd.read_csv(path) + result: dict[int, dict] = {} + for cid, grp in h.groupby("Class_ID"): + cid = int(cid) + leaf_level = int(grp["Leaf_Level"].iloc[0]) + + if leaf_level > 1: + parent_row = grp[grp["Hierarchy_Level"] == leaf_level - 1].iloc[0] + parent_class_id = int(parent_row["Level_Class_ID"]) + parent_leaf_level = leaf_level - 1 + parent_class_desc = str(parent_row["Level_Description"]) + else: + parent_class_id = None + parent_leaf_level = None + parent_class_desc = None + + result[cid] = { + "leaf_level": leaf_level, + "parent_class_id": parent_class_id, + "parent_leaf_level": parent_leaf_level, + "parent_class_desc": parent_class_desc, + } + return result + + +def build_summary( + per_window: pd.DataFrame, + legend: dict[int, str], + window_size: int, + hierarchy: dict[int, dict] | None = None, +) -> pd.DataFrame: + """Aggregate per-window counts into a global summary with train/val breakdown.""" + class_cols = [c for c in per_window.columns if c.startswith("class_")] + class_ids = [int(c.split("_")[1]) for c in class_cols] + + pixels_per_window = window_size * window_size + + rows: list[dict] = [] + for cid, col in zip(class_ids, class_cols): + vals_all = per_window[col].values + total_all = int(vals_all.sum()) + + rec: dict = { + "class_id": cid, + "class_name": legend.get(cid, "unknown"), + "leaf_level": None, + "parent_class_id": None, + "parent_leaf_level": None, + "parent_class_desc": None, + "total_pixels": total_all, + "frac_of_all": total_all / (len(per_window) * pixels_per_window), + "mean_per_window": float(vals_all.mean()), + "std_per_window": float(vals_all.std()), + } + + if hierarchy and cid in hierarchy: + hi = hierarchy[cid] + rec["leaf_level"] = hi["leaf_level"] + rec["parent_class_id"] = hi["parent_class_id"] + rec["parent_leaf_level"] = hi["parent_leaf_level"] + rec["parent_class_desc"] = hi["parent_class_desc"] + + for split in ("train", "val"): + mask = per_window["split"] == split + vals = per_window.loc[mask, col].values + n_windows = int(mask.sum()) + total = int(vals.sum()) + rec[f"{split}_total_pixels"] = total + rec[f"{split}_frac"] = ( + total / (n_windows * pixels_per_window) if n_windows else 0.0 + ) + rec[f"{split}_mean_per_window"] = float(vals.mean()) if n_windows else 0.0 + rec[f"{split}_std_per_window"] = float(vals.std()) if n_windows else 0.0 + + rows.append(rec) + + summary = ( + pd.DataFrame(rows) + .sort_values("total_pixels", ascending=False) + .reset_index(drop=True) + ) + return summary + + +def print_summary( + summary: pd.DataFrame, per_window: pd.DataFrame, window_size: int +) -> None: + """Print a human-readable summary to stdout.""" + pixels_per_window = window_size * window_size + n_total = len(per_window) + n_train = (per_window["split"] == "train").sum() + n_val = (per_window["split"] == "val").sum() + + print("\n" + "=" * 140) + print("WINDOW CLASS STATISTICS SUMMARY") + print( + f" Window size: {window_size}x{window_size} ({pixels_per_window} pixels/window)" + ) + print(f" Samples: {n_total} total ({n_train} train, {n_val} val)") + print(f" Total pixels across all windows: {n_total * pixels_per_window:,}") + print("=" * 140) + + has_hierarchy = ( + "leaf_level" in summary.columns and summary["leaf_level"].notna().any() + ) + + if has_hierarchy: + fmt = ( + "{:<6s} {:<35s} {:<5s} {:<6s} {:<25s}" + " {:>12s} {:>8s} {:>12s} {:>8s} {:>12s} {:>8s}" + ) + print( + fmt.format( + "ID", + "Class", + "Lvl", + "ParID", + "Parent", + "All pixels", + "All %", + "Train px", + "Train %", + "Val px", + "Val %", + ) + ) + else: + fmt = "{:<6s} {:<35s} {:>12s} {:>8s} {:>12s} {:>8s} {:>12s} {:>8s}" + print( + fmt.format( + "ID", + "Class", + "All pixels", + "All %", + "Train px", + "Train %", + "Val px", + "Val %", + ) + ) + print("-" * 140) + + for _, r in summary.iterrows(): + if has_hierarchy: + lvl = str(int(r["leaf_level"])) if pd.notna(r["leaf_level"]) else "" + par_id = ( + str(int(r["parent_class_id"])) + if pd.notna(r["parent_class_id"]) + else "-" + ) + par_desc = ( + str(r["parent_class_desc"])[:25] + if pd.notna(r["parent_class_desc"]) + else "-" + ) + print( + fmt.format( + str(int(r["class_id"])), + r["class_name"][:35], + lvl, + par_id, + par_desc, + f"{int(r['total_pixels']):,}", + f"{r['frac_of_all']:.3%}", + f"{int(r['train_total_pixels']):,}", + f"{r['train_frac']:.3%}", + f"{int(r['val_total_pixels']):,}", + f"{r['val_frac']:.3%}", + ) + ) + else: + print( + ( + "{:<6s} {:<35s} {:>12s} {:>8s} {:>12s} {:>8s} {:>12s} {:>8s}" + ).format( + str(int(r["class_id"])), + r["class_name"][:35], + f"{int(r['total_pixels']):,}", + f"{r['frac_of_all']:.3%}", + f"{int(r['train_total_pixels']):,}", + f"{r['train_frac']:.3%}", + f"{int(r['val_total_pixels']):,}", + f"{r['val_frac']:.3%}", + ) + ) + print("=" * 140 + "\n") + + +# Weights = 1 / observed_fraction from the baseline window_class_stats run. +MINORITY_WEIGHTS: dict[int, float] = { + 6: 1.0 / 0.01899, + 41: 1.0 / 0.01792, + 5: 1.0 / 0.00772, + 29: 1.0 / 0.00634, + 46: 1.0 / 0.00583, + 48: 1.0 / 0.00474, + 40: 1.0 / 0.00246, + 47: 1.0 / 0.00235, + 25: 1.0 / 0.00183, + 50: 1.0 / 0.00081, + 49: 1.0 / 0.00066, + 35: 1.0 / 0.00059, + 32: 1.0 / 0.00045, + 23: 1.0 / 0.00041, + 30: 1.0 / 0.00034, + 62: 1.0 / 0.00033, +} + +# Pre-build a uint8-indexed lookup table for fast vectorised scoring. +_WEIGHT_LUT = np.zeros(256, dtype=np.float64) +for _cls, _w in MINORITY_WEIGHTS.items(): + _WEIGHT_LUT[_cls] = _w + + +def _score_subwindow(patch: np.ndarray, row_off: int, col_off: int, size: int) -> float: + """Sum minority weights for all pixels in the sub-window.""" + sub = patch[row_off : row_off + size, col_off : col_off + size] + return float(_WEIGHT_LUT[sub].sum()) + + +def optimize_windows( + df: pd.DataFrame, + raster_dir: Path, + window_size: int, + search_size: int, + n_candidates: int, + rng: np.random.Generator, +) -> tuple[pd.DataFrame, pd.DataFrame]: + """Find the best 32x32 sub-window per sample and return updated CSV + per-window counts. + + Returns: + ------- + optimized_df : pd.DataFrame + Same schema as the input subsample CSV but with LON/LAT shifted to the + selected window centre. + per_window : pd.DataFrame + Per-window class counts (same format as window_class_stats output). + """ + half_search = search_size // 2 + half_win = window_size // 2 + max_offset = ( + search_size - window_size + ) # valid sub-window origin range [0, max_offset] + + updated_rows: list[dict] = [] + count_records: list[dict] = [] + + for year, group in df.groupby("YEAR"): + raster_path = raster_dir / f"brazil_coverage_{year}.tif" + if not raster_path.exists(): + raise FileNotFoundError(f"Raster not found: {raster_path}") + + print(f" Year {year}: {len(group)} samples from {raster_path.name}") + + with rasterio.open(raster_path) as src: + for _, row in group.iterrows(): + actual_lon = row["LAT"] + actual_lat = row["LON"] + r, c = src.index(actual_lon, actual_lat) + + big_win = Window( + col_off=c - half_search, + row_off=r - half_search, + width=search_size, + height=search_size, + ) + patch = src.read(1, window=big_win, boundless=True, fill_value=0) + + offsets = np.column_stack( + [ + rng.integers(0, max_offset + 1, size=n_candidates), + rng.integers(0, max_offset + 1, size=n_candidates), + ] + ) + + best_score = -1.0 + best_idx = 0 + for i in range(n_candidates): + ro, co = int(offsets[i, 0]), int(offsets[i, 1]) + score = _score_subwindow(patch, ro, co, window_size) + if score > best_score: + best_score = score + best_idx = i + + best_ro = int(offsets[best_idx, 0]) + best_co = int(offsets[best_idx, 1]) + + # Raster row/col of the selected sub-window centre + new_r = (r - half_search) + best_ro + half_win + new_c = (c - half_search) + best_co + half_win + new_lon, new_lat = src.xy(new_r, new_c) + + # Build updated CSV row (LON/LAT columns are swapped in source) + updated_rows.append( + { + "TARGETID": int(row["TARGETID"]), + "LON": new_lat, + "LAT": new_lon, + "YEAR": int(row["YEAR"]), + "CLASS": int(row["CLASS"]), + "BORDA": int(row["BORDA"]), + "COUNT": int(row["COUNT"]), + "CARTA_2": row["CARTA_2"], + "DECLIVIDAD": row["DECLIVIDAD"], + "split": row["split"], + } + ) + + # Per-window class counts for the selected sub-window + sub = patch[ + best_ro : best_ro + window_size, best_co : best_co + window_size + ] + codes, counts = np.unique(sub, return_counts=True) + rec: dict = { + "TARGETID": int(row["TARGETID"]), + "YEAR": int(row["YEAR"]), + "center_class": int(row["CLASS"]), + "split": row["split"], + } + for code, cnt in zip(codes.tolist(), counts.tolist()): + rec[f"class_{code}"] = cnt + count_records.append(rec) + + optimized_df = pd.DataFrame(updated_rows) + out_cols = [ + "TARGETID", + "LON", + "LAT", + "YEAR", + "CLASS", + "BORDA", + "COUNT", + "CARTA_2", + "DECLIVIDAD", + "split", + ] + optimized_df = ( + optimized_df[out_cols] + .sort_values( + ["split", "CLASS", "YEAR"], + ) + .reset_index(drop=True) + ) + + per_window = pd.DataFrame(count_records).fillna(0) + class_cols = sorted( + [c for c in per_window.columns if c.startswith("class_")], + key=lambda c: int(c.split("_")[1]), + ) + meta_cols = ["TARGETID", "YEAR", "center_class", "split"] + per_window = per_window[meta_cols + class_cols] + for col in class_cols: + per_window[col] = per_window[col].astype(int) + + return optimized_df, per_window + + +def main() -> None: + """Select minority-boosting 32x32 windows for each subsample point.""" + parser = argparse.ArgumentParser( + description="Select minority-boosting 32x32 windows for each subsample point.", + ) + parser.add_argument( + "--csv-path", + type=str, + default=str(DEFAULT_CSV), + help="Path to the subsample CSV (output of sample_expert_points.py).", + ) + parser.add_argument( + "--raster-dir", + type=str, + default=str(DEFAULT_RASTER_DIR), + help="Directory containing brazil_coverage_{year}.tif rasters.", + ) + parser.add_argument( + "--window-size", + type=int, + default=16, + help="Side length of the target window in pixels (default: 16).", + ) + parser.add_argument( + "--search-size", + type=int, + default=512, + help="Side length of the search neighbourhood in pixels (default: 512).", + ) + parser.add_argument( + "--n-candidates", + type=int, + default=64, + help="Number of random sub-window candidates to evaluate (default: 64).", + ) + parser.add_argument( + "--seed", + type=int, + default=42, + help="Random seed for reproducibility.", + ) + parser.add_argument( + "--out-dir", + type=str, + default=str(DEFAULT_OUT_DIR), + help="Directory for output files.", + ) + parser.add_argument( + "--legend-path", + type=str, + default=str(DEFAULT_LEGEND), + help="Path to the MapBiomas legend CSV (tab-separated).", + ) + parser.add_argument( + "--hierarchy-path", + type=str, + default=str(DEFAULT_HIERARCHY), + help="Path to the MapBiomas hierarchy CSV.", + ) + args = parser.parse_args() + + csv_path = Path(args.csv_path) + raster_dir = Path(args.raster_dir) + out_dir = Path(args.out_dir) + legend_path = Path(args.legend_path) + hierarchy_path = Path(args.hierarchy_path) + window_size = args.window_size + search_size = args.search_size + + legend = load_legend(legend_path) + hierarchy = load_hierarchy(hierarchy_path) if hierarchy_path.exists() else None + minority_names = { + cid: legend.get(cid, str(cid)) for cid in sorted(MINORITY_CLASSES) + } + + print("=" * 70) + print("MINORITY-BOOSTING WINDOW SELECTION") + print("=" * 70) + print( + f" Search window: {search_size}x{search_size} pixels centred on each point" + ) + print(f" Target window: {window_size}x{window_size} pixels") + print(f" Candidates: {args.n_candidates} random sub-windows per point") + print(f" Seed: {args.seed}") + print(f" Minority classes ({len(MINORITY_CLASSES)}):") + for cid in sorted(MINORITY_CLASSES): + print( + f" {cid:>3d} {minority_names[cid]:<35s} weight={MINORITY_WEIGHTS[cid]:,.1f}" + ) + print("\n Scoring: for each candidate sub-window, sum the weight of every") + print(" minority-class pixel. Keep the sub-window with the highest score.") + print("=" * 70) + + print(f"\nLoading subsample CSV: {csv_path}") + df = pd.read_csv(csv_path) + print( + f" {len(df)} samples ({(df['split'] == 'train').sum()} train, " + f"{(df['split'] == 'val').sum()} val)" + ) + + rng = np.random.default_rng(args.seed) + + print( + f"\nSearching {search_size}x{search_size} neighbourhood, " + f"{args.n_candidates} candidates per point …" + ) + optimized_df, per_window = optimize_windows( + df, + raster_dir, + window_size, + search_size, + args.n_candidates, + rng, + ) + + out_dir.mkdir(parents=True, exist_ok=True) + csv_out = out_dir / "sample_dense_raster_4k.csv" + optimized_df.to_csv(csv_out, index=False) + print(f"\nWrote optimized subsample: {csv_out} ({len(optimized_df)} rows)") + + summary = build_summary(per_window, legend, window_size, hierarchy) + summary_path = out_dir / "sample_dense_raster_window_stats_summary.csv" + summary.to_csv(summary_path, index=False) + print(f"Wrote summary stats: {summary_path} ({len(summary)} rows)") + + print_summary(summary, per_window, window_size) + + +if __name__ == "__main__": + main() diff --git a/rslp/mapbiomas/subsampling/sample_expert_points.py b/rslp/mapbiomas/subsampling/sample_expert_points.py new file mode 100644 index 000000000..d94005cd8 --- /dev/null +++ b/rslp/mapbiomas/subsampling/sample_expert_points.py @@ -0,0 +1,631 @@ +"""Build a balanced 4k subsample (3k train + 1k val) from MapBiomas validation points. + +Criteria: +- Balanced class representation across as many classes as possible (water-filling). +- Uniform sampling across CARTA_2 tiles. +- Best-effort DECLIVIDAD representation and year uniformity (2016-2022). +- Only best data points (COUNT_year == 3). +- 25% edge pixels (BORDA_year == 1), 75% interior (BORDA_year == 0) best-effort. +- Each spatial pixel (TARGETID) used at most once across all years. +- Excluded classes: 13, 23, 27, 30, 31, 32, 50. +""" + +from __future__ import annotations + +import argparse +import os +from collections import Counter +from pathlib import Path + +import geopandas as gpd +import numpy as np +import pandas as pd + +MY_ROOT = Path(os.environ.get("MY_ROOT", ".")) + +DEFAULT_LEGEND = ( + MY_ROOT / "datasets/mapbiomas/metadata/Codigos-da-legenda-colecao-10.csv" +) +DEFAULT_HIERARCHY = MY_ROOT / "datasets/mapbiomas/metadata/hierarchy.csv" + +EXCLUDED_CLASSES: set[int] = { + 13, + 23, + 27, + 30, + 31, + 32, + 50, +} # less than 100 points in data +YEAR_RANGE: range = range(2016, 2023) # 2016-2022 inclusive + + +# --------------------------------------------------------------------------- +# Step 1: load & melt +# --------------------------------------------------------------------------- + + +def load_and_melt(shp_path: str | Path) -> pd.DataFrame: + """Load the shapefile and melt year-specific columns into long format. + + Returns a DataFrame with columns: + TARGETID, LON, LAT, YEAR, CLASS, BORDA, COUNT, CARTA_2, DECLIVIDAD + filtered to COUNT == 3 and classes not in EXCLUDED_CLASSES. + """ + gdf = gpd.read_file(shp_path) + + static_cols = ["TARGETID", "LON", "LAT", "CARTA_2", "DECLIVIDAD"] + frames: list[pd.DataFrame] = [] + for year in YEAR_RANGE: + cls_col = f"CLASS_{year}" + cnt_col = f"COUNT_{year}" + brd_col = f"BORDA_{year}" + if cls_col not in gdf.columns: + continue + sub = gdf[static_cols + [cls_col, cnt_col, brd_col]].copy() + sub.columns = [*static_cols, "CLASS", "COUNT", "BORDA"] + sub["YEAR"] = year + sub["COUNT"] = pd.to_numeric(sub["COUNT"], errors="coerce") + sub["CLASS"] = pd.to_numeric(sub["CLASS"], errors="coerce") + sub["BORDA"] = pd.to_numeric(sub["BORDA"], errors="coerce") + frames.append(sub) + + long = pd.concat(frames, ignore_index=True) + + # Filter: best quality only, drop excluded classes and NaN class + long = long[long["COUNT"] == 3].copy() + long = long[long["CLASS"].notna()].copy() + long["CLASS"] = long["CLASS"].astype(int) + long = long[~long["CLASS"].isin(EXCLUDED_CLASSES)].copy() + + long = long.reset_index(drop=True) + return long + + +# --------------------------------------------------------------------------- +# Step 2: water-filling quota +# --------------------------------------------------------------------------- + + +def compute_quotas( + long: pd.DataFrame, + total: int, +) -> dict[int, int]: + """Water-fill per-class quotas so sum == total, capped by unique-pixel availability.""" + avail = long.groupby("CLASS")["TARGETID"].nunique().to_dict() + classes = sorted(avail.keys()) + + sorted_caps = sorted(avail[c] for c in classes) + n = len(classes) + level = 0.0 + remaining = total + + for i, cap in enumerate(sorted_caps): + slots = n - i + if (cap - level) * slots >= remaining: + level += remaining / slots + break + remaining -= int(cap - level) * slots + level = cap + else: + level = sorted_caps[-1] + + base = int(np.floor(level)) + quotas = {c: min(avail[c], base) for c in classes} + deficit = total - sum(quotas.values()) + + # distribute remainder one-at-a-time to classes that still have room + frac_parts: list[tuple[float, int]] = [] + for c in classes: + headroom = avail[c] - quotas[c] + if headroom > 0: + frac_parts.append((level - int(np.floor(level)), c)) + frac_parts.sort(key=lambda x: -x[0]) + + idx = 0 + while deficit > 0: + for c in classes: + if deficit <= 0: + break + if avail[c] - quotas[c] > 0: + quotas[c] += 1 + deficit -= 1 + idx += 1 + if idx > len(classes): + break + + return quotas + + +# --------------------------------------------------------------------------- +# Step 3: stratified selection (rarest class first) +# --------------------------------------------------------------------------- + + +def select_pixels( + long: pd.DataFrame, + quotas: dict[int, int], + edge_frac: float, + rng: np.random.Generator, +) -> pd.DataFrame: + """Select pixels per class, rarest first, with CARTA_2/DECLIVIDAD/YEAR balance. + + Each TARGETID is used at most once globally. Edge/interior mix is best-effort. + """ + used_ids: set[int] = set() + selected_rows: list[pd.DataFrame] = [] + + class_order = sorted(quotas.keys(), key=lambda c: quotas[c]) + active_classes = [c for c in class_order if quotas[c] > 0] + total_quota = sum(quotas[c] for c in active_classes) + cumulative = 0 + + for i, cls in enumerate(active_classes, 1): + quota = quotas[cls] + + pool = long[(long["CLASS"] == cls) & (~long["TARGETID"].isin(used_ids))].copy() + if pool.empty: + print( + f" [{i}/{len(active_classes)}] class {cls:>2d}: " + f"quota={quota}, pool empty — skipped" + ) + continue + + n_edge_target = int(round(edge_frac * quota)) + n_interior_target = quota - n_edge_target + + edge_pool = pool[pool["BORDA"] == 1] + interior_pool = pool[pool["BORDA"] == 0] + + chosen_edge = _stratified_pick( + edge_pool, + n_edge_target, + used_ids, + rng, + label=f"class {cls} edge", + ) + newly_used = set(chosen_edge["TARGETID"]) + used_ids.update(newly_used) + + interior_pool = interior_pool[~interior_pool["TARGETID"].isin(used_ids)] + chosen_interior = _stratified_pick( + interior_pool, + n_interior_target, + used_ids, + rng, + label=f"class {cls} interior", + ) + used_ids.update(chosen_interior["TARGETID"]) + + # Backfill shortfall from the other pool + got = len(chosen_edge) + len(chosen_interior) + shortfall = quota - got + if shortfall > 0: + remaining = pool[~pool["TARGETID"].isin(used_ids)] + backfill = _stratified_pick( + remaining, + shortfall, + used_ids, + rng, + label=f"class {cls} backfill", + ) + used_ids.update(backfill["TARGETID"]) + selected_rows.extend([chosen_edge, chosen_interior, backfill]) + got += len(backfill) + else: + selected_rows.extend([chosen_edge, chosen_interior]) + + cumulative += got + print( + f" [{i}/{len(active_classes)}] class {cls:>2d}: " + f"picked {got}/{quota} — cumulative {cumulative}/{total_quota}" + ) + + result = pd.concat(selected_rows, ignore_index=True) + return result + + +def _stratified_pick( + pool: pd.DataFrame, + n: int, + used_ids: set[int], + rng: np.random.Generator, + label: str = "", +) -> pd.DataFrame: + """Pick up to *n* rows from *pool* spread across CARTA_2 / DECLIVIDAD / YEAR. + + Within each tile, prefer rows that fill under-represented DECLIVIDAD and YEAR + buckets. Each picked TARGETID is immediately added to used_ids so no pixel + repeats. + """ + if n <= 0 or pool.empty: + return pool.iloc[:0] + + pool = pool[~pool["TARGETID"].isin(used_ids)].copy() + if pool.empty: + return pool + + pool = pool.sample(frac=1, random_state=int(rng.integers(2**31))).reset_index( + drop=True + ) + + tiles = sorted(pool["CARTA_2"].unique()) + tile_pools: dict[str, pd.DataFrame] = { + t: pool[pool["CARTA_2"] == t].copy() for t in tiles + } + + year_counts: Counter[int] = Counter() + decliv_counts: Counter[str] = Counter() + picked_indices: list[int] = [] + local_used: set[int] = set() + + log_interval = max(1, n // 5) + tag = f" {label}: " if label else " pick: " + + tile_idx = 0 + stall_counter = 0 + while len(picked_indices) < n and stall_counter < len(tiles) + 1: + tile = tiles[tile_idx % len(tiles)] + tile_idx += 1 + tp = tile_pools[tile] + tp = tp[~tp["TARGETID"].isin(local_used)] + tile_pools[tile] = tp + if tp.empty: + stall_counter += 1 + continue + stall_counter = 0 + + scores = np.zeros(len(tp)) + for i, (_, row) in enumerate(tp.iterrows()): + yr = row["YEAR"] + dc = row["DECLIVIDAD"] + scores[i] = -(year_counts.get(yr, 0) + decliv_counts.get(dc, 0)) + + best_idx = int(np.argmax(scores)) + chosen_row = tp.iloc[best_idx] + picked_indices.append(tp.index[best_idx]) + tid = chosen_row["TARGETID"] + local_used.add(tid) + used_ids.add(tid) + year_counts[chosen_row["YEAR"]] += 1 + decliv_counts[chosen_row["DECLIVIDAD"]] += 1 + + for t2 in tiles: + tp2 = tile_pools[t2] + if not tp2.empty: + tile_pools[t2] = tp2[tp2["TARGETID"] != tid] + + picked = len(picked_indices) + if picked % log_interval == 0 or picked == n: + print(f"{tag}{picked}/{n}", flush=True) + + if not picked_indices: + return pool.iloc[:0] + return pool.loc[pool.index.isin(picked_indices)].copy() + + +# --------------------------------------------------------------------------- +# Step 4: train / val split +# --------------------------------------------------------------------------- + + +def train_val_split( + selected: pd.DataFrame, + n_train: int, + n_val: int, + rng: np.random.Generator, +) -> pd.DataFrame: + """Stratified per-class 75/25 train/val split.""" + total = n_train + n_val + train_frac = n_train / total + + parts: list[pd.DataFrame] = [] + for cls, grp in selected.groupby("CLASS"): + grp = grp.sample(frac=1, random_state=int(rng.integers(2**31))) + n_tr = int(round(len(grp) * train_frac)) + grp = grp.copy() + grp["split"] = "val" + grp.iloc[:n_tr, grp.columns.get_loc("split")] = "train" + parts.append(grp) + + result = pd.concat(parts, ignore_index=True) + + # Adjust global counts to hit exact targets + n_train_actual = (result["split"] == "train").sum() + diff = n_train_actual - n_train + if diff > 0: + train_idx = ( + result[result["split"] == "train"] + .sample(n=diff, random_state=int(rng.integers(2**31))) + .index + ) + result.loc[train_idx, "split"] = "val" + elif diff < 0: + val_idx = ( + result[result["split"] == "val"] + .sample(n=-diff, random_state=int(rng.integers(2**31))) + .index + ) + result.loc[val_idx, "split"] = "train" + + return result + + +# --------------------------------------------------------------------------- +# Step 5: summary report +# --------------------------------------------------------------------------- + + +def load_legend(path: Path) -> dict[int, str]: + """Return {class_id: english_description} from the MapBiomas legend CSV.""" + legend = pd.read_csv(path, sep="\t") + legend["Description"] = legend["Description"].str.strip() + return dict(zip(legend["Class_ID"].astype(int), legend["Description"])) + + +def load_hierarchy(path: Path) -> dict[int, dict]: + """Parse the hierarchy CSV and return per-class hierarchy info. + + Returns ``{class_id: {"leaf_level": int, + "parent_class_id": int | None, + "parent_leaf_level": int | None, + "parent_class_desc": str | None}}``. + """ + h = pd.read_csv(path) + result: dict[int, dict] = {} + for cid, grp in h.groupby("Class_ID"): + cid = int(cid) + leaf_level = int(grp["Leaf_Level"].iloc[0]) + + if leaf_level > 1: + parent_row = grp[grp["Hierarchy_Level"] == leaf_level - 1].iloc[0] + parent_class_id = int(parent_row["Level_Class_ID"]) + parent_leaf_level = leaf_level - 1 + parent_class_desc = str(parent_row["Level_Description"]) + else: + parent_class_id = None + parent_leaf_level = None + parent_class_desc = None + + result[cid] = { + "leaf_level": leaf_level, + "parent_class_id": parent_class_id, + "parent_leaf_level": parent_leaf_level, + "parent_class_desc": parent_class_desc, + } + return result + + +def build_summary( + df: pd.DataFrame, + legend: dict[int, str], + hierarchy: dict[int, dict] | None = None, +) -> pd.DataFrame: + """Build a class-level summary with train/val breakdown and diversity metrics.""" + n_total = len(df) + rows: list[dict] = [] + + for cls, grp in df.groupby("CLASS"): + n = len(grp) + cid = int(cls) + train_mask = grp["split"] == "train" + n_train = int(train_mask.sum()) + n_val = n - n_train + + rec: dict = { + "class_id": cid, + "class_name": legend.get(cid, "unknown"), + "leaf_level": None, + "parent_class_id": None, + "parent_leaf_level": None, + "parent_class_desc": None, + "total_points": n, + "frac_of_all": n / n_total, + "train_points": n_train, + "train_frac": n_train / n if n else 0.0, + "val_points": n_val, + "val_frac": n_val / n if n else 0.0, + "edge_frac": float((grp["BORDA"] == 1).mean()), + "n_tiles": int(grp["CARTA_2"].nunique()), + "n_years": int(grp["YEAR"].nunique()), + } + + if hierarchy and cid in hierarchy: + hi = hierarchy[cid] + rec["leaf_level"] = hi["leaf_level"] + rec["parent_class_id"] = hi["parent_class_id"] + rec["parent_leaf_level"] = hi["parent_leaf_level"] + rec["parent_class_desc"] = hi["parent_class_desc"] + + rows.append(rec) + + summary = ( + pd.DataFrame(rows) + .sort_values("total_points", ascending=False) + .reset_index(drop=True) + ) + return summary + + +def print_summary( + df: pd.DataFrame, + summary: pd.DataFrame | None = None, +) -> None: + """Print a diagnostic summary of the subsample.""" + print("\n" + "=" * 70) + print("SUBSAMPLE SUMMARY") + print("=" * 70) + + print(f"\nTotal rows: {len(df)}") + print(f" Train: {(df['split'] == 'train').sum()}") + print(f" Val: {(df['split'] == 'val').sum()}") + print(f" Unique TARGETIDs: {df['TARGETID'].nunique()} (should == {len(df)})") + + print("\n--- Per-class counts ---") + cls_split = df.groupby(["CLASS", "split"]).size().unstack(fill_value=0) + cls_split["total"] = cls_split.sum(axis=1) + print(cls_split.to_string()) + + if ( + summary is not None + and "leaf_level" in summary.columns + and summary["leaf_level"].notna().any() + ): + print("\n--- Class hierarchy ---") + fmt = "{:<6s} {:<35s} {:<5s} {:<6s} {:<5s} {:<25s}" + print(fmt.format("ID", "Class", "Lvl", "ParID", "ParLv", "Parent")) + print("-" * 90) + for _, r in summary.sort_values("class_id").iterrows(): + lvl = str(int(r["leaf_level"])) if pd.notna(r["leaf_level"]) else "" + par_id = ( + str(int(r["parent_class_id"])) + if pd.notna(r["parent_class_id"]) + else "-" + ) + par_lv = ( + str(int(r["parent_leaf_level"])) + if pd.notna(r["parent_leaf_level"]) + else "-" + ) + par_desc = ( + str(r["parent_class_desc"])[:25] + if pd.notna(r["parent_class_desc"]) + else "-" + ) + print( + fmt.format( + str(int(r["class_id"])), + str(r["class_name"])[:35], + lvl, + par_id, + par_lv, + par_desc, + ) + ) + + print("\n--- BORDA distribution ---") + borda_split = df.groupby(["BORDA", "split"]).size().unstack(fill_value=0) + borda_split["total"] = borda_split.sum(axis=1) + print(borda_split.to_string()) + n_edge = (df["BORDA"] == 1).sum() + print(f" Edge fraction: {n_edge / len(df):.3f} (target 0.250)") + + print("\n--- YEAR distribution ---") + year_split = df.groupby(["YEAR", "split"]).size().unstack(fill_value=0) + year_split["total"] = year_split.sum(axis=1) + print(year_split.to_string()) + + print("\n--- DECLIVIDAD distribution ---") + dec_split = df.groupby(["DECLIVIDAD", "split"]).size().unstack(fill_value=0) + dec_split["total"] = dec_split.sum(axis=1) + print(dec_split.to_string()) + + print("\n--- CARTA_2 tile coverage ---") + print(f" Tiles represented: {df['CARTA_2'].nunique()}") + + print("=" * 70 + "\n") + + +# --------------------------------------------------------------------------- +# main +# --------------------------------------------------------------------------- + + +def main() -> None: + """Build a balanced MapBiomas subsample (3k train + 1k val).""" + parser = argparse.ArgumentParser( + description="Build a balanced MapBiomas subsample (3k train + 1k val)." + ) + parser.add_argument( + "--shp-path", + type=str, + default=str( + MY_ROOT / "datasets/mapbiomas/metadata/mapbiomas_85k_points_validation.shp" + ), + help="Path to the 85k validation shapefile.", + ) + parser.add_argument( + "--out-path", + type=str, + default=str( + MY_ROOT + / "rslearn_projects/rslp/mapbiomas/subsampling/sample_expert_points_4k.csv" + ), + help="Output CSV path.", + ) + parser.add_argument("--n-train", type=int, default=3000) + parser.add_argument("--n-val", type=int, default=1000) + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--edge-frac", type=float, default=0.25) + parser.add_argument( + "--legend-path", + type=str, + default=str(DEFAULT_LEGEND), + help="Path to the MapBiomas legend CSV (tab-separated).", + ) + parser.add_argument( + "--hierarchy-path", + type=str, + default=str(DEFAULT_HIERARCHY), + help="Path to the MapBiomas hierarchy CSV.", + ) + args = parser.parse_args() + + total = args.n_train + args.n_val + rng = np.random.default_rng(args.seed) + + print(f"Loading shapefile: {args.shp_path}") + long = load_and_melt(args.shp_path) + n_unique = long["TARGETID"].nunique() + n_classes = long["CLASS"].nunique() + print(f" Long-format rows (COUNT==3, kept classes): {len(long)}") + print(f" Unique pixels: {n_unique}, Classes: {n_classes}") + + quotas = compute_quotas(long, total) + print(f"\nPer-class quotas (total {sum(quotas.values())}):") + for cls in sorted(quotas, key=lambda c: -quotas[c]): + avail = long[long["CLASS"] == cls]["TARGETID"].nunique() + print(f" class {cls:>2d}: quota={quotas[cls]:>4d} (avail={avail})") + + print("\nSelecting pixels …") + selected = select_pixels(long, quotas, args.edge_frac, rng) + print(f" Selected: {len(selected)} rows") + + selected = train_val_split(selected, args.n_train, args.n_val, rng) + + out_cols = [ + "TARGETID", + "LON", + "LAT", + "YEAR", + "CLASS", + "BORDA", + "COUNT", + "CARTA_2", + "DECLIVIDAD", + "split", + ] + selected = ( + selected[out_cols] + .sort_values(["split", "CLASS", "YEAR"]) + .reset_index(drop=True) + ) + + out_path = Path(args.out_path) + out_path.parent.mkdir(parents=True, exist_ok=True) + selected.to_csv(out_path, index=False) + print(f"\nWrote {len(selected)} rows to {out_path}") + + legend = load_legend(Path(args.legend_path)) + hierarchy_path = Path(args.hierarchy_path) + hierarchy = load_hierarchy(hierarchy_path) if hierarchy_path.exists() else None + summary = build_summary(selected, legend, hierarchy) + summary_path = out_path.parent / "sample_expert_points_summary.csv" + summary.to_csv(summary_path, index=False) + print(f"Wrote summary stats: {summary_path} ({len(summary)} rows)") + + print_summary(selected, summary) + + +if __name__ == "__main__": + main() diff --git a/rslp/mapbiomas/subsampling/visualize_sample_expert_points.py b/rslp/mapbiomas/subsampling/visualize_sample_expert_points.py new file mode 100644 index 000000000..66d55f259 --- /dev/null +++ b/rslp/mapbiomas/subsampling/visualize_sample_expert_points.py @@ -0,0 +1,191 @@ +"""Visualize the MapBiomas subsample produced by sample_expert_points.py. + +Generates a 2x3 figure: + (0,0) Geographic scatter – train only + (0,1) Geographic scatter – val only + (0,2) CARTA_2 histogram (overlaid train/val) + (1,0) Class distribution bar chart + (1,1) Year distribution bar chart + (1,2) DECLIVIDAD distribution bar chart +""" + +from __future__ import annotations + +import argparse +import os +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd + +MY_ROOT = Path(os.environ.get("MY_ROOT", ".")) + +SPLIT_COLORS = {"train": "#1f77b4", "val": "#ff7f0e"} + + +def main() -> None: + """Visualize the MapBiomas 4k subsample.""" + parser = argparse.ArgumentParser( + description="Visualize the MapBiomas 4k subsample." + ) + parser.add_argument( + "--csv-path", + type=str, + default=str( + MY_ROOT + / "rslearn_projects/rslp/mapbiomas/subsampling/sample_expert_points_4k.csv" + ), + help="Path to the subsample CSV produced by sample_expert_points.py.", + ) + parser.add_argument( + "--out-path", + type=str, + default=None, + help="If set, save the figure to this path instead of showing it.", + ) + args = parser.parse_args() + + df = pd.read_csv(args.csv_path) + train = df[df["split"] == "train"] + val = df[df["split"] == "val"] + + fig, axes = plt.subplots(2, 3, figsize=(22, 12)) + fig.suptitle("MapBiomas Subsample Overview", fontsize=16, fontweight="bold") + + # --- (0,0) Geographic scatter – train --- + ax = axes[0, 0] + ax.scatter( + train["LAT"], + train["LON"], + s=4, + alpha=0.4, + color=SPLIT_COLORS["train"], + rasterized=True, + ) + ax.set_xlabel("Longitude") + ax.set_ylabel("Latitude") + ax.set_title(f"Train ({len(train)} pts)") + + # --- (0,1) Geographic scatter – val --- + ax = axes[0, 1] + ax.scatter( + val["LAT"], + val["LON"], + s=4, + alpha=0.4, + color=SPLIT_COLORS["val"], + rasterized=True, + ) + ax.set_xlabel("Longitude") + ax.set_ylabel("Latitude") + ax.set_title(f"Val ({len(val)} pts)") + + # Share axis limits between the two geo plots + all_lon = df["LAT"] + all_lat = df["LON"] + lon_pad = (all_lon.max() - all_lon.min()) * 0.03 + lat_pad = (all_lat.max() - all_lat.min()) * 0.03 + for a in axes[0, :2]: + a.set_xlim(all_lon.min() - lon_pad, all_lon.max() + lon_pad) + a.set_ylim(all_lat.min() - lat_pad, all_lat.max() + lat_pad) + + # --- (0,2) CARTA_2 histogram --- + ax = axes[0, 2] + tiles_sorted = sorted(df["CARTA_2"].unique()) + tile_to_idx = {t: i for i, t in enumerate(tiles_sorted)} + train_idx = train["CARTA_2"].map(tile_to_idx) + val_idx = val["CARTA_2"].map(tile_to_idx) + bins = np.linspace(-0.5, len(tiles_sorted) - 0.5, min(50, len(tiles_sorted)) + 1) + ax.hist(train_idx, bins=bins, alpha=0.5, label="train", color=SPLIT_COLORS["train"]) + ax.hist(val_idx, bins=bins, alpha=0.5, label="val", color=SPLIT_COLORS["val"]) + ax.set_xlabel("CARTA_2 tile (index)") + ax.set_ylabel("Count") + ax.set_title(f"CARTA_2 Distribution ({len(tiles_sorted)} tiles)") + ax.legend() + + # --- (1,0) Class distribution --- + _grouped_bar( + axes[1, 0], + df, + group_col="CLASS", + title="Class Distribution", + xlabel="Class", + ylabel="Count", + ) + + # --- (1,1) Year distribution --- + _grouped_bar( + axes[1, 1], + df, + group_col="YEAR", + title="Year Distribution", + xlabel="Year", + ylabel="Count", + ) + + # --- (1,2) DECLIVIDAD distribution --- + _grouped_bar( + axes[1, 2], + df, + group_col="DECLIVIDAD", + title="DECLIVIDAD Distribution", + xlabel="DECLIVIDAD", + ylabel="Count", + ) + + fig.tight_layout(rect=[0, 0, 1, 0.96]) + + if args.out_path: + out = Path(args.out_path) + out.parent.mkdir(parents=True, exist_ok=True) + fig.savefig(out, dpi=150, bbox_inches="tight") + print(f"Saved figure to {out}") + else: + plt.show() + + +def _grouped_bar( + ax: plt.Axes, + df: pd.DataFrame, + group_col: str, + title: str, + xlabel: str, + ylabel: str, +) -> None: + """Draw a side-by-side bar chart of train vs val counts for *group_col*.""" + ct = df.groupby([group_col, "split"]).size().unstack(fill_value=0) + for split in ["train", "val"]: + if split not in ct.columns: + ct[split] = 0 + ct = ct[["train", "val"]].sort_index() + + labels = [str(v) for v in ct.index] + x = range(len(labels)) + bar_width = 0.4 + + ax.bar( + [i - bar_width / 2 for i in x], + ct["train"], + width=bar_width, + label="train", + color=SPLIT_COLORS["train"], + ) + ax.bar( + [i + bar_width / 2 for i in x], + ct["val"], + width=bar_width, + label="val", + color=SPLIT_COLORS["val"], + ) + + ax.set_xticks(list(x)) + ax.set_xticklabels(labels, rotation=45, ha="right") + ax.set_xlabel(xlabel) + ax.set_ylabel(ylabel) + ax.set_title(title) + ax.legend() + + +if __name__ == "__main__": + main() diff --git a/rslp/utils/beaker.py b/rslp/utils/beaker.py index fe5ce52b4..9e94be03e 100644 --- a/rslp/utils/beaker.py +++ b/rslp/utils/beaker.py @@ -7,7 +7,7 @@ from beaker.client import Beaker DEFAULT_WORKSPACE = "ai2/earth-systems" -DEFAULT_BUDGET = "ai2/es-platform" +DEFAULT_BUDGET = "ai2/atec-olmoearth" @dataclass From 4c4ffc1fa5b36819e92306c9b6514abf42aefa34 Mon Sep 17 00:00:00 2001 From: root Date: Tue, 9 Jun 2026 21:31:30 +0000 Subject: [PATCH 6/6] deleted lfmc files wrongly ported over this mapbiomas branch --- .../v2_lfmc/subsample/woody3k_lprobe.yaml | 109 ---- .../v2_lfmc/subsample/woody3k_unet.yaml | 101 ---- rslp/lfmc/compute_target_stats.py | 285 ---------- rslp/lfmc/subsample_oep_eval.py | 486 ------------------ rslp/lfmc/visualize_oep_eval.py | 182 ------- 5 files changed, 1163 deletions(-) delete mode 100644 data/helios/v2_lfmc/subsample/woody3k_lprobe.yaml delete mode 100644 data/helios/v2_lfmc/subsample/woody3k_unet.yaml delete mode 100644 rslp/lfmc/compute_target_stats.py delete mode 100644 rslp/lfmc/subsample_oep_eval.py delete mode 100644 rslp/lfmc/visualize_oep_eval.py diff --git a/data/helios/v2_lfmc/subsample/woody3k_lprobe.yaml b/data/helios/v2_lfmc/subsample/woody3k_lprobe.yaml deleted file mode 100644 index d048e0e10..000000000 --- a/data/helios/v2_lfmc/subsample/woody3k_lprobe.yaml +++ /dev/null @@ -1,109 +0,0 @@ -model: - class_path: rslearn.train.lightning_module.RslearnLightningModule - init_args: - model: - class_path: rslearn.models.multitask.MultiTaskModel - init_args: - encoder: - - class_path: rslearn.models.olmoearth_pretrain.model.OlmoEarth - init_args: - model_id: OLMOEARTH_V1_1_BASE - patch_size: 4 - decoders: - lfmc_estimation: - - class_path: rslearn.models.upsample.Upsample - init_args: - scale_factor: 4 - - class_path: rslearn.models.conv.Conv - init_args: - in_channels: 768 - out_channels: 1 - kernel_size: 1 - activation: - class_path: torch.nn.Identity - - class_path: rslearn.train.tasks.per_pixel_regression.PerPixelRegressionHead - lr: 0.00002 - scheduler: - class_path: rslearn.train.scheduler.PlateauScheduler - init_args: - factor: 0.2 - patience: 2 - min_lr: 0 - cooldown: 10 -data: - class_path: rslearn.train.data_module.RslearnDataModule - init_args: - path: /weka/dfive-default/rslearn-eai/datasets/lfmc/20251023/woody/scratch/dataset - inputs: - sentinel2_l2a: - data_type: "raster" - layers: ["sentinel2"] - bands: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12", "B01", "B09"] - passthrough: true - dtype: FLOAT32 - load_all_item_groups: true - load_all_layers: true - targets: - data_type: "raster" - layers: ["labels"] - bands: ["value"] - dtype: FLOAT32 - is_target: true - task: - class_path: rslearn.train.tasks.multi_task.MultiTask - init_args: - tasks: - lfmc_estimation: - class_path: rslearn.train.tasks.per_pixel_regression.PerPixelRegressionTask - init_args: - nodata_value: -1 - metrics: ["rmse", "r2"] - #scale_factor: 0.0024038 - target_mean: 108.389520 - target_std: 36.616698 - - input_mapping: - lfmc_estimation: - targets: "targets" - batch_size: 32 - num_workers: 16 - default_config: - crop_size: 32 - transforms: - - class_path: rslearn.models.olmoearth_pretrain.norm.OlmoEarthNormalize - init_args: - band_names: - sentinel2_l2a: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12", "B01", "B09"] - train_config: - groups: ["spatial_split"] - tags: - split: "train" - oep_eval_big: "" - val_config: - groups: ["spatial_split"] - tags: - split: "val" - oep_eval_big: "" - test_config: - groups: ["spatial_split"] - tags: - split: "test" - oep_eval_big: "" -trainer: - max_epochs: 100 - callbacks: - - class_path: lightning.pytorch.callbacks.LearningRateMonitor - init_args: - logging_interval: "epoch" - - class_path: rslearn.train.callbacks.checkpointing.ManagedBestLastCheckpoint - init_args: - monitor: val_loss - mode: min - - class_path: rslearn.train.callbacks.freeze_unfreeze.FreezeUnfreeze - init_args: - module_selector: ["model", "encoder", 0] - unfreeze_at_epoch: 50 - unfreeze_lr_factor: 10 -management_dir: ${RSLP_PREFIX}/projects -project_name: 20260528_lfmc_finetune -run_name: placeholder diff --git a/data/helios/v2_lfmc/subsample/woody3k_unet.yaml b/data/helios/v2_lfmc/subsample/woody3k_unet.yaml deleted file mode 100644 index a3c570f03..000000000 --- a/data/helios/v2_lfmc/subsample/woody3k_unet.yaml +++ /dev/null @@ -1,101 +0,0 @@ -model: - class_path: rslearn.train.lightning_module.RslearnLightningModule - init_args: - model: - class_path: rslearn.models.multitask.MultiTaskModel - init_args: - encoder: - - class_path: rslearn.models.olmoearth_pretrain.model.OlmoEarth - init_args: - model_id: OLMOEARTH_V1_1_BASE - patch_size: 4 - decoders: - lfmc_estimation: - - class_path: rslearn.models.unet.UNetDecoder - init_args: - in_channels: [[4, 768]] - out_channels: 1 - conv_layers_per_resolution: 2 - num_channels: {8: 512, 4: 512, 2: 256, 1: 128} - - class_path: rslearn.train.tasks.per_pixel_regression.PerPixelRegressionHead - lr: 0.00002 - scheduler: - class_path: rslearn.train.scheduler.PlateauScheduler - init_args: - factor: 0.2 - patience: 2 - min_lr: 0 - cooldown: 10 -data: - class_path: rslearn.train.data_module.RslearnDataModule - init_args: - path: /weka/dfive-default/rslearn-eai/datasets/lfmc/20251023/woody/scratch/dataset - inputs: - sentinel2_l2a: - data_type: "raster" - layers: ["sentinel2"] - bands: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12", "B01", "B09"] - passthrough: true - dtype: FLOAT32 - load_all_item_groups: true - load_all_layers: true - targets: - data_type: "raster" - layers: ["labels"] - bands: ["value"] - dtype: FLOAT32 - is_target: true - task: - class_path: rslearn.train.tasks.multi_task.MultiTask - init_args: - tasks: - lfmc_estimation: - class_path: rslearn.train.tasks.per_pixel_regression.PerPixelRegressionTask - init_args: - nodata_value: -1 - metrics: ["rmse", "r2"] - input_mapping: - lfmc_estimation: - targets: "targets" - batch_size: 32 - num_workers: 16 - default_config: - crop_size: 32 - transforms: - - class_path: rslearn.models.olmoearth_pretrain.norm.OlmoEarthNormalize - init_args: - band_names: - sentinel2_l2a: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12", "B01", "B09"] - train_config: - groups: ["spatial_split"] - tags: - split: "train" - oep_eval_big: "" - val_config: - groups: ["spatial_split"] - tags: - split: "val" - oep_eval_big: "" - test_config: - groups: ["spatial_split"] - tags: - split: "test" - oep_eval_big: "" -trainer: - max_epochs: 100 - callbacks: - - class_path: lightning.pytorch.callbacks.LearningRateMonitor - init_args: - logging_interval: "epoch" - - class_path: rslearn.train.callbacks.checkpointing.ManagedBestLastCheckpoint - init_args: - monitor: val_loss - mode: min - - class_path: rslearn.train.callbacks.freeze_unfreeze.FreezeUnfreeze - init_args: - module_selector: ["model", "encoder", 0] - unfreeze_at_epoch: 20 - unfreeze_lr_factor: 10 -management_dir: ${RSLP_PREFIX}/projects -project_name: 20260528_lfmc_finetune -run_name: placeholder diff --git a/rslp/lfmc/compute_target_stats.py b/rslp/lfmc/compute_target_stats.py deleted file mode 100644 index 96e51b745..000000000 --- a/rslp/lfmc/compute_target_stats.py +++ /dev/null @@ -1,285 +0,0 @@ -"""Compute mean/std of LFMC target values over a split (and optional tag). - -These statistics are intended to be plugged into the ``target_mean`` / ``target_std`` -arguments of ``rslearn.train.tasks.per_pixel_regression.PerPixelRegressionTask`` so that -the regression target is normalized during training while metrics are still reported in -the original units. - -The statistics are computed only over *valid* pixels: pixels equal to the ignore value -(``-1`` by default, matching the task's ``nodata_value``) and non-finite pixels are -excluded. Windows are filtered the same way the training data module filters them: by -group, by the ``split`` option, and (optionally) by requiring a tag option to be present. - -IMPORTANT: only compute these statistics on the train split to avoid leaking validation -or test information into the normalization. - -Usage: - python -m rslp.lfmc.compute_target_stats \ - --dataset_path /weka/dfive-default/rslearn-eai/datasets/lfmc/20251023/woody/scratch/dataset \ - --split train \ - --tag oep_eval_big -""" - -import argparse -import json -import math -import os -from collections.abc import Iterable -from concurrent.futures import ProcessPoolExecutor, as_completed -from dataclasses import dataclass - -import numpy as np -import rasterio - - -@dataclass -class PixelStats: - """Streaming accumulator for per-pixel statistics over valid pixels.""" - - count: int = 0 - total: float = 0.0 - total_sq: float = 0.0 - minimum: float = math.inf - maximum: float = -math.inf - - def merge(self, other: "PixelStats") -> None: - """Merge another accumulator into this one.""" - self.count += other.count - self.total += other.total - self.total_sq += other.total_sq - self.minimum = min(self.minimum, other.minimum) - self.maximum = max(self.maximum, other.maximum) - - @property - def mean(self) -> float: - """Mean over valid pixels.""" - if self.count == 0: - raise ValueError("no valid pixels found") - return self.total / self.count - - def std(self, ddof: int = 0) -> float: - """Standard deviation over valid pixels. - - Args: - ddof: delta degrees of freedom. Use 0 (population std) for normalization. - """ - if self.count - ddof <= 0: - raise ValueError("not enough valid pixels to compute std") - # Var = E[x^2] - E[x]^2, scaled to the requested ddof. Clamp tiny negatives that - # can arise from floating point error. - variance = (self.total_sq - self.total**2 / self.count) / (self.count - ddof) - return math.sqrt(max(variance, 0.0)) - - -def _matches_filters( - options: dict, - split: str | None, - tag: str | None, -) -> bool: - """Return whether a window's options pass the split/tag filters. - - Mirrors the tag filtering done by rslearn's ModelDataset: the tag key must be present - in the window options (an empty configured value just requires presence). - """ - if split is not None and options.get("split") != split: - return False - if tag is not None and tag not in options: - return False - return True - - -def _select_window_dirs( - dataset_path: str, - group: str, - split: str | None, - tag: str | None, -) -> list[str]: - """Find window directories matching the split/tag filters.""" - windows_dir = os.path.join(dataset_path, "windows", group) - names = sorted(os.listdir(windows_dir)) - selected = [] - for name in names: - window_dir = os.path.join(windows_dir, name) - meta_path = os.path.join(window_dir, "metadata.json") - if not os.path.isfile(meta_path): - continue - with open(meta_path) as f: - meta = json.load(f) - if _matches_filters(meta.get("options", {}), split, tag): - selected.append(window_dir) - return selected - - -# Globals set per worker process so the geotiff path can be built without re-passing -# constant arguments for every task. -_LAYER = "labels" -_IGNORE_VALUE = -1.0 - - -def _init_worker(layer: str, ignore_value: float) -> None: - global _LAYER, _IGNORE_VALUE - _LAYER = layer - _IGNORE_VALUE = ignore_value - - -def _window_stats(window_dir: str) -> PixelStats: - """Read one window's label raster and accumulate stats over valid pixels.""" - stats = PixelStats() - geotiff_path = os.path.join(window_dir, "layers", _LAYER, "value", "geotiff.tif") - if not os.path.isfile(geotiff_path): - return stats - with rasterio.open(geotiff_path) as ds: - array = ds.read().astype(np.float64) - valid = np.isfinite(array) & (array != _IGNORE_VALUE) - values = array[valid] - if values.size == 0: - return stats - stats.count = int(values.size) - stats.total = float(values.sum()) - stats.total_sq = float(np.square(values).sum()) - stats.minimum = float(values.min()) - stats.maximum = float(values.max()) - return stats - - -def compute_stats( - window_dirs: Iterable[str], - layer: str, - ignore_value: float, - workers: int, -) -> PixelStats: - """Compute aggregate pixel statistics across the given window directories.""" - window_dirs = list(window_dirs) - total = PixelStats() - if workers <= 1: - _init_worker(layer, ignore_value) - for i, window_dir in enumerate(window_dirs): - total.merge(_window_stats(window_dir)) - if (i + 1) % 5000 == 0: - print(f" ... processed {i + 1}/{len(window_dirs)} windows") - return total - - with ProcessPoolExecutor( - max_workers=workers, - initializer=_init_worker, - initargs=(layer, ignore_value), - ) as executor: - futures = { - executor.submit(_window_stats, window_dir): window_dir - for window_dir in window_dirs - } - for i, future in enumerate(as_completed(futures)): - total.merge(future.result()) - if (i + 1) % 5000 == 0: - print(f" ... processed {i + 1}/{len(window_dirs)} windows") - return total - - -def main() -> None: - """Compute and report LFMC target mean/std for a split.""" - parser = argparse.ArgumentParser( - description="Compute mean/std of LFMC target values over a split." - ) - parser.add_argument("--dataset_path", type=str, required=True) - parser.add_argument( - "--group", - type=str, - default="spatial_split", - help="Window group to read (default: spatial_split).", - ) - parser.add_argument( - "--split", - type=str, - default="train", - help="Only include windows whose 'split' option equals this. " - "Pass an empty string to disable the split filter.", - ) - parser.add_argument( - "--tag", - type=str, - default=None, - help="If set, only include windows that have this tag option (e.g. " - "oep_eval_big).", - ) - parser.add_argument( - "--layer", - type=str, - default="labels", - help="Raster layer holding the target (default: labels).", - ) - parser.add_argument( - "--ignore_value", - type=float, - default=-1.0, - help="Pixel value to treat as invalid/nodata (default: -1).", - ) - parser.add_argument( - "--workers", - type=int, - default=16, - help="Number of worker processes for reading rasters (default: 16).", - ) - parser.add_argument( - "--output", - type=str, - default=None, - help="Optional path to write the computed stats as JSON.", - ) - args = parser.parse_args() - - split = args.split if args.split else None - - print( - f"Selecting windows from group '{args.group}' " - f"(split={split}, tag={args.tag})..." - ) - window_dirs = _select_window_dirs(args.dataset_path, args.group, split, args.tag) - print(f" Selected {len(window_dirs)} windows") - if not window_dirs: - raise RuntimeError("no windows matched the given filters") - - print("Reading target rasters and accumulating statistics...") - stats = compute_stats(window_dirs, args.layer, args.ignore_value, args.workers) - - if stats.count == 0: - raise RuntimeError("no valid target pixels found") - - mean = stats.mean - std_pop = stats.std(ddof=0) - std_sample = stats.std(ddof=1) - - result = { - "dataset_path": args.dataset_path, - "group": args.group, - "split": split, - "tag": args.tag, - "layer": args.layer, - "ignore_value": args.ignore_value, - "num_windows": len(window_dirs), - "num_valid_pixels": stats.count, - "mean": mean, - "std": std_pop, - "std_sample": std_sample, - "min": stats.minimum, - "max": stats.maximum, - } - - print("\n=== LFMC target statistics (valid pixels only) ===") - print(f" windows: {len(window_dirs)}") - print(f" valid pixels: {stats.count}") - print(f" mean: {mean:.6f}") - print(f" std (pop): {std_pop:.6f}") - print(f" std (sample): {std_sample:.6f}") - print(f" min / max: {stats.minimum:.6f} / {stats.maximum:.6f}") - print("\nUse these in the PerPixelRegressionTask config:") - print(f" target_mean: {mean:.6f}") - print(f" target_std: {std_pop:.6f}") - - if args.output: - with open(args.output, "w") as f: - json.dump(result, f, indent=2) - print(f"\nWrote stats to {args.output}") - - -if __name__ == "__main__": - main() diff --git a/rslp/lfmc/subsample_oep_eval.py b/rslp/lfmc/subsample_oep_eval.py deleted file mode 100644 index 83b702795..000000000 --- a/rslp/lfmc/subsample_oep_eval.py +++ /dev/null @@ -1,486 +0,0 @@ -"""Tag a spatially-balanced subset of LFMC windows with the oep_eval tag. - -WARNING: This script is intended to be run exactly once per dataset. It will -refuse to run if a manifest file (oep_eval_manifest.json) from a previous run -is found. - -For train: selects ~1000 windows distributed evenly across US states, operating -at the location level. Each selected location contributes at most 1 year of -samples (the densest 365-day window), maximizing geographic coverage. -For val/test: tags ALL windows. - -Usage: - python subsample_oep_eval.py --dataset_path /path/to/dataset --tag oep_eval --target_train 1000 --seed 42 -""" - -import argparse -import json -import os -from collections import defaultdict -from datetime import datetime, timedelta -from typing import Any - -import geopandas as gpd -import numpy as np -import pandas as pd -from pyproj import Transformer -from shapely.geometry import Point - -LocKey = tuple[str, tuple[Any, ...]] - -US_STATES_SHP = ( - "/weka/dfive-default/hadriens/datasets/Misc/Us states/cb_2018_us_state_20m.shp" -) - - -def load_windows(dataset_path: str) -> list[dict]: - """Load all window metadata from the dataset.""" - windows_dir = os.path.join(dataset_path, "windows", "spatial_split") - names = os.listdir(windows_dir) - print(f" Found {len(names)} window directories, loading metadata...") - windows = [] - for i, name in enumerate(names): - if i % 5000 == 0 and i > 0: - print(f" ... loaded {i}/{len(names)}") - meta_path = os.path.join(windows_dir, name, "metadata.json") - if not os.path.isfile(meta_path): - continue - with open(meta_path) as f: - meta = json.load(f) - meta["_dir"] = os.path.join(windows_dir, name) - meta["_name"] = name - windows.append(meta) - return windows - - -def compute_location_centroid_wgs84( - crs: str, bounds: list, x_resolution: float, y_resolution: float -) -> tuple[float, float]: - """Convert the center of pixel bounds to WGS84 lon/lat. - - Bounds are stored in pixel coordinates; multiply by resolution to get - real-world UTM coordinates. - """ - cx_pixel = (bounds[0] + bounds[2]) / 2.0 - cy_pixel = (bounds[1] + bounds[3]) / 2.0 - cx_utm = cx_pixel * x_resolution - cy_utm = cy_pixel * y_resolution - transformer = Transformer.from_crs(crs, "EPSG:4326", always_xy=True) - lon, lat = transformer.transform(cx_utm, cy_utm) - return lon, lat - - -def get_window_date(window: dict) -> datetime: - """Extract the sample date from a window's time_range.""" - return datetime.fromisoformat(window["time_range"][0]) - - -def best_one_year_subset(windows: list[dict]) -> list[dict]: - """Find the densest 365-day sliding window and return only those samples. - - Slides across all sample dates and picks the 365-day interval containing - the most samples. - """ - if len(windows) <= 1: - return windows - - sorted_windows = sorted(windows, key=get_window_date) - dates = [get_window_date(w) for w in sorted_windows] - year_delta = timedelta(days=365) - - best_start = 0 - best_end = 0 - best_count = 0 - - for i, start_date in enumerate(dates): - end_date = start_date + year_delta - # Find how many samples fall within [start_date, start_date + 365 days] - j = i - while j < len(dates) and dates[j] <= end_date: - j += 1 - count = j - i - if count > best_count: - best_count = count - best_start = i - best_end = j - - return sorted_windows[best_start:best_end] - - -def assign_locations_to_states( - locations: dict[LocKey, list[dict]], states_gdf: gpd.GeoDataFrame -) -> dict[LocKey, str]: - """Assign each location key to a US state via point-in-polygon.""" - loc_keys = list(locations.keys()) - points = [] - for key in loc_keys: - crs, bounds = key[0], list(key[1]) - first_window = locations[key][0] - x_res = first_window["projection"]["x_resolution"] - y_res = first_window["projection"]["y_resolution"] - lon, lat = compute_location_centroid_wgs84(crs, bounds, x_res, y_res) - points.append(Point(lon, lat)) - - points_gdf = gpd.GeoDataFrame( - {"loc_key": loc_keys}, geometry=points, crs="EPSG:4326" - ) - states_gdf_4326 = states_gdf.to_crs("EPSG:4326") - joined = gpd.sjoin(points_gdf, states_gdf_4326[["NAME", "geometry"]], how="left") - - loc_to_state: dict[LocKey, str] = {} - for _, row in joined.iterrows(): - state = row.get("NAME") - if isinstance(state, str) and not pd.isna(state): - loc_to_state[row["loc_key"]] = state - else: - loc_to_state[row["loc_key"]] = "Unknown" - return loc_to_state - - -def select_locations_balanced( - locations: dict[LocKey, list[dict]], - loc_to_state: dict[LocKey, str], - loc_one_year_count: dict[LocKey, int], - target: int, - seed: int, -) -> list[LocKey]: - """Select locations with even distribution across states. - - Uses the 1-year capped sample count for each location when computing quotas, - so that more locations can be selected for better geographic coverage. - Returns list of selected location keys. - """ - rng = np.random.default_rng(seed) - - state_to_locs: dict[str, list[LocKey]] = defaultdict(list) - for loc_key, state in loc_to_state.items(): - state_to_locs[state].append(loc_key) - - states_with_data = [s for s in state_to_locs if s != "Unknown"] - if not states_with_data: - states_with_data = list(state_to_locs.keys()) - - selected: list[LocKey] = [] - remaining_target = target - - # Iteratively allocate: states with fewer capped samples than quota get all - settled_states = set() - for _ in range(20): - unsettled = [s for s in states_with_data if s not in settled_states] - if not unsettled: - break - - per_state_quota = remaining_target / len(unsettled) if unsettled else 0 - newly_settled = [] - - for state in unsettled: - locs = state_to_locs[state] - total_capped = sum(loc_one_year_count[lk] for lk in locs) - if total_capped <= per_state_quota: - selected.extend(locs) - remaining_target -= total_capped - newly_settled.append(state) - - if not newly_settled: - break - settled_states.update(newly_settled) - - # For remaining states, select locations to fill quota - unsettled = [s for s in states_with_data if s not in settled_states] - if unsettled: - per_state_quota = remaining_target / len(unsettled) if unsettled else 0 - for state in unsettled: - locs = state_to_locs[state] - locs_with_count = [(lk, loc_one_year_count[lk]) for lk in locs] - - # Shuffle for randomness, then sort by count (prefer dense locations) - shuffled = list(locs_with_count) - rng.shuffle(shuffled) - shuffled.sort(key=lambda x: x[1], reverse=True) - - state_selected: list[LocKey] = [] - state_total = 0 - for lk, count in shuffled: - if state_total + count > per_state_quota and state_selected: - break - state_selected.append(lk) - state_total += count - - selected.extend(state_selected) - remaining_target -= state_total - - # Include "Unknown" state locations if any remain and we're under target - if "Unknown" in state_to_locs and remaining_target > 0: - unknown_locs = state_to_locs["Unknown"] - unknown_with_count = [(lk, loc_one_year_count[lk]) for lk in unknown_locs] - rng.shuffle(unknown_with_count) - unknown_with_count.sort(key=lambda x: x[1], reverse=True) - for lk, count in unknown_with_count: - if remaining_target <= 0: - break - selected.append(lk) - remaining_target -= count - - return selected - - -def tag_windows(windows: list[dict], tag: str) -> int: - """Add the given tag to the windows' metadata.json files.""" - tagged = 0 - for i, w in enumerate(windows): - if i % 2000 == 0 and i > 0: - print(f" ... tagged {i}/{len(windows)}") - meta_path = os.path.join(w["_dir"], "metadata.json") - with open(meta_path) as f: - meta = json.load(f) - if tag not in meta.get("options", {}): - meta.setdefault("options", {})[tag] = "" - with open(meta_path, "w") as f: - json.dump(meta, f) - tagged += 1 - return tagged - - -VAL_MAX_SAMPLES = 800 -TEST_MAX_SAMPLES = 500 - - -def subsample_split( - windows: list[dict], - target: int, - states_gdf: gpd.GeoDataFrame, - seed: int, - split_name: str, -) -> tuple[list[dict], dict]: - """Subsample a split using spatially-balanced selection with 1-year cap. - - Returns (selected_windows, split_stats_dict). - """ - # Group by location - locations: dict[LocKey, list[dict]] = defaultdict(list) - for w in windows: - key: LocKey = (w["projection"]["crs"], tuple(w["bounds"])) - locations[key].append(w) - print(f" Unique {split_name} locations: {len(locations)}") - - # Precompute best 1-year subset for each location - print(f" Computing best 1-year window per {split_name} location...") - loc_one_year: dict[LocKey, list[dict]] = {} - loc_one_year_count: dict[LocKey, int] = {} - for loc_key, loc_windows in locations.items(): - subset = best_one_year_subset(loc_windows) - loc_one_year[loc_key] = subset - loc_one_year_count[loc_key] = len(subset) - avg_capped = np.mean(list(loc_one_year_count.values())) - print(f" Avg samples per location (1-year cap): {avg_capped:.1f}") - - # Assign locations to states - print(f" Assigning {split_name} locations to states...") - loc_to_state = assign_locations_to_states(locations, states_gdf) - - # State distribution - state_total_samples: dict[str, int] = defaultdict(int) - state_total_locations: dict[str, int] = defaultdict(int) - for loc_key, state in loc_to_state.items(): - state_total_samples[state] += len(locations[loc_key]) - state_total_locations[state] += 1 - - # Select locations - print(f" Selecting locations for ~{target} {split_name} samples...") - selected_locs = select_locations_balanced( - locations, loc_to_state, loc_one_year_count, target, seed - ) - - # Collect the 1-year subset - selected_windows = [] - for loc_key in selected_locs: - selected_windows.extend(loc_one_year[loc_key]) - print( - f" Selected {len(selected_locs)} locations -> {len(selected_windows)} {split_name} samples" - ) - - # Per-state reporting - selected_state_samples: dict[str, int] = defaultdict(int) - selected_state_locations: dict[str, int] = defaultdict(int) - for loc_key in selected_locs: - state = loc_to_state[loc_key] - selected_state_samples[state] += loc_one_year_count[loc_key] - selected_state_locations[state] += 1 - print(f"\n Per-state {split_name} selection:") - for state in sorted(selected_state_samples.keys()): - total_locs = state_total_locations[state] - sel_locs = selected_state_locations[state] - total_samp = state_total_samples[state] - sel_samp = selected_state_samples[state] - print( - f" {state}: {sel_locs}/{total_locs} locations, {sel_samp}/{total_samp} samples" - ) - - # Build stats for this split - all_states = sorted( - set(list(selected_state_samples.keys()) + list(state_total_locations.keys())) - ) - split_stats = { - "target_samples": target, - "total_locations": len(locations), - "selected_locations": len(selected_locs), - "total_samples": len(windows), - "selected_samples": len(selected_windows), - "avg_samples_per_location_1yr_cap": float(avg_capped), - "per_state": { - state: { - "total_locations": state_total_locations[state], - "selected_locations": selected_state_locations.get(state, 0), - "total_samples": state_total_samples[state], - "selected_samples": selected_state_samples.get(state, 0), - } - for state in all_states - }, - } - - return selected_windows, split_stats - - -def main() -> None: - """Tag a spatially-balanced subset for LFMC datasets.""" - parser = argparse.ArgumentParser(description="Tag a subset for LFMC") - parser.add_argument("--dataset_path", type=str, required=True) - parser.add_argument("--tag", type=str, default="oep_eval", help="Tag name to apply") - parser.add_argument("--target_train", type=int, default=1000) - parser.add_argument("--seed", type=int, default=42) - parser.add_argument( - "--states_shp", - type=str, - default=US_STATES_SHP, - help="Path to US states shapefile", - ) - parser.add_argument( - "--dry_run", action="store_true", help="Print stats without tagging" - ) - args = parser.parse_args() - - manifest_path = os.path.join(args.dataset_path, "oep_eval_manifest.json") - if os.path.exists(manifest_path): - raise RuntimeError( - f"Manifest file already exists at {manifest_path}. " - "This script is intended to be run exactly once per dataset. " - "Remove the manifest file manually if you need to re-run." - ) - - print(f"Loading windows from {args.dataset_path}...") - all_windows = load_windows(args.dataset_path) - print(f" Total windows: {len(all_windows)}") - - # Split by train/val/test - train_windows = [ - w for w in all_windows if w.get("options", {}).get("split") == "train" - ] - val_windows = [w for w in all_windows if w.get("options", {}).get("split") == "val"] - test_windows = [ - w for w in all_windows if w.get("options", {}).get("split") == "test" - ] - print( - f" Train: {len(train_windows)}, Val: {len(val_windows)}, Test: {len(test_windows)}" - ) - - # Load states shapefile - print("Loading US states shapefile...") - states_gdf = gpd.read_file(args.states_shp) - - # --- Train subsampling --- - print(f"\n--- TRAIN (target: {args.target_train}) ---") - selected_train, train_stats = subsample_split( - train_windows, args.target_train, states_gdf, args.seed, "train" - ) - - # --- Val: subsample if over threshold, else use all --- - if len(val_windows) > VAL_MAX_SAMPLES: - print( - f"\n--- VAL (target: {VAL_MAX_SAMPLES}, total {len(val_windows)} > {VAL_MAX_SAMPLES}) ---" - ) - selected_val, val_stats = subsample_split( - val_windows, VAL_MAX_SAMPLES, states_gdf, args.seed + 1, "val" - ) - else: - print( - f"\n--- VAL: using all {len(val_windows)} samples (below {VAL_MAX_SAMPLES} threshold) ---" - ) - selected_val = val_windows - val_stats = { - "total_samples": len(val_windows), - "selected_samples": len(val_windows), - } - - # --- Test: subsample if over threshold, else use all --- - if len(test_windows) > TEST_MAX_SAMPLES: - print( - f"\n--- TEST (target: {TEST_MAX_SAMPLES}, total {len(test_windows)} > {TEST_MAX_SAMPLES}) ---" - ) - selected_test, test_stats = subsample_split( - test_windows, TEST_MAX_SAMPLES, states_gdf, args.seed + 2, "test" - ) - else: - print( - f"\n--- TEST: using all {len(test_windows)} samples (below {TEST_MAX_SAMPLES} threshold) ---" - ) - selected_test = test_windows - test_stats = { - "total_samples": len(test_windows), - "selected_samples": len(test_windows), - } - - if args.dry_run: - print("\n[DRY RUN] No tagging performed.") - return - - # Tag windows - print(f"\nTagging selected train samples with '{args.tag}'...") - train_tagged = tag_windows(selected_train, args.tag) - print(f" Tagged {train_tagged} train samples") - - print(f"Tagging selected val samples with '{args.tag}'...") - val_tagged = tag_windows(selected_val, args.tag) - print(f" Tagged {val_tagged} val samples") - - print(f"Tagging selected test samples with '{args.tag}'...") - test_tagged = tag_windows(selected_test, args.tag) - print(f" Tagged {test_tagged} test samples") - - # Write manifest - manifest = { - "train": [w["_name"] for w in selected_train], - "val": [w["_name"] for w in selected_val], - "test": [w["_name"] for w in selected_test], - } - manifest_path = os.path.join(args.dataset_path, "oep_eval_manifest.json") - with open(manifest_path, "w") as f: - json.dump(manifest, f, indent=2) - print(f"\nWrote manifest to {manifest_path}") - - # Write stats - stats = { - "tag": args.tag, - "seed": args.seed, - "one_year_cap_days": 365, - "val_max_samples_threshold": VAL_MAX_SAMPLES, - "test_max_samples_threshold": TEST_MAX_SAMPLES, - "total_samples": len(all_windows), - "tagged_samples": { - "train": len(selected_train), - "val": len(selected_val), - "test": len(selected_test), - }, - "train": train_stats, - "val": val_stats, - "test": test_stats, - } - stats_path = os.path.join(args.dataset_path, f"{args.tag}_stats.json") - with open(stats_path, "w") as f: - json.dump(stats, f, indent=2) - print(f"Wrote stats to {stats_path}") - - print("\nDone!") - - -if __name__ == "__main__": - main() diff --git a/rslp/lfmc/visualize_oep_eval.py b/rslp/lfmc/visualize_oep_eval.py deleted file mode 100644 index 90c9f0d14..000000000 --- a/rslp/lfmc/visualize_oep_eval.py +++ /dev/null @@ -1,182 +0,0 @@ -"""Visualize tagged subsampled locations vs all original locations on a US map. - -Produces a 2x2 figure: - - Row 1: Woody dataset - - Row 2: Herbaceous dataset - - Left column: all original locations - - Right column: tagged subset only -Points are colored by split (train=blue, val=orange, test=green). - -Usage: - python visualize_oep_eval.py \ - --woody_path /path/to/woody/dataset \ - --herbaceous_path /path/to/herbaceous/dataset \ - --tag oep_eval \ - --output oep_eval_map.png -""" - -import argparse -import json -import os -from collections import defaultdict - -import geopandas as gpd -import matplotlib.pyplot as plt -from pyproj import Transformer - -US_STATES_SHP = ( - "/weka/dfive-default/hadriens/datasets/Misc/Us states/cb_2018_us_state_20m.shp" -) - -SPLIT_COLORS = {"train": "#2176AE", "val": "#F77F00", "test": "#06D6A0"} - - -def load_locations(dataset_path: str) -> dict[str, list[tuple[float, float]]]: - """Load unique locations per split, return {split: [(lon, lat), ...]}.""" - windows_dir = os.path.join(dataset_path, "windows", "spatial_split") - seen: dict[str, set[tuple]] = defaultdict(set) - locations: dict[str, list[tuple[float, float]]] = defaultdict(list) - - for name in os.listdir(windows_dir): - meta_path = os.path.join(windows_dir, name, "metadata.json") - if not os.path.isfile(meta_path): - continue - with open(meta_path) as f: - meta = json.load(f) - - split = meta.get("options", {}).get("split", "unknown") - crs = meta["projection"]["crs"] - bounds = meta["bounds"] - loc_key = (crs, tuple(bounds)) - - if loc_key in seen[split]: - continue - seen[split].add(loc_key) - - x_res = meta["projection"]["x_resolution"] - y_res = meta["projection"]["y_resolution"] - cx = (bounds[0] + bounds[2]) / 2.0 * x_res - cy = (bounds[1] + bounds[3]) / 2.0 * y_res - transformer = Transformer.from_crs(crs, "EPSG:4326", always_xy=True) - lon, lat = transformer.transform(cx, cy) - locations[split].append((lon, lat)) - - return dict(locations) - - -def load_tagged_locations( - dataset_path: str, tag: str -) -> dict[str, list[tuple[float, float]]]: - """Load unique locations that have the given tag, per split.""" - windows_dir = os.path.join(dataset_path, "windows", "spatial_split") - seen: dict[str, set[tuple]] = defaultdict(set) - locations: dict[str, list[tuple[float, float]]] = defaultdict(list) - - for name in os.listdir(windows_dir): - meta_path = os.path.join(windows_dir, name, "metadata.json") - if not os.path.isfile(meta_path): - continue - with open(meta_path) as f: - meta = json.load(f) - - if tag not in meta.get("options", {}): - continue - - split = meta.get("options", {}).get("split", "unknown") - crs = meta["projection"]["crs"] - bounds = meta["bounds"] - loc_key = (crs, tuple(bounds)) - - if loc_key in seen[split]: - continue - seen[split].add(loc_key) - - x_res = meta["projection"]["x_resolution"] - y_res = meta["projection"]["y_resolution"] - cx = (bounds[0] + bounds[2]) / 2.0 * x_res - cy = (bounds[1] + bounds[3]) / 2.0 * y_res - transformer = Transformer.from_crs(crs, "EPSG:4326", always_xy=True) - lon, lat = transformer.transform(cx, cy) - locations[split].append((lon, lat)) - - return dict(locations) - - -def plot_locations_on_ax( - ax: plt.Axes, - locations: dict[str, list[tuple[float, float]]], - states_gdf: gpd.GeoDataFrame, - title: str, -) -> None: - """Plot location points on a US map axis.""" - states_gdf.boundary.plot(ax=ax, linewidth=0.5, color="gray") - - for split in ["test", "val", "train"]: - if split not in locations: - continue - pts = locations[split] - if not pts: - continue - lons, lats = zip(*pts) - ax.scatter( - lons, - lats, - c=SPLIT_COLORS[split], - s=15, - alpha=0.7, - edgecolors="none", - label=f"{split} ({len(pts)} locs)", - zorder=3, - ) - - ax.set_xlim(-130, -65) - ax.set_ylim(24, 50) - ax.set_title(title, fontsize=11, fontweight="bold") - ax.set_aspect("equal") - ax.tick_params(labelsize=7) - ax.legend(loc="lower left", fontsize=7, framealpha=0.8) - - -def main() -> None: - """Visualize tagged subsampled vs original locations on a US map.""" - parser = argparse.ArgumentParser(description="Visualize tagged locations") - parser.add_argument("--woody_path", type=str, required=True) - parser.add_argument("--herbaceous_path", type=str, required=True) - parser.add_argument( - "--tag", type=str, default="oep_eval", help="Tag name to filter on" - ) - parser.add_argument("--states_shp", type=str, default=US_STATES_SHP) - parser.add_argument("--output", type=str, default="oep_eval_map.png") - args = parser.parse_args() - - print("Loading US states shapefile...") - states_gdf = gpd.read_file(args.states_shp) - states_gdf = states_gdf[~states_gdf["NAME"].isin(["Alaska", "Hawaii"])] - states_gdf = states_gdf.to_crs("EPSG:4326") - - print("Loading woody locations...") - woody_all = load_locations(args.woody_path) - woody_tagged = load_tagged_locations(args.woody_path, args.tag) - - print("Loading herbaceous locations...") - herb_all = load_locations(args.herbaceous_path) - herb_tagged = load_tagged_locations(args.herbaceous_path, args.tag) - - fig, axes = plt.subplots(2, 2, figsize=(16, 10)) - - plot_locations_on_ax(axes[0, 0], woody_all, states_gdf, "Woody — All Locations") - plot_locations_on_ax( - axes[0, 1], woody_tagged, states_gdf, f"Woody — {args.tag} Subset" - ) - plot_locations_on_ax(axes[1, 0], herb_all, states_gdf, "Herbaceous — All Locations") - plot_locations_on_ax( - axes[1, 1], herb_tagged, states_gdf, f"Herbaceous — {args.tag} Subset" - ) - - plt.tight_layout() - plt.savefig(args.output, dpi=150, bbox_inches="tight") - print(f"Saved figure to {args.output}") - - -if __name__ == "__main__": - main()