Skip to content

Commit 25a8afe

Browse files
authored
Merge pull request #220 from csiro-coasts/test-helpers
Split tests.utils and tests.test_utils
2 parents 612392f + 26348d7 commit 25a8afe

20 files changed

Lines changed: 313 additions & 281 deletions

docs/releases/development.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,3 +21,7 @@ Next release (in development)
2121
(:pr:`219`).
2222
* Defer ShocSimple coordinate detection to the CFGrid2D base class
2323
(:issue:`217`, :pr:`218`).
24+
* Split `tests.utils` in to multiple `tests.helpers` submodules
25+
(:pr:`220`).
26+
* Split `tests.test_utils` in to multiple `tests.utils.test_component` submodules
27+
(:pr:`220`).

tests/conventions/test_cfgrid1d.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,10 @@
1717
CFGrid1D, CFGrid1DTopology, CFGridKind, CFGridTopology
1818
)
1919
from emsarray.operations import geometry
20-
from tests.utils import (
21-
assert_property_not_cached, box, mask_from_strings, track_peak_memory_usage
22-
)
20+
from tests.helpers.array import mask_from_strings
21+
from tests.helpers.functools import assert_property_not_cached
22+
from tests.helpers.geometry import box
23+
from tests.helpers.memory import track_peak_memory_usage
2324

2425
logger = logging.getLogger(__name__)
2526

tests/conventions/test_cfgrid2d.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,13 @@
2626
from emsarray.conventions.shoc import ShocSimple
2727
from emsarray.exceptions import NoSuchCoordinateError
2828
from emsarray.operations import geometry
29-
from tests.utils import (
29+
from tests.helpers.datasets import (
3030
AxisAlignedShocGrid, DiagonalShocGrid, ShocGridGenerator,
31-
ShocLayerGenerator, assert_property_not_cached, plot_geometry,
32-
track_peak_memory_usage
31+
ShocLayerGenerator
3332
)
33+
from tests.helpers.functools import assert_property_not_cached
34+
from tests.helpers.geometry import plot_geometry
35+
from tests.helpers.memory import track_peak_memory_usage
3436

3537
logger = logging.getLogger(__name__)
3638

tests/conventions/test_shoc_standard.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,11 @@
1919
)
2020
from emsarray.conventions.shoc import ShocStandard
2121
from emsarray.operations import geometry
22-
from tests.utils import (
23-
DiagonalShocGrid, ShocGridGenerator, ShocLayerGenerator, mask_from_strings,
24-
track_peak_memory_usage
22+
from tests.helpers.array import mask_from_strings
23+
from tests.helpers.datasets import (
24+
DiagonalShocGrid, ShocGridGenerator, ShocLayerGenerator
2525
)
26+
from tests.helpers.memory import track_peak_memory_usage
2627

2728
logger = logging.getLogger(__name__)
2829

tests/conventions/test_ugrid.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,9 @@
2323
ConventionViolationError, ConventionViolationWarning
2424
)
2525
from emsarray.operations import geometry
26-
from tests.utils import (
27-
assert_property_not_cached, filter_warning, track_peak_memory_usage
28-
)
26+
from tests.helpers.functools import assert_property_not_cached
27+
from tests.helpers.memory import track_peak_memory_usage
28+
from tests.helpers.warnings import filter_warning
2929

3030
logger = logging.getLogger(__name__)
3131

tests/helpers/__init__.py

Whitespace-only changes.

tests/helpers/array.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import itertools
2+
3+
import numpy
4+
5+
6+
def reduce_axes(arr: numpy.ndarray, axes: tuple[bool, ...] | None = None) -> numpy.ndarray:
7+
"""
8+
Reduce the size of an array by one on an axis-by-axis basis. If an axis is
9+
reduced, neigbouring values are averaged together
10+
11+
:param arr: The array to reduce.
12+
:param axes: A tuple of booleans indicating which axes should be reduced. Optional, defaults to reducing along all axes.
13+
:returns: A new array with the same number of axes, but one size smaller in each axis that was reduced.
14+
"""
15+
if axes is None:
16+
axes = tuple(True for _ in arr.shape)
17+
axes_slices = [[numpy.s_[+1:], numpy.s_[:-1]] if axis else [numpy.s_[:]] for axis in axes]
18+
return numpy.mean([arr[tuple(p)] for p in itertools.product(*axes_slices)], axis=0) # type: ignore
19+
20+
21+
def mask_from_strings(mask_strings: list[str]) -> numpy.ndarray:
22+
"""
23+
Make a boolean mask array from a list of strings:
24+
25+
>>> mask_from_strings([
26+
... "101",
27+
... "010",
28+
... "111",
29+
... ])
30+
array([[ True, False, True],
31+
[False, True, False],
32+
[ True, True, True]])
33+
"""
34+
return numpy.array([list(map(int, line)) for line in mask_strings]).astype(bool)

tests/utils.py renamed to tests/helpers/datasets.py

