diff --git a/ocf_data_sampler/config/model.py b/ocf_data_sampler/config/model.py index d9175481..9f65af05 100644 --- a/ocf_data_sampler/config/model.py +++ b/ocf_data_sampler/config/model.py @@ -157,9 +157,32 @@ class SpatialWindowMixin(Base): class NormalisationValues(Base): - """Normalisation mean and standard deviation.""" + """Normalisation parameters.""" mean: float = Field(..., description="Mean value for normalization") std: float = Field(..., gt=0, description="Standard deviation (must be positive)") + clip_min: float | None = Field( + None, + description="Minimum value to clip to before normalisation. If None, no clipping is " + "applied", + ) + clip_max: float | None = Field( + None, + description="Maximum value to clip to before normalisation. If None, no clipping is " + "applied", + ) + + @model_validator(mode="after") + def validate_clip_range(self) -> "NormalisationValues": + """"Validate that if both clip_min and clip_max are provided, then clip_min < clip_max.""" + if ( + self.clip_min is not None + and self.clip_max is not None + and self.clip_min >= self.clip_max + ): + raise ValueError( + f"clip_min ({self.clip_min}) must be less than clip_max ({self.clip_max})", + ) + return self class NormalisationConstantsMixin(Base): diff --git a/ocf_data_sampler/load/satellite.py b/ocf_data_sampler/load/satellite.py index 30371823..eacfe5d8 100755 --- a/ocf_data_sampler/load/satellite.py +++ b/ocf_data_sampler/load/satellite.py @@ -1,4 +1,6 @@ """Satellite loader.""" +import json + import numpy as np import xarray as xr @@ -23,24 +25,32 @@ def open_sat_data(zarr_path: str | list[str]) -> xr.DataArray: else: ds = open_zarr(zarr_path) - check_time_unique_increasing(ds.time) + rename_dict = { + "variable": "channel", + "time": "time_utc", + } - ds = ds.rename( - { - "variable": "channel", - "time": "time_utc", - }, - ) + for old_name, new_name in list(rename_dict.items()): + if old_name not in ds: + if new_name in ds: + del rename_dict[old_name] + else: + raise KeyError(f"Expected either '{old_name}' or '{new_name}' to be in dataset") + ds = ds.rename(rename_dict) check_time_unique_increasing(ds.time_utc) ds = make_spatial_coords_increasing(ds, x_coord="x_geostationary", y_coord="y_geostationary") ds = ds.transpose("time_utc", "channel", "x_geostationary", "y_geostationary") - data_array = get_xr_data_array_from_xr_dataset(ds) + da = get_xr_data_array_from_xr_dataset(ds) + + # Copy the area attribute if missing + if "area" not in da.attrs: + da.attrs["area"] = json.dumps(ds.attrs["area"]) # Validate data types directly loading function - if not np.issubdtype(data_array.dtype, np.number): - raise TypeError(f"Satellite data should be numeric, not {data_array.dtype}") + if not np.issubdtype(da.dtype, np.number): + raise TypeError(f"Satellite data should be numeric, not {da.dtype}") coord_dtypes = { "time_utc": np.datetime64, @@ -50,8 +60,8 @@ def open_sat_data(zarr_path: str | list[str]) -> xr.DataArray: } for coord, expected_dtype in coord_dtypes.items(): - if not np.issubdtype(data_array.coords[coord].dtype, expected_dtype): - dtype = data_array.coords[coord].dtype + if not np.issubdtype(da.coords[coord].dtype, expected_dtype): + dtype = da.coords[coord].dtype raise TypeError(f"{coord} should be {expected_dtype.__name__}, not {dtype}") - return data_array + return da diff --git a/ocf_data_sampler/torch_datasets/pvnet_dataset.py b/ocf_data_sampler/torch_datasets/pvnet_dataset.py index c7adff72..6980f952 100644 --- a/ocf_data_sampler/torch_datasets/pvnet_dataset.py +++ b/ocf_data_sampler/torch_datasets/pvnet_dataset.py @@ -174,9 +174,13 @@ def __init__( ) # Extract the normalisation values from the config for faster access - means_dict, stds_dict = config_normalization_values_to_dicts(config) - self.means_dict = means_dict - self.stds_dict = stds_dict + mean_dict, std_dict, clip_min_dict, clip_max_dict = ( + config_normalization_values_to_dicts(config) + ) + self.mean_dict = mean_dict + self.std_dict = std_dict + self.clip_min_dict = clip_min_dict + self.clip_max_dict = clip_max_dict def process_and_combine_datasets( self, @@ -194,15 +198,25 @@ def process_and_combine_datasets( # Normalise NWP if "nwp" in dataset_dict: for nwp_key, da_nwp in dataset_dict["nwp"].items(): - channel_means = self.means_dict["nwp"][nwp_key] - channel_stds = self.stds_dict["nwp"][nwp_key] - dataset_dict["nwp"][nwp_key] = (da_nwp - channel_means) / channel_stds + channel_means = self.mean_dict["nwp"][nwp_key] + channel_stds = self.std_dict["nwp"][nwp_key] + channel_mins = self.clip_min_dict["nwp"][nwp_key] + channel_maxs = self.clip_max_dict["nwp"][nwp_key] + dataset_dict["nwp"][nwp_key].data = ( + (da_nwp.data.clip(channel_mins, channel_maxs) - channel_means) + / channel_stds + ) # Normalise satellite if "sat" in dataset_dict: - channel_means = self.means_dict["sat"] - channel_stds = self.stds_dict["sat"] - dataset_dict["sat"] = (dataset_dict["sat"] - channel_means) / channel_stds + channel_means = self.mean_dict["sat"] + channel_stds = self.std_dict["sat"] + channel_mins = self.clip_min_dict["sat"] + channel_maxs = self.clip_max_dict["sat"] + dataset_dict["sat"].data = ( + (dataset_dict["sat"].data.clip(channel_mins, channel_maxs) - channel_means) + / channel_stds + ) # Normalise generation by capacity if "generation" in dataset_dict: diff --git a/ocf_data_sampler/torch_datasets/utils/config_normalization_values_to_dicts.py b/ocf_data_sampler/torch_datasets/utils/config_normalization_values_to_dicts.py index 9f8b1123..25def6e1 100644 --- a/ocf_data_sampler/torch_datasets/utils/config_normalization_values_to_dicts.py +++ b/ocf_data_sampler/torch_datasets/utils/config_normalization_values_to_dicts.py @@ -8,7 +8,7 @@ def config_normalization_values_to_dicts( config: Configuration, ) -> tuple[dict[str, np.ndarray | dict[str, np.ndarray]]]: - """Construct numpy arrays of mean and std values from the config normalisation constants. + """Construct numpy arrays of mean, std, and clip values from the config normalisation constants. Args: config: Data configuration. @@ -16,44 +16,65 @@ def config_normalization_values_to_dicts( Returns: Means dict Stds dict + Clip min dict + Clip max dict """ means_dict = {} stds_dict = {} + clip_min_dict = {} + clip_max_dict = {} if config.input_data.nwp is not None: means_dict["nwp"] = {} stds_dict["nwp"] = {} + clip_min_dict["nwp"] = {} + clip_max_dict["nwp"] = {} for nwp_key in config.input_data.nwp: nwp_config = config.input_data.nwp[nwp_key] means_list = [] stds_list = [] + clip_min_list = [] + clip_max_list = [] for channel in list(nwp_config.channels): # These accumulated channels are diffed and renamed if channel in nwp_config.accum_channels: channel =f"diff_{channel}" - means_list.append(nwp_config.normalisation_constants[channel].mean) - stds_list.append(nwp_config.normalisation_constants[channel].std) + norm_conf = nwp_config.normalisation_constants[channel] + + means_list.append(norm_conf.mean) + stds_list.append(norm_conf.std) + clip_min_list.append(-np.inf if norm_conf.clip_min is None else norm_conf.clip_min) + clip_max_list.append(np.inf if norm_conf.clip_max is None else norm_conf.clip_max) means_dict["nwp"][nwp_key] = np.array(means_list)[None, :, None, None] stds_dict["nwp"][nwp_key] = np.array(stds_list)[None, :, None, None] + clip_min_dict["nwp"][nwp_key] = np.array(clip_min_list)[None, :, None, None] + clip_max_dict["nwp"][nwp_key] = np.array(clip_max_list)[None, :, None, None] if config.input_data.satellite is not None: sat_config = config.input_data.satellite means_list = [] stds_list = [] + clip_min_list = [] + clip_max_list = [] - for channel in list(config.input_data.satellite.channels): - means_list.append(sat_config.normalisation_constants[channel].mean) - stds_list.append(sat_config.normalisation_constants[channel].std) + for channel in list(sat_config.channels): + norm_conf = sat_config.normalisation_constants[channel] + means_list.append(norm_conf.mean) + stds_list.append(norm_conf.std) + clip_min_list.append(-np.inf if norm_conf.clip_min is None else norm_conf.clip_min) + clip_max_list.append(np.inf if norm_conf.clip_max is None else norm_conf.clip_max) # Convert to array and expand dimensions so we can normalise the 4D sat and NWP sources means_dict["sat"] = np.array(means_list)[None, :, None, None] stds_dict["sat"] = np.array(stds_list)[None, :, None, None] + clip_min_dict["sat"] = np.array(clip_min_list)[None, :, None, None] + clip_max_dict["sat"] = np.array(clip_max_list)[None, :, None, None] - return means_dict, stds_dict + return means_dict, stds_dict, clip_min_dict, clip_max_dict diff --git a/tests/load/test_load_satellite.py b/tests/load/test_load_satellite.py index ae86c9cf..d243e973 100755 --- a/tests/load/test_load_satellite.py +++ b/tests/load/test_load_satellite.py @@ -39,6 +39,7 @@ def test_open_satellite_bad_dtype(tmp_path: Path): "y_geostationary": np.arange(4), "x_geostationary": np.arange(4), }, + attrs={"area": "area_info"}, ) bad_ds.to_zarr(zarr_path) @@ -60,6 +61,7 @@ def test_open_satellite_bad_dtype_spatial_coords(tmp_path: Path): "y_geostationary": np.arange(4), "x_geostationary": np.arange(4), }, + attrs={"area": "area_info"}, ) bad_ds.to_zarr(zarr_path) with pytest.raises(TypeError, match="geostationary should be floating"): diff --git a/tests/test_data/configs/pvnet_test_config.yaml b/tests/test_data/configs/pvnet_test_config.yaml index 55d7ccd8..38664772 100644 --- a/tests/test_data/configs/pvnet_test_config.yaml +++ b/tests/test_data/configs/pvnet_test_config.yaml @@ -30,6 +30,8 @@ input_data: t: mean: 283.64913206 std: 4.38818501 + clip_min: 270 + clip_max: 310 satellite: zarr_path: set_in_temp_file diff --git a/tests/test_data/configs/test_config.yaml b/tests/test_data/configs/test_config.yaml deleted file mode 100644 index 38203a80..00000000 --- a/tests/test_data/configs/test_config.yaml +++ /dev/null @@ -1,49 +0,0 @@ -general: - description: test example configuration - name: example - -input_data: - generation: - zarr_path: tests/data/gsp/test.zarr - interval_start_minutes: -60 - interval_end_minutes: 120 - time_resolution_minutes: 30 - dropout_timedeltas_minutes: [-30] - dropout_fraction: 0.1 - nwp: - ukv: - zarr_path: tests/data/nwp_data/test.zarr - provider: "ukv" - interval_start_minutes: -60 - interval_end_minutes: 120 - time_resolution_minutes: 60 - channels: - - t - accum_channels: - - t - image_size_pixels_height: 2 - image_size_pixels_width: 2 - dropout_timedeltas_minutes: [-180] - dropout_fraction: 1.0 - max_staleness_minutes: null - normalisation_constants: - t: - mean: 283.64913206 - std: 4.38818501 - diff_t: - mean: 0.0 - std: 1.0 - - satellite: - zarr_path: tests/data/sat_data.zarr - time_resolution_minutes: 15 - interval_start_minutes: -60 - interval_end_minutes: 0 - channels: - - IR_016 - image_size_pixels_height: 24 - image_size_pixels_width: 24 - normalisation_constants: - IR_016: - mean: 0.17594202 - std: 0.21462157