diff --git a/ocf_data_sampler/select/select_spatial_slice.py b/ocf_data_sampler/select/select_spatial_slice.py index 0aaf9bbe..c7fe5495 100644 --- a/ocf_data_sampler/select/select_spatial_slice.py +++ b/ocf_data_sampler/select/select_spatial_slice.py @@ -61,8 +61,8 @@ def select_spatial_slice_pixels( The selected DataArray slice. Raises: - ValueError: If the dimensions are not even or the slice is not allowed - when padding is required. + ValueError: If the window dimensions are not even or the slice extends beyond the data + boundaries. """ if (width_pixels % 2) != 0: raise ValueError("Width must be an even number") @@ -83,7 +83,6 @@ def select_spatial_slice_pixels( data_width_pixels = len(da[x_dim]) data_height_pixels = len(da[y_dim]) - # Padding checks slice_unavailable = ( left_idx < 0 or right_idx > data_width_pixels @@ -104,7 +103,83 @@ def select_spatial_slice_pixels( issue_details = "\n - ".join(issues) raise ValueError(f"Window for location {location} not available: \n - {issue_details}") - # Standard selection - without padding - da = da.isel({x_dim: slice(left_idx, right_idx), y_dim: slice(bottom_idx, top_idx)}) + return da.isel({x_dim: slice(left_idx, right_idx), y_dim: slice(bottom_idx, top_idx)}) - return da + +def select_spatial_slice_pixels_multiple( + da: xr.DataArray, + locations: list[Location], + width_pixels: int, + height_pixels: int, +) -> xr.DataArray: + """Select spatial slice which covers all given locations. + + Args: + da: xarray DataArray to slice from + locations: List of locations of interest that will be covered by the returned slice + height_pixels: Height of the slice in pixels + width_pixels: Width of the slice in pixels + + Returns: + The selected DataArray slice. + + Raises: + ValueError: If the window dimensions are not even or the slice extends beyond the data + boundaries. + """ + if (width_pixels % 2) != 0: + raise ValueError("Width must be an even number") + if (height_pixels % 2) != 0: + raise ValueError("Height must be an even number") + + _, x_dim, y_dim = find_coord_system(da) + + data_width_pixels = len(da[x_dim]) + data_height_pixels = len(da[y_dim]) + + idx_x_min: int = data_width_pixels + idx_x_max: int = 0 + idx_y_min: int = data_height_pixels + idx_y_max: int = 0 + + for location in locations: + center_idx_x, center_idx_y = _get_pixel_index_location(da, location) + idx_x_min = min(idx_x_min, center_idx_x) + idx_x_max = max(idx_x_max, center_idx_x) + idx_y_min = min(idx_y_min, center_idx_y) + idx_y_max = max(idx_y_max, center_idx_y) + + half_width = width_pixels // 2 + half_height = height_pixels // 2 + + left_idx = int(idx_x_min - half_width) + right_idx = int(idx_x_max + half_width) + bottom_idx = int(idx_y_min - half_height) + top_idx = int(idx_y_max + half_height) + + slice_unavailable = ( + left_idx < 0 + or right_idx > data_width_pixels + or bottom_idx < 0 + or top_idx > data_height_pixels + ) + + if slice_unavailable: + raise ValueError( + "Multi-location window not available: " + f"left_idx ({left_idx}), right_idx ({right_idx}), " + f"bottom_idx ({bottom_idx}), top_idx ({top_idx}), " + f"data_width_pixels ({data_width_pixels}), data_height_pixels ({data_height_pixels})", + ) + + # Add buffer of 1 pixel if window is 2 pixels wide to ensure the central location is within the + # returned slice + x_buffer = 1 if width_pixels==2 else 0 + y_buffer = 1 if height_pixels==2 else 0 + + return da.isel( + { + x_dim: slice(left_idx, right_idx+x_buffer), + y_dim: slice(bottom_idx, top_idx+y_buffer), + }, + ) diff --git a/ocf_data_sampler/torch_datasets/pvnet_dataset.py b/ocf_data_sampler/torch_datasets/pvnet_dataset.py index 6980f952..b35fa33a 100644 --- a/ocf_data_sampler/torch_datasets/pvnet_dataset.py +++ b/ocf_data_sampler/torch_datasets/pvnet_dataset.py @@ -33,6 +33,7 @@ diff_nwp_data, fill_nans_in_arrays, find_valid_time_periods, + reduce_spatial_extent_of_datasets, slice_datasets_by_space, slice_datasets_by_time, ) @@ -433,6 +434,22 @@ def validate_sample_request(self, t0: pd.Timestamp, location_id: int) -> None: class PVNetConcurrentDataset(AbstractPVNetDataset): """A torch Dataset for creating concurrent PVNet location samples.""" + @override + def __init__( + self, + config_filename: str, + start_time: str | None = None, + end_time: str | None = None, + ) -> None: + + super().__init__(config_filename, start_time, end_time) + + self.datasets_dict = reduce_spatial_extent_of_datasets( + self.datasets_dict, + self.locations, + self.config, + ) + @override def __len__(self) -> int: return len(self.valid_t0_times) diff --git a/ocf_data_sampler/torch_datasets/utils/__init__.py b/ocf_data_sampler/torch_datasets/utils/__init__.py index 9c7cae1f..4adf705d 100644 --- a/ocf_data_sampler/torch_datasets/utils/__init__.py +++ b/ocf_data_sampler/torch_datasets/utils/__init__.py @@ -1,7 +1,7 @@ from .config_normalization_values_to_dicts import config_normalization_values_to_dicts from .merge_and_fill_utils import fill_nans_in_arrays from .valid_time_periods import find_valid_time_periods -from .spatial_slice_for_dataset import slice_datasets_by_space +from .spatial_slice_for_dataset import slice_datasets_by_space, reduce_spatial_extent_of_datasets from .time_slice_for_dataset import slice_datasets_by_time from .add_alterate_coordinate_projections import add_alterate_coordinate_projections from .diff_nwp_data import diff_nwp_data \ No newline at end of file diff --git a/ocf_data_sampler/torch_datasets/utils/spatial_slice_for_dataset.py b/ocf_data_sampler/torch_datasets/utils/spatial_slice_for_dataset.py index 19de243c..c0afa358 100644 --- a/ocf_data_sampler/torch_datasets/utils/spatial_slice_for_dataset.py +++ b/ocf_data_sampler/torch_datasets/utils/spatial_slice_for_dataset.py @@ -1,8 +1,12 @@ """Functions for selecting data around a given location.""" + from ocf_data_sampler.config import Configuration from ocf_data_sampler.select.location import Location -from ocf_data_sampler.select.select_spatial_slice import select_spatial_slice_pixels +from ocf_data_sampler.select.select_spatial_slice import ( + select_spatial_slice_pixels, + select_spatial_slice_pixels_multiple, +) def slice_datasets_by_space( @@ -53,3 +57,45 @@ def slice_datasets_by_space( return sliced_datasets_dict + + +def reduce_spatial_extent_of_datasets( + datasets_dict: dict, + locations: list[Location], + config: Configuration, +) -> dict: + """Reduce the spatial extent of the datasets to only cover the locations. + + Args: + datasets_dict: Dictionary of the input data sources + locations: List of locations to reduce to + config: Configuration object + """ + sliced_datasets_dict = {} + + if "nwp" in datasets_dict: + sliced_datasets_dict["nwp"] = {} + + for nwp_key, nwp_config in config.input_data.nwp.items(): + sliced_datasets_dict["nwp"][nwp_key] = select_spatial_slice_pixels_multiple( + datasets_dict["nwp"][nwp_key], + locations, + height_pixels=nwp_config.image_size_pixels_height, + width_pixels=nwp_config.image_size_pixels_width, + ) + + + if "sat" in datasets_dict: + sat_config = config.input_data.satellite + + sliced_datasets_dict["sat"] = select_spatial_slice_pixels_multiple( + datasets_dict["sat"], + locations, + height_pixels=sat_config.image_size_pixels_height, + width_pixels=sat_config.image_size_pixels_width, + ) + + if "generation" in datasets_dict: + sliced_datasets_dict["generation"] = datasets_dict["generation"] + + return sliced_datasets_dict