11import logging
22import os
33import sys
4- from datetime import datetime
4+ from datetime import datetime , timedelta
55from pathlib import Path
6- from typing import Iterable
76
87eccodes_definition_path = Path (sys .prefix ) / "share/eccodes-cosmo-resources/definitions"
98os .environ ["ECCODES_DEFINITION_PATH" ] = str (eccodes_definition_path )
1615LOG = logging .getLogger (__name__ )
1716
1817
18+ def _select_valid_times (ds , times : np .datetime64 ):
19+ # (handle special case where some valid times are not in the dataset, e.g. at the end)
20+ times_np = np .asarray (times , dtype = "datetime64[ns]" )
21+ times_included = np .isin (times_np , ds .time .values )
22+ if times_included .all ():
23+ return ds .sel (time = times_np )
24+ elif times_included .any ():
25+ LOG .warning (
26+ "Some valid times are not included in the dataset: \n %s" ,
27+ times_np [~ times_included ],
28+ )
29+ return ds .sel (time = times_np [times_included ])
30+ else :
31+ raise ValueError (
32+ "Valid times are not included in the dataset. "
33+ "Please check the valid times and the dataset."
34+ )
35+
36+
37+ def parse_steps (steps : str ) -> list [int ]:
38+ # check that steps is in the format "start/stop/step"
39+ if "/" not in steps :
40+ raise ValueError (f"Expected steps in format 'start/stop/step', got '{ steps } '" )
41+ if len (steps .split ("/" )) != 3 :
42+ raise ValueError (f"Expected steps in format 'start/stop/step', got '{ steps } '" )
43+ start , end , step = map (int , steps .split ("/" ))
44+ return list (range (start , end + 1 , step ))
45+
46+
1947def load_analysis_data_from_zarr (
20- analysis_zarr : Path , times : Iterable [ datetime ], params : list [str ]
48+ root : Path , reftime : datetime , steps : list [ int ], params : list [str ]
2149) -> xr .Dataset :
2250 """Load analysis data from an anemoi-generated Zarr dataset
2351
@@ -36,9 +64,9 @@ def load_analysis_data_from_zarr(
3664 PARAMS_MAP_COSMO1 = {
3765 v : v .replace ("TOT_PREC" , "TOT_PREC_6H" ) for v in PARAMS_MAP_COSMO2 .keys ()
3866 }
39- PARAMS_MAP = PARAMS_MAP_COSMO2 if "co2" in analysis_zarr .name else PARAMS_MAP_COSMO1
67+ PARAMS_MAP = PARAMS_MAP_COSMO2 if "co2" in root .name else PARAMS_MAP_COSMO1
4068
41- ds = xr .open_zarr (analysis_zarr , consolidated = False )
69+ ds = xr .open_zarr (root , consolidated = False )
4270
4371 # rename "dates" to "time" and set it as index
4472 ds = ds .set_index (time = "dates" )
@@ -59,8 +87,8 @@ def load_analysis_data_from_zarr(
5987
6088 # set lat lon as coords (optional)
6189 if "latitudes" in ds and "longitudes" in ds :
62- ds = ds .rename ({"latitudes" : "latitude " , "longitudes" : "longitude " })
63- ds = ds .set_coords (["latitude " , "longitude " ])
90+ ds = ds .rename ({"latitudes" : "lat " , "longitudes" : "lon " })
91+ ds = ds .set_coords (["lat " , "lon " ])
6492 ds = (
6593 ds ["data" ]
6694 .to_dataset ("variable" )
@@ -71,30 +99,15 @@ def load_analysis_data_from_zarr(
7199 if "cell" in ds .dims :
72100 ds = ds .rename ({"cell" : "values" })
73101
74- # select valid times
75- # (handle special case where some valid times are not in the dataset, e.g. at the end)
76- times_included = times .isin (ds .time .values ).values
77- if all (times_included ):
78- ds = ds .sel (time = times )
79- elif np .sum (times_included ) < len (times_included ):
80- LOG .warning (
81- "Some valid times are not included in the dataset: \n %s" ,
82- times [~ times_included ].values ,
83- )
84- ds = ds .sel (time = times [times_included ])
85- else :
86- raise ValueError (
87- "Valid times are not included in the dataset. "
88- "Please check the valid times and the dataset."
89- )
90- return ds
102+ times = np .datetime64 (reftime ) + np .asarray (steps , dtype = "timedelta64[h]" )
103+ return _select_valid_times (ds , times )
91104
92105
93106def load_fct_data_from_grib (
94- grib_output_dir : Path , reftime : datetime , steps : list [int ], params : list [str ]
107+ root : Path , reftime : datetime , steps : list [int ], params : list [str ]
95108) -> xr .Dataset :
96109 """Load forecast data from GRIB files for a specific valid time."""
97- files = sorted (grib_output_dir .glob ("20 *.grib" ))
110+ files = sorted (root .glob (f" { reftime :%Y%m%d%H%M } *.grib" ))
98111 fds = data_source .FileDataSource (datafiles = files )
99112 ds = grib_decoder .load (fds , {"param" : params , "step" : steps })
100113 for var , da in ds .items ():
@@ -127,13 +140,13 @@ def load_fct_data_from_grib(
127140
128141
129142def load_baseline_from_zarr (
130- zarr_path : Path , reftime : datetime , steps : list [int ], params : list [str ]
143+ root : Path , reftime : datetime , steps : list [int ], params : list [str ]
131144) -> xr .Dataset :
132145 """Load forecast data from a Zarr dataset."""
133146 try :
134- baseline = xr .open_zarr (zarr_path , consolidated = True , decode_timedelta = True )
147+ baseline = xr .open_zarr (root , consolidated = True , decode_timedelta = True )
135148 except ValueError :
136- raise ValueError (f"Could not open baseline zarr at { zarr_path } " )
149+ raise ValueError (f"Could not open baseline zarr at { root } " )
137150
138151 baseline = baseline .rename (
139152 {"forecast_reference_time" : "ref_time" , "step" : "lead_time" }
@@ -156,14 +169,116 @@ def load_baseline_from_zarr(
156169 lead_time = np .array (steps , dtype = "timedelta64[h]" ),
157170 )
158171 baseline = baseline .assign_coords (time = baseline .ref_time + baseline .lead_time )
172+ if "latitude" in baseline .coords and "longitude" in baseline :
173+ baseline = baseline .rename ({"latitude" : "lat" , "longitude" : "lon" })
159174 return baseline
160175
161176
162- def parse_steps (steps : str ) -> list [int ]:
163- # check that steps is in the format "start/stop/step"
164- if "/" not in steps :
165- raise ValueError (f"Expected steps in format 'start/stop/step', got '{ steps } '" )
166- if len (steps .split ("/" )) != 3 :
167- raise ValueError (f"Expected steps in format 'start/stop/step', got '{ steps } '" )
168- start , end , step = map (int , steps .split ("/" ))
169- return list (range (start , end + 1 , step ))
177+ def load_obs_data_from_peakweather (
178+ root , reftime : datetime , steps : list [int ], params : list [str ], freq : str = "1h"
179+ ) -> xr .Dataset :
180+ """Load PeakWeather station observations into an xarray Dataset.
181+
182+ Returns a Dataset with dimensions `time` and `values`, values coordinates
183+ (`lat`, `lon`), and variables renamed to ICON parameter names.
184+ Temperatures are converted to Kelvin when present.
185+ """
186+ from peakweather .dataset import PeakWeatherDataset
187+
188+ param_names = {
189+ "temperature" : "T_2M" ,
190+ "wind_u" : "U_10M" ,
191+ "wind_v" : "V_10M" ,
192+ }
193+ param_names = {k : v for k , v in param_names .items () if v in params }
194+
195+ start = reftime
196+ end = start + timedelta (hours = max (steps ))
197+ if len (steps ) > 1 :
198+ end += timedelta (hours = steps [- 1 ] - steps [- 2 ]) # extend by 1 extra step
199+ years = list (set ([start .year , end .year ]))
200+ pw = PeakWeatherDataset (root = root , years = years , freq = freq )
201+ ds , mask = pw .get_observations (
202+ parameters = [k for k in param_names .keys ()],
203+ first_date = f"{ start :%Y-%m-%d %H:%M} " ,
204+ last_date = f"{ end :%Y-%m-%d %H:%M} " ,
205+ return_mask = True ,
206+ )
207+ ds = (
208+ ds .stack (["nat_abbr" , "name" ], future_stack = True )
209+ .to_xarray ()
210+ .to_dataset (dim = "name" )
211+ )
212+ mask = (
213+ mask .stack (["nat_abbr" , "name" ], future_stack = True )
214+ .to_xarray ()
215+ .to_dataset (dim = "name" )
216+ )
217+ ds = ds .where (mask )
218+ ds = ds .rename ({"datetime" : "time" , "nat_abbr" : "values" })
219+ ds = ds .rename (param_names )
220+ ds = ds .assign_coords (time = ds .indexes ["time" ].tz_convert ("UTC" ).tz_localize (None ))
221+ ds = ds .assign_coords (values = ds .indexes ["values" ])
222+ ds = ds .assign_coords (lon = ("values" , pw .stations_table ["longitude" ]))
223+ ds = ds .assign_coords (lat = ("values" , pw .stations_table ["latitude" ]))
224+ if "T_2M" in ds :
225+ ds ["T_2M" ] = ds ["T_2M" ] + 273.15 # convert to Kelvin
226+ ds = ds .dropna ("values" , how = "all" )
227+
228+ times = np .datetime64 (reftime ) + np .asarray (steps , dtype = "timedelta64[h]" )
229+ return _select_valid_times (ds , times )
230+
231+
232+ def load_truth_data (
233+ root , reftime : datetime , steps : list [int ], params : list [str ]
234+ ) -> xr .Dataset :
235+ """Load truth data from analysis Zarr or PeakWeather observations."""
236+ if root .suffix == ".zarr" :
237+ LOG .info ("Loading ground truth from an analysis zarr dataset..." )
238+ truth = load_analysis_data_from_zarr (
239+ root = root ,
240+ reftime = reftime ,
241+ steps = steps ,
242+ params = params ,
243+ )
244+ truth = truth .compute ().chunk (
245+ {"y" : - 1 , "x" : - 1 }
246+ if "y" in truth .dims and "x" in truth .dims
247+ else {"values" : - 1 }
248+ )
249+ elif "peakweather" in str (root ):
250+ LOG .info ("Loading ground truth from PeakWeather observations..." )
251+ truth = load_obs_data_from_peakweather (
252+ root = root ,
253+ reftime = reftime ,
254+ steps = steps ,
255+ params = params ,
256+ )
257+ else :
258+ raise ValueError (f"Unsupported truth root: { root } " )
259+ return truth
260+
261+
262+ def load_forecast_data (
263+ root , reftime : datetime , steps : list [int ], params : list [str ]
264+ ) -> xr .Dataset :
265+ """Load forecast data from GRIB files or a baseline Zarr dataset."""
266+
267+ if any (root .glob ("*.grib" )):
268+ LOG .info ("Loading forecasts from GRIB files..." )
269+ fcst = load_fct_data_from_grib (
270+ root = root ,
271+ reftime = reftime ,
272+ steps = steps ,
273+ params = params ,
274+ )
275+ else :
276+ LOG .info ("Loading baseline forecasts from zarr dataset..." )
277+ fcst = load_baseline_from_zarr (
278+ root = root ,
279+ reftime = reftime ,
280+ steps = steps ,
281+ params = params ,
282+ )
283+
284+ return fcst
0 commit comments