Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

This file was deleted.

This file was deleted.

211 changes: 211 additions & 0 deletions ocf_data_sampler/lightarray.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
"""A lightweight DataArray-like class."""


import numpy as np
import xarray as xr
from tensorstore import Future as TensorStoreFuture
from tensorstore import TensorStore
from xarray_tensorstore import _TensorStoreAdapter


class LightDataArray:
"""A lightweight DataArray-like class."""

__slots__ = ["attrs", "coord_dims", "coords", "data", "dims", "future"]

def __init__(
self,
data: np.ndarray | TensorStore,
dims: tuple[str, ...],
coords: dict[str, np.ndarray],
coord_dims: dict[str, tuple[str, ...]],
attrs: None | dict = None,
) -> None:
"""A lightweight DataArray-like class."""
self.data = data
self.dims = dims
self.coords = coords
self.coord_dims = coord_dims
self.attrs = attrs or {}
self.future: None | TensorStoreFuture = None

@classmethod
def from_xarray(cls, da: xr.DataArray) -> "LightDataArray":
"""Create a LightDataArray from an Xarray DataArray."""
# Get raw data handle which can be a numpy array or TensorStore
data: TensorStore | np.ndarray
if isinstance(da.variable._data, _TensorStoreAdapter):
data = da.variable._data.array
elif isinstance(da.variable._data, np.ndarray):
data = da.variable._data
else:
raise ValueError(f"Data backend of type {type(da.variable._data)} not supported.")

coord_values: dict[str, np.ndarray] = {}
coord_dims: dict[str, tuple[str, ...]] = {}

for k, v in da.coords.items():
if v.ndim <= 1:
coord_values[k] = v.values
coord_dims[k] = v.dims
else:
raise ValueError(
"Coordinates with more than 1 dimension not supported. "
f"Found coord '{k}' with shape {v.shape}.",
)

return cls(
data=data,
dims=da.dims,
coords=coord_values,
coord_dims=coord_dims,
attrs=da.attrs,
)

def to_xarray(self) -> xr.DataArray:
"""Convert to an Xarray DataArray.

Note this loads the data eagerly.
"""
coords_dict = {}
for c, v in self.coords.items():
cdims = self.coord_dims.get(c, ())

# If it's a 1D array and the dimension is still in our dims list
if np.ndim(v) == 1 and cdims[0] in self.dims:
coords_dict[c] = (cdims, v)
else:
# It's a scalar or a non-indexed coordinate
coords_dict[c] = v

return xr.DataArray(
data=self.values,
dims=self.dims,
coords=coords_dict,
attrs=self.attrs,
)

def isel(
self,
indexers: None | dict[str, int | slice | list] = None,
**indexers_kwargs: object,
) -> "LightDataArray":
"""Select data by integer index along specified dimensions.

Args:
indexers: A dict with keys matching dimensions and values given by integers, slice
objects or arrays. `indexer` can be an integer, slice or array-like.
**indexers_kwargs: The keyword arguments form of indexers.
"""
if indexers is not None:
indexers_kwargs.update(indexers)

axis_indexers = [slice(None)] * len(self.dims)
new_coords = self.coords.copy()
dims_to_remove = []

for dim, indexer in indexers_kwargs.items():
if dim not in self.dims:
raise KeyError(
f"'{dim}' is not a valid dimension or coordinate for data with dimensions"
f"{self.dims}",
)

axis_indexers[self.dims.index(dim)] = indexer

# Slice the coords which depend on this dimension
for c_name, c_dim_name in self.coord_dims.items():
if c_dim_name == (dim,):
new_coords[c_name] = new_coords[c_name][indexer]

# Check if this dimension is being collapsed (e.g. an integer index like .isel(time=0))
if isinstance(indexer, int | np.integer):
dims_to_remove.append(dim)

# Slice the underlying dta
sliced_data = self.data[tuple(axis_indexers)]

# Remove dims that have been reduced to points
remaining_dims = tuple(d for d in self.dims if d not in dims_to_remove)

# Remove dims from coords that have been reduced to points
new_coord_dims = self.coord_dims.copy()
for dim in dims_to_remove:
for c_name, c_dim_name in self.coord_dims.items():
if c_dim_name == (dim,):
new_coord_dims[c_name] = ()

return LightDataArray(
data=sliced_data,
dims=remaining_dims,
coords=new_coords,
coord_dims=new_coord_dims,
attrs=self.attrs,
)

def read(self) -> None:
"""Trigger reading of the data if it's a lazy handle."""
if isinstance(self.data, TensorStore):
self.future = self.data.read()

def load(self) -> "LightDataArray":
"""Load data in-place and return self."""
self.data = self.values
self.future = None
return self

@property
def values(self) -> np.ndarray:
"""Get the underlying data as numpy array, loading it if necessary."""
if isinstance(self.data, TensorStore):
# If TensorStore handle reading
if self.future is None:
return np.asarray(self.data.read().result())
else:
return np.asarray(self.future.result())
else:
return np.asarray(self.data)


def __getattr__(self, name: str) -> "LightDataArray":
"""Allow access to coordinates via attribute syntax, e.g., da.time."""
if name in self.coords:
return self[name]
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")

def __getitem__(self, key: str) -> "LightDataArray":
"""Allow access to coordinates via indexing syntax, e.g., da['time']."""
if key in self.coords:
return LightDataArray(
data=self.coords[key],
dims=self.coord_dims[key],
coords={key: self.coords[key]},
coord_dims={key: self.coord_dims[key]},
)
raise KeyError(f"Coordinate '{key}' not found.")

def __getstate__(self) -> dict:
"""Prepare state for pickling, excluding un-picklable attributes."""
return {
"data": self.data,
"dims": self.dims,
"coords": self.coords,
"attrs": self.attrs,
"coord_dims": self.coord_dims,
}

def __setstate__(self, state: dict) -> None:
"""Restore state after unpickling."""
for k, v in state.items():
setattr(self, k, v)
# Restore the un-picklable attribute to a default state
self.future = None

@property
def shape(self) -> tuple[int, ...]:
"""Return the shape of the underlying data array."""
return self.data.shape

def __len__(self) -> int:
"""Return the length of the underlying data array."""
return self.shape[0]
10 changes: 6 additions & 4 deletions ocf_data_sampler/load/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
import numpy as np
import xarray as xr

from ocf_data_sampler.load.utils import assert_values_unique_increasing


def open_generation(zarr_path: str, public: bool = False) -> xr.DataArray:
"""Open and eagerly load the generation data and validates its data types.
Expand All @@ -41,18 +43,18 @@ def open_generation(zarr_path: str, public: bool = False) -> xr.DataArray:
backend_kwargs=backend_kwargs,
)

ds = ds.assign_coords(capacity_mwp=ds.capacity_mwp)
da = ds.to_dataarray("gen_param").transpose("time_utc", "location_id", "gen_param")

da = ds.generation_mw
assert_values_unique_increasing(ds.time_utc.values, "time_utc")
assert_values_unique_increasing(ds.location_id.values, "location_id")

# Validate data types
if not np.issubdtype(da.dtype, np.floating):
raise TypeError(f"generation_mw should be floating, not {da.dtype}")
raise TypeError(f"generation and capacity values should be floating, not {da.dtype}")

coord_dtypes = {
"time_utc": np.datetime64,
"location_id": np.integer,
"capacity_mwp": np.floating,
"longitude": np.floating,
"latitude": np.floating,
}
Expand Down
6 changes: 3 additions & 3 deletions ocf_data_sampler/load/nwp/providers/cloudcasting.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from ocf_data_sampler.load.nwp.providers.utils import open_zarr_paths
from ocf_data_sampler.load.utils import (
check_time_unique_increasing,
assert_values_unique_increasing,
get_xr_data_array_from_xr_dataset,
make_spatial_coords_increasing,
)
Expand Down Expand Up @@ -38,8 +38,8 @@ def open_cloudcasting(zarr_path: str | list[str]) -> xr.DataArray:
},
)

# Check the timestamps are unique and increasing
check_time_unique_increasing(ds.init_time_utc)
assert_values_unique_increasing(ds.init_time_utc.values, "init_time_utc")
assert_values_unique_increasing(ds.step.values, "step")

# Make sure the spatial coords are in increasing order
ds = make_spatial_coords_increasing(ds, x_coord="x_geostationary", y_coord="y_geostationary")
Expand Down
5 changes: 3 additions & 2 deletions ocf_data_sampler/load/nwp/providers/ecmwf.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from ocf_data_sampler.load.nwp.providers.utils import open_zarr_paths
from ocf_data_sampler.load.utils import (
check_time_unique_increasing,
assert_values_unique_increasing,
get_xr_data_array_from_xr_dataset,
make_spatial_coords_increasing,
)
Expand All @@ -24,7 +24,8 @@ def open_ifs(zarr_path: str | list[str]) -> xr.DataArray:
# LEGACY SUPPORT - rename variable to channel if it exists
ds = ds.rename({"init_time": "init_time_utc", "variable": "channel"})

check_time_unique_increasing(ds.init_time_utc)
assert_values_unique_increasing(ds.init_time_utc.values, "init_time_utc")
assert_values_unique_increasing(ds.step.values, "step")

ds = make_spatial_coords_increasing(ds, x_coord="longitude", y_coord="latitude")

Expand Down
5 changes: 3 additions & 2 deletions ocf_data_sampler/load/nwp/providers/gdm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from ocf_data_sampler.load.nwp.providers.utils import open_zarr_paths
from ocf_data_sampler.load.utils import (
check_time_unique_increasing,
assert_values_unique_increasing,
get_xr_data_array_from_xr_dataset,
make_spatial_coords_increasing,
)
Expand All @@ -21,7 +21,8 @@ def open_gdm(zarr_path: str | list[str]) -> xr.DataArray:
"""
ds = open_zarr_paths(zarr_path, backend="tensorstore", time_dim="init_time_utc")

