From 2a2cb9a30007990afa9159400095759f0b63b4f8 Mon Sep 17 00:00:00 2001 From: martinaoliver Date: Tue, 3 Mar 2026 14:57:00 +0000 Subject: [PATCH] TF adapted to work with other non-homosapiens species (#333) * changed to tf on ensembl * added info on genes expressed and mapped * logger * github pr rerun checks * Bump version to 1.8.1 --------- Co-authored-by: maxime --- .../transcriptformer/data/dataloader.py | 21 +++++++++++++++++-- pyproject.toml | 2 +- 2 files changed, 20 insertions(+), 3 deletions(-) diff --git a/helical/models/transcriptformer/data/dataloader.py b/helical/models/transcriptformer/data/dataloader.py index f8272db8..b4d8f201 100644 --- a/helical/models/transcriptformer/data/dataloader.py +++ b/helical/models/transcriptformer/data/dataloader.py @@ -35,8 +35,13 @@ def load_gene_features(adata, gene_col_name, species: str = "hsapiens"): message = f"Gene column '{gene_col_name}' not found in adata.var.columns. Available columns: {adata.var.columns}. Modify config accordingly." logging.error(message) raise ValueError(message) - adata = map_gene_symbols_to_ensembl_ids(adata, gene_names=gene_col_name, species=species) - gene_names = np.array(list(adata.var["ensembl_id"].values)) + if adata.var[gene_col_name].str.contains("ENS", na=False).mean() <= 0.5: + adata = map_gene_symbols_to_ensembl_ids(adata, gene_names=gene_col_name, species=species) + gene_names = np.array(list(adata.var["ensembl_id"].values)) + else: + adata.var["ensembl_id"] = adata.var[gene_col_name] + gene_names = np.array(list(adata.var["ensembl_id"].values)) + return gene_names, True @@ -52,7 +57,10 @@ def apply_filters( ): """Apply filters to the data.""" if filter_to_vocab: + n_total_genes = len(gene_names) filter_idx = [i for i, name in enumerate(gene_names) if name in vocab] + print(len(vocab)) + not_in_vocab = n_total_genes - len(filter_idx) X = X[:, filter_idx] gene_names = gene_names[filter_idx] if X.shape[1] == 0: @@ -60,6 +68,15 @@ def apply_filters( logging.warning(f"Available genes: {len(gene_names)}") logging.warning(f"Number of non-zero genes: {np.sum(X > 0, axis=1).mean()}") return None, None, None + zero_expr = int((X == 0).all(axis=0).sum()) + nonzero_expr = len(filter_idx) - zero_expr + logging.info( + f"Gene mapping: {len(filter_idx)} / {n_total_genes} in vocab | " + f"not in vocab: {not_in_vocab} | " + f"in vocab but zero expression: {zero_expr} | " + f"in vocab and expressed: {nonzero_expr}" + ) + if filter_outliers > 0: expr_counts = X.sum(axis=1) diff --git a/pyproject.toml b/pyproject.toml index e381dd77..300ee304 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "helical" -version = "1.8.0" +version = "1.8.1" authors = [ { name="Helical Team", email="support@helical-ai.com" }, ]