diff --git a/ci/tests/test_transcriptformer/test_transcriptformer_model.py b/ci/tests/test_transcriptformer/test_transcriptformer_model.py index 900088c0..141e803a 100644 --- a/ci/tests/test_transcriptformer/test_transcriptformer_model.py +++ b/ci/tests/test_transcriptformer/test_transcriptformer_model.py @@ -1,3 +1,7 @@ +import numpy as np +import h5py +import pytest +import torch from helical.models.transcriptformer.model import TranscriptFormer from helical.models.transcriptformer.transcriptformer_config import ( TranscriptFormerConfig, @@ -5,6 +9,16 @@ from anndata import AnnData +def _write_dummy_embedding_h5(path, gene_names, emb_dim=2560): + """Write a minimal HDF5 embedding file with random embeddings.""" + with h5py.File(path, "w") as f: + f.create_dataset("keys", data=np.array(gene_names, dtype="S")) + arrays_group = f.create_group("arrays") + rng = np.random.default_rng(seed=0) + for gene in gene_names: + arrays_group.create_dataset(gene, data=rng.random(emb_dim).astype(np.float32)) + + class TestTranscriptFormerModel: configurer = TranscriptFormerConfig(emb_mode="gene") transcriptformer = TranscriptFormer(configurer) @@ -40,3 +54,27 @@ def test_get_embeddings__in_gene_mode(self): assert embeddings[0]["ENSG00000121410"].shape == (2048,) assert embeddings[0]["ENSG00000036549"].shape == (2048,) assert embeddings[0]["ENSG00000074755"].shape == (2048,) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +class TestTranscriptFormerPretainedEmbeddingList: + """Tests that a list of pretrained embedding paths is accepted and applied correctly.""" + + GENES_FILE_1 = ["ENSG00000121410", "ENSG00000036549"] + GENES_FILE_2 = ["ENSG00000074755", "ENSG00000078808"] + + def test_model_loads_with_list_of_pretrained_embeddings(self, tmp_path): + path1 = str(tmp_path / "embeddings_1.h5") + path2 = str(tmp_path / "embeddings_2.h5") + _write_dummy_embedding_h5(path1, self.GENES_FILE_1) + _write_dummy_embedding_h5(path2, self.GENES_FILE_2) + + configurer = TranscriptFormerConfig( + emb_mode="gene", + pretrained_embedding=[path1, path2], + ) + model = TranscriptFormer(configurer) + + # All genes from both embedding files should be present in the updated vocab + for gene in self.GENES_FILE_1 + self.GENES_FILE_2: + assert gene in model.gene_vocab diff --git a/helical/models/transcriptformer/model.py b/helical/models/transcriptformer/model.py index 64447090..3cec5952 100644 --- a/helical/models/transcriptformer/model.py +++ b/helical/models/transcriptformer/model.py @@ -11,7 +11,7 @@ from helical.models.transcriptformer.utils.utils import stack_dict from helical.models.base_models import HelicalRNAModel from helical.utils.downloader import Downloader -from omegaconf import OmegaConf +from omegaconf import OmegaConf, ListConfig import json import os import pandas as pd @@ -149,12 +149,12 @@ def __init__(self, configurer: TranscriptFormerConfig = configurer): if self.model.inference_config.pretrained_embedding is not None: logger.info("Performing embedding surgery") # Check if pretrained_embedding_paths is a list, if not convert it to a list - if not isinstance(self.model.inference_config.pretrained_embedding, list): + if not isinstance(self.model.inference_config.pretrained_embedding, (list, ListConfig)): pretrained_embedding_paths = [ self.model.inference_config.pretrained_embedding ] else: - pretrained_embedding_paths = ( + pretrained_embedding_paths = list( self.model.inference_config.pretrained_embedding ) self.model, self.gene_vocab = change_embedding_layer( diff --git a/helical/models/transcriptformer/transcriptformer_config.py b/helical/models/transcriptformer/transcriptformer_config.py index 6ad19cee..bf91d679 100644 --- a/helical/models/transcriptformer/transcriptformer_config.py +++ b/helical/models/transcriptformer/transcriptformer_config.py @@ -1,5 +1,5 @@ from omegaconf import OmegaConf -from typing import Literal, List +from typing import Literal, List, Union class TranscriptFormerConfig: @@ -24,8 +24,8 @@ class TranscriptFormerConfig: Directory where results will be saved load_checkpoint: str = None Path to model weights file (automatically set by inference.py) - pretrained_embedding: str = None - Path to pretrained embeddings for out-of-distribution species + pretrained_embedding: Union[str, List[str]] = None + Path or list of paths to pretrained embeddings for out-of-distribution species gene_col_name: str = "ensembl_id" Column name in AnnData.var containing gene names which will be mapped to ensembl ids. If index is set, .var_names will be used. clip_counts: int = 30 @@ -57,7 +57,7 @@ def __init__( data_files: List[str] = [None], output_path: str = "./inference_results", load_checkpoint: str = None, - pretrained_embedding: str = None, + pretrained_embedding: Union[str, List[str]] = None, gene_col_name: str = "index", clip_counts: int = 30, filter_to_vocabs: bool = True, @@ -129,6 +129,10 @@ def __init__( "transcriptformer/tf_metazoa/vocabs/oryctolagus_cuniculus_gene.h5", "transcriptformer/tf_metazoa/vocabs/spongilla_lacustris_gene.h5", "transcriptformer/tf_metazoa/vocabs/homo_sapiens_gene.h5", + "transcriptformer/tf_metazoa/vocabs/canis_lupus_familiaris_gene.h5", + "transcriptformer/tf_metazoa/vocabs/rattus_norvegicus_gene.h5", + "transcriptformer/tf_metazoa/vocabs/sus_scrofa_gene.h5", + "transcriptformer/tf_metazoa/vocabs/macaca_fascicularis_gene.h5", ] elif model_name == "tf_exemplar": self.list_of_files_to_download = [ diff --git a/pyproject.toml b/pyproject.toml index 300ee304..f10a89f6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "helical" -version = "1.8.1" +version = "1.8.2" authors = [ { name="Helical Team", email="support@helical-ai.com" }, ]