Lines changed: 4 additions & 197 deletions
Original file line numberDiff line numberDiff line change
@@ -1,81 +1,15 @@
11
import abc
2-
import contextlib
3-
import importlib.metadata
4-
import itertools
5-
import tracemalloc
6-
import warnings
7-
from collections.abc import Hashable
82
from functools import cached_property
9-
from types import TracebackType
10-
from typing import Any, Type
3+
from typing import Hashable
114

12-
import matplotlib.pyplot as plt
135
import numpy
14-
import pytest
15-
import shapely
166
import xarray
17-
from cartopy.mpl.geoaxes import GeoAxes
18-
from packaging.requirements import Requirement
197

208
from emsarray.conventions.arakawa_c import (
219
ArakawaCGridKind, c_mask_from_centres
2210
)
23-
from emsarray.types import Bounds, Pathish
24-
25-
26-
@contextlib.contextmanager
27-
def filter_warning(*args, record: bool = False, **kwargs):
28-
"""
29-
A shortcut wrapper around warnings.catch_warning()
30-
and warnings.filterwarnings()
31-
"""
32-
with warnings.catch_warnings(record=record) as context:
33-
warnings.filterwarnings(*args, **kwargs)
34-
yield context
35-
36-
37-
def box(minx, miny, maxx, maxy) -> shapely.Polygon:
38-
"""
39-
Make a box, with coordinates going counterclockwise
40-
starting at (minx miny).
41-
"""
42-
return shapely.Polygon([
43-
(minx, miny),
44-
(maxx, miny),
45-
(maxx, maxy),
46-
(minx, maxy),
47-
])
48-
49-
50-
def reduce_axes(arr: numpy.ndarray, axes: tuple[bool, ...] | None = None) -> numpy.ndarray:
51-
"""
52-
Reduce the size of an array by one on an axis-by-axis basis. If an axis is
53-
reduced, neigbouring values are averaged together
54-
55-
:param arr: The array to reduce.
56-
:param axes: A tuple of booleans indicating which axes should be reduced. Optional, defaults to reducing along all axes.
57-
:returns: A new array with the same number of axes, but one size smaller in each axis that was reduced.
58-
"""
59-
if axes is None:
60-
axes = tuple(True for _ in arr.shape)
61-
axes_slices = [[numpy.s_[+1:], numpy.s_[:-1]] if axis else [numpy.s_[:]] for axis in axes]
62-
return numpy.mean([arr[tuple(p)] for p in itertools.product(*axes_slices)], axis=0) # type: ignore
63-
64-
65-
def mask_from_strings(mask_strings: list[str]) -> numpy.ndarray:
66-
"""
67-
Make a boolean mask array from a list of strings:
68-
69-
>>> mask_from_strings([
70-
... "101",
71-
... "010",
72-
... "111",
73-
... ])
74-
array([[ True, False, True],
75-
[False, True, False],
76-
[ True, True, True]])
77-
"""
78-
return numpy.array([list(map(int, line)) for line in mask_strings]).astype(bool)
11+
12+
from .array import reduce_axes
7913

8014

8115
class ShocLayerGenerator(abc.ABC):
@@ -132,7 +66,7 @@ def z_centre(self) -> numpy.ndarray:
13266

13367

