|
1 | 1 | import abc |
2 | | -import contextlib |
3 | | -import importlib.metadata |
4 | | -import itertools |
5 | | -import tracemalloc |
6 | | -import warnings |
7 | | -from collections.abc import Hashable |
8 | 2 | from functools import cached_property |
9 | | -from types import TracebackType |
10 | | -from typing import Any, Type |
| 3 | +from typing import Hashable |
11 | 4 |
|
12 | | -import matplotlib.pyplot as plt |
13 | 5 | import numpy |
14 | | -import pytest |
15 | | -import shapely |
16 | 6 | import xarray |
17 | | -from cartopy.mpl.geoaxes import GeoAxes |
18 | | -from packaging.requirements import Requirement |
19 | 7 |
|
20 | 8 | from emsarray.conventions.arakawa_c import ( |
21 | 9 | ArakawaCGridKind, c_mask_from_centres |
22 | 10 | ) |
23 | | -from emsarray.types import Bounds, Pathish |
24 | | - |
25 | | - |
26 | | -@contextlib.contextmanager |
27 | | -def filter_warning(*args, record: bool = False, **kwargs): |
28 | | - """ |
29 | | - A shortcut wrapper around warnings.catch_warning() |
30 | | - and warnings.filterwarnings() |
31 | | - """ |
32 | | - with warnings.catch_warnings(record=record) as context: |
33 | | - warnings.filterwarnings(*args, **kwargs) |
34 | | - yield context |
35 | | - |
36 | | - |
37 | | -def box(minx, miny, maxx, maxy) -> shapely.Polygon: |
38 | | - """ |
39 | | - Make a box, with coordinates going counterclockwise |
40 | | - starting at (minx miny). |
41 | | - """ |
42 | | - return shapely.Polygon([ |
43 | | - (minx, miny), |
44 | | - (maxx, miny), |
45 | | - (maxx, maxy), |
46 | | - (minx, maxy), |
47 | | - ]) |
48 | | - |
49 | | - |
50 | | -def reduce_axes(arr: numpy.ndarray, axes: tuple[bool, ...] | None = None) -> numpy.ndarray: |
51 | | - """ |
52 | | - Reduce the size of an array by one on an axis-by-axis basis. If an axis is |
53 | | - reduced, neigbouring values are averaged together |
54 | | -
|
55 | | - :param arr: The array to reduce. |
56 | | - :param axes: A tuple of booleans indicating which axes should be reduced. Optional, defaults to reducing along all axes. |
57 | | - :returns: A new array with the same number of axes, but one size smaller in each axis that was reduced. |
58 | | - """ |
59 | | - if axes is None: |
60 | | - axes = tuple(True for _ in arr.shape) |
61 | | - axes_slices = [[numpy.s_[+1:], numpy.s_[:-1]] if axis else [numpy.s_[:]] for axis in axes] |
62 | | - return numpy.mean([arr[tuple(p)] for p in itertools.product(*axes_slices)], axis=0) # type: ignore |
63 | | - |
64 | | - |
65 | | -def mask_from_strings(mask_strings: list[str]) -> numpy.ndarray: |
66 | | - """ |
67 | | - Make a boolean mask array from a list of strings: |
68 | | -
|
69 | | - >>> mask_from_strings([ |
70 | | - ... "101", |
71 | | - ... "010", |
72 | | - ... "111", |
73 | | - ... ]) |
74 | | - array([[ True, False, True], |
75 | | - [False, True, False], |
76 | | - [ True, True, True]]) |
77 | | - """ |
78 | | - return numpy.array([list(map(int, line)) for line in mask_strings]).astype(bool) |
| 11 | + |
| 12 | +from .array import reduce_axes |
79 | 13 |
|
80 | 14 |
|
81 | 15 | class ShocLayerGenerator(abc.ABC): |
@@ -132,7 +66,7 @@ def z_centre(self) -> numpy.ndarray: |
132 | 66 |
|
133 | 67 |
|
134 | 68 | class ShocGridGenerator(abc.ABC): |
135 | | - dimensions = { |
| 69 | + dimensions: dict[ArakawaCGridKind, tuple[Hashable, Hashable]] = { |
136 | 70 | ArakawaCGridKind.face: ('j_centre', 'i_centre'), |
137 | 71 | ArakawaCGridKind.back: ('j_back', 'i_back'), |
138 | 72 | ArakawaCGridKind.left: ('j_left', 'i_left'), |
@@ -376,130 +310,3 @@ def make_x_grid(self, j: numpy.ndarray, i: numpy.ndarray) -> numpy.ndarray: |
376 | 310 |
|
377 | 311 | def make_y_grid(self, j: numpy.ndarray, i: numpy.ndarray) -> numpy.ndarray: |
378 | 312 | return 0.1 * (5 + j) * numpy.sin(numpy.pi - i * numpy.pi / (self.i_size)) # type: ignore |
379 | | - |
380 | | - |
381 | | -def assert_property_not_cached( |
382 | | - instance: Any, |
383 | | - prop_name: str, |
384 | | - /, |
385 | | -) -> None: |
386 | | - __tracebackhide__ = True # noqa |
387 | | - cls = type(instance) |
388 | | - prop = getattr(cls, prop_name) |
389 | | - assert isinstance(prop, cached_property), \ |
390 | | - "{instance!r}.{prop_name} is not a cached_property" |
391 | | - |
392 | | - cache = instance.__dict__ |
393 | | - assert prop.attrname not in cache, \ |
394 | | - f"{instance!r}.{prop_name} was cached!" |
395 | | - |
396 | | - |
397 | | -def skip_versions(*requirements: str): |
398 | | - """ |
399 | | - Skips a test function if any of the version specifiers match. |
400 | | - """ |
401 | | - invalid_versions = [] |
402 | | - for requirement in map(Requirement, requirements): |
403 | | - assert not requirement.extras |
404 | | - assert requirement.url is None |
405 | | - assert requirement.marker is None |
406 | | - |
407 | | - try: |
408 | | - version = importlib.metadata.version(requirement.name) |
409 | | - except importlib.metadata.PackageNotFoundError: |
410 | | - # The package is not installed, so an invalid version isn't installed |
411 | | - continue |
412 | | - |
413 | | - if version in requirement.specifier: |
414 | | - invalid_versions.append( |
415 | | - f'{requirement.name}=={version} matches skipped version specifier {requirement}') |
416 | | - |
417 | | - return pytest.mark.skipif(len(invalid_versions) > 0, reason='\n'.join(invalid_versions)) |
418 | | - |
419 | | - |
420 | | -def only_versions(*requirements: str): |
421 | | - """ |
422 | | - Runs a test function only if all of the version specifiers match. |
423 | | - """ |
424 | | - invalid_versions = [] |
425 | | - for requirement in map(Requirement, requirements): |
426 | | - assert not requirement.extras |
427 | | - assert requirement.url is None |
428 | | - assert requirement.marker is None |
429 | | - |
430 | | - try: |
431 | | - version = importlib.metadata.version(requirement.name) |
432 | | - except importlib.metadata.PackageNotFoundError: |
433 | | - # The package is not installed, so a required version is not installed |
434 | | - invalid_versions.append(f'{requirement.name} is not installed') |
435 | | - continue |
436 | | - |
437 | | - if version not in requirement.specifier: |
438 | | - invalid_versions.append( |
439 | | - f'{requirement.name}=={version} does not satisfy {requirement}') |
440 | | - |
441 | | - return pytest.mark.skipif(len(invalid_versions) > 0, reason='\n'.join(invalid_versions)) |
442 | | - |
443 | | - |
444 | | -def plot_geometry( |
445 | | - dataset: xarray.Dataset, |
446 | | - out: Pathish, |
447 | | - *, |
448 | | - figsize: tuple[float, float] = (10, 10), |
449 | | - extent: Bounds | None = None, |
450 | | - title: str | None = None |
451 | | -) -> None: |
452 | | - figure = plt.figure(layout='constrained', figsize=figsize) |
453 | | - axes: GeoAxes = figure.add_subplot(projection=dataset.ems.data_crs) |
454 | | - axes.set_aspect(aspect='equal', adjustable='datalim') |
455 | | - axes.gridlines(draw_labels=['left', 'bottom'], linestyle='dashed') |
456 | | - |
457 | | - dataset.ems.plot_geometry(axes) |
458 | | - grid = dataset.ems.default_grid |
459 | | - x, y = grid.centroid_coordinates.T |
460 | | - axes.scatter(x, y, c='red') |
461 | | - |
462 | | - if title is not None: |
463 | | - axes.set_title(title) |
464 | | - if extent is not None: |
465 | | - axes.set_extent(extent) |
466 | | - |
467 | | - figure.savefig(out) |
468 | | - |
469 | | - |
470 | | -class TracemallocTracker: |
471 | | - _finished = False |
472 | | - _usage = None |
473 | | - |
474 | | - def __enter__(self): |
475 | | - tracemalloc.start() |
476 | | - return self |
477 | | - |
478 | | - @property |
479 | | - def current(self): |
480 | | - if not self._finished: |
481 | | - raise RuntimeError("Context manager has not exited yet") |
482 | | - return self._usage[0] |
483 | | - |
484 | | - @property |
485 | | - def peak(self): |
486 | | - if not self._finished: |
487 | | - raise RuntimeError("Context manager has not exited yet") |
488 | | - return self._usage[1] |
489 | | - |
490 | | - def __exit__( |
491 | | - self, |
492 | | - exc_type: Type[Exception] | None, |
493 | | - exc_value: Exception | None, |
494 | | - exc_traceback: TracebackType | None, |
495 | | - ) -> bool | None: |
496 | | - self._finished = True |
497 | | - self._usage = tracemalloc.get_traced_memory() |
498 | | - |
499 | | - tracemalloc.stop() |
500 | | - |
501 | | - return None |
502 | | - |
503 | | - |
504 | | -def track_peak_memory_usage(): |
505 | | - return TracemallocTracker() |
0 commit comments