Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions .github/workflows/test-hf-models.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ on:
jobs:
test-models:
runs-on: ubuntu-latest
env:
WANDB_MODE: disabled
strategy:
fail-fast: false
matrix:
Expand Down Expand Up @@ -48,15 +50,21 @@ jobs:
- name: Install project dependencies
run: uv sync --all-extras --dev

- name: Free package download cache
run: uv cache clean

- name: Run tests
env:
WANDB_MODE: disabled
HF_HOME: ${{ runner.temp }}/huggingface
HF_HUB_CACHE: ${{ runner.temp }}/huggingface/hub
HF_HUB_DISABLE_PROGRESS_BARS: "1"
HF_HUB_DISABLE_XET: "1"
HF_MODULES_CACHE: ${{ runner.temp }}/huggingface/modules
run: |
uv run pytest opt/package/test_hf_org.py::${{ matrix.test-class }} \
uv run --no-sync pytest opt/package/test_hf_org.py::${{ matrix.test-class }} \
-v \
--log-cli-level=INFO \
--durations=10 \
--maxfail=5
--durations=10

- name: Upload test results
if: always()
Expand Down
219 changes: 138 additions & 81 deletions opt/package/test_hf_org.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
"""Tests for models in the mist-models HuggingFace organization."""

from contextlib import contextmanager
from functools import lru_cache
import gc
import logging
import os
from pathlib import Path
import tempfile
from typing import Iterator

import pytest
from huggingface_hub import HfApi
from transformers import AutoModel
from transformers import AutoConfig, AutoModel

from .test_inference import (
single_molecule_smiles,
Expand All @@ -17,6 +23,101 @@
)

logger = logging.getLogger(__name__)
HF_ORG = "mist-models"


@contextmanager
def loaded_hf_model(model_id: str, hf_token: str | None) -> Iterator[object]:
"""Load one HF model in an isolated cache that is removed after the check."""
with tempfile.TemporaryDirectory(
prefix="hf-model-cache-", dir=os.getenv("RUNNER_TEMP") or None
) as cache_dir:
model = AutoModel.from_pretrained(
model_id,
trust_remote_code=True,
token=hf_token,
cache_dir=cache_dir,
)
try:
yield model
finally:
del model
gc.collect()


def load_hf_config(model_id: str, hf_token: str | None):
with tempfile.TemporaryDirectory(
prefix="hf-config-cache-", dir=os.getenv("RUNNER_TEMP") or None
) as cache_dir:
return AutoConfig.from_pretrained(
model_id,
trust_remote_code=True,
token=hf_token,
cache_dir=cache_dir,
)


@lru_cache
def list_hf_org_model_ids(hf_token: str | None) -> tuple[str, ...]:
api = HfApi(token=hf_token)
model_ids = tuple(m.id for m in api.list_models(author=HF_ORG))
logger.info("Found %d models in %s organization", len(model_ids), HF_ORG)
return model_ids


def parametrize_model_ids(metafunc, fixture_name: str, model_ids: tuple[str, ...]):
if model_ids:
metafunc.parametrize(
fixture_name, model_ids, ids=lambda m: m.rsplit("/", 1)[-1]
)
return

metafunc.parametrize(
fixture_name,
[
pytest.param(
None,
marks=pytest.mark.skip(reason=f"No {fixture_name} models found"),
)
],
ids=["no-models"],
)


def pytest_generate_tests(metafunc):
model_ids = list_hf_org_model_ids(os.getenv("HF_TOKEN"))

if "hf_model_id" in metafunc.fixturenames:
parametrize_model_ids(metafunc, "hf_model_id", model_ids)

if "single_model_id" in metafunc.fixturenames:
parametrize_model_ids(
metafunc,
"single_model_id",
tuple(
m for m in model_ids if get_model_type_from_path(Path(m)) == "single"
),
)

if "conductivity_model_id" in metafunc.fixturenames:
parametrize_model_ids(
metafunc,
"conductivity_model_id",
tuple(
m
for m in model_ids
if get_model_type_from_path(Path(m)) == "conductivity"
),
)

if "excess_model_id" in metafunc.fixturenames:
parametrize_model_ids(
metafunc,
"excess_model_id",
tuple(
m for m in model_ids if get_model_type_from_path(Path(m)) == "excess"
),
)


@pytest.fixture
Expand All @@ -26,36 +127,14 @@ def hf_token():
return os.getenv("HF_TOKEN")


@pytest.fixture
def hf_org_models(hf_token):
api = HfApi(token=hf_token)
models = list(api.list_models(author="mist-models"))

if not models:
pytest.skip("No models found in mist-models organization")

model_ids = [m.id for m in models]
logger.info(f"Found {len(model_ids)} models in mist-models organization")
return model_ids


class TestHFOrgSingleMoleculeModels:
def test_predict_single_molecules(
self, hf_org_models, hf_token, single_molecule_smiles
self, single_model_id, hf_token, single_molecule_smiles
):
single_models = [
m for m in hf_org_models if get_model_type_from_path(Path(m)) == "single"
]

for model_id in single_models:
logger.info(f"Testing {model_id}")
model = AutoModel.from_pretrained(
model_id, trust_remote_code=True, token=hf_token
)

logger.info(f"Testing {single_model_id}")
with loaded_hf_model(single_model_id, hf_token) as model:
if "RobertaPreLayerNormModel" in type(model).__name__:
logger.info("Skipping encoder-only model")
continue
pytest.skip("Skipping encoder-only model")

predictions = model.predict(single_molecule_smiles)
assert predictions is not None
Expand All @@ -66,48 +145,35 @@ def test_predict_single_molecules(
if isinstance(task_data, dict) and "value" in task_data:
values = task_data["value"]
assert len(values) == len(single_molecule_smiles)
validate_predictions(values, name=f"{model_id}:{task_name}")
validate_predictions(
values, name=f"{single_model_id}:{task_name}"
)
else:
assert len(predictions) == len(single_molecule_smiles)
validate_predictions(predictions, name=model_id)
validate_predictions(predictions, name=single_model_id)


class TestHFOrgConductivityModels:
def test_predict_mixtures(self, hf_org_models, hf_token, conductivity_test_data):
cond_models = [
m
for m in hf_org_models
if get_model_type_from_path(Path(m)) == "conductivity"
]
for model_id in cond_models:
logger.info(f"Testing {model_id}")
model = AutoModel.from_pretrained(
model_id, trust_remote_code=True, token=hf_token
)

def test_predict_mixtures(
self, conductivity_model_id, hf_token, conductivity_test_data
):
logger.info(f"Testing {conductivity_model_id}")
with loaded_hf_model(conductivity_model_id, hf_token) as model:
predictions = model.predict(conductivity_test_data)
assert predictions is not None

if isinstance(predictions, dict):
for key, value in predictions.items():
validate_predictions(value, name=f"{model_id}:{key}")
validate_predictions(value, name=f"{conductivity_model_id}:{key}")
else:
validate_predictions(predictions, name=model_id)
validate_predictions(predictions, name=conductivity_model_id)


class TestHFOrgExcessPhysicsModels:
def test_predict_binary_mixture(self, hf_org_models, hf_token, excess_test_data):
excess_models = [
m for m in hf_org_models if get_model_type_from_path(Path(m)) == "excess"
]
def test_predict_binary_mixture(self, excess_model_id, hf_token, excess_test_data):
test_case = excess_test_data[0]

for model_id in excess_models:
logger.info(f"Testing {model_id}")
model = AutoModel.from_pretrained(
model_id, trust_remote_code=True, token=hf_token
)

logger.info(f"Testing {excess_model_id}")
with loaded_hf_model(excess_model_id, hf_token) as model:
predictions = model.predict(
smiles_list=test_case["smiles_list"],
composition=test_case["composition"],
Expand All @@ -118,41 +184,32 @@ def test_predict_binary_mixture(self, hf_org_models, hf_token, excess_test_data)

if isinstance(predictions, dict):
for key, value in predictions.items():
validate_predictions(value, name=f"{model_id}:{key}")
validate_predictions(value, name=f"{excess_model_id}:{key}")


class TestHFOrgModelIntegrity:
def test_all_models_config(self, hf_org_models, hf_token):
for model_id in hf_org_models:
logger.info(f"Checking config for {model_id}")
model = AutoModel.from_pretrained(
model_id, trust_remote_code=True, token=hf_token
)
assert hasattr(model, "config")
assert model.config is not None
def test_all_models_config(self, hf_model_id, hf_token):
logger.info(f"Checking config for {hf_model_id}")
config = load_hf_config(hf_model_id, hf_token)
assert config is not None

def test_models_required_files(self, hf_org_models, hf_token):
def test_models_required_files(self, hf_model_id, hf_token):
api = HfApi(token=hf_token)
model_info = api.model_info(hf_model_id)
siblings = {f.rfilename for f in model_info.siblings}

for model_id in hf_org_models:
model_info = api.model_info(model_id)
siblings = {f.rfilename for f in model_info.siblings}
assert "config.json" in siblings, f"{hf_model_id} missing config.json"
assert "README.md" in siblings, f"{hf_model_id} missing README.md"

assert "config.json" in siblings, f"{model_id} missing config.json"
assert "README.md" in siblings, f"{model_id} missing README.md"
has_weights = any(
"safetensors" in f or "pytorch_model.bin" in f for f in siblings
)
assert has_weights, f"{hf_model_id} missing model weights"

has_weights = any(
"safetensors" in f or "pytorch_model.bin" in f for f in siblings
)
assert has_weights, f"{model_id} missing model weights"

def test_multi_channel_model_labels(self, hf_org_models, hf_token):
for model_id in hf_org_models:
logger.info(f"Checking channels for {model_id}")
model = AutoModel.from_pretrained(
model_id, trust_remote_code=True, token=hf_token
)
check_multi_channel_labels(model, model_id)
def test_multi_channel_model_labels(self, hf_model_id, hf_token):
logger.info(f"Checking channels for {hf_model_id}")
config = load_hf_config(hf_model_id, hf_token)
check_multi_channel_labels(config, hf_model_id)


if __name__ == "__main__":
Expand Down
9 changes: 4 additions & 5 deletions opt/package/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,15 +115,14 @@ def get_model_type_from_path(path: Path) -> str:
return "single"


def check_multi_channel_labels(model, model_name: str):
def check_multi_channel_labels(model_or_config, model_name: str):
"""Verify that multi-output models have channel labels."""
if "RobertaPreLayerNormModel" in type(model).__name__:
type_name = type(model_or_config).__name__
if "RobertaPreLayerNorm" in type_name:
return

if not hasattr(model, "config"):
return
config = getattr(model_or_config, "config", model_or_config)

config = model.config
if not hasattr(config, "task_network"):
return

Expand Down
Loading