From 54ac5be9ddd7ed2b5792130ea9dc7c07b0dd84ca Mon Sep 17 00:00:00 2001 From: Sandor Kertesz Date: Mon, 7 Apr 2025 13:34:27 +0100 Subject: [PATCH 1/5] Virtual field --- src/earthkit/data/sources/ecmwf_api.py | 172 +++++++++++++++++++++++-- 1 file changed, 162 insertions(+), 10 deletions(-) diff --git a/src/earthkit/data/sources/ecmwf_api.py b/src/earthkit/data/sources/ecmwf_api.py index d0f781a40..717dd8305 100644 --- a/src/earthkit/data/sources/ecmwf_api.py +++ b/src/earthkit/data/sources/ecmwf_api.py @@ -44,12 +44,116 @@ class MARSAPIKeyPrompt(APIKeyPrompt): config_env = ("ECMWF_API_KEY", "ECMWF_API_URL") +class MarsFetcher: + def __init__(self, service): + self.service = service + + def get(self, request): + def retrieve(target, request): + self.service.execute(request, target) + + return self.cache_file( + retrieve, + request, + ) + + def cache_file(self, create, args, **kwargs): + import re + + from earthkit.data.core.caching import cache_file + + owner = kwargs.pop("owner", None) + if owner is None: + owner = re.sub(r"(?!^)([A-Z]+)", r"-\1", self.__class__.__name__).lower() + + return cache_file(owner, create, args, **kwargs) + + +class RequestMappaper: + defaults = { + "class": "oper", + "type": "an", + "stream": "da", + "expver": "1", + "param": "z", + "levtype": "pl", + "levelist": [1000, 850, 700, 500, 400, 300], + "time": 12, + "step": 0, + } + mandatory = ["date", "area", "grid"] + + def __init__(self, *args, request, **kwargs): + self.request = request + coords = {} + self.skipped = {} + + # res = [] + skip = ["grid", "area"] + for r in request: + # keys = [] + # vals = [] + for k, v in r.items(): + if k in skip: + self.skipped[k] = v + else: + if not isinstance(v, (list, tuple)): + v = [v] + if k not in coords: + coords[k] = v + else: + coords[k].append(v) + + self.coords = coords + self.shape = [len(v) for v in coords.values()] + + def coords_to_index(coords, shape) -> int: + """ + Map user coords to field index""" + index = 0 + n = 1 + for i in range(len(coords) - 1, -1, -1): + index += coords[i] * n + n *= shape[i] + return index + + def index_to_coords(self, index: int, shape): + assert isinstance(index, int), (index, type(index)) + + result = [None] * len(shape) + i = len(shape) - 1 + + while i >= 0: + result[i] = index % shape[i] + index = index // shape[i] + i -= 1 + + result = tuple(result) + + assert len(result) == len(shape) + return result + + def request_at(self, index): + idx = self.index_to_coords(index, self.shape) + r = {} + for i, key in enumerate(self.coords): + r[key] = self.coords[key][idx[i]] + r.update(self.skipped) + return r + + def __len__(self): + import math + + return math.prod(self.shape) + + class ECMWFApi(FileSource): - def __init__(self, *args, prompt=True, log="default", **kwargs): + def __init__(self, *args, prompt=True, log="default", lazy=False, **kwargs): super().__init__() self.prompt = prompt self.log = log + self.lazy = lazy request = {} for a in args: @@ -64,20 +168,25 @@ def __init__(self, *args, prompt=True, log="default", **kwargs): self.expect_any = True break - self.service() # Trigger password prompt before threading + self.request = requests - nthreads = min(self.config("number-of-download-threads"), len(requests)) + self.service() # Trigger password prompt before threading - if nthreads < 2: - self.path = [self._retrieve(r) for r in requests] + if lazy: + pass else: - from earthkit.data.utils.progbar import tqdm + nthreads = min(self.config("number-of-download-threads"), len(requests)) - with SoftThreadPool(nthreads=nthreads) as pool: - futures = [pool.submit(self._retrieve, r) for r in requests] + if nthreads < 2: + self.path = [self._retrieve(r) for r in requests] + else: + from earthkit.data.utils.progbar import tqdm - iterator = (f.result() for f in futures) - self.path = list(tqdm(iterator, leave=True, total=len(requests))) + with SoftThreadPool(nthreads=nthreads) as pool: + futures = [pool.submit(self._retrieve, r) for r in requests] + + iterator = (f.result() for f in futures) + self.path = list(tqdm(iterator, leave=True, total=len(requests))) def _retrieve(self, request): def retrieve(target, request): @@ -130,3 +239,46 @@ def empty_reader(self, *args, **kwargs): from .empty import EmptySource return EmptySource() + + def request_per_field(self, requests): + from itertools import product + + # print(requests) + + res = [] + skip = ["grid", "area"] + for r in requests: + keys = [] + vals = [] + skipped = {} + for k, v in r.items(): + if k in skip: + skipped[k] = v + else: + keys.append(k) + if isinstance(v, (list, tuple)): + vals.append(v) + else: + vals.append([v]) + + for v in product(*vals): + res.append(dict(zip(keys, v), **skipped)) + + return res + + def mutate(self): + if self.lazy: + print("lazy") + mapper = RequestMappaper(request=self.request) + print("coords", mapper.coords) + print("index[0]", mapper.index_to_coords(0, mapper.shape)) + print("index[0]", mapper.request_at(0)) + print("index[1]", mapper.request_at(1)) + from earthkit.data.readers.grib.virtual import VirtualGribFieldList + + return VirtualGribFieldList(mapper, MarsFetcher(self.service())) + + return self + # ref = self.request[0] + # ref_path = _retrieve(self, ref) + # return VirtualGribFieldList(ref_path, requests, fetcher=MarsFetcher(self.service())) From 779fe61379d629abc8b7eb43d34e27de52340115 Mon Sep 17 00:00:00 2001 From: Sandor Kertesz Date: Fri, 11 Apr 2025 21:00:07 +0100 Subject: [PATCH 2/5] Implement virtual fields --- src/earthkit/data/readers/grib/virtual.py | 169 +++++++++++++++++++++ src/earthkit/data/sources/ecmwf_api.py | 172 ++-------------------- src/earthkit/data/sources/fdb.py | 61 +++++++- 3 files changed, 234 insertions(+), 168 deletions(-) create mode 100644 src/earthkit/data/readers/grib/virtual.py diff --git a/src/earthkit/data/readers/grib/virtual.py b/src/earthkit/data/readers/grib/virtual.py new file mode 100644 index 000000000..d4544e843 --- /dev/null +++ b/src/earthkit/data/readers/grib/virtual.py @@ -0,0 +1,169 @@ +# (C) Copyright 2020 ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +# + +import logging +from functools import cached_property + +from earthkit.data import from_source +from earthkit.data.core.fieldlist import Field +from earthkit.data.core.metadata import WrappedMetadata +from earthkit.data.utils.dates import date_to_grib +from earthkit.data.utils.dates import datetime_from_grib +from earthkit.data.utils.dates import time_to_grib +from earthkit.data.utils.dates import to_timedelta + +from .index import GribFieldList + +LOG = logging.getLogger(__name__) + + +class VirtualGribField(Field): + def __init__(self, owner, request, md, reference=None): + self.owner = owner + self.request = request + self.md = md + self.reference = reference + + self.extra = {} + if "param" in request: + if "shortName" not in request: + self.extra = self.owner._get_info(self.request["param"]) + if "shortName" not in self.extra: + self.extra["shortName"] = request["param"] + else: + self.extra["param"] = self.extra["shortName"] + + @property + def _data_datetime(self): + if "date" in self.request and "time" in self.request: + return datetime_from_grib(self.request["date"], self.request["time"]) + return None + + @property + def _valid_datetime(self): + base = self._data_datetime + if base is None: + return None + if "step" in self.request: + step = self.request["step"] + return base + to_timedelta(step) + return base + + def _attributes(self, names, remapping=None, joiner=None, default=None): + # print("CALLED") + result = {} + metadata = self.metadata + if remapping is not None: + metadata = remapping(metadata, joiner=joiner) + + for name in names: + result[name] = metadata(name, default=default) + return result + + def metadata(self, *keys, astype=None, remapping=None, patches=None, **kwargs): + # print(f"metadata keys={keys} {kwargs=}") + if (not kwargs or kwargs == {"default": None}) and keys: + if isinstance(keys[0], (list, tuple)): + keys = keys[0] + # print(f" -> keys={keys}") + if keys and isinstance(keys[0], str): + r = [] + for k in keys: + r.append(self._one_metadata(k, astype=astype, remapping=remapping, patches=patches)) + if len(r) == 1: + return r[0] + return r + + return super().metadata( + *keys, + astype=astype, + remapping=remapping, + patches=patches, + **kwargs, + ) + + def _one_metadata(self, key, **kwargs): + # print(f"one_metadata key={key} kwargs={kwargs}") + if key in self.extra: + return self.extra[key] + if key in self.request: + return self.request[key] + if key in self.md and key in self.request: + return self.request[self.md[key]] + + if key == "number": + return 0 + if key == "validityDate": + return date_to_grib(self._valid_datetime) + if key == "validityTime": + return time_to_grib(self._valid_datetime) + if key in ("forecast_reference_time", "base_time", "base_datetime"): + # print("here") + return self._data_datetime.isoformat() + if key in ("valid_datetime", "valid_time"): + return self._valid_datetime.isoformat() + + return super().metadata(key, **kwargs) + + @property + def _metadata(self): + r = {**self.request, **self.extra} + for k, v in self.md.items(): + if k not in r and v in r: + r[k] = r[v] + + return WrappedMetadata(self.owner.reference._metadata, extra=r) + + def _values(self, dtype=None): + return self._field._values(dtype=dtype) + + @cached_property + def _field(self): + if self.reference: + return self.reference + else: + p = self.owner.retriever.get(self.request) + return from_source("file", p, stream=True, read_all=True)[0] + + +class VirtualGribFieldList(GribFieldList): + def __init__(self, request_mapper, retriever): + self.mapper = request_mapper + self.retriever = retriever + + path = self.retriever.get(self.mapper.request_at(0)) + self.reference = from_source("file", path)[0] + + def __len__(self): + return len(self.mapper) + + def mutate(self): + return self + + # def ls(self): + # return self.reference.ls() + + def _getitem(self, n): + if isinstance(n, int): + if n < 0: + n += len(self) + if n >= len(self): + raise IndexError(f"Index {n} out of range") + + return VirtualGribField( + self, self.mapper.request_at(n), self.mapper.md, reference=self.reference if n == 0 else None + ) + + def _get_info(self, param): + ref_request = self.mapper.request_at(0) + if param == ref_request.get("param"): + return self.reference._attributes(["shortName", "name", "units", "cfName"]) + else: + md = self.reference.metadata().override(paramId=param) + return {k: md.get(k, None) for k in ["shortName", "name", "units", "cfName"]} diff --git a/src/earthkit/data/sources/ecmwf_api.py b/src/earthkit/data/sources/ecmwf_api.py index 717dd8305..d0f781a40 100644 --- a/src/earthkit/data/sources/ecmwf_api.py +++ b/src/earthkit/data/sources/ecmwf_api.py @@ -44,116 +44,12 @@ class MARSAPIKeyPrompt(APIKeyPrompt): config_env = ("ECMWF_API_KEY", "ECMWF_API_URL") -class MarsFetcher: - def __init__(self, service): - self.service = service - - def get(self, request): - def retrieve(target, request): - self.service.execute(request, target) - - return self.cache_file( - retrieve, - request, - ) - - def cache_file(self, create, args, **kwargs): - import re - - from earthkit.data.core.caching import cache_file - - owner = kwargs.pop("owner", None) - if owner is None: - owner = re.sub(r"(?!^)([A-Z]+)", r"-\1", self.__class__.__name__).lower() - - return cache_file(owner, create, args, **kwargs) - - -class RequestMappaper: - defaults = { - "class": "oper", - "type": "an", - "stream": "da", - "expver": "1", - "param": "z", - "levtype": "pl", - "levelist": [1000, 850, 700, 500, 400, 300], - "time": 12, - "step": 0, - } - mandatory = ["date", "area", "grid"] - - def __init__(self, *args, request, **kwargs): - self.request = request - coords = {} - self.skipped = {} - - # res = [] - skip = ["grid", "area"] - for r in request: - # keys = [] - # vals = [] - for k, v in r.items(): - if k in skip: - self.skipped[k] = v - else: - if not isinstance(v, (list, tuple)): - v = [v] - if k not in coords: - coords[k] = v - else: - coords[k].append(v) - - self.coords = coords - self.shape = [len(v) for v in coords.values()] - - def coords_to_index(coords, shape) -> int: - """ - Map user coords to field index""" - index = 0 - n = 1 - for i in range(len(coords) - 1, -1, -1): - index += coords[i] * n - n *= shape[i] - return index - - def index_to_coords(self, index: int, shape): - assert isinstance(index, int), (index, type(index)) - - result = [None] * len(shape) - i = len(shape) - 1 - - while i >= 0: - result[i] = index % shape[i] - index = index // shape[i] - i -= 1 - - result = tuple(result) - - assert len(result) == len(shape) - return result - - def request_at(self, index): - idx = self.index_to_coords(index, self.shape) - r = {} - for i, key in enumerate(self.coords): - r[key] = self.coords[key][idx[i]] - r.update(self.skipped) - return r - - def __len__(self): - import math - - return math.prod(self.shape) - - class ECMWFApi(FileSource): - def __init__(self, *args, prompt=True, log="default", lazy=False, **kwargs): + def __init__(self, *args, prompt=True, log="default", **kwargs): super().__init__() self.prompt = prompt self.log = log - self.lazy = lazy request = {} for a in args: @@ -168,25 +64,20 @@ def __init__(self, *args, prompt=True, log="default", lazy=False, **kwargs): self.expect_any = True break - self.request = requests - self.service() # Trigger password prompt before threading - if lazy: - pass - else: - nthreads = min(self.config("number-of-download-threads"), len(requests)) + nthreads = min(self.config("number-of-download-threads"), len(requests)) - if nthreads < 2: - self.path = [self._retrieve(r) for r in requests] - else: - from earthkit.data.utils.progbar import tqdm + if nthreads < 2: + self.path = [self._retrieve(r) for r in requests] + else: + from earthkit.data.utils.progbar import tqdm - with SoftThreadPool(nthreads=nthreads) as pool: - futures = [pool.submit(self._retrieve, r) for r in requests] + with SoftThreadPool(nthreads=nthreads) as pool: + futures = [pool.submit(self._retrieve, r) for r in requests] - iterator = (f.result() for f in futures) - self.path = list(tqdm(iterator, leave=True, total=len(requests))) + iterator = (f.result() for f in futures) + self.path = list(tqdm(iterator, leave=True, total=len(requests))) def _retrieve(self, request): def retrieve(target, request): @@ -239,46 +130,3 @@ def empty_reader(self, *args, **kwargs): from .empty import EmptySource return EmptySource() - - def request_per_field(self, requests): - from itertools import product - - # print(requests) - - res = [] - skip = ["grid", "area"] - for r in requests: - keys = [] - vals = [] - skipped = {} - for k, v in r.items(): - if k in skip: - skipped[k] = v - else: - keys.append(k) - if isinstance(v, (list, tuple)): - vals.append(v) - else: - vals.append([v]) - - for v in product(*vals): - res.append(dict(zip(keys, v), **skipped)) - - return res - - def mutate(self): - if self.lazy: - print("lazy") - mapper = RequestMappaper(request=self.request) - print("coords", mapper.coords) - print("index[0]", mapper.index_to_coords(0, mapper.shape)) - print("index[0]", mapper.request_at(0)) - print("index[1]", mapper.request_at(1)) - from earthkit.data.readers.grib.virtual import VirtualGribFieldList - - return VirtualGribFieldList(mapper, MarsFetcher(self.service())) - - return self - # ref = self.request[0] - # ref_path = _retrieve(self, ref) - # return VirtualGribFieldList(ref_path, requests, fetcher=MarsFetcher(self.service())) diff --git a/src/earthkit/data/sources/fdb.py b/src/earthkit/data/sources/fdb.py index 3482ecfff..9c839cf79 100644 --- a/src/earthkit/data/sources/fdb.py +++ b/src/earthkit/data/sources/fdb.py @@ -10,6 +10,7 @@ import logging import os import shutil +from functools import cached_property try: import pyfdb @@ -25,13 +26,14 @@ class FDBSource(Source): - def __init__(self, *args, stream=True, config=None, userconfig=None, **kwargs): + def __init__(self, *args, stream=True, config=None, userconfig=None, lazy=False, **kwargs): super().__init__() for k in ["group_by", "batch_size"]: if k in kwargs: raise ValueError(f"Invalid argument '{k}' for FDBSource. Deprecated since 0.8.0.") + self.lazy = lazy self._fdb_kwargs = {} if config is not None: self._fdb_kwargs["config"] = config @@ -64,12 +66,19 @@ def _check_env(self): ) def mutate(self): - fdb = pyfdb.FDB(**self._fdb_kwargs) - if self.stream: - stream = fdb.retrieve(self.request) - return StreamSource(stream, **self._stream_kwargs) + if not self.lazy: + fdb = pyfdb.FDB(**self._fdb_kwargs) + if self.stream: + stream = fdb.retrieve(self.request) + return StreamSource(stream, **self._stream_kwargs) + else: + return FDBFileSource(fdb, self.request) else: - return FDBFileSource(fdb, self.request) + mapper = RequestMappaper(self._fdb_kwargs, self.request) + retriever = FdbRetriever(self._fdb_kwargs) + from earthkit.data.readers.grib.virtual import VirtualGribFieldList + + return VirtualGribFieldList(mapper, retriever) class FDBFileSource(FileSource): @@ -89,4 +98,44 @@ def retrieve(target, request): ) +class FdbRetriever: + def __init__(self, fdb_kwargs): + self.fdb_kwargs = fdb_kwargs + + def get(self, request): + fdb = pyfdb.FDB(**self.fdb_kwargs) + s = FDBFileSource(fdb, request) + return s.path + + +class RequestMappaper: + def __init__(self, fdb_kwargs, request, **kwargs): + self.fdb_kwargs = fdb_kwargs + self.request = request + self.md = { + "stepRange": "step", + "typeOfLevel": "leveltype", + "level": "levelist", + "dataDate": "date", + "dataTime": "time", + } + + @cached_property + def field_requests(self): + return self._scan() + + def _scan(self): + r = [] + fdb = pyfdb.FDB(**self.fdb_kwargs) + for el in fdb.list(self.request, True, True): + r.append(el["keys"]) + return r + + def request_at(self, index): + return self.field_requests[index] + + def __len__(self): + return len(self.field_requests) + + source = FDBSource From 36ed2cbd8b798e85a81a88cc9d511939b5944ba8 Mon Sep 17 00:00:00 2001 From: Sandor Kertesz Date: Tue, 22 Apr 2025 11:22:45 +0100 Subject: [PATCH 3/5] Add test --- src/earthkit/data/readers/grib/virtual.py | 8 +-- tests/lazy/test_lazy_fdb.py | 87 +++++++++++++++++++++++ 2 files changed, 90 insertions(+), 5 deletions(-) create mode 100644 tests/lazy/test_lazy_fdb.py diff --git a/src/earthkit/data/readers/grib/virtual.py b/src/earthkit/data/readers/grib/virtual.py index d4544e843..809b23604 100644 --- a/src/earthkit/data/readers/grib/virtual.py +++ b/src/earthkit/data/readers/grib/virtual.py @@ -88,7 +88,7 @@ def metadata(self, *keys, astype=None, remapping=None, patches=None, **kwargs): **kwargs, ) - def _one_metadata(self, key, **kwargs): + def _one_metadata(self, key, remapping=None, patches=None, **kwargs): # print(f"one_metadata key={key} kwargs={kwargs}") if key in self.extra: return self.extra[key] @@ -109,7 +109,8 @@ def _one_metadata(self, key, **kwargs): if key in ("valid_datetime", "valid_time"): return self._valid_datetime.isoformat() - return super().metadata(key, **kwargs) + return self._metadata.get(key, **kwargs) + # return super().one_metadata(key, **kwargs) @property def _metadata(self): @@ -146,9 +147,6 @@ def __len__(self): def mutate(self): return self - # def ls(self): - # return self.reference.ls() - def _getitem(self, n): if isinstance(n, int): if n < 0: diff --git a/tests/lazy/test_lazy_fdb.py b/tests/lazy/test_lazy_fdb.py new file mode 100644 index 000000000..7829496ad --- /dev/null +++ b/tests/lazy/test_lazy_fdb.py @@ -0,0 +1,87 @@ +#!/usr/bin/env python3 + +# (C) Copyright 2020 ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +# + +import os + +import numpy as np +import pytest + +from earthkit.data import from_source +from earthkit.data.core.temporary import temp_directory +from earthkit.data.testing import NO_FDB +from earthkit.data.testing import earthkit_test_data_file + +TEST_GRIB_REQUEST = { + "class": "od", + "expver": "0001", + "stream": "oper", + "date": [20240603, 20240604], + "time": [0, 1200], + "domain": "g", + "type": "fc", + "levtype": "pl", + "levelist": [500, 700], + "step": [0, 6], + "param": [130, 157], +} + + +def make_fdb_config(path): + fdb_schema = earthkit_test_data_file("fdb_schema.txt") + fdb_dir = path + os.makedirs(fdb_dir, exist_ok=True) + config = { + "type": "local", + "engine": "toc", + "schema": fdb_schema, + "spaces": [{"handler": "Default", "roots": [{"path": fdb_dir}]}], + } + return config + + +def make_fdb(path): + ds = from_source("sample", "pl.grib") + config = make_fdb_config(path) + ds.to_target("fdb", config=config) + return ds, config + + +@pytest.mark.skipif(NO_FDB, reason="No access to FDB") +@pytest.mark.cache +def test_lazy_fdb(): + with temp_directory() as tmpdir: + ds, config = make_fdb(os.path.join(tmpdir, "_fdb")) + + ds = from_source("fdb", TEST_GRIB_REQUEST, config=config, stream=False, lazy=True) + assert len(ds) == 32 + + assert ds[0].shape == (19, 36) + assert ds[1].shape == (19, 36) + assert ds[0].metadata(["shortName", "param", "units", "cfName"]) == ["t", "t", "K", "air_temperature"] + assert ds[1].metadata(["shortName", "param", "units", "cfName"]) == [ + "r", + "r", + "%", + "relative_humidity", + ] + + assert not hasattr(ds, "path") + assert not hasattr(ds[0], "path") + + a = ds.to_xarray(time_dim_mode="forecast") + assert a["t"].values.shape == (4, 2, 2, 19, 36) + assert a["r"].values.shape == (4, 2, 2, 19, 36) + + m = a.mean("step").load() + assert m["t"].values.shape == (4, 2, 19, 36) + assert m["r"].values.shape == (4, 2, 19, 36) + assert np.allclose(m["r"].values.flatten()[85:87], [47.66908598, 53.43959379]) + assert np.allclose(m["t"].values.flatten()[85:87], [253.22625732, 252.78778076]) From 0792b71d950a28cd9e286077d2ae5bbe430f42cf Mon Sep 17 00:00:00 2001 From: Sandor Kertesz Date: Tue, 6 May 2025 17:58:54 +0100 Subject: [PATCH 4/5] Implement virtual fields --- src/earthkit/data/readers/grib/virtual.py | 35 ++++++++++++++-------- src/earthkit/data/sources/fdb.py | 30 +++++++------------ src/earthkit/data/utils/request.py | 36 +++++++++++++++++++++++ 3 files changed, 68 insertions(+), 33 deletions(-) create mode 100644 src/earthkit/data/utils/request.py diff --git a/src/earthkit/data/readers/grib/virtual.py b/src/earthkit/data/readers/grib/virtual.py index 809b23604..e90da5331 100644 --- a/src/earthkit/data/readers/grib/virtual.py +++ b/src/earthkit/data/readers/grib/virtual.py @@ -24,10 +24,10 @@ class VirtualGribField(Field): - def __init__(self, owner, request, md, reference=None): + def __init__(self, owner, request, metadata_alias, reference=None): self.owner = owner self.request = request - self.md = md + self.metadata_alias = metadata_alias self.reference = reference self.extra = {} @@ -94,8 +94,8 @@ def _one_metadata(self, key, remapping=None, patches=None, **kwargs): return self.extra[key] if key in self.request: return self.request[key] - if key in self.md and key in self.request: - return self.request[self.md[key]] + if key in self.metadata_alias and key in self.request: + return self.request[self.metadata_alias[key]] if key == "number": return 0 @@ -110,12 +110,11 @@ def _one_metadata(self, key, remapping=None, patches=None, **kwargs): return self._valid_datetime.isoformat() return self._metadata.get(key, **kwargs) - # return super().one_metadata(key, **kwargs) @property def _metadata(self): r = {**self.request, **self.extra} - for k, v in self.md.items(): + for k, v in self.metadata_alias.items(): if k not in r and v in r: r[k] = r[v] @@ -135,14 +134,15 @@ def _field(self): class VirtualGribFieldList(GribFieldList): def __init__(self, request_mapper, retriever): - self.mapper = request_mapper + self.request_mapper = request_mapper self.retriever = retriever - path = self.retriever.get(self.mapper.request_at(0)) + path = self.retriever.get(self.request_mapper.request_at(0)) self.reference = from_source("file", path)[0] + self._info_cache = {} def __len__(self): - return len(self.mapper) + return len(self.request_mapper) def mutate(self): return self @@ -155,13 +155,22 @@ def _getitem(self, n): raise IndexError(f"Index {n} out of range") return VirtualGribField( - self, self.mapper.request_at(n), self.mapper.md, reference=self.reference if n == 0 else None + self, + self.request_mapper.request_at(n), + self.request_mapper.metadata_alias, + reference=self.reference if n == 0 else None, ) def _get_info(self, param): - ref_request = self.mapper.request_at(0) + if param in self._info_cache: + return self._info_cache[param] + + ref_request = self.request_mapper.request_at(0) if param == ref_request.get("param"): - return self.reference._attributes(["shortName", "name", "units", "cfName"]) + r = self.reference._attributes(["shortName", "name", "units", "cfName"]) else: md = self.reference.metadata().override(paramId=param) - return {k: md.get(k, None) for k in ["shortName", "name", "units", "cfName"]} + r = {k: md.get(k, None) for k in ["shortName", "name", "units", "cfName"]} + + self._info_cache[param] = r + return r diff --git a/src/earthkit/data/sources/fdb.py b/src/earthkit/data/sources/fdb.py index 9c839cf79..2862311c2 100644 --- a/src/earthkit/data/sources/fdb.py +++ b/src/earthkit/data/sources/fdb.py @@ -10,7 +10,6 @@ import logging import os import shutil -from functools import cached_property try: import pyfdb @@ -19,6 +18,7 @@ from earthkit.data.sources.file import FileSource from earthkit.data.sources.stream import StreamSource +from earthkit.data.utils.request import RequestMapper from . import Source @@ -74,8 +74,8 @@ def mutate(self): else: return FDBFileSource(fdb, self.request) else: - mapper = RequestMappaper(self._fdb_kwargs, self.request) - retriever = FdbRetriever(self._fdb_kwargs) + mapper = FDBRequestMapper(self.request, fdb_kwargs=self._fdb_kwargs) + retriever = FDBRetriever(self._fdb_kwargs) from earthkit.data.readers.grib.virtual import VirtualGribFieldList return VirtualGribFieldList(mapper, retriever) @@ -98,7 +98,7 @@ def retrieve(target, request): ) -class FdbRetriever: +class FDBRetriever: def __init__(self, fdb_kwargs): self.fdb_kwargs = fdb_kwargs @@ -108,11 +108,11 @@ def get(self, request): return s.path -class RequestMappaper: - def __init__(self, fdb_kwargs, request, **kwargs): - self.fdb_kwargs = fdb_kwargs - self.request = request - self.md = { +class FDBRequestMapper(RequestMapper): + def __init__(self, request, fdb_kwargs=None, **kwargs): + super().__init__(request, **kwargs) + self.fdb_kwargs = fdb_kwargs or {} + self.metadata_alias = { "stepRange": "step", "typeOfLevel": "leveltype", "level": "levelist", @@ -120,22 +120,12 @@ def __init__(self, fdb_kwargs, request, **kwargs): "dataTime": "time", } - @cached_property - def field_requests(self): - return self._scan() - - def _scan(self): + def _build(self): r = [] fdb = pyfdb.FDB(**self.fdb_kwargs) for el in fdb.list(self.request, True, True): r.append(el["keys"]) return r - def request_at(self, index): - return self.field_requests[index] - - def __len__(self): - return len(self.field_requests) - source = FDBSource diff --git a/src/earthkit/data/utils/request.py b/src/earthkit/data/utils/request.py new file mode 100644 index 000000000..9700854d8 --- /dev/null +++ b/src/earthkit/data/utils/request.py @@ -0,0 +1,36 @@ +# (C) Copyright 2022 ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +# + +import logging +from abc import ABCMeta +from abc import abstractmethod +from functools import cached_property + +LOG = logging.getLogger(__name__) + + +class RequestMapper(metaclass=ABCMeta): + metadata_alias = None + + def __init__(self, request, **kwargs): + self.request = request + + @cached_property + def field_requests(self): + return self._build() + + @abstractmethod + def _build(self): + pass + + def request_at(self, index): + return self.field_requests[index] + + def __len__(self): + return len(self.field_requests) From 9c5a256dcd15d53868c3101d477c9defb4df7165 Mon Sep 17 00:00:00 2001 From: Sandor Kertesz Date: Tue, 6 May 2025 18:00:31 +0100 Subject: [PATCH 5/5] Implement virtual fields --- src/earthkit/data/readers/grib/virtual.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/earthkit/data/readers/grib/virtual.py b/src/earthkit/data/readers/grib/virtual.py index e90da5331..21a290953 100644 --- a/src/earthkit/data/readers/grib/virtual.py +++ b/src/earthkit/data/readers/grib/virtual.py @@ -56,7 +56,6 @@ def _valid_datetime(self): return base def _attributes(self, names, remapping=None, joiner=None, default=None): - # print("CALLED") result = {} metadata = self.metadata if remapping is not None: @@ -67,11 +66,9 @@ def _attributes(self, names, remapping=None, joiner=None, default=None): return result def metadata(self, *keys, astype=None, remapping=None, patches=None, **kwargs): - # print(f"metadata keys={keys} {kwargs=}") if (not kwargs or kwargs == {"default": None}) and keys: if isinstance(keys[0], (list, tuple)): keys = keys[0] - # print(f" -> keys={keys}") if keys and isinstance(keys[0], str): r = [] for k in keys: