Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@ Introduction
The :mod:`regridding` package aims to provide Numba-accelerated resampling of
logically-rectangular curvilinear grids.

Features
--------
* 1D linear interpolation
* 1D conservative resampling
* 2D conservative resampling


API Reference
Expand Down
34 changes: 34 additions & 0 deletions regridding/_regrid/_regrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,40 @@ def regrid(

|

Regrid a 1D array using conservative resampling.

.. jupyter-execute::

# Define the edges of the input grid
x_input = np.linspace(-1, 1, num=21)

# Define the edges of the output grid
# with a small offset to prevent degenerate cells
x_output = np.linspace(-1, 1, num=11)[::-1] + 1e-6

# Compute the centers of the input grid
x = (x_input[1:] + x_input[:-1]) / 2

# Define an array of values for each cell
# of the input grid
values = np.exp(-(x / 0.25) ** 2 /2)

# Regrid the array of values onto the output grid
values_new = regridding.regrid(
coordinates_input=x_input,
coordinates_output=x_output,
values_input=values,
method="conservative",
)

# Plot the result
fig, ax = plt.subplots()
ax.stairs(values, x_input, label="input")
ax.stairs(values_new, x_output, label="output")
ax.legend();

|

Regrid a 2D array using conservative resampling.

.. jupyter-execute::
Expand Down
51 changes: 37 additions & 14 deletions regridding/_regrid/_regrid_from_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,25 +54,51 @@ def regrid_from_weights(

unit = getattr(values_input, "unit", None)

shape_input = np.broadcast_shapes(values_input.shape, shape_input)

ndim_input = len(shape_input)
ndim_output = len(shape_output)

axis_input = _util._normalize_axis(axis_input, ndim=ndim_input)
axis_output = _util._normalize_axis(axis_output, ndim=ndim_output)

shape_input_orthogonal = tuple(
shape_input[i]
for i in _util._normalize_axis(None, ndim=len(shape_input))
if i not in axis_input
)
shape_output_orthogonal = tuple(
shape_output[i]
for i in _util._normalize_axis(None, ndim=len(shape_output))
if i not in axis_output
)
shape_values_orthogonal = tuple(
values_input.shape[i]
for i in _util._normalize_axis(None, ndim=values_input.ndim)
if i not in axis_input
)

shape_orthogonal = (
1 if i in axis_input else shape_input[i] for i in range(-len(shape_input), 0)
shape_orthogonal = np.broadcast_shapes(
shape_input_orthogonal,
shape_output_orthogonal,
shape_values_orthogonal,
)

axis_input = tuple(sorted(axis_input))
axis_output = tuple(sorted(axis_output))

shape_input_new = list(reversed(shape_orthogonal))
for ax in reversed(axis_input):
shape_input_new.insert(~ax, shape_input[ax])
shape_input = tuple(reversed(shape_input_new))

shape_output_new = list(reversed(shape_orthogonal))
for ax in reversed(axis_output):
shape_output_new.insert(~ax, shape_output[ax])
shape_output = tuple(reversed(shape_output_new))

weights = np.broadcast_to(np.array(weights), shape_orthogonal, subok=True)
values_input = np.broadcast_to(values_input, shape_input, subok=True)

if values_output is None:
shape_output = np.broadcast_shapes(
shape_output,
tuple(
shape_input[ax] if ax not in axis_input else 1
for ax in _util._normalize_axis(None, ndim_input)
),
)
values_output = np.zeros_like(values_input, shape=shape_output)
else:
if values_output.shape != shape_output: # pragma: nocover
Expand All @@ -81,9 +107,6 @@ def regrid_from_weights(
)
values_output.fill(0)

ndim_output = len(shape_output)
axis_output = _util._normalize_axis(axis_output, ndim=ndim_output)

axis_input_numba = ~np.arange(len(axis_input))[::-1]
axis_output_numba = ~np.arange(len(axis_output))[::-1]

Expand Down
36 changes: 20 additions & 16 deletions regridding/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ def _normalize_axis(


def _normalize_input_output_coordinates(
coordinates_input: tuple[np.ndarray, ...],
coordinates_output: tuple[np.ndarray, ...],
coordinates_input: np.ndarray | tuple[np.ndarray, ...],
coordinates_output: np.ndarray | tuple[np.ndarray, ...],
axis_input: None | int | tuple[int, ...] = None,
axis_output: None | int | tuple[int, ...] = None,
) -> tuple[
Expand All @@ -26,6 +26,12 @@ def _normalize_input_output_coordinates(
tuple[int, ...],
tuple[int, ...],
]:
if isinstance(coordinates_input, np.ndarray):
coordinates_input = (coordinates_input,)

if isinstance(coordinates_output, np.ndarray):
coordinates_output = (coordinates_output,)

shape_coordinates_input = np.broadcast(*coordinates_input).shape
shape_coordinates_output = np.broadcast(*coordinates_output).shape

Expand All @@ -35,8 +41,8 @@ def _normalize_input_output_coordinates(
axis_input = _normalize_axis(axis_input, ndim=ndim_input)
axis_output = _normalize_axis(axis_output, ndim=ndim_output)

axis_input = sorted(axis_input, reverse=True)
axis_output = sorted(axis_output, reverse=True)
axis_input = tuple(sorted(axis_input, reverse=True))
axis_output = tuple(sorted(axis_output, reverse=True))

if len(axis_output) != len(axis_input):
raise ValueError(
Expand Down Expand Up @@ -74,17 +80,15 @@ def _normalize_input_output_coordinates(
shape_input_orthogonal, shape_output_orthogonal
)

shape_input = list(shape_orthogonal)
for ax in reversed(axis_input):
ax = ax % ndim_input
shape_input.insert(ax, shape_coordinates_input[ax])
shape_input = tuple(shape_input)
shape_input = list(reversed(shape_orthogonal))
for ax in axis_input:
shape_input.insert(~ax, shape_coordinates_input[ax])
shape_input = tuple(reversed(shape_input))

shape_output = list(shape_orthogonal)
for ax in reversed(axis_output):
ax = ax % ndim_input
shape_output.insert(ax, shape_coordinates_output[ax])
shape_output = tuple(shape_output)
shape_output = list(reversed(shape_orthogonal))
for ax in axis_output:
shape_output.insert(~ax, shape_coordinates_output[ax])
shape_output = tuple(reversed(shape_output))

coordinates_input = tuple(
np.broadcast_to(coord, shape_input) for coord in coordinates_input
Expand All @@ -98,7 +102,7 @@ def _normalize_input_output_coordinates(
coordinates_output,
axis_input,
axis_output,
shape_input,
shape_output,
shape_coordinates_input,
shape_coordinates_output,
shape_orthogonal,
)
112 changes: 79 additions & 33 deletions regridding/_weights/_weights_conservative.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from typing import Sequence
import multiprocessing
import concurrent.futures
import numpy as np
import numba
from regridding import _util
from ._weights_conservative_1d import weights_conservative_1d
from regridding._conservative_ramshaw import _conservative_ramshaw


Expand Down Expand Up @@ -38,38 +41,81 @@ def _weights_conservative(

weights = np.empty(shape_orthogonal, dtype=numba.typed.List)

for index in np.ndindex(*shape_orthogonal):
index_vertices_input = list(reversed(index))

for ax in axis_input:
index_vertices_input.insert(~ax, slice(None))
index_vertices_input = tuple(reversed(index_vertices_input))

index_vertices_output = list(reversed(index))
for ax in axis_output:
index_vertices_output.insert(~ax, slice(None))
index_vertices_output = tuple(reversed(index_vertices_output))

if len(axis_input) == 1:
raise NotImplementedError("1D regridding not supported")

elif len(axis_input) == 2:
coordinates_input_x, coordinates_input_y = coordinates_input
coordinates_output_x, coordinates_output_y = coordinates_output
weights[index] = _conservative_ramshaw(
grid_input=(
coordinates_input_x[index_vertices_input],
coordinates_input_y[index_vertices_input],
),
grid_output=(
coordinates_output_x[index_vertices_output],
coordinates_output_y[index_vertices_output],
),
)

else:
raise NotImplementedError(
"Regridding operations greater than 2D are not supported"
)
if len(axis_input) == 1:

threads = 5 * multiprocessing.cpu_count()

with concurrent.futures.ThreadPoolExecutor(threads) as executor:

(x_input,) = coordinates_input
(x_output,) = coordinates_output

x_input = np.moveaxis(x_input, axis_input, ~0)
x_output = np.moveaxis(x_output, axis_output, ~0)

x_input = x_input.reshape(-1, x_input.shape[~0])
x_output = x_output.reshape(-1, x_output.shape[~0])

weights = weights.reshape(-1)

step = np.ceil(x_input.shape[0] / threads).astype(int)

futures = []

for t in range(threads):

index_start = t * step
index_stop = (t + 1) * step

future = executor.submit(
weights_conservative_1d,
x_input=x_input,
x_output=x_output,
weights=weights,
index_start=index_start,
index_stop=index_stop,
)

futures.append(future)

if index_stop >= x_output.shape[0]:
break

concurrent.futures.wait(futures)

weights = weights.reshape(shape_orthogonal)

else:

for index in np.ndindex(*shape_orthogonal):
index_vertices_input = list(reversed(index))

for ax in axis_input:
index_vertices_input.insert(~ax, slice(None))
index_vertices_input = tuple(reversed(index_vertices_input))

index_vertices_output = list(reversed(index))
for ax in axis_output:
index_vertices_output.insert(~ax, slice(None))
index_vertices_output = tuple(reversed(index_vertices_output))

if len(axis_input) == 2:
coordinates_input_x, coordinates_input_y = coordinates_input
coordinates_output_x, coordinates_output_y = coordinates_output
weights[index] = _conservative_ramshaw(
grid_input=(
coordinates_input_x[index_vertices_input],
coordinates_input_y[index_vertices_input],
),
grid_output=(
coordinates_output_x[index_vertices_output],
coordinates_output_y[index_vertices_output],
),
)

else: # pragma: nocover
raise NotImplementedError(
"Regridding operations greater than 2D are not supported"
)

return weights, shape_values_input, shape_values_output
5 changes: 5 additions & 0 deletions regridding/_weights/_weights_conservative_1d/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from ._weights_conservative_1d import weights_conservative_1d

__all__ = [
"weights_conservative_1d",
]
Loading
Loading