13468
class ShocGridGenerator(abc.ABC):
135-
dimensions = {
69+
dimensions: dict[ArakawaCGridKind, tuple[Hashable, Hashable]] = {
13670
ArakawaCGridKind.face: ('j_centre', 'i_centre'),
13771
ArakawaCGridKind.back: ('j_back', 'i_back'),
13872
ArakawaCGridKind.left: ('j_left', 'i_left'),
@@ -376,130 +310,3 @@ def make_x_grid(self, j: numpy.ndarray, i: numpy.ndarray) -> numpy.ndarray:
376310

377311
def make_y_grid(self, j: numpy.ndarray, i: numpy.ndarray) -> numpy.ndarray:
378312
return 0.1 * (5 + j) * numpy.sin(numpy.pi - i * numpy.pi / (self.i_size)) # type: ignore
379-
380-
381-
def assert_property_not_cached(
382-
instance: Any,
383-
prop_name: str,
384-
/,
385-
) -> None:
386-
__tracebackhide__ = True # noqa
387-
cls = type(instance)
388-
prop = getattr(cls, prop_name)
389-
assert isinstance(prop, cached_property), \
390-
"{instance!r}.{prop_name} is not a cached_property"
391-
392-
cache = instance.__dict__
393-
assert prop.attrname not in cache, \
394-
f"{instance!r}.{prop_name} was cached!"
395-
396-
397-
def skip_versions(*requirements: str):
398-
"""
399-
Skips a test function if any of the version specifiers match.
400-
"""
401-
invalid_versions = []
402-
for requirement in map(Requirement, requirements):
403-
assert not requirement.extras
404-
assert requirement.url is None
405-
assert requirement.marker is None
406-
407-
try:
408-
version = importlib.metadata.version(requirement.name)
409-
except importlib.metadata.PackageNotFoundError:
410-
# The package is not installed, so an invalid version isn't installed
411-
continue
412-
413-
if version in requirement.specifier:
414-
invalid_versions.append(
415-
f'{requirement.name}=={version} matches skipped version specifier {requirement}')
416-
417-
return pytest.mark.skipif(len(invalid_versions) > 0, reason='\n'.join(invalid_versions))
418-
419-
420-
def only_versions(*requirements: str):
421-
"""
422-
Runs a test function only if all of the version specifiers match.
423-
"""
424-
invalid_versions = []
425-
for requirement in map(Requirement, requirements):
426-
assert not requirement.extras
427-
assert requirement.url is None
428-
assert requirement.marker is None
429-
430-
try:
431-
version = importlib.metadata.version(requirement.name)
432-
except importlib.metadata.PackageNotFoundError:
433-
# The package is not installed, so a required version is not installed
434-
invalid_versions.append(f'{requirement.name} is not installed')
435-
continue
436-
437-
if version not in requirement.specifier:
438-
invalid_versions.append(
439-
f'{requirement.name}=={version} does not satisfy {requirement}')
440-
441-
return pytest.mark.skipif(len(invalid_versions) > 0, reason='\n'.join(invalid_versions))
442-
443-
444-
def plot_geometry(
445-
dataset: xarray.Dataset,
446-
out: Pathish,
447-
*,
448-
figsize: tuple[float, float] = (10, 10),
449-
extent: Bounds | None = None,
450-
title: str | None = None
451-
) -> None:
452-
figure = plt.figure(layout='constrained', figsize=figsize)
453-
axes: GeoAxes = figure.add_subplot(projection=dataset.ems.data_crs)
454-
axes.set_aspect(aspect='equal', adjustable='datalim')
455-
axes.gridlines(draw_labels=['left', 'bottom'], linestyle='dashed')
456-
457-
dataset.ems.plot_geometry(axes)
458-
grid = dataset.ems.default_grid
459-
x, y = grid.centroid_coordinates.T
460-
axes.scatter(x, y, c='red')
461-
462-
if title is not None:
463-
axes.set_title(title)
464-
if extent is not None:
465-
axes.set_extent(extent)
466-
467-
figure.savefig(out)
468-
469-
470-
class TracemallocTracker:
471-
_finished = False
472-
_usage = None
473-
474-
def __enter__(self):
475-
tracemalloc.start()
476-
return self
477-
478-
@property
479-
def current(self):
480-
if not self._finished:
481-
raise RuntimeError("Context manager has not exited yet")
482-
return self._usage[0]
483-
484-
@property
485-
def peak(self):
486-
if not self._finished:
487-
raise RuntimeError("Context manager has not exited yet")
488-
return self._usage[1]
489-
490-
def __exit__(
491-
self,
492-
exc_type: Type[Exception] | None,
493-
exc_value: Exception | None,
494-
exc_traceback: TracebackType | None,
495-
) -> bool | None:
496-
self._finished = True
497-
self._usage = tracemalloc.get_traced_memory()
498-
499-
tracemalloc.stop()
500-
501-
return None
502-
503-
504-
def track_peak_memory_usage():
505-
return TracemallocTracker()

tests/helpers/functools.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
from functools import cached_property
2+
from typing import Any
3+
4+
5+
def assert_property_not_cached(
6+
instance: Any,
7+
prop_name: str,
8+
/,
9+
) -> None:
10+
__tracebackhide__ = True # noqa
11+
cls = type(instance)
12+
prop = getattr(cls, prop_name)
13+
assert isinstance(prop, cached_property), \
14+
"{instance!r}.{prop_name} is not a cached_property"
15+
16+
cache = instance.__dict__
17+
assert prop.attrname not in cache, \
18+
f"{instance!r}.{prop_name} was cached!"

tests/helpers/geometry.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
import shapely
2+
import xarray
3+
from cartopy.mpl.geoaxes import GeoAxes
4+
from matplotlib import pyplot as plt
5+
6+
from emsarray.types import Bounds, Pathish
7+
8+
9+
def box(minx, miny, maxx, maxy) -> shapely.Polygon:
10+
"""
11+
Make a box, with coordinates going counterclockwise
12+
starting at (minx miny).
13+
"""
14+
return shapely.Polygon([
15+
(minx, miny),
16+
(maxx, miny),
17+
(maxx, maxy),
18+
(minx, maxy),
19+
])
20+
21+
22+
def plot_geometry(
23+
dataset: xarray.Dataset,
24+
out: Pathish,
25+
*,
26+
figsize: tuple[float, float] = (10, 10),
27+
extent: Bounds | None = None,
28+
title: str | None = None
29+
) -> None:
30+
figure = plt.figure(layout='constrained', figsize=figsize)
31+
axes: GeoAxes = figure.add_subplot(projection=dataset.ems.data_crs)
32+
axes.set_aspect(aspect='equal', adjustable='datalim')
33+
axes.gridlines(draw_labels=['left', 'bottom'], linestyle='dashed')
34+
35+
dataset.ems.plot_geometry(axes)
36+
grid = dataset.ems.default_grid
37+
x, y = grid.centroid_coordinates.T
38+
axes.scatter(x, y, c='red')
39+
40+
if title is not None:
41+
axes.set_title(title)
42+
if extent is not None:
43+
axes.set_extent(extent)
44+
45+
figure.savefig(out)

0 commit comments

Comments
 (0)