Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 30 additions & 1 deletion helical/models/transcriptformer/transcriptformer_config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from omegaconf import OmegaConf
from typing import Literal, List, Union
from pathlib import Path
from helical.constants.paths import CACHE_DIR_HELICAL


class TranscriptFormerConfig:
Expand All @@ -25,7 +27,9 @@ class TranscriptFormerConfig:
load_checkpoint: str = None
Path to model weights file (automatically set by inference.py)
pretrained_embedding: Union[str, List[str]] = None
Path or list of paths to pretrained embeddings for out-of-distribution species
Path or list of paths to pretrained embeddings for out-of-distribution species. Mutually exclusive with `pretrained_embedding_species`.
pretrained_embedding_species: Union[str, List[str]] = None
Underscore-separated specie name or list of names to retrieve paths. Example: `pretrained_embedding_species="mus_musculus"` or `pretrained_embedding_species=["mus_musculus", "sus_scrofa"]`. Mutually exclusive with `pretrained_embedding`.
gene_col_name: str = "ensembl_id"
Column name in AnnData.var containing gene names which will be mapped to ensembl ids. If index is set, .var_names will be used.
clip_counts: int = 30
Expand Down Expand Up @@ -58,6 +62,7 @@ def __init__(
output_path: str = "./inference_results",
load_checkpoint: str = None,
pretrained_embedding: Union[str, List[str]] = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you adjust the docstring for this new parameter? also is there a way to only have 1 and not 2 of these variables?

pretrained_embedding_species: Union[str, List[str]] = None,
gene_col_name: str = "index",
clip_counts: int = 30,
filter_to_vocabs: bool = True,
Expand All @@ -68,6 +73,30 @@ def __init__(
min_expressed_genes: int = 0,
):

if (
pretrained_embedding_species is not None
and pretrained_embedding is not None
):
raise ValueError(
"pretrained_embedding_species and pretrained_embedding are mutually exclusive"
)

if pretrained_embedding_species is not None and pretrained_embedding is None:
species_list = (
[pretrained_embedding_species]
if isinstance(pretrained_embedding_species, str)
else pretrained_embedding_species
)
vocab_base = (
Path(CACHE_DIR_HELICAL)
/ "models/transcriptformer"
/ model_name
/ "vocabs"
)
pretrained_embedding = [
str(vocab_base / f"{s}_gene.h5") for s in species_list
]

inference_config: dict = {
"batch_size": batch_size,
"output_keys": output_keys,
Expand Down
Loading