diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..57b0a53 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,46 @@ +name: CI + +on: + push: + branches: [main] + pull_request: + branches: [main] + +jobs: + test: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.11", "3.12", "3.13"] + steps: + - uses: actions/checkout@v4 + + - uses: astral-sh/setup-uv@v5 + with: + enable-cache: true + + - name: Set up Python ${{ matrix.python-version }} + run: uv python install ${{ matrix.python-version }} + + - name: Install dependencies + run: uv sync --extra dev + + - name: Run tests + run: uv run pytest irtk/tests/ -x -q --timeout=60 + env: + JAX_PLATFORMS: cpu + + build: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - uses: astral-sh/setup-uv@v5 + with: + enable-cache: true + + - name: Build package + run: uv build + + - name: Check wheel contents + run: python3 -m zipfile -l dist/*.whl | grep -c "test_" | xargs -I{} test {} -eq 0 diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml new file mode 100644 index 0000000..86a9953 --- /dev/null +++ b/.github/workflows/publish.yml @@ -0,0 +1,39 @@ +name: Publish to PyPI + +on: + release: + types: [published] + +jobs: + build: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - uses: astral-sh/setup-uv@v5 + with: + enable-cache: true + + - name: Build package + run: uv build + + - name: Check wheel contents + run: python3 -m zipfile -l dist/*.whl | grep -c "test_" | xargs -I{} test {} -eq 0 + + - uses: actions/upload-artifact@v4 + with: + name: dist + path: dist/ + + publish: + needs: build + runs-on: ubuntu-latest + permissions: + id-token: write + steps: + - uses: actions/download-artifact@v4 + with: + name: dist + path: dist/ + + - uses: pypa/gh-action-pypi-publish@release/v1 diff --git a/README.md b/README.md index 21427a7..fc25418 100644 --- a/README.md +++ b/README.md @@ -10,12 +10,6 @@ Initially vibe-coded by Opus, so YMMV. PRs welcome. pip install irtk-jax ``` -For Apple Silicon GPU acceleration via [jax-mps](https://github.com/danielpcox/jax-mps): - -```bash -pip install irtk-jax[mps] -``` - ### Development setup ```bash diff --git a/irtk/tests/test_model_internal_consistency.py b/irtk/tests/test_model_internal_consistency.py index e8a5352..6198d7b 100644 --- a/irtk/tests/test_model_internal_consistency.py +++ b/irtk/tests/test_model_internal_consistency.py @@ -43,7 +43,7 @@ def test_logit_lens_values(model_and_tokens): model, tokens = model_and_tokens result = logit_lens_consistency(model, tokens, position=-1) for p in result["per_layer"]: - assert -1.0 <= p["logit_cosine"] <= 1.0 + assert -1.0 - 1e-6 <= p["logit_cosine"] <= 1.0 + 1e-6 assert result["per_layer"][-1]["agrees_with_final"] is True @@ -74,7 +74,7 @@ def test_orthogonality_structure(model_and_tokens): def test_orthogonality_values(model_and_tokens): model, tokens = model_and_tokens result = component_orthogonality(model, tokens, layer=0, position=-1) - assert -1.0 <= result["cosine"] <= 1.0 + assert -1.0 - 1e-6 <= result["cosine"] <= 1.0 + 1e-6 assert 0 <= result["orthogonality"] <= 1.0 diff --git a/pyproject.toml b/pyproject.toml index 7fe7896..a34ea55 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,10 +35,6 @@ dev = [ "pytest>=7.0", "torch>=2.0", ] -mps = [ - "jax-mps", -] - [tool.hatch.build.targets.wheel] packages = ["irtk"] exclude = ["irtk/tests"] @@ -49,12 +45,9 @@ exclude = ["notebooks", "irtk/tests"] [tool.pytest.ini_options] testpaths = ["irtk/tests"] -[tool.uv.sources] -jax-mps = { path = "../jax-mps", editable = true } - [dependency-groups] dev = [ - "jax-mps", "jupyterlab>=4.5.5", + "pytest-timeout>=2.2.0", "pytest-xdist>=3.8.0", ] diff --git a/uv.lock b/uv.lock index cd942d5..8d3e612 100644 --- a/uv.lock +++ b/uv.lock @@ -820,6 +820,7 @@ mps = [ dev = [ { name = "jax-mps" }, { name = "jupyterlab" }, + { name = "pytest-timeout" }, { name = "pytest-xdist" }, ] @@ -842,6 +843,7 @@ provides-extras = ["dev", "mps"] dev = [ { name = "jax-mps", editable = "../jax-mps" }, { name = "jupyterlab", specifier = ">=4.5.5" }, + { name = "pytest-timeout", specifier = ">=2.2.0" }, { name = "pytest-xdist", specifier = ">=3.8.0" }, ] @@ -2099,6 +2101,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/3b/ab/b3226f0bd7cdcf710fbede2b3548584366da3b19b5021e74f5bde2a8fa3f/pytest-9.0.2-py3-none-any.whl", hash = "sha256:711ffd45bf766d5264d487b917733b453d917afd2b0ad65223959f59089f875b", size = 374801, upload-time = "2025-12-06T21:30:49.154Z" }, ] +[[package]] +name = "pytest-timeout" +version = "2.4.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pytest" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/ac/82/4c9ecabab13363e72d880f2fb504c5f750433b2b6f16e99f4ec21ada284c/pytest_timeout-2.4.0.tar.gz", hash = "sha256:7e68e90b01f9eff71332b25001f85c75495fc4e3a836701876183c4bcfd0540a", size = 17973, upload-time = "2025-05-05T19:44:34.99Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/fa/b6/3127540ecdf1464a00e5a01ee60a1b09175f6913f0644ac748494d9c4b21/pytest_timeout-2.4.0-py3-none-any.whl", hash = "sha256:c42667e5cdadb151aeb5b26d114aff6bdf5a907f176a007a30b940d3d865b5c2", size = 14382, upload-time = "2025-05-05T19:44:33.502Z" }, +] + [[package]] name = "pytest-xdist" version = "3.8.0"