diff --git a/README.md b/README.md index f2fdf5b4..32ad0a86 100644 --- a/README.md +++ b/README.md @@ -51,9 +51,9 @@ baselines: root: /store_new/mch/msopr/ml/COSMO-E steps: 0/120/6 -analysis: +truth: label: COSMO KENDA - analysis_zarr: /scratch/mch/fzanetta/data/anemoi/datasets/mch-co2-an-archive-0p02-2015-2020-6h-v3-pl13.zarr + root: /scratch/mch/fzanetta/data/anemoi/datasets/mch-co2-an-archive-0p02-2015-2020-6h-v3-pl13.zarr locations: output_root: output/ diff --git a/config/forecasters-co1e.yaml b/config/forecasters-co1e.yaml index 251e1d02..83846ec4 100644 --- a/config/forecasters-co1e.yaml +++ b/config/forecasters-co1e.yaml @@ -27,9 +27,9 @@ baselines: root: /store_new/mch/msopr/ml/COSMO-1E steps: 0/33/6 -analysis: +truth: label: COSMO KENDA - analysis_zarr: /scratch/mch/fzanetta/data/anemoi/datasets/mch-co1e-an-archive-0p01-2019-2024-1h-v1-pl13.zarr + root: /scratch/mch/fzanetta/data/anemoi/datasets/mch-co1e-an-archive-0p01-2019-2024-1h-v1-pl13.zarr stratification: regions: diff --git a/config/forecasters-co2-disentangled.yaml b/config/forecasters-co2-disentangled.yaml index 1826c9d0..c9a75f2c 100644 --- a/config/forecasters-co2-disentangled.yaml +++ b/config/forecasters-co2-disentangled.yaml @@ -48,9 +48,9 @@ baselines: root: /store_new/mch/msopr/ml/COSMO-E steps: 0/120/6 -analysis: +truth: label: COSMO KENDA - analysis_zarr: /scratch/mch/fzanetta/data/anemoi/datasets/mch-co2-an-archive-0p02-2015-2020-6h-v3-pl13.zarr + root: /scratch/mch/fzanetta/data/anemoi/datasets/mch-co2-an-archive-0p02-2015-2020-6h-v3-pl13.zarr stratification: regions: diff --git a/config/forecasters-co2.yaml b/config/forecasters-co2.yaml index d27c9f16..059ab97e 100644 --- a/config/forecasters-co2.yaml +++ b/config/forecasters-co2.yaml @@ -23,9 +23,9 @@ baselines: root: /store_new/mch/msopr/ml/COSMO-E steps: 0/120/6 -analysis: +truth: label: COSMO KENDA - analysis_zarr: /scratch/mch/fzanetta/data/anemoi/datasets/mch-co2-an-archive-0p02-2015-2020-6h-v3-pl13.zarr + root: /scratch/mch/fzanetta/data/anemoi/datasets/mch-co2-an-archive-0p02-2015-2020-6h-v3-pl13.zarr stratification: regions: diff --git a/config/forecasters-ich1-oper.yaml b/config/forecasters-ich1-oper.yaml index fc1ae591..6e5b011f 100644 --- a/config/forecasters-ich1-oper.yaml +++ b/config/forecasters-ich1-oper.yaml @@ -11,7 +11,6 @@ dates: - 2025-02-01T06:00 - 2025-03-01T12:00 - runs: - forecaster: checkpoint: https://servicedepl.meteoswiss.ch/mlstore#/experiments/409/runs/b30acf68520a4bbd8324c44666561696 @@ -26,13 +25,18 @@ runs: baselines: - baseline: baseline_id: ICON-CH1-EPS - label: ICON-CH1-EPS - root: /store_new/mch/msopr/ml/ICON-CH1-EPS + label: ICON-CH1-ctrl + root: /scratch/mch/cmerker/ICON-CH1-EPS steps: 0/33/6 + - baseline: + baseline_id: ICON-CH2-EPS + label: ICON-CH2-ctrl + root: /scratch/mch/cmerker/ICON-CH2-EPS + steps: 0/120/6 -analysis: +truth: label: KENDA-CH1 - analysis_zarr: /store_new/mch/msopr/ml/datasets/mch-ich1-1km-2024-2025-1h-pl13-v1.0.zarr + root: /store_new/mch/msopr/ml/datasets/mch-ich1-1km-2024-2025-1h-pl13-v1.0.zarr stratification: regions: diff --git a/config/forecasters-ich1.yaml b/config/forecasters-ich1.yaml index b4831717..a5de9b54 100644 --- a/config/forecasters-ich1.yaml +++ b/config/forecasters-ich1.yaml @@ -41,14 +41,14 @@ runs: baselines: - baseline: - baseline_id: ICON-CH1-EPS - label: ICON-CH1-EPS - root: /store_new/mch/msopr/ml/ICON-CH1-EPS - steps: 0/33/6 + baseline_id: ICON-CH2-EPS + label: ICON-CH2-EPS + root: /scratch/mch/cmerker/ICON-CH2-EPS + steps: 0/120/6 -analysis: - label: REA-L-CH1 - analysis_zarr: /store_new/mch/msopr/ml/datasets/mch-realch1-fdb-1km-2005-2025-1h-pl13-v1.0.zarr +truth: + label: KENDA-CH1 + root: /store_new/mch/msopr/ml/datasets/mch-ich1-1km-2024-2025-1h-pl13-v1.0.zarr stratification: regions: diff --git a/config/interpolators-co2.yaml b/config/interpolators-co2.yaml index 5830dbcb..115dd6bc 100644 --- a/config/interpolators-co2.yaml +++ b/config/interpolators-co2.yaml @@ -54,9 +54,9 @@ baselines: root: /store_new/mch/msopr/ml/COSMO-E_hourly steps: 0/120/1 -analysis: +truth: label: COSMO KENDA - analysis_zarr: /scratch/mch/fzanetta/data/anemoi/datasets/mch-co2-an-archive-0p02-2015-2020-1h-v3-pl13.zarr + root: /scratch/mch/fzanetta/data/anemoi/datasets/mch-co2-an-archive-0p02-2015-2020-1h-v3-pl13.zarr stratification: regions: diff --git a/src/data_input/__init__.py b/src/data_input/__init__.py index c5a060a6..e67816fb 100644 --- a/src/data_input/__init__.py +++ b/src/data_input/__init__.py @@ -1,9 +1,8 @@ import logging import os import sys -from datetime import datetime +from datetime import datetime, timedelta from pathlib import Path -from typing import Iterable eccodes_definition_path = Path(sys.prefix) / "share/eccodes-cosmo-resources/definitions" os.environ["ECCODES_DEFINITION_PATH"] = str(eccodes_definition_path) @@ -16,8 +15,37 @@ LOG = logging.getLogger(__name__) +def _select_valid_times(ds, times: np.datetime64): + # (handle special case where some valid times are not in the dataset, e.g. at the end) + times_np = np.asarray(times, dtype="datetime64[ns]") + times_included = np.isin(times_np, ds.time.values) + if times_included.all(): + return ds.sel(time=times_np) + elif times_included.any(): + LOG.warning( + "Some valid times are not included in the dataset: \n%s", + times_np[~times_included], + ) + return ds.sel(time=times_np[times_included]) + else: + raise ValueError( + "Valid times are not included in the dataset. " + "Please check the valid times and the dataset." + ) + + +def parse_steps(steps: str) -> list[int]: + # check that steps is in the format "start/stop/step" + if "/" not in steps: + raise ValueError(f"Expected steps in format 'start/stop/step', got '{steps}'") + if len(steps.split("/")) != 3: + raise ValueError(f"Expected steps in format 'start/stop/step', got '{steps}'") + start, end, step = map(int, steps.split("/")) + return list(range(start, end + 1, step)) + + def load_analysis_data_from_zarr( - analysis_zarr: Path, times: Iterable[datetime], params: list[str] + root: Path, reftime: datetime, steps: list[int], params: list[str] ) -> xr.Dataset: """Load analysis data from an anemoi-generated Zarr dataset @@ -36,9 +64,9 @@ def load_analysis_data_from_zarr( PARAMS_MAP_COSMO1 = { v: v.replace("TOT_PREC", "TOT_PREC_6H") for v in PARAMS_MAP_COSMO2.keys() } - PARAMS_MAP = PARAMS_MAP_COSMO2 if "co2" in analysis_zarr.name else PARAMS_MAP_COSMO1 + PARAMS_MAP = PARAMS_MAP_COSMO2 if "co2" in root.name else PARAMS_MAP_COSMO1 - ds = xr.open_zarr(analysis_zarr, consolidated=False) + ds = xr.open_zarr(root, consolidated=False) # rename "dates" to "time" and set it as index ds = ds.set_index(time="dates") @@ -59,8 +87,8 @@ def load_analysis_data_from_zarr( # set lat lon as coords (optional) if "latitudes" in ds and "longitudes" in ds: - ds = ds.rename({"latitudes": "latitude", "longitudes": "longitude"}) - ds = ds.set_coords(["latitude", "longitude"]) + ds = ds.rename({"latitudes": "lat", "longitudes": "lon"}) + ds = ds.set_coords(["lat", "lon"]) ds = ( ds["data"] .to_dataset("variable") @@ -71,30 +99,15 @@ def load_analysis_data_from_zarr( if "cell" in ds.dims: ds = ds.rename({"cell": "values"}) - # select valid times - # (handle special case where some valid times are not in the dataset, e.g. at the end) - times_included = times.isin(ds.time.values).values - if all(times_included): - ds = ds.sel(time=times) - elif np.sum(times_included) < len(times_included): - LOG.warning( - "Some valid times are not included in the dataset: \n%s", - times[~times_included].values, - ) - ds = ds.sel(time=times[times_included]) - else: - raise ValueError( - "Valid times are not included in the dataset. " - "Please check the valid times and the dataset." - ) - return ds + times = np.datetime64(reftime) + np.asarray(steps, dtype="timedelta64[h]") + return _select_valid_times(ds, times) def load_fct_data_from_grib( - grib_output_dir: Path, reftime: datetime, steps: list[int], params: list[str] + root: Path, reftime: datetime, steps: list[int], params: list[str] ) -> xr.Dataset: """Load forecast data from GRIB files for a specific valid time.""" - files = sorted(grib_output_dir.glob("20*.grib")) + files = sorted(root.glob(f"{reftime:%Y%m%d%H%M}*.grib")) fds = data_source.FileDataSource(datafiles=files) ds = grib_decoder.load(fds, {"param": params, "step": steps}) for var, da in ds.items(): @@ -127,13 +140,13 @@ def load_fct_data_from_grib( def load_baseline_from_zarr( - zarr_path: Path, reftime: datetime, steps: list[int], params: list[str] + root: Path, reftime: datetime, steps: list[int], params: list[str] ) -> xr.Dataset: """Load forecast data from a Zarr dataset.""" try: - baseline = xr.open_zarr(zarr_path, consolidated=True, decode_timedelta=True) + baseline = xr.open_zarr(root, consolidated=True, decode_timedelta=True) except ValueError: - raise ValueError(f"Could not open baseline zarr at {zarr_path}") + raise ValueError(f"Could not open baseline zarr at {root}") baseline = baseline.rename( {"forecast_reference_time": "ref_time", "step": "lead_time"} @@ -156,14 +169,116 @@ def load_baseline_from_zarr( lead_time=np.array(steps, dtype="timedelta64[h]"), ) baseline = baseline.assign_coords(time=baseline.ref_time + baseline.lead_time) + if "latitude" in baseline.coords and "longitude" in baseline: + baseline = baseline.rename({"latitude": "lat", "longitude": "lon"}) return baseline -def parse_steps(steps: str) -> list[int]: - # check that steps is in the format "start/stop/step" - if "/" not in steps: - raise ValueError(f"Expected steps in format 'start/stop/step', got '{steps}'") - if len(steps.split("/")) != 3: - raise ValueError(f"Expected steps in format 'start/stop/step', got '{steps}'") - start, end, step = map(int, steps.split("/")) - return list(range(start, end + 1, step)) +def load_obs_data_from_peakweather( + root, reftime: datetime, steps: list[int], params: list[str], freq: str = "1h" +) -> xr.Dataset: + """Load PeakWeather station observations into an xarray Dataset. + + Returns a Dataset with dimensions `time` and `values`, values coordinates + (`lat`, `lon`), and variables renamed to ICON parameter names. + Temperatures are converted to Kelvin when present. + """ + from peakweather.dataset import PeakWeatherDataset + + param_names = { + "temperature": "T_2M", + "wind_u": "U_10M", + "wind_v": "V_10M", + } + param_names = {k: v for k, v in param_names.items() if v in params} + + start = reftime + end = start + timedelta(hours=max(steps)) + if len(steps) > 1: + end += timedelta(hours=steps[-1] - steps[-2]) # extend by 1 extra step + years = list(set([start.year, end.year])) + pw = PeakWeatherDataset(root=root, years=years, freq=freq) + ds, mask = pw.get_observations( + parameters=[k for k in param_names.keys()], + first_date=f"{start:%Y-%m-%d %H:%M}", + last_date=f"{end:%Y-%m-%d %H:%M}", + return_mask=True, + ) + ds = ( + ds.stack(["nat_abbr", "name"], future_stack=True) + .to_xarray() + .to_dataset(dim="name") + ) + mask = ( + mask.stack(["nat_abbr", "name"], future_stack=True) + .to_xarray() + .to_dataset(dim="name") + ) + ds = ds.where(mask) + ds = ds.rename({"datetime": "time", "nat_abbr": "values"}) + ds = ds.rename(param_names) + ds = ds.assign_coords(time=ds.indexes["time"].tz_convert("UTC").tz_localize(None)) + ds = ds.assign_coords(values=ds.indexes["values"]) + ds = ds.assign_coords(lon=("values", pw.stations_table["longitude"])) + ds = ds.assign_coords(lat=("values", pw.stations_table["latitude"])) + if "T_2M" in ds: + ds["T_2M"] = ds["T_2M"] + 273.15 # convert to Kelvin + ds = ds.dropna("values", how="all") + + times = np.datetime64(reftime) + np.asarray(steps, dtype="timedelta64[h]") + return _select_valid_times(ds, times) + + +def load_truth_data( + root, reftime: datetime, steps: list[int], params: list[str] +) -> xr.Dataset: + """Load truth data from analysis Zarr or PeakWeather observations.""" + if root.suffix == ".zarr": + LOG.info("Loading ground truth from an analysis zarr dataset...") + truth = load_analysis_data_from_zarr( + root=root, + reftime=reftime, + steps=steps, + params=params, + ) + truth = truth.compute().chunk( + {"y": -1, "x": -1} + if "y" in truth.dims and "x" in truth.dims + else {"values": -1} + ) + elif "peakweather" in str(root): + LOG.info("Loading ground truth from PeakWeather observations...") + truth = load_obs_data_from_peakweather( + root=root, + reftime=reftime, + steps=steps, + params=params, + ) + else: + raise ValueError(f"Unsupported truth root: {root}") + return truth + + +def load_forecast_data( + root, reftime: datetime, steps: list[int], params: list[str] +) -> xr.Dataset: + """Load forecast data from GRIB files or a baseline Zarr dataset.""" + + if any(root.glob("*.grib")): + LOG.info("Loading forecasts from GRIB files...") + fcst = load_fct_data_from_grib( + root=root, + reftime=reftime, + steps=steps, + params=params, + ) + else: + LOG.info("Loading baseline forecasts from zarr dataset...") + fcst = load_baseline_from_zarr( + root=root, + reftime=reftime, + steps=steps, + params=params, + ) + + return fcst diff --git a/src/evalml/config.py b/src/evalml/config.py index ef59068f..802afabb 100644 --- a/src/evalml/config.py +++ b/src/evalml/config.py @@ -173,18 +173,18 @@ class BaselineConfig(BaseModel): ) -class AnalysisConfig(BaseModel): - """Configuration for the analysis data used in the verification.""" +class TruthConfig(BaseModel): + """Configuration for the truth data used in the verification.""" label: str = Field( ..., min_length=1, - description="Label for the analysis that will be used in experiment results such as reports and figures.", + description="Label that will be used in experiment results such as reports and figures.", ) - analysis_zarr: str = Field( + root: str = Field( ..., min_length=1, - description="Path to the zarr dataset containing the analysis data.", + description="Path to the root of the dataset.", ) @@ -306,7 +306,7 @@ class ConfigModel(BaseModel): ..., description="Dictionary of baselines to include in the verification.", ) - analysis: AnalysisConfig + truth: TruthConfig | None stratification: Stratification locations: Locations profile: Profile diff --git a/src/verification/__init__.py b/src/verification/__init__.py index db70f4fc..97a62505 100644 --- a/src/verification/__init__.py +++ b/src/verification/__init__.py @@ -75,6 +75,7 @@ def _mask_from_polygons( def _compute_scores( fcst: xr.DataArray, obs: xr.DataArray, + dim: list[str], prefix="", suffix="", source="", @@ -83,7 +84,6 @@ def _compute_scores( Compute basic verification metrics between two xarray DataArrays (fcst and obs). Returns a xarray Dataset with the computed metrics. """ - dim = ["x", "y"] if "x" in fcst.dims and "y" in fcst.dims else ["values"] error = fcst - obs scores = xr.Dataset( { @@ -101,6 +101,7 @@ def _compute_scores( def _compute_statistics( data: xr.DataArray, + dim: list[str], prefix="", suffix="", source="", @@ -109,7 +110,6 @@ def _compute_statistics( Compute basic statistics of a xarray DataArray (data). Returns a xarray Dataset with the computed statistics. """ - dim = ["x", "y"] if "x" in data.dims and "y" in data.dims else ["values"] stats = xr.Dataset( { f"{prefix}mean{suffix}": data.mean(dim=dim, skipna=True), @@ -146,6 +146,7 @@ def verify( fcst_label: str, obs_label: str, regions: list[str] | None = None, + dim: list[str] | None = None, ) -> xr.Dataset: """ Compare two xarray Datasets (fcst and obs) and return pandas DataFrame with @@ -153,15 +154,21 @@ def verify( """ start = time.time() + if dim is None: + if "x" in fcst.dims and "y" in fcst.dims: + dim = ["x", "y"] + elif "values" in fcst.dims: + dim = ["values"] + else: + dim = ["values"] + # rewrite the verification to use dask and xarray # chunk the data to avoid memory issues # compute the metrics in parallel # return the results as a xarray Dataset fcst_aligned, obs_aligned = xr.align(fcst, obs, join="inner", copy=False) region_polygons = ShapefileSpatialAggregationMasks(shp=regions) - masks = region_polygons.get_masks( - lon=obs_aligned["longitude"], lat=obs_aligned["latitude"] - ) + masks = region_polygons.get_masks(lon=obs_aligned["lon"], lat=obs_aligned["lat"]) scores = [] statistics = [] @@ -180,19 +187,29 @@ def verify( # scores vs time (reduce spatially) score.append( _compute_scores( - fcst_param, obs_param, prefix=param + ".", source=fcst_label + fcst_param, + obs_param, + prefix=param + ".", + source=fcst_label, + dim=dim, ).expand_dims(region=[region]) ) # statistics vs time (reduce spatially) fcst_statistics.append( _compute_statistics( - fcst_param, prefix=param + ".", source=fcst_label + fcst_param, + prefix=param + ".", + source=fcst_label, + dim=dim, ).expand_dims(region=[region]) ) obs_statistics.append( _compute_statistics( - obs_param, prefix=param + ".", source=obs_label + obs_param, + prefix=param + ".", + source=obs_label, + dim=dim, ).expand_dims(region=[region]) ) diff --git a/src/verification/spatial.py b/src/verification/spatial.py new file mode 100644 index 00000000..a5186d5f --- /dev/null +++ b/src/verification/spatial.py @@ -0,0 +1,149 @@ +"""Spatial mapping helpers for aligning forecasts and references. + +This module contains reusable nearest-neighbor utilities used by verification +and plotting scripts to map data between different spatial supports. +""" + +from __future__ import annotations + +import numpy as np +import xarray as xr +from scipy.spatial import cKDTree + + +def spherical_nearest_neighbor_indices( + source_lat: np.ndarray, + source_lon: np.ndarray, + target_lat: np.ndarray, + target_lon: np.ndarray, +) -> np.ndarray: + """Return indices of nearest source points for each target point. + + Distances are computed in 3D Cartesian space after projecting latitude and + longitude (degrees) onto the unit sphere. This avoids distortions from + Euclidean distance in degree space. + + Parameters + ---------- + source_lat, source_lon + Latitude and longitude of source points in degrees. + target_lat, target_lon + Latitude and longitude of target points in degrees. + + Returns + ------- + np.ndarray + Integer indices into source points, one index per target point. + """ + + source_lat = np.asarray(source_lat).ravel() + source_lon = np.asarray(source_lon).ravel() + target_lat = np.asarray(target_lat).ravel() + target_lon = np.asarray(target_lon).ravel() + + source_lat_rad = np.deg2rad(source_lat) + source_lon_rad = np.deg2rad(source_lon) + target_lat_rad = np.deg2rad(target_lat) + target_lon_rad = np.deg2rad(target_lon) + + source_xyz = np.c_[ + np.cos(source_lat_rad) * np.cos(source_lon_rad), + np.cos(source_lat_rad) * np.sin(source_lon_rad), + np.sin(source_lat_rad), + ] + target_xyz = np.c_[ + np.cos(target_lat_rad) * np.cos(target_lon_rad), + np.cos(target_lat_rad) * np.sin(target_lon_rad), + np.sin(target_lat_rad), + ] + + tree = cKDTree(source_xyz) + _, nearest_idx = tree.query(target_xyz, k=1) + return np.asarray(nearest_idx, dtype=int) + + +def nearest_grid_yx_indices( + grid: xr.Dataset | xr.DataArray, target_lat: np.ndarray, target_lon: np.ndarray +) -> tuple[np.ndarray, np.ndarray]: + """Find nearest `(y, x)` grid indices for target coordinates. + + Parameters + ---------- + grid + Dataset or DataArray with `lat` and `lon` coordinates defined on a + `(y, x)` grid. + target_lat, target_lon + Target coordinates in degrees. + + Returns + ------- + tuple[np.ndarray, np.ndarray] + Arrays of `y` and `x` indices for each target location. + """ + + if "lat" not in grid or "lon" not in grid: + raise ValueError("Input must provide 'lat' and 'lon' coordinates") + + lat2d = np.asarray(grid["lat"].values) + lon2d = np.asarray(grid["lon"].values) + if lat2d.ndim != 2 or lon2d.ndim != 2: + raise ValueError("'lat' and 'lon' must be 2D on (y, x) for y/x indexing") + + flat_idx = spherical_nearest_neighbor_indices( + source_lat=lat2d.ravel(), + source_lon=lon2d.ravel(), + target_lat=target_lat, + target_lon=target_lon, + ) + y_idx, x_idx = np.unravel_index(flat_idx, lat2d.shape) + return np.asarray(y_idx, dtype=int), np.asarray(x_idx, dtype=int) + + +def map_forecast_to_truth(fcst: xr.Dataset, truth: xr.Dataset) -> xr.Dataset: + """Map forecast points to truth locations using nearest-neighbor matching. + + The forecast is flattened to a single spatial `values` dimension (when + provided as `(y, x)`), then sampled at the nearest points to each truth + location. Returned forecast coordinates are overwritten with truth station + coordinates to make subsequent verification align naturally. + + Parameters + ---------- + fcst + Forecast dataset with `lat` and `lon` coordinates on either `(y, x)` or + `values`. + truth + Reference dataset with `lat` and `lon` coordinates on either `(y, x)` or + `values`. + + Returns + ------- + xr.Dataset + Mapped forecast dataset. + """ + # TODO: return fcst unchanged when forecast and truth are already aligned + + truth_is_grid = "y" in truth.dims and "x" in truth.dims + + if "y" in fcst.dims and "x" in fcst.dims: + fcst = fcst.stack(values=("y", "x")) + if truth_is_grid: + truth = truth.stack(values=("y", "x")) + + nearest_idx = spherical_nearest_neighbor_indices( + source_lat=fcst["lat"].values, + source_lon=fcst["lon"].values, + target_lat=truth["lat"].values, + target_lon=truth["lon"].values, + ) + + fcst = fcst.isel(values=nearest_idx) + fcst = fcst.drop_vars(["x", "y", "values"], errors="ignore") + fcst = fcst.assign_coords(lon=("values", truth.lon.data)) + fcst = fcst.assign_coords(lat=("values", truth.lat.data)) + fcst = fcst.assign_coords(values=truth["values"]) + + if truth_is_grid: + fcst = fcst.unstack("values") + + return fcst diff --git a/tests/unit/test_spatial_mapping.py b/tests/unit/test_spatial_mapping.py new file mode 100644 index 00000000..73d56954 --- /dev/null +++ b/tests/unit/test_spatial_mapping.py @@ -0,0 +1,133 @@ +import numpy as np +import xarray as xr + +from verification.spatial import ( + map_forecast_to_truth, + nearest_grid_yx_indices, + spherical_nearest_neighbor_indices, +) + + +def test_spherical_nearest_neighbor_indices_returns_expected_points(): + source_lat = np.array([46.0, 46.0, 47.0, 47.0]) + source_lon = np.array([7.0, 8.0, 7.0, 8.0]) + target_lat = np.array([46.1, 46.9]) + target_lon = np.array([7.1, 7.9]) + + idx = spherical_nearest_neighbor_indices( + source_lat=source_lat, + source_lon=source_lon, + target_lat=target_lat, + target_lon=target_lon, + ) + + assert np.array_equal(idx, np.array([0, 3])) + + +def test_nearest_grid_yx_indices_returns_grid_indices(): + lat = xr.DataArray([[46.0, 46.0], [47.0, 47.0]], dims=("y", "x")) + lon = xr.DataArray([[7.0, 8.0], [7.0, 8.0]], dims=("y", "x")) + grid = xr.Dataset(coords={"lat": lat, "lon": lon}) + + y_idx, x_idx = nearest_grid_yx_indices( + grid=grid, + target_lat=np.array([46.1, 46.9]), + target_lon=np.array([7.1, 7.9]), + ) + + assert np.array_equal(y_idx, np.array([0, 1])) + assert np.array_equal(x_idx, np.array([0, 1])) + + +def test_map_forecast_to_truth_maps_forecast_to_truth_locations(): + fcst_time = np.array( + ["2024-01-01T00:00", "2024-01-01T01:00"], dtype="datetime64[ns]" + ) + truth_time = np.array( + ["2024-01-01T00:00", "2024-01-01T01:00", "2024-01-01T02:00"], + dtype="datetime64[ns]", + ) + + fcst = xr.Dataset( + data_vars={ + "T_2M": ( + ("time", "y", "x"), + np.array( + [ + [[1.0, 2.0], [3.0, 4.0]], + [[10.0, 20.0], [30.0, 40.0]], + ] + ), + ) + }, + coords={ + "time": fcst_time, + "y": [0, 1], + "x": [0, 1], + "lat": (("y", "x"), np.array([[46.0, 46.0], [47.0, 47.0]])), + "lon": (("y", "x"), np.array([[7.0, 8.0], [7.0, 8.0]])), + }, + ) + truth = xr.Dataset( + data_vars={"T_2M": (("time", "values"), np.zeros((3, 2)))}, + coords={ + "time": truth_time, + "values": ["STA1", "STA2"], + "lat": ("values", np.array([46.1, 46.9])), + "lon": ("values", np.array([7.1, 7.9])), + }, + ) + + mapped_fcst = map_forecast_to_truth(fcst, truth) + + assert mapped_fcst["T_2M"].dims == ("time", "values") + assert np.array_equal(mapped_fcst["time"].values, fcst_time) + assert np.array_equal(mapped_fcst["values"].values, np.array(["STA1", "STA2"])) + assert np.allclose(mapped_fcst["lat"].values, np.array([46.1, 46.9])) + assert np.allclose(mapped_fcst["lon"].values, np.array([7.1, 7.9])) + assert np.allclose( + mapped_fcst["T_2M"].values, + np.array([[1.0, 4.0], [10.0, 40.0]]), + ) + + +def test_map_forecast_to_truth_restores_grid_when_truth_is_gridded(): + fcst_time = np.array(["2024-01-01T00:00"], dtype="datetime64[ns]") + + fcst = xr.Dataset( + data_vars={ + "T_2M": ( + ("time", "y", "x"), + np.array([[[1.0, 2.0], [3.0, 4.0]]]), + ) + }, + coords={ + "time": fcst_time, + "y": [0, 1], + "x": [0, 1], + "lat": (("y", "x"), np.array([[46.0, 46.0], [47.0, 47.0]])), + "lon": (("y", "x"), np.array([[7.0, 8.0], [7.0, 8.0]])), + }, + ) + truth = xr.Dataset( + data_vars={"T_2M": (("time", "y", "x"), np.zeros((1, 2, 2)))}, + coords={ + "time": fcst_time, + "y": [0, 1], + "x": [0, 1], + "lat": (("y", "x"), np.array([[46.1, 46.1], [46.9, 46.9]])), + "lon": (("y", "x"), np.array([[7.1, 7.9], [7.1, 7.9]])), + }, + ) + + mapped_fcst = map_forecast_to_truth(fcst, truth) + + assert mapped_fcst["T_2M"].dims == ("time", "y", "x") + assert np.array_equal(mapped_fcst["y"].values, np.array([0, 1])) + assert np.array_equal(mapped_fcst["x"].values, np.array([0, 1])) + assert np.allclose(mapped_fcst["lat"].values, truth["lat"].values) + assert np.allclose(mapped_fcst["lon"].values, truth["lon"].values) + assert np.allclose( + mapped_fcst["T_2M"].values, + np.array([[[1.0, 2.0], [3.0, 4.0]]]), + ) diff --git a/workflow/rules/data.smk b/workflow/rules/data.smk index 9892f8f6..91a48a79 100644 --- a/workflow/rules/data.smk +++ b/workflow/rules/data.smk @@ -4,12 +4,18 @@ from pathlib import Path include: "common.smk" +if config["truth"]["root"].endswith("peakweather"): + output_peakweather_root = config["truth"]["root"] +else: + output_peakweather_root = OUT_ROOT / "data/observations/peakweather" + + rule download_obs_from_peakweather: localrule: True output: - peakweather=directory(OUT_ROOT / "data/observations/peakweather"), + root=directory(output_peakweather_root), run: from peakweather.dataset import PeakWeatherDataset # Download the data from Huggingface - ds = PeakWeatherDataset(root=output.peakweather) + ds = PeakWeatherDataset(root=output.root) diff --git a/workflow/rules/plot.smk b/workflow/rules/plot.smk index 8a3515c7..73badb2b 100644 --- a/workflow/rules/plot.smk +++ b/workflow/rules/plot.smk @@ -9,25 +9,28 @@ include: "common.smk" import pandas as pd -def _use_first_baseline_zarr(wc): - """Get the first available baseline zarr for the given init time.""" +def _get_available_baselines(wc) -> list[dict[str, str]]: + """Get all available baseline zarr datasets for the given init time.""" + baselines = [] for baseline_id in BASELINE_CONFIGS: root = BASELINE_CONFIGS[baseline_id].get("root") steps = BASELINE_CONFIGS[baseline_id].get("steps") + label = BASELINE_CONFIGS[baseline_id].get("label", baseline_id) year = wc.init_time[2:4] baseline_zarr = f"{root}/FCST{year}.zarr" if Path(baseline_zarr).exists(): - return baseline_zarr, steps - raise ValueError(f"No baseline zarr found for init time {wc.init_time}") + baselines.append({"zarr": baseline_zarr, "steps": steps, "label": label}) + if not baselines: + raise ValueError(f"No baseline zarr found for init time {wc.init_time}") + return baselines rule plot_meteogram: input: script="workflow/scripts/plot_meteogram.mo.py", inference_okfile=rules.execute_inference.output.okfile, - analysis_zarr=config["analysis"].get("analysis_zarr"), - baseline_zarr=lambda wc: _use_first_baseline_zarr(wc)[0], - peakweather_dir=rules.download_obs_from_peakweather.output.peakweather, + truth=config["truth"]["root"], + peakweather_dir=rules.download_obs_from_peakweather.output.root, output: OUT_ROOT / "results/{showcase}/{run_id}/{init_time}/{init_time}_{param}_{sta}.png", @@ -37,25 +40,46 @@ rule plot_meteogram: cpus_per_task=1, runtime="10m", params: - grib_out_dir=lambda wc: ( + ana_label=lambda wc: config["truth"]["label"], + fcst_grib=lambda wc: ( Path(OUT_ROOT) / f"data/runs/{wc.run_id}/{wc.init_time}/grib" ).resolve(), - baseline_steps=lambda wc: _use_first_baseline_zarr(wc)[1], + fcst_steps=lambda wc: RUN_CONFIGS[wc.run_id]["steps"], + fcst_label=lambda wc: RUN_CONFIGS[wc.run_id]["label"], + baseline_zarrs=lambda wc: [x["zarr"] for x in _get_available_baselines(wc)], + baseline_steps=lambda wc: [x["steps"] for x in _get_available_baselines(wc)], + baseline_labels=lambda wc: [x["label"] for x in _get_available_baselines(wc)], shell: """ + set -euo pipefail export ECCODES_DEFINITION_PATH=$(realpath .venv/share/eccodes-cosmo-resources/definitions) - python {input.script} \ - --forecast {params.grib_out_dir} --analysis {input.analysis_zarr} \ - --baseline {input.baseline_zarr} --baseline_steps {params.baseline_steps} \ - --peakweather {input.peakweather_dir} \ - --date {wildcards.init_time} --outfn {output[0]} \ - --param {wildcards.param} --station {wildcards.sta} + + BASELINE_ZARRS=({params.baseline_zarrs:q}) + BASELINE_STEPS=({params.baseline_steps:q}) + BASELINE_LABELS=({params.baseline_labels:q}) + + CMD_ARGS=( + --forecast {params.fcst_grib:q} + --forecast_steps {params.fcst_steps:q} + --forecast_label {params.fcst_label:q} + --analysis {input.truth:q} + --analysis_label {params.ana_label:q} + --peakweather {input.peakweather_dir:q} + --date {wildcards.init_time:q} + --outfn {output[0]:q} + --param {wildcards.param:q} + --station {wildcards.sta:q} + ) + + for i in "${{!BASELINE_ZARRS[@]}}"; do + CMD_ARGS+=(--baseline "${{BASELINE_ZARRS[$i]}}") + CMD_ARGS+=(--baseline_steps "${{BASELINE_STEPS[$i]}}") + CMD_ARGS+=(--baseline_label "${{BASELINE_LABELS[$i]}}") + done + + python {input.script} "${{CMD_ARGS[@]}}" # interactive editing (needs to set localrule: True and use only one core) - # marimo edit {input.script} -- \ - # --forecast {params.grib_out_dir} --analysis {input.analysis_zarr} \ - # --baseline {input.baseline_zarr} --peakweather {input.peakweather_dir} \ - # --date {wildcards.init_time} --outfn {output[0]} \ - # --param {wildcards.param} --station {wildcards.sta} + # marimo edit {input.script} -- "${{CMD_ARGS[@]}}" """ diff --git a/workflow/rules/report.smk b/workflow/rules/report.smk index 7122e60a..b0acc44b 100644 --- a/workflow/rules/report.smk +++ b/workflow/rules/report.smk @@ -9,10 +9,10 @@ include: "common.smk" def make_header_text(): dates = config["dates"] - analysis = config["analysis"]["label"] + truth = config["truth"]["label"] if isinstance(dates, list): return f"Explicit initializations from {len(dates)} runs have been used." - return f"Verification against {analysis} with initializations from {dates.get('start')} to {dates.get('end')} by {dates.get('frequency')}" + return f"Verification against {truth} with initializations from {dates.get('start')} to {dates.get('end')} by {dates.get('frequency')}" rule report_experiment_dashboard: diff --git a/workflow/rules/verif.smk b/workflow/rules/verif.smk index f30ad9a2..5c763353 100644 --- a/workflow/rules/verif.smk +++ b/workflow/rules/verif.smk @@ -20,11 +20,11 @@ rule verif_metrics_baseline: root=BASELINE_CONFIGS[wc.baseline_id].get("root"), year=wc.init_time[2:4], ), - analysis_zarr=config["analysis"].get("analysis_zarr"), + truth=config["truth"]["root"], params: baseline_label=lambda wc: BASELINE_CONFIGS[wc.baseline_id].get("label"), baseline_steps=lambda wc: BASELINE_CONFIGS[wc.baseline_id]["steps"], - analysis_label=config["analysis"].get("label"), + truth_label=config["truth"]["label"], regions=REGIONS, output: OUT_ROOT / "data/baselines/{baseline_id}/{init_time}/verif.nc", @@ -38,11 +38,11 @@ rule verif_metrics_baseline: """ uv run {input.script} \ --forecast {input.baseline_zarr} \ - --analysis_zarr {input.analysis_zarr} \ + --truth {input.truth} \ --reftime {wildcards.init_time} \ --steps "{params.baseline_steps}" \ --label "{params.baseline_label}" \ - --analysis_label "{params.analysis_label}" \ + --truth_label "{params.truth_label}" \ --regions "{params.regions}" \ --output {output} > {log} 2>&1 """ @@ -61,7 +61,7 @@ rule verif_metrics: "src/data_input/__init__.py", script="workflow/scripts/verif_single_init.py", inference_okfile=rules.execute_inference.output.okfile, - analysis_zarr=config["analysis"].get("analysis_zarr"), + truth=config["truth"]["root"], output: OUT_ROOT / "data/runs/{run_id}/{init_time}/verif.nc", # wildcard_constraints: @@ -70,7 +70,7 @@ rule verif_metrics: params: fcst_label=lambda wc: RUN_CONFIGS[wc.run_id].get("label"), fcst_steps=lambda wc: RUN_CONFIGS[wc.run_id]["steps"], - analysis_label=config["analysis"].get("label"), + truth_label=config["truth"]["label"], regions=REGIONS, grib_out_dir=lambda wc: ( Path(OUT_ROOT) / f"data/runs/{wc.run_id}/{wc.init_time}/grib" @@ -85,11 +85,11 @@ rule verif_metrics: """ uv run {input.script} \ --forecast {params.grib_out_dir} \ - --analysis_zarr {input.analysis_zarr} \ + --truth {input.truth} \ --reftime {wildcards.init_time} \ --steps "{params.fcst_steps}" \ --label "{params.fcst_label}" \ - --analysis_label "{params.analysis_label}" \ + --truth_label "{params.truth_label}" \ --regions "{params.regions}" \ --output {output} > {log} 2>&1 """ diff --git a/workflow/scripts/plot_meteogram.mo.py b/workflow/scripts/plot_meteogram.mo.py index 8abb7c54..c26d2c24 100644 --- a/workflow/scripts/plot_meteogram.mo.py +++ b/workflow/scripts/plot_meteogram.mo.py @@ -1,59 +1,91 @@ import marimo -__generated_with = "0.16.5" +__generated_with = "0.19.6" app = marimo.App(width="medium") @app.cell def _(): from argparse import ArgumentParser + from datetime import datetime from pathlib import Path import matplotlib.pyplot as plt import numpy as np - import xarray as xr - from meteodatalab import data_source, grib_decoder from peakweather import PeakWeatherDataset from data_input import ( - load_analysis_data_from_zarr, - load_baseline_from_zarr, parse_steps, + load_forecast_data, + load_truth_data, ) + from verification.spatial import map_forecast_to_truth return ( ArgumentParser, Path, PeakWeatherDataset, - data_source, - grib_decoder, - load_analysis_data_from_zarr, - load_baseline_from_zarr, - parse_steps, + datetime, + load_forecast_data, + load_truth_data, + map_forecast_to_truth, np, + parse_steps, plt, - xr, ) @app.cell -def _(ArgumentParser, Path, parse_steps): +def _(ArgumentParser, Path, datetime, parse_steps): parser = ArgumentParser() parser.add_argument( "--forecast", type=str, default=None, help="Directory to forecast grib data" ) parser.add_argument( - "--analysis", type=str, default=None, help="Path to analysis zarr data" + "--forecast_steps", + type=parse_steps, + default="0/120/6", + help="Forecast steps in the format 'start/stop/step' (default: 0/120/6).", ) parser.add_argument( - "--baseline", type=str, default=None, help="Path to baseline zarr data" + "--forecast_label", + type=str, + default="forecast", + help="Label for forecast line in plot legend.", + ) + parser.add_argument( + "--baseline", + action="append", + type=str, + default=[], + help="Path to baseline zarr data (repeatable).", ) parser.add_argument( "--baseline_steps", + action="append", type=parse_steps, - default="0/120/6", - help="Forecast steps in the format 'start/stop/step' (default: 0/120/6).", + default=[], + help=( + "Forecast steps in the format 'start/stop/step' for each baseline " + "(repeatable, must match --baseline count)." + ), + ) + parser.add_argument( + "--baseline_label", + action="append", + type=str, + default=[], + help="Label for each baseline line in plot legend (repeatable).", + ) + parser.add_argument( + "--analysis", type=str, default=None, help="Path to analysis zarr data" + ) + parser.add_argument( + "--analysis_label", + type=str, + default="truth", + help="Label for analysis line in plot legend.", ) parser.add_argument( "--peakweather", type=str, default=None, help="Path to PeakWeather dataset" @@ -64,25 +96,43 @@ def _(ArgumentParser, Path, parse_steps): parser.add_argument("--station", type=str, help="station") args = parser.parse_args() - grib_dir = Path(args.forecast) - zarr_dir_ana = Path(args.analysis) - zarr_dir_base = Path(args.baseline) + forecast_grib_dir = Path(args.forecast) + forecast_steps = args.forecast_steps + forecast_label = args.forecast_label + analysis_zarr = Path(args.analysis) + analysis_label = args.analysis_label + baseline_zarrs = [Path(path) for path in args.baseline] baseline_steps = args.baseline_steps + baseline_labels = args.baseline_label + if len(baseline_zarrs) != len(baseline_steps): + raise ValueError( + "Mismatched baseline arguments: --baseline and --baseline_steps " + "must be provided the same number of times." + ) + if len(baseline_labels) != len(baseline_zarrs): + raise ValueError( + "Mismatched baseline arguments: --baseline and --baseline_label " + "must be provided the same number of times." + ) peakweather_dir = Path(args.peakweather) - init_time = args.date + init_time = datetime.strptime(args.date, "%Y%m%d%H%M") outfn = Path(args.outfn) station = args.station param = args.param return ( - grib_dir, + analysis_label, + analysis_zarr, + baseline_labels, + baseline_steps, + baseline_zarrs, + forecast_label, + forecast_steps, + forecast_grib_dir, init_time, outfn, param, peakweather_dir, station, - zarr_dir_ana, - zarr_dir_base, - baseline_steps, ) @@ -113,25 +163,23 @@ def preprocess_ds(ds, param: str): "name": '"Wind speed', } ds = ds.drop_vars(["U", "V"]) - return ds + return ds.squeeze() return (preprocess_ds,) @app.cell -def load_grib_data( - data_source, +def load_data( + analysis_zarr, baseline_steps, - grib_decoder, - grib_dir, + baseline_zarrs, + forecast_steps, + forecast_grib_dir, init_time, - load_analysis_data_from_zarr, - load_baseline_from_zarr, + load_forecast_data, + load_truth_data, param, preprocess_ds, - xr, - zarr_dir_ana, - zarr_dir_base, ): if param == "SP_10M": paramlist = ["U_10M", "V_10M"] @@ -140,162 +188,95 @@ def load_grib_data( else: paramlist = [param] - grib_files = sorted(grib_dir.glob(f"{init_time}*.grib")) - fds = data_source.FileDataSource(datafiles=grib_files) - ds_fct = xr.Dataset(grib_decoder.load(fds, {"param": paramlist})) - ds_fct = preprocess_ds(ds_fct, param) - da_fct = ds_fct[param].squeeze() - - ds_ana = load_analysis_data_from_zarr(zarr_dir_ana, da_fct.valid_time, paramlist) - ds_ana = preprocess_ds(ds_ana, param) - da_ana = ds_ana[param].squeeze() - - ds_base = load_baseline_from_zarr( - zarr_dir_base, da_fct.ref_time, baseline_steps, paramlist + forecast_ds = load_forecast_data( + forecast_grib_dir, init_time, forecast_steps, paramlist ) - ds_base = preprocess_ds(ds_base, param) - da_base = ds_base[param].squeeze() - return da_ana, da_base, da_fct + forecast_ds = preprocess_ds(forecast_ds, param) + steps = forecast_ds.lead_time.dt.total_seconds().values / 3600 + analysis_ds = load_truth_data(analysis_zarr, init_time, steps, paramlist) + analysis_ds = preprocess_ds(analysis_ds, param) -@app.cell -def _(PeakWeatherDataset, da_fct, np, param, peakweather_dir, station): - if param == "T_2M": - parameter = "temperature" - offset = 273.15 # K to C - elif param == "SP_10M": - parameter = "wind_speed" - offset = 0 - elif param == "TOT_PREC": - parameter = "precipitation" - offset = 0 - else: - raise NotImplementedError( - f"The mapping for {param=} to PeakWeather is not implemented" + baseline_ds_list = [ + preprocess_ds( + load_forecast_data(zarr, init_time, step, paramlist), + param, ) + for zarr, step in zip(baseline_zarrs, baseline_steps) + ] - peakweather = PeakWeatherDataset(root=peakweather_dir, freq="1h") - obs, mask = peakweather.get_observations( - parameters=[parameter], - stations=station, - first_date=np.datetime_as_string(da_fct.valid_time.values[0]), - last_date=np.datetime_as_string(da_fct.valid_time.values[-1]), - return_mask=True, - ) - obs = obs.loc[:, mask.iloc[0]].droplevel("name", axis=1) - obs - return obs, offset, peakweather + return analysis_ds, baseline_ds_list, forecast_ds @app.cell -def _(peakweather): +def _(PeakWeatherDataset, peakweather_dir, station): + peakweather = PeakWeatherDataset(root=peakweather_dir) stations = peakweather.stations_table - stations.index.names = ["station"] - stations - return (stations,) + stations.index.names = ["values"] + station_ds = stations.to_xarray().sel(values=[station]) # keep singleton dim + station_ds = station_ds.rename({"latitude": "lat", "longitude": "lon"}) + station_ds = station_ds.set_coords(("lat", "lon", "station_name")) + station_ds = station_ds.drop_vars(list(station_ds.data_vars)) + station_ds + return (station_ds,) @app.cell -def _(da_ana, da_base, da_fct, np, stations): - def nearest_indexers_euclid(ds, lat_s, lon_s): - """ - Return a dict of indexers usable as: ds.isel(**indexers) - - Examples: - - 2D structured grid -> {"y": y_idx, "x": x_idx} - - 1D unstructured grid -> {"point": i_idx} (or {"cell": i_idx}, etc.) - """ - try: - lat = ds["lat"] - lon = ds["lon"] - except KeyError: - lat = ds["latitude"] - lon = ds["longitude"] - - dist = (lat - lat_s) ** 2 + (lon - lon_s) ** 2 - arr = dist.values - - flat_idx = int(np.nanargmin(arr)) - - if dist.ndim == 1: - return {dist.dims[0]: flat_idx} - - unr = np.unravel_index(flat_idx, dist.shape) - return {dim: int(i) for dim, i in zip(dist.dims, unr)} - - def get_idx_row(row, da): - return nearest_indexers_euclid(da, row["latitude"], row["longitude"]) - - # store dicts (indexers) in columns - sta_idxs = stations.copy() - sta_idxs["fct_isel"] = sta_idxs.apply(lambda r: get_idx_row(r, da_fct), axis=1) - sta_idxs["ana_isel"] = sta_idxs.apply(lambda r: get_idx_row(r, da_ana), axis=1) - sta_idxs["base_isel"] = sta_idxs.apply(lambda r: get_idx_row(r, da_base), axis=1) - sta_idxs - return (sta_idxs,) +def _(analysis_ds, baseline_ds_list, forecast_ds, station_ds, map_forecast_to_truth): + forecast_station_ds = map_forecast_to_truth(forecast_ds, station_ds) + analysis_station_ds = map_forecast_to_truth(analysis_ds, station_ds) + baseline_station_ds_list = [ + map_forecast_to_truth(ds, station_ds) for ds in baseline_ds_list + ] + return analysis_station_ds, baseline_station_ds_list, forecast_station_ds @app.cell def _( - da_ana, - da_base, - da_fct, + analysis_label, + baseline_labels, + analysis_station_ds, + baseline_station_ds_list, + forecast_label, + forecast_ds, + forecast_station_ds, init_time, - obs, - offset, outfn, + param, plt, - sta_idxs, station, ): - # station indices - row = sta_idxs.loc[station] - fct_isel = row.fct_isel - ana_isel = row.ana_isel - base_isel = row.base_isel - fig, ax = plt.subplots() - # station + # truth ax.plot( - obs.index.to_pydatetime(), - obs.to_numpy() + offset, + analysis_station_ds["time"].values, + analysis_station_ds[param].values, color="k", ls="--", - label=station, - ) - - # analysis - ana2plot = da_ana.isel(**ana_isel) - ax.plot( - ana2plot["time"].values, - ana2plot.values, - color="k", - ls="-", - label="analysis", - ) - - # baseline - base2plot = da_base.isel(**base_isel) - ax.plot( - base2plot["time"].values, - base2plot.values, - color="C1", - label="baseline", + label=analysis_label, ) - + # baselines + for i, (baseline_label, baseline_station_ds) in enumerate( + zip(baseline_labels, baseline_station_ds_list), start=1 + ): + ax.plot( + baseline_station_ds["time"].values, + baseline_station_ds[param].values, + color=f"C{i}", + label=f"{baseline_label}", + ) # forecast - fct2plot = da_fct.isel(**fct_isel) ax.plot( - fct2plot["valid_time"].values, - fct2plot.values, + forecast_station_ds["time"].values, + forecast_station_ds[param].values, color="C0", - label="forecast", + label=forecast_label, ) ax.legend() - param2plot = da_fct.attrs.get("parameter", {}) + param2plot = forecast_ds[param].attrs.get("parameter", {}) short = param2plot.get("shortName", "") units = param2plot.get("units", "") name = param2plot.get("name", "") @@ -304,11 +285,7 @@ def _( ax.set_title(f"{init_time} {name} at {station}") plt.savefig(outfn) - return - - -@app.cell -def _(): + print(f"saved: {outfn}") return diff --git a/workflow/scripts/verif_aggregation.py b/workflow/scripts/verif_aggregation.py index deb35d65..9057b19c 100644 --- a/workflow/scripts/verif_aggregation.py +++ b/workflow/scripts/verif_aggregation.py @@ -34,7 +34,7 @@ def aggregate_results(ds: xr.Dataset) -> xr.Dataset: ds = ds.assign_coords( season=lambda ds: ds.ref_time.dt.season, init_hour=lambda ds: ds.ref_time.dt.hour, - ).drop_vars(["time"]) + ).drop_vars(["time"], errors="ignore") # compute mean with grouping by all permutations of season and init_hour ds_mean = [] diff --git a/workflow/scripts/verif_single_init.py b/workflow/scripts/verif_single_init.py index 5ad505ef..421078de 100644 --- a/workflow/scripts/verif_single_init.py +++ b/workflow/scripts/verif_single_init.py @@ -6,11 +6,11 @@ from verification import verify # noqa: E402 +from verification.spatial import map_forecast_to_truth # noqa: E402 from data_input import ( - load_baseline_from_zarr, - load_analysis_data_from_zarr, - load_fct_data_from_grib, parse_steps, + load_forecast_data, + load_truth_data, ) # noqa: E402 LOG = logging.getLogger(__name__) @@ -23,7 +23,7 @@ class ScriptConfig(Namespace): """Configuration for the script to verify baseline forecast data.""" archive_root: Path = None - analysis_zarr: Path = None + truth: Path = None baseline_zarr: Path = None reftime: datetime = None params: list[str] = ["T_2M", "TD_2M", "U_10M", "V_10M"] @@ -35,8 +35,8 @@ def program_summary_log(args): LOG.info("=" * 80) LOG.info("Running verification of baseline forecast data") LOG.info("=" * 80) - LOG.info("baseline zarr dataset: %s", args.baseline_zarr) - LOG.info("Zarr dataset for analysis: %s", args.analysis_zarr) + LOG.info("Baseline dataset: %s", args.baseline_zarr) + LOG.info("Truth dataset: %s", args.truth) LOG.info("Reference time: %s", args.reftime) LOG.info("Parameters to verify: %s", args.params) LOG.info("Lead time: %s", args.lead_time) @@ -48,29 +48,9 @@ def main(args: ScriptConfig): """Main function to verify baseline forecast data.""" # get baseline forecast data - now = datetime.now() - # try to open the baselin as a zarr, and if it fails load from grib - if not args.forecast: - raise ValueError("--forecast must be provided.") - - if any(args.forecast.glob("*.grib")): - LOG.info("Loading forecasts from GRIB files...") - fcst = load_fct_data_from_grib( - grib_output_dir=args.forecast, - reftime=args.reftime, - steps=args.steps, - params=args.params, - ) - else: - LOG.info("Loading baseline forecasts from zarr dataset...") - fcst = load_baseline_from_zarr( - zarr_path=args.forecast, - reftime=args.reftime, - steps=args.steps, - params=args.params, - ) + fcst = load_forecast_data(args.forecast, args.reftime, args.steps, args.params) LOG.info( "Loaded forecast data in %s seconds: \n%s", @@ -78,33 +58,21 @@ def main(args: ScriptConfig): fcst, ) - # get truth data (aka analysis data) + # get truth data now = datetime.now() - if args.analysis_zarr: - analysis = ( - load_analysis_data_from_zarr( - analysis_zarr=args.analysis_zarr, - times=fcst.time, - params=args.params, - ) - .compute() - .chunk( - {"y": -1, "x": -1} - if "y" in fcst.dims and "x" in fcst.dims - else {"values": -1} - ) - ) - else: - raise ValueError("--analysis_zarr must be provided.") + truth = load_truth_data(args.truth, args.reftime, args.steps, args.params) LOG.info( - "Loaded analysis data in %s seconds: \n%s", + "Loaded truth data in %s seconds: \n%s", (datetime.now() - now).total_seconds(), - analysis, + truth, ) - # compute metrics and statistics + # align forecast and truth data spatially and temporally + fcst = map_forecast_to_truth(fcst, truth) + truth = truth.sel(time=fcst.time) - results = verify(fcst, analysis, args.label, args.analysis_label, args.regions) + # compute metrics and statistics + results = verify(fcst, truth, args.label, args.truth_label, args.regions) # save results to NetCDF args.output.parent.mkdir(parents=True, exist_ok=True) @@ -125,11 +93,10 @@ def main(args: ScriptConfig): help="Path to the directory containing the grib forecast or to the zarr dataset containing baseline data.", ) parser.add_argument( - "--analysis_zarr", + "--truth", type=Path, required=True, - default="/scratch/mch/fzanetta/data/anemoi/datasets/mch-co2-an-archive-0p02-2015-2020-6h-v3-pl13.zarr", - help="Path to the zarr dataset containing analysis data.", + help="Path to the truth data.", ) parser.add_argument( "--reftime", @@ -155,10 +122,10 @@ def main(args: ScriptConfig): help="Label for the forecast or baseline data (default: COSMO-E).", ) parser.add_argument( - "--analysis_label", + "--truth_label", type=str, default="COSMO KENDA", - help="Label for the analysis data (default: COSMO KENDA).", + help="Label for the truth data (default: COSMO KENDA).", ) parser.add_argument( "--regions", diff --git a/workflow/tools/config.schema.json b/workflow/tools/config.schema.json index ea6186f4..cc66d7fb 100644 --- a/workflow/tools/config.schema.json +++ b/workflow/tools/config.schema.json @@ -1,28 +1,5 @@ { "$defs": { - "AnalysisConfig": { - "description": "Configuration for the analysis data used in the verification.", - "properties": { - "label": { - "description": "Label for the analysis that will be used in experiment results such as reports and figures.", - "minLength": 1, - "title": "Label", - "type": "string" - }, - "analysis_zarr": { - "description": "Path to the zarr dataset containing the analysis data.", - "minLength": 1, - "title": "Analysis Zarr", - "type": "string" - } - }, - "required": [ - "label", - "analysis_zarr" - ], - "title": "AnalysisConfig", - "type": "object" - }, "BaselineConfig": { "description": "Configuration for a single baseline to include in the verification.", "properties": { @@ -504,6 +481,29 @@ ], "title": "Stratification", "type": "object" + }, + "TruthConfig": { + "description": "Configuration for the truth data used in the verification.", + "properties": { + "label": { + "description": "Label that will be used in experiment results such as reports and figures.", + "minLength": 1, + "title": "Label", + "type": "string" + }, + "root": { + "description": "Path to the root of the dataset.", + "minLength": 1, + "title": "Root", + "type": "string" + } + }, + "required": [ + "label", + "root" + ], + "title": "TruthConfig", + "type": "object" } }, "additionalProperties": false, @@ -561,8 +561,15 @@ "title": "Baselines", "type": "array" }, - "analysis": { - "$ref": "#/$defs/AnalysisConfig" + "truth": { + "anyOf": [ + { + "$ref": "#/$defs/TruthConfig" + }, + { + "type": "null" + } + ] }, "stratification": { "$ref": "#/$defs/Stratification" @@ -579,7 +586,7 @@ "dates", "runs", "baselines", - "analysis", + "truth", "stratification", "locations", "profile"