Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 24 additions & 1 deletion ocf_data_sampler/config/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,9 +157,32 @@ class SpatialWindowMixin(Base):


class NormalisationValues(Base):
"""Normalisation mean and standard deviation."""
"""Normalisation parameters."""
mean: float = Field(..., description="Mean value for normalization")
std: float = Field(..., gt=0, description="Standard deviation (must be positive)")
clip_min: float | None = Field(
None,
description="Minimum value to clip to before normalisation. If None, no clipping is "
"applied",
)
clip_max: float | None = Field(
None,
description="Maximum value to clip to before normalisation. If None, no clipping is "
"applied",
)

@model_validator(mode="after")
def validate_clip_range(self) -> "NormalisationValues":
""""Validate that if both clip_min and clip_max are provided, then clip_min < clip_max."""
if (
self.clip_min is not None
and self.clip_max is not None
and self.clip_min >= self.clip_max
):
raise ValueError(
f"clip_min ({self.clip_min}) must be less than clip_max ({self.clip_max})",
)
return self


class NormalisationConstantsMixin(Base):
Expand Down
36 changes: 23 additions & 13 deletions ocf_data_sampler/load/satellite.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
"""Satellite loader."""
import json

import numpy as np
import xarray as xr

Expand All @@ -23,24 +25,32 @@ def open_sat_data(zarr_path: str | list[str]) -> xr.DataArray:
else:
ds = open_zarr(zarr_path)

check_time_unique_increasing(ds.time)
rename_dict = {
"variable": "channel",
"time": "time_utc",
}

ds = ds.rename(
{
"variable": "channel",
"time": "time_utc",
},
)
for old_name, new_name in list(rename_dict.items()):
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This bit is to allow for the new satellite data where the channel dimension is already called channel

if old_name not in ds:
if new_name in ds:
del rename_dict[old_name]
else:
raise KeyError(f"Expected either '{old_name}' or '{new_name}' to be in dataset")
ds = ds.rename(rename_dict)

check_time_unique_increasing(ds.time_utc)
ds = make_spatial_coords_increasing(ds, x_coord="x_geostationary", y_coord="y_geostationary")
ds = ds.transpose("time_utc", "channel", "x_geostationary", "y_geostationary")

data_array = get_xr_data_array_from_xr_dataset(ds)
da = get_xr_data_array_from_xr_dataset(ds)
Copy link
Copy Markdown
Member Author

@dfulu dfulu Mar 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I renamed data_array --> da to make the style here more consistent


# Copy the area attribute if missing
if "area" not in da.attrs:
da.attrs["area"] = json.dumps(ds.attrs["area"])

# Validate data types directly loading function
if not np.issubdtype(data_array.dtype, np.number):
raise TypeError(f"Satellite data should be numeric, not {data_array.dtype}")
if not np.issubdtype(da.dtype, np.number):
raise TypeError(f"Satellite data should be numeric, not {da.dtype}")

coord_dtypes = {
"time_utc": np.datetime64,
Expand All @@ -50,8 +60,8 @@ def open_sat_data(zarr_path: str | list[str]) -> xr.DataArray:
}

for coord, expected_dtype in coord_dtypes.items():
if not np.issubdtype(data_array.coords[coord].dtype, expected_dtype):
dtype = data_array.coords[coord].dtype
if not np.issubdtype(da.coords[coord].dtype, expected_dtype):
dtype = da.coords[coord].dtype
raise TypeError(f"{coord} should be {expected_dtype.__name__}, not {dtype}")

return data_array
return da
32 changes: 23 additions & 9 deletions ocf_data_sampler/torch_datasets/pvnet_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,9 +174,13 @@ def __init__(
)

# Extract the normalisation values from the config for faster access
means_dict, stds_dict = config_normalization_values_to_dicts(config)
self.means_dict = means_dict
self.stds_dict = stds_dict
mean_dict, std_dict, clip_min_dict, clip_max_dict = (
config_normalization_values_to_dicts(config)
)
self.mean_dict = mean_dict
self.std_dict = std_dict
self.clip_min_dict = clip_min_dict
self.clip_max_dict = clip_max_dict

