11"""Image visualization tools."""
22
3- from collections .abc import Callable
3+ from collections .abc import Callable , Iterable
44from functools import partial
5+ from typing import Any , cast
56
67import numpy as np
78from matplotlib import pyplot as plt
9+ from matplotlib .axes import Axes
10+ from matplotlib .image import AxesImage
811from matplotlib .ticker import FixedLocator
912
1013from . import colors as c
1518
1619@plotwrapper
1720def img (
18- data ,
19- mode = "div" ,
20- cmap = None ,
21- aspect = "equal" ,
22- vmin = None ,
23- vmax = None ,
24- cbar = True ,
25- interpolation = "none" ,
26- ** kwargs ,
27- ):
21+ data : np . ndarray ,
22+ mode : str = "div" ,
23+ cmap : str | None = None ,
24+ aspect : str = "equal" ,
25+ vmin : float | None = None ,
26+ vmax : float | None = None ,
27+ cbar : bool = True ,
28+ interpolation : str = "none" ,
29+ ** kwargs : Any ,
30+ ) -> AxesImage :
2831 """Visualize a matrix as an image.
2932
3033 Args:
@@ -86,7 +89,7 @@ def fsurface(
8689 yrng : tuple [float , float ] | None = None ,
8790 n : int = 100 ,
8891 nargs : int = 2 ,
89- ** kwargs ,
92+ ** kwargs : Any ,
9093) -> None :
9194 """Plot a 2‑D function as a filled surface."""
9295 xrng = (- 1 , 1 ) if xrng is None else xrng
@@ -112,22 +115,22 @@ def fsurface(
112115
113116@plotwrapper
114117def cmat (
115- arr ,
116- labels = None ,
117- annot = True ,
118- cmap = "gist_heat_r" ,
119- cbar = False ,
120- fmt = "0.0%" ,
121- dark_color = "#222222" ,
122- light_color = "#dddddd" ,
123- grid_color = c .gray [9 ],
124- theta = 0.5 ,
125- label_fontsize = 10.0 ,
126- fontsize = 10.0 ,
127- vmin = 0.0 ,
128- vmax = 1.0 ,
129- ** kwargs ,
130- ):
118+ arr : np . ndarray ,
119+ labels : Iterable [ str ] | None = None ,
120+ annot : bool = True ,
121+ cmap : str = "gist_heat_r" ,
122+ cbar : bool = False ,
123+ fmt : str = "0.0%" ,
124+ dark_color : str = "#222222" ,
125+ light_color : str = "#dddddd" ,
126+ grid_color : str = cast ( str , c .gray [9 ]) ,
127+ theta : float = 0.5 ,
128+ label_fontsize : float = 10.0 ,
129+ fontsize : float = 10.0 ,
130+ vmin : float = 0.0 ,
131+ vmax : float = 1.0 ,
132+ ** kwargs : Any ,
133+ ) -> tuple [ AxesImage , Axes ] :
131134 """Plot confusion matrix."""
132135 num_rows , num_cols = arr .shape
133136
@@ -138,8 +141,8 @@ def cmat(
138141
139142 for x , y , value in zip (xs .flat , ys .flat , arr .flat , strict = True ): # pyrefly: ignore
140143 color = dark_color if (value <= theta ) else light_color
141- annot = f"{{:{ fmt } }}" .format (value )
142- ax .text (x , y , annot , ha = "center" , va = "center" , color = color , fontsize = fontsize )
144+ label = f"{{:{ fmt } }}" .format (value )
145+ ax .text (x , y , label , ha = "center" , va = "center" , color = color , fontsize = fontsize )
143146
144147 if labels is not None :
145148 ax .set_xticks (np .arange (num_cols ))
0 commit comments