diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 50f14bbb..07e45357 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -73,6 +73,11 @@ jobs: run: | nvidia-smi || (echo "❌ No GPU detected" && exit 1) python -c "import torch; assert torch.cuda.is_available(), '❌ GPU not found'; print('✅ Found GPU:', torch.cuda.get_device_name(0))" + - name: Cache model checkpoints + uses: actions/cache@v4 + with: + path: ~/.cache/torch/syntheseus + key: model-checkpoints-${{ hashFiles('syntheseus/reaction_prediction/inference/default_checkpoint_ids.yml') }} - name: Run single-step model tests run: | coverage run -p -m pytest \ diff --git a/pyproject.toml b/pyproject.toml index 8c07c86b..687806f5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,6 +33,7 @@ viz = [ dev = [ "pytest", "pytest-cov", + "pytest-forked", "pytest-rerunfailures", "pre-commit" ] diff --git a/syntheseus/tests/cli/test_cli.py b/syntheseus/tests/cli/test_cli.py index ca65bad4..6c346e8b 100644 --- a/syntheseus/tests/cli/test_cli.py +++ b/syntheseus/tests/cli/test_cli.py @@ -1,6 +1,7 @@ import glob import json import math +import subprocess import sys import tempfile import urllib @@ -8,7 +9,6 @@ from pathlib import Path from typing import Generator, List -import omegaconf import pytest from syntheseus.reaction_prediction.inference.config import BackwardModelClass @@ -21,7 +21,7 @@ ) -MODEL_CLASSES_TO_TEST = set(BackwardModelClass) - {BackwardModelClass.GLN} +MODEL_CLASSES_TO_TEST = [m for m in BackwardModelClass if m is not BackwardModelClass.GLN] @pytest.fixture(scope="module") @@ -66,12 +66,13 @@ def search_cli_argv() -> List[str]: def run_cli_with_argv(argv: List[str]) -> None: - # The import below pulls in some optional dependencies, so do it locally to avoid executing it - # if the test suite is being skipped. - from syntheseus.cli.main import main + # Run in a subprocess so that model memory is fully reclaimed after each test. + result = subprocess.run( + [sys.executable, "-m", "syntheseus.cli.main"] + argv, capture_output=True, text=True + ) - sys.argv = ["syntheseus"] + argv - main() + if result.returncode != 0: + raise RuntimeError("CLI failed") def test_cli_invalid( @@ -103,7 +104,7 @@ def test_cli_invalid( ] for argv in argv_lists: - with pytest.raises((ValueError, omegaconf.errors.MissingMandatoryValue)): + with pytest.raises(RuntimeError): run_cli_with_argv(argv) diff --git a/syntheseus/tests/conftest.py b/syntheseus/tests/conftest.py index 538e74a9..ab556837 100644 --- a/syntheseus/tests/conftest.py +++ b/syntheseus/tests/conftest.py @@ -1,11 +1,17 @@ from __future__ import annotations +import os + import pytest from syntheseus.interface.bag import Bag from syntheseus.interface.molecule import Molecule from syntheseus.interface.reaction import SingleProductReaction +# Make `torch.cuda.is_available()` fork-safe so that calling it in the parent process does not +# poison `@pytest.mark.forked` tests (otherwise could run into CUDA reinitialization issues). +os.environ.setdefault("PYTORCH_NVML_BASED_CUDA_CHECK", "1") + @pytest.fixture def cocs_mol() -> Molecule: diff --git a/syntheseus/tests/reaction_prediction/inference/test_models.py b/syntheseus/tests/reaction_prediction/inference/test_models.py index 99d31128..41a4de1a 100644 --- a/syntheseus/tests/reaction_prediction/inference/test_models.py +++ b/syntheseus/tests/reaction_prediction/inference/test_models.py @@ -12,15 +12,16 @@ ) -MODEL_CLASSES_TO_TEST = set(BackwardModelClass) - {BackwardModelClass.GLN} +MODEL_CLASSES_TO_TEST = [m for m in BackwardModelClass if m is not BackwardModelClass.GLN] -@pytest.fixture(scope="module", params=list(MODEL_CLASSES_TO_TEST) * 2) +@pytest.fixture(params=MODEL_CLASSES_TO_TEST) def model(request) -> ExternalBackwardReactionModel: model_cls = request.param.value return model_cls() +@pytest.mark.forked def test_call(model: ExternalBackwardReactionModel) -> None: [result] = model([Molecule("Cc1ccc(-c2ccc(C)cc2)cc1")], num_results=20) model_predictions = [prediction.reactants for prediction in result] @@ -35,10 +36,10 @@ def test_call(model: ExternalBackwardReactionModel) -> None: # The model should recover at least two (out of six) in its top-20. assert len(set(expected_predictions) & set(model_predictions)) >= 2 - -def test_misc(model: ExternalBackwardReactionModel) -> None: import torch + # Additionally test some misc properties and methods. + assert isinstance(model.name, str) assert isinstance(model.get_model_info(), dict) assert model.is_backward() is not model.is_forward()