From de37100a1b1c10016446fb02715e4fd1e00ee638 Mon Sep 17 00:00:00 2001 From: Francesco Zanetta Date: Fri, 6 Feb 2026 14:40:13 +0100 Subject: [PATCH 01/13] add configs for evaluation with operational analyses --- config/forecasters-ich1-oper.yaml | 60 +++++++++++++++++++ pyproject.toml | 2 +- .../sgm-forecaster-global-ich1-oper.yaml | 54 +++++++++++++++++ 3 files changed, 115 insertions(+), 1 deletion(-) create mode 100644 config/forecasters-ich1-oper.yaml create mode 100644 resources/inference/configs/sgm-forecaster-global-ich1-oper.yaml diff --git a/config/forecasters-ich1-oper.yaml b/config/forecasters-ich1-oper.yaml new file mode 100644 index 00000000..94163eae --- /dev/null +++ b/config/forecasters-ich1-oper.yaml @@ -0,0 +1,60 @@ +# yaml-language-server: $schema=../workflow/tools/config.schema.json +description: | + Evaluate skill of ICON-CH1 single. + +dates: + start: 2025-01-01T00:00 + end: 2025-10-20T00:00 + frequency: 54h + + +runs: + - forecaster: + mlflow_id: b30acf68520a4bbd8324c44666561696 + label: stage_C_icon_1km + steps: 0/120/6 + config: resources/inference/configs/sgm-forecaster-global-ich1-oper.yaml + disable_local_eccodes_definitions: true + extra_dependencies: + - git+https://github.com/ecmwf/anemoi-inference.git@main + +baselines: + - baseline: + baseline_id: ICON-CH1-EPS + label: ICON-CH1-EPS + root: /store_new/mch/msopr/ml/ICON-CH1-EPS + steps: 0/33/6 + +analysis: + label: KENDA-CH1 + analysis_zarr: /store_new/mch/msopr/ml/datasets/mch-ich1-1km-2024-2025-1h-pl13-v1.0.zarr + +stratification: + regions: + - jura + - mittelland + - voralpen + - alpennordhang + - innerealpentaeler + - alpensuedseite + root: /scratch/mch/bhendj/regions/Prognoseregionen_LV95_20220517 + +locations: + output_root: output/ + mlflow_uri: + - https://servicedepl.meteoswiss.ch/mlstore + - https://mlflow.ecmwf.int + +profile: + executor: slurm + global_resources: + gpus: 16 + default_resources: + slurm_partition: "postproc" + cpus_per_task: 1 + mem_mb_per_cpu: 1800 + runtime: "1h" + gpus: 0 + jobs: 50 + batch_rules: + plot_forecast_frame: 32 diff --git a/pyproject.toml b/pyproject.toml index 76680860..f754384d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,7 +13,7 @@ dependencies = [ "snakemake-executor-plugin-slurm", "click", "meteodata-lab>=0.4.0", - "anemoi-datasets>=0.5.25", + "anemoi-datasets>=0.5.31", "mlflow>=3.1.1", "pydantic>=2.11.7", "toml>=0.10.2", diff --git a/resources/inference/configs/sgm-forecaster-global-ich1-oper.yaml b/resources/inference/configs/sgm-forecaster-global-ich1-oper.yaml new file mode 100644 index 00000000..7e52968e --- /dev/null +++ b/resources/inference/configs/sgm-forecaster-global-ich1-oper.yaml @@ -0,0 +1,54 @@ +lead_time: 120h +write_initial_state: true +allow_nans: true + +env: + ANEMOI_INFERENCE_NUM_CHUNKS: 8 # OOM error if not set + +# inputs +input: + test: + use_original_paths: true + + +# outputs +output: + tee: + - grib: + path: grib/{date}{time:04}_{step:03}.grib + encoding: + typeOfGeneratingProcess: 2 + templates: + - file: + path: resources/icon-ch1-typeOfLevel=surface.grib + variables: [lsm, msl, sp, z, skt, tp] + - file: + path: resources/icon-ch1-typeOfLevel=heightAboveGround.grib + variables: [2t, 2d, 10u, 10v] + - file: resources/icon-ch1-typeOfLevel=isobaricInhPa.grib + post_processors: + - extract_mask: # removes global points + mask: "lam_0/cutout_mask" + as_slice: true + - grib: + path: grib/ifs-{date}{time:04}_{step:03}.grib + encoding: + typeOfGeneratingProcess: 2 + templates: + samples: resources/templates_index_ifs.yaml + post_processors: + - extract_mask: # removes lam points + mask: "lam_0/cutout_mask" + as_slice: true + inverse: true + - assign_mask: # fill local/global overlapping points with nan + mask: "global/cutout_mask" + +patch_metadata: + config: + dataloader: + test: + dataset: + cutout: + - dataset: /store_new/mch/msopr/ml/datasets/mch-ich1-1km-2024-2025-1h-pl13-ifsnames-v1.0.zarr + - dataset: /store_new/mch/msopr/ml/datasets/aifs-od-an-oper-0001-mars-n320-2016-2025-6h-v1-for-single-v2.zarr From b8573704df87a69213084bf439505400fb1e702f Mon Sep 17 00:00:00 2001 From: Francesco Zanetta Date: Fri, 6 Feb 2026 14:40:34 +0100 Subject: [PATCH 02/13] update lockfile --- uv.lock | 124 ++++++++++++++++++++++++++++++++++++++++++++------------ 1 file changed, 98 insertions(+), 26 deletions(-) diff --git a/uv.lock b/uv.lock index 1c33f702..05fe9dc8 100644 --- a/uv.lock +++ b/uv.lock @@ -36,47 +36,54 @@ wheels = [ [[package]] name = "anemoi-datasets" -version = "0.5.26" +version = "0.5.31" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "anemoi-transform" }, - { name = "anemoi-utils", extra = ["provenance"] }, + { name = "anemoi-utils" }, { name = "cfunits" }, + { name = "glom" }, + { name = "jsonschema" }, { name = "numcodecs" }, { name = "numpy" }, + { name = "pytest" }, + { name = "pytest-xdist" }, { name = "pyyaml" }, + { name = "ruamel-yaml" }, { name = "semantic-version" }, { name = "tqdm" }, { name = "zarr" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/24/9b/ddce8751a14879349b46f65153ea95a654dc874043f869e7c6a9c4a28a25/anemoi_datasets-0.5.26.tar.gz", hash = "sha256:75c3c0c55c26a985fb3b58cb8af483383820fb0ad6768b2eb62058c2b1f36bd2", size = 1771624, upload-time = "2025-07-15T06:53:02.194Z" } +sdist = { url = "https://files.pythonhosted.org/packages/8d/6c/abb645f9862da60ed62ecea4f71635528ecde83aeb0cdc820bd31299b258/anemoi_datasets-0.5.31.tar.gz", hash = "sha256:3ffbc382e6f6b512ded580a0a0ed481da686944c840f8778c0354a3e98fe4675", size = 1793730, upload-time = "2026-02-05T11:46:03.881Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/ca/00/8fd84aa44be8efb99c2b74d630a7db1ec48bd7a0620a4634a560437175ea/anemoi_datasets-0.5.26-py3-none-any.whl", hash = "sha256:2508fc15bd825d04477530ed285186a748da25abe006499e1d39ce7310d3296b", size = 282350, upload-time = "2025-07-15T06:53:00.71Z" }, + { url = "https://files.pythonhosted.org/packages/d0/81/09ac820664b56cc54522560dc56d9b8a776d2b866af127d5f98d9a11f650/anemoi_datasets-0.5.31-py3-none-any.whl", hash = "sha256:9e46242f0d8910b25631923fdf4034ae10f2b33618d2c12a9d3bef8c250c01b9", size = 285973, upload-time = "2026-02-05T11:46:02.017Z" }, ] [[package]] name = "anemoi-transform" -version = "0.1.16" +version = "0.1.23" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "anemoi-utils" }, { name = "cfunits" }, { name = "earthkit-data" }, + { name = "earthkit-geo" }, { name = "earthkit-meteo" }, { name = "earthkit-regrid" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/c8/f0/544c4a0943d649ab089ac2bf6737ea7e5959b68fb6e6a7b6e3695757d9dc/anemoi_transform-0.1.16.tar.gz", hash = "sha256:771a4134bd35cf8ef713047e704558defdfab6a363ab4ea61356d4da26c6bee6", size = 132995, upload-time = "2025-08-08T07:49:42.267Z" } +sdist = { url = "https://files.pythonhosted.org/packages/e7/ee/0e16c0735ce83b03eefb7e2a69b5d1b3be0ad56736826ab06fc60db16b19/anemoi_transform-0.1.23.tar.gz", hash = "sha256:fcf06a4aaa6dcb8a8f56350489e4b48418b3500766b2b882d8c29fc0a8b25ba5", size = 163027, upload-time = "2026-02-05T10:30:47.186Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/14/2d/c41440ff4924a4a8696efb8d795e710491882fc7ba76d1b0b479dd99ce83/anemoi_transform-0.1.16-py3-none-any.whl", hash = "sha256:2987a40467fb26d6a01b2b0a9a2be55e2a6bf8a8a4f0f3b28e61795da3bec2b4", size = 76819, upload-time = "2025-08-08T07:49:41.307Z" }, + { url = "https://files.pythonhosted.org/packages/7c/87/505614c01e023306b707d13b673628b577fd4c1332f5b44f0c9e1ff2a09b/anemoi_transform-0.1.23-py3-none-any.whl", hash = "sha256:233f20eefb1ce8fe68ff96e39f4e1daf9a29cced9dda6a30bdf867dc49571869", size = 106826, upload-time = "2026-02-05T10:30:45.835Z" }, ] [[package]] name = "anemoi-utils" -version = "0.4.35" +version = "0.4.43" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "aniso8601" }, { name = "deprecation" }, + { name = "entrypoints" }, { name = "multiurl" }, { name = "numpy" }, { name = "pydantic" }, @@ -85,15 +92,9 @@ dependencies = [ { name = "rich" }, { name = "tqdm" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/70/70/50e7bd93327bd610b0d808d017c5dc9b87a21a3f6172a449307ca1480cfe/anemoi_utils-0.4.35.tar.gz", hash = "sha256:c259c918e7ae9582c6f8cac3bb137fb1326759703d82334c7ec568c00a07afe5", size = 136476, upload-time = "2025-08-12T16:11:52.902Z" } +sdist = { url = "https://files.pythonhosted.org/packages/9e/40/79ee8859b327806c193a22c840ba02715aacf4993270f3f713d8f61f7fa3/anemoi_utils-0.4.43.tar.gz", hash = "sha256:e8ab8577e1b68b252e75772aa1e3eb9733bf30cbd280b94375ca93bafd23c012", size = 145251, upload-time = "2026-01-21T13:34:00.239Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/d4/d0/219f4f280745d7462bb534a1de9197afe1ab81e81ae2768e23569cb1cdce/anemoi_utils-0.4.35-py3-none-any.whl", hash = "sha256:b0c98359852231c3e18c5b90dd8a5ba278a106d527fa97d9b578fbf162b185a2", size = 91019, upload-time = "2025-08-12T16:11:51.43Z" }, -] - -[package.optional-dependencies] -provenance = [ - { name = "gitpython" }, - { name = "nvsmi" }, + { url = "https://files.pythonhosted.org/packages/9e/48/bfbd65958625bb75e41b77fc198420461a5a1dcdd32ef1763daf2ae28cab/anemoi_utils-0.4.43-py3-none-any.whl", hash = "sha256:5c04312d48168197d94bf3fa97228a2537a6e8ff756845be95193a4233297b68", size = 93307, upload-time = "2026-01-21T13:33:57.637Z" }, ] [[package]] @@ -216,6 +217,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/10/cb/f2ad4230dc2eb1a74edf38f1a38b9b52277f75bef262d8908e60d957e13c/blinker-1.9.0-py3-none-any.whl", hash = "sha256:ba0efaa9080b619ff2f3459d1d500c57bddea4a6b424b60a91141db6fd2f08bc", size = 8458, upload-time = "2024-11-08T17:25:46.184Z" }, ] +[[package]] +name = "boltons" +version = "25.0.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/63/54/71a94d8e02da9a865587fb3fff100cb0fc7aa9f4d5ed9ed3a591216ddcc7/boltons-25.0.0.tar.gz", hash = "sha256:e110fbdc30b7b9868cb604e3f71d4722dd8f4dcb4a5ddd06028ba8f1ab0b5ace", size = 246294, upload-time = "2025-02-03T05:57:59.129Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/45/7f/0e961cf3908bc4c1c3e027de2794f867c6c89fb4916fc7dba295a0e80a2d/boltons-25.0.0-py3-none-any.whl", hash = "sha256:dc9fb38bf28985715497d1b54d00b62ea866eca3938938ea9043e254a3a6ca62", size = 194210, upload-time = "2025-02-03T05:57:56.705Z" }, +] + [[package]] name = "cachetools" version = "5.5.2" @@ -878,6 +888,20 @@ covjsonkit = [ { name = "covjsonkit" }, ] +[[package]] +name = "earthkit-geo" +version = "0.4.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pyproj" }, + { name = "requests" }, + { name = "scipy" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/ca/2f/6dcb9cb89ef0520b5752c3cfb93dc15775bfe9ed59265b4a9b4adedf66b7/earthkit_geo-0.4.0.tar.gz", hash = "sha256:7aeb989abd97f06de9fd870676ddec97e215b91cf80cc225c4b2198878940b38", size = 644334, upload-time = "2025-10-24T10:44:36.69Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/8b/82/61bd75d544c959213abfeafc425835d203334c5fed63d62e51d661d81244/earthkit_geo-0.4.0-py3-none-any.whl", hash = "sha256:4d89ae8a8dfbf7f7e83a95656c608cefb018ab0bb732be4f2678cf6cc4d9a469", size = 20037, upload-time = "2025-10-24T10:44:34.709Z" }, +] + [[package]] name = "earthkit-meteo" version = "0.4.1" @@ -1031,7 +1055,7 @@ dev = [ [package.metadata] requires-dist = [ - { name = "anemoi-datasets", specifier = ">=0.5.25" }, + { name = "anemoi-datasets", specifier = ">=0.5.31" }, { name = "cartopy", specifier = ">=0.25.0" }, { name = "click" }, { name = "earthkit-plots" }, @@ -1060,6 +1084,27 @@ dev = [ { name = "snakefmt", specifier = ">=0.11.0" }, ] +[[package]] +name = "execnet" +version = "2.1.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/bf/89/780e11f9588d9e7128a3f87788354c7946a9cbb1401ad38a48c4db9a4f07/execnet-2.1.2.tar.gz", hash = "sha256:63d83bfdd9a23e35b9c6a3261412324f964c2ec8dcd8d3c6916ee9373e0befcd", size = 166622, upload-time = "2025-11-12T09:56:37.75Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ab/84/02fc1827e8cdded4aa65baef11296a9bbe595c474f0d6d758af082d849fd/execnet-2.1.2-py3-none-any.whl", hash = "sha256:67fba928dd5a544b783f6056f449e5e3931a5c378b128bc18501f7ea79e296ec", size = 40708, upload-time = "2025-11-12T09:56:36.333Z" }, +] + +[[package]] +name = "face" +version = "24.0.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "boltons" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/ac/79/2484075a8549cd64beae697a8f664dee69a5ccf3a7439ee40c8f93c1978a/face-24.0.0.tar.gz", hash = "sha256:611e29a01ac5970f0077f9c577e746d48c082588b411b33a0dd55c4d872949f6", size = 62732, upload-time = "2024-11-02T05:24:26.095Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e9/47/21867c2e5fd006c8d36a560df9e32cb4f1f566b20c5dd41f5f8a2124f7de/face-24.0.0-py3-none-any.whl", hash = "sha256:0e2c17b426fa4639a4e77d1de9580f74a98f4869ba4c7c8c175b810611622cd3", size = 54742, upload-time = "2024-11-02T05:24:24.939Z" }, +] + [[package]] name = "fastapi" version = "0.116.1" @@ -1289,6 +1334,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/01/61/d4b89fec821f72385526e1b9d9a3a0385dda4a72b206d28049e2c7cd39b8/gitpython-3.1.45-py3-none-any.whl", hash = "sha256:8908cb2e02fb3b93b7eb0f2827125cb699869470432cc885f019b8fd0fccff77", size = 208168, upload-time = "2025-07-24T03:45:52.517Z" }, ] +[[package]] +name = "glom" +version = "25.12.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "attrs" }, + { name = "boltons" }, + { name = "face" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/78/74/8387f95565ba7c30cd152a585b275ebb9a834d1d32782425c5d2fe0a102c/glom-25.12.0.tar.gz", hash = "sha256:1ae7da88be3693df40ad27bdf57a765a55c075c86c971bcddd67927403eb0069", size = 196128, upload-time = "2025-12-29T06:29:07.274Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a7/e6/4129d9a3baa72d747533bb33376543ccadd9a7f9944e5a6e3ae2e245f5d6/glom-25.12.0-py3-none-any.whl", hash = "sha256:b9f21e77f71a6576a43864e85066b8cc3f0f778d0d50961563f8981377a6dcb1", size = 103295, upload-time = "2025-12-29T06:29:06.074Z" }, +] + [[package]] name = "google-auth" version = "2.40.3" @@ -2318,15 +2377,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/78/e3/6690b3f85a05506733c7e90b577e4762517404ea78bab2ca3a5cb1aeb78d/numpy-2.3.2-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:6936aff90dda378c09bea075af0d9c675fe3a977a9d2402f95a87f440f59f619", size = 12977811, upload-time = "2025-07-24T21:29:18.234Z" }, ] -[[package]] -name = "nvsmi" -version = "0.4.2" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/9d/13/c5da04d29f4e5f830a8558601b3e179163d0d94e0da06529d5a8e62eed9e/nvsmi-0.4.2.tar.gz", hash = "sha256:c1a391c7c4dadc6ec572909ff0372451d464ebadc144e5aa5fbbcc893dcb7bfa", size = 5248, upload-time = "2020-02-28T09:32:05.357Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/d2/d5/6ec6d6410b434463ba76900d2363a1f75c474f3442a4365557b2588fa14b/nvsmi-0.4.2-py3-none-any.whl", hash = "sha256:718894c24bdf7b58b8ecdfd282dceb06ef120a4b4e0b8517193cba876174945e", size = 5466, upload-time = "2020-02-28T09:32:03.88Z" }, -] - [[package]] name = "opentelemetry-api" version = "1.36.0" @@ -3029,6 +3079,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/29/16/c8a903f4c4dffe7a12843191437d7cd8e32751d5de349d45d3fe69544e87/pytest-8.4.1-py3-none-any.whl", hash = "sha256:539c70ba6fcead8e78eebbf1115e8b589e7565830d7d006a8723f19ac8a0afb7", size = 365474, upload-time = "2025-06-18T05:48:03.955Z" }, ] +[[package]] +name = "pytest-xdist" +version = "3.8.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "execnet" }, + { name = "pytest" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/78/b4/439b179d1ff526791eb921115fca8e44e596a13efeda518b9d845a619450/pytest_xdist-3.8.0.tar.gz", hash = "sha256:7e578125ec9bc6050861aa93f2d59f1d8d085595d6551c2c90b6f4fad8d3a9f1", size = 88069, upload-time = "2025-07-01T13:30:59.346Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ca/31/d4e37e9e550c2b92a9cbc2e4d0b7420a27224968580b5a447f420847c975/pytest_xdist-3.8.0-py3-none-any.whl", hash = "sha256:202ca578cfeb7370784a8c33d6d05bc6e13b4f25b5053c30a152269fd10f0b88", size = 46396, upload-time = "2025-07-01T13:30:56.632Z" }, +] + [[package]] name = "python-dateutil" version = "2.9.0.post0" @@ -3287,6 +3350,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/64/8d/0133e4eb4beed9e425d9a98ed6e081a55d195481b7632472be1af08d2f6b/rsa-4.9.1-py3-none-any.whl", hash = "sha256:68635866661c6836b8d39430f97a996acbd61bfa49406748ea243539fe239762", size = 34696, upload-time = "2025-04-16T09:51:17.142Z" }, ] +[[package]] +name = "ruamel-yaml" +version = "0.19.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/c7/3b/ebda527b56beb90cb7652cb1c7e4f91f48649fbcd8d2eb2fb6e77cd3329b/ruamel_yaml-0.19.1.tar.gz", hash = "sha256:53eb66cd27849eff968ebf8f0bf61f46cdac2da1d1f3576dd4ccee9b25c31993", size = 142709, upload-time = "2026-01-02T16:50:31.84Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b8/0c/51f6841f1d84f404f92463fc2b1ba0da357ca1e3db6b7fbda26956c3b82a/ruamel_yaml-0.19.1-py3-none-any.whl", hash = "sha256:27592957fedf6e0b62f281e96effd28043345e0e66001f97683aa9a40c667c93", size = 118102, upload-time = "2026-01-02T16:50:29.201Z" }, +] + [[package]] name = "scikit-learn" version = "1.7.1" From 0453af8fd6fcdad7cbcfa545fa8c2ee1e4316a05 Mon Sep 17 00:00:00 2001 From: Daniele Nerini Date: Thu, 12 Feb 2026 15:14:13 +0100 Subject: [PATCH 03/13] Map fcst to truth --- README.md | 4 +- config/forecasters-co1e.yaml | 4 +- config/forecasters-co2-disentangled.yaml | 4 +- config/forecasters-co2.yaml | 4 +- config/forecasters-ich1-oper.yaml | 12 +- config/forecasters-ich1.yaml | 14 +-- config/interpolators-co2.yaml | 4 +- src/data_input/__init__.py | 118 ++++++++++++++----- src/evalml/config.py | 12 +- src/verification/__init__.py | 33 ++++-- workflow/rules/data.smk | 10 +- workflow/rules/inference.smk | 1 - workflow/rules/plot.smk | 8 +- workflow/rules/report.smk | 4 +- workflow/rules/verif.smk | 16 +-- workflow/scripts/plot_meteogram.mo.py | 15 ++- workflow/scripts/verif_aggregation.py | 2 +- workflow/scripts/verif_single_init.py | 143 +++++++++++++++++------ workflow/tools/config.schema.json | 59 +++++----- 19 files changed, 311 insertions(+), 156 deletions(-) diff --git a/README.md b/README.md index f2fdf5b4..32ad0a86 100644 --- a/README.md +++ b/README.md @@ -51,9 +51,9 @@ baselines: root: /store_new/mch/msopr/ml/COSMO-E steps: 0/120/6 -analysis: +truth: label: COSMO KENDA - analysis_zarr: /scratch/mch/fzanetta/data/anemoi/datasets/mch-co2-an-archive-0p02-2015-2020-6h-v3-pl13.zarr + root: /scratch/mch/fzanetta/data/anemoi/datasets/mch-co2-an-archive-0p02-2015-2020-6h-v3-pl13.zarr locations: output_root: output/ diff --git a/config/forecasters-co1e.yaml b/config/forecasters-co1e.yaml index ade81dd3..9ee74a7c 100644 --- a/config/forecasters-co1e.yaml +++ b/config/forecasters-co1e.yaml @@ -27,9 +27,9 @@ baselines: root: /store_new/mch/msopr/ml/COSMO-1E steps: 0/33/6 -analysis: +truth: label: COSMO KENDA - analysis_zarr: /scratch/mch/fzanetta/data/anemoi/datasets/mch-co1e-an-archive-0p01-2019-2024-1h-v1-pl13.zarr + root: /scratch/mch/fzanetta/data/anemoi/datasets/mch-co1e-an-archive-0p01-2019-2024-1h-v1-pl13.zarr stratification: regions: diff --git a/config/forecasters-co2-disentangled.yaml b/config/forecasters-co2-disentangled.yaml index 5c595cff..cabc5f08 100644 --- a/config/forecasters-co2-disentangled.yaml +++ b/config/forecasters-co2-disentangled.yaml @@ -42,9 +42,9 @@ baselines: root: /store_new/mch/msopr/ml/COSMO-E steps: 0/120/6 -analysis: +truth: label: COSMO KENDA - analysis_zarr: /scratch/mch/fzanetta/data/anemoi/datasets/mch-co2-an-archive-0p02-2015-2020-6h-v3-pl13.zarr + root: /scratch/mch/fzanetta/data/anemoi/datasets/mch-co2-an-archive-0p02-2015-2020-6h-v3-pl13.zarr stratification: regions: diff --git a/config/forecasters-co2.yaml b/config/forecasters-co2.yaml index ef881d49..00343cc2 100644 --- a/config/forecasters-co2.yaml +++ b/config/forecasters-co2.yaml @@ -23,9 +23,9 @@ baselines: root: /store_new/mch/msopr/ml/COSMO-E steps: 0/120/6 -analysis: +truth: label: COSMO KENDA - analysis_zarr: /scratch/mch/fzanetta/data/anemoi/datasets/mch-co2-an-archive-0p02-2015-2020-6h-v3-pl13.zarr + root: /scratch/mch/fzanetta/data/anemoi/datasets/mch-co2-an-archive-0p02-2015-2020-6h-v3-pl13.zarr stratification: regions: diff --git a/config/forecasters-ich1-oper.yaml b/config/forecasters-ich1-oper.yaml index 94163eae..534fc30a 100644 --- a/config/forecasters-ich1-oper.yaml +++ b/config/forecasters-ich1-oper.yaml @@ -20,14 +20,14 @@ runs: baselines: - baseline: - baseline_id: ICON-CH1-EPS - label: ICON-CH1-EPS - root: /store_new/mch/msopr/ml/ICON-CH1-EPS - steps: 0/33/6 + baseline_id: ICON-CH2-EPS + label: ICON-CH2-EPS + root: /scratch/mch/cmerker/ICON-CH2-EPS + steps: 0/120/6 -analysis: +truth: label: KENDA-CH1 - analysis_zarr: /store_new/mch/msopr/ml/datasets/mch-ich1-1km-2024-2025-1h-pl13-v1.0.zarr + root: /store_new/mch/msopr/ml/datasets/mch-ich1-1km-2024-2025-1h-pl13-v1.0.zarr stratification: regions: diff --git a/config/forecasters-ich1.yaml b/config/forecasters-ich1.yaml index e97c681c..d1c1b600 100644 --- a/config/forecasters-ich1.yaml +++ b/config/forecasters-ich1.yaml @@ -19,14 +19,14 @@ runs: baselines: - baseline: - baseline_id: ICON-CH1-EPS - label: ICON-CH1-EPS - root: /store_new/mch/msopr/ml/ICON-CH1-EPS - steps: 0/33/6 + baseline_id: ICON-CH2-EPS + label: ICON-CH2-EPS + root: /scratch/mch/cmerker/ICON-CH2-EPS + steps: 0/120/6 -analysis: - label: REA-L-CH1 - analysis_zarr: /store_new/mch/msopr/ml/datasets/mch-realch1-fdb-1km-2005-2025-1h-pl13-v1.0.zarr +truth: + label: KENDA-CH1 + root: /store_new/mch/msopr/ml/datasets/mch-ich1-1km-2024-2025-1h-pl13-v1.0.zarr stratification: regions: diff --git a/config/interpolators-co2.yaml b/config/interpolators-co2.yaml index cf235360..a90f7630 100644 --- a/config/interpolators-co2.yaml +++ b/config/interpolators-co2.yaml @@ -52,9 +52,9 @@ baselines: root: /store_new/mch/msopr/ml/COSMO-E_hourly steps: 0/120/1 -analysis: +truth: label: COSMO KENDA - analysis_zarr: /scratch/mch/fzanetta/data/anemoi/datasets/mch-co2-an-archive-0p02-2015-2020-1h-v3-pl13.zarr + root: /scratch/mch/fzanetta/data/anemoi/datasets/mch-co2-an-archive-0p02-2015-2020-1h-v3-pl13.zarr stratification: regions: diff --git a/src/data_input/__init__.py b/src/data_input/__init__.py index 047fff05..2404e792 100644 --- a/src/data_input/__init__.py +++ b/src/data_input/__init__.py @@ -1,9 +1,8 @@ import logging import os import sys -from datetime import datetime +from datetime import datetime, timedelta from pathlib import Path -from typing import Iterable eccodes_definition_path = Path(sys.prefix) / "share/eccodes-cosmo-resources/definitions" os.environ["ECCODES_DEFINITION_PATH"] = str(eccodes_definition_path) @@ -16,8 +15,27 @@ LOG = logging.getLogger(__name__) +def _select_valid_times(ds, times: np.datetime64): + # (handle special case where some valid times are not in the dataset, e.g. at the end) + times_np = np.asarray(times, dtype="datetime64[ns]") + times_included = np.isin(times_np, ds.time.values) + if times_included.all(): + return ds.sel(time=times_np) + elif times_included.any(): + LOG.warning( + "Some valid times are not included in the dataset: \n%s", + times_np[~times_included], + ) + return ds.sel(time=times_np[times_included]) + else: + raise ValueError( + "Valid times are not included in the dataset. " + "Please check the valid times and the dataset." + ) + + def load_analysis_data_from_zarr( - analysis_zarr: Path, times: Iterable[datetime], params: list[str] + root: Path, reftime: datetime, steps: list[int], params: list[str] ) -> xr.Dataset: """Load analysis data from an anemoi-generated Zarr dataset @@ -36,9 +54,9 @@ def load_analysis_data_from_zarr( PARAMS_MAP_COSMO1 = { v: v.replace("TOT_PREC", "TOT_PREC_6H") for v in PARAMS_MAP_COSMO2.keys() } - PARAMS_MAP = PARAMS_MAP_COSMO2 if "co2" in analysis_zarr.name else PARAMS_MAP_COSMO1 + PARAMS_MAP = PARAMS_MAP_COSMO2 if "co2" in root.name else PARAMS_MAP_COSMO1 - ds = xr.open_zarr(analysis_zarr, consolidated=False) + ds = xr.open_zarr(root, consolidated=False) # rename "dates" to "time" and set it as index ds = ds.set_index(time="dates") @@ -59,8 +77,8 @@ def load_analysis_data_from_zarr( # set lat lon as coords (optional) if "latitudes" in ds and "longitudes" in ds: - ds = ds.rename({"latitudes": "latitude", "longitudes": "longitude"}) - ds = ds.set_coords(["latitude", "longitude"]) + ds = ds.rename({"latitudes": "lat", "longitudes": "lon"}) + ds = ds.set_coords(["lat", "lon"]) ds = ( ds["data"] .to_dataset("variable") @@ -71,30 +89,15 @@ def load_analysis_data_from_zarr( if "cell" in ds.dims: ds = ds.rename({"cell": "values"}) - # select valid times - # (handle special case where some valid times are not in the dataset, e.g. at the end) - times_included = times.isin(ds.time.values).values - if all(times_included): - ds = ds.sel(time=times) - elif np.sum(times_included) < len(times_included): - LOG.warning( - "Some valid times are not included in the dataset: \n%s", - times[~times_included].values, - ) - ds = ds.sel(time=times[times_included]) - else: - raise ValueError( - "Valid times are not included in the dataset. " - "Please check the valid times and the dataset." - ) - return ds + times = np.datetime64(reftime) + np.asarray(steps, dtype="timedelta64[h]") + return _select_valid_times(ds, times) def load_fct_data_from_grib( - grib_output_dir: Path, reftime: datetime, steps: list[int], params: list[str] + root: Path, reftime: datetime, steps: list[int], params: list[str] ) -> xr.Dataset: """Load forecast data from GRIB files for a specific valid time.""" - files = sorted(grib_output_dir.glob("20*.grib")) + files = sorted(root.glob("20*.grib")) fds = data_source.FileDataSource(datafiles=files) ds = grib_decoder.load(fds, {"param": params, "step": steps}) for var, da in ds.items(): @@ -127,13 +130,13 @@ def load_fct_data_from_grib( def load_baseline_from_zarr( - zarr_path: Path, reftime: datetime, steps: list[int], params: list[str] + root: Path, reftime: datetime, steps: list[int], params: list[str] ) -> xr.Dataset: """Load forecast data from a Zarr dataset.""" try: - baseline = xr.open_zarr(zarr_path, consolidated=True, decode_timedelta=True) + baseline = xr.open_zarr(root, consolidated=True, decode_timedelta=True) except ValueError: - raise ValueError(f"Could not open baseline zarr at {zarr_path}") + raise ValueError(f"Could not open baseline zarr at {root}") baseline = baseline.rename( {"forecast_reference_time": "ref_time", "step": "lead_time"} @@ -156,4 +159,61 @@ def load_baseline_from_zarr( lead_time=np.array(steps, dtype="timedelta64[h]"), ) baseline = baseline.assign_coords(time=baseline.ref_time + baseline.lead_time) + if "latitude" in baseline.coords and "longitude" in baseline: + baseline = baseline.rename({"latitude": "lat", "longitude": "lon"}) return baseline + + +def load_obs_data_from_peakweather( + root, reftime: datetime, steps: list[int], params: list[str], freq: str = "1h" +) -> xr.Dataset: + """Load PeakWeather station observations into an xarray Dataset. + + Returns a Dataset with dimensions `time` and `values`, values coordinates + (`lat`, `lon`), and variables renamed to ICON parameter names. + Temperatures are converted to Kelvin when present. + """ + from peakweather.dataset import PeakWeatherDataset + + param_names = { + "temperature": "T_2M", + "wind_u": "U_10M", + "wind_v": "V_10M", + } + param_names = {k: v for k, v in param_names.items() if v in params} + + start = reftime + end = start + timedelta(hours=max(steps)) + if len(steps) > 1: + end += timedelta(hours=steps[-1] - steps[-2]) # extend by 1 extra step + years = list(set([start.year, end.year])) + pw = PeakWeatherDataset(root=root, years=years, freq=freq) + ds, mask = pw.get_observations( + parameters=[k for k in param_names.keys()], + first_date=f"{start:%Y-%m-%d %H:%M}", + last_date=f"{end:%Y-%m-%d %H:%M}", + return_mask=True, + ) + ds = ( + ds.stack(["nat_abbr", "name"], future_stack=True) + .to_xarray() + .to_dataset(dim="name") + ) + mask = ( + mask.stack(["nat_abbr", "name"], future_stack=True) + .to_xarray() + .to_dataset(dim="name") + ) + ds = ds.where(mask) + ds = ds.rename({"datetime": "time", "nat_abbr": "values"}) + ds = ds.rename(param_names) + ds = ds.assign_coords(time=ds.indexes["time"].tz_convert("UTC").tz_localize(None)) + ds = ds.assign_coords(values=ds.indexes["values"]) + ds = ds.assign_coords(lon=("values", pw.stations_table["longitude"])) + ds = ds.assign_coords(lat=("values", pw.stations_table["latitude"])) + if "T_2M" in ds: + ds["T_2M"] = ds["T_2M"] + 273.15 # convert to Kelvin + ds = ds.dropna("values", how="all") + + times = np.datetime64(reftime) + np.asarray(steps, dtype="timedelta64[h]") + return _select_valid_times(ds, times) diff --git a/src/evalml/config.py b/src/evalml/config.py index d36ceeda..c8e88e48 100644 --- a/src/evalml/config.py +++ b/src/evalml/config.py @@ -173,18 +173,18 @@ class BaselineConfig(BaseModel): ) -class AnalysisConfig(BaseModel): - """Configuration for the analysis data used in the verification.""" +class TruthConfig(BaseModel): + """Configuration for the truth data used in the verification.""" label: str = Field( ..., min_length=1, - description="Label for the analysis that will be used in experiment results such as reports and figures.", + description="Label that will be used in experiment results such as reports and figures.", ) - analysis_zarr: str = Field( + root: str = Field( ..., min_length=1, - description="Path to the zarr dataset containing the analysis data.", + description="Path to the root of the dataset.", ) @@ -310,7 +310,7 @@ class ConfigModel(BaseModel): ..., description="Dictionary of baselines to include in the verification.", ) - analysis: AnalysisConfig + truth: TruthConfig | None stratification: Stratification locations: Locations profile: Profile diff --git a/src/verification/__init__.py b/src/verification/__init__.py index db70f4fc..97a62505 100644 --- a/src/verification/__init__.py +++ b/src/verification/__init__.py @@ -75,6 +75,7 @@ def _mask_from_polygons( def _compute_scores( fcst: xr.DataArray, obs: xr.DataArray, + dim: list[str], prefix="", suffix="", source="", @@ -83,7 +84,6 @@ def _compute_scores( Compute basic verification metrics between two xarray DataArrays (fcst and obs). Returns a xarray Dataset with the computed metrics. """ - dim = ["x", "y"] if "x" in fcst.dims and "y" in fcst.dims else ["values"] error = fcst - obs scores = xr.Dataset( { @@ -101,6 +101,7 @@ def _compute_scores( def _compute_statistics( data: xr.DataArray, + dim: list[str], prefix="", suffix="", source="", @@ -109,7 +110,6 @@ def _compute_statistics( Compute basic statistics of a xarray DataArray (data). Returns a xarray Dataset with the computed statistics. """ - dim = ["x", "y"] if "x" in data.dims and "y" in data.dims else ["values"] stats = xr.Dataset( { f"{prefix}mean{suffix}": data.mean(dim=dim, skipna=True), @@ -146,6 +146,7 @@ def verify( fcst_label: str, obs_label: str, regions: list[str] | None = None, + dim: list[str] | None = None, ) -> xr.Dataset: """ Compare two xarray Datasets (fcst and obs) and return pandas DataFrame with @@ -153,15 +154,21 @@ def verify( """ start = time.time() + if dim is None: + if "x" in fcst.dims and "y" in fcst.dims: + dim = ["x", "y"] + elif "values" in fcst.dims: + dim = ["values"] + else: + dim = ["values"] + # rewrite the verification to use dask and xarray # chunk the data to avoid memory issues # compute the metrics in parallel # return the results as a xarray Dataset fcst_aligned, obs_aligned = xr.align(fcst, obs, join="inner", copy=False) region_polygons = ShapefileSpatialAggregationMasks(shp=regions) - masks = region_polygons.get_masks( - lon=obs_aligned["longitude"], lat=obs_aligned["latitude"] - ) + masks = region_polygons.get_masks(lon=obs_aligned["lon"], lat=obs_aligned["lat"]) scores = [] statistics = [] @@ -180,19 +187,29 @@ def verify( # scores vs time (reduce spatially) score.append( _compute_scores( - fcst_param, obs_param, prefix=param + ".", source=fcst_label + fcst_param, + obs_param, + prefix=param + ".", + source=fcst_label, + dim=dim, ).expand_dims(region=[region]) ) # statistics vs time (reduce spatially) fcst_statistics.append( _compute_statistics( - fcst_param, prefix=param + ".", source=fcst_label + fcst_param, + prefix=param + ".", + source=fcst_label, + dim=dim, ).expand_dims(region=[region]) ) obs_statistics.append( _compute_statistics( - obs_param, prefix=param + ".", source=obs_label + obs_param, + prefix=param + ".", + source=obs_label, + dim=dim, ).expand_dims(region=[region]) ) diff --git a/workflow/rules/data.smk b/workflow/rules/data.smk index 9892f8f6..91a48a79 100644 --- a/workflow/rules/data.smk +++ b/workflow/rules/data.smk @@ -4,12 +4,18 @@ from pathlib import Path include: "common.smk" +if config["truth"]["root"].endswith("peakweather"): + output_peakweather_root = config["truth"]["root"] +else: + output_peakweather_root = OUT_ROOT / "data/observations/peakweather" + + rule download_obs_from_peakweather: localrule: True output: - peakweather=directory(OUT_ROOT / "data/observations/peakweather"), + root=directory(output_peakweather_root), run: from peakweather.dataset import PeakWeatherDataset # Download the data from Huggingface - ds = PeakWeatherDataset(root=output.peakweather) + ds = PeakWeatherDataset(root=output.root) diff --git a/workflow/rules/inference.smk b/workflow/rules/inference.smk index 20043e94..267715f8 100644 --- a/workflow/rules/inference.smk +++ b/workflow/rules/inference.smk @@ -17,7 +17,6 @@ rule create_inference_pyproject: """ input: toml="workflow/envs/anemoi_inference.toml", - summary=rules.write_summary.output, output: pyproject=OUT_ROOT / "data/runs/{run_id}/pyproject.toml", params: diff --git a/workflow/rules/plot.smk b/workflow/rules/plot.smk index ec3a08fc..60a036b0 100644 --- a/workflow/rules/plot.smk +++ b/workflow/rules/plot.smk @@ -24,9 +24,9 @@ rule plot_meteogram: input: script="workflow/scripts/plot_meteogram.mo.py", inference_okfile=rules.execute_inference.output.okfile, - analysis_zarr=config["analysis"].get("analysis_zarr"), + truth=config["truth"]["root"], baseline_zarr=lambda wc: _use_first_baseline_zarr(wc), - peakweather_dir=rules.download_obs_from_peakweather.output.peakweather, + peakweather_dir=rules.download_obs_from_peakweather.output.root, output: OUT_ROOT / "results/{showcase}/{run_id}/{init_time}/{init_time}_{param}_{sta}.png", @@ -43,13 +43,13 @@ rule plot_meteogram: """ export ECCODES_DEFINITION_PATH=$(realpath .venv/share/eccodes-cosmo-resources/definitions) python {input.script} \ - --forecast {params.grib_out_dir} --analysis {input.analysis_zarr} \ + --forecast {params.grib_out_dir} --analysis {input.truth} \ --baseline {input.baseline_zarr} --peakweather {input.peakweather_dir} \ --date {wildcards.init_time} --outfn {output[0]} \ --param {wildcards.param} --station {wildcards.sta} # interactive editing (needs to set localrule: True and use only one core) # marimo edit {input.script} -- \ - # --forecast {params.grib_out_dir} --analysis {input.analysis_zarr} \ + # --forecast {params.grib_out_dir} --analysis {input.truth} \ # --baseline {input.baseline_zarr} --peakweather {input.peakweather_dir} \ # --date {wildcards.init_time} --outfn {output[0]} \ # --param {wildcards.param} --station {wildcards.sta} diff --git a/workflow/rules/report.smk b/workflow/rules/report.smk index 7122e60a..b0acc44b 100644 --- a/workflow/rules/report.smk +++ b/workflow/rules/report.smk @@ -9,10 +9,10 @@ include: "common.smk" def make_header_text(): dates = config["dates"] - analysis = config["analysis"]["label"] + truth = config["truth"]["label"] if isinstance(dates, list): return f"Explicit initializations from {len(dates)} runs have been used." - return f"Verification against {analysis} with initializations from {dates.get('start')} to {dates.get('end')} by {dates.get('frequency')}" + return f"Verification against {truth} with initializations from {dates.get('start')} to {dates.get('end')} by {dates.get('frequency')}" rule report_experiment_dashboard: diff --git a/workflow/rules/verif.smk b/workflow/rules/verif.smk index f8397c3d..ebbc8a8c 100644 --- a/workflow/rules/verif.smk +++ b/workflow/rules/verif.smk @@ -20,11 +20,11 @@ rule verif_metrics_baseline: root=BASELINE_CONFIGS[wc.baseline_id].get("root"), year=wc.init_time[2:4], ), - analysis_zarr=config["analysis"].get("analysis_zarr"), + truth=config["truth"]["root"], params: baseline_label=lambda wc: BASELINE_CONFIGS[wc.baseline_id].get("label"), baseline_steps=lambda wc: BASELINE_CONFIGS[wc.baseline_id]["steps"], - analysis_label=config["analysis"].get("label"), + truth_label=config["truth"]["label"], regions=REGION_TXT, output: OUT_ROOT / "data/baselines/{baseline_id}/{init_time}/verif.nc", @@ -38,11 +38,11 @@ rule verif_metrics_baseline: """ uv run {input.script} \ --forecast {input.baseline_zarr} \ - --analysis_zarr {input.analysis_zarr} \ + --truth {input.truth} \ --reftime {wildcards.init_time} \ --steps "{params.baseline_steps}" \ --label "{params.baseline_label}" \ - --analysis_label "{params.analysis_label}" \ + --truth_label "{params.truth_label}" \ --regions "{params.regions}" \ --output {output} > {log} 2>&1 """ @@ -61,7 +61,7 @@ rule verif_metrics: "src/data_input/__init__.py", script="workflow/scripts/verif_single_init.py", inference_okfile=rules.execute_inference.output.okfile, - analysis_zarr=config["analysis"].get("analysis_zarr"), + truth=config["truth"]["root"], output: OUT_ROOT / "data/runs/{run_id}/{init_time}/verif.nc", # wildcard_constraints: @@ -70,7 +70,7 @@ rule verif_metrics: params: fcst_label=lambda wc: RUN_CONFIGS[wc.run_id].get("label"), fcst_steps=lambda wc: RUN_CONFIGS[wc.run_id]["steps"], - analysis_label=config["analysis"].get("label"), + truth_label=config["truth"]["label"], regions=REGION_TXT, grib_out_dir=lambda wc: ( Path(OUT_ROOT) / f"data/runs/{wc.run_id}/{wc.init_time}/grib" @@ -85,11 +85,11 @@ rule verif_metrics: """ uv run {input.script} \ --forecast {params.grib_out_dir} \ - --analysis_zarr {input.analysis_zarr} \ + --truth {input.truth} \ --reftime {wildcards.init_time} \ --steps "{params.fcst_steps}" \ --label "{params.fcst_label}" \ - --analysis_label "{params.analysis_label}" \ + --truth_label "{params.truth_label}" \ --regions "{params.regions}" \ --output {output} > {log} 2>&1 """ diff --git a/workflow/scripts/plot_meteogram.mo.py b/workflow/scripts/plot_meteogram.mo.py index 1738c0ee..e36f45d2 100644 --- a/workflow/scripts/plot_meteogram.mo.py +++ b/workflow/scripts/plot_meteogram.mo.py @@ -134,7 +134,11 @@ def load_grib_data( ds_fct = preprocess_ds(ds_fct, param) da_fct = ds_fct[param].squeeze() - ds_ana = load_analysis_data_from_zarr(zarr_dir_ana, da_fct.valid_time, paramlist) + reftime = da_fct.ref_time.values + steps = list( + range(da_fct.sizes["lead_time"]) + ) # FIX: this will fail if lead_time is not 0,1,2,... + ds_ana = load_analysis_data_from_zarr(zarr_dir_ana, reftime, steps, paramlist) ds_ana = preprocess_ds(ds_ana, param) da_ana = ds_ana[param].squeeze() @@ -191,13 +195,8 @@ def nearest_yx_euclid(ds, lat_s, lon_s): Return (y_idx, x_idx) of the grid point nearest to (lat_s, lon_s) using Euclidean distance in degrees. """ - try: - lat2d = ds["lat"] # (y, x) - lon2d = ds["lon"] # (y, x) - except KeyError: - lat2d = ds["latitude"] # (y, x) - lon2d = ds["longitude"] # (y, x) - + lat2d = ds["lat"] # (y, x) + lon2d = ds["lon"] # (y, x) dist2 = (lat2d - lat_s) ** 2 + (lon2d - lon_s) ** 2 flat_idx = np.nanargmin(dist2.values) y_idx, x_idx = np.unravel_index(flat_idx, dist2.shape) diff --git a/workflow/scripts/verif_aggregation.py b/workflow/scripts/verif_aggregation.py index deb35d65..9057b19c 100644 --- a/workflow/scripts/verif_aggregation.py +++ b/workflow/scripts/verif_aggregation.py @@ -34,7 +34,7 @@ def aggregate_results(ds: xr.Dataset) -> xr.Dataset: ds = ds.assign_coords( season=lambda ds: ds.ref_time.dt.season, init_hour=lambda ds: ds.ref_time.dt.hour, - ).drop_vars(["time"]) + ).drop_vars(["time"], errors="ignore") # compute mean with grouping by all permutations of season and init_hour ds_mean = [] diff --git a/workflow/scripts/verif_single_init.py b/workflow/scripts/verif_single_init.py index 57b1e2c4..e162a838 100644 --- a/workflow/scripts/verif_single_init.py +++ b/workflow/scripts/verif_single_init.py @@ -4,12 +4,16 @@ from datetime import datetime from pathlib import Path +import numpy as np +import xarray as xr +from scipy.spatial import cKDTree from verification import verify # noqa: E402 from data_input import ( load_baseline_from_zarr, load_analysis_data_from_zarr, load_fct_data_from_grib, + load_obs_data_from_peakweather, ) # noqa: E402 LOG = logging.getLogger(__name__) @@ -32,7 +36,7 @@ class ScriptConfig(Namespace): """Configuration for the script to verify baseline forecast data.""" archive_root: Path = None - analysis_zarr: Path = None + truth: Path = None baseline_zarr: Path = None reftime: datetime = None params: list[str] = ["T_2M", "TD_2M", "U_10M", "V_10M"] @@ -44,8 +48,8 @@ def program_summary_log(args): LOG.info("=" * 80) LOG.info("Running verification of baseline forecast data") LOG.info("=" * 80) - LOG.info("baseline zarr dataset: %s", args.baseline_zarr) - LOG.info("Zarr dataset for analysis: %s", args.analysis_zarr) + LOG.info("Baseline dataset: %s", args.baseline_zarr) + LOG.info("Truth dataset: %s", args.truth) LOG.info("Reference time: %s", args.reftime) LOG.info("Parameters to verify: %s", args.params) LOG.info("Lead time: %s", args.lead_time) @@ -53,21 +57,59 @@ def program_summary_log(args): LOG.info("=" * 80) -def main(args: ScriptConfig): - """Main function to verify baseline forecast data.""" +def _map_fcst_to_truth( + fcst: xr.Dataset, truth: xr.Dataset +) -> tuple[xr.Dataset, xr.Dataset]: + """Map forecasts to the truth grid or station locations via nearest-neighbor lookup.""" - # get baseline forecast data + truth = truth.sel(time=fcst.time) # swap time dimension to lead_time - now = datetime.now() + if "y" in fcst.dims and "x" in fcst.dims: + fcst = fcst.stack(values=("y", "x")) + fcst_lat = fcst["lat"].values.ravel() + fcst_lon = fcst["lon"].values.ravel() + + if "y" in truth.dims and "x" in truth.dims: + truth = truth.stack(values=("y", "x")) + truth_lat = truth["lat"].values.ravel() + truth_lon = truth["lon"].values.ravel() + + # TODO: Project to a metric CRS for a proper distance metric + fcst_lat_rad = np.deg2rad(fcst_lat) + fcst_lon_rad = np.deg2rad(fcst_lon) + truth_lat_rad = np.deg2rad(truth_lat) + truth_lon_rad = np.deg2rad(truth_lon) + + fcst_xyz = np.c_[ + np.cos(fcst_lat_rad) * np.cos(fcst_lon_rad), + np.cos(fcst_lat_rad) * np.sin(fcst_lon_rad), + np.sin(fcst_lat_rad), + ] + truth_xyz = np.c_[ + np.cos(truth_lat_rad) * np.cos(truth_lon_rad), + np.cos(truth_lat_rad) * np.sin(truth_lon_rad), + np.sin(truth_lat_rad), + ] + + fcst_tree = cKDTree(fcst_xyz) + _, fi = fcst_tree.query(truth_xyz, k=1) + fi = np.asarray(fi) + fcst = fcst.isel(values=fi) + fcst = fcst.drop_vars(["x", "y", "values"], errors="ignore") + fcst = fcst.assign_coords(lon=("values", truth.lon.data)) + fcst = fcst.assign_coords(lat=("values", truth.lat.data)) + fcst = fcst.assign_coords(values=truth["values"]) + + return fcst, truth - # try to open the baselin as a zarr, and if it fails load from grib - if not args.forecast: - raise ValueError("--forecast must be provided.") + +def _load_forecast(args: ScriptConfig) -> xr.Dataset: + """Load forecast data from GRIB files or a baseline Zarr dataset.""" if any(args.forecast.glob("*.grib")): LOG.info("Loading forecasts from GRIB files...") fcst = load_fct_data_from_grib( - grib_output_dir=args.forecast, + root=args.forecast, reftime=args.reftime, steps=args.steps, params=args.params, @@ -75,45 +117,71 @@ def main(args: ScriptConfig): else: LOG.info("Loading baseline forecasts from zarr dataset...") fcst = load_baseline_from_zarr( - zarr_path=args.forecast, + root=args.forecast, reftime=args.reftime, steps=args.steps, params=args.params, ) + return fcst + + +def _load_truth(args: ScriptConfig) -> xr.Dataset: + """Load truth data from analysis Zarr or PeakWeather observations.""" + LOG.info("Loading ground truth from an analysis zarr dataset...") + if args.truth.suffix == ".zarr": + truth = load_analysis_data_from_zarr( + root=args.truth, + reftime=args.reftime, + steps=args.steps, + params=args.params, + ) + truth = truth.compute().chunk( + {"y": -1, "x": -1} + if "y" in truth.dims and "x" in truth.dims + else {"values": -1} + ) + elif "peakweather" in str(args.truth): + LOG.info("Loading ground truth from PeakWeather observations...") + # TODO: replace with OGD data + truth = load_obs_data_from_peakweather( + root=args.truth, + reftime=args.reftime, + steps=args.steps, + params=args.params, + ) + else: + raise ValueError(f"Unsupported truth root: {args.truth}") + return truth + + +def main(args: ScriptConfig): + """Main function to verify baseline forecast data.""" + + # get baseline forecast data + now = datetime.now() + + fcst = _load_forecast(args) + LOG.info( "Loaded forecast data in %s seconds: \n%s", (datetime.now() - now).total_seconds(), fcst, ) - # get truth data (aka analysis data) + # get truth data now = datetime.now() - if args.analysis_zarr: - analysis = ( - load_analysis_data_from_zarr( - analysis_zarr=args.analysis_zarr, - times=fcst.time, - params=args.params, - ) - .compute() - .chunk( - {"y": -1, "x": -1} - if "y" in fcst.dims and "x" in fcst.dims - else {"values": -1} - ) - ) - else: - raise ValueError("--analysis_zarr must be provided.") + truth = _load_truth(args) LOG.info( - "Loaded analysis data in %s seconds: \n%s", + "Loaded truth data in %s seconds: \n%s", (datetime.now() - now).total_seconds(), - analysis, + truth, ) - # compute metrics and statistics + fcst, truth = _map_fcst_to_truth(fcst, truth) - results = verify(fcst, analysis, args.label, args.analysis_label, args.regions) + # compute metrics and statistics + results = verify(fcst, truth, args.label, args.truth_label, args.regions) # save results to NetCDF args.output.parent.mkdir(parents=True, exist_ok=True) @@ -134,11 +202,10 @@ def main(args: ScriptConfig): help="Path to the directory containing the grib forecast or to the zarr dataset containing baseline data.", ) parser.add_argument( - "--analysis_zarr", + "--truth", type=Path, required=True, - default="/scratch/mch/fzanetta/data/anemoi/datasets/mch-co2-an-archive-0p02-2015-2020-6h-v3-pl13.zarr", - help="Path to the zarr dataset containing analysis data.", + help="Path to the truth data.", ) parser.add_argument( "--reftime", @@ -164,10 +231,10 @@ def main(args: ScriptConfig): help="Label for the forecast or baseline data (default: COSMO-E).", ) parser.add_argument( - "--analysis_label", + "--truth_label", type=str, default="COSMO KENDA", - help="Label for the analysis data (default: COSMO KENDA).", + help="Label for the truth data (default: COSMO KENDA).", ) parser.add_argument( "--regions", diff --git a/workflow/tools/config.schema.json b/workflow/tools/config.schema.json index 85e8add1..ec77ab41 100644 --- a/workflow/tools/config.schema.json +++ b/workflow/tools/config.schema.json @@ -1,28 +1,5 @@ { "$defs": { - "AnalysisConfig": { - "description": "Configuration for the analysis data used in the verification.", - "properties": { - "label": { - "description": "Label for the analysis that will be used in experiment results such as reports and figures.", - "minLength": 1, - "title": "Label", - "type": "string" - }, - "analysis_zarr": { - "description": "Path to the zarr dataset containing the analysis data.", - "minLength": 1, - "title": "Analysis Zarr", - "type": "string" - } - }, - "required": [ - "label", - "analysis_zarr" - ], - "title": "AnalysisConfig", - "type": "object" - }, "BaselineConfig": { "description": "Configuration for a single baseline to include in the verification.", "properties": { @@ -518,6 +495,29 @@ ], "title": "Stratification", "type": "object" + }, + "TruthConfig": { + "description": "Configuration for the truth data used in the verification.", + "properties": { + "label": { + "description": "Label that will be used in experiment results such as reports and figures.", + "minLength": 1, + "title": "Label", + "type": "string" + }, + "root": { + "description": "Path to the root of the dataset.", + "minLength": 1, + "title": "Root", + "type": "string" + } + }, + "required": [ + "label", + "root" + ], + "title": "TruthConfig", + "type": "object" } }, "additionalProperties": false, @@ -575,8 +575,15 @@ "title": "Baselines", "type": "array" }, - "analysis": { - "$ref": "#/$defs/AnalysisConfig" + "truth": { + "anyOf": [ + { + "$ref": "#/$defs/TruthConfig" + }, + { + "type": "null" + } + ] }, "stratification": { "$ref": "#/$defs/Stratification" @@ -593,7 +600,7 @@ "dates", "runs", "baselines", - "analysis", + "truth", "stratification", "locations", "profile" From 288a3bddf0d57b8862a296d19a33388d76470f61 Mon Sep 17 00:00:00 2001 From: Daniele Nerini Date: Wed, 25 Feb 2026 22:01:34 +0100 Subject: [PATCH 04/13] Fix steps parameter for analysis data --- workflow/rules/plot.smk | 1 - workflow/scripts/plot_meteogram.mo.py | 17 +++++------------ 2 files changed, 5 insertions(+), 13 deletions(-) diff --git a/workflow/rules/plot.smk b/workflow/rules/plot.smk index 78d83df2..ceaf60a2 100644 --- a/workflow/rules/plot.smk +++ b/workflow/rules/plot.smk @@ -28,7 +28,6 @@ rule plot_meteogram: truth=config["truth"]["root"], baseline_zarr=lambda wc: _use_first_baseline_zarr(wc)[0], peakweather_dir=rules.download_obs_from_peakweather.output.root, - output: OUT_ROOT / "results/{showcase}/{run_id}/{init_time}/{init_time}_{param}_{sta}.png", diff --git a/workflow/scripts/plot_meteogram.mo.py b/workflow/scripts/plot_meteogram.mo.py index bfc1aade..6e775d84 100644 --- a/workflow/scripts/plot_meteogram.mo.py +++ b/workflow/scripts/plot_meteogram.mo.py @@ -1,6 +1,6 @@ import marimo -__generated_with = "0.16.5" +__generated_with = "0.19.6" app = marimo.App(width="medium") @@ -29,8 +29,8 @@ def _(): grib_decoder, load_analysis_data_from_zarr, load_baseline_from_zarr, - parse_steps, np, + parse_steps, plt, xr, ) @@ -74,6 +74,7 @@ def _(ArgumentParser, Path, parse_steps): station = args.station param = args.param return ( + baseline_steps, grib_dir, init_time, outfn, @@ -82,7 +83,6 @@ def _(ArgumentParser, Path, parse_steps): station, zarr_dir_ana, zarr_dir_base, - baseline_steps, ) @@ -120,8 +120,8 @@ def preprocess_ds(ds, param: str): @app.cell def load_grib_data( - data_source, baseline_steps, + data_source, grib_decoder, grib_dir, init_time, @@ -147,9 +147,7 @@ def load_grib_data( da_fct = ds_fct[param].squeeze() reftime = da_fct.ref_time.values - steps = list( - range(da_fct.sizes["lead_time"]) - ) # FIX: this will fail if lead_time is not 0,1,2,... + steps = da_fct.lead_time.dt.total_seconds() / 3600 ds_ana = load_analysis_data_from_zarr(zarr_dir_ana, reftime, steps, paramlist) ds_ana = preprocess_ds(ds_ana, param) da_ana = ds_ana[param].squeeze() @@ -306,10 +304,5 @@ def _( return -@app.cell -def _(): - return - - if __name__ == "__main__": app.run() From 7ed566f8749bbef2572ab00f1d2010f58c635366 Mon Sep 17 00:00:00 2001 From: Daniele Nerini Date: Wed, 25 Feb 2026 23:12:27 +0100 Subject: [PATCH 05/13] Small refactoring --- src/data_input/__init__.py | 57 ++++++++++++- workflow/rules/plot.smk | 7 +- workflow/scripts/plot_meteogram.mo.py | 97 ++++++++------------- workflow/scripts/verif_single_init.py | 116 ++------------------------ 4 files changed, 105 insertions(+), 172 deletions(-) diff --git a/src/data_input/__init__.py b/src/data_input/__init__.py index c770d9c6..e67816fb 100644 --- a/src/data_input/__init__.py +++ b/src/data_input/__init__.py @@ -107,7 +107,7 @@ def load_fct_data_from_grib( root: Path, reftime: datetime, steps: list[int], params: list[str] ) -> xr.Dataset: """Load forecast data from GRIB files for a specific valid time.""" - files = sorted(root.glob("20*.grib")) + files = sorted(root.glob(f"{reftime:%Y%m%d%H%M}*.grib")) fds = data_source.FileDataSource(datafiles=files) ds = grib_decoder.load(fds, {"param": params, "step": steps}) for var, da in ds.items(): @@ -227,3 +227,58 @@ def load_obs_data_from_peakweather( times = np.datetime64(reftime) + np.asarray(steps, dtype="timedelta64[h]") return _select_valid_times(ds, times) + + +def load_truth_data( + root, reftime: datetime, steps: list[int], params: list[str] +) -> xr.Dataset: + """Load truth data from analysis Zarr or PeakWeather observations.""" + if root.suffix == ".zarr": + LOG.info("Loading ground truth from an analysis zarr dataset...") + truth = load_analysis_data_from_zarr( + root=root, + reftime=reftime, + steps=steps, + params=params, + ) + truth = truth.compute().chunk( + {"y": -1, "x": -1} + if "y" in truth.dims and "x" in truth.dims + else {"values": -1} + ) + elif "peakweather" in str(root): + LOG.info("Loading ground truth from PeakWeather observations...") + truth = load_obs_data_from_peakweather( + root=root, + reftime=reftime, + steps=steps, + params=params, + ) + else: + raise ValueError(f"Unsupported truth root: {root}") + return truth + + +def load_forecast_data( + root, reftime: datetime, steps: list[int], params: list[str] +) -> xr.Dataset: + """Load forecast data from GRIB files or a baseline Zarr dataset.""" + + if any(root.glob("*.grib")): + LOG.info("Loading forecasts from GRIB files...") + fcst = load_fct_data_from_grib( + root=root, + reftime=reftime, + steps=steps, + params=params, + ) + else: + LOG.info("Loading baseline forecasts from zarr dataset...") + fcst = load_baseline_from_zarr( + root=root, + reftime=reftime, + steps=steps, + params=params, + ) + + return fcst diff --git a/workflow/rules/plot.smk b/workflow/rules/plot.smk index ceaf60a2..a41c54ac 100644 --- a/workflow/rules/plot.smk +++ b/workflow/rules/plot.smk @@ -40,20 +40,23 @@ rule plot_meteogram: grib_out_dir=lambda wc: ( Path(OUT_ROOT) / f"data/runs/{wc.run_id}/{wc.init_time}/grib" ).resolve(), + fcst_steps=lambda wc: RUN_CONFIGS[wc.run_id]["steps"], baseline_steps=lambda wc: _use_first_baseline_zarr(wc)[1], shell: """ export ECCODES_DEFINITION_PATH=$(realpath .venv/share/eccodes-cosmo-resources/definitions) python {input.script} \ - --forecast {params.grib_out_dir} --analysis {input.truth} \ + --forecast {params.grib_out_dir} --steps {params.fcst_steps} \ --baseline {input.baseline_zarr} --baseline_steps {params.baseline_steps} \ + --analysis {input.truth} \ --peakweather {input.peakweather_dir} \ --date {wildcards.init_time} --outfn {output[0]} \ --param {wildcards.param} --station {wildcards.sta} # interactive editing (needs to set localrule: True and use only one core) # marimo edit {input.script} -- \ - # --forecast {params.grib_out_dir} --analysis {input.truth} \ + # --forecast {params.grib_out_dir} --steps {params.fcst_steps} \ # --baseline {input.baseline_zarr} --baseline_steps {params.baseline_steps} \ + # --analysis {input.truth} \ # --peakweather {input.peakweather_dir} \ # --date {wildcards.init_time} --outfn {output[0]} \ # --param {wildcards.param} --station {wildcards.sta} diff --git a/workflow/scripts/plot_meteogram.mo.py b/workflow/scripts/plot_meteogram.mo.py index 6e775d84..e9b9ce84 100644 --- a/workflow/scripts/plot_meteogram.mo.py +++ b/workflow/scripts/plot_meteogram.mo.py @@ -7,44 +7,44 @@ @app.cell def _(): from argparse import ArgumentParser + from datetime import datetime from pathlib import Path import matplotlib.pyplot as plt import numpy as np - import xarray as xr - from meteodatalab import data_source, grib_decoder from peakweather import PeakWeatherDataset from data_input import ( - load_analysis_data_from_zarr, - load_baseline_from_zarr, parse_steps, + load_forecast_data, + load_truth_data, ) return ( ArgumentParser, Path, PeakWeatherDataset, - data_source, - grib_decoder, - load_analysis_data_from_zarr, - load_baseline_from_zarr, + datetime, + load_forecast_data, + load_truth_data, np, parse_steps, plt, - xr, ) @app.cell -def _(ArgumentParser, Path, parse_steps): +def _(ArgumentParser, Path, datetime, parse_steps): parser = ArgumentParser() parser.add_argument( "--forecast", type=str, default=None, help="Directory to forecast grib data" ) parser.add_argument( - "--analysis", type=str, default=None, help="Path to analysis zarr data" + "--steps", + type=parse_steps, + default="0/120/6", + help="Forecast steps in the format 'start/stop/step' (default: 0/120/6).", ) parser.add_argument( "--baseline", type=str, default=None, help="Path to baseline zarr data" @@ -55,6 +55,9 @@ def _(ArgumentParser, Path, parse_steps): default="0/120/6", help="Forecast steps in the format 'start/stop/step' (default: 0/120/6).", ) + parser.add_argument( + "--analysis", type=str, default=None, help="Path to analysis zarr data" + ) parser.add_argument( "--peakweather", type=str, default=None, help="Path to PeakWeather dataset" ) @@ -65,16 +68,18 @@ def _(ArgumentParser, Path, parse_steps): args = parser.parse_args() grib_dir = Path(args.forecast) + forecast_steps = args.steps zarr_dir_ana = Path(args.analysis) zarr_dir_base = Path(args.baseline) baseline_steps = args.baseline_steps peakweather_dir = Path(args.peakweather) - init_time = args.date + init_time = datetime.strptime(args.date, "%Y%m%d%H%M") outfn = Path(args.outfn) station = args.station param = args.param return ( baseline_steps, + forecast_steps, grib_dir, init_time, outfn, @@ -121,15 +126,14 @@ def preprocess_ds(ds, param: str): @app.cell def load_grib_data( baseline_steps, - data_source, - grib_decoder, + forecast_steps, grib_dir, init_time, - load_analysis_data_from_zarr, - load_baseline_from_zarr, + load_forecast_data, + load_truth_data, param, + peakweather_dir, preprocess_ds, - xr, zarr_dir_ana, zarr_dir_base, ): @@ -140,57 +144,28 @@ def load_grib_data( else: paramlist = [param] - grib_files = sorted(grib_dir.glob(f"{init_time}*.grib")) - fds = data_source.FileDataSource(datafiles=grib_files) - ds_fct = xr.Dataset(grib_decoder.load(fds, {"param": paramlist})) + ds_fct = load_forecast_data(grib_dir, init_time, forecast_steps, paramlist) ds_fct = preprocess_ds(ds_fct, param) da_fct = ds_fct[param].squeeze() - reftime = da_fct.ref_time.values - steps = da_fct.lead_time.dt.total_seconds() / 3600 - ds_ana = load_analysis_data_from_zarr(zarr_dir_ana, reftime, steps, paramlist) + steps = da_fct.lead_time.dt.total_seconds().values / 3600 + ds_ana = load_truth_data(zarr_dir_ana, init_time, steps, paramlist) ds_ana = preprocess_ds(ds_ana, param) da_ana = ds_ana[param].squeeze() - ds_base = load_baseline_from_zarr( - zarr_dir_base, da_fct.ref_time, baseline_steps, paramlist - ) + ds_base = load_forecast_data(zarr_dir_base, init_time, baseline_steps, paramlist) ds_base = preprocess_ds(ds_base, param) da_base = ds_base[param].squeeze() - return da_ana, da_base, da_fct - -@app.cell -def _(PeakWeatherDataset, da_fct, np, param, peakweather_dir, station): - if param == "T_2M": - parameter = "temperature" - offset = 273.15 # K to C - elif param == "SP_10M": - parameter = "wind_speed" - offset = 0 - elif param == "TOT_PREC": - parameter = "precipitation" - offset = 0 - else: - raise NotImplementedError( - f"The mapping for {param=} to PeakWeather is not implemented" - ) - - peakweather = PeakWeatherDataset(root=peakweather_dir, freq="1h") - obs, mask = peakweather.get_observations( - parameters=[parameter], - stations=station, - first_date=np.datetime_as_string(da_fct.valid_time.values[0]), - last_date=np.datetime_as_string(da_fct.valid_time.values[-1]), - return_mask=True, - ) - obs = obs.loc[:, mask.iloc[0]].droplevel("name", axis=1) - obs - return obs, offset, peakweather + ds_obs = load_truth_data(peakweather_dir, init_time, steps, paramlist) + ds_obs = preprocess_ds(ds_obs, param) + da_obs = ds_obs[param].squeeze() + return da_ana, da_base, da_fct, da_obs @app.cell -def _(peakweather): +def _(PeakWeatherDataset, peakweather_dir): + peakweather = PeakWeatherDataset(root=peakweather_dir, freq="1h") stations = peakweather.stations_table stations.index.names = ["station"] stations @@ -237,9 +212,8 @@ def _( da_ana, da_base, da_fct, + da_obs, init_time, - obs, - offset, outfn, plt, sta_idxs, @@ -254,9 +228,10 @@ def _( fig, ax = plt.subplots() # station + obs2plot = da_obs.sel(values=station) ax.plot( - obs.index.to_pydatetime(), - obs.to_numpy() + offset, + obs2plot["time"].values, + obs2plot.values, color="k", ls="--", label=station, @@ -284,7 +259,7 @@ def _( # forecast fct2plot = da_fct.isel(**fct_isel) ax.plot( - fct2plot["valid_time"].values, + fct2plot["time"].values, fct2plot.values, color="C0", label="forecast", diff --git a/workflow/scripts/verif_single_init.py b/workflow/scripts/verif_single_init.py index edeec10b..02afe548 100644 --- a/workflow/scripts/verif_single_init.py +++ b/workflow/scripts/verif_single_init.py @@ -4,17 +4,13 @@ from datetime import datetime from pathlib import Path -import numpy as np -import xarray as xr -from scipy.spatial import cKDTree from verification import verify # noqa: E402 +from verification.spatial import map_forecast_to_truth # noqa: E402 from data_input import ( - load_baseline_from_zarr, - load_analysis_data_from_zarr, - load_fct_data_from_grib, - load_obs_data_from_peakweather, parse_steps, + load_forecast_data, + load_truth_data, ) # noqa: E402 LOG = logging.getLogger(__name__) @@ -48,111 +44,13 @@ def program_summary_log(args): LOG.info("=" * 80) -def _map_fcst_to_truth( - fcst: xr.Dataset, truth: xr.Dataset -) -> tuple[xr.Dataset, xr.Dataset]: - """Map forecasts to the truth grid or station locations via nearest-neighbor lookup.""" - - truth = truth.sel(time=fcst.time) # swap time dimension to lead_time - - if "y" in fcst.dims and "x" in fcst.dims: - fcst = fcst.stack(values=("y", "x")) - fcst_lat = fcst["lat"].values.ravel() - fcst_lon = fcst["lon"].values.ravel() - - if "y" in truth.dims and "x" in truth.dims: - truth = truth.stack(values=("y", "x")) - truth_lat = truth["lat"].values.ravel() - truth_lon = truth["lon"].values.ravel() - - # TODO: Project to a metric CRS for a proper distance metric - fcst_lat_rad = np.deg2rad(fcst_lat) - fcst_lon_rad = np.deg2rad(fcst_lon) - truth_lat_rad = np.deg2rad(truth_lat) - truth_lon_rad = np.deg2rad(truth_lon) - - fcst_xyz = np.c_[ - np.cos(fcst_lat_rad) * np.cos(fcst_lon_rad), - np.cos(fcst_lat_rad) * np.sin(fcst_lon_rad), - np.sin(fcst_lat_rad), - ] - truth_xyz = np.c_[ - np.cos(truth_lat_rad) * np.cos(truth_lon_rad), - np.cos(truth_lat_rad) * np.sin(truth_lon_rad), - np.sin(truth_lat_rad), - ] - - fcst_tree = cKDTree(fcst_xyz) - _, fi = fcst_tree.query(truth_xyz, k=1) - fi = np.asarray(fi) - fcst = fcst.isel(values=fi) - fcst = fcst.drop_vars(["x", "y", "values"], errors="ignore") - fcst = fcst.assign_coords(lon=("values", truth.lon.data)) - fcst = fcst.assign_coords(lat=("values", truth.lat.data)) - fcst = fcst.assign_coords(values=truth["values"]) - - return fcst, truth - - -def _load_forecast(args: ScriptConfig) -> xr.Dataset: - """Load forecast data from GRIB files or a baseline Zarr dataset.""" - - if any(args.forecast.glob("*.grib")): - LOG.info("Loading forecasts from GRIB files...") - fcst = load_fct_data_from_grib( - root=args.forecast, - reftime=args.reftime, - steps=args.steps, - params=args.params, - ) - else: - LOG.info("Loading baseline forecasts from zarr dataset...") - fcst = load_baseline_from_zarr( - root=args.forecast, - reftime=args.reftime, - steps=args.steps, - params=args.params, - ) - - return fcst - - -def _load_truth(args: ScriptConfig) -> xr.Dataset: - """Load truth data from analysis Zarr or PeakWeather observations.""" - LOG.info("Loading ground truth from an analysis zarr dataset...") - if args.truth.suffix == ".zarr": - truth = load_analysis_data_from_zarr( - root=args.truth, - reftime=args.reftime, - steps=args.steps, - params=args.params, - ) - truth = truth.compute().chunk( - {"y": -1, "x": -1} - if "y" in truth.dims and "x" in truth.dims - else {"values": -1} - ) - elif "peakweather" in str(args.truth): - LOG.info("Loading ground truth from PeakWeather observations...") - # TODO: replace with OGD data - truth = load_obs_data_from_peakweather( - root=args.truth, - reftime=args.reftime, - steps=args.steps, - params=args.params, - ) - else: - raise ValueError(f"Unsupported truth root: {args.truth}") - return truth - - def main(args: ScriptConfig): """Main function to verify baseline forecast data.""" # get baseline forecast data now = datetime.now() - fcst = _load_forecast(args) + fcst = load_forecast_data(args.root, args.reftime, args.steps, args.params) LOG.info( "Loaded forecast data in %s seconds: \n%s", @@ -162,14 +60,16 @@ def main(args: ScriptConfig): # get truth data now = datetime.now() - truth = _load_truth(args) + truth = load_truth_data(args.root, args.reftime, args.steps, args.params) LOG.info( "Loaded truth data in %s seconds: \n%s", (datetime.now() - now).total_seconds(), truth, ) - fcst, truth = _map_fcst_to_truth(fcst, truth) + # align forecast and truth data spatially and temporally + fcst = map_forecast_to_truth(fcst, truth) + truth = truth.sel(time=fcst.time) # compute metrics and statistics results = verify(fcst, truth, args.label, args.truth_label, args.regions) From f2a1bd0dce6647760ff6fed4814b764558af2ad9 Mon Sep 17 00:00:00 2001 From: Daniele Nerini Date: Sun, 1 Mar 2026 22:58:36 +0100 Subject: [PATCH 06/13] Refactor --- workflow/scripts/plot_meteogram.mo.py | 115 ++++++++------------------ 1 file changed, 35 insertions(+), 80 deletions(-) diff --git a/workflow/scripts/plot_meteogram.mo.py b/workflow/scripts/plot_meteogram.mo.py index e9b9ce84..92a621c1 100644 --- a/workflow/scripts/plot_meteogram.mo.py +++ b/workflow/scripts/plot_meteogram.mo.py @@ -19,6 +19,7 @@ def _(): load_forecast_data, load_truth_data, ) + from verification.spatial import map_forecast_to_truth return ( ArgumentParser, @@ -27,6 +28,7 @@ def _(): datetime, load_forecast_data, load_truth_data, + map_forecast_to_truth, np, parse_steps, plt, @@ -118,13 +120,13 @@ def preprocess_ds(ds, param: str): "name": '"Wind speed', } ds = ds.drop_vars(["U", "V"]) - return ds + return ds.squeeze() return (preprocess_ds,) @app.cell -def load_grib_data( +def load_data( baseline_steps, forecast_steps, grib_dir, @@ -146,128 +148,80 @@ def load_grib_data( ds_fct = load_forecast_data(grib_dir, init_time, forecast_steps, paramlist) ds_fct = preprocess_ds(ds_fct, param) - da_fct = ds_fct[param].squeeze() - steps = da_fct.lead_time.dt.total_seconds().values / 3600 + steps = ds_fct.lead_time.dt.total_seconds().values / 3600 ds_ana = load_truth_data(zarr_dir_ana, init_time, steps, paramlist) ds_ana = preprocess_ds(ds_ana, param) - da_ana = ds_ana[param].squeeze() ds_base = load_forecast_data(zarr_dir_base, init_time, baseline_steps, paramlist) ds_base = preprocess_ds(ds_base, param) - da_base = ds_base[param].squeeze() ds_obs = load_truth_data(peakweather_dir, init_time, steps, paramlist) ds_obs = preprocess_ds(ds_obs, param) - da_obs = ds_obs[param].squeeze() - return da_ana, da_base, da_fct, da_obs + return ds_ana, ds_base, ds_fct @app.cell -def _(PeakWeatherDataset, peakweather_dir): - peakweather = PeakWeatherDataset(root=peakweather_dir, freq="1h") +def _(PeakWeatherDataset, peakweather_dir, station): + peakweather = PeakWeatherDataset(root=peakweather_dir) stations = peakweather.stations_table - stations.index.names = ["station"] - stations - return (stations,) + stations.index.names = ["values"] + ds_sta = stations.to_xarray().sel(values=[station]) # keep singleton dim + ds_sta = ds_sta.rename({"latitude": "lat", "longitude": "lon"}) + ds_sta = ds_sta.set_coords(("lat", "lon", "station_name")) + ds_sta = ds_sta.drop_vars(list(ds_sta.data_vars)) + ds_sta + return (ds_sta,) @app.cell -def _(da_ana, da_base, da_fct, np, stations): - def nearest_indexers_euclid(ds, lat_s, lon_s): - """ - Return a dict of indexers usable as: ds.isel(**indexers) - - Examples: - - 2D structured grid -> {"y": y_idx, "x": x_idx} - - 1D unstructured grid -> {"point": i_idx} (or {"cell": i_idx}, etc.) - """ - lat = ds["lat"] - lon = ds["lon"] - dist = (lat - lat_s) ** 2 + (lon - lon_s) ** 2 - arr = dist.values - - flat_idx = int(np.nanargmin(arr)) - - if dist.ndim == 1: - return {dist.dims[0]: flat_idx} - - unr = np.unravel_index(flat_idx, dist.shape) - return {dim: int(i) for dim, i in zip(dist.dims, unr)} - - def get_idx_row(row, da): - return nearest_indexers_euclid(da, row["latitude"], row["longitude"]) - - # store dicts (indexers) in columns - sta_idxs = stations.copy() - sta_idxs["fct_isel"] = sta_idxs.apply(lambda r: get_idx_row(r, da_fct), axis=1) - sta_idxs["ana_isel"] = sta_idxs.apply(lambda r: get_idx_row(r, da_ana), axis=1) - sta_idxs["base_isel"] = sta_idxs.apply(lambda r: get_idx_row(r, da_base), axis=1) - sta_idxs - return (sta_idxs,) +def _(ds_ana, ds_base, ds_fct, ds_sta, map_forecast_to_truth): + ds_fct_sta = map_forecast_to_truth(ds_fct, ds_sta) + ds_ana_sta = map_forecast_to_truth(ds_ana, ds_sta) + ds_base_sta = map_forecast_to_truth(ds_base, ds_sta) + return ds_ana_sta, ds_base_sta, ds_fct_sta @app.cell def _( - da_ana, - da_base, - da_fct, - da_obs, + ds_ana_sta, + ds_base_sta, + ds_fct, + ds_fct_sta, init_time, outfn, + param, plt, - sta_idxs, station, ): - # station indices - row = sta_idxs.loc[station] - fct_isel = row.fct_isel - ana_isel = row.ana_isel - base_isel = row.base_isel - fig, ax = plt.subplots() - # station - obs2plot = da_obs.sel(values=station) + # truth ax.plot( - obs2plot["time"].values, - obs2plot.values, + ds_ana_sta["time"].values, + ds_ana_sta[param].values, color="k", ls="--", - label=station, + label="truth", ) - - # analysis - ana2plot = da_ana.isel(**ana_isel) - ax.plot( - ana2plot["time"].values, - ana2plot.values, - color="k", - ls="-", - label="analysis", - ) - # baseline - base2plot = da_base.isel(**base_isel) ax.plot( - base2plot["time"].values, - base2plot.values, + ds_base_sta["time"].values, + ds_base_sta[param].values, color="C1", label="baseline", ) - # forecast - fct2plot = da_fct.isel(**fct_isel) ax.plot( - fct2plot["time"].values, - fct2plot.values, + ds_fct_sta["time"].values, + ds_fct_sta[param].values, color="C0", label="forecast", ) ax.legend() - param2plot = da_fct.attrs.get("parameter", {}) + param2plot = ds_fct[param].attrs.get("parameter", {}) short = param2plot.get("shortName", "") units = param2plot.get("units", "") name = param2plot.get("name", "") @@ -276,6 +230,7 @@ def _( ax.set_title(f"{init_time} {name} at {station}") plt.savefig(outfn) + print(f"saved: {outfn}") return From 20d1452b80c828d81678c3d35dae44db67a25898 Mon Sep 17 00:00:00 2001 From: Daniele Nerini Date: Mon, 2 Mar 2026 14:20:05 +0100 Subject: [PATCH 07/13] Include new spatial module --- src/verification/spatial.py | 143 +++++++++++++++++++++++++++++ tests/unit/test_spatial_mapping.py | 86 +++++++++++++++++ 2 files changed, 229 insertions(+) create mode 100644 src/verification/spatial.py create mode 100644 tests/unit/test_spatial_mapping.py diff --git a/src/verification/spatial.py b/src/verification/spatial.py new file mode 100644 index 00000000..67161333 --- /dev/null +++ b/src/verification/spatial.py @@ -0,0 +1,143 @@ +"""Spatial mapping helpers for aligning forecasts and references. + +This module contains reusable nearest-neighbor utilities used by verification +and plotting scripts to map data between different spatial supports. +""" + +from __future__ import annotations + +import numpy as np +import xarray as xr +from scipy.spatial import cKDTree + + +def spherical_nearest_neighbor_indices( + source_lat: np.ndarray, + source_lon: np.ndarray, + target_lat: np.ndarray, + target_lon: np.ndarray, +) -> np.ndarray: + """Return indices of nearest source points for each target point. + + Distances are computed in 3D Cartesian space after projecting latitude and + longitude (degrees) onto the unit sphere. This avoids distortions from + Euclidean distance in degree space. + + Parameters + ---------- + source_lat, source_lon + Latitude and longitude of source points in degrees. + target_lat, target_lon + Latitude and longitude of target points in degrees. + + Returns + ------- + np.ndarray + Integer indices into source points, one index per target point. + """ + + source_lat = np.asarray(source_lat).ravel() + source_lon = np.asarray(source_lon).ravel() + target_lat = np.asarray(target_lat).ravel() + target_lon = np.asarray(target_lon).ravel() + + source_lat_rad = np.deg2rad(source_lat) + source_lon_rad = np.deg2rad(source_lon) + target_lat_rad = np.deg2rad(target_lat) + target_lon_rad = np.deg2rad(target_lon) + + source_xyz = np.c_[ + np.cos(source_lat_rad) * np.cos(source_lon_rad), + np.cos(source_lat_rad) * np.sin(source_lon_rad), + np.sin(source_lat_rad), + ] + target_xyz = np.c_[ + np.cos(target_lat_rad) * np.cos(target_lon_rad), + np.cos(target_lat_rad) * np.sin(target_lon_rad), + np.sin(target_lat_rad), + ] + + tree = cKDTree(source_xyz) + _, nearest_idx = tree.query(target_xyz, k=1) + return np.asarray(nearest_idx, dtype=int) + + +def nearest_grid_yx_indices( + grid: xr.Dataset | xr.DataArray, target_lat: np.ndarray, target_lon: np.ndarray +) -> tuple[np.ndarray, np.ndarray]: + """Find nearest `(y, x)` grid indices for target coordinates. + + Parameters + ---------- + grid + Dataset or DataArray with `lat` and `lon` coordinates defined on a + `(y, x)` grid. + target_lat, target_lon + Target coordinates in degrees. + + Returns + ------- + tuple[np.ndarray, np.ndarray] + Arrays of `y` and `x` indices for each target location. + """ + + if "lat" not in grid or "lon" not in grid: + raise ValueError("Input must provide 'lat' and 'lon' coordinates") + + lat2d = np.asarray(grid["lat"].values) + lon2d = np.asarray(grid["lon"].values) + if lat2d.ndim != 2 or lon2d.ndim != 2: + raise ValueError("'lat' and 'lon' must be 2D on (y, x) for y/x indexing") + + flat_idx = spherical_nearest_neighbor_indices( + source_lat=lat2d.ravel(), + source_lon=lon2d.ravel(), + target_lat=target_lat, + target_lon=target_lon, + ) + y_idx, x_idx = np.unravel_index(flat_idx, lat2d.shape) + return np.asarray(y_idx, dtype=int), np.asarray(x_idx, dtype=int) + + +def map_forecast_to_truth(fcst: xr.Dataset, truth: xr.Dataset) -> xr.Dataset: + """Map forecast points to truth locations using nearest-neighbor matching. + + The forecast is flattened to a single spatial `values` dimension (when + provided as `(y, x)`), then sampled at the nearest points to each truth + location. Returned forecast coordinates are overwritten with truth station + coordinates to make subsequent verification align naturally. + + Parameters + ---------- + fcst + Forecast dataset with `lat` and `lon` coordinates on either `(y, x)` or + `values`. + truth + Reference dataset with `lat` and `lon` coordinates on either `(y, x)` or + `values`. + + Returns + ------- + xr.Dataset + Mapped forecast dataset. + """ + + if "y" in fcst.dims and "x" in fcst.dims: + fcst = fcst.stack(values=("y", "x")) + if "y" in truth.dims and "x" in truth.dims: + truth = truth.stack(values=("y", "x")) + + nearest_idx = spherical_nearest_neighbor_indices( + source_lat=fcst["lat"].values, + source_lon=fcst["lon"].values, + target_lat=truth["lat"].values, + target_lon=truth["lon"].values, + ) + + fcst = fcst.isel(values=nearest_idx) + fcst = fcst.drop_vars(["x", "y", "values"], errors="ignore") + fcst = fcst.assign_coords(lon=("values", truth.lon.data)) + fcst = fcst.assign_coords(lat=("values", truth.lat.data)) + fcst = fcst.assign_coords(values=truth["values"]) + + return fcst diff --git a/tests/unit/test_spatial_mapping.py b/tests/unit/test_spatial_mapping.py new file mode 100644 index 00000000..399893be --- /dev/null +++ b/tests/unit/test_spatial_mapping.py @@ -0,0 +1,86 @@ +import numpy as np +import xarray as xr + +from verification.spatial import ( + map_forecast_to_truth, + nearest_grid_yx_indices, + spherical_nearest_neighbor_indices, +) + + +def test_spherical_nearest_neighbor_indices_returns_expected_points(): + source_lat = np.array([46.0, 46.0, 47.0, 47.0]) + source_lon = np.array([7.0, 8.0, 7.0, 8.0]) + target_lat = np.array([46.1, 46.9]) + target_lon = np.array([7.1, 7.9]) + + idx = spherical_nearest_neighbor_indices( + source_lat=source_lat, + source_lon=source_lon, + target_lat=target_lat, + target_lon=target_lon, + ) + + assert np.array_equal(idx, np.array([0, 3])) + + +def test_nearest_grid_yx_indices_returns_grid_indices(): + lat = xr.DataArray([[46.0, 46.0], [47.0, 47.0]], dims=("y", "x")) + lon = xr.DataArray([[7.0, 8.0], [7.0, 8.0]], dims=("y", "x")) + grid = xr.Dataset(coords={"lat": lat, "lon": lon}) + + y_idx, x_idx = nearest_grid_yx_indices( + grid=grid, + target_lat=np.array([46.1, 46.9]), + target_lon=np.array([7.1, 7.9]), + ) + + assert np.array_equal(y_idx, np.array([0, 1])) + assert np.array_equal(x_idx, np.array([0, 1])) + + +def test_map_forecast_to_truth_maps_and_aligns_time(): + fcst_time = np.array(["2024-01-01T00:00", "2024-01-01T01:00"], dtype="datetime64[ns]") + truth_time = np.array( + ["2024-01-01T00:00", "2024-01-01T01:00", "2024-01-01T02:00"], + dtype="datetime64[ns]", + ) + + fcst = xr.Dataset( + data_vars={ + "T_2M": ( + ("time", "y", "x"), + np.array( + [ + [[1.0, 2.0], [3.0, 4.0]], + [[10.0, 20.0], [30.0, 40.0]], + ] + ), + ) + }, + coords={ + "time": fcst_time, + "y": [0, 1], + "x": [0, 1], + "lat": (("y", "x"), np.array([[46.0, 46.0], [47.0, 47.0]])), + "lon": (("y", "x"), np.array([[7.0, 8.0], [7.0, 8.0]])), + }, + ) + truth = xr.Dataset( + data_vars={"T_2M": (("time", "values"), np.zeros((3, 2)))}, + coords={ + "time": truth_time, + "values": ["STA1", "STA2"], + "lat": ("values", np.array([46.1, 46.9])), + "lon": ("values", np.array([7.1, 7.9])), + }, + ) + + mapped_fcst, mapped_truth = map_forecast_to_truth(fcst, truth) + + assert mapped_fcst["T_2M"].dims == ("time", "values") + assert np.array_equal(mapped_truth["time"].values, fcst_time) + assert np.array_equal(mapped_fcst["values"].values, np.array(["STA1", "STA2"])) + assert np.allclose(mapped_fcst["lat"].values, np.array([46.1, 46.9])) + assert np.allclose(mapped_fcst["lon"].values, np.array([7.1, 7.9])) + assert np.allclose(mapped_fcst["T_2M"].values, np.array([[1.0, 4.0], [10.0, 40.0]])) From 998b8fcb7039e08416ef85ffea0bf9ac1219a25c Mon Sep 17 00:00:00 2001 From: Daniele Nerini Date: Mon, 2 Mar 2026 14:25:54 +0100 Subject: [PATCH 08/13] Liniting --- src/verification/spatial.py | 2 +- tests/unit/test_spatial_mapping.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/verification/spatial.py b/src/verification/spatial.py index 67161333..ff08eaca 100644 --- a/src/verification/spatial.py +++ b/src/verification/spatial.py @@ -120,7 +120,7 @@ def map_forecast_to_truth(fcst: xr.Dataset, truth: xr.Dataset) -> xr.Dataset: ------- xr.Dataset Mapped forecast dataset. - """ + """ if "y" in fcst.dims and "x" in fcst.dims: fcst = fcst.stack(values=("y", "x")) diff --git a/tests/unit/test_spatial_mapping.py b/tests/unit/test_spatial_mapping.py index 399893be..e55ed2c3 100644 --- a/tests/unit/test_spatial_mapping.py +++ b/tests/unit/test_spatial_mapping.py @@ -40,7 +40,9 @@ def test_nearest_grid_yx_indices_returns_grid_indices(): def test_map_forecast_to_truth_maps_and_aligns_time(): - fcst_time = np.array(["2024-01-01T00:00", "2024-01-01T01:00"], dtype="datetime64[ns]") + fcst_time = np.array( + ["2024-01-01T00:00", "2024-01-01T01:00"], dtype="datetime64[ns]" + ) truth_time = np.array( ["2024-01-01T00:00", "2024-01-01T01:00", "2024-01-01T02:00"], dtype="datetime64[ns]", From 12d5b80d21047a21f46a3d8ab4e355f4a17b3893 Mon Sep 17 00:00:00 2001 From: Daniele Nerini Date: Mon, 2 Mar 2026 14:31:14 +0100 Subject: [PATCH 09/13] Fix tests --- tests/unit/test_spatial_mapping.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/tests/unit/test_spatial_mapping.py b/tests/unit/test_spatial_mapping.py index e55ed2c3..1e54b126 100644 --- a/tests/unit/test_spatial_mapping.py +++ b/tests/unit/test_spatial_mapping.py @@ -39,7 +39,7 @@ def test_nearest_grid_yx_indices_returns_grid_indices(): assert np.array_equal(x_idx, np.array([0, 1])) -def test_map_forecast_to_truth_maps_and_aligns_time(): +def test_map_forecast_to_truth_maps_forecast_to_truth_locations(): fcst_time = np.array( ["2024-01-01T00:00", "2024-01-01T01:00"], dtype="datetime64[ns]" ) @@ -78,11 +78,14 @@ def test_map_forecast_to_truth_maps_and_aligns_time(): }, ) - mapped_fcst, mapped_truth = map_forecast_to_truth(fcst, truth) + mapped_fcst = map_forecast_to_truth(fcst, truth) assert mapped_fcst["T_2M"].dims == ("time", "values") - assert np.array_equal(mapped_truth["time"].values, fcst_time) + assert np.array_equal(mapped_fcst["time"].values, fcst_time) assert np.array_equal(mapped_fcst["values"].values, np.array(["STA1", "STA2"])) assert np.allclose(mapped_fcst["lat"].values, np.array([46.1, 46.9])) assert np.allclose(mapped_fcst["lon"].values, np.array([7.1, 7.9])) - assert np.allclose(mapped_fcst["T_2M"].values, np.array([[1.0, 4.0], [10.0, 40.0]])) + assert np.allclose( + mapped_fcst["T_2M"].values, + np.array([[1.0, 4.0], [10.0, 40.0]]), + ) From 68b3d637367cf993bfbaa01287e665814f353a49 Mon Sep 17 00:00:00 2001 From: Louis-Frey Date: Tue, 3 Mar 2026 16:51:28 +0100 Subject: [PATCH 10/13] Bug Fix suggested by Claude Code. --- workflow/scripts/verif_single_init.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/workflow/scripts/verif_single_init.py b/workflow/scripts/verif_single_init.py index 02afe548..421078de 100644 --- a/workflow/scripts/verif_single_init.py +++ b/workflow/scripts/verif_single_init.py @@ -50,7 +50,7 @@ def main(args: ScriptConfig): # get baseline forecast data now = datetime.now() - fcst = load_forecast_data(args.root, args.reftime, args.steps, args.params) + fcst = load_forecast_data(args.forecast, args.reftime, args.steps, args.params) LOG.info( "Loaded forecast data in %s seconds: \n%s", @@ -60,7 +60,7 @@ def main(args: ScriptConfig): # get truth data now = datetime.now() - truth = load_truth_data(args.root, args.reftime, args.steps, args.params) + truth = load_truth_data(args.truth, args.reftime, args.steps, args.params) LOG.info( "Loaded truth data in %s seconds: \n%s", (datetime.now() - now).total_seconds(), From 02eeeea323410018269df79bf32e8ac583d33c12 Mon Sep 17 00:00:00 2001 From: Daniele Nerini Date: Mon, 9 Mar 2026 14:19:24 +0100 Subject: [PATCH 11/13] Allow for multiple baselines --- config/forecasters-ich1-oper.yaml | 9 +- workflow/rules/plot.smk | 62 ++++++---- workflow/scripts/plot_meteogram.mo.py | 163 +++++++++++++++++--------- 3 files changed, 156 insertions(+), 78 deletions(-) diff --git a/config/forecasters-ich1-oper.yaml b/config/forecasters-ich1-oper.yaml index 64666cc5..6e5b011f 100644 --- a/config/forecasters-ich1-oper.yaml +++ b/config/forecasters-ich1-oper.yaml @@ -11,7 +11,6 @@ dates: - 2025-02-01T06:00 - 2025-03-01T12:00 - runs: - forecaster: checkpoint: https://servicedepl.meteoswiss.ch/mlstore#/experiments/409/runs/b30acf68520a4bbd8324c44666561696 @@ -24,9 +23,14 @@ runs: - git+https://github.com/ecmwf/anemoi-inference.git@main baselines: + - baseline: + baseline_id: ICON-CH1-EPS + label: ICON-CH1-ctrl + root: /scratch/mch/cmerker/ICON-CH1-EPS + steps: 0/33/6 - baseline: baseline_id: ICON-CH2-EPS - label: ICON-CH2-EPS + label: ICON-CH2-ctrl root: /scratch/mch/cmerker/ICON-CH2-EPS steps: 0/120/6 @@ -34,7 +38,6 @@ truth: label: KENDA-CH1 root: /store_new/mch/msopr/ml/datasets/mch-ich1-1km-2024-2025-1h-pl13-v1.0.zarr - stratification: regions: - jura diff --git a/workflow/rules/plot.smk b/workflow/rules/plot.smk index a41c54ac..73badb2b 100644 --- a/workflow/rules/plot.smk +++ b/workflow/rules/plot.smk @@ -9,16 +9,20 @@ include: "common.smk" import pandas as pd -def _use_first_baseline_zarr(wc) -> tuple[str, str]: - """Get the first available baseline zarr for the given init time.""" +def _get_available_baselines(wc) -> list[dict[str, str]]: + """Get all available baseline zarr datasets for the given init time.""" + baselines = [] for baseline_id in BASELINE_CONFIGS: root = BASELINE_CONFIGS[baseline_id].get("root") steps = BASELINE_CONFIGS[baseline_id].get("steps") + label = BASELINE_CONFIGS[baseline_id].get("label", baseline_id) year = wc.init_time[2:4] baseline_zarr = f"{root}/FCST{year}.zarr" if Path(baseline_zarr).exists(): - return baseline_zarr, steps - raise ValueError(f"No baseline zarr found for init time {wc.init_time}") + baselines.append({"zarr": baseline_zarr, "steps": steps, "label": label}) + if not baselines: + raise ValueError(f"No baseline zarr found for init time {wc.init_time}") + return baselines rule plot_meteogram: @@ -26,7 +30,6 @@ rule plot_meteogram: script="workflow/scripts/plot_meteogram.mo.py", inference_okfile=rules.execute_inference.output.okfile, truth=config["truth"]["root"], - baseline_zarr=lambda wc: _use_first_baseline_zarr(wc)[0], peakweather_dir=rules.download_obs_from_peakweather.output.root, output: OUT_ROOT @@ -37,29 +40,46 @@ rule plot_meteogram: cpus_per_task=1, runtime="10m", params: - grib_out_dir=lambda wc: ( + ana_label=lambda wc: config["truth"]["label"], + fcst_grib=lambda wc: ( Path(OUT_ROOT) / f"data/runs/{wc.run_id}/{wc.init_time}/grib" ).resolve(), fcst_steps=lambda wc: RUN_CONFIGS[wc.run_id]["steps"], - baseline_steps=lambda wc: _use_first_baseline_zarr(wc)[1], + fcst_label=lambda wc: RUN_CONFIGS[wc.run_id]["label"], + baseline_zarrs=lambda wc: [x["zarr"] for x in _get_available_baselines(wc)], + baseline_steps=lambda wc: [x["steps"] for x in _get_available_baselines(wc)], + baseline_labels=lambda wc: [x["label"] for x in _get_available_baselines(wc)], shell: """ + set -euo pipefail export ECCODES_DEFINITION_PATH=$(realpath .venv/share/eccodes-cosmo-resources/definitions) - python {input.script} \ - --forecast {params.grib_out_dir} --steps {params.fcst_steps} \ - --baseline {input.baseline_zarr} --baseline_steps {params.baseline_steps} \ - --analysis {input.truth} \ - --peakweather {input.peakweather_dir} \ - --date {wildcards.init_time} --outfn {output[0]} \ - --param {wildcards.param} --station {wildcards.sta} + + BASELINE_ZARRS=({params.baseline_zarrs:q}) + BASELINE_STEPS=({params.baseline_steps:q}) + BASELINE_LABELS=({params.baseline_labels:q}) + + CMD_ARGS=( + --forecast {params.fcst_grib:q} + --forecast_steps {params.fcst_steps:q} + --forecast_label {params.fcst_label:q} + --analysis {input.truth:q} + --analysis_label {params.ana_label:q} + --peakweather {input.peakweather_dir:q} + --date {wildcards.init_time:q} + --outfn {output[0]:q} + --param {wildcards.param:q} + --station {wildcards.sta:q} + ) + + for i in "${{!BASELINE_ZARRS[@]}}"; do + CMD_ARGS+=(--baseline "${{BASELINE_ZARRS[$i]}}") + CMD_ARGS+=(--baseline_steps "${{BASELINE_STEPS[$i]}}") + CMD_ARGS+=(--baseline_label "${{BASELINE_LABELS[$i]}}") + done + + python {input.script} "${{CMD_ARGS[@]}}" # interactive editing (needs to set localrule: True and use only one core) - # marimo edit {input.script} -- \ - # --forecast {params.grib_out_dir} --steps {params.fcst_steps} \ - # --baseline {input.baseline_zarr} --baseline_steps {params.baseline_steps} \ - # --analysis {input.truth} \ - # --peakweather {input.peakweather_dir} \ - # --date {wildcards.init_time} --outfn {output[0]} \ - # --param {wildcards.param} --station {wildcards.sta} + # marimo edit {input.script} -- "${{CMD_ARGS[@]}}" """ diff --git a/workflow/scripts/plot_meteogram.mo.py b/workflow/scripts/plot_meteogram.mo.py index 92a621c1..c26d2c24 100644 --- a/workflow/scripts/plot_meteogram.mo.py +++ b/workflow/scripts/plot_meteogram.mo.py @@ -43,23 +43,50 @@ def _(ArgumentParser, Path, datetime, parse_steps): "--forecast", type=str, default=None, help="Directory to forecast grib data" ) parser.add_argument( - "--steps", + "--forecast_steps", type=parse_steps, default="0/120/6", help="Forecast steps in the format 'start/stop/step' (default: 0/120/6).", ) parser.add_argument( - "--baseline", type=str, default=None, help="Path to baseline zarr data" + "--forecast_label", + type=str, + default="forecast", + help="Label for forecast line in plot legend.", + ) + parser.add_argument( + "--baseline", + action="append", + type=str, + default=[], + help="Path to baseline zarr data (repeatable).", ) parser.add_argument( "--baseline_steps", + action="append", type=parse_steps, - default="0/120/6", - help="Forecast steps in the format 'start/stop/step' (default: 0/120/6).", + default=[], + help=( + "Forecast steps in the format 'start/stop/step' for each baseline " + "(repeatable, must match --baseline count)." + ), + ) + parser.add_argument( + "--baseline_label", + action="append", + type=str, + default=[], + help="Label for each baseline line in plot legend (repeatable).", ) parser.add_argument( "--analysis", type=str, default=None, help="Path to analysis zarr data" ) + parser.add_argument( + "--analysis_label", + type=str, + default="truth", + help="Label for analysis line in plot legend.", + ) parser.add_argument( "--peakweather", type=str, default=None, help="Path to PeakWeather dataset" ) @@ -69,27 +96,43 @@ def _(ArgumentParser, Path, datetime, parse_steps): parser.add_argument("--station", type=str, help="station") args = parser.parse_args() - grib_dir = Path(args.forecast) - forecast_steps = args.steps - zarr_dir_ana = Path(args.analysis) - zarr_dir_base = Path(args.baseline) + forecast_grib_dir = Path(args.forecast) + forecast_steps = args.forecast_steps + forecast_label = args.forecast_label + analysis_zarr = Path(args.analysis) + analysis_label = args.analysis_label + baseline_zarrs = [Path(path) for path in args.baseline] baseline_steps = args.baseline_steps + baseline_labels = args.baseline_label + if len(baseline_zarrs) != len(baseline_steps): + raise ValueError( + "Mismatched baseline arguments: --baseline and --baseline_steps " + "must be provided the same number of times." + ) + if len(baseline_labels) != len(baseline_zarrs): + raise ValueError( + "Mismatched baseline arguments: --baseline and --baseline_label " + "must be provided the same number of times." + ) peakweather_dir = Path(args.peakweather) init_time = datetime.strptime(args.date, "%Y%m%d%H%M") outfn = Path(args.outfn) station = args.station param = args.param return ( + analysis_label, + analysis_zarr, + baseline_labels, baseline_steps, + baseline_zarrs, + forecast_label, forecast_steps, - grib_dir, + forecast_grib_dir, init_time, outfn, param, peakweather_dir, station, - zarr_dir_ana, - zarr_dir_base, ) @@ -127,17 +170,16 @@ def preprocess_ds(ds, param: str): @app.cell def load_data( + analysis_zarr, baseline_steps, + baseline_zarrs, forecast_steps, - grib_dir, + forecast_grib_dir, init_time, load_forecast_data, load_truth_data, param, - peakweather_dir, preprocess_ds, - zarr_dir_ana, - zarr_dir_base, ): if param == "SP_10M": paramlist = ["U_10M", "V_10M"] @@ -146,19 +188,24 @@ def load_data( else: paramlist = [param] - ds_fct = load_forecast_data(grib_dir, init_time, forecast_steps, paramlist) - ds_fct = preprocess_ds(ds_fct, param) + forecast_ds = load_forecast_data( + forecast_grib_dir, init_time, forecast_steps, paramlist + ) + forecast_ds = preprocess_ds(forecast_ds, param) - steps = ds_fct.lead_time.dt.total_seconds().values / 3600 - ds_ana = load_truth_data(zarr_dir_ana, init_time, steps, paramlist) - ds_ana = preprocess_ds(ds_ana, param) + steps = forecast_ds.lead_time.dt.total_seconds().values / 3600 + analysis_ds = load_truth_data(analysis_zarr, init_time, steps, paramlist) + analysis_ds = preprocess_ds(analysis_ds, param) - ds_base = load_forecast_data(zarr_dir_base, init_time, baseline_steps, paramlist) - ds_base = preprocess_ds(ds_base, param) + baseline_ds_list = [ + preprocess_ds( + load_forecast_data(zarr, init_time, step, paramlist), + param, + ) + for zarr, step in zip(baseline_zarrs, baseline_steps) + ] - ds_obs = load_truth_data(peakweather_dir, init_time, steps, paramlist) - ds_obs = preprocess_ds(ds_obs, param) - return ds_ana, ds_base, ds_fct + return analysis_ds, baseline_ds_list, forecast_ds @app.cell @@ -166,28 +213,33 @@ def _(PeakWeatherDataset, peakweather_dir, station): peakweather = PeakWeatherDataset(root=peakweather_dir) stations = peakweather.stations_table stations.index.names = ["values"] - ds_sta = stations.to_xarray().sel(values=[station]) # keep singleton dim - ds_sta = ds_sta.rename({"latitude": "lat", "longitude": "lon"}) - ds_sta = ds_sta.set_coords(("lat", "lon", "station_name")) - ds_sta = ds_sta.drop_vars(list(ds_sta.data_vars)) - ds_sta - return (ds_sta,) + station_ds = stations.to_xarray().sel(values=[station]) # keep singleton dim + station_ds = station_ds.rename({"latitude": "lat", "longitude": "lon"}) + station_ds = station_ds.set_coords(("lat", "lon", "station_name")) + station_ds = station_ds.drop_vars(list(station_ds.data_vars)) + station_ds + return (station_ds,) @app.cell -def _(ds_ana, ds_base, ds_fct, ds_sta, map_forecast_to_truth): - ds_fct_sta = map_forecast_to_truth(ds_fct, ds_sta) - ds_ana_sta = map_forecast_to_truth(ds_ana, ds_sta) - ds_base_sta = map_forecast_to_truth(ds_base, ds_sta) - return ds_ana_sta, ds_base_sta, ds_fct_sta +def _(analysis_ds, baseline_ds_list, forecast_ds, station_ds, map_forecast_to_truth): + forecast_station_ds = map_forecast_to_truth(forecast_ds, station_ds) + analysis_station_ds = map_forecast_to_truth(analysis_ds, station_ds) + baseline_station_ds_list = [ + map_forecast_to_truth(ds, station_ds) for ds in baseline_ds_list + ] + return analysis_station_ds, baseline_station_ds_list, forecast_station_ds @app.cell def _( - ds_ana_sta, - ds_base_sta, - ds_fct, - ds_fct_sta, + analysis_label, + baseline_labels, + analysis_station_ds, + baseline_station_ds_list, + forecast_label, + forecast_ds, + forecast_station_ds, init_time, outfn, param, @@ -198,30 +250,33 @@ def _( # truth ax.plot( - ds_ana_sta["time"].values, - ds_ana_sta[param].values, + analysis_station_ds["time"].values, + analysis_station_ds[param].values, color="k", ls="--", - label="truth", - ) - # baseline - ax.plot( - ds_base_sta["time"].values, - ds_base_sta[param].values, - color="C1", - label="baseline", + label=analysis_label, ) + # baselines + for i, (baseline_label, baseline_station_ds) in enumerate( + zip(baseline_labels, baseline_station_ds_list), start=1 + ): + ax.plot( + baseline_station_ds["time"].values, + baseline_station_ds[param].values, + color=f"C{i}", + label=f"{baseline_label}", + ) # forecast ax.plot( - ds_fct_sta["time"].values, - ds_fct_sta[param].values, + forecast_station_ds["time"].values, + forecast_station_ds[param].values, color="C0", - label="forecast", + label=forecast_label, ) ax.legend() - param2plot = ds_fct[param].attrs.get("parameter", {}) + param2plot = forecast_ds[param].attrs.get("parameter", {}) short = param2plot.get("shortName", "") units = param2plot.get("units", "") name = param2plot.get("name", "") From b10f024895b4035597da12d1022d02c2e68f2e6c Mon Sep 17 00:00:00 2001 From: Daniele Nerini Date: Wed, 11 Mar 2026 22:32:31 +0100 Subject: [PATCH 12/13] Reshape fcst to grid if truth is gridded analysis --- src/verification/spatial.py | 8 +++++- tests/unit/test_spatial_mapping.py | 42 ++++++++++++++++++++++++++++++ 2 files changed, 49 insertions(+), 1 deletion(-) diff --git a/src/verification/spatial.py b/src/verification/spatial.py index ff08eaca..d440d524 100644 --- a/src/verification/spatial.py +++ b/src/verification/spatial.py @@ -121,10 +121,13 @@ def map_forecast_to_truth(fcst: xr.Dataset, truth: xr.Dataset) -> xr.Dataset: xr.Dataset Mapped forecast dataset. """ + # TODO: return fcst unchanged when forecast and truth are already aligned + + truth_is_grid = "y" in truth.dims and "x" in truth.dims if "y" in fcst.dims and "x" in fcst.dims: fcst = fcst.stack(values=("y", "x")) - if "y" in truth.dims and "x" in truth.dims: + if truth_is_grid: truth = truth.stack(values=("y", "x")) nearest_idx = spherical_nearest_neighbor_indices( @@ -140,4 +143,7 @@ def map_forecast_to_truth(fcst: xr.Dataset, truth: xr.Dataset) -> xr.Dataset: fcst = fcst.assign_coords(lat=("values", truth.lat.data)) fcst = fcst.assign_coords(values=truth["values"]) + if truth_is_grid: + fcst = fcst.unstack("values") + return fcst diff --git a/tests/unit/test_spatial_mapping.py b/tests/unit/test_spatial_mapping.py index 1e54b126..73d56954 100644 --- a/tests/unit/test_spatial_mapping.py +++ b/tests/unit/test_spatial_mapping.py @@ -89,3 +89,45 @@ def test_map_forecast_to_truth_maps_forecast_to_truth_locations(): mapped_fcst["T_2M"].values, np.array([[1.0, 4.0], [10.0, 40.0]]), ) + + +def test_map_forecast_to_truth_restores_grid_when_truth_is_gridded(): + fcst_time = np.array(["2024-01-01T00:00"], dtype="datetime64[ns]") + + fcst = xr.Dataset( + data_vars={ + "T_2M": ( + ("time", "y", "x"), + np.array([[[1.0, 2.0], [3.0, 4.0]]]), + ) + }, + coords={ + "time": fcst_time, + "y": [0, 1], + "x": [0, 1], + "lat": (("y", "x"), np.array([[46.0, 46.0], [47.0, 47.0]])), + "lon": (("y", "x"), np.array([[7.0, 8.0], [7.0, 8.0]])), + }, + ) + truth = xr.Dataset( + data_vars={"T_2M": (("time", "y", "x"), np.zeros((1, 2, 2)))}, + coords={ + "time": fcst_time, + "y": [0, 1], + "x": [0, 1], + "lat": (("y", "x"), np.array([[46.1, 46.1], [46.9, 46.9]])), + "lon": (("y", "x"), np.array([[7.1, 7.9], [7.1, 7.9]])), + }, + ) + + mapped_fcst = map_forecast_to_truth(fcst, truth) + + assert mapped_fcst["T_2M"].dims == ("time", "y", "x") + assert np.array_equal(mapped_fcst["y"].values, np.array([0, 1])) + assert np.array_equal(mapped_fcst["x"].values, np.array([0, 1])) + assert np.allclose(mapped_fcst["lat"].values, truth["lat"].values) + assert np.allclose(mapped_fcst["lon"].values, truth["lon"].values) + assert np.allclose( + mapped_fcst["T_2M"].values, + np.array([[[1.0, 2.0], [3.0, 4.0]]]), + ) From 6efe4305cbd2c266393632bb53e01ede36293485 Mon Sep 17 00:00:00 2001 From: Daniele Nerini Date: Wed, 11 Mar 2026 22:35:32 +0100 Subject: [PATCH 13/13] Lint --- src/verification/spatial.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/verification/spatial.py b/src/verification/spatial.py index d440d524..a5186d5f 100644 --- a/src/verification/spatial.py +++ b/src/verification/spatial.py @@ -122,7 +122,7 @@ def map_forecast_to_truth(fcst: xr.Dataset, truth: xr.Dataset) -> xr.Dataset: Mapped forecast dataset. """ # TODO: return fcst unchanged when forecast and truth are already aligned - + truth_is_grid = "y" in truth.dims and "x" in truth.dims if "y" in fcst.dims and "x" in fcst.dims: