Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions ci/tests/test_geneformer/test_geneformer_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,3 +288,18 @@ def test_layer_to_quant(self, model_name, emb_layer):
geneformer = Geneformer(config)

assert geneformer.layer_to_quant == emb_layer

@pytest.mark.parametrize(
"old_model_name,new_model_name",
[
("gf-6L-30M-i2048", "gf-6L-10M-i2048"),
("gf-12L-30M-i2048", "gf-12L-40M-i2048"),
("gf-12L-95M-i4096", "gf-12L-38M-i4096"),
("gf-12L-95M-i4096-CLcancer", "gf-12L-38M-i4096-CLcancer"),
("gf-20L-95M-i4096", "gf-20L-151M-i4096"),
],
)
def test_model_name_mapping(self, old_model_name, new_model_name):
config = GeneformerConfig(model_name=old_model_name)

assert config.config["model_name"] == new_model_name
30 changes: 29 additions & 1 deletion helical/models/geneformer/geneformer_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
from pathlib import Path
from helical.constants.paths import CACHE_DIR_HELICAL
from typing import Literal
import logging

LOGGER = logging.getLogger(__name__)


class GeneformerConfig:
Expand Down Expand Up @@ -48,6 +51,7 @@ class GeneformerConfig:
def __init__(
self,
model_name: Literal[
# new and renamed models
"gf-6L-10M-i2048",
"gf-12L-38M-i4096",
"gf-12L-38M-i4096-CLcancer",
Expand All @@ -57,6 +61,12 @@ def __init__(
"gf-20L-151M-i4096",
"gf-18L-316M-i4096",
"gf-12L-40M-i2048-CZI-CellxGene",
# old models
"gf-6L-30M-i2048",
"gf-12L-30M-i2048",
"gf-12L-95M-i4096",
"gf-12L-95M-i4096-CLcancer",
"gf-20L-95M-i4096",
] = "gf-12L-38M-i4096",
batch_size: int = 24,
emb_layer: int = -1,
Expand All @@ -66,6 +76,23 @@ def __init__(
custom_attr_name_dict: Optional[dict] = None,
):

old_model_to_new_model_map = {
"gf-6L-30M-i2048": "gf-6L-10M-i2048",
"gf-12L-30M-i2048": "gf-12L-40M-i2048",
"gf-12L-95M-i4096": "gf-12L-38M-i4096",
"gf-12L-95M-i4096-CLcancer": "gf-12L-38M-i4096-CLcancer",
"gf-20L-95M-i4096": "gf-20L-151M-i4096",
}

if model_name in old_model_to_new_model_map:
message = (
f"Setting model to {old_model_to_new_model_map[model_name]}. Model name {model_name} is deprecated. "
"Please use the new name going forward to avoid code breakages."
"Geneformer models have been renamed to better reflect their size."
)
LOGGER.warning(message)
model_name = old_model_to_new_model_map[model_name]

# model specific parameters
self.model_map = {
"gf-12L-38M-i4096": {
Expand Down Expand Up @@ -144,7 +171,8 @@ def __init__(

# Add model weight files to download based on the model version (v1 or v2)
if (
self.model_map[model_name]["model_version"] != "v1" or model_name == "gf-12L-40M-i2048-CZI-CellxGene"
self.model_map[model_name]["model_version"] != "v1"
or model_name == "gf-12L-40M-i2048-CZI-CellxGene"
):
self.list_of_files_to_download.append(
f"geneformer/{self.model_map[model_name]['model_version']}/{model_name}/model.safetensors"
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "hatchling.build"

[project]
name = "helical"
version = "1.4.0"
version = "1.4.1"
authors = [
{ name="Helical Team", email="support@helical-ai.com" },
]
Expand Down
Loading