From 2a84f522046e35ee3e1526b3cd78ac506ac236f2 Mon Sep 17 00:00:00 2001 From: Favyen Bastani Date: Thu, 21 May 2026 08:43:10 -0700 Subject: [PATCH 1/6] Add embedding explorer web app --- .gitignore | 1 + rslp/embedding_explorer/README.md | 106 ++++ rslp/embedding_explorer/__init__.py | 1 + rslp/embedding_explorer/app.py | 589 +++++++++++++++++++ rslp/embedding_explorer/config.json | 49 ++ rslp/embedding_explorer/config.yaml | 50 ++ rslp/embedding_explorer/static/app.js | 503 ++++++++++++++++ rslp/embedding_explorer/static/style.css | 228 +++++++ rslp/embedding_explorer/templates/index.html | 110 ++++ 9 files changed, 1637 insertions(+) create mode 100644 rslp/embedding_explorer/README.md create mode 100644 rslp/embedding_explorer/__init__.py create mode 100644 rslp/embedding_explorer/app.py create mode 100644 rslp/embedding_explorer/config.json create mode 100644 rslp/embedding_explorer/config.yaml create mode 100644 rslp/embedding_explorer/static/app.js create mode 100644 rslp/embedding_explorer/static/style.css create mode 100644 rslp/embedding_explorer/templates/index.html diff --git a/.gitignore b/.gitignore index 4d73ea87c..b3598d519 100644 --- a/.gitignore +++ b/.gitignore @@ -5,6 +5,7 @@ wandb/ rslp/__pycache__/ **__pycache__ **/.DS_Store +*.pyc repos/ project_data diff --git a/rslp/embedding_explorer/README.md b/rslp/embedding_explorer/README.md new file mode 100644 index 000000000..fb7881da5 --- /dev/null +++ b/rslp/embedding_explorer/README.md @@ -0,0 +1,106 @@ +# Embedding Explorer + +WARNING: all of the code and most of the documentation here was AI-generated and not +carefully reviewed. + +A small Flask app for interactively exploring per-pixel OlmoEarth embeddings +on a Leaflet map. Click points on the map and the server computes a similarity +overlay using one of three modes: + +- **Cosine** — cosine similarity to the embedding under the (single, latest) + positive point. +- **KNN** — fraction of the K nearest labeled points (by cosine similarity) + that are positive. +- **Linear Probe** — a logistic regression (768 -> 2) is trained on CPU from the + labeled points and predict_proba is run over the whole window. Updates only + on **Apply** (training is fast but not free, so we don't refresh on every + click). + +Left-click adds a positive point, right-click adds a negative point. The +threshold slider and gradient/threshold toggle are pure client-side rendering +and stay live in all modes. + +## Setting up a dataset + +The explorer expects an rslearn dataset where each window has a raster layer +named `embeddings` (768 bands, float32) alongside the source images. Two +reference configs sit next to this README: + +- [config.json](config.json) — minimal dataset config (Sentinel-2 L2A from + Planetary Computer + an `embeddings` layer). +- [config.yaml](config.yaml) — model config that runs OlmoEarth and writes + outputs into the `embeddings` layer via `RslearnWriter`. + +For a fuller treatment of these configs (Sentinel-1 / Landsat layers, model +sizes, etc.) see +[OlmoEarthEmbeddings.md](https://github.com/allenai/rslearn/blob/master/docs/examples/OlmoEarthEmbeddings.md) +in rslearn. + +### 1. Create a dataset and add window(s) + +Copy `config.json` into a fresh dataset directory, then add one or more +2048×2048 windows. The explorer is happy with multiple windows and you +switch between them in the sidebar dropdown. + +```bash +export DATASET_PATH=./dataset +mkdir -p $DATASET_PATH +cp rslp/embedding_explorer/config.json $DATASET_PATH/config.json + +# A single 2048x2048 window over Seattle. +rslearn dataset add_windows --root $DATASET_PATH \ + --group default --name seattle \ + --utm --resolution 10 --src_crs EPSG:4326 \ + --box=-122.42,47.58,-122.22,47.78 \ + --start 2024-06-01T00:00:00+00:00 --end 2024-09-01T00:00:00+00:00 \ + --grid_size 2048 + +# Or tile a larger area into multiple 2048x2048 windows. +rslearn dataset add_windows --root $DATASET_PATH \ + --group default --name puget_sound \ + --utm --resolution 10 --src_crs EPSG:4326 \ + --box=-122.7,47.2,-122.0,47.9 \ + --start 2024-06-01T00:00:00+00:00 --end 2024-09-01T00:00:00+00:00 \ + --grid_size 2048 +``` + +### 2. Materialize Sentinel-2 + +```bash +rslearn dataset prepare --root $DATASET_PATH --workers 32 \ + --enabled-layers sentinel2_l2a \ + --retry-max-attempts 5 --retry-backoff-seconds 5 +rslearn dataset materialize --root $DATASET_PATH --workers 32 \ + --no-use-initial-job --enabled-layers sentinel2_l2a \ + --retry-max-attempts 5 --retry-backoff-seconds 5 +``` + +### 3. Compute embeddings + +```bash +rslearn model predict --config rslp/embedding_explorer/config.yaml +``` + +After this finishes, each window will have a populated `layers/embeddings/` +directory with a 768-band GeoTIFF. + +## Run the app + +The app needs Flask and scikit-learn (neither are pulled in by `rslp`'s base +requirements): + +```bash +pip install flask scikit-learn +``` + +Then point it at your dataset: + +```bash +python -m rslp.embedding_explorer.app \ + --dataset-path $DATASET_PATH \ + --port 5000 +``` + +Open `http://localhost:5000` and pick a window from the sidebar. Clicking on +the map adds points; in cosine/KNN modes the overlay updates immediately, and +in Linear Probe mode you press **Apply** to (re)train. diff --git a/rslp/embedding_explorer/__init__.py b/rslp/embedding_explorer/__init__.py new file mode 100644 index 000000000..3621f37d0 --- /dev/null +++ b/rslp/embedding_explorer/__init__.py @@ -0,0 +1 @@ +"""Embedding similarity explorer web app.""" diff --git a/rslp/embedding_explorer/app.py b/rslp/embedding_explorer/app.py new file mode 100644 index 000000000..c14fbf7a8 --- /dev/null +++ b/rslp/embedding_explorer/app.py @@ -0,0 +1,589 @@ +"""Flask app for exploring embedding similarity in rslearn datasets. + +Usage: + python -m rslp.embedding_explorer.app --dataset-path /path/to/dataset --port 5000 +""" + +import argparse +import io +import json +from pathlib import Path + +import numpy as np +import rasterio +import rasterio.transform +import rasterio.warp +from flask import Flask, Response, render_template, request +from PIL import Image +from pyproj import Transformer +from rasterio.enums import Resampling +from sklearn.linear_model import LogisticRegression + +EPSG_3857 = "EPSG:3857" + + +def find_geotiff(layer_dir: Path) -> Path | None: + """Find geotiff.tif inside a layer directory (inside the bandset subdir).""" + for subdir in layer_dir.iterdir(): + if subdir.is_dir() and subdir.name != "completed": + tif = subdir / "geotiff.tif" + if tif.exists(): + return tif + return None + + +def reproject_to_webmercator( + data: np.ndarray, + src_crs: str, + src_transform: rasterio.transform.Affine, + src_shape: tuple, +) -> tuple[np.ndarray, rasterio.transform.Affine, tuple[int, int]]: + """Reproject a (C, H, W) or (H, W) array to EPSG:3857. + + Returns (reprojected_data, dst_transform, (dst_height, dst_width)). + """ + if data.ndim == 2: + data = data[np.newaxis, ...] + squeeze = True + else: + squeeze = False + + dst_transform, dst_width, dst_height = rasterio.warp.calculate_default_transform( + src_crs, + EPSG_3857, + src_shape[1], + src_shape[0], + *rasterio.transform.array_bounds(src_shape[0], src_shape[1], src_transform), + ) + + dst = np.zeros((data.shape[0], dst_height, dst_width), dtype=data.dtype) + for i in range(data.shape[0]): + rasterio.warp.reproject( + source=data[i], + destination=dst[i], + src_transform=src_transform, + src_crs=src_crs, + dst_transform=dst_transform, + dst_crs=EPSG_3857, + resampling=Resampling.nearest, + ) + + if squeeze: + dst = dst[0] + return dst, dst_transform, (dst_height, dst_width) + + +def webmercator_bounds( + transform: rasterio.transform.Affine, shape: tuple[int, int] +) -> tuple[float, float, float, float]: + """Get (left, bottom, right, top) in EPSG:3857 meters from a raster grid.""" + return rasterio.transform.array_bounds(shape[0], shape[1], transform) + + +def webmercator_bounds_to_latlon( + transform: rasterio.transform.Affine, shape: tuple[int, int] +) -> list[list[float]]: + """Get [[lat_min, lon_min], [lat_max, lon_max]] from a Web Mercator raster.""" + bounds = webmercator_bounds(transform, shape) + transformer = Transformer.from_crs(EPSG_3857, "EPSG:4326", always_xy=True) + lon_min, lat_min = transformer.transform(bounds[0], bounds[1]) + lon_max, lat_max = transformer.transform(bounds[2], bounds[3]) + return [[lat_min, lon_min], [lat_max, lon_max]] + + +def latlon_to_pixel( + lat: float, lon: float, crs: str, transform: rasterio.transform.Affine, shape: tuple +) -> tuple[int, int]: + """Convert lat/lon to pixel row, col using the raster's transform.""" + proj_transformer = Transformer.from_crs("EPSG:4326", crs, always_xy=True) + x, y = proj_transformer.transform(lon, lat) + row, col = rasterio.transform.rowcol(transform, x, y) + row = max(0, min(int(row), shape[0] - 1)) + col = max(0, min(int(col), shape[1] - 1)) + return row, col + + +def render_rgb_png(data: np.ndarray, bands: tuple[int, int, int] = (0, 1, 2)) -> bytes: + """Render selected bands as a stretched RGB PNG.""" + rgb = np.stack([data[b] for b in bands], axis=-1).astype(np.float32) + for i in range(3): + band = rgb[:, :, i] + valid = band[np.isfinite(band) & (band != 0)] + if len(valid) == 0: + continue + lo = np.percentile(valid, 2) + hi = np.percentile(valid, 98) + if hi - lo < 1e-6: + hi = lo + 1 + rgb[:, :, i] = (band - lo) / (hi - lo) + rgb = np.clip(rgb * 255, 0, 255).astype(np.uint8) + img = Image.fromarray(rgb, mode="RGB") + buf = io.BytesIO() + img.save(buf, format="PNG") + return buf.getvalue() + + +def render_rgb_png_with_alpha( + data: np.ndarray, bands: tuple[int, int, int] = (0, 1, 2) +) -> bytes: + """Render selected bands as stretched RGBA PNG (nodata pixels transparent).""" + rgb = np.stack([data[b] for b in bands], axis=-1).astype(np.float32) + # Mask where all selected bands are zero (nodata from reprojection) + nodata_mask = np.all(rgb == 0, axis=-1) + + for i in range(3): + band = rgb[:, :, i] + valid = band[np.isfinite(band) & (~nodata_mask)] + if len(valid) == 0: + continue + lo = np.percentile(valid, 2) + hi = np.percentile(valid, 98) + if hi - lo < 1e-6: + hi = lo + 1 + rgb[:, :, i] = (band - lo) / (hi - lo) + + rgb = np.clip(rgb * 255, 0, 255).astype(np.uint8) + alpha = np.where(nodata_mask, 0, 255).astype(np.uint8) + rgba = np.dstack([rgb, alpha]) + img = Image.fromarray(rgba, mode="RGBA") + buf = io.BytesIO() + img.save(buf, format="PNG") + return buf.getvalue() + + +def compute_cosine_similarity( + embeddings: np.ndarray, ref_vector: np.ndarray +) -> np.ndarray: + """Cosine similarity between all pixels and a reference vector. + + Returns (H, W) float array in [-1, 1]. + """ + ref_norm = np.linalg.norm(ref_vector) + if ref_norm < 1e-8: + return np.zeros(embeddings.shape[1:], dtype=np.float32) + ref_unit = ref_vector / ref_norm + pixel_norms = np.linalg.norm(embeddings, axis=0) + valid = pixel_norms > 1e-8 + dot = np.tensordot(ref_unit, embeddings, axes=(0, 0)) + similarity = np.zeros(embeddings.shape[1:], dtype=np.float32) + similarity[valid] = dot[valid] / pixel_norms[valid] + return np.clip(similarity, -1.0, 1.0) + + +def compute_knn(embeddings: np.ndarray, points: list[dict], k: int) -> np.ndarray: + """KNN classification: positive_votes / k for each pixel. + + Returns (H, W) float array in [0, 1]. + """ + C, H, W = embeddings.shape + flat = embeddings.reshape(C, -1).T # (N, C) + flat_norms = np.linalg.norm(flat, axis=1, keepdims=True) + flat_norms[flat_norms < 1e-8] = 1.0 + flat_unit = flat / flat_norms + + n_points = len(points) + labels = np.array([1.0 if p["label"] == "positive" else 0.0 for p in points]) + point_vecs = np.stack([p["vector"] for p in points]) + point_norms = np.linalg.norm(point_vecs, axis=1, keepdims=True) + point_norms[point_norms < 1e-8] = 1.0 + point_unit = point_vecs / point_norms + + sims = flat_unit @ point_unit.T # (N, n_points) + + effective_k = min(k, n_points) + if effective_k == n_points: + top_k_indices = np.broadcast_to(np.arange(n_points), (sims.shape[0], n_points)) + else: + top_k_indices = np.argpartition(sims, -effective_k, axis=1)[:, -effective_k:] + + top_k_labels = labels[top_k_indices] + score = top_k_labels.mean(axis=1) + return score.reshape(H, W).astype(np.float32) + + +def compute_linear_probe(embeddings: np.ndarray, points: list[dict]) -> np.ndarray: + """Train a logistic regression on labeled points and predict per-pixel probability. + + Returns (H, W) float array in [0, 1] giving probability of the positive class. + """ + C, H, W = embeddings.shape + + X = np.stack([p["vector"] for p in points]).astype(np.float32) + y = np.array([1 if p["label"] == "positive" else 0 for p in points], dtype=np.int64) + + # max_iter is generous; with ~10-30 samples LBFGS converges in a handful of steps. + clf = LogisticRegression(C=1.0, solver="lbfgs", max_iter=1000) + clf.fit(X, y) + + pos_idx = int(np.where(clf.classes_ == 1)[0][0]) + flat = embeddings.reshape(C, -1).T.astype(np.float32) + probs = clf.predict_proba(flat)[:, pos_idx] + return probs.reshape(H, W).astype(np.float32) + + +def similarity_to_png(similarity: np.ndarray, mode: str) -> bytes: + """Encode similarity as 8-bit grayscale PNG. + + Cosine: [-1, 1] -> [0, 255] + KNN / linear_probe: [0, 1] -> [0, 255] + """ + if mode == "cosine": + img_data = ((similarity + 1.0) / 2.0 * 255.0).astype(np.uint8) + else: + img_data = (similarity * 255.0).astype(np.uint8) + img = Image.fromarray(img_data, mode="L") + buf = io.BytesIO() + img.save(buf, format="PNG") + return buf.getvalue() + + +def load_dataset(dataset_path: Path, embedding_layer: str) -> dict: + """Load dataset: embeddings into memory, image layer paths stored for on-demand serving.""" + windows_dir = dataset_path / "windows" + dataset: dict = {"windows": {}, "embedding_layer": embedding_layer} + + for group_dir in sorted(windows_dir.iterdir()): + if not group_dir.is_dir(): + continue + for window_dir in sorted(group_dir.iterdir()): + if not window_dir.is_dir(): + continue + meta_path = window_dir / "metadata.json" + if not meta_path.exists(): + continue + + with open(meta_path) as f: + metadata = json.load(f) + + layers_dir = window_dir / "layers" + if not layers_dir.exists(): + continue + + window_key = f"{group_dir.name}/{window_dir.name}" + window_info = { + "metadata": metadata, + "embeddings": None, + "image_layers": {}, + "path": window_dir, + } + + for layer_dir in sorted(layers_dir.iterdir()): + if not layer_dir.is_dir(): + continue + layer_name = layer_dir.name + base_name = layer_name.split(".")[0] + + tif = find_geotiff(layer_dir) + if tif is None: + continue + + if base_name == embedding_layer: + print(f" Loading embedding: {window_key}/{layer_name}") + with rasterio.open(tif) as src: + data = src.read().astype(np.float32) + src_crs = str(src.crs) + src_transform = src.transform + src_shape = (src.height, src.width) + window_info["embeddings"] = data + window_info["embedding_crs"] = src_crs + window_info["embedding_transform"] = src_transform + window_info["embedding_shape"] = src_shape + print(f" shape: {data.shape}") + + # Pre-compute Web Mercator transform for serving + print(" computing Web Mercator reprojection info...") + _, wm_transform, wm_shape = reproject_to_webmercator( + data[:1], src_crs, src_transform, src_shape + ) + window_info["wm_transform"] = wm_transform + window_info["wm_shape"] = wm_shape + wm_bounds = webmercator_bounds(wm_transform, wm_shape) + window_info["wm_bounds"] = wm_bounds # (left, bottom, right, top) + print(f" Web Mercator shape: {wm_shape}") + else: + # Compute WM bounds for this image layer at startup + with rasterio.open(tif) as src: + layer_crs = str(src.crs) + layer_transform = src.transform + layer_shape = (src.height, src.width) + wm_dst_transform, wm_dst_width, wm_dst_height = ( + rasterio.warp.calculate_default_transform( + layer_crs, + EPSG_3857, + layer_shape[1], + layer_shape[0], + *rasterio.transform.array_bounds( + layer_shape[0], layer_shape[1], layer_transform + ), + ) + ) + layer_wm_bounds = webmercator_bounds( + wm_dst_transform, (wm_dst_height, wm_dst_width) + ) + if base_name not in window_info["image_layers"]: + window_info["image_layers"][base_name] = [] + window_info["image_layers"][base_name].append( + { + "name": layer_name, + "path": str(tif), + "wm_bounds": layer_wm_bounds, + } + ) + + if window_info["embeddings"] is None: + print( + f" WARNING: no '{embedding_layer}' layer found in {window_key}, skipping" + ) + continue + + dataset["windows"][window_key] = window_info + print(f"Loaded window: {window_key}") + print(f" mercator bounds: {window_info['wm_bounds']}") + print(f" embedding shape: {window_info['embeddings'].shape}") + print(f" image layers: {list(window_info['image_layers'].keys())}") + + return dataset + + +def reproject_single_band_to_wm( + data: np.ndarray, + src_crs: str, + src_transform: rasterio.transform.Affine, + src_shape: tuple, + wm_transform: rasterio.transform.Affine, + wm_shape: tuple, +) -> np.ndarray: + """Reproject a single (H, W) array to a pre-computed Web Mercator grid.""" + dst = np.zeros(wm_shape, dtype=data.dtype) + rasterio.warp.reproject( + source=data, + destination=dst, + src_transform=src_transform, + src_crs=src_crs, + dst_transform=wm_transform, + dst_crs=EPSG_3857, + resampling=Resampling.nearest, + ) + return dst + + +def create_app(dataset_path: Path, embedding_layer: str = "embeddings") -> Flask: + """Create and configure the Flask application.""" + app_dir = Path(__file__).parent + app = Flask( + __name__, + template_folder=str(app_dir / "templates"), + static_folder=str(app_dir / "static"), + ) + + print(f"Loading dataset from {dataset_path}...") + print(f"Embedding layer: {embedding_layer}") + dataset = load_dataset(dataset_path, embedding_layer) + print(f"Loaded {len(dataset['windows'])} window(s)") + + @app.route("/") + def index() -> str: + windows_info: dict = {} + for key, win in dataset["windows"].items(): + wm = win["wm_bounds"] + windows_info[key] = { + "mercator_bounds": [wm[0], wm[1], wm[2], wm[3]], + "image_layers": { + name: [ + { + "name": g["name"], + "mercator_bounds": list(g["wm_bounds"]), + } + for g in groups + ] + for name, groups in win["image_layers"].items() + }, + } + return render_template( + "index.html", + windows=json.dumps(windows_info), + embedding_layer=embedding_layer, + ) + + @app.route("/api/image/") + def serve_image(layer_path: str) -> Response | tuple[str, int]: + """Serve a layer as RGB PNG reprojected to Web Mercator. + + layer_path: window_key/layer_name (e.g. default/default/sentinel2_l2a) + Query params: + bands: comma-separated 0-based band indices (default: 0,1,2) + """ + parts = layer_path.split("/") + if len(parts) < 3: + return "Invalid path", 400 + window_key = f"{parts[0]}/{parts[1]}" + layer_name = "/".join(parts[2:]) + + win = dataset["windows"].get(window_key) + if win is None: + return "Window not found", 404 + + bands_param = request.args.get("bands", "0,1,2") + bands = tuple(int(b) for b in bands_param.split(",")) + + # Serve embedding RGB from memory (reproject to WM) + if layer_name == dataset["embedding_layer"]: + data = win["embeddings"] + selected = np.stack([data[b] for b in bands]) + reprojected, wm_transform, wm_shape = reproject_to_webmercator( + selected, + win["embedding_crs"], + win["embedding_transform"], + win["embedding_shape"], + ) + png = render_rgb_png_with_alpha(reprojected, (0, 1, 2)) + left, bottom, right, top = webmercator_bounds(wm_transform, wm_shape) + return Response( + png, + mimetype="image/png", + headers={ + "X-Mercator-Bounds": f"{left},{bottom},{right},{top}", + "Cache-Control": "public, max-age=3600", + }, + ) + + # Find in image layers + tif_path = None + for base_name, groups in win["image_layers"].items(): + for g in groups: + if g["name"] == layer_name: + tif_path = g["path"] + break + if tif_path: + break + + if tif_path is None: + return "Layer not found", 404 + + with rasterio.open(tif_path) as src: + band_indices = [b + 1 for b in bands] # rasterio is 1-indexed + data = src.read(indexes=band_indices).astype(np.float32) + src_crs = str(src.crs) + src_transform = src.transform + src_shape = (src.height, src.width) + + reprojected, wm_transform, wm_shape = reproject_to_webmercator( + data, src_crs, src_transform, src_shape + ) + png = render_rgb_png_with_alpha(reprojected, (0, 1, 2)) + left, bottom, right, top = webmercator_bounds(wm_transform, wm_shape) + return Response( + png, + mimetype="image/png", + headers={ + "X-Mercator-Bounds": f"{left},{bottom},{right},{top}", + "Cache-Control": "public, max-age=3600", + }, + ) + + @app.route("/api/similarity", methods=["POST"]) + def compute_similarity_route() -> Response | tuple[str, int]: + """Compute similarity in native CRS, reproject result to Web Mercator. + + Request JSON: + mode: "cosine" or "knn" + points: list of {lat, lon, label} + k: int (for knn mode) + window: window key + """ + body = request.get_json() + window_key = body.get("window") + mode = body.get("mode", "cosine") + points = body.get("points", []) + k = body.get("k", 3) + + win = dataset["windows"].get(window_key) + if win is None: + return "Window not found", 404 + + embeddings = win["embeddings"] + if not points: + return "No points provided", 400 + + crs = win["embedding_crs"] + transform = win["embedding_transform"] + shape = embeddings.shape[1:] + + if mode == "cosine": + pos_points = [p for p in points if p.get("label") == "positive"] + if not pos_points: + return "No positive points for cosine mode", 400 + p = pos_points[0] + row, col = latlon_to_pixel(p["lat"], p["lon"], crs, transform, shape) + ref_vector = embeddings[:, row, col] + similarity = compute_cosine_similarity(embeddings, ref_vector) + elif mode == "linear_probe": + labeled_points = [] + for p in points: + row, col = latlon_to_pixel(p["lat"], p["lon"], crs, transform, shape) + labeled_points.append( + { + "vector": embeddings[:, row, col].copy(), + "label": p.get("label", "positive"), + } + ) + n_pos = sum(1 for p in labeled_points if p["label"] == "positive") + n_neg = len(labeled_points) - n_pos + if n_pos < 1 or n_neg < 1: + return ( + "Linear probe requires at least one positive and one negative point", + 400, + ) + similarity = compute_linear_probe(embeddings, labeled_points) + else: + labeled_points = [] + for p in points: + row, col = latlon_to_pixel(p["lat"], p["lon"], crs, transform, shape) + labeled_points.append( + { + "vector": embeddings[:, row, col].copy(), + "label": p.get("label", "positive"), + } + ) + similarity = compute_knn(embeddings, labeled_points, k) + + # Reproject similarity to Web Mercator + sim_wm = reproject_single_band_to_wm( + similarity, crs, transform, shape, win["wm_transform"], win["wm_shape"] + ) + + png = similarity_to_png(sim_wm, mode) + left, bottom, right, top = webmercator_bounds( + win["wm_transform"], win["wm_shape"] + ) + return Response( + png, + mimetype="image/png", + headers={ + "X-Mercator-Bounds": f"{left},{bottom},{right},{top}", + "X-Mode": mode, + }, + ) + + return app + + +def main() -> None: + """CLI entry point for the embedding explorer web app.""" + parser = argparse.ArgumentParser(description="Embedding similarity explorer") + parser.add_argument("--dataset-path", type=Path, required=True) + parser.add_argument( + "--embedding-layer", + default="embeddings", + help="Layer name to load as embeddings (default: 'embeddings')", + ) + parser.add_argument("--port", type=int, default=5000) + parser.add_argument("--host", default="0.0.0.0") + args = parser.parse_args() + + app = create_app(args.dataset_path, embedding_layer=args.embedding_layer) + app.run(host=args.host, port=args.port, debug=False) + + +if __name__ == "__main__": + main() diff --git a/rslp/embedding_explorer/config.json b/rslp/embedding_explorer/config.json new file mode 100644 index 000000000..34ecedc80 --- /dev/null +++ b/rslp/embedding_explorer/config.json @@ -0,0 +1,49 @@ +{ + "layers": { + "embeddings": { + "band_sets": [ + { + "dtype": "float32", + "num_bands": 768 + } + ], + "type": "raster" + }, + "sentinel2_l2a": { + "band_sets": [ + { + "bands": [ + "B01", + "B02", + "B03", + "B04", + "B05", + "B06", + "B07", + "B08", + "B8A", + "B09", + "B11", + "B12" + ], + "dtype": "uint16" + } + ], + "data_source": { + "class_path": "rslearn.data_sources.planetary_computer.Sentinel2", + "ingest": false, + "init_args": { + "cache_dir": "cache/planetary_computer", + "harmonize": true, + "sort_by": "eo:cloud_cover" + }, + "query_config": { + "max_matches": 12, + "period_duration": "30d", + "space_mode": "PER_PERIOD_MOSAIC" + } + }, + "type": "raster" + } + } +} diff --git a/rslp/embedding_explorer/config.yaml b/rslp/embedding_explorer/config.yaml new file mode 100644 index 000000000..021a37945 --- /dev/null +++ b/rslp/embedding_explorer/config.yaml @@ -0,0 +1,50 @@ +model: + class_path: rslearn.train.lightning_module.RslearnLightningModule + init_args: + model: + class_path: rslearn.models.singletask.SingleTaskModel + init_args: + encoder: + - class_path: rslearn.models.olmoearth_pretrain.model.OlmoEarth + init_args: + model_id: OLMOEARTH_V1_BASE + patch_size: 4 + decoder: + - class_path: rslearn.train.tasks.embedding.EmbeddingHead + optimizer: + class_path: rslearn.train.optimizer.AdamW +data: + class_path: rslearn.train.data_module.RslearnDataModule + init_args: + path: ${DATASET_PATH} + inputs: + sentinel2_l2a: + data_type: "raster" + layers: ["sentinel2_l2a", "sentinel2_l2a.1", "sentinel2_l2a.2", "sentinel2_l2a.3"] + bands: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12", "B01", "B09"] + passthrough: true + dtype: FLOAT32 + load_all_layers: true + task: + class_path: rslearn.train.tasks.embedding.EmbeddingTask + batch_size: 8 + num_workers: 32 + predict_config: + transforms: + - class_path: rslearn.models.olmoearth_pretrain.norm.OlmoEarthNormalize + init_args: + band_names: + sentinel2_l2a: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12", "B01", "B09"] + load_all_crops: true + crop_size: 64 + overlap_pixels: 32 +trainer: + callbacks: + - class_path: rslearn.train.prediction_writer.RslearnWriter + init_args: + output_layer: embeddings + merger: + class_path: rslearn.train.prediction_writer.RasterMerger + init_args: + overlap_pixels: 8 + downsample_factor: 4 diff --git a/rslp/embedding_explorer/static/app.js b/rslp/embedding_explorer/static/app.js new file mode 100644 index 000000000..e3484be5a --- /dev/null +++ b/rslp/embedding_explorer/static/app.js @@ -0,0 +1,503 @@ +(function () { + "use strict"; + + // State + let currentWindow = null; + let points = []; + let similarityImg = null; + let similarityBounds = null; + let similarityMode = null; + let overlayLayer = null; + let activeOverlay = null; + let activeOverlayKey = null; + + // Map setup + const map = L.map("map", { zoomControl: true }).setView([0, 0], 3); + + // Bing Maps tile layer (uses quadkey URLs). + const BingTileLayer = L.TileLayer.extend({ + getTileUrl: function (coords) { + const x = coords.x; + const y = coords.y; + const z = coords.z; + let quadkey = ""; + for (let i = z; i > 0; i--) { + let digit = 0; + const mask = 1 << (i - 1); + if ((x & mask) !== 0) digit += 1; + if ((y & mask) !== 0) digit += 2; + quadkey += digit.toString(); + } + const sub = (x + y) % 4; + return `https://ecn.t${sub}.tiles.virtualearth.net/tiles/a${quadkey}.jpeg?g=761`; + }, + }); + + function makeOsmLayer() { + return L.tileLayer("https://{s}.tile.openstreetmap.org/{z}/{x}/{y}.png", { + attribution: "© OSM contributors", + maxZoom: 19, + }); + } + + function makeBingLayer() { + return new BingTileLayer("", { + attribution: "© Microsoft Bing Maps", + maxZoom: 19, + }); + } + + // DOM refs + const windowSelect = document.getElementById("window-select"); + const layerToggles = document.getElementById("layer-toggles"); + const pointsList = document.getElementById("points-list"); + const clearBtn = document.getElementById("clear-points"); + const thresholdSlider = document.getElementById("threshold-slider"); + const thresholdValue = document.getElementById("threshold-value"); + const kInput = document.getElementById("k-input"); + const applyProbeBtn = document.getElementById("apply-probe"); + const loadingSpinner = document.getElementById("loading-spinner"); + const linearProbeGroup = document.getElementById("linear-probe-group"); + const kGroup = document.getElementById("k-group"); + const opacitySlider = document.getElementById("opacity-slider"); + const opacityValue = document.getElementById("opacity-value"); + const thresholdColor = document.getElementById("threshold-color"); + + // Populate window dropdown + Object.keys(WINDOWS_DATA).forEach((key) => { + const opt = document.createElement("option"); + opt.value = key; + opt.textContent = key; + windowSelect.appendChild(opt); + }); + + windowSelect.addEventListener("change", () => { + selectWindow(windowSelect.value); + }); + + function mercatorBoundsToLatLng(mb) { + const sw = L.CRS.EPSG3857.unproject(L.point(mb[0], mb[1])); + const ne = L.CRS.EPSG3857.unproject(L.point(mb[2], mb[3])); + return L.latLngBounds(sw, ne); + } + + const layerRadios = {}; + + function selectWindow(key) { + currentWindow = key; + const win = WINDOWS_DATA[key]; + points = []; + clearSimilarityOverlay(); + renderPoints(); + + const bounds = mercatorBoundsToLatLng(win.mercator_bounds); + map.flyToBounds(bounds, { padding: [20, 20] }); + + const previousKey = activeOverlayKey; + removeActiveOverlay(); + layerToggles.innerHTML = ""; + Object.keys(layerRadios).forEach((k) => delete layerRadios[k]); + + addLayerRadio("OpenStreetMap", { type: "tile", source: "osm" }); + addLayerRadio("Bing Maps", { type: "tile", source: "bing" }); + + addLayerRadio(EMBEDDING_LAYER + " (RGB)", { + type: "image", + url: `/api/image/${key}/${EMBEDDING_LAYER}?bands=0,1,2`, + mercatorBounds: win.mercator_bounds, + }); + + Object.entries(win.image_layers).forEach(([baseName, groups]) => { + groups.forEach((layerEntry) => { + const bands = baseName.includes("sentinel2") ? "3,2,1" : "0,1,2"; + addLayerRadio(layerEntry.name, { + type: "image", + url: `/api/image/${key}/${layerEntry.name}?bands=${bands}`, + mercatorBounds: layerEntry.mercator_bounds, + }); + }); + }); + + // Restore previous selection if still available, otherwise default to OSM. + const targetKey = + previousKey && layerRadios[previousKey] + ? previousKey + : "OpenStreetMap"; + selectLayerRadio(targetKey); + } + + function selectLayerRadio(name) { + const entry = layerRadios[name]; + if (!entry) return; + entry.input.checked = true; + applyOverlay(name, entry.spec); + } + + function addLayerRadio(name, spec) { + const div = document.createElement("div"); + div.className = "layer-toggle"; + const r = document.createElement("input"); + r.type = "radio"; + r.name = "layer-radio"; + r.id = `layer-${name}`; + r.addEventListener("change", () => { + if (r.checked) applyOverlay(name, spec); + }); + const label = document.createElement("label"); + label.htmlFor = r.id; + label.textContent = name; + div.appendChild(r); + div.appendChild(label); + layerToggles.appendChild(div); + layerRadios[name] = { input: r, spec: spec }; + } + + function applyOverlay(name, spec) { + removeActiveOverlay(); + if (spec.type === "tile") { + activeOverlay = (spec.source === "bing" ? makeBingLayer() : makeOsmLayer()).addTo(map); + } else if (spec.type === "image") { + const bounds = mercatorBoundsToLatLng(spec.mercatorBounds); + activeOverlay = L.imageOverlay(spec.url, bounds, { opacity: 1.0 }).addTo(map); + } + activeOverlayKey = name; + // Keep similarity overlay above the new image overlay. + if (overlayLayer && overlayLayer.bringToFront) { + overlayLayer.bringToFront(); + } + } + + function removeActiveOverlay() { + if (activeOverlay) { + map.removeLayer(activeOverlay); + activeOverlay = null; + } + activeOverlayKey = null; + } + + // Map click -> add point + map.on("click", (e) => { + if (!currentWindow) return; + points.push({ lat: e.latlng.lat, lon: e.latlng.lng, label: "positive" }); + renderPoints(); + maybeAutoRefresh(); + }); + + map.on("contextmenu", (e) => { + if (!currentWindow) return; + e.originalEvent.preventDefault(); + points.push({ lat: e.latlng.lat, lon: e.latlng.lng, label: "negative" }); + renderPoints(); + maybeAutoRefresh(); + }); + + // Points rendering + let pointMarkers = []; + + function renderPoints() { + pointMarkers.forEach((m) => map.removeLayer(m)); + pointMarkers = []; + pointsList.innerHTML = ""; + + points.forEach((p, i) => { + const color = p.label === "positive" ? "#22c55e" : "#ef4444"; + const marker = L.circleMarker([p.lat, p.lon], { + radius: 8, + fillColor: color, + color: "#fff", + weight: 2, + fillOpacity: 0.9, + }).addTo(map); + pointMarkers.push(marker); + + const div = document.createElement("div"); + div.className = "point-item"; + div.innerHTML = ` + ${p.label === "positive" ? "+" : "-"} + ${p.lat.toFixed(4)}, ${p.lon.toFixed(4)} + + + `; + pointsList.appendChild(div); + }); + + pointsList.querySelectorAll(".point-toggle").forEach((btn) => { + btn.addEventListener("click", () => { + const idx = parseInt(btn.dataset.idx); + points[idx].label = points[idx].label === "positive" ? "negative" : "positive"; + renderPoints(); + maybeAutoRefresh(); + }); + }); + pointsList.querySelectorAll(".point-delete").forEach((btn) => { + btn.addEventListener("click", () => { + const idx = parseInt(btn.dataset.idx); + points.splice(idx, 1); + renderPoints(); + if (points.length > 0) { + maybeAutoRefresh(); + } else { + clearSimilarityOverlay(); + } + }); + }); + + updateApplyButton(); + } + + clearBtn.addEventListener("click", () => { + points = []; + renderPoints(); + clearSimilarityOverlay(); + }); + + applyProbeBtn.addEventListener("click", () => { + requestSimilarity(); + }); + + // Similarity + function getMode() { + return document.querySelector('input[name="mode"]:checked').value; + } + + function getDisplay() { + return document.querySelector('input[name="display"]:checked').value; + } + + function maybeAutoRefresh() { + // Linear probe is only refreshed via the Apply button. + if (getMode() === "linear_probe") return; + if (points.length === 0) { + clearSimilarityOverlay(); + return; + } + requestSimilarity(); + } + + function requestSimilarity() { + if (!currentWindow || points.length === 0) return; + + const mode = getMode(); + const k = parseInt(kInput.value) || 3; + + if (mode === "cosine" && !points.some((p) => p.label === "positive")) return; + if (mode === "linear_probe") { + const hasPos = points.some((p) => p.label === "positive"); + const hasNeg = points.some((p) => p.label === "negative"); + if (!hasPos || !hasNeg) return; + } + + const win = WINDOWS_DATA[currentWindow]; + similarityBounds = mercatorBoundsToLatLng(win.mercator_bounds); + similarityMode = mode; + + const isProbe = mode === "linear_probe"; + if (isProbe) { + loadingSpinner.style.display = "flex"; + applyProbeBtn.disabled = true; + } + + fetch("/api/similarity", { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify({ + window: currentWindow, + mode: mode, + points: points, + k: k, + }), + }) + .then((resp) => { + if (!resp.ok) throw new Error(resp.statusText); + return resp.blob(); + }) + .then((blob) => { + const url = URL.createObjectURL(blob); + const img = new window.Image(); + img.onload = () => { + similarityImg = img; + renderSimilarityOverlay(); + if (isProbe) { + loadingSpinner.style.display = "none"; + updateApplyButton(); + } + }; + img.src = url; + }) + .catch((err) => { + console.error("Similarity request failed:", err); + if (isProbe) { + loadingSpinner.style.display = "none"; + updateApplyButton(); + } + }); + } + + function updateApplyButton() { + const hasPos = points.some((p) => p.label === "positive"); + const hasNeg = points.some((p) => p.label === "negative"); + applyProbeBtn.disabled = !(hasPos && hasNeg); + } + + function renderSimilarityOverlay() { + if (!similarityImg || !similarityBounds) return; + + const canvas = document.createElement("canvas"); + canvas.width = similarityImg.width; + canvas.height = similarityImg.height; + const ctx = canvas.getContext("2d"); + + ctx.drawImage(similarityImg, 0, 0); + const imageData = ctx.getImageData(0, 0, canvas.width, canvas.height); + const pixels = imageData.data; + + const display = getDisplay(); + const thresholdRaw = parseInt(thresholdSlider.value); + const opacity = parseInt(opacitySlider.value) / 100.0; + const [tr, tg, tb] = thresholdColor.value.split(",").map((v) => parseInt(v)); + + const output = ctx.createImageData(canvas.width, canvas.height); + const out = output.data; + + for (let i = 0; i < pixels.length; i += 4) { + const gray = pixels[i]; + + if (display === "threshold") { + if (gray >= thresholdRaw) { + out[i] = tr; + out[i + 1] = tg; + out[i + 2] = tb; + out[i + 3] = Math.round(255 * opacity); + } else { + out[i + 3] = 0; + } + } else { + // Gradient: blue -> yellow -> red + const t = gray / 255.0; + let baseAlpha; + if (t < 0.5) { + const s = t * 2; + out[i] = Math.round(s * 255); + out[i + 1] = Math.round(s * 255); + out[i + 2] = Math.round((1 - s) * 255); + baseAlpha = 50 + t * 300; + } else { + const s = (t - 0.5) * 2; + out[i] = 255; + out[i + 1] = Math.round((1 - s) * 255); + out[i + 2] = 0; + baseAlpha = 100 + s * 155; + } + out[i + 3] = Math.min(255, Math.round(baseAlpha * opacity)); + } + } + + ctx.putImageData(output, 0, 0); + const dataUrl = canvas.toDataURL("image/png"); + + if (overlayLayer) { + map.removeLayer(overlayLayer); + } + overlayLayer = L.imageOverlay(dataUrl, similarityBounds, { + opacity: 1.0, + interactive: false, + }); + if (showOverlayCheckbox.checked) { + overlayLayer.addTo(map); + } + } + + function clearSimilarityOverlay() { + similarityImg = null; + similarityBounds = null; + if (overlayLayer) { + map.removeLayer(overlayLayer); + overlayLayer = null; + } + } + + // Show/hide overlay toggle + const showOverlayCheckbox = document.getElementById("show-overlay"); + showOverlayCheckbox.addEventListener("change", () => { + if (overlayLayer) { + if (showOverlayCheckbox.checked) { + overlayLayer.addTo(map); + } else { + map.removeLayer(overlayLayer); + } + } + }); + + // Re-render on display setting changes (no network request) + thresholdSlider.addEventListener("input", () => { + const raw = parseInt(thresholdSlider.value); + let displayVal; + if (similarityMode === "cosine") { + displayVal = ((raw / 255.0) * 2.0 - 1.0).toFixed(3); + } else { + displayVal = (raw / 255.0).toFixed(3); + } + thresholdValue.textContent = displayVal; + renderSimilarityOverlay(); + }); + + function syncDisplayUi() { + const isThreshold = getDisplay() === "threshold"; + document.getElementById("threshold-color-group").style.display = isThreshold + ? "block" + : "none"; + } + + document.querySelectorAll('input[name="display"]').forEach((radio) => { + radio.addEventListener("change", () => { + syncDisplayUi(); + renderSimilarityOverlay(); + }); + }); + + thresholdColor.addEventListener("change", renderSimilarityOverlay); + + opacitySlider.addEventListener("input", () => { + opacityValue.textContent = `${opacitySlider.value}%`; + renderSimilarityOverlay(); + }); + + function syncModeUi() { + const mode = getMode(); + kGroup.style.display = mode === "knn" ? "block" : "none"; + linearProbeGroup.style.display = mode === "linear_probe" ? "block" : "none"; + if (mode !== "linear_probe") { + loadingSpinner.style.display = "none"; + } + } + + document.querySelectorAll('input[name="mode"]').forEach((radio) => { + radio.addEventListener("change", () => { + syncModeUi(); + const mode = getMode(); + if (mode === "linear_probe") { + // Switching into linear probe should not auto-refresh; clear any + // stale overlay from the previous mode so the user re-applies. + clearSimilarityOverlay(); + updateApplyButton(); + } else if (points.length > 0) { + requestSimilarity(); + } + }); + }); + + kInput.addEventListener("change", () => { + if (getMode() === "knn" && points.length > 0) requestSimilarity(); + }); + + // Initial state + syncModeUi(); + syncDisplayUi(); + loadingSpinner.style.display = "none"; + updateApplyButton(); + + if (Object.keys(WINDOWS_DATA).length > 0) { + windowSelect.value = Object.keys(WINDOWS_DATA)[0]; + selectWindow(windowSelect.value); + } +})(); diff --git a/rslp/embedding_explorer/static/style.css b/rslp/embedding_explorer/static/style.css new file mode 100644 index 000000000..14ee61aea --- /dev/null +++ b/rslp/embedding_explorer/static/style.css @@ -0,0 +1,228 @@ +* { + margin: 0; + padding: 0; + box-sizing: border-box; +} + +body { + font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, sans-serif; + display: flex; + height: 100vh; + overflow: hidden; +} + +#sidebar { + width: 320px; + min-width: 320px; + height: 100vh; + overflow-y: auto; + padding: 16px; + background: #1a1a2e; + color: #e0e0e0; + border-right: 1px solid #333; +} + +#sidebar h2 { + margin-bottom: 16px; + font-size: 18px; + color: #fff; +} + +#sidebar h3 { + margin-bottom: 8px; + font-size: 14px; + color: #ccc; +} + +#sidebar hr { + border: none; + border-top: 1px solid #333; + margin: 12px 0; +} + +#map { + flex: 1; + height: 100vh; +} + +.control-group { + margin-bottom: 12px; +} + +.control-group > label { + display: block; + font-size: 12px; + color: #aaa; + margin-bottom: 4px; + text-transform: uppercase; + letter-spacing: 0.5px; +} + +select, input[type="number"] { + width: 100%; + padding: 6px 8px; + background: #16213e; + border: 1px solid #444; + color: #e0e0e0; + border-radius: 4px; + font-size: 13px; +} + +select:focus, input:focus { + outline: none; + border-color: #667eea; +} + +input[type="range"] { + width: 100%; + margin-top: 4px; +} + +.radio-group { + display: flex; + gap: 12px; +} + +.radio-group label { + font-size: 13px; + cursor: pointer; + display: flex; + align-items: center; + gap: 4px; +} + +.layer-toggle { + display: flex; + align-items: center; + gap: 6px; + margin-bottom: 4px; + font-size: 12px; +} + +.layer-toggle label { + cursor: pointer; + word-break: break-all; +} + +.checkbox-label { + font-size: 13px; + cursor: pointer; + display: flex; + align-items: center; + gap: 6px; +} + +.hint { + font-size: 11px; + color: #888; + margin-bottom: 8px; +} + +#points-list { + max-height: 200px; + overflow-y: auto; + margin-bottom: 8px; +} + +.point-item { + display: flex; + align-items: center; + gap: 6px; + padding: 4px 0; + font-size: 12px; + border-bottom: 1px solid #222; +} + +.point-label { + display: inline-block; + width: 18px; + height: 18px; + line-height: 18px; + text-align: center; + border-radius: 50%; + font-weight: bold; + font-size: 14px; +} + +.point-label.positive { + background: #22c55e; + color: #fff; +} + +.point-label.negative { + background: #ef4444; + color: #fff; +} + +.point-coords { + flex: 1; + color: #999; + font-family: monospace; + font-size: 11px; +} + +.point-toggle, .point-delete { + background: none; + border: 1px solid #555; + color: #ccc; + padding: 2px 6px; + border-radius: 3px; + cursor: pointer; + font-size: 11px; +} + +.point-toggle:hover, .point-delete:hover { + background: #333; + color: #fff; +} + +.btn { + display: block; + width: 100%; + padding: 8px; + background: #16213e; + border: 1px solid #444; + color: #e0e0e0; + border-radius: 4px; + cursor: pointer; + font-size: 13px; +} + +.btn:hover { + background: #1a3a5c; + border-color: #667eea; +} + +.btn:disabled { + opacity: 0.5; + cursor: not-allowed; +} + +.btn:disabled:hover { + background: #16213e; + border-color: #444; +} + +.spinner-row { + display: none; + align-items: center; + gap: 8px; + margin-top: 8px; + font-size: 12px; + color: #ccc; +} + +.spinner { + width: 14px; + height: 14px; + border: 2px solid #444; + border-top-color: #667eea; + border-radius: 50%; + animation: spinner-rotate 0.8s linear infinite; +} + +@keyframes spinner-rotate { + to { + transform: rotate(360deg); + } +} diff --git a/rslp/embedding_explorer/templates/index.html b/rslp/embedding_explorer/templates/index.html new file mode 100644 index 000000000..25b50bfab --- /dev/null +++ b/rslp/embedding_explorer/templates/index.html @@ -0,0 +1,110 @@ + + + + + + Embedding Similarity Explorer + + + + + + +
+ + + + + + From 046ff1cddaf8fc0ab5d48ebbb5b1f9d03764e922 Mon Sep 17 00:00:00 2001 From: Favyen Bastani Date: Thu, 21 May 2026 14:15:54 -0700 Subject: [PATCH 2/6] aef support --- rslp/embedding_explorer/README.md | 42 +++++++- rslp/embedding_explorer/app.py | 105 +++++++++++-------- rslp/embedding_explorer/static/app.js | 37 +++++-- rslp/embedding_explorer/templates/index.html | 7 +- 4 files changed, 139 insertions(+), 52 deletions(-) diff --git a/rslp/embedding_explorer/README.md b/rslp/embedding_explorer/README.md index fb7881da5..8b499969a 100644 --- a/rslp/embedding_explorer/README.md +++ b/rslp/embedding_explorer/README.md @@ -75,14 +75,39 @@ rslearn dataset materialize --root $DATASET_PATH --workers 32 \ --retry-max-attempts 5 --retry-backoff-seconds 5 ``` -### 3. Compute embeddings +### 2b. Materialize AEF embeddings (optional) + +AEF provides pre-computed 64-dimensional satellite embeddings at 10m resolution +from [source.coop/tge-labs/aef](https://source.coop/tge-labs/aef). Use +`config_with_aef.json` instead of `config.json`: + +```bash +cp rslp/embedding_explorer/config_with_aef.json $DATASET_PATH/config.json +rslearn dataset prepare --root $DATASET_PATH --workers 32 --enabled-layers aef +rslearn dataset materialize --root $DATASET_PATH --workers 32 \ + --no-use-initial-job --enabled-layers aef +``` + +### 3. Compute OlmoEarth embeddings (40m, default) ```bash rslearn model predict --config rslp/embedding_explorer/config.yaml ``` After this finishes, each window will have a populated `layers/embeddings/` -directory with a 768-band GeoTIFF. +directory with a 768-band GeoTIFF at 40m/pixel (patch_size=4). + +### 3b. Compute OlmoEarth embeddings at 10m (optional) + +For a fair comparison with AEF (both at 10m), use `config_olmoearth_10m.yaml` +which runs OlmoEarth-v1-Base with `patch_size=1` (one embedding per 10m pixel): + +```bash +rslearn model predict --config rslp/embedding_explorer/config_olmoearth_10m.yaml +``` + +This writes to the same `embeddings` layer as the 40m config (so only run one +or the other per dataset). ## Run the app @@ -101,6 +126,19 @@ python -m rslp.embedding_explorer.app \ --port 5000 ``` +To load multiple embedding layers (e.g. OlmoEarth + AEF), pass them all: + +```bash +python -m rslp.embedding_explorer.app \ + --dataset-path $DATASET_PATH \ + --embedding-layer embeddings aef \ + --port 5000 +``` + +When multiple layers are loaded, a dropdown appears in the sidebar to select +which embedding is used for similarity queries. Each layer can have a different +resolution — the overlay adapts to the selected layer's grid. + Open `http://localhost:5000` and pick a window from the sidebar. Clicking on the map adds points; in cosine/KNN modes the overlay updates immediately, and in Linear Probe mode you press **Apply** to (re)train. diff --git a/rslp/embedding_explorer/app.py b/rslp/embedding_explorer/app.py index c14fbf7a8..a2c6ee79e 100644 --- a/rslp/embedding_explorer/app.py +++ b/rslp/embedding_explorer/app.py @@ -2,6 +2,8 @@ Usage: python -m rslp.embedding_explorer.app --dataset-path /path/to/dataset --port 5000 + python -m rslp.embedding_explorer.app --dataset-path /path/to/dataset \ + --embedding-layer embeddings aef --port 5000 """ import argparse @@ -237,10 +239,10 @@ def similarity_to_png(similarity: np.ndarray, mode: str) -> bytes: return buf.getvalue() -def load_dataset(dataset_path: Path, embedding_layer: str) -> dict: +def load_dataset(dataset_path: Path, embedding_layers: list[str]) -> dict: """Load dataset: embeddings into memory, image layer paths stored for on-demand serving.""" windows_dir = dataset_path / "windows" - dataset: dict = {"windows": {}, "embedding_layer": embedding_layer} + dataset: dict = {"windows": {}, "embedding_layers": embedding_layers} for group_dir in sorted(windows_dir.iterdir()): if not group_dir.is_dir(): @@ -260,9 +262,9 @@ def load_dataset(dataset_path: Path, embedding_layer: str) -> dict: continue window_key = f"{group_dir.name}/{window_dir.name}" - window_info = { + window_info: dict = { "metadata": metadata, - "embeddings": None, + "embeddings": {}, "image_layers": {}, "path": window_dir, } @@ -277,31 +279,32 @@ def load_dataset(dataset_path: Path, embedding_layer: str) -> dict: if tif is None: continue - if base_name == embedding_layer: + if base_name in embedding_layers: print(f" Loading embedding: {window_key}/{layer_name}") with rasterio.open(tif) as src: data = src.read().astype(np.float32) src_crs = str(src.crs) src_transform = src.transform src_shape = (src.height, src.width) - window_info["embeddings"] = data - window_info["embedding_crs"] = src_crs - window_info["embedding_transform"] = src_transform - window_info["embedding_shape"] = src_shape print(f" shape: {data.shape}") - # Pre-compute Web Mercator transform for serving print(" computing Web Mercator reprojection info...") _, wm_transform, wm_shape = reproject_to_webmercator( data[:1], src_crs, src_transform, src_shape ) - window_info["wm_transform"] = wm_transform - window_info["wm_shape"] = wm_shape - wm_bounds = webmercator_bounds(wm_transform, wm_shape) - window_info["wm_bounds"] = wm_bounds # (left, bottom, right, top) + wm_b = webmercator_bounds(wm_transform, wm_shape) print(f" Web Mercator shape: {wm_shape}") + + window_info["embeddings"][base_name] = { + "data": data, + "crs": src_crs, + "transform": src_transform, + "shape": src_shape, + "wm_transform": wm_transform, + "wm_shape": wm_shape, + "wm_bounds": wm_b, + } else: - # Compute WM bounds for this image layer at startup with rasterio.open(tif) as src: layer_crs = str(src.crs) layer_transform = src.transform @@ -330,16 +333,15 @@ def load_dataset(dataset_path: Path, embedding_layer: str) -> dict: } ) - if window_info["embeddings"] is None: - print( - f" WARNING: no '{embedding_layer}' layer found in {window_key}, skipping" - ) + if not window_info["embeddings"]: + print(f" WARNING: no embedding layers found in {window_key}, skipping") continue dataset["windows"][window_key] = window_info print(f"Loaded window: {window_key}") - print(f" mercator bounds: {window_info['wm_bounds']}") - print(f" embedding shape: {window_info['embeddings'].shape}") + for emb_name, emb_info in window_info["embeddings"].items(): + print(f" embedding '{emb_name}': shape={emb_info['data'].shape}") + print(f" mercator bounds: {emb_info['wm_bounds']}") print(f" image layers: {list(window_info['image_layers'].keys())}") return dataset @@ -367,8 +369,10 @@ def reproject_single_band_to_wm( return dst -def create_app(dataset_path: Path, embedding_layer: str = "embeddings") -> Flask: +def create_app(dataset_path: Path, embedding_layers: list[str] | None = None) -> Flask: """Create and configure the Flask application.""" + if embedding_layers is None: + embedding_layers = ["embeddings"] app_dir = Path(__file__).parent app = Flask( __name__, @@ -377,17 +381,20 @@ def create_app(dataset_path: Path, embedding_layer: str = "embeddings") -> Flask ) print(f"Loading dataset from {dataset_path}...") - print(f"Embedding layer: {embedding_layer}") - dataset = load_dataset(dataset_path, embedding_layer) + print(f"Embedding layers: {embedding_layers}") + dataset = load_dataset(dataset_path, embedding_layers) print(f"Loaded {len(dataset['windows'])} window(s)") @app.route("/") def index() -> str: windows_info: dict = {} for key, win in dataset["windows"].items(): - wm = win["wm_bounds"] + # Use the first available embedding layer's bounds as the window bounds + first_emb = next(iter(win["embeddings"].values())) + wm = first_emb["wm_bounds"] windows_info[key] = { "mercator_bounds": [wm[0], wm[1], wm[2], wm[3]], + "embedding_layers": list(win["embeddings"].keys()), "image_layers": { name: [ { @@ -402,7 +409,7 @@ def index() -> str: return render_template( "index.html", windows=json.dumps(windows_info), - embedding_layer=embedding_layer, + embedding_layers=json.dumps(embedding_layers), ) @app.route("/api/image/") @@ -427,14 +434,12 @@ def serve_image(layer_path: str) -> Response | tuple[str, int]: bands = tuple(int(b) for b in bands_param.split(",")) # Serve embedding RGB from memory (reproject to WM) - if layer_name == dataset["embedding_layer"]: - data = win["embeddings"] + if layer_name in win["embeddings"]: + emb = win["embeddings"][layer_name] + data = emb["data"] selected = np.stack([data[b] for b in bands]) reprojected, wm_transform, wm_shape = reproject_to_webmercator( - selected, - win["embedding_crs"], - win["embedding_transform"], - win["embedding_shape"], + selected, emb["crs"], emb["transform"], emb["shape"] ) png = render_rgb_png_with_alpha(reprojected, (0, 1, 2)) left, bottom, right, top = webmercator_bounds(wm_transform, wm_shape) @@ -486,29 +491,37 @@ def compute_similarity_route() -> Response | tuple[str, int]: """Compute similarity in native CRS, reproject result to Web Mercator. Request JSON: - mode: "cosine" or "knn" + mode: "cosine" or "knn" or "linear_probe" points: list of {lat, lon, label} k: int (for knn mode) window: window key + layer: embedding layer name (default: first available) """ body = request.get_json() window_key = body.get("window") mode = body.get("mode", "cosine") points = body.get("points", []) k = body.get("k", 3) + layer_name = body.get("layer") win = dataset["windows"].get(window_key) if win is None: return "Window not found", 404 - embeddings = win["embeddings"] - if not points: - return "No points provided", 400 + if not layer_name: + layer_name = next(iter(win["embeddings"])) + if layer_name not in win["embeddings"]: + return f"Embedding layer '{layer_name}' not found", 404 - crs = win["embedding_crs"] - transform = win["embedding_transform"] + emb = win["embeddings"][layer_name] + embeddings = emb["data"] + crs = emb["crs"] + transform = emb["transform"] shape = embeddings.shape[1:] + if not points: + return "No points provided", 400 + if mode == "cosine": pos_points = [p for p in points if p.get("label") == "positive"] if not pos_points: @@ -549,12 +562,17 @@ def compute_similarity_route() -> Response | tuple[str, int]: # Reproject similarity to Web Mercator sim_wm = reproject_single_band_to_wm( - similarity, crs, transform, shape, win["wm_transform"], win["wm_shape"] + similarity, + crs, + transform, + shape, + emb["wm_transform"], + emb["wm_shape"], ) png = similarity_to_png(sim_wm, mode) left, bottom, right, top = webmercator_bounds( - win["wm_transform"], win["wm_shape"] + emb["wm_transform"], emb["wm_shape"] ) return Response( png, @@ -574,14 +592,15 @@ def main() -> None: parser.add_argument("--dataset-path", type=Path, required=True) parser.add_argument( "--embedding-layer", - default="embeddings", - help="Layer name to load as embeddings (default: 'embeddings')", + nargs="+", + default=["embeddings"], + help="Embedding layer name(s) to load (default: 'embeddings')", ) parser.add_argument("--port", type=int, default=5000) parser.add_argument("--host", default="0.0.0.0") args = parser.parse_args() - app = create_app(args.dataset_path, embedding_layer=args.embedding_layer) + app = create_app(args.dataset_path, embedding_layers=args.embedding_layer) app.run(host=args.host, port=args.port, debug=False) diff --git a/rslp/embedding_explorer/static/app.js b/rslp/embedding_explorer/static/app.js index e3484be5a..bd1f0caee 100644 --- a/rslp/embedding_explorer/static/app.js +++ b/rslp/embedding_explorer/static/app.js @@ -50,6 +50,8 @@ // DOM refs const windowSelect = document.getElementById("window-select"); const layerToggles = document.getElementById("layer-toggles"); + const embeddingSelect = document.getElementById("embedding-select"); + const embeddingSelectGroup = document.getElementById("embedding-select-group"); const pointsList = document.getElementById("points-list"); const clearBtn = document.getElementById("clear-points"); const thresholdSlider = document.getElementById("threshold-slider"); @@ -101,10 +103,13 @@ addLayerRadio("OpenStreetMap", { type: "tile", source: "osm" }); addLayerRadio("Bing Maps", { type: "tile", source: "bing" }); - addLayerRadio(EMBEDDING_LAYER + " (RGB)", { - type: "image", - url: `/api/image/${key}/${EMBEDDING_LAYER}?bands=0,1,2`, - mercatorBounds: win.mercator_bounds, + // Add RGB overlay for each embedding layer + (win.embedding_layers || []).forEach((embName) => { + addLayerRadio(embName + " (RGB)", { + type: "image", + url: `/api/image/${key}/${embName}?bands=0,1,2`, + mercatorBounds: win.mercator_bounds, + }); }); Object.entries(win.image_layers).forEach(([baseName, groups]) => { @@ -118,6 +123,18 @@ }); }); + // Populate embedding select dropdown + embeddingSelect.innerHTML = ""; + (win.embedding_layers || []).forEach((embName) => { + const opt = document.createElement("option"); + opt.value = embName; + opt.textContent = embName; + embeddingSelect.appendChild(opt); + }); + // Show dropdown only if multiple layers available + embeddingSelectGroup.style.display = + (win.embedding_layers || []).length > 1 ? "block" : "none"; + // Restore previous selection if still available, otherwise default to OSM. const targetKey = previousKey && layerRadios[previousKey] @@ -287,8 +304,6 @@ if (!hasPos || !hasNeg) return; } - const win = WINDOWS_DATA[currentWindow]; - similarityBounds = mercatorBoundsToLatLng(win.mercator_bounds); similarityMode = mode; const isProbe = mode === "linear_probe"; @@ -305,10 +320,16 @@ mode: mode, points: points, k: k, + layer: embeddingSelect.value || undefined, }), }) .then((resp) => { if (!resp.ok) throw new Error(resp.statusText); + const boundsHeader = resp.headers.get("X-Mercator-Bounds"); + if (boundsHeader) { + const parts = boundsHeader.split(",").map(Number); + similarityBounds = mercatorBoundsToLatLng(parts); + } return resp.blob(); }) .then((blob) => { @@ -490,6 +511,10 @@ if (getMode() === "knn" && points.length > 0) requestSimilarity(); }); + embeddingSelect.addEventListener("change", () => { + if (points.length > 0) maybeAutoRefresh(); + }); + // Initial state syncModeUi(); syncDisplayUi(); diff --git a/rslp/embedding_explorer/templates/index.html b/rslp/embedding_explorer/templates/index.html index 25b50bfab..0a78b06ce 100644 --- a/rslp/embedding_explorer/templates/index.html +++ b/rslp/embedding_explorer/templates/index.html @@ -21,6 +21,11 @@

Embedding Explorer

+
+ + +
+

Similarity

@@ -103,7 +108,7 @@

Points

From 0c9193ed4b409c929422af577a7f3750a5430995 Mon Sep 17 00:00:00 2001 From: Favyen Bastani Date: Thu, 21 May 2026 14:18:04 -0700 Subject: [PATCH 3/6] add missing files --- .../config_olmoearth_10m.yaml | 50 +++++++ rslp/embedding_explorer/config_with_aef.json | 133 ++++++++++++++++++ 2 files changed, 183 insertions(+) create mode 100644 rslp/embedding_explorer/config_olmoearth_10m.yaml create mode 100644 rslp/embedding_explorer/config_with_aef.json diff --git a/rslp/embedding_explorer/config_olmoearth_10m.yaml b/rslp/embedding_explorer/config_olmoearth_10m.yaml new file mode 100644 index 000000000..82e59d2eb --- /dev/null +++ b/rslp/embedding_explorer/config_olmoearth_10m.yaml @@ -0,0 +1,50 @@ +model: + class_path: rslearn.train.lightning_module.RslearnLightningModule + init_args: + model: + class_path: rslearn.models.singletask.SingleTaskModel + init_args: + encoder: + - class_path: rslearn.models.olmoearth_pretrain.model.OlmoEarth + init_args: + model_id: OLMOEARTH_V1_BASE + patch_size: 1 + decoder: + - class_path: rslearn.train.tasks.embedding.EmbeddingHead + optimizer: + class_path: rslearn.train.optimizer.AdamW +data: + class_path: rslearn.train.data_module.RslearnDataModule + init_args: + path: ${DATASET_PATH} + inputs: + sentinel2_l2a: + data_type: "raster" + layers: ["sentinel2_l2a", "sentinel2_l2a.1", "sentinel2_l2a.2", "sentinel2_l2a.3"] + bands: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12", "B01", "B09"] + passthrough: true + dtype: FLOAT32 + load_all_layers: true + task: + class_path: rslearn.train.tasks.embedding.EmbeddingTask + batch_size: 8 + num_workers: 32 + predict_config: + transforms: + - class_path: rslearn.models.olmoearth_pretrain.norm.OlmoEarthNormalize + init_args: + band_names: + sentinel2_l2a: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12", "B01", "B09"] + load_all_crops: true + crop_size: 16 + overlap_pixels: 8 +trainer: + callbacks: + - class_path: rslearn.train.prediction_writer.RslearnWriter + init_args: + output_layer: embeddings + merger: + class_path: rslearn.train.prediction_writer.RasterMerger + init_args: + overlap_pixels: 8 + downsample_factor: 1 diff --git a/rslp/embedding_explorer/config_with_aef.json b/rslp/embedding_explorer/config_with_aef.json new file mode 100644 index 000000000..6f5b69b47 --- /dev/null +++ b/rslp/embedding_explorer/config_with_aef.json @@ -0,0 +1,133 @@ +{ + "layers": { + "aef": { + "band_sets": [ + { + "bands": [ + "A00", + "A01", + "A02", + "A03", + "A04", + "A05", + "A06", + "A07", + "A08", + "A09", + "A10", + "A11", + "A12", + "A13", + "A14", + "A15", + "A16", + "A17", + "A18", + "A19", + "A20", + "A21", + "A22", + "A23", + "A24", + "A25", + "A26", + "A27", + "A28", + "A29", + "A30", + "A31", + "A32", + "A33", + "A34", + "A35", + "A36", + "A37", + "A38", + "A39", + "A40", + "A41", + "A42", + "A43", + "A44", + "A45", + "A46", + "A47", + "A48", + "A49", + "A50", + "A51", + "A52", + "A53", + "A54", + "A55", + "A56", + "A57", + "A58", + "A59", + "A60", + "A61", + "A62", + "A63" + ], + "dtype": "float32" + } + ], + "data_source": { + "class_path": "rslearn.data_sources.aws_google_satellite_embedding_v1.GoogleSatelliteEmbeddingV1", + "ingest": false, + "init_args": { + "metadata_cache_dir": "cache/aef" + }, + "query_config": { + "max_matches": 1 + } + }, + "type": "raster" + }, + "embeddings": { + "band_sets": [ + { + "dtype": "float32", + "num_bands": 768 + } + ], + "type": "raster" + }, + "sentinel2_l2a": { + "band_sets": [ + { + "bands": [ + "B01", + "B02", + "B03", + "B04", + "B05", + "B06", + "B07", + "B08", + "B8A", + "B09", + "B11", + "B12" + ], + "dtype": "uint16" + } + ], + "data_source": { + "class_path": "rslearn.data_sources.planetary_computer.Sentinel2", + "ingest": false, + "init_args": { + "cache_dir": "cache/planetary_computer", + "harmonize": true, + "sort_by": "eo:cloud_cover" + }, + "query_config": { + "max_matches": 12, + "period_duration": "30d", + "space_mode": "PER_PERIOD_MOSAIC" + } + }, + "type": "raster" + } + } +} From 7ae6b12bd3de7e66b9402b0486158dca54edd69d Mon Sep 17 00:00:00 2001 From: Favyen Bastani Date: Sat, 6 Jun 2026 20:42:52 -0700 Subject: [PATCH 4/6] add presto option --- rslp/embedding_explorer/README.md | 33 +++++++++++ rslp/embedding_explorer/config_presto.yaml | 44 ++++++++++++++ .../config_with_presto.json | 58 +++++++++++++++++++ 3 files changed, 135 insertions(+) create mode 100644 rslp/embedding_explorer/config_presto.yaml create mode 100644 rslp/embedding_explorer/config_with_presto.json diff --git a/rslp/embedding_explorer/README.md b/rslp/embedding_explorer/README.md index 8b499969a..fccc6429e 100644 --- a/rslp/embedding_explorer/README.md +++ b/rslp/embedding_explorer/README.md @@ -30,6 +30,10 @@ reference configs sit next to this README: Planetary Computer + an `embeddings` layer). - [config.yaml](config.yaml) — model config that runs OlmoEarth and writes outputs into the `embeddings` layer via `RslearnWriter`. +- [config_with_presto.json](config_with_presto.json) — dataset config that also + declares a 128-dimensional `presto` embedding layer. +- [config_presto.yaml](config_presto.yaml) — model config that runs Presto and + writes outputs into the `presto` layer via `RslearnWriter`. For a fuller treatment of these configs (Sentinel-1 / Landsat layers, model sizes, etc.) see @@ -109,6 +113,26 @@ rslearn model predict --config rslp/embedding_explorer/config_olmoearth_10m.yaml This writes to the same `embeddings` layer as the 40m config (so only run one or the other per dataset). +### 3c. Compute Presto embeddings (optional) + +Presto is supported in rslearn as `rslearn.models.presto.Presto`. Use +`config_with_presto.json` instead of `config.json` when creating the dataset, +then materialize Sentinel-2 as above and run: + +```bash +rslearn model predict --config rslp/embedding_explorer/config_presto.yaml +``` + +This writes 128-dimensional embeddings at 10m/pixel to the `presto` layer. The +provided config uses Sentinel-2 only, matching the default explorer dataset. +Presto can also consume Sentinel-1 when the dataset has a compatible `s1` input. +The first run may need to download or otherwise populate the Presto checkpoint +cache used by `rslearn.models.presto.Presto`. + +Tessera is not currently available as an rslearn model config in this project. +There is a Tessera eval wrapper in `olmoearth_pretrain`, but it has not been +adapted to rslearn's `FeatureExtractor` interface for `rslearn model predict`. + ## Run the app The app needs Flask and scikit-learn (neither are pulled in by `rslp`'s base @@ -139,6 +163,15 @@ When multiple layers are loaded, a dropdown appears in the sidebar to select which embedding is used for similarity queries. Each layer can have a different resolution — the overlay adapts to the selected layer's grid. +For example, to compare OlmoEarth and Presto: + +```bash +python -m rslp.embedding_explorer.app \ + --dataset-path $DATASET_PATH \ + --embedding-layer embeddings presto \ + --port 5000 +``` + Open `http://localhost:5000` and pick a window from the sidebar. Clicking on the map adds points; in cosine/KNN modes the overlay updates immediately, and in Linear Probe mode you press **Apply** to (re)train. diff --git a/rslp/embedding_explorer/config_presto.yaml b/rslp/embedding_explorer/config_presto.yaml new file mode 100644 index 000000000..9f651056d --- /dev/null +++ b/rslp/embedding_explorer/config_presto.yaml @@ -0,0 +1,44 @@ +model: + class_path: rslearn.train.lightning_module.RslearnLightningModule + init_args: + model: + class_path: rslearn.models.singletask.SingleTaskModel + init_args: + encoder: + - class_path: rslearn.models.presto.Presto + init_args: + pixel_batch_size: 4096 + decoder: + - class_path: rslearn.train.tasks.embedding.EmbeddingHead + optimizer: + class_path: rslearn.train.optimizer.AdamW +data: + class_path: rslearn.train.data_module.RslearnDataModule + init_args: + path: ${DATASET_PATH} + inputs: + s2: + data_type: "raster" + layers: ["sentinel2_l2a", "sentinel2_l2a.1", "sentinel2_l2a.2", "sentinel2_l2a.3"] + bands: ["B02", "B03", "B04", "B05", "B06", "B07", "B08", "B8A", "B11", "B12"] + passthrough: true + dtype: FLOAT32 + load_all_layers: true + task: + class_path: rslearn.train.tasks.embedding.EmbeddingTask + batch_size: 8 + num_workers: 32 + predict_config: + load_all_crops: true + crop_size: 64 + overlap_pixels: 0 +trainer: + callbacks: + - class_path: rslearn.train.prediction_writer.RslearnWriter + init_args: + output_layer: presto + merger: + class_path: rslearn.train.prediction_writer.RasterMerger + init_args: + overlap_pixels: 0 + downsample_factor: 1 diff --git a/rslp/embedding_explorer/config_with_presto.json b/rslp/embedding_explorer/config_with_presto.json new file mode 100644 index 000000000..a07ff3621 --- /dev/null +++ b/rslp/embedding_explorer/config_with_presto.json @@ -0,0 +1,58 @@ +{ + "layers": { + "embeddings": { + "band_sets": [ + { + "dtype": "float32", + "num_bands": 768 + } + ], + "type": "raster" + }, + "presto": { + "band_sets": [ + { + "dtype": "float32", + "num_bands": 128 + } + ], + "type": "raster" + }, + "sentinel2_l2a": { + "band_sets": [ + { + "bands": [ + "B01", + "B02", + "B03", + "B04", + "B05", + "B06", + "B07", + "B08", + "B8A", + "B09", + "B11", + "B12" + ], + "dtype": "uint16" + } + ], + "data_source": { + "class_path": "rslearn.data_sources.planetary_computer.Sentinel2", + "ingest": false, + "init_args": { + "cache_dir": "cache/planetary_computer", + "harmonize": true, + "sort_by": "eo:cloud_cover" + }, + "query_config": { + "max_matches": 12, + "period_duration": "30d", + "space_mode": "PER_PERIOD_MOSAIC" + } + }, + "type": "raster" + } + } +} From 327fa86d5b191a2bed8616279ff75858c90162b5 Mon Sep 17 00:00:00 2001 From: Favyen Bastani Date: Tue, 16 Jun 2026 10:43:14 -0700 Subject: [PATCH 5/6] add tessera embeddings, move configs to data/embedding_explorer/ --- {rslp => data}/embedding_explorer/config.json | 23 ++- {rslp => data}/embedding_explorer/config.yaml | 10 +- .../config_olmoearth_10m.yaml | 10 +- .../config_olmoearth_v1_1_base_ps1.yaml | 56 ++++++ .../config_olmoearth_v1_1_base_ps4.yaml | 56 ++++++ .../config_olmoearth_v1_1_nano_ps1.yaml | 56 ++++++ .../config_olmoearth_v1_1_nano_ps4.yaml | 56 ++++++ .../embedding_explorer/config_presto.yaml | 12 +- data/embedding_explorer/config_tessera.yaml | 74 ++++++++ .../embedding_explorer/config_with_aef.json | 23 ++- .../config_with_tessera.json | 108 +++++++++++ rslp/embedding_explorer/README.md | 169 +++++++++++++----- rslp/embedding_explorer/app.py | 12 +- .../config_with_presto.json | 58 ------ 14 files changed, 586 insertions(+), 137 deletions(-) rename {rslp => data}/embedding_explorer/config.json (59%) rename {rslp => data}/embedding_explorer/config.yaml (86%) rename {rslp => data}/embedding_explorer/config_olmoearth_10m.yaml (86%) create mode 100644 data/embedding_explorer/config_olmoearth_v1_1_base_ps1.yaml create mode 100644 data/embedding_explorer/config_olmoearth_v1_1_base_ps4.yaml create mode 100644 data/embedding_explorer/config_olmoearth_v1_1_nano_ps1.yaml create mode 100644 data/embedding_explorer/config_olmoearth_v1_1_nano_ps4.yaml rename {rslp => data}/embedding_explorer/config_presto.yaml (82%) create mode 100644 data/embedding_explorer/config_tessera.yaml rename {rslp => data}/embedding_explorer/config_with_aef.json (84%) create mode 100644 data/embedding_explorer/config_with_tessera.json delete mode 100644 rslp/embedding_explorer/config_with_presto.json diff --git a/rslp/embedding_explorer/config.json b/data/embedding_explorer/config.json similarity index 59% rename from rslp/embedding_explorer/config.json rename to data/embedding_explorer/config.json index 34ecedc80..790bab59b 100644 --- a/rslp/embedding_explorer/config.json +++ b/data/embedding_explorer/config.json @@ -1,14 +1,5 @@ { "layers": { - "embeddings": { - "band_sets": [ - { - "dtype": "float32", - "num_bands": 768 - } - ], - "type": "raster" - }, "sentinel2_l2a": { "band_sets": [ { @@ -30,17 +21,23 @@ } ], "data_source": { - "class_path": "rslearn.data_sources.planetary_computer.Sentinel2", + "class_path": "olmoearth_run.runner.tools.rslearn_data_sources.olmoearth_datasets.sentinel2_l2a.Sentinel2L2A", "ingest": false, "init_args": { - "cache_dir": "cache/planetary_computer", + "cache_dir": "cache/olmoearth_datasets", "harmonize": true, - "sort_by": "eo:cloud_cover" + "query": { + "sort_by": "CLOUD_COVER", + "sort_direction": "ASC" + }, + "timeout": "0:0:10" }, "query_config": { "max_matches": 12, + "min_matches": 12, + "per_period_mosaic_reverse_time_order": false, "period_duration": "30d", - "space_mode": "PER_PERIOD_MOSAIC" + "space_mode": "MOSAIC" } }, "type": "raster" diff --git a/rslp/embedding_explorer/config.yaml b/data/embedding_explorer/config.yaml similarity index 86% rename from rslp/embedding_explorer/config.yaml rename to data/embedding_explorer/config.yaml index 021a37945..7d9203444 100644 --- a/rslp/embedding_explorer/config.yaml +++ b/data/embedding_explorer/config.yaml @@ -20,11 +20,12 @@ data: inputs: sentinel2_l2a: data_type: "raster" - layers: ["sentinel2_l2a", "sentinel2_l2a.1", "sentinel2_l2a.2", "sentinel2_l2a.3"] + layers: ["sentinel2_l2a"] bands: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12", "B01", "B09"] passthrough: true dtype: FLOAT32 load_all_layers: true + load_all_item_groups: true task: class_path: rslearn.train.tasks.embedding.EmbeddingTask batch_size: 8 @@ -42,7 +43,12 @@ trainer: callbacks: - class_path: rslearn.train.prediction_writer.RslearnWriter init_args: - output_layer: embeddings + output_layer: embeddings_olmoearth_v1_base_ps4_ws64 + layer_config: + type: RASTER + band_sets: + - dtype: FLOAT32 + num_bands: 768 merger: class_path: rslearn.train.prediction_writer.RasterMerger init_args: diff --git a/rslp/embedding_explorer/config_olmoearth_10m.yaml b/data/embedding_explorer/config_olmoearth_10m.yaml similarity index 86% rename from rslp/embedding_explorer/config_olmoearth_10m.yaml rename to data/embedding_explorer/config_olmoearth_10m.yaml index 82e59d2eb..023abaa3e 100644 --- a/rslp/embedding_explorer/config_olmoearth_10m.yaml +++ b/data/embedding_explorer/config_olmoearth_10m.yaml @@ -20,11 +20,12 @@ data: inputs: sentinel2_l2a: data_type: "raster" - layers: ["sentinel2_l2a", "sentinel2_l2a.1", "sentinel2_l2a.2", "sentinel2_l2a.3"] + layers: ["sentinel2_l2a"] bands: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12", "B01", "B09"] passthrough: true dtype: FLOAT32 load_all_layers: true + load_all_item_groups: true task: class_path: rslearn.train.tasks.embedding.EmbeddingTask batch_size: 8 @@ -42,7 +43,12 @@ trainer: callbacks: - class_path: rslearn.train.prediction_writer.RslearnWriter init_args: - output_layer: embeddings + output_layer: embeddings_olmoearth_v1_base_ps1_ws16 + layer_config: + type: RASTER + band_sets: + - dtype: FLOAT32 + num_bands: 768 merger: class_path: rslearn.train.prediction_writer.RasterMerger init_args: diff --git a/data/embedding_explorer/config_olmoearth_v1_1_base_ps1.yaml b/data/embedding_explorer/config_olmoearth_v1_1_base_ps1.yaml new file mode 100644 index 000000000..35c380789 --- /dev/null +++ b/data/embedding_explorer/config_olmoearth_v1_1_base_ps1.yaml @@ -0,0 +1,56 @@ +model: + class_path: rslearn.train.lightning_module.RslearnLightningModule + init_args: + model: + class_path: rslearn.models.singletask.SingleTaskModel + init_args: + encoder: + - class_path: rslearn.models.olmoearth_pretrain.model.OlmoEarth + init_args: + model_id: OLMOEARTH_V1_1_BASE + patch_size: 1 + decoder: + - class_path: rslearn.train.tasks.embedding.EmbeddingHead + optimizer: + class_path: rslearn.train.optimizer.AdamW +data: + class_path: rslearn.train.data_module.RslearnDataModule + init_args: + path: ${DATASET_PATH} + inputs: + sentinel2_l2a: + data_type: "raster" + layers: ["sentinel2_l2a"] + bands: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12", "B01", "B09"] + passthrough: true + dtype: FLOAT32 + load_all_layers: true + load_all_item_groups: true + task: + class_path: rslearn.train.tasks.embedding.EmbeddingTask + batch_size: 8 + num_workers: 32 + predict_config: + transforms: + - class_path: rslearn.models.olmoearth_pretrain.norm.OlmoEarthNormalize + init_args: + band_names: + sentinel2_l2a: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12", "B01", "B09"] + load_all_crops: true + crop_size: 16 + overlap_pixels: 8 +trainer: + callbacks: + - class_path: rslearn.train.prediction_writer.RslearnWriter + init_args: + output_layer: embeddings_olmoearth_v1_1_base_ps1_ws16 + layer_config: + type: RASTER + band_sets: + - dtype: FLOAT32 + num_bands: 768 + merger: + class_path: rslearn.train.prediction_writer.RasterMerger + init_args: + overlap_pixels: 8 + downsample_factor: 1 diff --git a/data/embedding_explorer/config_olmoearth_v1_1_base_ps4.yaml b/data/embedding_explorer/config_olmoearth_v1_1_base_ps4.yaml new file mode 100644 index 000000000..bb75ef455 --- /dev/null +++ b/data/embedding_explorer/config_olmoearth_v1_1_base_ps4.yaml @@ -0,0 +1,56 @@ +model: + class_path: rslearn.train.lightning_module.RslearnLightningModule + init_args: + model: + class_path: rslearn.models.singletask.SingleTaskModel + init_args: + encoder: + - class_path: rslearn.models.olmoearth_pretrain.model.OlmoEarth + init_args: + model_id: OLMOEARTH_V1_1_BASE + patch_size: 4 + decoder: + - class_path: rslearn.train.tasks.embedding.EmbeddingHead + optimizer: + class_path: rslearn.train.optimizer.AdamW +data: + class_path: rslearn.train.data_module.RslearnDataModule + init_args: + path: ${DATASET_PATH} + inputs: + sentinel2_l2a: + data_type: "raster" + layers: ["sentinel2_l2a"] + bands: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12", "B01", "B09"] + passthrough: true + dtype: FLOAT32 + load_all_layers: true + load_all_item_groups: true + task: + class_path: rslearn.train.tasks.embedding.EmbeddingTask + batch_size: 8 + num_workers: 32 + predict_config: + transforms: + - class_path: rslearn.models.olmoearth_pretrain.norm.OlmoEarthNormalize + init_args: + band_names: + sentinel2_l2a: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12", "B01", "B09"] + load_all_crops: true + crop_size: 64 + overlap_pixels: 32 +trainer: + callbacks: + - class_path: rslearn.train.prediction_writer.RslearnWriter + init_args: + output_layer: embeddings_olmoearth_v1_1_base_ps4_ws64 + layer_config: + type: RASTER + band_sets: + - dtype: FLOAT32 + num_bands: 768 + merger: + class_path: rslearn.train.prediction_writer.RasterMerger + init_args: + overlap_pixels: 8 + downsample_factor: 4 diff --git a/data/embedding_explorer/config_olmoearth_v1_1_nano_ps1.yaml b/data/embedding_explorer/config_olmoearth_v1_1_nano_ps1.yaml new file mode 100644 index 000000000..b1e2f7fd3 --- /dev/null +++ b/data/embedding_explorer/config_olmoearth_v1_1_nano_ps1.yaml @@ -0,0 +1,56 @@ +model: + class_path: rslearn.train.lightning_module.RslearnLightningModule + init_args: + model: + class_path: rslearn.models.singletask.SingleTaskModel + init_args: + encoder: + - class_path: rslearn.models.olmoearth_pretrain.model.OlmoEarth + init_args: + model_id: OLMOEARTH_V1_1_NANO + patch_size: 1 + decoder: + - class_path: rslearn.train.tasks.embedding.EmbeddingHead + optimizer: + class_path: rslearn.train.optimizer.AdamW +data: + class_path: rslearn.train.data_module.RslearnDataModule + init_args: + path: ${DATASET_PATH} + inputs: + sentinel2_l2a: + data_type: "raster" + layers: ["sentinel2_l2a"] + bands: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12", "B01", "B09"] + passthrough: true + dtype: FLOAT32 + load_all_layers: true + load_all_item_groups: true + task: + class_path: rslearn.train.tasks.embedding.EmbeddingTask + batch_size: 8 + num_workers: 32 + predict_config: + transforms: + - class_path: rslearn.models.olmoearth_pretrain.norm.OlmoEarthNormalize + init_args: + band_names: + sentinel2_l2a: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12", "B01", "B09"] + load_all_crops: true + crop_size: 16 + overlap_pixels: 8 +trainer: + callbacks: + - class_path: rslearn.train.prediction_writer.RslearnWriter + init_args: + output_layer: embeddings_olmoearth_v1_1_nano_ps1_ws16 + layer_config: + type: RASTER + band_sets: + - dtype: FLOAT32 + num_bands: 768 + merger: + class_path: rslearn.train.prediction_writer.RasterMerger + init_args: + overlap_pixels: 8 + downsample_factor: 1 diff --git a/data/embedding_explorer/config_olmoearth_v1_1_nano_ps4.yaml b/data/embedding_explorer/config_olmoearth_v1_1_nano_ps4.yaml new file mode 100644 index 000000000..02ea93a49 --- /dev/null +++ b/data/embedding_explorer/config_olmoearth_v1_1_nano_ps4.yaml @@ -0,0 +1,56 @@ +model: + class_path: rslearn.train.lightning_module.RslearnLightningModule + init_args: + model: + class_path: rslearn.models.singletask.SingleTaskModel + init_args: + encoder: + - class_path: rslearn.models.olmoearth_pretrain.model.OlmoEarth + init_args: + model_id: OLMOEARTH_V1_1_NANO + patch_size: 4 + decoder: + - class_path: rslearn.train.tasks.embedding.EmbeddingHead + optimizer: + class_path: rslearn.train.optimizer.AdamW +data: + class_path: rslearn.train.data_module.RslearnDataModule + init_args: + path: ${DATASET_PATH} + inputs: + sentinel2_l2a: + data_type: "raster" + layers: ["sentinel2_l2a"] + bands: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12", "B01", "B09"] + passthrough: true + dtype: FLOAT32 + load_all_layers: true + load_all_item_groups: true + task: + class_path: rslearn.train.tasks.embedding.EmbeddingTask + batch_size: 8 + num_workers: 32 + predict_config: + transforms: + - class_path: rslearn.models.olmoearth_pretrain.norm.OlmoEarthNormalize + init_args: + band_names: + sentinel2_l2a: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12", "B01", "B09"] + load_all_crops: true + crop_size: 64 + overlap_pixels: 32 +trainer: + callbacks: + - class_path: rslearn.train.prediction_writer.RslearnWriter + init_args: + output_layer: embeddings_olmoearth_v1_1_nano_ps4_ws64 + layer_config: + type: RASTER + band_sets: + - dtype: FLOAT32 + num_bands: 768 + merger: + class_path: rslearn.train.prediction_writer.RasterMerger + init_args: + overlap_pixels: 8 + downsample_factor: 4 diff --git a/rslp/embedding_explorer/config_presto.yaml b/data/embedding_explorer/config_presto.yaml similarity index 82% rename from rslp/embedding_explorer/config_presto.yaml rename to data/embedding_explorer/config_presto.yaml index 9f651056d..6b9579cc2 100644 --- a/rslp/embedding_explorer/config_presto.yaml +++ b/data/embedding_explorer/config_presto.yaml @@ -19,24 +19,30 @@ data: inputs: s2: data_type: "raster" - layers: ["sentinel2_l2a", "sentinel2_l2a.1", "sentinel2_l2a.2", "sentinel2_l2a.3"] + layers: ["sentinel2_l2a"] bands: ["B02", "B03", "B04", "B05", "B06", "B07", "B08", "B8A", "B11", "B12"] passthrough: true dtype: FLOAT32 load_all_layers: true + load_all_item_groups: true task: class_path: rslearn.train.tasks.embedding.EmbeddingTask batch_size: 8 num_workers: 32 predict_config: load_all_crops: true - crop_size: 64 + crop_size: 16 overlap_pixels: 0 trainer: callbacks: - class_path: rslearn.train.prediction_writer.RslearnWriter init_args: - output_layer: presto + output_layer: embeddings_presto + layer_config: + type: RASTER + band_sets: + - dtype: FLOAT32 + num_bands: 128 merger: class_path: rslearn.train.prediction_writer.RasterMerger init_args: diff --git a/data/embedding_explorer/config_tessera.yaml b/data/embedding_explorer/config_tessera.yaml new file mode 100644 index 000000000..2106a27a8 --- /dev/null +++ b/data/embedding_explorer/config_tessera.yaml @@ -0,0 +1,74 @@ +model: + class_path: rslearn.train.lightning_module.RslearnLightningModule + init_args: + model: + class_path: rslearn.models.singletask.SingleTaskModel + init_args: + encoder: + - class_path: rslearn.models.tessera.tessera.Tessera + init_args: + checkpoint_path: ${TESSERA_CHECKPOINT_PATH} + pixel_batch_size: 1024 + decoder: + - class_path: rslearn.train.tasks.embedding.EmbeddingHead + optimizer: + class_path: rslearn.train.optimizer.AdamW +data: + class_path: rslearn.train.data_module.RslearnDataModule + init_args: + path: ${DATASET_PATH} + inputs: + s2: + data_type: "raster" + layers: ["sentinel2_l2a"] + bands: ["B04", "B02", "B03", "B08", "B8A", "B05", "B06", "B07", "B11", "B12"] + passthrough: true + dtype: FLOAT32 + load_all_layers: true + load_all_item_groups: true + s1_ascending: + data_type: "raster" + layers: ["sentinel1_ascending"] + bands: ["vv", "vh"] + passthrough: true + dtype: FLOAT32 + load_all_layers: true + load_all_item_groups: true + s1_descending: + data_type: "raster" + layers: ["sentinel1_descending"] + bands: ["vv", "vh"] + passthrough: true + dtype: FLOAT32 + load_all_layers: true + load_all_item_groups: true + task: + class_path: rslearn.train.tasks.embedding.EmbeddingTask + batch_size: 1 + num_workers: 32 + predict_config: + transforms: + - class_path: rslearn.train.transforms.sentinel1.Sentinel1ToDecibels + init_args: + selectors: ["s1_ascending", "s1_descending"] + - class_path: rslearn.models.tessera.tessera.TesseraNormalize + init_args: + data_source: mpc + load_all_crops: true + crop_size: 16 + overlap_pixels: 0 +trainer: + callbacks: + - class_path: rslearn.train.prediction_writer.RslearnWriter + init_args: + output_layer: embeddings_tessera + layer_config: + type: RASTER + band_sets: + - dtype: FLOAT32 + num_bands: 128 + merger: + class_path: rslearn.train.prediction_writer.RasterMerger + init_args: + overlap_pixels: 0 + downsample_factor: 1 diff --git a/rslp/embedding_explorer/config_with_aef.json b/data/embedding_explorer/config_with_aef.json similarity index 84% rename from rslp/embedding_explorer/config_with_aef.json rename to data/embedding_explorer/config_with_aef.json index 6f5b69b47..c2c8e5ac5 100644 --- a/rslp/embedding_explorer/config_with_aef.json +++ b/data/embedding_explorer/config_with_aef.json @@ -84,15 +84,6 @@ }, "type": "raster" }, - "embeddings": { - "band_sets": [ - { - "dtype": "float32", - "num_bands": 768 - } - ], - "type": "raster" - }, "sentinel2_l2a": { "band_sets": [ { @@ -114,17 +105,23 @@ } ], "data_source": { - "class_path": "rslearn.data_sources.planetary_computer.Sentinel2", + "class_path": "olmoearth_run.runner.tools.rslearn_data_sources.olmoearth_datasets.sentinel2_l2a.Sentinel2L2A", "ingest": false, "init_args": { - "cache_dir": "cache/planetary_computer", + "cache_dir": "cache/olmoearth_datasets", "harmonize": true, - "sort_by": "eo:cloud_cover" + "query": { + "sort_by": "CLOUD_COVER", + "sort_direction": "ASC" + }, + "timeout": "0:0:10" }, "query_config": { "max_matches": 12, + "min_matches": 12, + "per_period_mosaic_reverse_time_order": false, "period_duration": "30d", - "space_mode": "PER_PERIOD_MOSAIC" + "space_mode": "MOSAIC" } }, "type": "raster" diff --git a/data/embedding_explorer/config_with_tessera.json b/data/embedding_explorer/config_with_tessera.json new file mode 100644 index 000000000..03d471d45 --- /dev/null +++ b/data/embedding_explorer/config_with_tessera.json @@ -0,0 +1,108 @@ +{ + "layers": { + "sentinel1_ascending": { + "band_sets": [ + { + "bands": [ + "vv", + "vh" + ], + "dtype": "float32" + } + ], + "data_source": { + "class_path": "olmoearth_run.runner.tools.rslearn_data_sources.olmoearth_datasets.sentinel1_rtc.Sentinel1RTC", + "ingest": false, + "init_args": { + "query": { + "orbit_direction": { + "eq": "ASCENDING" + } + }, + "timeout": "0:0:10" + }, + "query_config": { + "max_matches": 12, + "min_matches": 12, + "per_period_mosaic_reverse_time_order": false, + "period_duration": "30d", + "space_mode": "MOSAIC" + } + }, + "type": "raster" + }, + "sentinel1_descending": { + "band_sets": [ + { + "bands": [ + "vv", + "vh" + ], + "dtype": "float32" + } + ], + "data_source": { + "class_path": "olmoearth_run.runner.tools.rslearn_data_sources.olmoearth_datasets.sentinel1_rtc.Sentinel1RTC", + "ingest": false, + "init_args": { + "query": { + "orbit_direction": { + "eq": "DESCENDING" + } + }, + "timeout": "0:0:10" + }, + "query_config": { + "max_matches": 12, + "min_matches": 12, + "per_period_mosaic_reverse_time_order": false, + "period_duration": "30d", + "space_mode": "MOSAIC" + } + }, + "type": "raster" + }, + "sentinel2_l2a": { + "band_sets": [ + { + "bands": [ + "B01", + "B02", + "B03", + "B04", + "B05", + "B06", + "B07", + "B08", + "B8A", + "B09", + "B11", + "B12" + ], + "dtype": "uint16" + } + ], + "data_source": { + "class_path": "olmoearth_run.runner.tools.rslearn_data_sources.olmoearth_datasets.sentinel2_l2a.Sentinel2L2A", + "ingest": false, + "init_args": { + "cache_dir": "cache/olmoearth_datasets", + "harmonize": true, + "query": { + "sort_by": "CLOUD_COVER", + "sort_direction": "ASC" + }, + "timeout": "0:0:10" + }, + "query_config": { + "max_matches": 12, + "min_matches": 12, + "per_period_mosaic_reverse_time_order": false, + "period_duration": "30d", + "space_mode": "MOSAIC" + } + }, + "type": "raster" + } + } +} diff --git a/rslp/embedding_explorer/README.md b/rslp/embedding_explorer/README.md index fccc6429e..8304ae814 100644 --- a/rslp/embedding_explorer/README.md +++ b/rslp/embedding_explorer/README.md @@ -3,18 +3,17 @@ WARNING: all of the code and most of the documentation here was AI-generated and not carefully reviewed. -A small Flask app for interactively exploring per-pixel OlmoEarth embeddings -on a Leaflet map. Click points on the map and the server computes a similarity +A small Flask app for interactively exploring per-pixel embedding layers on a +Leaflet map. Click points on the map and the server computes a similarity overlay using one of three modes: - **Cosine** — cosine similarity to the embedding under the (single, latest) positive point. - **KNN** — fraction of the K nearest labeled points (by cosine similarity) that are positive. -- **Linear Probe** — a logistic regression (768 -> 2) is trained on CPU from the - labeled points and predict_proba is run over the whole window. Updates only - on **Apply** (training is fast but not free, so we don't refresh on every - click). +- **Linear Probe** — a logistic regression is trained on CPU from the labeled + points and predict_proba is run over the whole window. Updates only on + **Apply** (training is fast but not free, so we don't refresh on every click). Left-click adds a positive point, right-click adds a negative point. The threshold slider and gradient/threshold toggle are pure client-side rendering @@ -22,18 +21,48 @@ and stay live in all modes. ## Setting up a dataset -The explorer expects an rslearn dataset where each window has a raster layer -named `embeddings` (768 bands, float32) alongside the source images. Two -reference configs sit next to this README: +The explorer expects an rslearn dataset where each window has at least one +embedding raster layer alongside the source images. Reference configs live in +[`data/embedding_explorer`](../../data/embedding_explorer): + +- [config.json](../../data/embedding_explorer/config.json) — minimal dataset + config (Sentinel-2 L2A from OlmoEarth Datasets). +- [config.yaml](../../data/embedding_explorer/config.yaml) — model config that + runs OlmoEarth and writes outputs into the + `embeddings_olmoearth_v1_base_ps4_ws64` layer via `RslearnWriter`. +- [config_olmoearth_10m.yaml](../../data/embedding_explorer/config_olmoearth_10m.yaml) + — model config that runs OlmoEarth at 10m/pixel into the + `embeddings_olmoearth_v1_base_ps1_ws16` layer. +- [config_olmoearth_v1_1_base_ps4.yaml](../../data/embedding_explorer/config_olmoearth_v1_1_base_ps4.yaml) + and [config_olmoearth_v1_1_base_ps1.yaml](../../data/embedding_explorer/config_olmoearth_v1_1_base_ps1.yaml) + — OlmoEarth v1.1 Base configs that write distinct `embeddings_...` layers. +- [config_olmoearth_v1_1_nano_ps4.yaml](../../data/embedding_explorer/config_olmoearth_v1_1_nano_ps4.yaml) + and [config_olmoearth_v1_1_nano_ps1.yaml](../../data/embedding_explorer/config_olmoearth_v1_1_nano_ps1.yaml) + — OlmoEarth v1.1 Nano configs that write distinct `embeddings_...` layers. +- [config_with_aef.json](../../data/embedding_explorer/config_with_aef.json) — + dataset config that also declares a pre-computed 64-dimensional `aef` + embedding layer. +- [config_presto.yaml](../../data/embedding_explorer/config_presto.yaml) — + model config that runs Presto and writes outputs into the `embeddings_presto` + layer via `RslearnWriter`. +- [config_with_tessera.json](../../data/embedding_explorer/config_with_tessera.json) + — dataset config that adds Sentinel-1 ascending/descending inputs for Tessera. +- [config_tessera.yaml](../../data/embedding_explorer/config_tessera.yaml) — + model config that runs Tessera and writes outputs into the + `embeddings_tessera` layer via `RslearnWriter`. + +Model-generated output layer definitions (`embeddings_...`) live in the +corresponding model YAMLs under `RslearnWriter.layer_config`, so the dataset +JSONs only need layers that are prepared or materialized from data sources. + +The dataset JSONs use OlmoEarth Datasets-backed sources for Sentinel-2 and +Sentinel-1. Before `rslearn dataset prepare` or `materialize`, make sure +`olmoearth_run[runner]` is importable and set: -- [config.json](config.json) — minimal dataset config (Sentinel-2 L2A from - Planetary Computer + an `embeddings` layer). -- [config.yaml](config.yaml) — model config that runs OlmoEarth and writes - outputs into the `embeddings` layer via `RslearnWriter`. -- [config_with_presto.json](config_with_presto.json) — dataset config that also - declares a 128-dimensional `presto` embedding layer. -- [config_presto.yaml](config_presto.yaml) — model config that runs Presto and - writes outputs into the `presto` layer via `RslearnWriter`. +```bash +export OEDATASETS_API_URL=https://datasets.olmoearth.allenai.org +export DATASETS_API_TOKEN= +``` For a fuller treatment of these configs (Sentinel-1 / Landsat layers, model sizes, etc.) see @@ -49,22 +78,21 @@ switch between them in the sidebar dropdown. ```bash export DATASET_PATH=./dataset mkdir -p $DATASET_PATH -cp rslp/embedding_explorer/config.json $DATASET_PATH/config.json +cp data/embedding_explorer/config.json $DATASET_PATH/config.json -# A single 2048x2048 window over Seattle. +# A single window over Seattle. rslearn dataset add_windows --root $DATASET_PATH \ --group default --name seattle \ --utm --resolution 10 --src_crs EPSG:4326 \ - --box=-122.42,47.58,-122.22,47.78 \ - --start 2024-06-01T00:00:00+00:00 --end 2024-09-01T00:00:00+00:00 \ - --grid_size 2048 + --box=-122.5,47.5,-122.2,47.8 \ + --start 2025-01-01T00:00:00+00:00 --end 2026-01-01T00:00:00+00:00 # Or tile a larger area into multiple 2048x2048 windows. rslearn dataset add_windows --root $DATASET_PATH \ --group default --name puget_sound \ --utm --resolution 10 --src_crs EPSG:4326 \ --box=-122.7,47.2,-122.0,47.9 \ - --start 2024-06-01T00:00:00+00:00 --end 2024-09-01T00:00:00+00:00 \ + --start 2025-01-01T00:00:00+00:00 --end 2026-01-01T00:00:00+00:00 \ --grid_size 2048 ``` @@ -86,7 +114,7 @@ from [source.coop/tge-labs/aef](https://source.coop/tge-labs/aef). Use `config_with_aef.json` instead of `config.json`: ```bash -cp rslp/embedding_explorer/config_with_aef.json $DATASET_PATH/config.json +cp data/embedding_explorer/config_with_aef.json $DATASET_PATH/config.json rslearn dataset prepare --root $DATASET_PATH --workers 32 --enabled-layers aef rslearn dataset materialize --root $DATASET_PATH --workers 32 \ --no-use-initial-job --enabled-layers aef @@ -95,11 +123,12 @@ rslearn dataset materialize --root $DATASET_PATH --workers 32 \ ### 3. Compute OlmoEarth embeddings (40m, default) ```bash -rslearn model predict --config rslp/embedding_explorer/config.yaml +rslearn model predict --config data/embedding_explorer/config.yaml ``` -After this finishes, each window will have a populated `layers/embeddings/` -directory with a 768-band GeoTIFF at 40m/pixel (patch_size=4). +After this finishes, each window will have a populated +`layers/embeddings_olmoearth_v1_base_ps4_ws64/` directory with a 768-band +GeoTIFF at 40m/pixel (patch_size=4). ### 3b. Compute OlmoEarth embeddings at 10m (optional) @@ -107,31 +136,77 @@ For a fair comparison with AEF (both at 10m), use `config_olmoearth_10m.yaml` which runs OlmoEarth-v1-Base with `patch_size=1` (one embedding per 10m pixel): ```bash -rslearn model predict --config rslp/embedding_explorer/config_olmoearth_10m.yaml +rslearn model predict --config data/embedding_explorer/config_olmoearth_10m.yaml ``` -This writes to the same `embeddings` layer as the 40m config (so only run one -or the other per dataset). +This writes to `embeddings_olmoearth_v1_base_ps1_ws16`, so it can coexist with +the 40m output in the same dataset. -### 3c. Compute Presto embeddings (optional) +### 3c. Compute OlmoEarth v1.1 embeddings (optional) -Presto is supported in rslearn as `rslearn.models.presto.Presto`. Use -`config_with_presto.json` instead of `config.json` when creating the dataset, -then materialize Sentinel-2 as above and run: +The v1.1 Base and Nano configs write to distinct output layers, so you can run +multiple variants in the same dataset: ```bash -rslearn model predict --config rslp/embedding_explorer/config_presto.yaml +rslearn model predict --config data/embedding_explorer/config_olmoearth_v1_1_base_ps4.yaml +rslearn model predict --config data/embedding_explorer/config_olmoearth_v1_1_base_ps1.yaml +rslearn model predict --config data/embedding_explorer/config_olmoearth_v1_1_nano_ps4.yaml +rslearn model predict --config data/embedding_explorer/config_olmoearth_v1_1_nano_ps1.yaml ``` -This writes 128-dimensional embeddings at 10m/pixel to the `presto` layer. The -provided config uses Sentinel-2 only, matching the default explorer dataset. +Pass the corresponding output layer names to the app, such as +`embeddings_olmoearth_v1_1_base_ps4_ws64` or +`embeddings_olmoearth_v1_1_nano_ps1_ws16`. + +### 3d. Compute Presto embeddings (optional) + +Presto is supported in rslearn as `rslearn.models.presto.Presto`. The provided +config uses Sentinel-2 only, so `config.json` is enough. Materialize Sentinel-2 as +above and run: + +```bash +rslearn model predict --config data/embedding_explorer/config_presto.yaml +``` + +This writes 128-dimensional embeddings at 10m/pixel to the `embeddings_presto` +layer. The provided config uses Sentinel-2 only, matching the default explorer +dataset. Presto can also consume Sentinel-1 when the dataset has a compatible `s1` input. The first run may need to download or otherwise populate the Presto checkpoint cache used by `rslearn.models.presto.Presto`. -Tessera is not currently available as an rslearn model config in this project. -There is a Tessera eval wrapper in `olmoearth_pretrain`, but it has not been -adapted to rslearn's `FeatureExtractor` interface for `rslearn model predict`. +### 3e. Compute Tessera embeddings (optional) + +Tessera is supported in rslearn as `rslearn.models.tessera.tessera.Tessera`. Use +`config_with_tessera.json` instead of `config.json` when creating the dataset. +Tessera uses Sentinel-2 plus separate ascending and descending Sentinel-1 RTC +time series, so materialize all three source layers: + +```bash +cp data/embedding_explorer/config_with_tessera.json $DATASET_PATH/config.json +rslearn dataset prepare --root $DATASET_PATH --workers 32 \ + --enabled-layers sentinel2_l2a,sentinel1_ascending,sentinel1_descending \ + --retry-max-attempts 5 --retry-backoff-seconds 5 +rslearn dataset materialize --root $DATASET_PATH --workers 32 \ + --no-use-initial-job \ + --enabled-layers sentinel2_l2a,sentinel1_ascending,sentinel1_descending \ + --retry-max-attempts 5 --retry-backoff-seconds 5 +``` + +Download an encoder-only Tessera v1.1 checkpoint and point the model config at +it, then run prediction: + +```bash +export TESSERA_CHECKPOINT_PATH=/path/to/tessera_v1_1_mpc_encoder.pt +rslearn model predict --config data/embedding_explorer/config_tessera.yaml +``` + +This writes 128-dimensional Tessera embeddings at 10m/pixel to the +`embeddings_tessera` layer. The provided config converts the OlmoEarth Datasets +Sentinel-1 RTC layers to standard dB with `Sentinel1ToDecibels`, then applies +`TesseraNormalize(data_source="mpc")` before the model runs. If your Sentinel-1 +layers are already stored in standard dB, skip `Sentinel1ToDecibels` and still +run `TesseraNormalize` so the model receives checkpoint-ready normalized inputs. ## Run the app @@ -147,6 +222,7 @@ Then point it at your dataset: ```bash python -m rslp.embedding_explorer.app \ --dataset-path $DATASET_PATH \ + --embedding-layer embeddings_olmoearth_v1_base_ps4_ws64 \ --port 5000 ``` @@ -155,7 +231,7 @@ To load multiple embedding layers (e.g. OlmoEarth + AEF), pass them all: ```bash python -m rslp.embedding_explorer.app \ --dataset-path $DATASET_PATH \ - --embedding-layer embeddings aef \ + --embedding-layer embeddings_olmoearth_v1_base_ps4_ws64 aef \ --port 5000 ``` @@ -168,7 +244,16 @@ For example, to compare OlmoEarth and Presto: ```bash python -m rslp.embedding_explorer.app \ --dataset-path $DATASET_PATH \ - --embedding-layer embeddings presto \ + --embedding-layer embeddings_olmoearth_v1_base_ps4_ws64 embeddings_presto \ + --port 5000 +``` + +Or compare OlmoEarth and Tessera: + +```bash +python -m rslp.embedding_explorer.app \ + --dataset-path $DATASET_PATH \ + --embedding-layer embeddings_olmoearth_v1_base_ps4_ws64 embeddings_tessera \ --port 5000 ``` diff --git a/rslp/embedding_explorer/app.py b/rslp/embedding_explorer/app.py index a2c6ee79e..612156214 100644 --- a/rslp/embedding_explorer/app.py +++ b/rslp/embedding_explorer/app.py @@ -3,7 +3,7 @@ Usage: python -m rslp.embedding_explorer.app --dataset-path /path/to/dataset --port 5000 python -m rslp.embedding_explorer.app --dataset-path /path/to/dataset \ - --embedding-layer embeddings aef --port 5000 + --embedding-layer embeddings_olmoearth_v1_base_ps4_ws64 aef --port 5000 """ import argparse @@ -22,6 +22,7 @@ from sklearn.linear_model import LogisticRegression EPSG_3857 = "EPSG:3857" +DEFAULT_EMBEDDING_LAYERS = ["embeddings_olmoearth_v1_base_ps4_ws64"] def find_geotiff(layer_dir: Path) -> Path | None: @@ -372,7 +373,7 @@ def reproject_single_band_to_wm( def create_app(dataset_path: Path, embedding_layers: list[str] | None = None) -> Flask: """Create and configure the Flask application.""" if embedding_layers is None: - embedding_layers = ["embeddings"] + embedding_layers = DEFAULT_EMBEDDING_LAYERS app_dir = Path(__file__).parent app = Flask( __name__, @@ -593,8 +594,11 @@ def main() -> None: parser.add_argument( "--embedding-layer", nargs="+", - default=["embeddings"], - help="Embedding layer name(s) to load (default: 'embeddings')", + default=DEFAULT_EMBEDDING_LAYERS, + help=( + "Embedding layer name(s) to load " + f"(default: {', '.join(DEFAULT_EMBEDDING_LAYERS)})" + ), ) parser.add_argument("--port", type=int, default=5000) parser.add_argument("--host", default="0.0.0.0") diff --git a/rslp/embedding_explorer/config_with_presto.json b/rslp/embedding_explorer/config_with_presto.json deleted file mode 100644 index a07ff3621..000000000 --- a/rslp/embedding_explorer/config_with_presto.json +++ /dev/null @@ -1,58 +0,0 @@ -{ - "layers": { - "embeddings": { - "band_sets": [ - { - "dtype": "float32", - "num_bands": 768 - } - ], - "type": "raster" - }, - "presto": { - "band_sets": [ - { - "dtype": "float32", - "num_bands": 128 - } - ], - "type": "raster" - }, - "sentinel2_l2a": { - "band_sets": [ - { - "bands": [ - "B01", - "B02", - "B03", - "B04", - "B05", - "B06", - "B07", - "B08", - "B8A", - "B09", - "B11", - "B12" - ], - "dtype": "uint16" - } - ], - "data_source": { - "class_path": "rslearn.data_sources.planetary_computer.Sentinel2", - "ingest": false, - "init_args": { - "cache_dir": "cache/planetary_computer", - "harmonize": true, - "sort_by": "eo:cloud_cover" - }, - "query_config": { - "max_matches": 12, - "period_duration": "30d", - "space_mode": "PER_PERIOD_MOSAIC" - } - }, - "type": "raster" - } - } -} From f238064d01fd0a2cada8b4874986c6aa32c79aeb Mon Sep 17 00:00:00 2001 From: Favyen Bastani Date: Tue, 16 Jun 2026 13:53:14 -0700 Subject: [PATCH 6/6] fix tessera --- data/embedding_explorer/config_tessera.yaml | 6 +++--- rslp/embedding_explorer/README.md | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/data/embedding_explorer/config_tessera.yaml b/data/embedding_explorer/config_tessera.yaml index 2106a27a8..f4135a9a2 100644 --- a/data/embedding_explorer/config_tessera.yaml +++ b/data/embedding_explorer/config_tessera.yaml @@ -5,7 +5,7 @@ model: class_path: rslearn.models.singletask.SingleTaskModel init_args: encoder: - - class_path: rslearn.models.tessera.tessera.Tessera + - class_path: rslearn.models.tessera.Tessera init_args: checkpoint_path: ${TESSERA_CHECKPOINT_PATH} pixel_batch_size: 1024 @@ -51,7 +51,7 @@ data: - class_path: rslearn.train.transforms.sentinel1.Sentinel1ToDecibels init_args: selectors: ["s1_ascending", "s1_descending"] - - class_path: rslearn.models.tessera.tessera.TesseraNormalize + - class_path: rslearn.models.tessera.TesseraNormalize init_args: data_source: mpc load_all_crops: true @@ -66,7 +66,7 @@ trainer: type: RASTER band_sets: - dtype: FLOAT32 - num_bands: 128 + num_bands: 192 merger: class_path: rslearn.train.prediction_writer.RasterMerger init_args: diff --git a/rslp/embedding_explorer/README.md b/rslp/embedding_explorer/README.md index 8304ae814..699be8fd7 100644 --- a/rslp/embedding_explorer/README.md +++ b/rslp/embedding_explorer/README.md @@ -177,7 +177,7 @@ cache used by `rslearn.models.presto.Presto`. ### 3e. Compute Tessera embeddings (optional) -Tessera is supported in rslearn as `rslearn.models.tessera.tessera.Tessera`. Use +Tessera is supported in rslearn as `rslearn.models.tessera.Tessera`. Use `config_with_tessera.json` instead of `config.json` when creating the dataset. Tessera uses Sentinel-2 plus separate ascending and descending Sentinel-1 RTC time series, so materialize all three source layers: @@ -201,7 +201,7 @@ export TESSERA_CHECKPOINT_PATH=/path/to/tessera_v1_1_mpc_encoder.pt rslearn model predict --config data/embedding_explorer/config_tessera.yaml ``` -This writes 128-dimensional Tessera embeddings at 10m/pixel to the +This writes 192-dimensional Tessera embeddings at 10m/pixel to the `embeddings_tessera` layer. The provided config converts the OlmoEarth Datasets Sentinel-1 RTC layers to standard dB with `Sentinel1ToDecibels`, then applies `TesseraNormalize(data_source="mpc")` before the model runs. If your Sentinel-1