def process_and_combine_datasets(
self,
Expand All @@ -194,15 +198,25 @@ def process_and_combine_datasets(
# Normalise NWP
if "nwp" in dataset_dict:
for nwp_key, da_nwp in dataset_dict["nwp"].items():
channel_means = self.means_dict["nwp"][nwp_key]
channel_stds = self.stds_dict["nwp"][nwp_key]
dataset_dict["nwp"][nwp_key] = (da_nwp - channel_means) / channel_stds
channel_means = self.mean_dict["nwp"][nwp_key]
channel_stds = self.std_dict["nwp"][nwp_key]
channel_mins = self.clip_min_dict["nwp"][nwp_key]
channel_maxs = self.clip_max_dict["nwp"][nwp_key]
dataset_dict["nwp"][nwp_key].data = (
(da_nwp.data.clip(channel_mins, channel_maxs) - channel_means)
/ channel_stds
)

# Normalise satellite
if "sat" in dataset_dict:
channel_means = self.means_dict["sat"]
channel_stds = self.stds_dict["sat"]
dataset_dict["sat"] = (dataset_dict["sat"] - channel_means) / channel_stds
channel_means = self.mean_dict["sat"]
channel_stds = self.std_dict["sat"]
channel_mins = self.clip_min_dict["sat"]
channel_maxs = self.clip_max_dict["sat"]
dataset_dict["sat"].data = (
(dataset_dict["sat"].data.clip(channel_mins, channel_maxs) - channel_means)
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This applies a different clip value per channel

/ channel_stds
)

# Normalise generation by capacity
if "generation" in dataset_dict:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,52 +8,73 @@
def config_normalization_values_to_dicts(
config: Configuration,
) -> tuple[dict[str, np.ndarray | dict[str, np.ndarray]]]:
"""Construct numpy arrays of mean and std values from the config normalisation constants.
"""Construct numpy arrays of mean, std, and clip values from the config normalisation constants.

Args:
config: Data configuration.

Returns:
Means dict
Stds dict
Clip min dict
Clip max dict
"""
means_dict = {}
stds_dict = {}
clip_min_dict = {}
clip_max_dict = {}

if config.input_data.nwp is not None:

means_dict["nwp"] = {}
stds_dict["nwp"] = {}
clip_min_dict["nwp"] = {}
clip_max_dict["nwp"] = {}

for nwp_key in config.input_data.nwp:
nwp_config = config.input_data.nwp[nwp_key]

means_list = []
stds_list = []
clip_min_list = []
clip_max_list = []

for channel in list(nwp_config.channels):
# These accumulated channels are diffed and renamed
if channel in nwp_config.accum_channels:
channel =f"diff_{channel}"

means_list.append(nwp_config.normalisation_constants[channel].mean)
stds_list.append(nwp_config.normalisation_constants[channel].std)
norm_conf = nwp_config.normalisation_constants[channel]

means_list.append(norm_conf.mean)
stds_list.append(norm_conf.std)
clip_min_list.append(-np.inf if norm_conf.clip_min is None else norm_conf.clip_min)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cheers!

clip_max_list.append(np.inf if norm_conf.clip_max is None else norm_conf.clip_max)

means_dict["nwp"][nwp_key] = np.array(means_list)[None, :, None, None]
stds_dict["nwp"][nwp_key] = np.array(stds_list)[None, :, None, None]
clip_min_dict["nwp"][nwp_key] = np.array(clip_min_list)[None, :, None, None]
clip_max_dict["nwp"][nwp_key] = np.array(clip_max_list)[None, :, None, None]

if config.input_data.satellite is not None:
sat_config = config.input_data.satellite

means_list = []
stds_list = []
clip_min_list = []
clip_max_list = []

for channel in list(config.input_data.satellite.channels):
means_list.append(sat_config.normalisation_constants[channel].mean)
stds_list.append(sat_config.normalisation_constants[channel].std)
for channel in list(sat_config.channels):
norm_conf = sat_config.normalisation_constants[channel]
means_list.append(norm_conf.mean)
stds_list.append(norm_conf.std)
clip_min_list.append(-np.inf if norm_conf.clip_min is None else norm_conf.clip_min)
clip_max_list.append(np.inf if norm_conf.clip_max is None else norm_conf.clip_max)

# Convert to array and expand dimensions so we can normalise the 4D sat and NWP sources
means_dict["sat"] = np.array(means_list)[None, :, None, None]
stds_dict["sat"] = np.array(stds_list)[None, :, None, None]
clip_min_dict["sat"] = np.array(clip_min_list)[None, :, None, None]
clip_max_dict["sat"] = np.array(clip_max_list)[None, :, None, None]

return means_dict, stds_dict
return means_dict, stds_dict, clip_min_dict, clip_max_dict
2 changes: 2 additions & 0 deletions tests/load/test_load_satellite.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def test_open_satellite_bad_dtype(tmp_path: Path):
"y_geostationary": np.arange(4),
"x_geostationary": np.arange(4),
},
attrs={"area": "area_info"},
)
bad_ds.to_zarr(zarr_path)

Expand All @@ -60,6 +61,7 @@ def test_open_satellite_bad_dtype_spatial_coords(tmp_path: Path):
"y_geostationary": np.arange(4),
"x_geostationary": np.arange(4),
},
attrs={"area": "area_info"},
)
bad_ds.to_zarr(zarr_path)
with pytest.raises(TypeError, match="geostationary should be floating"):
Expand Down
2 changes: 2 additions & 0 deletions tests/test_data/configs/pvnet_test_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ input_data:
t:
mean: 283.64913206
std: 4.38818501
clip_min: 270
clip_max: 310

satellite:
zarr_path: set_in_temp_file
Expand Down
49 changes: 0 additions & 49 deletions tests/test_data/configs/test_config.yaml

This file was deleted.

Loading