Skip to content
4 changes: 2 additions & 2 deletions pynanigans/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
__version__ = "0.1.0"
__version__ = "0.2.0"

from .grids import get_distances, get_metrics, get_coords, get_grid
from .grids import get_metrics, get_coords, get_grid
from .utils import *
from . import pnplot
from . import utils
Expand Down
61 changes: 18 additions & 43 deletions pynanigans/grids.py
Original file line number Diff line number Diff line change
@@ -1,59 +1,34 @@

def get_coords(ds, topology="PPN",):
def get_coords(topology="PPN"):
"""
Constructs the coords dict for ds to be passed to xgcm.Grid
Constructs the coords dict to be passed to xgcm.Grid
Flat dimensions (F) are treated the same as Periodic ones (P)
"""
per = dict(left='xF', center='xC')
nper = dict(outer='xF', center='xC')
per = { dim : dict(left=f"{dim}F", center=f"{dim}C") for dim in "xyz" }
nper = { dim : dict(outer=f"{dim}F", center=f"{dim}C") for dim in "xyz" }
coords = { dim : per[dim] if top in "FP" else nper[dim] for dim, top in zip("xyz", topology) }

return coords
x_per = dict(left="x_faa", center="x_caa")
y_per = dict(left="y_afa", center="y_aca")
z_per = dict(left="z_aaf", center="z_aac")

x_nper = dict(outer="x_faa", center="x_caa")
y_nper = dict(outer="y_afa", center="y_aca")
z_nper = dict(outer="z_aaf", center="z_aac")

def get_distances(ds, dim="x", topology="P"):
"""
Get distance metrics for Center and Face points of one specific dimension ξ.
If the topology of this dimension is periodic, len(ξC)==len(ξF), but if it
is nonperiodic, then len(ξC)+1==len(ξF).
per = dict(x = x_per, y = y_per, z = z_per)
nper = dict(x = x_nper, y = y_nper, z = z_nper)

Currently does not deal with stretched domains where ΔξC!=ΔξF in the interior.
"""
import numpy as np
import xarray as xr
coords = { dim : per[dim] if top in "FP" else nper[dim] for dim, top in zip("xyz", topology) }

Δξ_mean = ds[dim+"C"].diff(dim+"C").mean().item()
ΔξC = xr.DataArray(np.ones(len(ds[dim+"C"])), dims=[dim+'C'])
if topology=="P" or topology=="F":
ΔξF = xr.DataArray(np.ones(len(ds[dim+"F"])), dims=[dim+'F'])
elif topology=="N":
if len(ds[dim+"F"]) != 1:
interior = np.ones(len(ds[dim+"F"])-2)
ΔξF = xr.DataArray(np.hstack([0.5, interior, 0.5]), dims=[dim+'F'])
else: # Especial case of a slice in a non-periodic dimension
ΔξF = xr.DataArray([1], dims=[dim+'F'])
return Δξ_mean * xr.Dataset({f"Δ{dim}C" : ΔξC, f"Δ{dim}F" : ΔξF})
return coords


