diff --git a/pynanigans/__init__.py b/pynanigans/__init__.py index 0f825fd..e181608 100644 --- a/pynanigans/__init__.py +++ b/pynanigans/__init__.py @@ -1,6 +1,6 @@ -__version__ = "0.1.0" +__version__ = "0.2.0" -from .grids import get_distances, get_metrics, get_coords, get_grid +from .grids import get_metrics, get_coords, get_grid from .utils import * from . import pnplot from . import utils diff --git a/pynanigans/grids.py b/pynanigans/grids.py index 62cf597..4ab61c9 100644 --- a/pynanigans/grids.py +++ b/pynanigans/grids.py @@ -1,59 +1,34 @@ -def get_coords(ds, topology="PPN",): +def get_coords(topology="PPN"): """ - Constructs the coords dict for ds to be passed to xgcm.Grid + Constructs the coords dict to be passed to xgcm.Grid Flat dimensions (F) are treated the same as Periodic ones (P) """ - per = dict(left='xF', center='xC') - nper = dict(outer='xF', center='xC') - per = { dim : dict(left=f"{dim}F", center=f"{dim}C") for dim in "xyz" } - nper = { dim : dict(outer=f"{dim}F", center=f"{dim}C") for dim in "xyz" } - coords = { dim : per[dim] if top in "FP" else nper[dim] for dim, top in zip("xyz", topology) } - - return coords + x_per = dict(left="x_faa", center="x_caa") + y_per = dict(left="y_afa", center="y_aca") + z_per = dict(left="z_aaf", center="z_aac") + x_nper = dict(outer="x_faa", center="x_caa") + y_nper = dict(outer="y_afa", center="y_aca") + z_nper = dict(outer="z_aaf", center="z_aac") -def get_distances(ds, dim="x", topology="P"): - """ - Get distance metrics for Center and Face points of one specific dimension ξ. - If the topology of this dimension is periodic, len(ξC)==len(ξF), but if it - is nonperiodic, then len(ξC)+1==len(ξF). + per = dict(x = x_per, y = y_per, z = z_per) + nper = dict(x = x_nper, y = y_nper, z = z_nper) - Currently does not deal with stretched domains where ΔξC!=ΔξF in the interior. - """ - import numpy as np - import xarray as xr + coords = { dim : per[dim] if top in "FP" else nper[dim] for dim, top in zip("xyz", topology) } - Δξ_mean = ds[dim+"C"].diff(dim+"C").mean().item() - ΔξC = xr.DataArray(np.ones(len(ds[dim+"C"])), dims=[dim+'C']) - if topology=="P" or topology=="F": - ΔξF = xr.DataArray(np.ones(len(ds[dim+"F"])), dims=[dim+'F']) - elif topology=="N": - if len(ds[dim+"F"]) != 1: - interior = np.ones(len(ds[dim+"F"])-2) - ΔξF = xr.DataArray(np.hstack([0.5, interior, 0.5]), dims=[dim+'F']) - else: # Especial case of a slice in a non-periodic dimension - ΔξF = xr.DataArray([1], dims=[dim+'F']) - return Δξ_mean * xr.Dataset({f"Δ{dim}C" : ΔξC, f"Δ{dim}F" : ΔξF}) + return coords def get_metrics(ds, topology="PPN"): + """ + Constructs the metric dict for `ds`. """ - Constructs the metric dict for ds. - (Not sure if the metrics are correct at the boundary points - """ - - for ξ, top in zip('xyz', topology): - ξdist = get_distances(ds, dim=ξ, topology=top) - ds.coords[f"Δ{ξ}C"] = ξdist[f"Δ{ξ}C"] - ds.coords[f"Δ{ξ}F"] = ξdist[f"Δ{ξ}F"] - metrics = { - ('x',): ['ΔxC', 'ΔxF'], # X distances - ('y',): ['ΔyC', 'ΔyF'], # Y distances - ('z',): ['ΔzC', 'ΔzF'], # Z distances + ("x",): ["Δx_caa", "Δx_faa"], # X distances + ("y",): ["Δy_aca", "Δy_afa"], # Y distances + ("z",): ["Δz_aac", "Δz_aaf"], # Z distances } - return metrics @@ -62,7 +37,7 @@ def get_grid(ds, coords=None, metrics=None, topology="PPN", **kwargs): import xgcm as xg if coords is None: - coords = get_coords(ds, topology=topology) + coords = get_coords(topology) if metrics is None: metrics = get_metrics(ds, topology=topology) diff --git a/pynanigans/pnplot.py b/pynanigans/pnplot.py index 695be0e..8893598 100644 --- a/pynanigans/pnplot.py +++ b/pynanigans/pnplot.py @@ -5,8 +5,8 @@ def pnplot(darray, surjection=surjection, **kwargs): """ Bijects darray to rename the dimensions before calling plot(). - This makes plot easier as, instead of calling, `ds.u.plot(x='xF', y='zC')`, - you can call `ds.pnplot(x='x', y='z')` + This makes plot easier as, instead of calling, `ds.u.plot(x="x_faa", y="z_aac")`, + you can call `ds.pnplot(x="x", y="z")` """ return biject(darray, surjection=surjection).plot(**kwargs) xr.DataArray.pnplot = pnplot diff --git a/pynanigans/utils.py b/pynanigans/utils.py index f962c05..aa5a2fa 100644 --- a/pynanigans/utils.py +++ b/pynanigans/utils.py @@ -1,12 +1,12 @@ import xarray as xr from .grids import get_grid -surjection = dict(xC='x', - xF='x', - yC='y', - yF='y', - zC='z', - zF='z', +surjection = dict(x_caa="x", + x_faa="x", + y_aca="y", + y_afa="y", + z_aac="z", + z_aaf="z", ) @@ -17,7 +17,7 @@ def biject(darray, *args, surjection=surjection): If `*args` is provided, only those dimensions will be renamed. If not, `x`, `y` and `z` will be automatically renamed. - This makes calling functions easier as instead of calling `darray.u.plot(x='xF', y='zC')`, + This makes calling functions easier as instead of calling `darray.u.plot(x='x_faa', y='z_aac')`, you can call `darray.pnplot(x='x', y='z')` """ da_dims = darray.dims @@ -44,9 +44,9 @@ def normalize_time_by(darray, seconds=1, new_units="seconds"): object while normalizing it by number of seconds `seconds`. """ import numpy as np - if darray.time.dtype == '