check_time_unique_increasing(ds.init_time_utc)
assert_values_unique_increasing(ds.init_time_utc.values, "init_time_utc")
assert_values_unique_increasing(ds.step.values, "step")

ds = make_spatial_coords_increasing(ds, x_coord="longitude", y_coord="latitude")

Expand Down
7 changes: 5 additions & 2 deletions ocf_data_sampler/load/nwp/providers/gfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@
import xarray as xr

from ocf_data_sampler.load.nwp.providers.utils import open_zarr_paths
from ocf_data_sampler.load.utils import check_time_unique_increasing, make_spatial_coords_increasing
from ocf_data_sampler.load.utils import (
assert_values_unique_increasing,
make_spatial_coords_increasing,
)

_log = logging.getLogger(__name__)

Expand Down Expand Up @@ -33,7 +36,7 @@ def open_gfs(zarr_path: str | list[str], public: bool = False) -> xr.DataArray:

del gfs

check_time_unique_increasing(nwp.init_time_utc)
assert_values_unique_increasing(nwp.init_time_utc.values, "init_time_utc")
nwp = make_spatial_coords_increasing(nwp, x_coord="longitude", y_coord="latitude")

nwp = nwp.transpose("init_time_utc", "step", "channel", "longitude", "latitude")
Expand Down
7 changes: 5 additions & 2 deletions ocf_data_sampler/load/nwp/providers/icon.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@
import xarray as xr

from ocf_data_sampler.load.nwp.providers.utils import open_zarr_paths
from ocf_data_sampler.load.utils import check_time_unique_increasing, make_spatial_coords_increasing
from ocf_data_sampler.load.utils import (
assert_values_unique_increasing,
make_spatial_coords_increasing,
)


def open_icon_eu(zarr_path: str | list[str]) -> xr.DataArray:
Expand All @@ -27,7 +30,7 @@ def open_icon_eu(zarr_path: str | list[str]) -> xr.DataArray:
else:
raise ValueError("Could not find 'icon_eu_data' DataArray in the ICON-EU Zarr file.")

check_time_unique_increasing(nwp.init_time_utc)
assert_values_unique_increasing(nwp.init_time_utc.values, "init_time_utc")

# 0-78 one hour steps, rest 3 hour steps
nwp = nwp.isel(step=slice(0, 78))
Expand Down
5 changes: 3 additions & 2 deletions ocf_data_sampler/load/nwp/providers/ukv.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from ocf_data_sampler.load.nwp.providers.utils import open_zarr_paths
from ocf_data_sampler.load.utils import (
check_time_unique_increasing,
assert_values_unique_increasing,
get_xr_data_array_from_xr_dataset,
make_spatial_coords_increasing,
)
Expand Down Expand Up @@ -32,9 +32,10 @@ def open_ukv(zarr_path: str | list[str]) -> xr.DataArray:
# Only rename if the source key exists in the dataset's dimensions or coordinates
# This prevents KeyErrors when the new UKV data already has "x_osgb" and "y_osgb"
ds = ds.rename({k: v for k, v in rename_map.items() if k in ds.coords})
check_time_unique_increasing(ds.init_time_utc)

ds = make_spatial_coords_increasing(ds, x_coord="x_osgb", y_coord="y_osgb")
assert_values_unique_increasing(ds.init_time_utc.values, "init_time_utc")
assert_values_unique_increasing(ds.step.values, "step")

ds = ds.transpose("init_time_utc", "step", "channel", "x_osgb", "y_osgb")

Expand Down
6 changes: 2 additions & 4 deletions ocf_data_sampler/load/satellite.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from ocf_data_sampler.load.open_xarray_tensorstore import open_zarr, open_zarrs
from ocf_data_sampler.load.utils import (
check_time_unique_increasing,
assert_values_unique_increasing,
get_xr_data_array_from_xr_dataset,
make_spatial_coords_increasing,
)
Expand All @@ -23,17 +23,15 @@ def open_sat_data(zarr_path: str | list[str]) -> xr.DataArray:
else:
ds = open_zarr(zarr_path)

check_time_unique_increasing(ds.time)

ds = ds.rename(
{
"variable": "channel",
"time": "time_utc",
},
)

check_time_unique_increasing(ds.time_utc)
ds = make_spatial_coords_increasing(ds, x_coord="x_geostationary", y_coord="y_geostationary")
assert_values_unique_increasing(ds.time_utc.values, "time_utc")
ds = ds.transpose("time_utc", "channel", "x_geostationary", "y_geostationary")

data_array = get_xr_data_array_from_xr_dataset(ds)
Expand Down
Loading
Loading