Skip to content

Commit 695ba96

Browse files
author
Niru Maheswaranathan
authored
Adds support for pyrefly type checking
- Adds pyrefly to pyproject - Suppresses a bunch of existing errors - Adds GitHub workflow for type checking
2 parents a77b36b + 1d50345 commit 695ba96

16 files changed

Lines changed: 113 additions & 253 deletions
Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
name: CI
1+
name: Tests
22

33
on:
44
push:
@@ -11,7 +11,7 @@ jobs:
1111
runs-on: ubuntu-latest
1212
strategy:
1313
matrix:
14-
python-version: ["3.9", "3.10", "3.11", "3.12"]
14+
python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
1515

1616
steps:
1717
- name: Check out the code
@@ -27,14 +27,14 @@ jobs:
2727
curl -LsSf https://astral.sh/uv/install.sh | sh
2828
echo "$HOME/.cargo/bin" >> $GITHUB_PATH
2929
30+
- name: Create virtual environment
31+
run: uv venv
32+
3033
- name: Install project with dev dependencies
31-
run: |
32-
uv pip install --system .[dev]
34+
run: uv pip install -e .[dev]
3335

3436
- name: Run ruff
35-
run: |
36-
ruff check .
37+
run: uv run ruff check .
3738

3839
- name: Run tests with pytest
39-
run: |
40-
pytest --cov --cov-report=term-missing
40+
run: uv run pytest --cov --cov-report=term-missing

.github/workflows/typecheck.yml

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
name: Typecheck
2+
3+
on:
4+
push:
5+
branches: [main, master]
6+
pull_request:
7+
branches: [main, master]
8+
9+
jobs:
10+
test:
11+
runs-on: ubuntu-latest
12+
strategy:
13+
matrix:
14+
python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
15+
16+
steps:
17+
- name: Check out the code
18+
uses: actions/checkout@v4
19+
20+
- name: Set up Python
21+
uses: actions/setup-python@v5
22+
with:
23+
python-version: ${{ matrix.python-version }}
24+
25+
- name: Install uv
26+
run: |
27+
curl -LsSf https://astral.sh/uv/install.sh | sh
28+
echo "$HOME/.cargo/bin" >> $GITHUB_PATH
29+
30+
- name: Create virtual environment
31+
run: uv venv
32+
33+
- name: Install project with dev dependencies
34+
run: uv pip install -e .[dev]
35+
36+
- name: Run Pyrefly Type Checker
37+
run: uv run pyrefly check

pyproject.toml

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,21 @@ Homepage = "https://github.com/nirum/jetplot"
2424

2525
[project.optional-dependencies]
2626
dev = [
27-
"pre-commit>=2.21.0",
2827
"pyrefly>=0.14.0",
2928
"pytest>=7.4.4",
3029
"pytest-cov>=4.1.0",
3130
"ruff>=0.11.10",
3231
]
32+
docs = [
33+
"mkdocs>=1.5.3",
34+
"mkdocs-material>=9.2.7",
35+
"mkdocstrings[python]>=0.22.0",
36+
]
37+
38+
[tool.pyrefly]
39+
search_path = [
40+
"src/"
41+
]
3342

3443
[tool.ruff]
3544
lint.extend-ignore = ["E111", "E114", "E501", "F403"]
@@ -39,13 +48,3 @@ package-dir = {"" = "src"}
3948

4049
[tool.setuptools.dynamic]
4150
version = {attr = "jetplot.__version__"}
42-
43-
[tool.uv]
44-
default-groups = ["dev", "docs"]
45-
46-
[dependency-groups]
47-
docs = [
48-
"mkdocs>=1.5.3",
49-
"mkdocs-material>=9.2.7",
50-
"mkdocstrings[python]>=0.22.0",
51-
]

requirements.txt

Lines changed: 0 additions & 4 deletions
This file was deleted.

src/jetplot/chart_utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,8 @@ def get_bounds(axis, ax=None):
129129
return ax.spines[spine_key].get_bounds()
130130
else:
131131
lower, upper = None, None
132+
133+
# pyrefly: ignore # no-matching-overload, bad-argument-type
132134
for tick, label in zip(ticks(), labels()):
133135
if label.get_text() != "":
134136
if lower is None:
@@ -188,6 +190,8 @@ def yclamp(y0=None, y1=None, dt=None, **kwargs):
188190
lims = ax.get_ylim()
189191
y0 = lims[0] if y0 is None else y0
190192
y1 = lims[1] if y1 is None else y1
193+
194+
# pyrefly: ignore # no-matching-overload, bad-argument-type
191195
dt = np.mean(np.diff(ax.get_yticks())) if dt is None else dt
192196

193197
new_ticks = np.arange(dt * np.floor(y0 / dt), dt * (np.ceil(y1 / dt) + 1), dt)
@@ -205,6 +209,8 @@ def xclamp(x0=None, x1=None, dt=None, **kwargs):
205209
lims = ax.get_xlim()
206210
x0 = lims[0] if x0 is None else x0
207211
x1 = lims[1] if x1 is None else x1
212+
213+
# pyrefly: ignore # no-matching-overload, bad-argument-type
208214
dt = np.mean(np.diff(ax.get_xticks())) if dt is None else dt
209215

210216
new_ticks = np.arange(dt * np.floor(x0 / dt), dt * (np.ceil(x1 / dt) + 1), dt)

src/jetplot/colors.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,13 +37,18 @@ def cubehelix(
3737
lambda_ = np.linspace(vmin, vmax, n)
3838
x = lambda_**gamma
3939
phi = 2 * np.pi * (start / 3 + rot * lambda_)
40+
41+
# pyrefly: ignore # bad-argument-type, no-matching-overload
4042
alpha = 0.5 * hue * x * (1.0 - x)
4143
A = np.array([[-0.14861, 1.78277], [-0.29227, -0.90649], [1.97294, 0.0]])
4244
b = np.stack([np.cos(phi), np.sin(phi)])
45+
46+
# pyrefly: ignore # no-matching-overload, bad-argument-type
4347
return Palette((x + alpha * (A @ b)).T)
4448

4549

4650
def cmap_colors(cmap: str, n: int, vmin: float = 0.0, vmax: float = 1.0):
51+
# pyrefly: ignore # missing-attribute
4752
return Palette(cm.__getattribute__(cmap)(np.linspace(vmin, vmax, n)))
4853

4954

src/jetplot/images.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,10 @@ def fsurface(func, xrng=None, yrng=None, n=100, nargs=2, **kwargs):
8383
xrng = (-1, 1) if xrng is None else xrng
8484
yrng = xrng if yrng is None else yrng
8585

86+
# pyrefly: ignore # missing-argument, no-matching-overload, bad-argument-type
8687
xs = np.linspace(*xrng, n)
88+
89+
# pyrefly: ignore # missing-argument, no-matching-overload, bad-argument-type
8790
ys = np.linspace(*yrng, n)
8891

8992
xm, ym = np.meshgrid(xs, ys)
@@ -126,6 +129,8 @@ def cmat(
126129
cb = imv(arr, ax=ax, vmin=vmin, vmax=vmax, cmap=cmap, cbar=cbar)
127130

128131
xs, ys = np.meshgrid(np.arange(num_cols), np.arange(num_rows))
132+
133+
# pyrefly: ignore # no-matching-overload, bad-argument-type
129134
for x, y, value in zip(xs.flat, ys.flat, arr.flat):
130135
color = dark_color if (value <= theta) else light_color
131136
annot = f"{{:{fmt}}}".format(value)
@@ -137,7 +142,10 @@ def cmat(
137142
ax.set_yticks(np.arange(num_rows))
138143
ax.set_yticklabels(labels, fontsize=label_fontsize)
139144

145+
# pyrefly: ignore # bad-argument-type
140146
ax.xaxis.set_minor_locator(FixedLocator(np.arange(num_cols) - 0.5))
147+
148+
# pyrefly: ignore # bad-argument-type
141149
ax.yaxis.set_minor_locator(FixedLocator(np.arange(num_rows) - 0.5))
142150

143151
ax.grid(

src/jetplot/plots.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ def violinplot(
5454
pc.set_edgecolor(ec)
5555
pc.set_alpha(1.0)
5656

57+
# pyrefly: ignore # no-matching-overload, bad-argument-type
5758
q1, medians, q3 = np.percentile(data, [25, 50, 75], axis=0)
5859

5960
ax.vlines(
@@ -76,6 +77,7 @@ def violinplot(
7677
if showmeans:
7778
ax.scatter(
7879
xs,
80+
# pyrefly: ignore # no-matching-overload, bad-argument-type
7981
np.mean(data, axis=0),
8082
marker="s",
8183
color=mc,
@@ -120,6 +122,8 @@ def hist2d(x, y, bins=None, range=None, cmap="hot", **kwargs):
120122
bins = 25
121123

122124
# compute the histogram
125+
126+
# pyrefly: ignore # no-matching-overload, unexpected-keyword, bad-argument-type
123127
cnt, xe, ye = np.histogram2d(x, y, bins=bins, normed=True, range=range)
124128

125129
# generate the plot
@@ -362,6 +366,7 @@ def ellipse(x, y, n_std=3.0, facecolor="none", estimator="empirical", **kwargs):
362366
mean_y = np.mean(y)
363367

364368
transform = (
369+
# pyrefly: ignore # bad-argument-type
365370
Affine2D().rotate_deg(45).scale(scale_x, scale_y).translate(mean_x, mean_y)
366371
)
367372

src/jetplot/signals.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
"""Tools for signal processing."""
22

3+
34
from typing import Callable
45

56
import numpy as np
67
from numpy.typing import ArrayLike
8+
9+
# pyrefly: ignore # missing-module-attribute
710
from scipy.ndimage import gaussian_filter1d
811

912
__all__ = ["smooth", "canoncorr", "participation_ratio", "stable_rank", "normalize"]
@@ -58,6 +61,8 @@ def canoncorr(X: ArrayLike, Y: ArrayLike) -> ArrayLike:
5861
between linear subspaces." Mathematics of computation 27.123 (1973): 579-594.
5962
"""
6063
# Orthogonalize each subspace
64+
65+
# pyrefly: ignore # no-matching-overload, bad-argument-type
6166
qu, qv = np.linalg.qr(X)[0], np.linalg.qr(Y)[0]
6267

6368
# singular values of the inner product between the orthogonalized spaces

src/jetplot/style.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
"""Opinionated matplotlib style defaults."""
22

33
from functools import partial
4+
5+
46
from typing import Mapping, Any
57

68
from cycler import cycler

0 commit comments

Comments
 (0)