def get_metrics(ds, topology="PPN"):
"""
Constructs the metric dict for `ds`.
"""
Constructs the metric dict for ds.
(Not sure if the metrics are correct at the boundary points
"""

for ξ, top in zip('xyz', topology):
ξdist = get_distances(ds, dim=ξ, topology=top)
ds.coords[f"Δ{ξ}C"] = ξdist[f"Δ{ξ}C"]
ds.coords[f"Δ{ξ}F"] = ξdist[f"Δ{ξ}F"]

metrics = {
('x',): ['ΔxC', 'ΔxF'], # X distances
('y',): ['ΔyC', 'ΔyF'], # Y distances
('z',): ['ΔzC', 'ΔzF'], # Z distances
("x",): ["Δx_caa", "Δx_faa"], # X distances
("y",): ["Δy_aca", "Δy_afa"], # Y distances
("z",): ["Δz_aac", "Δz_aaf"], # Z distances
}

return metrics


Expand All @@ -62,7 +37,7 @@ def get_grid(ds, coords=None, metrics=None, topology="PPN", **kwargs):
import xgcm as xg

if coords is None:
coords = get_coords(ds, topology=topology)
coords = get_coords(topology)
if metrics is None:
metrics = get_metrics(ds, topology=topology)

Expand Down
4 changes: 2 additions & 2 deletions pynanigans/pnplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
def pnplot(darray, surjection=surjection, **kwargs):
"""
Bijects darray to rename the dimensions before calling plot().
This makes plot easier as, instead of calling, `ds.u.plot(x='xF', y='zC')`,
you can call `ds.pnplot(x='x', y='z')`
This makes plot easier as, instead of calling, `ds.u.plot(x="x_faa", y="z_aac")`,
you can call `ds.pnplot(x="x", y="z")`
"""
return biject(darray, surjection=surjection).plot(**kwargs)
xr.DataArray.pnplot = pnplot
Expand Down
34 changes: 17 additions & 17 deletions pynanigans/utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import xarray as xr
from .grids import get_grid

surjection = dict(xC='x',
xF='x',
yC='y',
yF='y',
zC='z',
zF='z',
surjection = dict(x_caa="x",
x_faa="x",
y_aca="y",
y_afa="y",
z_aac="z",
z_aaf="z",
)


Expand All @@ -17,7 +17,7 @@ def biject(darray, *args, surjection=surjection):
If `*args` is provided, only those dimensions will be renamed. If not, `x`, `y`
and `z` will be automatically renamed.

This makes calling functions easier as instead of calling `darray.u.plot(x='xF', y='zC')`,
This makes calling functions easier as instead of calling `darray.u.plot(x='x_faa', y='z_aac')`,
you can call `darray.pnplot(x='x', y='z')`
"""
da_dims = darray.dims
Expand All @@ -44,9 +44,9 @@ def normalize_time_by(darray, seconds=1, new_units="seconds"):
object while normalizing it by number of seconds `seconds`.
"""
import numpy as np
if darray.time.dtype == '<m8[ns]': # timedelta[ns]
if darray.time.dtype == "<m8[ns]": # timedelta[ns]
darray = darray.assign_coords(time = darray.time.astype(np.float64)/1e9/seconds) # From timedelta[ns] to seconds
elif darray.time.dtype == 'float64':
elif darray.time.dtype == "float64":
darray = darray.assign_coords(time = darray.time.astype(np.float64)/seconds) # From timedelta[ns] to seconds
else:
raise(TypeError("Unknown type for time"))
Expand All @@ -60,7 +60,7 @@ def downsample(darray, round_func=round, **dim_limits):
Downsamples `darray` based on dimensions given in dim_limits

dim_limits should be of the form:
dim_limits = dict(yC=1000, zF=2048)
dim_limits = dict(y_aca=1000, z_aaf=2048)
"""
for dim, dim_limit in dim_limits.items():
dim_length = len(darray[dim])
Expand Down Expand Up @@ -134,25 +134,25 @@ def open_simulation(fname,
`kwargs` are passed to `xarray.open_dataset()` and `grid_kwargs` are passed to `pynanigans.get_grid()`.
"""

#++++ Open dataset and create grid before squeezing
#+++ Open dataset and create grid before squeezing
if load:
ds = xr.load_dataset(fname, **kwargs)
else:
ds = xr.open_dataset(fname, **kwargs)
grid_ds = get_grid(ds, topology=topology, **grid_kwargs)
#----
#---

#++++ Squeeze?
#+++ Squeeze?
if squeeze: ds = ds.squeeze()
#----
#---

#++++ Returning only unique times. Useful if simulation was restarted and there's overlap in time
#+++ Returning only unique times. Useful if simulation was restarted and there's overlap in time
if unique:
import numpy as np
_, index = np.unique(ds['time'], return_index=True)
_, index = np.unique(ds["time"], return_index=True)
if verbose and (len(index)!=len(ds.time)): print("Cleaning non-unique indices")
ds = ds.isel(time=index)
#----
#---

return grid_ds, ds

36 changes: 29 additions & 7 deletions tests/test_grids.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,36 @@

def test_get_coords():
# Test periodic coordinates
coords = get_coords(None, topology="PPP")
coords = get_coords(topology="PPP")
assert "x" in coords
assert "y" in coords
assert "z" in coords
assert coords["x"]["left"] == "xF"
assert coords["x"]["center"] == "xC"

assert coords["x"]["left"] == "x_faa"
assert coords["x"]["center"] == "x_caa"
# Test non-periodic coordinates
coords = get_coords(None, topology="NNN")
assert coords["x"]["outer"] == "xF"
assert coords["x"]["center"] == "xC"
coords = get_coords(topology="NNN")
assert coords["x"]["outer"] == "x_faa"
assert coords["x"]["center"] == "x_caa"

def test_get_metrics():
# Create a test dataset
data = np.random.rand(10, 10, 10)
dims = ['x_caa', 'y_aca', 'z_aac']
coords = {
'x_caa': np.linspace(0, 1, 10),
'y_aca': np.linspace(0, 1, 10),
'z_aac': np.linspace(0, 1, 10)
}
ds = xr.Dataset(
data_vars={'u': (dims, data)},
coords=coords
)

# Test metrics
metrics = get_metrics(ds)
assert ('x',) in metrics
assert ('y',) in metrics
assert ('z',) in metrics
assert "Δx_caa" in metrics[('x',)]
assert "Δx_faa" in metrics[('x',)]
64 changes: 42 additions & 22 deletions tests/test_pnplot.py
Original file line number Diff line number Diff line change
@@ -1,89 +1,109 @@
import pytest
import xarray as xr
import numpy as np
from pynanigans.pnplot import pnplot, _imshow, _pcolormesh, _contour, _contourf
from pynanigans import pnplot

def test_pnplot():
# Create a test dataset
data = np.random.rand(10, 10)
dims = ['xC', 'yC']
dims = ['x_caa', 'y_aca']
coords = {
'xC': np.linspace(0, 1, 10),
'yC': np.linspace(0, 1, 10)
'x_caa': np.linspace(0, 1, 10),
'y_aca': np.linspace(0, 1, 10)
}
ds = xr.Dataset(
data_vars={'u': (dims, data)},
coords=coords
)

# Test plotting
plot = pnplot(ds.u, x='x', y='y')
plot = ds.u.pnplot(x='x', y='y')
assert plot is not None

# Test error case - invalid dimension
with pytest.raises(ValueError):
ds.u.pnplot(x='invalid_dim', y='y')

def test_imshow():
# Create a test dataset
data = np.random.rand(10, 10)
dims = ['xC', 'yC']
dims = ['x_caa', 'y_aca']
coords = {
'xC': np.linspace(0, 1, 10),
'yC': np.linspace(0, 1, 10)
'x_caa': np.linspace(0, 1, 10),
'y_aca': np.linspace(0, 1, 10)
}
ds = xr.Dataset(
data_vars={'u': (dims, data)},
coords=coords
)

# Test imshow
plot = _imshow(ds.u, x='x', y='y')
plot = ds.u.pnimshow(x='x', y='y')
assert plot is not None

# Test error case - invalid dimension
with pytest.raises(ValueError):
ds.u.pnimshow(x='invalid_dim', y='y_aca')

def test_pcolormesh():
# Create a test dataset
data = np.random.rand(10, 10)
dims = ['xC', 'yC']
dims = ['x_caa', 'y_aca']
coords = {
'xC': np.linspace(0, 1, 10),
'yC': np.linspace(0, 1, 10)
'x_caa': np.linspace(0, 1, 10),
'y_aca': np.linspace(0, 1, 10)
}
ds = xr.Dataset(
data_vars={'u': (dims, data)},
coords=coords
)

# Test pcolormesh
plot = _pcolormesh(ds.u, x='x', y='y')
plot = ds.u.pnpcolormesh(x='x', y='y')
assert plot is not None

# Test error case - invalid dimension
with pytest.raises(ValueError):
ds.u.pnpcolormesh(x='invalid_dim', y='y')

def test_contour():
# Create a test dataset
data = np.random.rand(10, 10)
dims = ['xC', 'yC']
dims = ['x_caa', 'y_aca']
coords = {
'xC': np.linspace(0, 1, 10),
'yC': np.linspace(0, 1, 10)
'x_caa': np.linspace(0, 1, 10),
'y_aca': np.linspace(0, 1, 10)
}
ds = xr.Dataset(
data_vars={'u': (dims, data)},
coords=coords
)

# Test contour
plot = _contour(ds.u, x='x', y='y')
plot = ds.u.pncontour(x='x', y='y')
assert plot is not None

# Test error case - invalid dimension
with pytest.raises(ValueError):
ds.u.pncontour(x='invalid_dim', y='y')

def test_contourf():
# Create a test dataset
data = np.random.rand(10, 10)
dims = ['xC', 'yC']
dims = ['x_caa', 'y_aca']
coords = {
'xC': np.linspace(0, 1, 10),
'yC': np.linspace(0, 1, 10)
'x_caa': np.linspace(0, 1, 10),
'y_aca': np.linspace(0, 1, 10)
}
ds = xr.Dataset(
data_vars={'u': (dims, data)},
coords=coords
)

# Test contourf
plot = _contourf(ds.u, x='x', y='y')
assert plot is not None
plot = ds.u.pncontourf(x='x', y='y')
assert plot is not None

# Test error case - invalid dimension
with pytest.raises(ValueError):
ds.u.pncontourf(x='invalid_dim', y='y')
Loading