diff --git a/.github/ci/recipe.yaml b/.github/ci/recipe.yaml index 2a887e8b0..388c60b2d 100644 --- a/.github/ci/recipe.yaml +++ b/.github/ci/recipe.yaml @@ -36,6 +36,8 @@ requirements: - numpy >=2.1.0 - tqdm >=4.50.0 - xarray >=2025.8.0,<2026.4.0 # TODO: remove upper pin when https://github.com/UXARRAY/uxarray/issues/1490 is resolved + - pandas >=2.2 + - pyarrow >=20.0.0 - cf_xarray >=0.8.6 - xgcm >=0.9.0 - zarr >=2.15.0,!=2.18.0,<3 diff --git a/.gitignore b/.gitignore index 45b79435b..6f07f7d82 100644 --- a/.gitignore +++ b/.gitignore @@ -14,6 +14,7 @@ out-* *.pyc **/*.zarr/* .DS_Store +*.parquet .vscode .env diff --git a/docs/user_guide/v4-migration.md b/docs/user_guide/v4-migration.md index 7e33be83c..0e06bbc7b 100644 --- a/docs/user_guide/v4-migration.md +++ b/docs/user_guide/v4-migration.md @@ -36,7 +36,9 @@ Version 4 of Parcels is unreleased at the moment. The information in this migrat ## ParticleFile - Particlefiles should be created by `ParticleFile(...)` instead of `pset.ParticleFile(...)` -- The `name` argument in `ParticleFile` has been replaced by `store` and can now be a string, a Path or a zarr store. +- `ParticleFile` output is now in Parquet format +- `ParticleFile` writing behaviour now errors out if there's existing output (this be being further discussed in https://github.com/Parcels-code/Parcels/issues/2593 ) +- A utility to read in ParticleFile output is now available. `parcels.read_particlefile()` ## Field diff --git a/pixi.toml b/pixi.toml index a71be98ff..286d7f28c 100644 --- a/pixi.toml +++ b/pixi.toml @@ -24,6 +24,8 @@ netcdf4 = ">=1.6.0" numpy = ">=2.1.0" tqdm = ">=4.50.0" xarray = ">=2024.5.0,<2026.4.0" # TODO: remove upper pin when https://github.com/UXARRAY/uxarray/issues/1490 is resolved +pandas = ">=2.2" +pyarrow = ">=20.0.0" holoviews = ">=1.22.0" # https://github.com/prefix-dev/rattler-build/issues/2326 uxarray = ">=2025.3.0" dask = ">=2024.5.1" @@ -51,6 +53,8 @@ netcdf4 = "1.6.*" numpy = "2.1.*" tqdm = "4.50.*" xarray = "2025.8.*" +pandas = "2.2.*" +pyarrow = "20.0.*" uxarray = "2025.3.*" dask = "2024.6.*" zarr = "2.18.*" diff --git a/pyproject.toml b/pyproject.toml index 85aba3a67..072da3b77 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,6 +27,8 @@ dependencies = [ "zarr >=2.15.0,!=2.18.0,<3", "tqdm >=4.50.0", "xarray >=2024.5.0,<2026.4.0", # TODO: remove upper pin when https://github.com/UXARRAY/uxarray/issues/1490 is resolved + "pandas >= 2.2", + "pyarrow >=20.0.0", "uxarray >=2025.3.0", "pooch >=1.8.0", "xgcm >=0.9.0", diff --git a/src/parcels/__init__.py b/src/parcels/__init__.py index 2a7854cde..7ae1f6928 100644 --- a/src/parcels/__init__.py +++ b/src/parcels/__init__.py @@ -11,7 +11,7 @@ from parcels._core.fieldset import FieldSet from parcels._core.particleset import ParticleSet -from parcels._core.particlefile import ParticleFile +from parcels._core.particlefile import ParticleFile, read_particlefile from parcels._core.particle import ( Variable, Particle, @@ -67,6 +67,7 @@ "ParticleSetWarning", # Utilities "logger", + "read_particlefile", ] _stdlib_warnings.warn( diff --git a/src/parcels/_core/particle.py b/src/parcels/_core/particle.py index dc39a063c..15c55a519 100644 --- a/src/parcels/_core/particle.py +++ b/src/parcels/_core/particle.py @@ -8,7 +8,6 @@ from parcels._compat import _attrgetter_helper from parcels._core.statuscodes import StatusCode from parcels._core.utils.string import _assert_str_and_python_varname -from parcels._core.utils.time import TimeInterval from parcels._reprs import particleclass_repr, variable_repr __all__ = ["Particle", "ParticleClass", "Variable"] @@ -149,7 +148,11 @@ def get_default_particle(spatial_dtype: type[np.float32] | type[np.float64]) -> Variable( "time", dtype=np.float64, - attrs={"standard_name": "time", "units": "seconds", "axis": "T"}, + attrs={ + "standard_name": "time", + "units": "seconds", + "axis": "T", + }, # "units" and "calendar" gets updated/set if working with cftime time domain ), Variable( "trajectory", @@ -160,7 +163,6 @@ def get_default_particle(spatial_dtype: type[np.float32] | type[np.float64]) -> "cf_role": "trajectory_id", }, ), - Variable("obs_written", dtype=np.int32, initial=0, to_write=False), Variable("dt", dtype=np.float64, initial=1.0, to_write=False), Variable("state", dtype=np.int32, initial=StatusCode.Evaluate, to_write=False), ] @@ -176,7 +178,6 @@ def create_particle_data( pclass: ParticleClass, nparticles: int, ngrids: int, - time_interval: TimeInterval, initial: dict[str, np.ndarray] | None = None, ): if initial is None: @@ -207,16 +208,9 @@ def create_particle_data( name_to_copy = var.initial(_attrgetter_helper) data[var.name] = data[name_to_copy].copy() else: - data[var.name] = _create_array_for_variable(var, nparticles, time_interval) + data[var.name] = np.full( + shape=(nparticles,), + fill_value=var.initial, + dtype=var.dtype, + ) return data - - -def _create_array_for_variable(variable: Variable, nparticles: int, time_interval: TimeInterval): - assert not isinstance(variable.initial, operator.attrgetter), ( - "This function cannot handle attrgetter initial values." - ) - return np.full( - shape=(nparticles,), - fill_value=variable.initial, - dtype=variable.dtype, - ) diff --git a/src/parcels/_core/particlefile.py b/src/parcels/_core/particlefile.py index 788c6e572..90384f356 100644 --- a/src/parcels/_core/particlefile.py +++ b/src/parcels/_core/particlefile.py @@ -2,21 +2,21 @@ from __future__ import annotations -import os -from datetime import datetime, timedelta +from datetime import timedelta from pathlib import Path -from typing import TYPE_CHECKING, Literal +from typing import TYPE_CHECKING, Any, Literal -import cftime import numpy as np +import pandas as pd +import pyarrow as pa +import pyarrow.parquet as pq import xarray as xr -import zarr -from zarr.storage import DirectoryStore import parcels from parcels._core.particle import ParticleClass from parcels._core.utils.time import timedelta_to_float from parcels._reprs import particlefile_repr +from parcels._typing import PathLike if TYPE_CHECKING: from parcels._core.particle import Variable @@ -25,20 +25,25 @@ __all__ = ["ParticleFile"] -_DATATYPES_TO_FILL_VALUES = { - np.dtype(np.float16): np.nan, - np.dtype(np.float32): np.nan, - np.dtype(np.float64): np.nan, - np.dtype(np.bool_): np.iinfo(np.int8).max, - np.dtype(np.int8): np.iinfo(np.int8).max, - np.dtype(np.int16): np.iinfo(np.int16).max, - np.dtype(np.int32): np.iinfo(np.int32).max, - np.dtype(np.int64): np.iinfo(np.int64).min, - np.dtype(np.uint8): np.iinfo(np.uint8).max, - np.dtype(np.uint16): np.iinfo(np.uint16).max, - np.dtype(np.uint32): np.iinfo(np.uint32).max, - np.dtype(np.uint64): np.iinfo(np.uint64).max, -} + +def _get_schema( + particle: parcels.ParticleClass, file_metadata: dict[Any, Any], fset_time_interval: TimeInterval | None +) -> pa.Schema: + + fields = [] + for v in _get_vars_to_write(particle): + attrs = v.attrs.copy() + if v.name == "time": + if fset_time_interval is not None: + attrs.update(fset_time_interval.get_cf_attrs()) + fields.append( + pa.field( + v.name, + pa.from_numpy_dtype(v.dtype), + metadata=attrs, + ) + ) + return pa.schema(fields, metadata=file_metadata.copy()) class ParticleFile: @@ -46,18 +51,12 @@ class ParticleFile: Parameters ---------- - name : str - Basename of the output file. This can also be a Zarr store object. - particleset : - ParticleSet to output + path : PathLike + Path of the output Parquet file. outputdt : Interval which dictates the update frequency of file output while ParticleFile is given as an argument of ParticleSet.execute() It is either a numpy.timedelta64, a datimetime.timedelta object or a positive float (in seconds). - chunks : - Tuple (trajs, obs) to control the size of chunks in the zarr output. - create_new_zarrfile : bool - Whether to create a new file. Default is True Returns ------- @@ -65,36 +64,34 @@ class ParticleFile: ParticleFile object that can be used to write particle data to file """ - def __init__(self, store, outputdt, chunks=None, create_new_zarrfile=True): + def __init__(self, path: PathLike, outputdt): if not isinstance(outputdt, (np.timedelta64, timedelta, float)): raise ValueError( f"Expected outputdt to be a np.timedelta64, datetime.timedelta or float (in seconds), got {type(outputdt)}" ) outputdt = timedelta_to_float(outputdt) + path = Path(path) + + if path.suffix != ".parquet": + raise ValueError( + f"ParticleFile data is stored in Parquet files - file extension must be '.parquet'. Got {path.suffix=!r}." + ) if outputdt <= 0: raise ValueError(f"outputdt must be positive/non-zero. Got {outputdt=!r}") self._outputdt = outputdt - _assert_valid_chunks_tuple(chunks) - self._chunks = chunks - self._maxids = 0 - self._pids_written = {} - self.metadata = {} - self._create_new_zarrfile = create_new_zarrfile - - if not isinstance(store, zarr.storage.Store): - store = _get_store_from_pathlike(store) - - self._store = store + self._path = path # TODO v4: Consider https://arrow.apache.org/docs/python/getstarted.html#working-with-large-data - though a significant question becomes how to partition, perhaps using a particle variable "partition"? + self._writer: pq.ParquetWriter | None = None + if path.exists(): + # TODO: Add logic for recovering/appending to existing parquet file + raise ValueError(f"{path=!r} already exists. Either delete this file or use a path that doesn't exist.") + if not path.parent.exists(): + raise ValueError(f"Folder location for {path=!r} does not exist. Create the folder location first.") - # TODO v4: Enable once updating to zarr v3 - # if store.read_only: - # raise ValueError(f"Store {store} is read-only. Please provide a writable store.") - - # TODO v4: Add check that if create_new_zarrfile is False, the store already exists + self.metadata = {} def __repr__(self) -> str: return particlefile_repr(self) @@ -115,31 +112,8 @@ def outputdt(self): return self._outputdt @property - def chunks(self): - return self._chunks - - @property - def store(self): - return self._store - - @property - def create_new_zarrfile(self): - return self._create_new_zarrfile - - def _extend_zarr_dims(self, Z, store, dtype, axis): # noqa: N803 - if axis == 1: - a = np.full((Z.shape[0], self.chunks[1]), _DATATYPES_TO_FILL_VALUES[dtype], dtype=dtype) - obs = zarr.group(store=store, overwrite=False)["obs"] - if len(obs) == Z.shape[1]: - obs.append(np.arange(self.chunks[1]) + obs[-1] + 1) - else: - extra_trajs = self._maxids - Z.shape[0] - if len(Z.shape) == 2: - a = np.full((extra_trajs, Z.shape[1]), _DATATYPES_TO_FILL_VALUES[dtype], dtype=dtype) - else: - a = np.full((extra_trajs,), _DATATYPES_TO_FILL_VALUES[dtype], dtype=dtype) - Z.append(a, axis=axis) - zarr.consolidate_metadata(store) + def path(self): + return self._path def write(self, pset: ParticleSet, time, indices=None): """Write all data from one time step to the zarr file, @@ -156,124 +130,32 @@ def write(self, pset: ParticleSet, time, indices=None): time_interval = pset.fieldset.time_interval particle_data = pset._data - self._write_particle_data( - particle_data=particle_data, pclass=pclass, time_interval=time_interval, time=time, indices=indices - ) + if self._writer is None: + assert not self.path.exists(), "If the file exists, the writer should already be set" + self._writer = pq.ParquetWriter(self.path, _get_schema(pclass, self.metadata, pset.fieldset.time_interval)) - def _write_particle_data(self, *, particle_data, pclass, time_interval, time, indices=None): - # if pset._data._ncount == 0: - # warnings.warn( - # f"ParticleSet is empty on writing as array at time {time:g}", - # RuntimeWarning, - # stacklevel=2, - # ) - # return if isinstance(time, (np.timedelta64, np.datetime64)): time = timedelta_to_float(time - time_interval.left) - nparticles = len(particle_data["trajectory"]) vars_to_write = _get_vars_to_write(pclass) if indices is None: indices_to_write = _to_write_particles(particle_data, time) else: indices_to_write = indices - if len(indices_to_write) == 0: - return - - pids = particle_data["trajectory"][indices_to_write] - to_add = sorted(set(pids) - set(self._pids_written.keys())) - for i, pid in enumerate(to_add): - self._pids_written[pid] = self._maxids + i - ids = np.array([self._pids_written[p] for p in pids], dtype=int) - self._maxids = len(self._pids_written) - - once_ids = np.where(particle_data["obs_written"][indices_to_write] == 0)[0] - if len(once_ids) > 0: - ids_once = ids[once_ids] - indices_to_write_once = indices_to_write[once_ids] - - store = self.store - if self.create_new_zarrfile: - if self.chunks is None: - self._chunks = (nparticles, 1) - if (self._maxids > len(ids)) or (self._maxids > self.chunks[0]): - arrsize = (self._maxids, self.chunks[1]) - else: - arrsize = (len(ids), self.chunks[1]) - ds = xr.Dataset( - attrs=self.metadata, - coords={"trajectory": ("trajectory", pids), "obs": ("obs", np.arange(arrsize[1], dtype=np.int32))}, - ) - attrs = _create_variables_attribute_dict(pclass, time_interval) - obs = np.zeros((self._maxids), dtype=np.int32) - for var in vars_to_write: - if var.name not in ["trajectory"]: # because 'trajectory' is written as coordinate - if var.to_write == "once": - data = np.full( - (arrsize[0],), - _DATATYPES_TO_FILL_VALUES[var.dtype], - dtype=var.dtype, - ) - data[ids_once] = particle_data[var.name][indices_to_write_once] - dims = ["trajectory"] - else: - data = np.full(arrsize, _DATATYPES_TO_FILL_VALUES[var.dtype], dtype=var.dtype) - data[ids, 0] = particle_data[var.name][indices_to_write] - dims = ["trajectory", "obs"] - ds[var.name] = xr.DataArray(data=data, dims=dims, attrs=attrs[var.name]) - ds[var.name].encoding["chunks"] = self.chunks[0] if var.to_write == "once" else self.chunks - ds.to_zarr(store, mode="w") - self._create_new_zarrfile = False - else: - Z = zarr.group(store=store, overwrite=False) - obs = particle_data["obs_written"][indices_to_write] - for var in vars_to_write: - if self._maxids > Z[var.name].shape[0]: - self._extend_zarr_dims(Z[var.name], store, dtype=var.dtype, axis=0) - if var.to_write == "once": - if len(once_ids) > 0: - Z[var.name].vindex[ids_once] = particle_data[var.name][indices_to_write_once] - else: - if max(obs) >= Z[var.name].shape[1]: - self._extend_zarr_dims(Z[var.name], store, dtype=var.dtype, axis=1) - Z[var.name].vindex[ids, obs] = particle_data[var.name][indices_to_write] - - particle_data["obs_written"][indices_to_write] = obs + 1 - - -def _get_store_from_pathlike(path: Path | str) -> DirectoryStore: - path = str(Path(path)) # Ensure valid path, and convert to string - extension = os.path.splitext(path)[1] - if extension != ".zarr": - raise ValueError(f"ParticleFile name must end with '.zarr' extension. Got path {path!r}.") + self._writer.write_table( + pa.table({v.name: pa.array(particle_data[v.name][indices_to_write]) for v in vars_to_write}), + ) - return DirectoryStore(path) + def close(self): + if self._writer is not None: + self._writer.close() + self._writer = None def _get_vars_to_write(particle: ParticleClass) -> list[Variable]: return [v for v in particle.variables if v.to_write is not False] -def _create_variables_attribute_dict(particle: ParticleClass, time_interval: TimeInterval) -> dict: - """Creates the dictionary with variable attributes. - - Notes - ----- - For ParticleSet structures other than SoA, and structures where ID != index, this has to be overridden. - """ - attrs = {} - - vars = [var for var in particle.variables if var.to_write is not False] - for var in vars: - fill_value = {"_FillValue": _DATATYPES_TO_FILL_VALUES[var.dtype]} - - attrs[var.name] = {**var.attrs, **fill_value} - - attrs["time"].update(_get_calendar_and_units(time_interval)) - - return attrs - - def _to_write_particles(particle_data, time): """Return the Particles that need to be written at time: if particle.time is between time-dt/2 and time+dt (/2)""" return np.where( @@ -282,15 +164,17 @@ def _to_write_particles(particle_data, time): time - np.abs(particle_data["dt"] / 2), particle_data["time"], where=np.isfinite(particle_data["time"]), + out=None, ) & np.greater_equal( time + np.abs(particle_data["dt"] / 2), particle_data["time"], where=np.isfinite(particle_data["time"]), + out=None, ) # check time - dt/2 <= particle_data["time"] <= time + dt/2 | ( (np.isnan(particle_data["dt"])) - & np.equal(time, particle_data["time"], where=np.isfinite(particle_data["time"])) + & np.equal(time, particle_data["time"], where=np.isfinite(particle_data["time"]), out=None) ) # or dt is NaN and time matches particle_data["time"] ) & (np.isfinite(particle_data["trajectory"])) @@ -298,33 +182,64 @@ def _to_write_particles(particle_data, time): )[0] -def _get_calendar_and_units(time_interval: TimeInterval) -> dict[str, str]: - calendar = None - units = "seconds" - if time_interval: - if isinstance(time_interval.left, (np.datetime64, datetime)): - calendar = "standard" - elif isinstance(time_interval.left, cftime.datetime): - calendar = time_interval.left.calendar +def read_particlefile(path: PathLike, decode_times: bool = True) -> pd.DataFrame: + """Read a Parcels particlefile (Parquet format) into a pandas DataFrame. + + Parameters + ---------- + path : PathLike + Path to the ``.parquet`` particlefile. + decode_times : bool, optional + If ``True`` (default), use Xarray to decode the numeric ``time`` column from CF + conventions into ``datetime`` or ``cftime.datetime`` values using the units stored in + the column metadata. If ``False``, the raw numeric values are + returned unchanged. + + Returns + ------- + pd.DataFrame + DataFrame containing the particle data. When *decode_times* is + ``True``, the ``time`` column contains datetime-like values; + otherwise it contains the original numeric representation. + + Notes + ----- + For larger datasets, consider using `Polars `_ directly, + e.g. ``polars.read_parquet(path)``, which offers better performance and lower + memory usage than pandas for large Parquet files. + """ + path = Path(path) + + assert path.suffix == ".parquet", "Only Parquet files are supported" + + table = pq.read_table(path) + + try: + time_field = table.field("time") + except KeyError as e: + raise ValueError( + f"Could not find 'time' column in parquet file. Are you sure {path=!r} is a particlefile?" + ) from e + + assert pa.types.is_floating(time_field.type) or pa.types.is_integer(time_field.type), ( + f"'time' column must be numeric, got {time_field.type}" + ) - if calendar is not None: - units += f" since {time_interval.left}" + try: + assert b"units" in time_field.metadata + except AssertionError as e: + raise ValueError(f"Could not find 'units' in the 'time' column metadata for parquet {path=!r}.") from e - attrs = {"units": units} - if calendar is not None: - attrs["calendar"] = calendar + attrs = {k.decode(): v.decode() for k, v in time_field.metadata.items()} - return attrs + df = pd.read_parquet(path) + if not decode_times: + return df + values = table.column("time").to_numpy() + var = xr.Variable(("time",), values, attrs) + values = xr.coders.CFDatetimeCoder(time_unit="s").decode(var).values -def _assert_valid_chunks_tuple(chunks): - e = ValueError(f"chunks must be a tuple of integers with length 2, got {chunks=!r} instead.") - if chunks is None: - return + df["time"] = values - if not isinstance(chunks, tuple): - raise e - if len(chunks) != 2: - raise e - if not all(isinstance(c, int) for c in chunks): - raise e + return df diff --git a/src/parcels/_core/particleset.py b/src/parcels/_core/particleset.py index 5483ffbe4..b25c269a6 100644 --- a/src/parcels/_core/particleset.py +++ b/src/parcels/_core/particleset.py @@ -20,7 +20,7 @@ ) from parcels._core.warnings import ParticleSetWarning from parcels._logger import logger -from parcels._reprs import _format_zarr_output_location, particleset_repr +from parcels._reprs import particleset_repr __all__ = ["ParticleSet"] @@ -111,7 +111,6 @@ def __init__( pclass=pclass, nparticles=lon.size, ngrids=len(fieldset.gridset), - time_interval=fieldset.time_interval, initial=dict( lon=lon, lat=lat, @@ -415,7 +414,7 @@ def execute( # Set up pbar if output_file: - logger.info(f"Output files are stored in {_format_zarr_output_location(output_file.store)}") + logger.info(f"Output files are stored in {output_file.path}") if verbose_progress: pbar = tqdm(total=end_time - start_time, file=sys.stdout) @@ -451,6 +450,9 @@ def execute( time = next_time + if output_file is not None: + output_file.close() + if verbose_progress: pbar.close() diff --git a/src/parcels/_core/utils/time.py b/src/parcels/_core/utils/time.py index b76473a3f..bdce7fc13 100644 --- a/src/parcels/_core/utils/time.py +++ b/src/parcels/_core/utils/time.py @@ -1,7 +1,7 @@ from __future__ import annotations from datetime import datetime, timedelta -from typing import TYPE_CHECKING, TypeVar +from typing import TYPE_CHECKING, Literal, TypeVar, cast import cftime import numpy as np @@ -85,6 +85,39 @@ def intersection(self, other: TimeInterval) -> TimeInterval | None: return TimeInterval(start, end) if start <= end else None + def get_cf_attrs(self) -> dict[Literal["units", "calendar"], str]: + """Return the cf-attrs that would correspond to x seconds from the left edge.""" + return _get_cf_attrs(self.left) + + +def _get_cf_attrs(dt: TimeLike) -> dict[Literal["units", "calendar"], str]: + if isinstance(dt, cftime.datetime): + dt = cast(cftime.datetime, dt) + return {"units": f"seconds since {dt.strftime(dt.format)}", "calendar": dt.calendar} + + if isinstance(dt, np.timedelta64): + return {"units": "seconds"} + + from pandas import Timestamp + + if isinstance(dt, np.datetime64): + dt = Timestamp(dt) + + if isinstance(dt, (Timestamp, datetime)): + dt_cf = cftime.datetime( + year=dt.year, + month=dt.month, + day=dt.day, + hour=dt.hour, + minute=dt.minute, + second=dt.second, + microsecond=dt.microsecond, + calendar="gregorian", # What is the cftime proleptic_gregorian calendar? is that relevant here? + ) + return _get_cf_attrs(dt_cf) + + raise NotImplementedError(f"Not implemented for time object {type(dt)=!r}") + def is_compatible( t1: datetime | cftime.datetime | np.timedelta64, t2: datetime | cftime.datetime | np.timedelta64 diff --git a/src/parcels/_reprs.py b/src/parcels/_reprs.py index ad6d0cca2..d27eee379 100644 --- a/src/parcels/_reprs.py +++ b/src/parcels/_reprs.py @@ -7,7 +7,6 @@ import numpy as np import xarray as xr -from zarr.storage import DirectoryStore if TYPE_CHECKING: from parcels import Field, FieldSet, ParticleSet @@ -128,7 +127,7 @@ def timeinterval_repr(ti: Any) -> str: def particlefile_repr(pfile: Any) -> str: """Return a pretty repr for ParticleFile""" out = f"""<{type(pfile).__name__}> - store : {_format_zarr_output_location(pfile.store)} + path : {pfile.path} outputdt : {pfile.outputdt!r} chunks : {pfile.chunks!r} create_new_zarrfile : {pfile.create_new_zarrfile!r} @@ -178,11 +177,5 @@ def _format_list_items_multiline(items: list[str] | dict, level: int = 1, with_b return "\n".join([textwrap.indent(e, indentation_str) for e in entries]) -def _format_zarr_output_location(zarr_obj): - if isinstance(zarr_obj, DirectoryStore): - return zarr_obj.path - return repr(zarr_obj) - - def is_builtin_object(obj): return obj.__class__.__module__ == "builtins" diff --git a/tests/conftest.py b/tests/conftest.py index 82020c37e..0fd949880 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,13 +1,6 @@ import pytest -from zarr.storage import MemoryStore - - -@pytest.fixture() -def tmp_zarrfile(tmp_path, request): - test_name = request.node.name - yield tmp_path / f"{test_name}-output.zarr" @pytest.fixture -def tmp_store(): - return MemoryStore() +def tmp_parquet(tmp_path): + return tmp_path / "tmp.parquet" diff --git a/tests/test_advection.py b/tests/test_advection.py index d8c6d2a45..e5c000132 100644 --- a/tests/test_advection.py +++ b/tests/test_advection.py @@ -1,4 +1,5 @@ import numpy as np +import pandas as pd import pytest import xarray as xr @@ -36,7 +37,7 @@ AdvectionRK4_3D, AdvectionRK45, ) -from tests.utils import DEFAULT_PARTICLES +from tests.utils import DEFAULT_PARTICLES, assert_cftime_like_particlefile @pytest.mark.parametrize("mesh", ["spherical", "flat"]) @@ -60,7 +61,7 @@ def test_advection_zonal(mesh, npart=10): np.testing.assert_allclose(pset.lat, startlat, atol=1e-5) -def test_advection_zonal_with_particlefile(tmp_store): +def test_advection_zonal_with_particlefile(tmp_parquet): """Particles at high latitude move geographically faster due to the pole correction.""" npart = 10 ds = simple_UV_dataset(mesh="flat") @@ -68,12 +69,14 @@ def test_advection_zonal_with_particlefile(tmp_store): fieldset = FieldSet.from_sgrid_conventions(ds, mesh="flat") pset = ParticleSet(fieldset, lon=np.zeros(npart) + 20.0, lat=np.linspace(0, 80, npart)) - pfile = ParticleFile(tmp_store, outputdt=np.timedelta64(30, "m")) + pfile = ParticleFile(tmp_parquet, outputdt=np.timedelta64(30, "m")) pset.execute(AdvectionRK4, runtime=np.timedelta64(2, "h"), dt=np.timedelta64(15, "m"), output_file=pfile) assert (np.diff(pset.lon) < 1.0e-4).all() - ds = xr.open_zarr(tmp_store) - np.testing.assert_allclose(ds.isel(obs=-1).lon.values, pset.lon) + df = pd.read_parquet(tmp_parquet) + final_time = df["time"].max() + np.testing.assert_allclose(df[df["time"] == final_time]["lon"].values, pset.lon, atol=1e-5) + assert_cftime_like_particlefile(tmp_parquet) def periodicBC(particles, fieldset): diff --git a/tests/test_fieldset.py b/tests/test_fieldset.py index b2b05d33f..6eeef20a6 100644 --- a/tests/test_fieldset.py +++ b/tests/test_fieldset.py @@ -3,6 +3,7 @@ import cf_xarray # noqa: F401 import cftime import numpy as np +import pandas as pd import pytest import xarray as xr @@ -95,7 +96,7 @@ def test_fieldset_gridset(fieldset): assert len(fieldset.gridset) == 2 -def test_fieldset_no_UV(tmp_zarrfile): +def test_fieldset_no_UV(tmp_parquet): grid = XGrid.from_dataset(ds, mesh="flat") fieldset = FieldSet([Field("P", ds["U_A_grid"], grid, interp_method=XLinear)]) @@ -103,11 +104,11 @@ def SampleP(particles, fieldset): particles.dlon += fieldset.P[particles] pset = ParticleSet(fieldset, lon=0, lat=0) - ofile = ParticleFile(tmp_zarrfile, outputdt=np.timedelta64(1, "s")) + ofile = ParticleFile(tmp_parquet, outputdt=np.timedelta64(1, "s")) pset.execute(SampleP, runtime=np.timedelta64(1, "s"), dt=np.timedelta64(1, "s"), output_file=ofile) - ds_out = xr.open_zarr(tmp_zarrfile) - assert ds_out["lon"].shape == (1, 2) + df = pd.read_parquet(tmp_parquet) + assert len(df["lon"]) == 2 @pytest.mark.parametrize("ds", [pytest.param(ds, id=k) for k, ds in datasets_structured.items()]) diff --git a/tests/test_particle.py b/tests/test_particle.py index dabe6944c..62eb65cff 100644 --- a/tests/test_particle.py +++ b/tests/test_particle.py @@ -7,8 +7,6 @@ Variable, create_particle_data, ) -from parcels._core.utils.time import TimeInterval -from parcels._datasets.structured.generic import TIME def test_variable_init(): @@ -140,9 +138,8 @@ def test_particleclass_add_variable_collision(): ) @pytest.mark.parametrize("nparticles", [5, 10]) def test_create_particle_data(particle, nparticles): - time_interval = TimeInterval(TIME[0], TIME[-1]) ngrids = 4 - data = create_particle_data(pclass=particle, nparticles=nparticles, ngrids=ngrids, time_interval=time_interval) + data = create_particle_data(pclass=particle, nparticles=nparticles, ngrids=ngrids) assert isinstance(data, dict) assert len(data) == len(particle.variables) + 1 # ei variable is separate diff --git a/tests/test_particlefile.py b/tests/test_particlefile.py index 84cb90ffa..d782fb171 100755 --- a/tests/test_particlefile.py +++ b/tests/test_particlefile.py @@ -1,12 +1,13 @@ -import os import tempfile from contextlib import nullcontext as does_not_raise from datetime import datetime, timedelta import numpy as np +import pandas as pd +import pyarrow as pa +import pyarrow.parquet as pq import pytest import xarray as xr -from zarr.storage import MemoryStore import parcels.tutorial from parcels import ( @@ -20,7 +21,8 @@ VectorField, XGrid, ) -from parcels._core.particle import Particle, create_particle_data, get_default_particle +from parcels._core.particle import Particle, get_default_particle +from parcels._core.particlefile import _get_schema from parcels._core.utils.time import TimeInterval, timedelta_to_float from parcels._datasets.structured.generated import peninsula_dataset from parcels._datasets.structured.generic import datasets @@ -44,35 +46,17 @@ def fieldset() -> FieldSet: # TODO v4: Move into a `conftest.py` file and remov ) -def test_metadata(fieldset, tmp_zarrfile): +def test_metadata(fieldset, tmp_parquet): pset = ParticleSet(fieldset, pclass=Particle, lon=0, lat=0) - ofile = ParticleFile(tmp_zarrfile, outputdt=np.timedelta64(1, "s")) + ofile = ParticleFile(tmp_parquet, outputdt=np.timedelta64(1, "s")) pset.execute(DoNothing, runtime=np.timedelta64(1, "s"), dt=np.timedelta64(1, "s"), output_file=ofile) - ds = xr.open_zarr(tmp_zarrfile) - assert ds.attrs["parcels_kernels"].lower() == "DoNothing".lower() + tab = pq.read_table(tmp_parquet) + assert tab.schema.metadata[b"parcels_kernels"].decode().lower() == "DoNothing".lower() -def test_pfile_array_write_zarr_memorystore(fieldset): - """Check that writing to a Zarr MemoryStore works.""" - npart = 10 - zarr_store = MemoryStore() - pset = ParticleSet( - fieldset, - pclass=Particle, - lon=np.linspace(0, 1, npart), - lat=0.5 * np.ones(npart), - time=fieldset.time_interval.left, - ) - pfile = ParticleFile(zarr_store, outputdt=np.timedelta64(1, "s")) - pfile.write(pset, time=fieldset.time_interval.left) - - ds = xr.open_zarr(zarr_store) - assert ds.sizes["trajectory"] == npart - - -def test_write_fieldset_without_time(tmp_zarrfile): +def test_write_fieldset_without_time(tmp_parquet): ds = peninsula_dataset() # DataSet without time assert "time" not in ds.dims grid = XGrid.from_dataset(ds, mesh="flat") @@ -80,14 +64,17 @@ def test_write_fieldset_without_time(tmp_zarrfile): pset = ParticleSet(fieldset, pclass=Particle, lon=0, lat=0) - ofile = ParticleFile(tmp_zarrfile, outputdt=np.timedelta64(1, "s")) + ofile = ParticleFile(tmp_parquet, outputdt=np.timedelta64(1, "s")) pset.execute(DoNothing, runtime=np.timedelta64(1, "s"), dt=np.timedelta64(1, "s"), output_file=ofile) - ds = xr.open_zarr(tmp_zarrfile) - assert ds.time.values[0, 1] == np.timedelta64(1, "s") + table = pq.read_table(tmp_parquet) + assert table.schema.field("time").metadata[b"units"] == b"seconds" + assert b"calendar" not in table.schema.field("time").metadata + assert table["time"].to_numpy()[1] == 1.0 -def test_pfile_array_remove_particles(fieldset, tmp_zarrfile): +def test_pfile_array_remove_particles(fieldset, tmp_parquet): + """If a particle from the middle of a particleset is removed, that writing doesn't crash""" npart = 10 pset = ParticleSet( fieldset, @@ -96,20 +83,17 @@ def test_pfile_array_remove_particles(fieldset, tmp_zarrfile): lat=0.5 * np.ones(npart), time=fieldset.time_interval.left, ) - pfile = ParticleFile(tmp_zarrfile, outputdt=np.timedelta64(1, "s")) + pfile = ParticleFile(tmp_parquet, outputdt=np.timedelta64(1, "s")) pset._data["time"][:] = 0 pfile.write(pset, time=fieldset.time_interval.left) pset.remove_indices(3) new_time = 86400 # s in a day pset._data["time"][:] = new_time pfile.write(pset, new_time) - ds = xr.open_zarr(tmp_zarrfile) - timearr = ds["time"][:] - assert (np.isnat(timearr[3, 1])) and (np.isfinite(timearr[3, 0])) + pfile.close() -@pytest.mark.parametrize("chunks_obs", [1, None]) -def test_pfile_array_remove_all_particles(fieldset, chunks_obs, tmp_zarrfile): +def test_pfile_array_remove_all_particles(fieldset, tmp_parquet): npart = 10 pset = ParticleSet( fieldset, @@ -118,39 +102,19 @@ def test_pfile_array_remove_all_particles(fieldset, chunks_obs, tmp_zarrfile): lat=0.5 * np.ones(npart), time=fieldset.time_interval.left, ) - chunks = (npart, chunks_obs) if chunks_obs else None - pfile = ParticleFile(tmp_zarrfile, chunks=chunks, outputdt=np.timedelta64(1, "s")) + pfile = ParticleFile(tmp_parquet, outputdt=np.timedelta64(1, "s")) pfile.write(pset, time=0) for _ in range(npart): pset.remove_indices(-1) pfile.write(pset, fieldset.time_interval.left + np.timedelta64(1, "D")) pfile.write(pset, fieldset.time_interval.left + np.timedelta64(2, "D")) + pfile.close() - ds = xr.open_zarr(tmp_zarrfile) - np.testing.assert_allclose(ds["time"][:, 0] - fieldset.time_interval.left, np.timedelta64(0, "s")) - if chunks_obs is not None: - assert ds["time"][:].shape == chunks - else: - assert ds["time"][:].shape[0] == npart - assert np.all(np.isnan(ds["time"][:, 1:])) - + df = pd.read_parquet(tmp_parquet) + assert df["trajectory"].nunique() == npart -def test_variable_write_double(fieldset, tmp_zarrfile): - def Update_lon(particles, fieldset): # pragma: no cover - particles.dlon += 0.1 - dt = np.timedelta64(1, "s") - particle = get_default_particle(np.float64) - pset = ParticleSet(fieldset, pclass=particle, lon=[0], lat=[0]) - ofile = ParticleFile(tmp_zarrfile, outputdt=dt) - pset.execute(Update_lon, runtime=np.timedelta64(10, "s"), dt=dt, output_file=ofile) - - ds = xr.open_zarr(tmp_zarrfile) - lons = ds["lon"][:] - assert isinstance(lons.values[0, 0], np.float64) - - -def test_write_dtypes_pfile(fieldset, tmp_zarrfile): +def test_write_dtypes_pfile(fieldset, tmp_parquet): dtypes = [ np.float32, np.float64, @@ -169,14 +133,13 @@ def test_write_dtypes_pfile(fieldset, tmp_zarrfile): MyParticle = Particle.add_variable(extra_vars) pset = ParticleSet(fieldset, pclass=MyParticle, lon=0, lat=0, time=fieldset.time_interval.left) - pfile = ParticleFile(tmp_zarrfile, outputdt=np.timedelta64(1, "s")) + pfile = ParticleFile(tmp_parquet, outputdt=np.timedelta64(1, "s")) pfile.write(pset, time=fieldset.time_interval.left) + pfile.close() - ds = xr.open_zarr( - tmp_zarrfile, mask_and_scale=False - ) # Note masking issue at https://stackoverflow.com/questions/68460507/xarray-loading-int-data-as-float + tab = pq.read_table(tmp_parquet) for d in dtypes: - assert ds[f"v_{d.__name__}"].dtype == d + assert tab[f"v_{d.__name__}"].type == pa.from_numpy_dtype(d) def test_variable_written_once(): @@ -187,7 +150,7 @@ def test_variable_written_once(): @pytest.mark.skip(reason="Pending ParticleFile refactor; see issue #2386") @pytest.mark.parametrize("dt", [-np.timedelta64(1, "s"), np.timedelta64(1, "s")]) @pytest.mark.parametrize("maxvar", [2, 4, 10]) -def test_pset_repeated_release_delayed_adding_deleting(fieldset, tmp_zarrfile, dt, maxvar): +def test_pset_repeated_release_delayed_adding_deleting(fieldset, tmp_parquet, dt, maxvar): """Tests that if particles are released and deleted based on age that resulting output file is correct.""" npart = 10 fieldset.add_constant("maxvar", maxvar) @@ -203,7 +166,7 @@ def test_pset_repeated_release_delayed_adding_deleting(fieldset, tmp_zarrfile, d pclass=MyParticle, time=fieldset.time_interval.left + [np.timedelta64(i + 1, "s") for i in range(npart)], ) - pfile = ParticleFile(tmp_zarrfile, outputdt=abs(dt), chunks=(1, 1)) + pfile = ParticleFile(tmp_parquet, outputdt=abs(dt)) def IncrLon(particles, fieldset): # pragma: no cover particles.sample_var += 1.0 @@ -216,19 +179,17 @@ def IncrLon(particles, fieldset): # pragma: no cover for _ in range(npart): pset.execute(IncrLon, dt=dt, runtime=np.timedelta64(1, "s"), output_file=pfile) - ds = xr.open_zarr(tmp_zarrfile) + ds = xr.open_zarr(tmp_parquet) samplevar = ds["sample_var"][:] assert samplevar.shape == (npart, min(maxvar, npart + 1)) # test whether samplevar[:, k] = k for k in range(samplevar.shape[1]): assert np.allclose([p for p in samplevar[:, k] if np.isfinite(p)], k + 1) - filesize = os.path.getsize(str(tmp_zarrfile)) - assert filesize < 1024 * 65 # test that chunking leads to filesize less than 65KB -def test_file_warnings(fieldset, tmp_zarrfile): +def test_file_warnings(fieldset, tmp_parquet): pset = ParticleSet(fieldset, lon=[0, 0], lat=[0, 0], time=[np.timedelta64(0, "s"), np.timedelta64(1, "s")]) - pfile = ParticleFile(tmp_zarrfile, outputdt=np.timedelta64(2, "s")) + pfile = ParticleFile(tmp_parquet, outputdt=np.timedelta64(2, "s")) with pytest.warns(ParticleSetWarning, match="Some of the particles have a start time difference.*"): pset.execute(AdvectionRK4, runtime=3, dt=1, output_file=pfile) @@ -244,32 +205,33 @@ def test_file_warnings(fieldset, tmp_zarrfile): (-np.timedelta64(5, "s"), pytest.raises(ValueError)), ], ) -def test_outputdt_types(outputdt, expectation, tmp_zarrfile): +def test_outputdt_types(outputdt, expectation, tmp_parquet): with expectation: - pfile = ParticleFile(tmp_zarrfile, outputdt=outputdt) + pfile = ParticleFile(tmp_parquet, outputdt=outputdt) assert pfile.outputdt == timedelta_to_float(outputdt) -def test_write_timebackward(fieldset, tmp_zarrfile): +def test_write_timebackward(fieldset, tmp_parquet): release_time = fieldset.time_interval.left + [np.timedelta64(i + 1, "s") for i in range(3)] pset = ParticleSet(fieldset, lat=[0, 1, 2], lon=[0, 0, 0], time=release_time) - pfile = ParticleFile(tmp_zarrfile, outputdt=np.timedelta64(1, "s")) + pfile = ParticleFile(tmp_parquet, outputdt=np.timedelta64(1, "s")) pset.execute(DoNothing, runtime=np.timedelta64(3, "s"), dt=-np.timedelta64(1, "s"), output_file=pfile) - ds = xr.open_zarr(tmp_zarrfile) - trajs = ds["trajectory"][:] + df = pd.read_parquet(tmp_parquet) - output_time = ds["time"][:].values - - assert trajs.values.dtype == "int64" - assert np.all(np.diff(trajs.values) < 0) # all particles written in order of release - doutput_time = np.diff(output_time, axis=1) - assert np.all(doutput_time[~np.isnan(doutput_time)] < 0) # all times written in decreasing order + assert df["trajectory"].dtype == "int64" + assert bool( + df.groupby("trajectory") + .apply( + lambda x: (np.diff(x["time"]) < 0).all() # for each particle - set True if it has decreasing time + ) + .all() # ensure for all particles + ) @pytest.mark.xfail @pytest.mark.v4alpha -def test_write_xiyi(fieldset, tmp_zarrfile): +def test_write_xiyi(fieldset, tmp_parquet): fieldset.U.data[:] = 1 # set a non-zero zonal velocity fieldset.add_field( Field(name="P", data=np.zeros((3, 20)), lon=np.linspace(0, 1, 20), lat=[-2, 0, 2], interp_method=XLinear) @@ -300,10 +262,10 @@ def SampleP(particles, fieldset): # pragma: no cover _ = fieldset.P[particles] # To trigger sampling of the P field pset = ParticleSet(fieldset, pclass=XiYiParticle, lon=[0, 0.2], lat=[0.2, 1]) - pfile = ParticleFile(tmp_zarrfile, outputdt=dt) + pfile = ParticleFile(tmp_parquet, outputdt=dt) pset.execute([SampleP, Get_XiYi, AdvectionRK4], endtime=10 * dt, dt=dt, output_file=pfile) - ds = xr.open_zarr(tmp_zarrfile) + ds = xr.open_zarr(tmp_parquet) pxi0 = ds["pxi0"][:].values.astype(np.int32) pxi1 = ds["pxi1"][:].values.astype(np.int32) lons = ds["lon"][:].values @@ -323,7 +285,7 @@ def SampleP(particles, fieldset): # pragma: no cover @pytest.mark.parametrize("outputdt", [np.timedelta64(1, "s"), np.timedelta64(2, "s"), np.timedelta64(3, "s")]) -def test_time_is_age(fieldset, tmp_zarrfile, outputdt): +def test_time_is_age(fieldset, tmp_parquet, outputdt): # Test that particle age is same as time - initial_time npart = 10 @@ -334,19 +296,19 @@ def IncreaseAge(particles, fieldset): # pragma: no cover time = fieldset.time_interval.left + np.arange(npart) * np.timedelta64(1, "s") pset = ParticleSet(fieldset, pclass=AgeParticle, lon=npart * [0], lat=npart * [0], time=time) - ofile = ParticleFile(tmp_zarrfile, outputdt=outputdt) + ofile = ParticleFile(tmp_parquet, outputdt=outputdt) pset.execute(IncreaseAge, runtime=np.timedelta64(npart * 2, "s"), dt=np.timedelta64(1, "s"), output_file=ofile) - ds = xr.open_zarr(tmp_zarrfile) - age = ds["age"][:].values.astype("timedelta64[s]") - ds_timediff = np.zeros_like(age) - for i in range(npart): - ds_timediff[i, :] = ds.time.values[i, :] - time[i] - np.testing.assert_equal(age, ds_timediff) + df = parcels.read_particlefile(tmp_parquet) + + # Map sorted trajectory IDs to release times (0, 1, ..., npart-1 seconds) + for index, df_traj in df.groupby("trajectory"): + release_time = time[index] + np.testing.assert_equal(df_traj["age"].astype("timedelta64[s]").values, (df_traj["time"] - release_time).values) -def test_reset_dt(fieldset, tmp_zarrfile): +def test_reset_dt(fieldset, tmp_parquet): # Assert that p.dt gets reset when a write_time is not a multiple of dt # for p.dt=0.02 to reach outputdt=0.05 and endtime=0.1, the steps should be [0.2, 0.2, 0.1, 0.2, 0.2, 0.1], resulting in 6 kernel executions dt = np.timedelta64(20, "s") @@ -356,13 +318,13 @@ def Update_lon(particles, fieldset): # pragma: no cover particle = get_default_particle(np.float64) pset = ParticleSet(fieldset, pclass=particle, lon=[0], lat=[0]) - ofile = ParticleFile(tmp_zarrfile, outputdt=np.timedelta64(50, "s")) + ofile = ParticleFile(tmp_parquet, outputdt=np.timedelta64(50, "s")) pset.execute(Update_lon, runtime=5 * dt, dt=dt, output_file=ofile) assert np.allclose(pset.lon, 0.6) -def test_correct_misaligned_outputdt_dt(fieldset, tmp_zarrfile): +def test_correct_misaligned_outputdt_dt(fieldset, tmp_parquet): """Testing that outputdt does not need to be a multiple of dt.""" def Update_lon(particles, fieldset): # pragma: no cover @@ -370,12 +332,12 @@ def Update_lon(particles, fieldset): # pragma: no cover particle = get_default_particle(np.float64) pset = ParticleSet(fieldset, pclass=particle, lon=[0], lat=[0]) - ofile = ParticleFile(tmp_zarrfile, outputdt=np.timedelta64(3, "s")) + ofile = ParticleFile(tmp_parquet, outputdt=np.timedelta64(3, "s")) pset.execute(Update_lon, runtime=np.timedelta64(11, "s"), dt=np.timedelta64(2, "s"), output_file=ofile) - ds = xr.open_zarr(tmp_zarrfile) - assert np.allclose(ds.lon.values, [0, 3, 6, 9]) - assert np.allclose(timedelta_to_float(ds.time.values - ds.time.values[0, 0]), [0, 3, 6, 9]) + df = pd.read_parquet(tmp_parquet) + assert np.allclose(df["lon"].values, [0, 3, 6, 9]) + assert np.allclose(df.time - df.time.min(), [0, 3, 6, 9]) def setup_pset_execute(*, fieldset: FieldSet, outputdt: timedelta, execute_kwargs, particle_class=Particle): @@ -389,13 +351,13 @@ def setup_pset_execute(*, fieldset: FieldSet, outputdt: timedelta, execute_kwarg ) with tempfile.TemporaryDirectory() as dir: - name = f"{dir}/test.zarr" + name = f"{dir}/tmp.parquet" output_file = ParticleFile(name, outputdt=outputdt) pset.execute(DoNothing, output_file=output_file, **execute_kwargs) - ds = xr.open_zarr(name).load() + df = parcels.read_particlefile(name) - return ds + return df def test_pset_execute_outputdt_forwards(fieldset): @@ -404,9 +366,10 @@ def test_pset_execute_outputdt_forwards(fieldset): runtime = timedelta(hours=5) dt = timedelta(minutes=5) - ds = setup_pset_execute(fieldset=fieldset, outputdt=outputdt, execute_kwargs=dict(runtime=runtime, dt=dt)) + df = setup_pset_execute(fieldset=fieldset, outputdt=outputdt, execute_kwargs=dict(runtime=runtime, dt=dt)) + particle_0_times = df[df.trajectory == 0].time.values - assert np.all(ds.isel(trajectory=0).time.diff(dim="obs").values == np.timedelta64(outputdt)) + np.testing.assert_equal(np.diff(particle_0_times), outputdt.seconds) def test_pset_execute_output_time_forwards(fieldset): @@ -415,12 +378,9 @@ def test_pset_execute_output_time_forwards(fieldset): runtime = np.timedelta64(5, "h") dt = np.timedelta64(5, "m") - ds = setup_pset_execute(fieldset=fieldset, outputdt=outputdt, execute_kwargs=dict(runtime=runtime, dt=dt)) - - assert ( - ds.time[0, 0].values == fieldset.time_interval.left - and ds.time[0, -1].values == fieldset.time_interval.left + runtime - ) + df = setup_pset_execute(fieldset=fieldset, outputdt=outputdt, execute_kwargs=dict(runtime=runtime, dt=dt)) + assert df.time.min() == pd.Timestamp(fieldset.time_interval.left) + assert df.time.max() - df.time.min() == runtime def test_pset_execute_outputdt_backwards(fieldset): @@ -429,9 +389,9 @@ def test_pset_execute_outputdt_backwards(fieldset): runtime = timedelta(days=2) dt = -timedelta(minutes=5) - ds = setup_pset_execute(fieldset=fieldset, outputdt=outputdt, execute_kwargs=dict(runtime=runtime, dt=dt)) - file_outputdt = ds.isel(trajectory=0).time.diff(dim="obs").values - assert np.all(file_outputdt == np.timedelta64(-outputdt)) + df = setup_pset_execute(fieldset=fieldset, outputdt=outputdt, execute_kwargs=dict(runtime=runtime, dt=dt)) + particle_0_times = df[df.trajectory == 0].time.values + np.testing.assert_equal(np.diff(particle_0_times), -outputdt.seconds) def test_pset_execute_outputdt_backwards_fieldset_timevarying(): @@ -448,61 +408,26 @@ def test_pset_execute_outputdt_backwards_fieldset_timevarying(): ds_fset = copernicusmarine_to_sgrid(fields=fields) fieldset = FieldSet.from_sgrid_conventions(ds_fset) - ds = setup_pset_execute(outputdt=outputdt, execute_kwargs=dict(runtime=runtime, dt=dt), fieldset=fieldset) - file_outputdt = ds.isel(trajectory=0).time.diff(dim="obs").values - assert np.all(file_outputdt == np.timedelta64(-outputdt)), (file_outputdt, np.timedelta64(-outputdt)) + df = setup_pset_execute(outputdt=outputdt, execute_kwargs=dict(runtime=runtime, dt=dt), fieldset=fieldset) + particle_0_times = df[df.trajectory == 0].time.values + np.testing.assert_equal(np.diff(particle_0_times), -outputdt.seconds) -def test_particlefile_init(tmp_store): - ParticleFile(tmp_store, outputdt=np.timedelta64(1, "s"), chunks=(1, 3)) +def test_particlefile_init(tmp_parquet): + ParticleFile(tmp_parquet, outputdt=np.timedelta64(1, "s")) -@pytest.mark.parametrize("name", ["store", "outputdt", "chunks", "create_new_zarrfile"]) -def test_particlefile_readonly_attrs(tmp_store, name): - pfile = ParticleFile(tmp_store, outputdt=np.timedelta64(1, "s"), chunks=(1, 3)) +@pytest.mark.parametrize("name", ["path", "outputdt"]) +def test_particlefile_readonly_attrs(tmp_parquet, name): + pfile = ParticleFile(tmp_parquet, outputdt=np.timedelta64(1, "s")) with pytest.raises(AttributeError, match="property .* of 'ParticleFile' object has no setter"): setattr(pfile, name, "something") -def test_particlefile_init_invalid(tmp_store): # TODO: Add test for read only store - with pytest.raises(ValueError, match="chunks must be a tuple"): - ParticleFile(tmp_store, outputdt=np.timedelta64(1, "s"), chunks=1) - - -def test_particlefile_write_particle_data(tmp_store): - nparticles = 100 - - pfile = ParticleFile(tmp_store, outputdt=np.timedelta64(1, "s"), chunks=(nparticles, 40)) - pclass = Particle - - left, right = np.datetime64("2019-05-30T12:00:00.000000000", "ns"), np.datetime64("2020-01-02", "ns") - time_interval = TimeInterval(left=left, right=right) - - initial_lon = np.linspace(0, 1, nparticles) - data = create_particle_data( - pclass=pclass, - nparticles=nparticles, - ngrids=4, - time_interval=time_interval, - initial={ - "time": np.full(nparticles, fill_value=0), - "lon": initial_lon, - "dt": np.full(nparticles, fill_value=1.0), - "trajectory": np.arange(nparticles), - }, - ) - np.testing.assert_array_equal(data["time"], 0) - pfile._write_particle_data( - particle_data=data, - pclass=pclass, - time_interval=time_interval, - time=left, - ) - ds = xr.open_zarr(tmp_store) - assert ds.time.dtype == "datetime64[ns]" - np.testing.assert_equal(ds["time"].isel(obs=0).values, left) - assert ds.sizes["trajectory"] == nparticles - np.testing.assert_allclose(ds["lon"].isel(obs=0).values, initial_lon) +def test_particlefile_init_invalid(tmp_path): + path = tmp_path / "file.not-parquet" + with pytest.raises(ValueError, match="file extension must be '.parquet'"): + ParticleFile(path, outputdt=np.timedelta64(1, "s")) def test_pfile_write_custom_particle(): @@ -514,19 +439,19 @@ def test_pfile_write_custom_particle(): @pytest.mark.xfail( reason="set_variable_write_status should be removed - with Particle writing defined on the particle level. GH2186" ) -def test_pfile_set_towrite_False(fieldset, tmp_zarrfile): +def test_pfile_set_towrite_False(fieldset, tmp_parquet): npart = 10 pset = ParticleSet(fieldset, pclass=Particle, lon=np.linspace(0, 1, npart), lat=0.5 * np.ones(npart)) pset.set_variable_write_status("z", False) pset.set_variable_write_status("lat", False) - pfile = pset.ParticleFile(tmp_zarrfile, outputdt=1) + pfile = pset.ParticleFile(tmp_parquet, outputdt=1) def Update_lon(particles, fieldset): # pragma: no cover particles.dlon += 0.1 pset.execute(Update_lon, runtime=10, output_file=pfile) - ds = xr.open_zarr(tmp_zarrfile) + ds = xr.open_zarr(tmp_parquet) assert "time" in ds assert "z" not in ds assert "lat" not in ds @@ -535,3 +460,51 @@ def Update_lon(particles, fieldset): # pragma: no cover # For pytest purposes, we need to reset to original status pset.set_variable_write_status("z", True) pset.set_variable_write_status("lat", True) + + +@pytest.mark.parametrize( + "particle", + [ + Particle, + parcels.ParticleClass( + variables=[ + Variable( + "lon", + dtype=np.float32, + attrs={"standard_name": "longitude", "units": "degrees_east", "axis": "X"}, + ), + Variable( + "lat", + dtype=np.float32, + attrs={"standard_name": "latitude", "units": "degrees_north", "axis": "Y"}, + ), + Variable( + "z", + dtype=np.float32, + attrs={"standard_name": "vertical coordinate", "units": "m", "positive": "down"}, + ), + ] + ), + ], +) +def test_particle_schema(particle): + s = _get_schema(particle, {}, TimeInterval(datetime(2023, 1, 1, 12, 0), datetime(2023, 1, 2, 12, 0))) + + written_variables = [v for v in particle.variables if v.to_write] + + assert len(s.names) == len(written_variables), ( + "Number of particles in the output schema should be the same as the writable variables in the ParticleClass object." + ) + + for variable, pyarrow_field in zip( + written_variables, + s, + strict=False, + ): + assert variable.name == pyarrow_field.name + if variable.name != "time": + assert variable.attrs == {k.decode(): v.decode() for k, v in pyarrow_field.metadata.items()} + else: + assert b"units" in pyarrow_field.metadata + assert b"calendar" in pyarrow_field.metadata + assert pa.from_numpy_dtype(variable.dtype) == pyarrow_field.type diff --git a/tests/test_utils.py b/tests/test_utils.py index b42e13330..d8d695c18 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,4 +1,7 @@ import numpy as np +import pyarrow as pa +import pyarrow.parquet as pq +import pytest from tests import utils @@ -17,3 +20,27 @@ def test_round_and_hash_float_array(): arr_test = arr + 0.51 * delta h3 = utils.round_and_hash_float_array(arr_test, decimals=decimals) assert h3 != h + + +@pytest.mark.parametrize("cal", ["julian", "proleptic_gregorian", "365_day", "366_day", "360_day"]) +def test_assert_cftime_like_particlefile(tmp_path, cal): + path = tmp_path / "test.parquet" + attrs = {"units": "seconds since 2000-01-01 17:00:00", "calendar": cal} + field = pa.field("time", pa.float64(), metadata=attrs) + schema = pa.schema([field]) + table = pa.table({"time": pa.array([-20.0, 1.0])}, schema=schema) + pq.write_table(table, path) + + utils.assert_cftime_like_particlefile(path) + + +def test_assert_cftime_like_particlefile_broken_parquet(tmp_path): + path = tmp_path / "test.parquet" + attrs = {"units": "broken-units", "calendar": "365_day"} + field = pa.field("time", pa.float64(), metadata=attrs) + schema = pa.schema([field]) + table = pa.table({"time": pa.array([-20.0, 1.0])}, schema=schema) + pq.write_table(table, path) + + with pytest.raises(Exception, match="CF-time values in Parquet did not get properly decoded"): + utils.assert_cftime_like_particlefile(path) diff --git a/tests/test_uxadvection.py b/tests/test_uxadvection.py index 3f27536f8..d3db9aecd 100644 --- a/tests/test_uxadvection.py +++ b/tests/test_uxadvection.py @@ -1,6 +1,6 @@ import numpy as np +import pandas as pd import pytest -import xarray as xr import parcels from parcels._datasets.unstructured.generic import datasets as datasets_unstructured @@ -12,17 +12,17 @@ @pytest.mark.parametrize("integrator", [AdvectionEE, AdvectionRK2, AdvectionRK4]) -def test_ux_constant_flow_face_centered_2D(integrator, tmp_zarrfile): +def test_ux_constant_flow_face_centered_2D(integrator, tmp_parquet): ds = datasets_unstructured["ux_constant_flow_face_centered_2D"] T = np.timedelta64(3600, "s") dt = np.timedelta64(300, "s") fieldset = parcels.FieldSet.from_ugrid_conventions(ds, mesh="flat") pset = parcels.ParticleSet(fieldset, lon=[5.0], lat=[5.0]) - pfile = parcels.ParticleFile(store=tmp_zarrfile, outputdt=dt) + pfile = parcels.ParticleFile(path=tmp_parquet, outputdt=dt) pset.execute(integrator, runtime=T, dt=dt, output_file=pfile, verbose_progress=False) expected_lon = 8.6 np.testing.assert_allclose(pset.lon, expected_lon, atol=1e-5) - ds_out = xr.open_zarr(tmp_zarrfile) - np.testing.assert_allclose(ds_out["lon"][:, -1], expected_lon, atol=1e-5) + df = pd.read_parquet(tmp_parquet) + np.testing.assert_allclose(df["lon"].iloc[-1], expected_lon, atol=1e-5) diff --git a/tests/utils.py b/tests/utils.py index 3213abd31..33d6e0012 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -6,6 +6,7 @@ from collections import defaultdict from pathlib import Path +import cftime import numpy as np import xarray as xr @@ -151,3 +152,15 @@ def round_and_hash_float_array(arr, decimals=6): # Mimic Java's HashMap hash transformation h ^= (h >> 20) ^ (h >> 12) return h ^ (h >> 7) ^ (h >> 4) + + +def assert_cftime_like_particlefile(parquet_path: Path) -> None: + assert parquet_path.suffix == ".parquet", "Path must be a parquet file" + + df = parcels.read_particlefile(parquet_path, decode_times=True) + + # check first value (and hence rest of array) is what we expect + assert isinstance(df["time"].values[0], (cftime.datetime, np.datetime64)), ( + "CF-time values in Parquet did not get properly decoded. Are the attributes correct?" + ) + return diff --git a/tests/utils/test_time.py b/tests/utils/test_time.py index ef1f39346..26cc39c1c 100644 --- a/tests/utils/test_time.py +++ b/tests/utils/test_time.py @@ -8,7 +8,12 @@ from hypothesis import given from hypothesis import strategies as st -from parcels._core.utils.time import TimeInterval, maybe_convert_python_timedelta_to_numpy, timedelta_to_float +from parcels._core.utils.time import ( + TimeInterval, + _get_cf_attrs, + maybe_convert_python_timedelta_to_numpy, + timedelta_to_float, +) calendar_strategy = st.sampled_from( [ @@ -215,3 +220,9 @@ def test_timedelta_to_float(input, expected): def test_timedelta_to_float_exceptions(): with pytest.raises((ValueError, TypeError)): timedelta_to_float("invalid_type") + + +@given(datetime_strategy()) +def test_datetime_get_cf_attrs(dt): + attrs = _get_cf_attrs(dt) + assert "seconds" in attrs["units"]