-
-
Notifications
You must be signed in to change notification settings - Fork 47
Add optional clip #406
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add optional clip #406
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,4 +1,6 @@ | ||
| """Satellite loader.""" | ||
| import json | ||
|
|
||
| import numpy as np | ||
| import xarray as xr | ||
|
|
||
|
|
@@ -23,24 +25,32 @@ def open_sat_data(zarr_path: str | list[str]) -> xr.DataArray: | |
| else: | ||
| ds = open_zarr(zarr_path) | ||
|
|
||
| check_time_unique_increasing(ds.time) | ||
| rename_dict = { | ||
| "variable": "channel", | ||
| "time": "time_utc", | ||
| } | ||
|
|
||
| ds = ds.rename( | ||
| { | ||
| "variable": "channel", | ||
| "time": "time_utc", | ||
| }, | ||
| ) | ||
| for old_name, new_name in list(rename_dict.items()): | ||
| if old_name not in ds: | ||
| if new_name in ds: | ||
| del rename_dict[old_name] | ||
| else: | ||
| raise KeyError(f"Expected either '{old_name}' or '{new_name}' to be in dataset") | ||
| ds = ds.rename(rename_dict) | ||
|
|
||
| check_time_unique_increasing(ds.time_utc) | ||
| ds = make_spatial_coords_increasing(ds, x_coord="x_geostationary", y_coord="y_geostationary") | ||
| ds = ds.transpose("time_utc", "channel", "x_geostationary", "y_geostationary") | ||
|
|
||
| data_array = get_xr_data_array_from_xr_dataset(ds) | ||
| da = get_xr_data_array_from_xr_dataset(ds) | ||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I renamed |
||
|
|
||
| # Copy the area attribute if missing | ||
| if "area" not in da.attrs: | ||
| da.attrs["area"] = json.dumps(ds.attrs["area"]) | ||
|
|
||
| # Validate data types directly loading function | ||
| if not np.issubdtype(data_array.dtype, np.number): | ||
| raise TypeError(f"Satellite data should be numeric, not {data_array.dtype}") | ||
| if not np.issubdtype(da.dtype, np.number): | ||
| raise TypeError(f"Satellite data should be numeric, not {da.dtype}") | ||
|
|
||
| coord_dtypes = { | ||
| "time_utc": np.datetime64, | ||
|
|
@@ -50,8 +60,8 @@ def open_sat_data(zarr_path: str | list[str]) -> xr.DataArray: | |
| } | ||
|
|
||
| for coord, expected_dtype in coord_dtypes.items(): | ||
| if not np.issubdtype(data_array.coords[coord].dtype, expected_dtype): | ||
| dtype = data_array.coords[coord].dtype | ||
| if not np.issubdtype(da.coords[coord].dtype, expected_dtype): | ||
| dtype = da.coords[coord].dtype | ||
| raise TypeError(f"{coord} should be {expected_dtype.__name__}, not {dtype}") | ||
|
|
||
| return data_array | ||
| return da | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -174,9 +174,13 @@ def __init__( | |
| ) | ||
|
|
||
| # Extract the normalisation values from the config for faster access | ||
| means_dict, stds_dict = config_normalization_values_to_dicts(config) | ||
| self.means_dict = means_dict | ||
| self.stds_dict = stds_dict | ||
| mean_dict, std_dict, clip_min_dict, clip_max_dict = ( | ||
| config_normalization_values_to_dicts(config) | ||
| ) | ||
| self.mean_dict = mean_dict | ||
| self.std_dict = std_dict | ||
| self.clip_min_dict = clip_min_dict | ||
| self.clip_max_dict = clip_max_dict | ||
|
|
||
| def process_and_combine_datasets( | ||
| self, | ||
|
|
@@ -194,15 +198,25 @@ def process_and_combine_datasets( | |
| # Normalise NWP | ||
| if "nwp" in dataset_dict: | ||
| for nwp_key, da_nwp in dataset_dict["nwp"].items(): | ||
| channel_means = self.means_dict["nwp"][nwp_key] | ||
| channel_stds = self.stds_dict["nwp"][nwp_key] | ||
| dataset_dict["nwp"][nwp_key] = (da_nwp - channel_means) / channel_stds | ||
| channel_means = self.mean_dict["nwp"][nwp_key] | ||
| channel_stds = self.std_dict["nwp"][nwp_key] | ||
| channel_mins = self.clip_min_dict["nwp"][nwp_key] | ||
| channel_maxs = self.clip_max_dict["nwp"][nwp_key] | ||
| dataset_dict["nwp"][nwp_key].data = ( | ||
| (da_nwp.data.clip(channel_mins, channel_maxs) - channel_means) | ||
| / channel_stds | ||
| ) | ||
|
|
||
| # Normalise satellite | ||
| if "sat" in dataset_dict: | ||
| channel_means = self.means_dict["sat"] | ||
| channel_stds = self.stds_dict["sat"] | ||
| dataset_dict["sat"] = (dataset_dict["sat"] - channel_means) / channel_stds | ||
| channel_means = self.mean_dict["sat"] | ||
| channel_stds = self.std_dict["sat"] | ||
| channel_mins = self.clip_min_dict["sat"] | ||
| channel_maxs = self.clip_max_dict["sat"] | ||
| dataset_dict["sat"].data = ( | ||
| (dataset_dict["sat"].data.clip(channel_mins, channel_maxs) - channel_means) | ||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This applies a different clip value per channel |
||
| / channel_stds | ||
| ) | ||
|
|
||
| # Normalise generation by capacity | ||
| if "generation" in dataset_dict: | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -8,52 +8,73 @@ | |
| def config_normalization_values_to_dicts( | ||
| config: Configuration, | ||
| ) -> tuple[dict[str, np.ndarray | dict[str, np.ndarray]]]: | ||
| """Construct numpy arrays of mean and std values from the config normalisation constants. | ||
| """Construct numpy arrays of mean, std, and clip values from the config normalisation constants. | ||
|
|
||
| Args: | ||
| config: Data configuration. | ||
|
|
||
| Returns: | ||
| Means dict | ||
| Stds dict | ||
| Clip min dict | ||
| Clip max dict | ||
| """ | ||
| means_dict = {} | ||
| stds_dict = {} | ||
| clip_min_dict = {} | ||
| clip_max_dict = {} | ||
|
|
||
| if config.input_data.nwp is not None: | ||
|
|
||
| means_dict["nwp"] = {} | ||
| stds_dict["nwp"] = {} | ||
| clip_min_dict["nwp"] = {} | ||
| clip_max_dict["nwp"] = {} | ||
|
|
||
| for nwp_key in config.input_data.nwp: | ||
| nwp_config = config.input_data.nwp[nwp_key] | ||
|
|
||
| means_list = [] | ||
| stds_list = [] | ||
| clip_min_list = [] | ||
| clip_max_list = [] | ||
|
|
||
| for channel in list(nwp_config.channels): | ||
| # These accumulated channels are diffed and renamed | ||
| if channel in nwp_config.accum_channels: | ||
| channel =f"diff_{channel}" | ||
|
|
||
| means_list.append(nwp_config.normalisation_constants[channel].mean) | ||
| stds_list.append(nwp_config.normalisation_constants[channel].std) | ||
| norm_conf = nwp_config.normalisation_constants[channel] | ||
|
|
||
| means_list.append(norm_conf.mean) | ||
| stds_list.append(norm_conf.std) | ||
| clip_min_list.append(-np.inf if norm_conf.clip_min is None else norm_conf.clip_min) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Cheers! |
||
| clip_max_list.append(np.inf if norm_conf.clip_max is None else norm_conf.clip_max) | ||
|
|
||
| means_dict["nwp"][nwp_key] = np.array(means_list)[None, :, None, None] | ||
| stds_dict["nwp"][nwp_key] = np.array(stds_list)[None, :, None, None] | ||
| clip_min_dict["nwp"][nwp_key] = np.array(clip_min_list)[None, :, None, None] | ||
| clip_max_dict["nwp"][nwp_key] = np.array(clip_max_list)[None, :, None, None] | ||
|
|
||
| if config.input_data.satellite is not None: | ||
| sat_config = config.input_data.satellite | ||
|
|
||
| means_list = [] | ||
| stds_list = [] | ||
| clip_min_list = [] | ||
| clip_max_list = [] | ||
|
|
||
| for channel in list(config.input_data.satellite.channels): | ||
| means_list.append(sat_config.normalisation_constants[channel].mean) | ||
| stds_list.append(sat_config.normalisation_constants[channel].std) | ||
| for channel in list(sat_config.channels): | ||
| norm_conf = sat_config.normalisation_constants[channel] | ||
| means_list.append(norm_conf.mean) | ||
| stds_list.append(norm_conf.std) | ||
| clip_min_list.append(-np.inf if norm_conf.clip_min is None else norm_conf.clip_min) | ||
| clip_max_list.append(np.inf if norm_conf.clip_max is None else norm_conf.clip_max) | ||
|
|
||
| # Convert to array and expand dimensions so we can normalise the 4D sat and NWP sources | ||
| means_dict["sat"] = np.array(means_list)[None, :, None, None] | ||
| stds_dict["sat"] = np.array(stds_list)[None, :, None, None] | ||
| clip_min_dict["sat"] = np.array(clip_min_list)[None, :, None, None] | ||
| clip_max_dict["sat"] = np.array(clip_max_list)[None, :, None, None] | ||
|
|
||
| return means_dict, stds_dict | ||
| return means_dict, stds_dict, clip_min_dict, clip_max_dict | ||
This file was deleted.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This bit is to allow for the new satellite data where the channel dimension is already called channel