Skip to content

Commit 72bdc39

Browse files
author
Niru Maheswaranathan
authored
Merge pull request #26 from nirum/codex/add-type-annotations-for-functions
Add type hints across project
2 parents 1f7e25c + 17a55d1 commit 72bdc39

10 files changed

Lines changed: 160 additions & 110 deletions

File tree

justfile

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,6 @@ typecheck:
1717

1818
test:
1919
uv run pytest --cov=jetplot --cov-report=term
20+
21+
loop:
22+
find {src,tests} -name "*.py" | entr -c just test

src/jetplot/chart_utils.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Plotting utils."""
22

3+
from collections.abc import Callable
34
from functools import partial, wraps
45
from typing import Any, Literal
56

@@ -20,7 +21,7 @@
2021
]
2122

2223

23-
def figwrapper(fun):
24+
def figwrapper(fun: Callable[..., Any]) -> Callable[..., Any]:
2425
"""Decorator that adds figure handles to the kwargs of a function."""
2526

2627
@wraps(fun)
@@ -33,7 +34,7 @@ def wrapper(*args, **kwargs):
3334
return wrapper
3435

3536

36-
def plotwrapper(fun):
37+
def plotwrapper(fun: Callable[..., Any]) -> Callable[..., Any]:
3738
"""Decorator that adds figure and axes handles to the kwargs of a function."""
3839

3940
@wraps(fun)
@@ -52,7 +53,7 @@ def wrapper(*args, **kwargs):
5253
return wrapper
5354

5455

55-
def axwrapper(fun):
56+
def axwrapper(fun: Callable[..., Any]) -> Callable[..., Any]:
5657
"""Decorator that adds an axes handle to kwargs."""
5758

5859
@wraps(fun)
@@ -70,7 +71,7 @@ def wrapper(*args, **kwargs):
7071

7172

7273
@axwrapper
73-
def noticks(**kwargs):
74+
def noticks(**kwargs: Any) -> None:
7475
"""
7576
Clears tick marks (useful for images)
7677
"""
@@ -81,7 +82,13 @@ def noticks(**kwargs):
8182

8283

8384
@axwrapper
84-
def nospines(left=False, bottom=False, top=True, right=True, **kwargs):
85+
def nospines(
86+
left: bool = False,
87+
bottom: bool = False,
88+
top: bool = True,
89+
right: bool = True,
90+
**kwargs: Any,
91+
) -> plt.Axes:
8592
"""
8693
Hides the specified axis spines (by default, right and top spines)
8794
"""
@@ -160,7 +167,12 @@ def get_bounds(axis: Literal["x", "y"], ax: Axes | None = None) -> tuple[float,
160167

161168

162169
@axwrapper
163-
def breathe(xlims=None, ylims=None, padding_percent=0.05, **kwargs):
170+
def breathe(
171+
xlims: tuple[float, float] | None = None,
172+
ylims: tuple[float, float] | None = None,
173+
padding_percent: float = 0.05,
174+
**kwargs: Any,
175+
) -> plt.Axes:
164176
"""Adds space between axes and plot."""
165177
ax = kwargs["ax"]
166178

@@ -203,7 +215,7 @@ def yclamp(
203215
y0: float | None = None,
204216
y1: float | None = None,
205217
dt: float | None = None,
206-
**kwargs,
218+
**kwargs: Any,
207219
) -> Axes:
208220
"""Clamp the y-axis to evenly spaced tick marks."""
209221
ax = kwargs["ax"]
@@ -228,7 +240,7 @@ def xclamp(
228240
x0: float | None = None,
229241
x1: float | None = None,
230242
dt: float | None = None,
231-
**kwargs,
243+
**kwargs: Any,
232244
) -> Axes:
233245
"""Clamp the x-axis to evenly spaced tick marks."""
234246
ax = kwargs["ax"]

src/jetplot/colors.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
"""Colorschemes"""
22

3+
from typing import cast
4+
35
import numpy as np
46
from matplotlib import cm
57
from matplotlib import pyplot as plt
68
from matplotlib.axes import Axes
79
from matplotlib.colors import LinearSegmentedColormap, to_hex
810
from matplotlib.figure import Figure
911
from matplotlib.typing import ColorType
10-
from numpy.typing import NDArray
1112

1213
from .chart_utils import noticks
1314

@@ -18,24 +19,24 @@ class Palette(list[ColorType]):
1819
"""Color palette based on a list of values."""
1920

2021
@property
21-
def hex(self):
22+
def hex(self) -> "Palette":
2223
"""Return the palette colors as hexadecimal strings."""
23-
return Palette([to_hex(rgb) for rgb in self])
24+
return Palette([to_hex(rgb) for rgb in self]) # pyrefly: ignore
2425

2526
@property
2627
def cmap(self) -> LinearSegmentedColormap:
2728
"""Return the palette as a Matplotlib colormap."""
2829
return LinearSegmentedColormap.from_list("", self)
2930

30-
def plot(self, figsize: tuple[int, int] = (5, 1)) -> tuple[Figure, NDArray[Axes]]:
31+
def plot(self, figsize: tuple[int, int] = (5, 1)) -> tuple[Figure, list[Axes]]:
3132
"""Visualize the colors in the palette."""
3233
fig, axs = plt.subplots(1, len(self), figsize=figsize)
3334
for c, ax in zip(self, axs, strict=True): # pyrefly: ignore
3435
ax.set_facecolor(c)
3536
ax.set_aspect("equal")
3637
noticks(ax=ax)
3738

38-
return fig, axs
39+
return fig, cast(list[Axes], axs)
3940

4041

4142
def cubehelix(
@@ -46,7 +47,7 @@ def cubehelix(
4647
start: float = 0.0,
4748
rot: float = 0.4,
4849
hue: float = 0.8,
49-
):
50+
) -> Palette:
5051
"""Cubehelix parameterized colormap."""
5152
lambda_ = np.linspace(vmin, vmax, n)
5253
x = lambda_**gamma

src/jetplot/images.py

Lines changed: 33 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
"""Image visualization tools."""
22

3-
from collections.abc import Callable
3+
from collections.abc import Callable, Iterable
44
from functools import partial
5+
from typing import Any, cast
56

67
import numpy as np
78
from matplotlib import pyplot as plt
9+
from matplotlib.axes import Axes
10+
from matplotlib.image import AxesImage
811
from matplotlib.ticker import FixedLocator
912

1013
from . import colors as c
@@ -15,16 +18,16 @@
1518

1619
@plotwrapper
1720
def img(
18-
data,
19-
mode="div",
20-
cmap=None,
21-
aspect="equal",
22-
vmin=None,
23-
vmax=None,
24-
cbar=True,
25-
interpolation="none",
26-
**kwargs,
27-
):
21+
data: np.ndarray,
22+
mode: str = "div",
23+
cmap: str | None = None,
24+
aspect: str = "equal",
25+
vmin: float | None = None,
26+
vmax: float | None = None,
27+
cbar: bool = True,
28+
interpolation: str = "none",
29+
**kwargs: Any,
30+
) -> AxesImage:
2831
"""Visualize a matrix as an image.
2932
3033
Args:
@@ -86,7 +89,7 @@ def fsurface(
8689
yrng: tuple[float, float] | None = None,
8790
n: int = 100,
8891
nargs: int = 2,
89-
**kwargs,
92+
**kwargs: Any,
9093
) -> None:
9194
"""Plot a 2‑D function as a filled surface."""
9295
xrng = (-1, 1) if xrng is None else xrng
@@ -112,22 +115,22 @@ def fsurface(
112115

113116
@plotwrapper
114117
def cmat(
115-
arr,
116-
labels=None,
117-
annot=True,
118-
cmap="gist_heat_r",
119-
cbar=False,
120-
fmt="0.0%",
121-
dark_color="#222222",
122-
light_color="#dddddd",
123-
grid_color=c.gray[9],
124-
theta=0.5,
125-
label_fontsize=10.0,
126-
fontsize=10.0,
127-
vmin=0.0,
128-
vmax=1.0,
129-
**kwargs,
130-
):
118+
arr: np.ndarray,
119+
labels: Iterable[str] | None = None,
120+
annot: bool = True,
121+
cmap: str = "gist_heat_r",
122+
cbar: bool = False,
123+
fmt: str = "0.0%",
124+
dark_color: str = "#222222",
125+
light_color: str = "#dddddd",
126+
grid_color: str = cast(str, c.gray[9]),
127+
theta: float = 0.5,
128+
label_fontsize: float = 10.0,
129+
fontsize: float = 10.0,
130+
vmin: float = 0.0,
131+
vmax: float = 1.0,
132+
**kwargs: Any,
133+
) -> tuple[AxesImage, Axes]:
131134
"""Plot confusion matrix."""
132135
num_rows, num_cols = arr.shape
133136

@@ -138,8 +141,8 @@ def cmat(
138141

139142
for x, y, value in zip(xs.flat, ys.flat, arr.flat, strict=True): # pyrefly: ignore
140143
color = dark_color if (value <= theta) else light_color
141-
annot = f"{{:{fmt}}}".format(value)
142-
ax.text(x, y, annot, ha="center", va="center", color=color, fontsize=fontsize)
144+
label = f"{{:{fmt}}}".format(value)
145+
ax.text(x, y, label, ha="center", va="center", color=color, fontsize=fontsize)
143146

144147
if labels is not None:
145148
ax.set_xticks(np.arange(num_cols))

0 commit comments

Comments
 (0)