Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion ci/tests/test_c2s/test_c2s.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,8 @@ def test_process_data_basic_properties(self, cell2sen_model, sample_anndata):

# Check dataset type and structure
assert isinstance(dataset, Dataset)

# Check if the dataset has the right size
assert len(dataset) == sample_anndata.n_obs

# Check all required columns exist
Expand Down Expand Up @@ -228,8 +230,8 @@ def test_get_embeddings_with_different_batch_sizes(self, cell2sen_model, process

def test_get_embeddings_empty_dataset(self, cell2sen_model, sample_anndata):
"""Test embeddings with empty dataset."""
empty_dataset = cell2sen_model.process_data(sample_anndata[:0])
with pytest.raises((ValueError, IndexError, AssertionError)):
empty_dataset = cell2sen_model.process_data(sample_anndata[:0])
cell2sen_model.get_embeddings(empty_dataset)

def test_get_embeddings_attention_shapes(self, cell2sen_model, processed_dataset_basic):
Expand Down
119 changes: 76 additions & 43 deletions helical/models/c2s/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import torch
import anndata
import scanpy as sc
import numpy as np
from tqdm import tqdm
from helical.models.base_models import HelicalBaseFoundationModel
Expand Down Expand Up @@ -107,7 +108,8 @@ def __init__(self, configurer: Cell2SenConfig = None) -> None:

def process_data(
self,
anndata: anndata.AnnData,
adata: anndata.AnnData,
max_genes: int = None,
):
"""
Process anndata to create a HuggingFace Dataset with cell sentences and fit parameters.
Expand All @@ -116,38 +118,34 @@ def process_data(
-----------
anndata : AnnData
Annotated data object with gene expression
min_genes : int
Minimum number of genes expressed per cell
min_counts : int
Minimum total counts per cell
min_cells : int
Minimum number of cells expressing a gene
max_cells : int, optional
Maximum number of cells to process
max_genes : int, optional
Maximum number of genes to process
organism : str, optional
Organism name. If None, tries to extract from anndata.uns or uses default
perturbation_column : str, optional
Column name in anndata.obs to use for perturbations. If None, no perturbations are created.
Maximum number of genes to process per cell in descending expression order
Returns:
--------
dataset : Dataset
HuggingFace Dataset with fields: cell_sentence, fit_parameters, organism, perturbations
"""

LOGGER.info("Processing data")

if adata.n_obs == 0:
raise ValueError("Anndata is empty. Please provide a valid anndata object.")

# standard log-normalization, enables accurate expression reconstruction
anndata = adata.copy()
sc.pp.normalize_total(anndata, target_sum=1e4)
sc.pp.log1p(anndata, base=10)

X = anndata.X
if hasattr(X, 'toarray'):
X = X.toarray()

X_log = np.log10(X + 1)
# gene names corresponding to each cell in order
# anndata.X[i, j] is the expression of the j-th gene in the i-th cell
gene_names = anndata.var_names.values
cell_sentences = []
fit_parameters = [] # Will be list of lists: [slope, intercept, r_squared] for each cell

# Collect ranks and corresponding expression means as training data for reconstruction model
rank_to_mean = {}
rank_to_count = {}

if self.organism is None:
if 'organism' in anndata.uns:
Expand All @@ -163,36 +161,71 @@ def process_data(
self.organism = "unknown" # Default if not found

# Process each cell
progress_bar = tqdm(total=X_log.shape[0], desc="Processing cells")
for cell_idx in range(X_log.shape[0]):
cell_expr = X_log[cell_idx, :]

# Rank genes by expression (highest = rank 1)
progress_bar = tqdm(total=X.shape[0], desc="Processing cells")
for cell_idx in range(X.shape[0]):
gene_names = anndata.var_names.values
cell_expr = X[cell_idx, :]
# Rank nonzero genes by expression (highest = rank 1)
non_zero_mask = cell_expr > 0
if non_zero_mask.sum() == 0:
LOGGER.warning(f"No genes expressed above zero in cell {cell_idx}. Using empty sentence.")
cell_sentence = ""
cell_sentences.append(cell_sentence)
progress_bar.update(1)
continue

cell_expr = cell_expr[non_zero_mask]
gene_names = gene_names[non_zero_mask]

ranked_indices = np.argsort(cell_expr)[::-1]
assert len(ranked_indices) != 0, "No genes expressed in cell"
expr_values = cell_expr[ranked_indices] # Expression values in descending order
gene_names = gene_names[ranked_indices] # Gene names in descending order by expression

if self.return_fit:
non_zero_mask = expr_values > 0
if non_zero_mask.sum() > 0:
last_non_zero_idx = np.where(non_zero_mask)[0][-1] # Last index where expr > 0
# Fit only up to the last non-zero gene
ranks_to_fit = np.arange(1, last_non_zero_idx + 2) # +1 because rank starts at 1, +1 for inclusive
expr_to_fit = expr_values[:last_non_zero_idx + 1]
# Fit linear model
model = LinearRegression()
model.fit(ranks_to_fit.reshape(-1, 1), expr_to_fit)
slope, intercept = model.coef_[0], model.intercept_
r_squared = model.score(ranks_to_fit.reshape(-1, 1), expr_to_fit)
else:
slope, intercept, r_squared = 0.0, 0.0, 0.0
fit_parameters.append({"slope": float(slope), "intercept": float(intercept), "r_squared": float(r_squared)})
else:
fit_parameters.append(None)
# Cut at max_genes if desired
if max_genes:
if len(gene_names) > max_genes:
gene_names = gene_names[:max_genes]
expr_values = expr_values[:max_genes]

cell_sentence = " ".join(gene_names[ranked_indices])
if self.return_fit:
ranks = np.arange(1, len(gene_names) + 1)
for rank, expr in zip(ranks, expr_values):
r = int(rank)

if r not in rank_to_mean:
# first time seeing this rank
rank_to_mean[r] = expr
rank_to_count[r] = 1
else:
# online mean update
count = rank_to_count[r] + 1
old_mean = rank_to_mean[r]
new_mean = old_mean + (expr - old_mean) / count

rank_to_mean[r] = new_mean
rank_to_count[r] = count


cell_sentence = " ".join(gene_names)
cell_sentences.append(cell_sentence)
progress_bar.update(1)


if self.return_fit:
log_ranks_to_fit = np.log10(list(rank_to_mean.keys()))
expr_to_fit = np.array(list(rank_to_mean.values()))

# Fit linear model to predict log-normalized expression from log rank: expr(g) = slope * log(rank(g)) = intercept
model = LinearRegression()
model.fit(log_ranks_to_fit.reshape(-1, 1), np.array(expr_to_fit))
slope, intercept = model.coef_[0], model.intercept_
r_squared = model.score(log_ranks_to_fit.reshape(-1, 1), expr_to_fit)

fit_parameters = {"slope": float(slope), "intercept": float(intercept), "r_squared": float(r_squared)}

else:
fit_parameters = None

progress_bar.close()

if self.perturbation_column is not None:
Expand All @@ -204,7 +237,7 @@ def process_data(

dataset = Dataset.from_dict({
'cell_sentence': cell_sentences,
'fit_parameters': fit_parameters,
'fit_parameters': [fit_parameters] * len(cell_sentences),
'organism': [self.organism] * len(cell_sentences),
'perturbations': perturbations
})
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "hatchling.build"

[project]
name = "helical"
version = "1.4.15"
version = "1.4.16"
authors = [
{ name="Helical Team", email="support@helical-ai.com" },
]
Expand Down
Loading