Skip to content
Merged

Main #344

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions ci/tests/test_transcriptformer/test_transcriptformer_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,3 +78,30 @@ def test_model_loads_with_list_of_pretrained_embeddings(self, tmp_path):
# All genes from both embedding files should be present in the updated vocab
for gene in self.GENES_FILE_1 + self.GENES_FILE_2:
assert gene in model.gene_vocab

def test_special_token_indices_preserved_after_surgery(self, tmp_path):
path1 = str(tmp_path / "embeddings_1.h5")
_write_dummy_embedding_h5(path1, self.GENES_FILE_1 + self.GENES_FILE_2)

base_configurer = TranscriptFormerConfig(emb_mode="cell")
base_model = TranscriptFormer(base_configurer)
base_special_token_indices = {
token: idx
for token, idx in base_model.gene_vocab.items()
if token.startswith("[") or token == "unknown"
}

surg_configurer = TranscriptFormerConfig(
emb_mode="cell",
pretrained_embedding=path1,
)
surg_model = TranscriptFormer(surg_configurer)

# Every special token must retain its original index after surgery so that
# _pad_mask (which uses model.gene_vocab.pad_idx) stays consistent with
# the PAD token written by process_batch (which uses gene_vocab["[PAD]"]).
for token, orig_idx in base_special_token_indices.items():
assert surg_model.gene_vocab[token] == orig_idx, (
f"Special token '{token}' index changed after surgery: "
f"expected {orig_idx}, got {surg_model.gene_vocab[token]}"
)
89 changes: 89 additions & 0 deletions helical/models/geneformer/geneformer_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,95 @@ def __init__(
# protein-coding and miRNA gene list dictionary for selecting .h5ad columns for tokenization
self.genelist_dict = dict(zip(self.gene_keys, [True] * len(self.gene_keys)))

def get_gene_ranks(self, adata_obj, gene_names="index"):
"""
Compute per-cell gene ranks obtained from standard rank value encoding.


Parameters
----------
adata_obj : AnnData
Raw counts scRNAseq data. Gene symbols will be mapped to
Ensembl IDs if 'ensembl_id' is not already in adata.var.
gene_names : str
Column in adata.var containing gene names, or "index" to use
var_names. Set to "ensembl_id" if the column already exists.

Returns
-------
rank_matrix : np.ndarray
Array of shape (n_cells, n_genes_in_adata).
Entries are 1-indexed ranks (1 = highest median-normalized
expression in that cell). Zero means the gene is not expressed
or not in the model vocabulary.
context_length : int
Effective context length (model_input_size, or model_input_size - 2
if special tokens are used). Genes with rank <= context_length
are "in context" for the model.
"""
if "ensembl_id" not in adata_obj.var.columns:
from helical.utils.mapping import map_gene_symbols_to_ensembl_ids
col = gene_names if gene_names != "index" else None
adata_obj = map_gene_symbols_to_ensembl_ids(adata_obj, col)

adata = sum_ensembl_ids(
adata_obj,
self.collapse_gene_ids,
self.gene_mapping_dict,
self.gene_token_dict,
file_format="h5ad",
chunk_size=self.chunk_size,
)

# Identify vocabulary genes (same filter as tokenize_anndata)
coding_miRNA_loc = np.where(
[self.genelist_dict.get(i, False) for i in adata.var["ensembl_id"]]
)[0]
norm_factor_vector = np.array(
[
self.gene_median_dict[i]
for i in adata.var["ensembl_id"].iloc[coding_miRNA_loc]
]
)

# Effective context length (account for CLS/EOS tokens)
context_length = (
self.model_input_size - 2 if self.special_token else self.model_input_size
)

if "filter_pass" in adata.obs.columns:
filter_pass_loc = np.where(
[i == 1 for i in adata.obs["filter_pass"]]
)[0]
else:
filter_pass_loc = np.arange(adata.shape[0])

n_cells = len(filter_pass_loc)
n_genes = adata.shape[1]
rank_matrix = np.zeros((adata.shape[0], n_genes), dtype=np.int32)

for i in range(0, n_cells, self.chunk_size):
idx = filter_pass_loc[i : i + self.chunk_size]
X_view = adata[idx, :].X[:, coding_miRNA_loc]

# Median-scale: e_{c,g} = r_{c,g} / m_g
X_scaled = sp.csr_matrix(X_view / norm_factor_vector.reshape(1, -1))

# Compute ranks per cell and write into dense matrix
for j in range(X_scaled.shape[0]):
row = X_scaled.getrow(j)
if row.nnz == 0:
continue

order = np.argsort(-row.data)
ranks_arr = np.empty_like(order, dtype=np.int32)
ranks_arr[order] = np.arange(1, len(order) + 1, dtype=np.int32)

orig_cols = coding_miRNA_loc[row.indices]
rank_matrix[idx[j], orig_cols] = ranks_arr

return rank_matrix, context_length

def tokenize_data(
self,
data_directory: Path | str,
Expand Down
23 changes: 18 additions & 5 deletions helical/models/transcriptformer/model_dir/embedding_surgery.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,19 +47,31 @@ def change_embedding_layer(
[token for token in special_tokens if token not in old_special_tokens],
)

# Concatenate the old special token embeddings with the new embeddings
# Build a special-token embedding matrix that preserves the original indices.
# The training vocab may assign special tokens to indices that differ from
# the order they appear in the SPECIAL_TOKENS list, so we must place each
# token's embedding at its original row rather than re-numbering them.
n_special = len(old_special_tokens)
special_emb_matrix = torch.zeros(
n_special, old_special_token_embeddings.shape[1], device="cuda"
)
for i, token in enumerate(old_special_tokens):
orig_idx = model.gene_vocab.vocab_dict[token]
special_emb_matrix[orig_idx] = old_special_token_embeddings[i]

# Concatenate the (correctly ordered) special token embeddings with the new gene embeddings
new_embedding_matrix = torch.cat(
[old_special_token_embeddings, torch.Tensor(new_embedding_matrix).to("cuda")],
[special_emb_matrix, torch.Tensor(new_embedding_matrix).to("cuda")],
dim=0,
)

# Update the vocab indices of the model
gene_vocab = {
gene: idx + len(old_special_tokens) for gene, idx in gene_vocab.items()
gene: idx + n_special for gene, idx in gene_vocab.items()
}

# Update the gene_vocab with the special tokens
gene_vocab.update({token: idx for idx, token in enumerate(old_special_tokens)})
# Restore special tokens at their original indices from the training vocab
gene_vocab.update({token: model.gene_vocab.vocab_dict[token] for token in old_special_tokens})

# Create a new embedding layer
new_embedding = torch.nn.Embedding(
Expand All @@ -80,5 +92,6 @@ def change_embedding_layer(
# Update the gene_vocab
model.gene_vocab.vocab_dict = gene_vocab
model.gene_vocab.embedding_matrix = new_embedding_matrix
model.token_to_gene_dict = {v: k for k, v in gene_vocab.items()}

return model, gene_vocab
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "hatchling.build"

[project]
name = "helical"
version = "1.8.2"
version = "1.8.3"
authors = [
{ name="Helical Team", email="support@helical-ai.com" },
]
Expand Down
Loading