Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,11 @@ jobs:
run: |
nvidia-smi || (echo "❌ No GPU detected" && exit 1)
python -c "import torch; assert torch.cuda.is_available(), '❌ GPU not found'; print('✅ Found GPU:', torch.cuda.get_device_name(0))"
- name: Cache model checkpoints
uses: actions/cache@v4
with:
path: ~/.cache/torch/syntheseus
key: model-checkpoints-${{ hashFiles('syntheseus/reaction_prediction/inference/default_checkpoint_ids.yml') }}
- name: Run single-step model tests
run: |
coverage run -p -m pytest \
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ viz = [
dev = [
"pytest",
"pytest-cov",
"pytest-forked",
"pytest-rerunfailures",
"pre-commit"
]
Expand Down
17 changes: 9 additions & 8 deletions syntheseus/tests/cli/test_cli.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import glob
import json
import math
import subprocess
import sys
import tempfile
import urllib
import zipfile
from pathlib import Path
from typing import Generator, List

import omegaconf
import pytest

from syntheseus.reaction_prediction.inference.config import BackwardModelClass
Expand All @@ -21,7 +21,7 @@
)


MODEL_CLASSES_TO_TEST = set(BackwardModelClass) - {BackwardModelClass.GLN}
MODEL_CLASSES_TO_TEST = [m for m in BackwardModelClass if m is not BackwardModelClass.GLN]


@pytest.fixture(scope="module")
Expand Down Expand Up @@ -66,12 +66,13 @@ def search_cli_argv() -> List[str]:


def run_cli_with_argv(argv: List[str]) -> None:
# The import below pulls in some optional dependencies, so do it locally to avoid executing it
# if the test suite is being skipped.
from syntheseus.cli.main import main
# Run in a subprocess so that model memory is fully reclaimed after each test.
result = subprocess.run(
[sys.executable, "-m", "syntheseus.cli.main"] + argv, capture_output=True, text=True
)

sys.argv = ["syntheseus"] + argv
main()
if result.returncode != 0:
raise RuntimeError("CLI failed")


def test_cli_invalid(
Expand Down Expand Up @@ -103,7 +104,7 @@ def test_cli_invalid(
]

for argv in argv_lists:
with pytest.raises((ValueError, omegaconf.errors.MissingMandatoryValue)):
with pytest.raises(RuntimeError):
run_cli_with_argv(argv)


Expand Down
6 changes: 6 additions & 0 deletions syntheseus/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
from __future__ import annotations

import os

import pytest

from syntheseus.interface.bag import Bag
from syntheseus.interface.molecule import Molecule
from syntheseus.interface.reaction import SingleProductReaction

# Make `torch.cuda.is_available()` fork-safe so that calling it in the parent process does not
# poison `@pytest.mark.forked` tests (otherwise could run into CUDA reinitialization issues).
os.environ.setdefault("PYTORCH_NVML_BASED_CUDA_CHECK", "1")


@pytest.fixture
def cocs_mol() -> Molecule:
Expand Down
9 changes: 5 additions & 4 deletions syntheseus/tests/reaction_prediction/inference/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,16 @@
)


MODEL_CLASSES_TO_TEST = set(BackwardModelClass) - {BackwardModelClass.GLN}
MODEL_CLASSES_TO_TEST = [m for m in BackwardModelClass if m is not BackwardModelClass.GLN]


@pytest.fixture(scope="module", params=list(MODEL_CLASSES_TO_TEST) * 2)
@pytest.fixture(params=MODEL_CLASSES_TO_TEST)
def model(request) -> ExternalBackwardReactionModel:
model_cls = request.param.value
return model_cls()


@pytest.mark.forked
def test_call(model: ExternalBackwardReactionModel) -> None:
[result] = model([Molecule("Cc1ccc(-c2ccc(C)cc2)cc1")], num_results=20)
model_predictions = [prediction.reactants for prediction in result]
Expand All @@ -35,10 +36,10 @@ def test_call(model: ExternalBackwardReactionModel) -> None:
# The model should recover at least two (out of six) in its top-20.
assert len(set(expected_predictions) & set(model_predictions)) >= 2


def test_misc(model: ExternalBackwardReactionModel) -> None:
import torch

# Additionally test some misc properties and methods.

assert isinstance(model.name, str)
assert isinstance(model.get_model_info(), dict)
assert model.is_backward() is not model.is_forward()
Expand Down
Loading