From 181f9435941760604c0be9959ce4a2bc212ea254 Mon Sep 17 00:00:00 2001 From: James Fulton Date: Mon, 23 Mar 2026 17:26:08 +0000 Subject: [PATCH 1/5] Add optional input clipping --- ocf_data_sampler/config/model.py | 18 ++++++- .../torch_datasets/pvnet_dataset.py | 32 ++++++++---- .../config_normalization_values_to_dicts.py | 35 ++++++++++--- .../test_data/configs/pvnet_test_config.yaml | 2 + tests/test_data/configs/test_config.yaml | 49 ------------------- 5 files changed, 70 insertions(+), 66 deletions(-) delete mode 100644 tests/test_data/configs/test_config.yaml diff --git a/ocf_data_sampler/config/model.py b/ocf_data_sampler/config/model.py index d9175481..bff3211d 100644 --- a/ocf_data_sampler/config/model.py +++ b/ocf_data_sampler/config/model.py @@ -157,10 +157,26 @@ class SpatialWindowMixin(Base): class NormalisationValues(Base): - """Normalisation mean and standard deviation.""" + """Normalisation parameters.""" mean: float = Field(..., description="Mean value for normalization") std: float = Field(..., gt=0, description="Standard deviation (must be positive)") + clip_min: float = Field( + float('-inf'), + description="Minimum value to clip to before normalisation. If None, no clipping is applied" + ) + clip_max: float = Field( + float('inf'), + description="Maximum value to clip to before normalisation. If None, no clipping is applied" + ) + @model_validator(mode="after") + def validate_clip_range(self) -> "NormalisationValues": + """Validate that clip_min is less than clip_max.""" + if self.clip_min >= self.clip_max: + raise ValueError( + f"clip_min ({self.clip_min}) must be less than clip_max ({self.clip_max})" + ) + return self class NormalisationConstantsMixin(Base): """Normalisation constants for multiple channels.""" diff --git a/ocf_data_sampler/torch_datasets/pvnet_dataset.py b/ocf_data_sampler/torch_datasets/pvnet_dataset.py index c7adff72..d7119149 100644 --- a/ocf_data_sampler/torch_datasets/pvnet_dataset.py +++ b/ocf_data_sampler/torch_datasets/pvnet_dataset.py @@ -174,9 +174,13 @@ def __init__( ) # Extract the normalisation values from the config for faster access - means_dict, stds_dict = config_normalization_values_to_dicts(config) - self.means_dict = means_dict - self.stds_dict = stds_dict + mean_dict, std_dict, clip_min_dict, clip_max_dict = ( + config_normalization_values_to_dicts(config) + ) + self.mean_dict = mean_dict + self.std_dict = std_dict + self.clip_min_dict = clip_min_dict + self.clip_max_dict = clip_max_dict def process_and_combine_datasets( self, @@ -194,15 +198,25 @@ def process_and_combine_datasets( # Normalise NWP if "nwp" in dataset_dict: for nwp_key, da_nwp in dataset_dict["nwp"].items(): - channel_means = self.means_dict["nwp"][nwp_key] - channel_stds = self.stds_dict["nwp"][nwp_key] - dataset_dict["nwp"][nwp_key] = (da_nwp - channel_means) / channel_stds + channel_means = self.mean_dict["nwp"][nwp_key] + channel_stds = self.std_dict["nwp"][nwp_key] + channel_mins = self.clip_min_dict["nwp"][nwp_key] + channel_maxs = self.clip_max_dict["nwp"][nwp_key] + dataset_dict["nwp"][nwp_key].data = ( + (da_nwp.data.clip(channel_mins, channel_maxs) - channel_means) + / channel_stds + ) # Normalise satellite if "sat" in dataset_dict: - channel_means = self.means_dict["sat"] - channel_stds = self.stds_dict["sat"] - dataset_dict["sat"] = (dataset_dict["sat"] - channel_means) / channel_stds + channel_means = self.mean_dict["sat"] + channel_stds = self.std_dict["sat"] + channel_mins = self.clip_min_dict["sat"] + channel_maxs = self.clip_max_dict["sat"] + dataset_dict["sat"].data = ( + (dataset_dict["sat"].data.clip(channel_mins, channel_maxs) - channel_means) + / channel_stds + ) # Normalise generation by capacity if "generation" in dataset_dict: diff --git a/ocf_data_sampler/torch_datasets/utils/config_normalization_values_to_dicts.py b/ocf_data_sampler/torch_datasets/utils/config_normalization_values_to_dicts.py index 9f8b1123..0e65a695 100644 --- a/ocf_data_sampler/torch_datasets/utils/config_normalization_values_to_dicts.py +++ b/ocf_data_sampler/torch_datasets/utils/config_normalization_values_to_dicts.py @@ -8,7 +8,7 @@ def config_normalization_values_to_dicts( config: Configuration, ) -> tuple[dict[str, np.ndarray | dict[str, np.ndarray]]]: - """Construct numpy arrays of mean and std values from the config normalisation constants. + """Construct numpy arrays of mean, std, and clip values from the config normalisation constants. Args: config: Data configuration. @@ -16,44 +16,65 @@ def config_normalization_values_to_dicts( Returns: Means dict Stds dict + Clip min dict + Clip max dict """ means_dict = {} stds_dict = {} + clip_min_dict = {} + clip_max_dict = {} if config.input_data.nwp is not None: means_dict["nwp"] = {} stds_dict["nwp"] = {} + clip_min_dict["nwp"] = {} + clip_max_dict["nwp"] = {} for nwp_key in config.input_data.nwp: nwp_config = config.input_data.nwp[nwp_key] means_list = [] stds_list = [] + clip_min_list = [] + clip_max_list = [] for channel in list(nwp_config.channels): # These accumulated channels are diffed and renamed if channel in nwp_config.accum_channels: channel =f"diff_{channel}" - means_list.append(nwp_config.normalisation_constants[channel].mean) - stds_list.append(nwp_config.normalisation_constants[channel].std) + norm_conf = nwp_config.normalisation_constants[channel] + + means_list.append(norm_conf.mean) + stds_list.append(norm_conf.std) + clip_min_list.append(norm_conf.clip_min) + clip_max_list.append(norm_conf.clip_max) means_dict["nwp"][nwp_key] = np.array(means_list)[None, :, None, None] stds_dict["nwp"][nwp_key] = np.array(stds_list)[None, :, None, None] + clip_min_dict["nwp"][nwp_key] = np.array(clip_min_list)[None, :, None, None] + clip_max_dict["nwp"][nwp_key] = np.array(clip_max_list)[None, :, None, None] if config.input_data.satellite is not None: sat_config = config.input_data.satellite means_list = [] stds_list = [] + clip_min_list = [] + clip_max_list = [] - for channel in list(config.input_data.satellite.channels): - means_list.append(sat_config.normalisation_constants[channel].mean) - stds_list.append(sat_config.normalisation_constants[channel].std) + for channel in list(sat_config.channels): + norm_conf = sat_config.normalisation_constants[channel] + means_list.append(norm_conf.mean) + stds_list.append(norm_conf.std) + clip_min_list.append(norm_conf.clip_min) + clip_max_list.append(norm_conf.clip_max) # Convert to array and expand dimensions so we can normalise the 4D sat and NWP sources means_dict["sat"] = np.array(means_list)[None, :, None, None] stds_dict["sat"] = np.array(stds_list)[None, :, None, None] + clip_min_dict["sat"] = np.array(clip_min_list)[None, :, None, None] + clip_max_dict["sat"] = np.array(clip_max_list)[None, :, None, None] - return means_dict, stds_dict + return means_dict, stds_dict, clip_min_dict, clip_max_dict diff --git a/tests/test_data/configs/pvnet_test_config.yaml b/tests/test_data/configs/pvnet_test_config.yaml index 55d7ccd8..38664772 100644 --- a/tests/test_data/configs/pvnet_test_config.yaml +++ b/tests/test_data/configs/pvnet_test_config.yaml @@ -30,6 +30,8 @@ input_data: t: mean: 283.64913206 std: 4.38818501 + clip_min: 270 + clip_max: 310 satellite: zarr_path: set_in_temp_file diff --git a/tests/test_data/configs/test_config.yaml b/tests/test_data/configs/test_config.yaml deleted file mode 100644 index 38203a80..00000000 --- a/tests/test_data/configs/test_config.yaml +++ /dev/null @@ -1,49 +0,0 @@ -general: - description: test example configuration - name: example - -input_data: - generation: - zarr_path: tests/data/gsp/test.zarr - interval_start_minutes: -60 - interval_end_minutes: 120 - time_resolution_minutes: 30 - dropout_timedeltas_minutes: [-30] - dropout_fraction: 0.1 - nwp: - ukv: - zarr_path: tests/data/nwp_data/test.zarr - provider: "ukv" - interval_start_minutes: -60 - interval_end_minutes: 120 - time_resolution_minutes: 60 - channels: - - t - accum_channels: - - t - image_size_pixels_height: 2 - image_size_pixels_width: 2 - dropout_timedeltas_minutes: [-180] - dropout_fraction: 1.0 - max_staleness_minutes: null - normalisation_constants: - t: - mean: 283.64913206 - std: 4.38818501 - diff_t: - mean: 0.0 - std: 1.0 - - satellite: - zarr_path: tests/data/sat_data.zarr - time_resolution_minutes: 15 - interval_start_minutes: -60 - interval_end_minutes: 0 - channels: - - IR_016 - image_size_pixels_height: 24 - image_size_pixels_width: 24 - normalisation_constants: - IR_016: - mean: 0.17594202 - std: 0.21462157 From dbd78eda19d9bff3291fe1fbee2a5093d313ef08 Mon Sep 17 00:00:00 2001 From: James Fulton Date: Mon, 23 Mar 2026 18:01:34 +0000 Subject: [PATCH 2/5] update for new satellite format --- ocf_data_sampler/load/satellite.py | 35 +++++++++++++++++++----------- 1 file changed, 22 insertions(+), 13 deletions(-) diff --git a/ocf_data_sampler/load/satellite.py b/ocf_data_sampler/load/satellite.py index 30371823..731b8174 100755 --- a/ocf_data_sampler/load/satellite.py +++ b/ocf_data_sampler/load/satellite.py @@ -1,6 +1,7 @@ """Satellite loader.""" import numpy as np import xarray as xr +import json from ocf_data_sampler.load.open_xarray_tensorstore import open_zarr, open_zarrs from ocf_data_sampler.load.utils import ( @@ -23,24 +24,32 @@ def open_sat_data(zarr_path: str | list[str]) -> xr.DataArray: else: ds = open_zarr(zarr_path) - check_time_unique_increasing(ds.time) + rename_dict = { + "variable": "channel", + "time": "time_utc", + } - ds = ds.rename( - { - "variable": "channel", - "time": "time_utc", - }, - ) + for old_name, new_name in list(rename_dict.items()): + if old_name not in ds: + if new_name in ds: + del rename_dict[old_name] + else: + raise KeyError(f"Expected either '{old_name}' or '{new_name}' to be in dataset") + ds = ds.rename(rename_dict) check_time_unique_increasing(ds.time_utc) ds = make_spatial_coords_increasing(ds, x_coord="x_geostationary", y_coord="y_geostationary") ds = ds.transpose("time_utc", "channel", "x_geostationary", "y_geostationary") - data_array = get_xr_data_array_from_xr_dataset(ds) + da = get_xr_data_array_from_xr_dataset(ds) + + # Copy the area attribute if missing + if "area" not in da.attrs: + da.attrs["area"] = json.dumps(ds.attrs["area"]) # Validate data types directly loading function - if not np.issubdtype(data_array.dtype, np.number): - raise TypeError(f"Satellite data should be numeric, not {data_array.dtype}") + if not np.issubdtype(da.dtype, np.number): + raise TypeError(f"Satellite data should be numeric, not {da.dtype}") coord_dtypes = { "time_utc": np.datetime64, @@ -50,8 +59,8 @@ def open_sat_data(zarr_path: str | list[str]) -> xr.DataArray: } for coord, expected_dtype in coord_dtypes.items(): - if not np.issubdtype(data_array.coords[coord].dtype, expected_dtype): - dtype = data_array.coords[coord].dtype + if not np.issubdtype(da.coords[coord].dtype, expected_dtype): + dtype = da.coords[coord].dtype raise TypeError(f"{coord} should be {expected_dtype.__name__}, not {dtype}") - return data_array + return da From 6c295a59043b3dbac57c8c2b79c4d9870642524e Mon Sep 17 00:00:00 2001 From: James Fulton Date: Mon, 23 Mar 2026 18:24:40 +0000 Subject: [PATCH 3/5] Fix tests --- ocf_data_sampler/config/model.py | 16 ++++++++++------ .../config_normalization_values_to_dicts.py | 8 ++++---- tests/load/test_load_satellite.py | 2 ++ 3 files changed, 16 insertions(+), 10 deletions(-) diff --git a/ocf_data_sampler/config/model.py b/ocf_data_sampler/config/model.py index bff3211d..8eba1f36 100644 --- a/ocf_data_sampler/config/model.py +++ b/ocf_data_sampler/config/model.py @@ -160,24 +160,28 @@ class NormalisationValues(Base): """Normalisation parameters.""" mean: float = Field(..., description="Mean value for normalization") std: float = Field(..., gt=0, description="Standard deviation (must be positive)") - clip_min: float = Field( - float('-inf'), + clip_min: float | None = Field( + None, description="Minimum value to clip to before normalisation. If None, no clipping is applied" ) - clip_max: float = Field( - float('inf'), + clip_max: float | None = Field( + None, description="Maximum value to clip to before normalisation. If None, no clipping is applied" ) @model_validator(mode="after") def validate_clip_range(self) -> "NormalisationValues": - """Validate that clip_min is less than clip_max.""" - if self.clip_min >= self.clip_max: + if ( + self.clip_min is not None + and self.clip_max is not None + and self.clip_min >= self.clip_max + ): raise ValueError( f"clip_min ({self.clip_min}) must be less than clip_max ({self.clip_max})" ) return self + class NormalisationConstantsMixin(Base): """Normalisation constants for multiple channels.""" normalisation_constants: dict[str, NormalisationValues] diff --git a/ocf_data_sampler/torch_datasets/utils/config_normalization_values_to_dicts.py b/ocf_data_sampler/torch_datasets/utils/config_normalization_values_to_dicts.py index 0e65a695..eeaad34f 100644 --- a/ocf_data_sampler/torch_datasets/utils/config_normalization_values_to_dicts.py +++ b/ocf_data_sampler/torch_datasets/utils/config_normalization_values_to_dicts.py @@ -48,8 +48,8 @@ def config_normalization_values_to_dicts( means_list.append(norm_conf.mean) stds_list.append(norm_conf.std) - clip_min_list.append(norm_conf.clip_min) - clip_max_list.append(norm_conf.clip_max) + clip_min_list.append(norm_conf.clip_min or -np.inf) + clip_max_list.append(norm_conf.clip_max or np.inf) means_dict["nwp"][nwp_key] = np.array(means_list)[None, :, None, None] stds_dict["nwp"][nwp_key] = np.array(stds_list)[None, :, None, None] @@ -68,8 +68,8 @@ def config_normalization_values_to_dicts( norm_conf = sat_config.normalisation_constants[channel] means_list.append(norm_conf.mean) stds_list.append(norm_conf.std) - clip_min_list.append(norm_conf.clip_min) - clip_max_list.append(norm_conf.clip_max) + clip_min_list.append(norm_conf.clip_min or -np.inf) + clip_max_list.append(norm_conf.clip_max or np.inf) # Convert to array and expand dimensions so we can normalise the 4D sat and NWP sources means_dict["sat"] = np.array(means_list)[None, :, None, None] diff --git a/tests/load/test_load_satellite.py b/tests/load/test_load_satellite.py index ae86c9cf..d243e973 100755 --- a/tests/load/test_load_satellite.py +++ b/tests/load/test_load_satellite.py @@ -39,6 +39,7 @@ def test_open_satellite_bad_dtype(tmp_path: Path): "y_geostationary": np.arange(4), "x_geostationary": np.arange(4), }, + attrs={"area": "area_info"}, ) bad_ds.to_zarr(zarr_path) @@ -60,6 +61,7 @@ def test_open_satellite_bad_dtype_spatial_coords(tmp_path: Path): "y_geostationary": np.arange(4), "x_geostationary": np.arange(4), }, + attrs={"area": "area_info"}, ) bad_ds.to_zarr(zarr_path) with pytest.raises(TypeError, match="geostationary should be floating"): From 6b021833bde8559a03d0a87bbf1b4838099a2a52 Mon Sep 17 00:00:00 2001 From: James Fulton Date: Mon, 23 Mar 2026 18:33:08 +0000 Subject: [PATCH 4/5] lint --- ocf_data_sampler/config/model.py | 13 ++++++++----- ocf_data_sampler/load/satellite.py | 3 ++- ocf_data_sampler/torch_datasets/pvnet_dataset.py | 2 +- 3 files changed, 11 insertions(+), 7 deletions(-) diff --git a/ocf_data_sampler/config/model.py b/ocf_data_sampler/config/model.py index 8eba1f36..9f65af05 100644 --- a/ocf_data_sampler/config/model.py +++ b/ocf_data_sampler/config/model.py @@ -162,22 +162,25 @@ class NormalisationValues(Base): std: float = Field(..., gt=0, description="Standard deviation (must be positive)") clip_min: float | None = Field( None, - description="Minimum value to clip to before normalisation. If None, no clipping is applied" + description="Minimum value to clip to before normalisation. If None, no clipping is " + "applied", ) clip_max: float | None = Field( None, - description="Maximum value to clip to before normalisation. If None, no clipping is applied" + description="Maximum value to clip to before normalisation. If None, no clipping is " + "applied", ) @model_validator(mode="after") def validate_clip_range(self) -> "NormalisationValues": + """"Validate that if both clip_min and clip_max are provided, then clip_min < clip_max.""" if ( - self.clip_min is not None - and self.clip_max is not None + self.clip_min is not None + and self.clip_max is not None and self.clip_min >= self.clip_max ): raise ValueError( - f"clip_min ({self.clip_min}) must be less than clip_max ({self.clip_max})" + f"clip_min ({self.clip_min}) must be less than clip_max ({self.clip_max})", ) return self diff --git a/ocf_data_sampler/load/satellite.py b/ocf_data_sampler/load/satellite.py index 731b8174..eacfe5d8 100755 --- a/ocf_data_sampler/load/satellite.py +++ b/ocf_data_sampler/load/satellite.py @@ -1,7 +1,8 @@ """Satellite loader.""" +import json + import numpy as np import xarray as xr -import json from ocf_data_sampler.load.open_xarray_tensorstore import open_zarr, open_zarrs from ocf_data_sampler.load.utils import ( diff --git a/ocf_data_sampler/torch_datasets/pvnet_dataset.py b/ocf_data_sampler/torch_datasets/pvnet_dataset.py index d7119149..6980f952 100644 --- a/ocf_data_sampler/torch_datasets/pvnet_dataset.py +++ b/ocf_data_sampler/torch_datasets/pvnet_dataset.py @@ -214,7 +214,7 @@ def process_and_combine_datasets( channel_mins = self.clip_min_dict["sat"] channel_maxs = self.clip_max_dict["sat"] dataset_dict["sat"].data = ( - (dataset_dict["sat"].data.clip(channel_mins, channel_maxs) - channel_means) + (dataset_dict["sat"].data.clip(channel_mins, channel_maxs) - channel_means) / channel_stds ) From 67bcd918838cc956fd2cf3b02eca640a6ff21d84 Mon Sep 17 00:00:00 2001 From: James Fulton Date: Thu, 26 Mar 2026 09:56:26 +0000 Subject: [PATCH 5/5] fix edge case --- .../utils/config_normalization_values_to_dicts.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/ocf_data_sampler/torch_datasets/utils/config_normalization_values_to_dicts.py b/ocf_data_sampler/torch_datasets/utils/config_normalization_values_to_dicts.py index eeaad34f..25def6e1 100644 --- a/ocf_data_sampler/torch_datasets/utils/config_normalization_values_to_dicts.py +++ b/ocf_data_sampler/torch_datasets/utils/config_normalization_values_to_dicts.py @@ -48,8 +48,8 @@ def config_normalization_values_to_dicts( means_list.append(norm_conf.mean) stds_list.append(norm_conf.std) - clip_min_list.append(norm_conf.clip_min or -np.inf) - clip_max_list.append(norm_conf.clip_max or np.inf) + clip_min_list.append(-np.inf if norm_conf.clip_min is None else norm_conf.clip_min) + clip_max_list.append(np.inf if norm_conf.clip_max is None else norm_conf.clip_max) means_dict["nwp"][nwp_key] = np.array(means_list)[None, :, None, None] stds_dict["nwp"][nwp_key] = np.array(stds_list)[None, :, None, None] @@ -68,8 +68,8 @@ def config_normalization_values_to_dicts( norm_conf = sat_config.normalisation_constants[channel] means_list.append(norm_conf.mean) stds_list.append(norm_conf.std) - clip_min_list.append(norm_conf.clip_min or -np.inf) - clip_max_list.append(norm_conf.clip_max or np.inf) + clip_min_list.append(-np.inf if norm_conf.clip_min is None else norm_conf.clip_min) + clip_max_list.append(np.inf if norm_conf.clip_max is None else norm_conf.clip_max) # Convert to array and expand dimensions so we can normalise the 4D sat and NWP sources means_dict["sat"] = np.array(means_list)[None, :, None, None]