-
Notifications
You must be signed in to change notification settings - Fork 13
Open
Description
Recently recoded the remove_bottom_values function using numba:
from numba import float64, guvectorize
import numpy as np
import xarray as xr
@guvectorize(
[
(float64[:], float64[:]),
],
"(n)->(n)",
nopython=True,
)
def _remove_last_value(data, output):
# initialize output
output[:] = data[:]
for i in range(len(data)-1):
if np.isnan(output[i+1]):
output[i] = np.nan
# take care of boundaries
if not np.isnan(output[-1]):
output[-1] = np.nan
def remove_bottom_values_numba(da, dim='lev'):
out = xr.apply_ufunc(
_remove_last_value,
da,
input_core_dims=[[dim]],
output_core_dims=[[dim]],
dask="parallelized",
output_dtypes=[da.dtype],
)
return out
def remove_bottom_values_recoded(ds, dim="lev", fill_val=-1e10):
"""Remove the deepest values that are not nan along the dimension `dim`"""
# for now assume that values of `dim` increase along the dimension
if ds[dim][0] > ds[dim][-1]:
raise ValueError(
f"It seems like `{dim}` has decreasing values. This is not supported yet. Please sort before."
)
else:
ds_masked = xr.Dataset({va:remove_bottom_values_numba(ds[va]) for va in ds.data_vars})
ds_masked = ds_masked.transpose(*tuple([di for di in ds.dims if di in ds_masked]))
ds_masked = ds_masked.assign_coords({co:ds[co].transpose(*[di for di in ds.dims if di in ds[co]]) for co in ds.coords})
ds_masked.attrs = ds.attrs
return ds_maskedI am planning on implementing this here at some point. It might also be nice to generalize this to optionally keep only the bottom, and maybe not just leave one value, but an arbitrary amount.
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels