Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 14 additions & 35 deletions .github/workflows/pull_request.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,23 +35,23 @@ jobs:
include:
- os: ubuntu-latest
python-version: "3.9"
dependency-set: minimum
dependency-set: lowest-direct
- os: macos-15-intel # We need x86 as ARM is python>= 3.11 only.
# https://github.com/actions/setup-python/issues/855
python-version: "3.9"
dependency-set: minimum
dependency-set: lowest-direct
- os: windows-latest
python-version: "3.9"
dependency-set: minimum
dependency-set: lowest-direct
- os: ubuntu-latest
python-version: "3.13"
dependency-set: maximum
dependency-set: highest
- os: macos-latest
python-version: "3.13"
dependency-set: maximum
dependency-set: highest
- os: windows-latest
python-version: "3.13"
dependency-set: maximum
dependency-set: highest
runs-on: ${{ matrix.os }}

env:
Expand All @@ -71,23 +71,8 @@ jobs:
with:
enable-cache: true

- name: Generate requirements file
run: python scripts/generate_dependencies.py ${{ matrix.dependency-set }}

- name: Install dependencies
run: |
uv pip install --system ".[all]"
# onnx is required for onnx export tests
# we don't install all dev dependencies here for speed
uv pip install --system -r requirements.txt
uv pip install --system pytest psutil
# licensecheck is required for license checking
uv pip install --system licensecheck
# onnx is not supported on python 3.13 yet https://github.com/onnx/onnx/issues/6339
if [[ "${{ matrix.python-version }}" != "3.13" ]]; then
uv pip install --system onnx
fi
shell: bash
run: uv sync --group ci --all-extras --resolution ${{ matrix.dependency-set }}

- name: Restore model cache
id: restore-model-cache
Expand All @@ -104,30 +89,24 @@ jobs:
- name: Download models from Hugging Face
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
run: python scripts/download_all_models.py
run: uv run --no-sync python scripts/download_all_models.py

- name: Check for forbidden licenses
if: runner.os == 'MacOS' && matrix.python-version == '3.9'
run: |
licensecheck \
uv run --no-sync licensecheck \
--requirements-paths pyproject.toml \
--only-licenses APACHE MIT BSD ISC PYTHON UNLICENSE UNKNOWN \
--ignore-packages certifi "tabpfn*" \
--show-only-failing \
-0

- name: Run Tests (Unix)
if: runner.os != 'Windows'
- name: Run Tests
env:
TABPFN_EXCLUDE_DEVICES: mps
run: |
FAST_TEST_MODE=1 pytest tests/

- name: Run Tests (Windows)
if: runner.os == 'Windows'
run: |
$env:FAST_TEST_MODE = 1
pytest tests/
TABPFN_EXCLUDE_DEVICES: "mps"
HF_TOKEN: ${{ secrets.HF_TOKEN }}
FAST_TEST_MODE: 1
run: uv run --no-sync pytest tests/

- name: Save model cache
if: github.ref == 'refs/heads/main'
Expand Down
32 changes: 17 additions & 15 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,12 @@ source = "https://github.com/PriorLabs/tabpfn-extensions"

[project.optional-dependencies]
interpretability = [
"shap>=0.41.0",
"shap>=0.47.0",
"shapiq>=0.4.0",
"seaborn>=0.12.2",
]
post_hoc_ensembles = [
"llvmlite",
"llvmlite>=0.43.0",
"hyperopt>=0.2.7",
"autogluon.tabular==1.4.0"
]
Expand All @@ -81,29 +81,31 @@ unsupervised = []

# Meta-package that installs all extensions
all = [
"shap>=0.41.0",
"shapiq>=0.4.0",
"seaborn>=0.12.2",

"llvmlite>=0.30.0",
"hyperopt>=0.2.7",
# https://discuss.python.org/t/pkg-resources-removal-how-to-go-from-there/106079
"setuptools>=67.0.0,<82",
"autogluon.tabular==1.4.0",

"scikit-survival>=0.25.0; python_version >= '3.10'",
"tabpfn-extensions[interpretability]",
"tabpfn-extensions[post_hoc_ensembles]",
"tabpfn-extensions[hpo]",
"tabpfn-extensions[survival]",
"tabpfn-extensions[many_class]",
"tabpfn-extensions[classifier_as_regressor]",
"tabpfn-extensions[rf_pfn]",
"tabpfn-extensions[unsupervised]",
]

[dependency-groups]
dev = [
{include-group = "ci"},
"pre-commit>=3.0.0",
"ruff==0.8.6", # This must be the same version as in .pre-commit-config.yaml
"mypy>=1.0.0",
"build>=1.3.0",
"twine>=6.2.0",
]
# The minimum subset of the dev dependencies required to run the tests on the CI.
# The idea is to be as close to the deployment environment as possible.
ci = [
"pytest>=8.0.0",
"pytest-xdist>=3.6.0",
"pytest-mock>=3.15.1",
"build>=1.3.0",
"twine>=6.2.0",
"licensecheck>=2025.1.0",
]

Expand Down
93 changes: 0 additions & 93 deletions scripts/generate_dependencies.py

This file was deleted.

23 changes: 23 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
"""Tests for tabpfn_extensions.utils."""

from __future__ import annotations

from pytest_mock import MockerFixture

from tabpfn_extensions.utils import infer_device


def test__infer_device__tabpfn_not_installed__returns_fake_device_with_cpu_type(
mocker: MockerFixture,
) -> None:
"""Test that we get a fake CPU device when tabpfn is not installed.

Currently our test infrastructure runs the tests for the maximum and minimum
supported versions of the tabpfn package. This means that the cases where tabpfn is
installed will be covered by the other tests in this package. However, the case
where tabpfn is not installed will not be covered, so this is a basic test for that.
"""
mocker.patch("importlib.util.find_spec", return_value=None)
assert infer_device(device="auto").type == "cpu"
assert infer_device(device="cuda").type == "cpu"
assert infer_device(device="cpu").type == "cpu"
48 changes: 34 additions & 14 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading