Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 30 additions & 1 deletion helical/models/transcriptformer/transcriptformer_config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from omegaconf import OmegaConf
from typing import Literal, List, Union
from pathlib import Path
from helical.constants.paths import CACHE_DIR_HELICAL


class TranscriptFormerConfig:
Expand All @@ -25,7 +27,9 @@ class TranscriptFormerConfig:
load_checkpoint: str = None
Path to model weights file (automatically set by inference.py)
pretrained_embedding: Union[str, List[str]] = None
Path or list of paths to pretrained embeddings for out-of-distribution species
Path or list of paths to pretrained embeddings for out-of-distribution species. Mutually exclusive with `pretrained_embedding_species`.
pretrained_embedding_species: Union[str, List[str]] = None
Underscore-separated specie name or list of names to retrieve paths. Example: `pretrained_embedding_species="mus_musculus"` or `pretrained_embedding_species=["mus_musculus", "sus_scrofa"]`. Mutually exclusive with `pretrained_embedding`.
gene_col_name: str = "ensembl_id"
Column name in AnnData.var containing gene names which will be mapped to ensembl ids. If index is set, .var_names will be used.
clip_counts: int = 30
Expand Down Expand Up @@ -58,6 +62,7 @@ def __init__(
output_path: str = "./inference_results",
load_checkpoint: str = None,
pretrained_embedding: Union[str, List[str]] = None,
pretrained_embedding_species: Union[str, List[str]] = None,
gene_col_name: str = "index",
clip_counts: int = 30,
filter_to_vocabs: bool = True,
Expand All @@ -68,6 +73,30 @@ def __init__(
min_expressed_genes: int = 0,
):

if (
pretrained_embedding_species is not None
and pretrained_embedding is not None
):
raise ValueError(
"pretrained_embedding_species and pretrained_embedding are mutually exclusive"
)

if pretrained_embedding_species is not None and pretrained_embedding is None:
species_list = (
[pretrained_embedding_species]
if isinstance(pretrained_embedding_species, str)
else pretrained_embedding_species
)
vocab_base = (
Path(CACHE_DIR_HELICAL)
/ "models/transcriptformer"
/ model_name
/ "vocabs"
)
pretrained_embedding = [
str(vocab_base / f"{s}_gene.h5") for s in species_list
]

inference_config: dict = {
"batch_size": batch_size,
"output_keys": output_keys,
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "hatchling.build"

[project]
name = "helical"
version = "1.10.1"
version = "1.11.0"
authors = [
{ name="Helical Team", email="support@helical-ai.com" },
]
Expand Down
Loading