From ca4727ede3e29186b7ecefca5f96ba61044d15c5 Mon Sep 17 00:00:00 2001 From: Simon Graf <82808503+sgraf2002@users.noreply.github.com> Date: Thu, 20 Nov 2025 19:02:34 +0100 Subject: [PATCH 1/2] C2s fixes and extensions (#306) * fix c2s conversion * further fixes * use base 10 instead of 2 for normalization * fix formula * minor changes * smaller fixes to pass tests * replace list accumulation by running mean --------- Co-authored-by: Maxime Allard Co-authored-by: Benoit Putzeys <157973952+bputzeys@users.noreply.github.com> --- ci/tests/test_c2s/test_c2s.py | 4 +- helical/models/c2s/model.py | 119 ++++++++++++++++++++++------------ 2 files changed, 79 insertions(+), 44 deletions(-) diff --git a/ci/tests/test_c2s/test_c2s.py b/ci/tests/test_c2s/test_c2s.py index 43bdbf11..1df71869 100644 --- a/ci/tests/test_c2s/test_c2s.py +++ b/ci/tests/test_c2s/test_c2s.py @@ -115,6 +115,8 @@ def test_process_data_basic_properties(self, cell2sen_model, sample_anndata): # Check dataset type and structure assert isinstance(dataset, Dataset) + + # Check if the dataset has the right size assert len(dataset) == sample_anndata.n_obs # Check all required columns exist @@ -228,8 +230,8 @@ def test_get_embeddings_with_different_batch_sizes(self, cell2sen_model, process def test_get_embeddings_empty_dataset(self, cell2sen_model, sample_anndata): """Test embeddings with empty dataset.""" - empty_dataset = cell2sen_model.process_data(sample_anndata[:0]) with pytest.raises((ValueError, IndexError, AssertionError)): + empty_dataset = cell2sen_model.process_data(sample_anndata[:0]) cell2sen_model.get_embeddings(empty_dataset) def test_get_embeddings_attention_shapes(self, cell2sen_model, processed_dataset_basic): diff --git a/helical/models/c2s/model.py b/helical/models/c2s/model.py index 7a9f4705..76a6fe6d 100644 --- a/helical/models/c2s/model.py +++ b/helical/models/c2s/model.py @@ -9,6 +9,7 @@ import torch import anndata +import scanpy as sc import numpy as np from tqdm import tqdm from helical.models.base_models import HelicalBaseFoundationModel @@ -107,7 +108,8 @@ def __init__(self, configurer: Cell2SenConfig = None) -> None: def process_data( self, - anndata: anndata.AnnData, + adata: anndata.AnnData, + max_genes: int = None, ): """ Process anndata to create a HuggingFace Dataset with cell sentences and fit parameters. @@ -116,20 +118,8 @@ def process_data( ----------- anndata : AnnData Annotated data object with gene expression - min_genes : int - Minimum number of genes expressed per cell - min_counts : int - Minimum total counts per cell - min_cells : int - Minimum number of cells expressing a gene - max_cells : int, optional - Maximum number of cells to process max_genes : int, optional - Maximum number of genes to process - organism : str, optional - Organism name. If None, tries to extract from anndata.uns or uses default - perturbation_column : str, optional - Column name in anndata.obs to use for perturbations. If None, no perturbations are created. + Maximum number of genes to process per cell in descending expression order Returns: -------- dataset : Dataset @@ -137,17 +127,25 @@ def process_data( """ LOGGER.info("Processing data") - + if adata.n_obs == 0: + raise ValueError("Anndata is empty. Please provide a valid anndata object.") + + # standard log-normalization, enables accurate expression reconstruction + anndata = adata.copy() + sc.pp.normalize_total(anndata, target_sum=1e4) + sc.pp.log1p(anndata, base=10) + X = anndata.X if hasattr(X, 'toarray'): X = X.toarray() - X_log = np.log10(X + 1) # gene names corresponding to each cell in order # anndata.X[i, j] is the expression of the j-th gene in the i-th cell - gene_names = anndata.var_names.values cell_sentences = [] - fit_parameters = [] # Will be list of lists: [slope, intercept, r_squared] for each cell + + # Collect ranks and corresponding expression means as training data for reconstruction model + rank_to_mean = {} + rank_to_count = {} if self.organism is None: if 'organism' in anndata.uns: @@ -163,36 +161,71 @@ def process_data( self.organism = "unknown" # Default if not found # Process each cell - progress_bar = tqdm(total=X_log.shape[0], desc="Processing cells") - for cell_idx in range(X_log.shape[0]): - cell_expr = X_log[cell_idx, :] - - # Rank genes by expression (highest = rank 1) + progress_bar = tqdm(total=X.shape[0], desc="Processing cells") + for cell_idx in range(X.shape[0]): + gene_names = anndata.var_names.values + cell_expr = X[cell_idx, :] + # Rank nonzero genes by expression (highest = rank 1) + non_zero_mask = cell_expr > 0 + if non_zero_mask.sum() == 0: + LOGGER.warning(f"No genes expressed above zero in cell {cell_idx}. Using empty sentence.") + cell_sentence = "" + cell_sentences.append(cell_sentence) + progress_bar.update(1) + continue + + cell_expr = cell_expr[non_zero_mask] + gene_names = gene_names[non_zero_mask] + ranked_indices = np.argsort(cell_expr)[::-1] - assert len(ranked_indices) != 0, "No genes expressed in cell" expr_values = cell_expr[ranked_indices] # Expression values in descending order + gene_names = gene_names[ranked_indices] # Gene names in descending order by expression - if self.return_fit: - non_zero_mask = expr_values > 0 - if non_zero_mask.sum() > 0: - last_non_zero_idx = np.where(non_zero_mask)[0][-1] # Last index where expr > 0 - # Fit only up to the last non-zero gene - ranks_to_fit = np.arange(1, last_non_zero_idx + 2) # +1 because rank starts at 1, +1 for inclusive - expr_to_fit = expr_values[:last_non_zero_idx + 1] - # Fit linear model - model = LinearRegression() - model.fit(ranks_to_fit.reshape(-1, 1), expr_to_fit) - slope, intercept = model.coef_[0], model.intercept_ - r_squared = model.score(ranks_to_fit.reshape(-1, 1), expr_to_fit) - else: - slope, intercept, r_squared = 0.0, 0.0, 0.0 - fit_parameters.append({"slope": float(slope), "intercept": float(intercept), "r_squared": float(r_squared)}) - else: - fit_parameters.append(None) + # Cut at max_genes if desired + if max_genes: + if len(gene_names) > max_genes: + gene_names = gene_names[:max_genes] + expr_values = expr_values[:max_genes] - cell_sentence = " ".join(gene_names[ranked_indices]) + if self.return_fit: + ranks = np.arange(1, len(gene_names) + 1) + for rank, expr in zip(ranks, expr_values): + r = int(rank) + + if r not in rank_to_mean: + # first time seeing this rank + rank_to_mean[r] = expr + rank_to_count[r] = 1 + else: + # online mean update + count = rank_to_count[r] + 1 + old_mean = rank_to_mean[r] + new_mean = old_mean + (expr - old_mean) / count + + rank_to_mean[r] = new_mean + rank_to_count[r] = count + + + cell_sentence = " ".join(gene_names) cell_sentences.append(cell_sentence) progress_bar.update(1) + + + if self.return_fit: + log_ranks_to_fit = np.log10(list(rank_to_mean.keys())) + expr_to_fit = np.array(list(rank_to_mean.values())) + + # Fit linear model to predict log-normalized expression from log rank: expr(g) = slope * log(rank(g)) = intercept + model = LinearRegression() + model.fit(log_ranks_to_fit.reshape(-1, 1), np.array(expr_to_fit)) + slope, intercept = model.coef_[0], model.intercept_ + r_squared = model.score(log_ranks_to_fit.reshape(-1, 1), expr_to_fit) + + fit_parameters = {"slope": float(slope), "intercept": float(intercept), "r_squared": float(r_squared)} + + else: + fit_parameters = None + progress_bar.close() if self.perturbation_column is not None: @@ -204,7 +237,7 @@ def process_data( dataset = Dataset.from_dict({ 'cell_sentence': cell_sentences, - 'fit_parameters': fit_parameters, + 'fit_parameters': [fit_parameters] * len(cell_sentences), 'organism': [self.organism] * len(cell_sentences), 'perturbations': perturbations }) From 54e6636a93cd3770a053823410897e89cbbb38cc Mon Sep 17 00:00:00 2001 From: Benoit Putzeys <157973952+bputzeys@users.noreply.github.com> Date: Fri, 21 Nov 2025 08:51:43 +0100 Subject: [PATCH 2/2] Update pyproject.toml --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 60b69123..da6a32d2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "helical" -version = "1.4.15" +version = "1.4.16" authors = [ { name="Helical Team", email="support@helical-ai.com" }, ]