Skip to content

Commit d8f2bf1

Browse files
committed
Merge remote-tracking branch 'origin/develop' into feature/snapshots
2 parents 0582b59 + 6e215ec commit d8f2bf1

7 files changed

Lines changed: 68 additions & 21 deletions

File tree

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ Code freeze date: YYYY-MM-DD
1212

1313
### Added
1414

15+
- Better type hints and overloads signatures for ImpactFuncSet [#1250](https://github.com/CLIMADA-project/climada_python/pull/1250)
16+
1517
### Changed
1618
- Updated Impact Calculation Tutorial (`doc.climada_engine_Impact.ipynb`) [#1095](https://github.com/CLIMADA-project/climada_python/pull/1095).
1719

climada/entity/impact_funcs/base.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ def from_step_impf(
189189
haz_type: str,
190190
mdd: tuple[float, float] = (0, 1),
191191
paa: tuple[float, float] = (1, 1),
192-
impf_id: int = 1,
192+
impf_id: int | str = 1,
193193
**kwargs,
194194
):
195195
"""Step function type impact function.
@@ -207,7 +207,7 @@ def from_step_impf(
207207
(min, max) mdd values. The default is (0, 1)
208208
paa: tuple(float, float)
209209
(min, max) paa values. The default is (1, 1)
210-
impf_id : int, optional, default=1
210+
impf_id : int|str, optional, default=1
211211
impact function id
212212
kwargs :
213213
keyword arguments passed to ImpactFunc()
@@ -250,7 +250,7 @@ def from_sigmoid_impf(
250250
k: float,
251251
x0: float,
252252
haz_type: str,
253-
impf_id: int = 1,
253+
impf_id: int | str = 1,
254254
**kwargs,
255255
):
256256
r"""Sigmoid type impact function hinging on three parameter.
@@ -320,7 +320,7 @@ def from_poly_s_shape(
320320
scale: float,
321321
exponent: float,
322322
haz_type: str,
323-
impf_id: int = 1,
323+
impf_id: int | str = 1,
324324
**kwargs,
325325
):
326326
r"""S-shape polynomial impact function hinging on four parameter.

climada/entity/impact_funcs/impact_func_set.py

Lines changed: 50 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
import copy
2525
import logging
2626
from itertools import repeat
27-
from typing import Iterable, Optional
27+
from typing import Iterable, Optional, Union, overload
2828

2929
import matplotlib.pyplot as plt
3030
import numpy as np
@@ -119,7 +119,7 @@ def clear(self):
119119
"""Reinitialize attributes."""
120120
self._data = dict() # {hazard_type : {id:ImpactFunc}}
121121

122-
def append(self, func):
122+
def append(self, func: ImpactFunc):
123123
"""Append a ImpactFunc. Overwrite existing if same id and haz_type.
124124
125125
Parameters
@@ -141,7 +141,9 @@ def append(self, func):
141141
self._data[func.haz_type] = dict()
142142
self._data[func.haz_type][func.id] = func
143143

144-
def remove_func(self, haz_type=None, fun_id=None):
144+
def remove_func(
145+
self, haz_type: Optional[str] = None, fun_id: Optional[str | int] = None
146+
):
145147
"""Remove impact function(s) with provided hazard type and/or id.
146148
If no input provided, all impact functions are removed.
147149
@@ -173,7 +175,29 @@ def remove_func(self, haz_type=None, fun_id=None):
173175
else:
174176
self._data = dict()
175177

176-
def get_func(self, haz_type=None, fun_id=None):
178+
@overload
179+
def get_func(
180+
self, haz_type: None = None, fun_id: None = None
181+
) -> dict[str, dict[Union[int, str], ImpactFunc]]: ...
182+
183+
@overload
184+
def get_func(
185+
self, haz_type: None = ..., fun_id: int | str = ...
186+
) -> list[ImpactFunc]: ...
187+
188+
@overload
189+
def get_func(
190+
self, haz_type: str = ..., fun_id: None = None
191+
) -> list[ImpactFunc]: ...
192+
193+
@overload
194+
def get_func(self, haz_type: str = ..., fun_id: int | str = ...) -> ImpactFunc: ...
195+
196+
def get_func(
197+
self, haz_type: Optional[str] = None, fun_id: Optional[int | str] = None
198+
) -> Union[
199+
ImpactFunc, list[ImpactFunc], dict[str, dict[Union[int, str], ImpactFunc]]
200+
]:
177201
"""Get ImpactFunc(s) of input hazard type and/or id.
178202
If no input provided, all impact functions are returned.
179203
@@ -209,7 +233,7 @@ def get_func(self, haz_type=None, fun_id=None):
209233
else:
210234
return self._data
211235

212-
def get_hazard_types(self, fun_id=None):
236+
def get_hazard_types(self, fun_id: Optional[str | int] = None) -> list[str]:
213237
"""Get impact functions hazard types contained for the id provided.
214238
Return all hazard types if no input id.
215239
@@ -231,7 +255,15 @@ def get_hazard_types(self, fun_id=None):
231255
haz_types.append(vul_haz)
232256
return haz_types
233257

234-
def get_ids(self, haz_type=None):
258+
@overload
259+
def get_ids(self, haz_type: None = None) -> dict[str, list[str | int]]: ...
260+
261+
@overload
262+
def get_ids(self, haz_type: str) -> list[int | str]: ...
263+
264+
def get_ids(
265+
self, haz_type: Optional[str] = None
266+
) -> dict[str, list[str | int]] | list[int | str]:
235267
"""Get impact functions ids contained for the hazard type provided.
236268
Return all ids for each hazard type if no input hazard type.
237269
@@ -256,7 +288,9 @@ def get_ids(self, haz_type=None):
256288
except KeyError:
257289
return list()
258290

259-
def size(self, haz_type=None, fun_id=None):
291+
def size(
292+
self, haz_type: Optional[str] = None, fun_id: Optional[str | int] = None
293+
) -> int:
260294
"""Get number of impact functions contained with input hazard type and
261295
/or id. If no input provided, get total number of impact functions.
262296
@@ -279,6 +313,7 @@ def size(self, haz_type=None, fun_id=None):
279313
return 1
280314
if (haz_type is not None) or (fun_id is not None):
281315
return len(self.get_func(haz_type, fun_id))
316+
282317
return sum(len(vul_list) for vul_list in self.get_ids().values())
283318

284319
def check(self):
@@ -300,7 +335,7 @@ def check(self):
300335
)
301336
vul.check()
302337

303-
def extend(self, impact_funcs):
338+
def extend(self, impact_funcs: "ImpactFuncSet"):
304339
"""Append impact functions of input ImpactFuncSet to current
305340
ImpactFuncSet. Overwrite ImpactFunc if same id and haz_type.
306341
@@ -323,7 +358,13 @@ def extend(self, impact_funcs):
323358
for _, vul in vul_dict.items():
324359
self.append(vul)
325360

326-
def plot(self, haz_type=None, fun_id=None, axis=None, **kwargs):
361+
def plot(
362+
self,
363+
haz_type: Optional[str] = None,
364+
fun_id: Optional[str | int] = None,
365+
axis=None,
366+
**kwargs,
367+
):
327368
"""Plot impact functions of selected hazard (all if not provided) and
328369
selected function id (all if not provided).
329370

climada/hazard/test/test_xarray.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
import numpy as np
2929
import xarray as xr
3030
from pyproj import CRS
31-
from scipy.sparse import csr_matrix
31+
from scipy.sparse import csr_array, csr_matrix
3232

3333
from climada.hazard.base import Hazard
3434
from climada.util.constants import DEF_CRS
@@ -104,8 +104,8 @@ def _assert_default_types(self, hazard):
104104
self.assertIsInstance(hazard.event_id, np.ndarray)
105105
self.assertIsInstance(hazard.event_name, list)
106106
self.assertIsInstance(hazard.frequency, np.ndarray)
107-
self.assertIsInstance(hazard.intensity, csr_matrix)
108-
self.assertIsInstance(hazard.fraction, csr_matrix)
107+
self.assertIsInstance(hazard.intensity, csr_matrix | csr_array)
108+
self.assertIsInstance(hazard.fraction, csr_matrix | csr_array)
109109
self.assertIsInstance(hazard.date, np.ndarray)
110110

111111
def test_load_path(self):
@@ -149,8 +149,11 @@ def _load_and_assert(**kwargs):
149149
def test_type_error(self):
150150
"""Calling 'from_xarray_raster' with wrong data type should throw"""
151151
# Passing a DataArray
152-
with xr.open_dataset(self.netcdf_path) as dset, self.assertRaisesRegex(
153-
TypeError, "This method only supports passing xr.Dataset"
152+
with (
153+
xr.open_dataset(self.netcdf_path) as dset,
154+
self.assertRaisesRegex(
155+
TypeError, "This method only supports passing xr.Dataset"
156+
),
154157
):
155158
Hazard.from_xarray_raster(dset["intensity"], "", "")
156159

climada/hazard/xarray.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def _to_csr_matrix(array: xr.DataArray) -> sparse.csr_matrix:
5858
output_dtypes=[array.dtype],
5959
)
6060
sparse_coo = array.compute().data # Load into memory
61-
return sparse_coo.tocsr() # Convert sparse.COO to scipy.sparse.csr_matrix
61+
return sparse_coo.tocsr() # Convert sparse.COO to scipy.sparse.csr_array
6262

6363

6464
# Define accessors for xarray DataArrays

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ doc = [
6565
"ipython",
6666
"myst-nb",
6767
"readthedocs-sphinx-ext>=2.2",
68-
"sphinx",
68+
"sphinx>=8.1,<9.0",
6969
"sphinx-book-theme",
7070
"sphinx-markdown-tables",
7171
"sphinx-design",

requirements/env_climada.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,11 @@ dependencies:
2727
- peewee>=3.18
2828
- pint>=0.24
2929
- pip
30-
- pyarrow>=21.0
30+
- pyarrow>=20.0 # petals cannot be installed on win-64 with pyarrow 21.0
3131
- pycountry>=24.6
3232
- pyproj>=3.7
3333
- pytables>=3.10 # this is the name of the pypi 'tables' package on conda-forge
34+
- python>=3.10,<3.13
3435
- pyxlsb>=1.0
3536
- rasterio>=1.4
3637
- requests>=2.32

0 commit comments

Comments
 (0)