Skip to content
173 changes: 173 additions & 0 deletions src/earthkit/data/readers/grib/virtual.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
# (C) Copyright 2020 ECMWF.
#
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.
#

import logging
from functools import cached_property

from earthkit.data import from_source
from earthkit.data.core.fieldlist import Field
from earthkit.data.core.metadata import WrappedMetadata
from earthkit.data.utils.dates import date_to_grib
from earthkit.data.utils.dates import datetime_from_grib
from earthkit.data.utils.dates import time_to_grib
from earthkit.data.utils.dates import to_timedelta

from .index import GribFieldList

LOG = logging.getLogger(__name__)


class VirtualGribField(Field):
def __init__(self, owner, request, metadata_alias, reference=None):
self.owner = owner
self.request = request
self.metadata_alias = metadata_alias
self.reference = reference

self.extra = {}
if "param" in request:
if "shortName" not in request:
self.extra = self.owner._get_info(self.request["param"])
if "shortName" not in self.extra:
self.extra["shortName"] = request["param"]
else:
self.extra["param"] = self.extra["shortName"]

@property
def _data_datetime(self):
if "date" in self.request and "time" in self.request:
return datetime_from_grib(self.request["date"], self.request["time"])
return None

@property
def _valid_datetime(self):
base = self._data_datetime
if base is None:
return None
if "step" in self.request:
step = self.request["step"]
return base + to_timedelta(step)
return base

def _attributes(self, names, remapping=None, joiner=None, default=None):
result = {}
metadata = self.metadata
if remapping is not None:
metadata = remapping(metadata, joiner=joiner)

for name in names:
result[name] = metadata(name, default=default)
return result

def metadata(self, *keys, astype=None, remapping=None, patches=None, **kwargs):
if (not kwargs or kwargs == {"default": None}) and keys:
if isinstance(keys[0], (list, tuple)):
keys = keys[0]
if keys and isinstance(keys[0], str):
r = []
for k in keys:
r.append(self._one_metadata(k, astype=astype, remapping=remapping, patches=patches))
if len(r) == 1:
return r[0]
return r

return super().metadata(
*keys,
astype=astype,
remapping=remapping,
patches=patches,
**kwargs,
)

def _one_metadata(self, key, remapping=None, patches=None, **kwargs):
# print(f"one_metadata key={key} kwargs={kwargs}")
if key in self.extra:
return self.extra[key]
if key in self.request:
return self.request[key]
if key in self.metadata_alias and key in self.request:
return self.request[self.metadata_alias[key]]

if key == "number":
return 0
if key == "validityDate":
return date_to_grib(self._valid_datetime)
if key == "validityTime":
return time_to_grib(self._valid_datetime)
if key in ("forecast_reference_time", "base_time", "base_datetime"):
# print("here")
return self._data_datetime.isoformat()
if key in ("valid_datetime", "valid_time"):
return self._valid_datetime.isoformat()

return self._metadata.get(key, **kwargs)

@property
def _metadata(self):
r = {**self.request, **self.extra}
for k, v in self.metadata_alias.items():
if k not in r and v in r:
r[k] = r[v]

return WrappedMetadata(self.owner.reference._metadata, extra=r)

def _values(self, dtype=None):
return self._field._values(dtype=dtype)

@cached_property
def _field(self):
if self.reference:
return self.reference
else:
p = self.owner.retriever.get(self.request)
return from_source("file", p, stream=True, read_all=True)[0]


class VirtualGribFieldList(GribFieldList):
def __init__(self, request_mapper, retriever):
self.request_mapper = request_mapper
self.retriever = retriever

path = self.retriever.get(self.request_mapper.request_at(0))
self.reference = from_source("file", path)[0]
self._info_cache = {}

def __len__(self):
return len(self.request_mapper)

def mutate(self):
return self

def _getitem(self, n):
if isinstance(n, int):
if n < 0:
n += len(self)
if n >= len(self):
raise IndexError(f"Index {n} out of range")

return VirtualGribField(
self,
self.request_mapper.request_at(n),
self.request_mapper.metadata_alias,
reference=self.reference if n == 0 else None,
)

def _get_info(self, param):
if param in self._info_cache:
return self._info_cache[param]

ref_request = self.request_mapper.request_at(0)
if param == ref_request.get("param"):
r = self.reference._attributes(["shortName", "name", "units", "cfName"])
else:
md = self.reference.metadata().override(paramId=param)
r = {k: md.get(k, None) for k in ["shortName", "name", "units", "cfName"]}

self._info_cache[param] = r
return r
51 changes: 45 additions & 6 deletions src/earthkit/data/sources/fdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,20 +18,22 @@

from earthkit.data.sources.file import FileSource
from earthkit.data.sources.stream import StreamSource
from earthkit.data.utils.request import RequestMapper

from . import Source

LOG = logging.getLogger(__name__)


class FDBSource(Source):
def __init__(self, *args, stream=True, config=None, userconfig=None, **kwargs):
def __init__(self, *args, stream=True, config=None, userconfig=None, lazy=False, **kwargs):
super().__init__()

for k in ["group_by", "batch_size"]:
if k in kwargs:
raise ValueError(f"Invalid argument '{k}' for FDBSource. Deprecated since 0.8.0.")

