From b4bd024b271731925d491c81cf3d606269c81ccb Mon Sep 17 00:00:00 2001 From: Oscar Key Date: Mon, 24 Nov 2025 12:25:10 +0100 Subject: [PATCH] Tidy up dependency management and CI. - Switch to uv min/max dependency resolution, rather than our script, to match our other repositories - Bump shap minimum version to get min dependency resolution to work (it's still 12 months old) - Don't duplicate dependencies in "all" optional dependency, and reference other optional dependency sets instead - Use "ci" dependency group, rather than puting depedencies in workflow file - Merge windows and linux jobs --- .github/workflows/pull_request.yml | 49 +++++----------- pyproject.toml | 32 +++++----- scripts/generate_dependencies.py | 93 ------------------------------ tests/test_utils.py | 23 ++++++++ uv.lock | 48 ++++++++++----- 5 files changed, 88 insertions(+), 157 deletions(-) delete mode 100644 scripts/generate_dependencies.py create mode 100644 tests/test_utils.py diff --git a/.github/workflows/pull_request.yml b/.github/workflows/pull_request.yml index b4d84495..61ed5fb1 100644 --- a/.github/workflows/pull_request.yml +++ b/.github/workflows/pull_request.yml @@ -35,23 +35,23 @@ jobs: include: - os: ubuntu-latest python-version: "3.9" - dependency-set: minimum + dependency-set: lowest-direct - os: macos-15-intel # We need x86 as ARM is python>= 3.11 only. # https://github.com/actions/setup-python/issues/855 python-version: "3.9" - dependency-set: minimum + dependency-set: lowest-direct - os: windows-latest python-version: "3.9" - dependency-set: minimum + dependency-set: lowest-direct - os: ubuntu-latest python-version: "3.13" - dependency-set: maximum + dependency-set: highest - os: macos-latest python-version: "3.13" - dependency-set: maximum + dependency-set: highest - os: windows-latest python-version: "3.13" - dependency-set: maximum + dependency-set: highest runs-on: ${{ matrix.os }} env: @@ -71,23 +71,8 @@ jobs: with: enable-cache: true - - name: Generate requirements file - run: python scripts/generate_dependencies.py ${{ matrix.dependency-set }} - - name: Install dependencies - run: | - uv pip install --system ".[all]" - # onnx is required for onnx export tests - # we don't install all dev dependencies here for speed - uv pip install --system -r requirements.txt - uv pip install --system pytest psutil - # licensecheck is required for license checking - uv pip install --system licensecheck - # onnx is not supported on python 3.13 yet https://github.com/onnx/onnx/issues/6339 - if [[ "${{ matrix.python-version }}" != "3.13" ]]; then - uv pip install --system onnx - fi - shell: bash + run: uv sync --group ci --all-extras --resolution ${{ matrix.dependency-set }} - name: Restore model cache id: restore-model-cache @@ -104,30 +89,24 @@ jobs: - name: Download models from Hugging Face env: HF_TOKEN: ${{ secrets.HF_TOKEN }} - run: python scripts/download_all_models.py + run: uv run --no-sync python scripts/download_all_models.py - name: Check for forbidden licenses if: runner.os == 'MacOS' && matrix.python-version == '3.9' run: | - licensecheck \ + uv run --no-sync licensecheck \ --requirements-paths pyproject.toml \ --only-licenses APACHE MIT BSD ISC PYTHON UNLICENSE UNKNOWN \ --ignore-packages certifi "tabpfn*" \ --show-only-failing \ -0 - - name: Run Tests (Unix) - if: runner.os != 'Windows' + - name: Run Tests env: - TABPFN_EXCLUDE_DEVICES: mps - run: | - FAST_TEST_MODE=1 pytest tests/ - - - name: Run Tests (Windows) - if: runner.os == 'Windows' - run: | - $env:FAST_TEST_MODE = 1 - pytest tests/ + TABPFN_EXCLUDE_DEVICES: "mps" + HF_TOKEN: ${{ secrets.HF_TOKEN }} + FAST_TEST_MODE: 1 + run: uv run --no-sync pytest tests/ - name: Save model cache if: github.ref == 'refs/heads/main' diff --git a/pyproject.toml b/pyproject.toml index 19083d61..e51910a9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -57,12 +57,12 @@ source = "https://github.com/PriorLabs/tabpfn-extensions" [project.optional-dependencies] interpretability = [ - "shap>=0.41.0", + "shap>=0.47.0", "shapiq>=0.4.0", "seaborn>=0.12.2", ] post_hoc_ensembles = [ - "llvmlite", + "llvmlite>=0.43.0", "hyperopt>=0.2.7", "autogluon.tabular==1.4.0" ] @@ -81,29 +81,31 @@ unsupervised = [] # Meta-package that installs all extensions all = [ - "shap>=0.41.0", - "shapiq>=0.4.0", - "seaborn>=0.12.2", - - "llvmlite>=0.30.0", - "hyperopt>=0.2.7", - # https://discuss.python.org/t/pkg-resources-removal-how-to-go-from-there/106079 - "setuptools>=67.0.0,<82", - "autogluon.tabular==1.4.0", - - "scikit-survival>=0.25.0; python_version >= '3.10'", + "tabpfn-extensions[interpretability]", + "tabpfn-extensions[post_hoc_ensembles]", + "tabpfn-extensions[hpo]", + "tabpfn-extensions[survival]", + "tabpfn-extensions[many_class]", + "tabpfn-extensions[classifier_as_regressor]", + "tabpfn-extensions[rf_pfn]", + "tabpfn-extensions[unsupervised]", ] [dependency-groups] dev = [ + {include-group = "ci"}, "pre-commit>=3.0.0", "ruff==0.8.6", # This must be the same version as in .pre-commit-config.yaml "mypy>=1.0.0", + "build>=1.3.0", + "twine>=6.2.0", +] +# The minimum subset of the dev dependencies required to run the tests on the CI. +# The idea is to be as close to the deployment environment as possible. +ci = [ "pytest>=8.0.0", "pytest-xdist>=3.6.0", "pytest-mock>=3.15.1", - "build>=1.3.0", - "twine>=6.2.0", "licensecheck>=2025.1.0", ] diff --git a/scripts/generate_dependencies.py b/scripts/generate_dependencies.py deleted file mode 100644 index 3a246c6e..00000000 --- a/scripts/generate_dependencies.py +++ /dev/null @@ -1,93 +0,0 @@ -"""Generate a requirements.txt file from pyproject.toml dependencies. - -This script can operate in two modes: -1. 'min': Extracts minimum versions (>=) and pins them with '=='. -2. 'max': Extracts maximum versions (<) or leaves them unpinned. -""" - -from __future__ import annotations - -import argparse -import re -from pathlib import Path - - -def parse_dependency_lines(content: str) -> list[str]: - """Finds and cleans the dependency lines from the pyproject.toml content.""" - # Find the dependencies section and match until we find the closing bracket - # that's not part of a package extra specification - deps_match = re.search(r"dependencies\s*=\s*\[(.*?)\n\]", content, re.DOTALL) - if not deps_match: - return [] - - deps_lines = deps_match.group(1).strip().split("\n") - - cleaned_deps = [] - for line in deps_lines: - # Assign the stripped line to a new variable to avoid the linter warning. - stripped_line = line.strip() - # Skip empty lines or comments - if not stripped_line or stripped_line.startswith("#"): - continue - # Clean the line by removing an optional trailing comma, then stripping quotes. - clean_dep = stripped_line.rstrip(",").strip("'\"") - cleaned_deps.append(clean_dep) - - return cleaned_deps - - -def main() -> None: - """Main function to parse arguments and generate the requirements file.""" - parser = argparse.ArgumentParser( - description="Generate requirements.txt from pyproject.toml.", - formatter_class=argparse.RawTextHelpFormatter, - ) - parser.add_argument( - "mode", - choices=["minimum", "maximum"], - help="The type of requirements to generate:\n" - "'minimum' - for minimum versions (e.g., 'package==1.2.3')\n" - "'maximum' - for maximum/unpinned versions (e.g., 'package<2.0' or 'package')", - ) - args = parser.parse_args() - - try: - content = Path("pyproject.toml").read_text() - except FileNotFoundError: - return - - # 1. Shared parsing logic - deps = parse_dependency_lines(content) - output_reqs = [] - - # 2. Mode-specific processing logic - if args.mode == "maximum": - for dep in deps: - # Check for maximum version constraint - pattern = r'([^>=<\s]+(?:\[[^\]]+\])?).*?<\s*([^,\s"\']+)' - max_version_match = re.search(pattern, dep) - if max_version_match: - package, max_ver = max_version_match.groups() - output_reqs.append(f"{package}<{max_ver}") - else: - # If no max version, just use the package name - package_match = re.match(r"([^>=<\s]+(?:\[[^\]]+\])?)", dep) - if package_match: - output_reqs.append(package_match.group(1)) - - elif args.mode == "minimum": - for dep in deps: - # Check for minimum version constraint - match = re.match(r'([^>=<\s]+(?:\[[^\]]+\])?)\s*>=\s*([^,\s"\']+)', dep) - if match: - package, min_ver = match.groups() - output_reqs.append(f"{package}=={min_ver}") - - # 3. Shared writing logic - output_filename = "requirements.txt" - with Path(output_filename).open("w") as f: - f.write("\n".join(output_reqs)) - - -if __name__ == "__main__": - main() diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 00000000..5eb1e47b --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,23 @@ +"""Tests for tabpfn_extensions.utils.""" + +from __future__ import annotations + +from pytest_mock import MockerFixture + +from tabpfn_extensions.utils import infer_device + + +def test__infer_device__tabpfn_not_installed__returns_fake_device_with_cpu_type( + mocker: MockerFixture, +) -> None: + """Test that we get a fake CPU device when tabpfn is not installed. + + Currently our test infrastructure runs the tests for the maximum and minimum + supported versions of the tabpfn package. This means that the cases where tabpfn is + installed will be covered by the other tests in this package. However, the case + where tabpfn is not installed will not be covered, so this is a basic test for that. + """ + mocker.patch("importlib.util.find_spec", return_value=None) + assert infer_device(device="auto").type == "cpu" + assert infer_device(device="cuda").type == "cpu" + assert infer_device(device="cpu").type == "cpu" diff --git a/uv.lock b/uv.lock index b5036596..a8d2bd2c 100644 --- a/uv.lock +++ b/uv.lock @@ -4651,6 +4651,13 @@ survival = [ ] [package.dev-dependencies] +ci = [ + { name = "licensecheck" }, + { name = "pytest", version = "8.4.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.10'" }, + { name = "pytest", version = "9.0.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.10'" }, + { name = "pytest-mock" }, + { name = "pytest-xdist" }, +] dev = [ { name = "build" }, { name = "licensecheck" }, @@ -4667,33 +4674,39 @@ dev = [ [package.metadata] requires-dist = [ - { name = "autogluon-tabular", marker = "extra == 'all'", specifier = "==1.4.0" }, { name = "autogluon-tabular", marker = "extra == 'post-hoc-ensembles'", specifier = "==1.4.0" }, - { name = "hyperopt", marker = "extra == 'all'", specifier = ">=0.2.7" }, { name = "hyperopt", marker = "extra == 'hpo'", specifier = ">=0.2.7" }, { name = "hyperopt", marker = "extra == 'post-hoc-ensembles'", specifier = ">=0.2.7" }, - { name = "llvmlite", marker = "extra == 'all'", specifier = ">=0.30.0" }, - { name = "llvmlite", marker = "extra == 'post-hoc-ensembles'" }, + { name = "llvmlite", marker = "extra == 'post-hoc-ensembles'", specifier = ">=0.43.0" }, { name = "pandas", specifier = ">=1.4.0,<3" }, { name = "scikit-learn", specifier = ">=1.6.0,<1.7" }, - { name = "scikit-survival", marker = "python_full_version >= '3.10' and extra == 'all'", specifier = ">=0.25.0" }, { name = "scikit-survival", marker = "python_full_version >= '3.10' and extra == 'survival'", specifier = ">=0.25.0" }, { name = "scipy", specifier = ">=1.11.1,<2" }, - { name = "seaborn", marker = "extra == 'all'", specifier = ">=0.12.2" }, { name = "seaborn", marker = "extra == 'interpretability'", specifier = ">=0.12.2" }, - { name = "setuptools", marker = "extra == 'all'", specifier = ">=67.0.0,<82" }, { name = "setuptools", marker = "extra == 'hpo'", specifier = ">=67.0.0,<82" }, - { name = "shap", marker = "extra == 'all'", specifier = ">=0.41.0" }, - { name = "shap", marker = "extra == 'interpretability'", specifier = ">=0.41.0" }, - { name = "shapiq", marker = "extra == 'all'", specifier = ">=0.4.0" }, + { name = "shap", marker = "extra == 'interpretability'", specifier = ">=0.47.0" }, { name = "shapiq", marker = "extra == 'interpretability'", specifier = ">=0.4.0" }, { name = "tabpfn", specifier = ">=6.0.5,<7" }, { name = "tabpfn-common-utils", extras = ["telemetry-interactive"], specifier = ">=0.2.0" }, + { name = "tabpfn-extensions", extras = ["classifier-as-regressor"], marker = "extra == 'all'" }, + { name = "tabpfn-extensions", extras = ["hpo"], marker = "extra == 'all'" }, + { name = "tabpfn-extensions", extras = ["interpretability"], marker = "extra == 'all'" }, + { name = "tabpfn-extensions", extras = ["many-class"], marker = "extra == 'all'" }, + { name = "tabpfn-extensions", extras = ["post-hoc-ensembles"], marker = "extra == 'all'" }, + { name = "tabpfn-extensions", extras = ["rf-pfn"], marker = "extra == 'all'" }, + { name = "tabpfn-extensions", extras = ["survival"], marker = "extra == 'all'" }, + { name = "tabpfn-extensions", extras = ["unsupervised"], marker = "extra == 'all'" }, { name = "torch", specifier = ">=2.1,<3" }, ] provides-extras = ["interpretability", "post-hoc-ensembles", "hpo", "survival", "many-class", "classifier-as-regressor", "rf-pfn", "unsupervised", "all"] [package.metadata.requires-dev] +ci = [ + { name = "licensecheck", specifier = ">=2025.1.0" }, + { name = "pytest", specifier = ">=8.0.0" }, + { name = "pytest-mock", specifier = ">=3.15.1" }, + { name = "pytest-xdist", specifier = ">=3.6.0" }, +] dev = [ { name = "build", specifier = ">=1.3.0" }, { name = "licensecheck", specifier = ">=2025.1.0" }, @@ -4864,10 +4877,17 @@ dependencies = [ { name = "typing-extensions", marker = "python_full_version >= '3.10'" }, ] wheels = [ - { url = "https://files.pythonhosted.org/packages/e3/ea/304cf7afb744aa626fa9855245526484ee55aba610d9973a0521c552a843/torch-2.10.0-1-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:c37fc46eedd9175f9c81814cc47308f1b42cfe4987e532d4b423d23852f2bf63", size = 79411450, upload-time = "2026-02-06T17:37:35.75Z" }, - { url = "https://files.pythonhosted.org/packages/25/d8/9e6b8e7df981a1e3ea3907fd5a74673e791da483e8c307f0b6ff012626d0/torch-2.10.0-1-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:f699f31a236a677b3118bc0a3ef3d89c0c29b5ec0b20f4c4bf0b110378487464", size = 79423460, upload-time = "2026-02-06T17:37:39.657Z" }, - { url = "https://files.pythonhosted.org/packages/c9/2f/0b295dd8d199ef71e6f176f576473d645d41357b7b8aa978cc6b042575df/torch-2.10.0-1-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:6abb224c2b6e9e27b592a1c0015c33a504b00a0e0938f1499f7f514e9b7bfb5c", size = 79498197, upload-time = "2026-02-06T17:37:27.627Z" }, - { url = "https://files.pythonhosted.org/packages/a4/1b/af5fccb50c341bd69dc016769503cb0857c1423fbe9343410dfeb65240f2/torch-2.10.0-1-cp313-none-macosx_11_0_arm64.whl", hash = "sha256:7350f6652dfd761f11f9ecb590bfe95b573e2961f7a242eccb3c8e78348d26fe", size = 79498248, upload-time = "2026-02-06T17:37:31.982Z" }, + { url = "https://files.pythonhosted.org/packages/5b/30/bfebdd8ec77db9a79775121789992d6b3b75ee5494971294d7b4b7c999bc/torch-2.10.0-2-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:2b980edd8d7c0a68c4e951ee1856334a43193f98730d97408fbd148c1a933313", size = 79411457, upload-time = "2026-02-10T21:44:59.189Z" }, + { url = "https://files.pythonhosted.org/packages/0f/8b/4b61d6e13f7108f36910df9ab4b58fd389cc2520d54d81b88660804aad99/torch-2.10.0-2-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:418997cb02d0a0f1497cf6a09f63166f9f5df9f3e16c8a716ab76a72127c714f", size = 79423467, upload-time = "2026-02-10T21:44:48.711Z" }, + { url = "https://files.pythonhosted.org/packages/d3/54/a2ba279afcca44bbd320d4e73675b282fcee3d81400ea1b53934efca6462/torch-2.10.0-2-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:13ec4add8c3faaed8d13e0574f5cd4a323c11655546f91fbe6afa77b57423574", size = 79498202, upload-time = "2026-02-10T21:44:52.603Z" }, + { url = "https://files.pythonhosted.org/packages/ec/23/2c9fe0c9c27f7f6cb865abcea8a4568f29f00acaeadfc6a37f6801f84cb4/torch-2.10.0-2-cp313-none-macosx_11_0_arm64.whl", hash = "sha256:e521c9f030a3774ed770a9c011751fb47c4d12029a3d6522116e48431f2ff89e", size = 79498254, upload-time = "2026-02-10T21:44:44.095Z" }, + { url = "https://files.pythonhosted.org/packages/16/ee/efbd56687be60ef9af0c9c0ebe106964c07400eade5b0af8902a1d8cd58c/torch-2.10.0-3-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:a1ff626b884f8c4e897c4c33782bdacdff842a165fee79817b1dd549fdda1321", size = 915510070, upload-time = "2026-03-11T14:16:39.386Z" }, + { url = "https://files.pythonhosted.org/packages/36/ab/7b562f1808d3f65414cd80a4f7d4bb00979d9355616c034c171249e1a303/torch-2.10.0-3-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:ac5bdcbb074384c66fa160c15b1ead77839e3fe7ed117d667249afce0acabfac", size = 915518691, upload-time = "2026-03-11T14:15:43.147Z" }, + { url = "https://files.pythonhosted.org/packages/b3/7a/abada41517ce0011775f0f4eacc79659bc9bc6c361e6bfe6f7052a6b9363/torch-2.10.0-3-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:98c01b8bb5e3240426dcde1446eed6f40c778091c8544767ef1168fc663a05a6", size = 915622781, upload-time = "2026-03-11T14:17:11.354Z" }, + { url = "https://files.pythonhosted.org/packages/ab/c6/4dfe238342ffdcec5aef1c96c457548762d33c40b45a1ab7033bb26d2ff2/torch-2.10.0-3-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:80b1b5bfe38eb0e9f5ff09f206dcac0a87aadd084230d4a36eea5ec5232c115b", size = 915627275, upload-time = "2026-03-11T14:16:11.325Z" }, + { url = "https://files.pythonhosted.org/packages/d8/f0/72bf18847f58f877a6a8acf60614b14935e2f156d942483af1ffc081aea0/torch-2.10.0-3-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:46b3574d93a2a8134b3f5475cfb98e2eb46771794c57015f6ad1fb795ec25e49", size = 915523474, upload-time = "2026-03-11T14:17:44.422Z" }, + { url = "https://files.pythonhosted.org/packages/f4/39/590742415c3030551944edc2ddc273ea1fdfe8ffb2780992e824f1ebee98/torch-2.10.0-3-cp314-cp314-manylinux_2_28_x86_64.whl", hash = "sha256:b1d5e2aba4eb7f8e87fbe04f86442887f9167a35f092afe4c237dfcaaef6e328", size = 915632474, upload-time = "2026-03-11T14:15:13.666Z" }, + { url = "https://files.pythonhosted.org/packages/b6/8e/34949484f764dde5b222b7fe3fede43e4a6f0da9d7f8c370bb617d629ee2/torch-2.10.0-3-cp314-cp314t-manylinux_2_28_x86_64.whl", hash = "sha256:0228d20b06701c05a8f978357f657817a4a63984b0c90745def81c18aedfa591", size = 915523882, upload-time = "2026-03-11T14:14:46.311Z" }, { url = "https://files.pythonhosted.org/packages/0c/1a/c61f36cfd446170ec27b3a4984f072fd06dab6b5d7ce27e11adb35d6c838/torch-2.10.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:5276fa790a666ee8becaffff8acb711922252521b28fbce5db7db5cf9cb2026d", size = 145992962, upload-time = "2026-01-21T16:24:14.04Z" }, { url = "https://files.pythonhosted.org/packages/b5/60/6662535354191e2d1555296045b63e4279e5a9dbad49acf55a5d38655a39/torch-2.10.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:aaf663927bcd490ae971469a624c322202a2a1e68936eb952535ca4cd3b90444", size = 915599237, upload-time = "2026-01-21T16:23:25.497Z" }, { url = "https://files.pythonhosted.org/packages/40/b8/66bbe96f0d79be2b5c697b2e0b187ed792a15c6c4b8904613454651db848/torch-2.10.0-cp310-cp310-win_amd64.whl", hash = "sha256:a4be6a2a190b32ff5c8002a0977a25ea60e64f7ba46b1be37093c141d9c49aeb", size = 113720931, upload-time = "2026-01-21T16:24:23.743Z" },