Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
name: CI

on:
push:
branches: [main]
pull_request:
branches: [main]

jobs:
test:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.11", "3.12", "3.13"]
steps:
- uses: actions/checkout@v4

- uses: astral-sh/setup-uv@v5
with:
enable-cache: true

- name: Set up Python ${{ matrix.python-version }}
run: uv python install ${{ matrix.python-version }}

- name: Install dependencies
run: uv sync --extra dev

- name: Run tests
run: uv run pytest irtk/tests/ -x -q --timeout=60
env:
JAX_PLATFORMS: cpu

build:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4

- uses: astral-sh/setup-uv@v5
with:
enable-cache: true

- name: Build package
run: uv build

- name: Check wheel contents
run: python3 -m zipfile -l dist/*.whl | grep -c "test_" | xargs -I{} test {} -eq 0
39 changes: 39 additions & 0 deletions .github/workflows/publish.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
name: Publish to PyPI

on:
release:
types: [published]

jobs:
build:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4

- uses: astral-sh/setup-uv@v5
with:
enable-cache: true

- name: Build package
run: uv build

- name: Check wheel contents
run: python3 -m zipfile -l dist/*.whl | grep -c "test_" | xargs -I{} test {} -eq 0

- uses: actions/upload-artifact@v4
with:
name: dist
path: dist/

publish:
needs: build
runs-on: ubuntu-latest
permissions:
id-token: write
steps:
- uses: actions/download-artifact@v4
with:
name: dist
path: dist/

- uses: pypa/gh-action-pypi-publish@release/v1
6 changes: 0 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,6 @@ Initially vibe-coded by Opus, so YMMV. PRs welcome.
pip install irtk-jax
```

For Apple Silicon GPU acceleration via [jax-mps](https://github.com/danielpcox/jax-mps):

```bash
pip install irtk-jax[mps]
```

### Development setup

```bash
Expand Down
4 changes: 2 additions & 2 deletions irtk/tests/test_model_internal_consistency.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def test_logit_lens_values(model_and_tokens):
model, tokens = model_and_tokens
result = logit_lens_consistency(model, tokens, position=-1)
for p in result["per_layer"]:
assert -1.0 <= p["logit_cosine"] <= 1.0
assert -1.0 - 1e-6 <= p["logit_cosine"] <= 1.0 + 1e-6
assert result["per_layer"][-1]["agrees_with_final"] is True


Expand Down Expand Up @@ -74,7 +74,7 @@ def test_orthogonality_structure(model_and_tokens):
def test_orthogonality_values(model_and_tokens):
model, tokens = model_and_tokens
result = component_orthogonality(model, tokens, layer=0, position=-1)
assert -1.0 <= result["cosine"] <= 1.0
assert -1.0 - 1e-6 <= result["cosine"] <= 1.0 + 1e-6
assert 0 <= result["orthogonality"] <= 1.0


Expand Down
9 changes: 1 addition & 8 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,6 @@ dev = [
"pytest>=7.0",
"torch>=2.0",
]
mps = [
"jax-mps",
]

[tool.hatch.build.targets.wheel]
packages = ["irtk"]
exclude = ["irtk/tests"]
Expand All @@ -49,12 +45,9 @@ exclude = ["notebooks", "irtk/tests"]
[tool.pytest.ini_options]
testpaths = ["irtk/tests"]

[tool.uv.sources]
jax-mps = { path = "../jax-mps", editable = true }

[dependency-groups]
dev = [
"jax-mps",
"jupyterlab>=4.5.5",
"pytest-timeout>=2.2.0",
"pytest-xdist>=3.8.0",
]
14 changes: 14 additions & 0 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading