diff --git a/ci/tests/test_geneformer/test_geneformer_model.py b/ci/tests/test_geneformer/test_geneformer_model.py index a325e853..7b846f06 100644 --- a/ci/tests/test_geneformer/test_geneformer_model.py +++ b/ci/tests/test_geneformer/test_geneformer_model.py @@ -288,3 +288,18 @@ def test_layer_to_quant(self, model_name, emb_layer): geneformer = Geneformer(config) assert geneformer.layer_to_quant == emb_layer + + @pytest.mark.parametrize( + "old_model_name,new_model_name", + [ + ("gf-6L-30M-i2048", "gf-6L-10M-i2048"), + ("gf-12L-30M-i2048", "gf-12L-40M-i2048"), + ("gf-12L-95M-i4096", "gf-12L-38M-i4096"), + ("gf-12L-95M-i4096-CLcancer", "gf-12L-38M-i4096-CLcancer"), + ("gf-20L-95M-i4096", "gf-20L-151M-i4096"), + ], + ) + def test_model_name_mapping(self, old_model_name, new_model_name): + config = GeneformerConfig(model_name=old_model_name) + + assert config.config["model_name"] == new_model_name diff --git a/helical/models/geneformer/geneformer_config.py b/helical/models/geneformer/geneformer_config.py index b040bf60..ec3c4f53 100644 --- a/helical/models/geneformer/geneformer_config.py +++ b/helical/models/geneformer/geneformer_config.py @@ -2,6 +2,9 @@ from pathlib import Path from helical.constants.paths import CACHE_DIR_HELICAL from typing import Literal +import logging + +LOGGER = logging.getLogger(__name__) class GeneformerConfig: @@ -48,6 +51,7 @@ class GeneformerConfig: def __init__( self, model_name: Literal[ + # new and renamed models "gf-6L-10M-i2048", "gf-12L-38M-i4096", "gf-12L-38M-i4096-CLcancer", @@ -57,6 +61,12 @@ def __init__( "gf-20L-151M-i4096", "gf-18L-316M-i4096", "gf-12L-40M-i2048-CZI-CellxGene", + # old models + "gf-6L-30M-i2048", + "gf-12L-30M-i2048", + "gf-12L-95M-i4096", + "gf-12L-95M-i4096-CLcancer", + "gf-20L-95M-i4096", ] = "gf-12L-38M-i4096", batch_size: int = 24, emb_layer: int = -1, @@ -66,6 +76,23 @@ def __init__( custom_attr_name_dict: Optional[dict] = None, ): + old_model_to_new_model_map = { + "gf-6L-30M-i2048": "gf-6L-10M-i2048", + "gf-12L-30M-i2048": "gf-12L-40M-i2048", + "gf-12L-95M-i4096": "gf-12L-38M-i4096", + "gf-12L-95M-i4096-CLcancer": "gf-12L-38M-i4096-CLcancer", + "gf-20L-95M-i4096": "gf-20L-151M-i4096", + } + + if model_name in old_model_to_new_model_map: + message = ( + f"Setting model to {old_model_to_new_model_map[model_name]}. Model name {model_name} is deprecated. " + "Please use the new name going forward to avoid code breakages." + "Geneformer models have been renamed to better reflect their size." + ) + LOGGER.warning(message) + model_name = old_model_to_new_model_map[model_name] + # model specific parameters self.model_map = { "gf-12L-38M-i4096": { @@ -144,7 +171,8 @@ def __init__( # Add model weight files to download based on the model version (v1 or v2) if ( - self.model_map[model_name]["model_version"] != "v1" or model_name == "gf-12L-40M-i2048-CZI-CellxGene" + self.model_map[model_name]["model_version"] != "v1" + or model_name == "gf-12L-40M-i2048-CZI-CellxGene" ): self.list_of_files_to_download.append( f"geneformer/{self.model_map[model_name]['model_version']}/{model_name}/model.safetensors" diff --git a/pyproject.toml b/pyproject.toml index 646ac438..6b62095e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "helical" -version = "1.4.0" +version = "1.4.1" authors = [ { name="Helical Team", email="support@helical-ai.com" }, ]