self.lazy = lazy
self._fdb_kwargs = {}
if config is not None:
self._fdb_kwargs["config"] = config
Expand Down Expand Up @@ -64,12 +66,19 @@ def _check_env(self):
)

def mutate(self):
fdb = pyfdb.FDB(**self._fdb_kwargs)
if self.stream:
stream = fdb.retrieve(self.request)
return StreamSource(stream, **self._stream_kwargs)
if not self.lazy:
fdb = pyfdb.FDB(**self._fdb_kwargs)
if self.stream:
stream = fdb.retrieve(self.request)
return StreamSource(stream, **self._stream_kwargs)
else:
return FDBFileSource(fdb, self.request)
else:
return FDBFileSource(fdb, self.request)
mapper = FDBRequestMapper(self.request, fdb_kwargs=self._fdb_kwargs)
retriever = FDBRetriever(self._fdb_kwargs)
from earthkit.data.readers.grib.virtual import VirtualGribFieldList

return VirtualGribFieldList(mapper, retriever)


class FDBFileSource(FileSource):
Expand All @@ -89,4 +98,34 @@ def retrieve(target, request):
)


class FDBRetriever:
def __init__(self, fdb_kwargs):
self.fdb_kwargs = fdb_kwargs

def get(self, request):
fdb = pyfdb.FDB(**self.fdb_kwargs)
s = FDBFileSource(fdb, request)
return s.path


class FDBRequestMapper(RequestMapper):
def __init__(self, request, fdb_kwargs=None, **kwargs):
super().__init__(request, **kwargs)
self.fdb_kwargs = fdb_kwargs or {}
self.metadata_alias = {
"stepRange": "step",
"typeOfLevel": "leveltype",
"level": "levelist",
"dataDate": "date",
"dataTime": "time",
}

def _build(self):
r = []
fdb = pyfdb.FDB(**self.fdb_kwargs)
for el in fdb.list(self.request, True, True):
r.append(el["keys"])
return r


source = FDBSource
36 changes: 36 additions & 0 deletions src/earthkit/data/utils/request.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# (C) Copyright 2022 ECMWF.
#
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.
#

import logging
from abc import ABCMeta
from abc import abstractmethod
from functools import cached_property

LOG = logging.getLogger(__name__)


class RequestMapper(metaclass=ABCMeta):
metadata_alias = None

def __init__(self, request, **kwargs):
self.request = request

@cached_property
def field_requests(self):
return self._build()

@abstractmethod
def _build(self):
pass

def request_at(self, index):
return self.field_requests[index]

def __len__(self):
return len(self.field_requests)
87 changes: 87 additions & 0 deletions tests/lazy/test_lazy_fdb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
#!/usr/bin/env python3

# (C) Copyright 2020 ECMWF.
#
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.
#

import os

import numpy as np
import pytest

from earthkit.data import from_source
from earthkit.data.core.temporary import temp_directory
from earthkit.data.testing import NO_FDB
from earthkit.data.testing import earthkit_test_data_file

TEST_GRIB_REQUEST = {
"class": "od",
"expver": "0001",
"stream": "oper",
"date": [20240603, 20240604],
"time": [0, 1200],
"domain": "g",
"type": "fc",
"levtype": "pl",
"levelist": [500, 700],
"step": [0, 6],
"param": [130, 157],
}


def make_fdb_config(path):
fdb_schema = earthkit_test_data_file("fdb_schema.txt")
fdb_dir = path
os.makedirs(fdb_dir, exist_ok=True)
config = {
"type": "local",
"engine": "toc",
"schema": fdb_schema,
"spaces": [{"handler": "Default", "roots": [{"path": fdb_dir}]}],
}
return config


def make_fdb(path):
ds = from_source("sample", "pl.grib")
config = make_fdb_config(path)
ds.to_target("fdb", config=config)
return ds, config


@pytest.mark.skipif(NO_FDB, reason="No access to FDB")
@pytest.mark.cache
def test_lazy_fdb():
with temp_directory() as tmpdir:
ds, config = make_fdb(os.path.join(tmpdir, "_fdb"))

ds = from_source("fdb", TEST_GRIB_REQUEST, config=config, stream=False, lazy=True)
assert len(ds) == 32

assert ds[0].shape == (19, 36)
assert ds[1].shape == (19, 36)
assert ds[0].metadata(["shortName", "param", "units", "cfName"]) == ["t", "t", "K", "air_temperature"]
assert ds[1].metadata(["shortName", "param", "units", "cfName"]) == [
"r",
"r",
"%",
"relative_humidity",
]

assert not hasattr(ds, "path")
assert not hasattr(ds[0], "path")

a = ds.to_xarray(time_dim_mode="forecast")
assert a["t"].values.shape == (4, 2, 2, 19, 36)
assert a["r"].values.shape == (4, 2, 2, 19, 36)

m = a.mean("step").load()
assert m["t"].values.shape == (4, 2, 19, 36)
assert m["r"].values.shape == (4, 2, 19, 36)
assert np.allclose(m["r"].values.flatten()[85:87], [47.66908598, 53.43959379])
assert np.allclose(m["t"].values.flatten()[85:87], [253.22625732, 252.78778076])
Loading