From 5cb0b66103dd8e8cb94f8f35b53f8cbca5ee8d28 Mon Sep 17 00:00:00 2001 From: Matthew Wood <62712722+mattwoodx@users.noreply.github.com> Date: Mon, 11 Aug 2025 14:55:32 +0200 Subject: [PATCH] Fix deprecation warning geneformer (#260) * Revert "Hot fix remove notebook from release run" This reverts commit 52bba026f1cb6149cf7e6a7f57f1ee261a6c4769. * Add aliasing and a deprecation message for old Geneformer models * Update helical version for release * fixup! Update helical version for release * fixup! Add aliasing and a deprecation message for old Geneformer models * Caplog not working --- .../test_geneformer/test_geneformer_model.py | 15 ++++++++++ .../models/geneformer/geneformer_config.py | 30 ++++++++++++++++++- pyproject.toml | 2 +- 3 files changed, 45 insertions(+), 2 deletions(-) diff --git a/ci/tests/test_geneformer/test_geneformer_model.py b/ci/tests/test_geneformer/test_geneformer_model.py index a325e853..7b846f06 100644 --- a/ci/tests/test_geneformer/test_geneformer_model.py +++ b/ci/tests/test_geneformer/test_geneformer_model.py @@ -288,3 +288,18 @@ def test_layer_to_quant(self, model_name, emb_layer): geneformer = Geneformer(config) assert geneformer.layer_to_quant == emb_layer + + @pytest.mark.parametrize( + "old_model_name,new_model_name", + [ + ("gf-6L-30M-i2048", "gf-6L-10M-i2048"), + ("gf-12L-30M-i2048", "gf-12L-40M-i2048"), + ("gf-12L-95M-i4096", "gf-12L-38M-i4096"), + ("gf-12L-95M-i4096-CLcancer", "gf-12L-38M-i4096-CLcancer"), + ("gf-20L-95M-i4096", "gf-20L-151M-i4096"), + ], + ) + def test_model_name_mapping(self, old_model_name, new_model_name): + config = GeneformerConfig(model_name=old_model_name) + + assert config.config["model_name"] == new_model_name diff --git a/helical/models/geneformer/geneformer_config.py b/helical/models/geneformer/geneformer_config.py index b040bf60..ec3c4f53 100644 --- a/helical/models/geneformer/geneformer_config.py +++ b/helical/models/geneformer/geneformer_config.py @@ -2,6 +2,9 @@ from pathlib import Path from helical.constants.paths import CACHE_DIR_HELICAL from typing import Literal +import logging + +LOGGER = logging.getLogger(__name__) class GeneformerConfig: @@ -48,6 +51,7 @@ class GeneformerConfig: def __init__( self, model_name: Literal[ + # new and renamed models "gf-6L-10M-i2048", "gf-12L-38M-i4096", "gf-12L-38M-i4096-CLcancer", @@ -57,6 +61,12 @@ def __init__( "gf-20L-151M-i4096", "gf-18L-316M-i4096", "gf-12L-40M-i2048-CZI-CellxGene", + # old models + "gf-6L-30M-i2048", + "gf-12L-30M-i2048", + "gf-12L-95M-i4096", + "gf-12L-95M-i4096-CLcancer", + "gf-20L-95M-i4096", ] = "gf-12L-38M-i4096", batch_size: int = 24, emb_layer: int = -1, @@ -66,6 +76,23 @@ def __init__( custom_attr_name_dict: Optional[dict] = None, ): + old_model_to_new_model_map = { + "gf-6L-30M-i2048": "gf-6L-10M-i2048", + "gf-12L-30M-i2048": "gf-12L-40M-i2048", + "gf-12L-95M-i4096": "gf-12L-38M-i4096", + "gf-12L-95M-i4096-CLcancer": "gf-12L-38M-i4096-CLcancer", + "gf-20L-95M-i4096": "gf-20L-151M-i4096", + } + + if model_name in old_model_to_new_model_map: + message = ( + f"Setting model to {old_model_to_new_model_map[model_name]}. Model name {model_name} is deprecated. " + "Please use the new name going forward to avoid code breakages." + "Geneformer models have been renamed to better reflect their size." + ) + LOGGER.warning(message) + model_name = old_model_to_new_model_map[model_name] + # model specific parameters self.model_map = { "gf-12L-38M-i4096": { @@ -144,7 +171,8 @@ def __init__( # Add model weight files to download based on the model version (v1 or v2) if ( - self.model_map[model_name]["model_version"] != "v1" or model_name == "gf-12L-40M-i2048-CZI-CellxGene" + self.model_map[model_name]["model_version"] != "v1" + or model_name == "gf-12L-40M-i2048-CZI-CellxGene" ): self.list_of_files_to_download.append( f"geneformer/{self.model_map[model_name]['model_version']}/{model_name}/model.safetensors" diff --git a/pyproject.toml b/pyproject.toml index 646ac438..6b62095e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "helical" -version = "1.4.0" +version = "1.4.1" authors = [ { name="Helical Team", email="support@helical-ai.com" }, ]