Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 81 additions & 6 deletions ocf_data_sampler/select/select_spatial_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,8 @@ def select_spatial_slice_pixels(
The selected DataArray slice.

Raises:
ValueError: If the dimensions are not even or the slice is not allowed
when padding is required.
ValueError: If the window dimensions are not even or the slice extends beyond the data
boundaries.
"""
if (width_pixels % 2) != 0:
raise ValueError("Width must be an even number")
Expand All @@ -83,7 +83,6 @@ def select_spatial_slice_pixels(
data_width_pixels = len(da[x_dim])
data_height_pixels = len(da[y_dim])

# Padding checks
slice_unavailable = (
left_idx < 0
or right_idx > data_width_pixels
Expand All @@ -104,7 +103,83 @@ def select_spatial_slice_pixels(
issue_details = "\n - ".join(issues)
raise ValueError(f"Window for location {location} not available: \n - {issue_details}")

# Standard selection - without padding
da = da.isel({x_dim: slice(left_idx, right_idx), y_dim: slice(bottom_idx, top_idx)})
return da.isel({x_dim: slice(left_idx, right_idx), y_dim: slice(bottom_idx, top_idx)})

return da

def select_spatial_slice_pixels_multiple(
da: xr.DataArray,
locations: list[Location],
width_pixels: int,
height_pixels: int,
) -> xr.DataArray:
"""Select spatial slice which covers all given locations.

Args:
da: xarray DataArray to slice from
locations: List of locations of interest that will be covered by the returned slice
height_pixels: Height of the slice in pixels
width_pixels: Width of the slice in pixels

Returns:
The selected DataArray slice.

Raises:
ValueError: If the window dimensions are not even or the slice extends beyond the data
boundaries.
"""
if (width_pixels % 2) != 0:
raise ValueError("Width must be an even number")
if (height_pixels % 2) != 0:
raise ValueError("Height must be an even number")

_, x_dim, y_dim = find_coord_system(da)

data_width_pixels = len(da[x_dim])
data_height_pixels = len(da[y_dim])

idx_x_min: int = data_width_pixels
idx_x_max: int = 0
idx_y_min: int = data_height_pixels
idx_y_max: int = 0

for location in locations:
center_idx_x, center_idx_y = _get_pixel_index_location(da, location)
idx_x_min = min(idx_x_min, center_idx_x)
idx_x_max = max(idx_x_max, center_idx_x)
idx_y_min = min(idx_y_min, center_idx_y)
idx_y_max = max(idx_y_max, center_idx_y)

half_width = width_pixels // 2
half_height = height_pixels // 2

left_idx = int(idx_x_min - half_width)
right_idx = int(idx_x_max + half_width)
bottom_idx = int(idx_y_min - half_height)
top_idx = int(idx_y_max + half_height)

slice_unavailable = (
left_idx < 0
or right_idx > data_width_pixels
or bottom_idx < 0
or top_idx > data_height_pixels
)

if slice_unavailable:
raise ValueError(
"Multi-location window not available: "
f"left_idx ({left_idx}), right_idx ({right_idx}), "
f"bottom_idx ({bottom_idx}), top_idx ({top_idx}), "
f"data_width_pixels ({data_width_pixels}), data_height_pixels ({data_height_pixels})",
)

# Add buffer of 1 pixel if window is 2 pixels wide to ensure the central location is within the
# returned slice
x_buffer = 1 if width_pixels==2 else 0
y_buffer = 1 if height_pixels==2 else 0

return da.isel(
{
x_dim: slice(left_idx, right_idx+x_buffer),
y_dim: slice(bottom_idx, top_idx+y_buffer),
},
)
17 changes: 17 additions & 0 deletions ocf_data_sampler/torch_datasets/pvnet_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
diff_nwp_data,
fill_nans_in_arrays,
find_valid_time_periods,
reduce_spatial_extent_of_datasets,
slice_datasets_by_space,
slice_datasets_by_time,
)
Expand Down Expand Up @@ -433,6 +434,22 @@ def validate_sample_request(self, t0: pd.Timestamp, location_id: int) -> None:
class PVNetConcurrentDataset(AbstractPVNetDataset):
"""A torch Dataset for creating concurrent PVNet location samples."""

@override
def __init__(
self,
config_filename: str,
start_time: str | None = None,
end_time: str | None = None,
) -> None:

super().__init__(config_filename, start_time, end_time)

self.datasets_dict = reduce_spatial_extent_of_datasets(
self.datasets_dict,
self.locations,
self.config,
)

@override
def __len__(self) -> int:
return len(self.valid_t0_times)
Expand Down
2 changes: 1 addition & 1 deletion ocf_data_sampler/torch_datasets/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from .config_normalization_values_to_dicts import config_normalization_values_to_dicts
from .merge_and_fill_utils import fill_nans_in_arrays
from .valid_time_periods import find_valid_time_periods
from .spatial_slice_for_dataset import slice_datasets_by_space
from .spatial_slice_for_dataset import slice_datasets_by_space, reduce_spatial_extent_of_datasets
from .time_slice_for_dataset import slice_datasets_by_time
from .add_alterate_coordinate_projections import add_alterate_coordinate_projections
from .diff_nwp_data import diff_nwp_data
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
"""Functions for selecting data around a given location."""


from ocf_data_sampler.config import Configuration
from ocf_data_sampler.select.location import Location
from ocf_data_sampler.select.select_spatial_slice import select_spatial_slice_pixels
from ocf_data_sampler.select.select_spatial_slice import (
select_spatial_slice_pixels,
select_spatial_slice_pixels_multiple,
)


def slice_datasets_by_space(
Expand Down Expand Up @@ -53,3 +57,45 @@ def slice_datasets_by_space(


return sliced_datasets_dict


def reduce_spatial_extent_of_datasets(
datasets_dict: dict,
locations: list[Location],
config: Configuration,
) -> dict:
"""Reduce the spatial extent of the datasets to only cover the locations.

Args:
datasets_dict: Dictionary of the input data sources
locations: List of locations to reduce to
config: Configuration object
"""
sliced_datasets_dict = {}

if "nwp" in datasets_dict:
sliced_datasets_dict["nwp"] = {}

for nwp_key, nwp_config in config.input_data.nwp.items():
sliced_datasets_dict["nwp"][nwp_key] = select_spatial_slice_pixels_multiple(
datasets_dict["nwp"][nwp_key],
locations,
height_pixels=nwp_config.image_size_pixels_height,
width_pixels=nwp_config.image_size_pixels_width,
)


if "sat" in datasets_dict:
sat_config = config.input_data.satellite

sliced_datasets_dict["sat"] = select_spatial_slice_pixels_multiple(
datasets_dict["sat"],
locations,
height_pixels=sat_config.image_size_pixels_height,
width_pixels=sat_config.image_size_pixels_width,
)

if "generation" in datasets_dict:
sliced_datasets_dict["generation"] = datasets_dict["generation"]

return sliced_datasets_dict
Loading