Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions ci/tests/test_transcriptformer/test_transcriptformer_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@ def test_process_data__correct_ensembl_ids(self):
dataset = self.transcriptformer.process_data([self.data])
assert len(dataset) == 1
assert all(
dataset.files_list[0].var["ensembl_id"].values
dataset.files_list[0].var["ensembl_id"].astype(str).values
== [
"ENSG00000121410",
"ENSG00000036549",
None,
"nan",
"ENSG00000074755",
]
)
Expand Down
12 changes: 5 additions & 7 deletions ci/tests/test_utils/test_mapping.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
from helical.utils.mapping import map_gene_symbols_to_ensembl_ids
from helical.utils.mapping import map_ensembl_ids_to_gene_symbols
from helical.utils.mapping import convert_list_ensembl_ids_to_gene_symbols, convert_list_gene_symbols_to_ensembl_ids
from pyensembl.species import human
from pyensembl.species import macaque
import anndata as ad
import pytest

Expand All @@ -15,7 +13,7 @@ def test_map_gene_symbols_to_ensembl_ids():
CD99 should be mapped to ENSG00000002586.
"""
adata.var["gene_names"] = ["CD99"] * adata.var.shape[0]
map_gene_symbols_to_ensembl_ids(adata, gene_names="gene_names", species=human)
map_gene_symbols_to_ensembl_ids(adata, gene_names="gene_names", species='hsapiens')
assert all(adata.var["ensembl_id"] == ["ENSG00000002586"] * adata.var.shape[0])


Expand All @@ -25,7 +23,7 @@ def test_map_ensembl_ids_to_gene_symbols():
ENSG00000002330 should be mapped to BAD.
"""
adata.var["ensembl_id"] = ["ENSG00000002330"] * adata.var.shape[0]
map_ensembl_ids_to_gene_symbols(adata, ensembl_id_key="ensembl_id", species=human)
map_ensembl_ids_to_gene_symbols(adata, ensembl_id_key="ensembl_id", species='hsapiens')
assert all(adata.var["gene_names"] == ["BAD"] * adata.var.shape[0])


Expand All @@ -39,17 +37,17 @@ def test_map_gene_symbols_to_ensembl_ids_macaque():
Note, this test may be long the first time it is being run because the database for macaque needs to be downloaded.
"""
adata.var["gene_names"] = ["CD99"] * adata.var.shape[0]
map_gene_symbols_to_ensembl_ids(adata, gene_names="gene_names", species=macaque)
map_gene_symbols_to_ensembl_ids(adata, gene_names="gene_names", species='mfascicularis')
assert all(adata.var["ensembl_id"] == ["ENSMFAG00000000608"] * adata.var.shape[0])


def test_convert_list_ensembl_ids_to_gene_symbols():
ensembl_ids = ["ENSG00000139618", "ENSG00000139620"]
gene_symbols = convert_list_ensembl_ids_to_gene_symbols(ensembl_ids, species=human)
gene_symbols = convert_list_ensembl_ids_to_gene_symbols(ensembl_ids, species='hsapiens')
assert gene_symbols == ["BRCA2", "KANSL2"]


def test_convert_list_gene_symbols_to_ensembl_ids():
gene_symbols = ["BRCA2", "KANSL2"]
ensembl_ids = convert_list_gene_symbols_to_ensembl_ids(gene_symbols, species=human)
ensembl_ids = convert_list_gene_symbols_to_ensembl_ids(gene_symbols, species='hsapiens')
assert ensembl_ids == ["ENSG00000139618", "ENSG00000139620"]
189 changes: 76 additions & 113 deletions helical/utils/mapping.py
Original file line number Diff line number Diff line change
@@ -1,168 +1,131 @@
import logging
from pyensembl import genome_for_reference_name
from typing import List, Optional, Sequence

import pandas as pd
from pyensembl import EnsemblRelease
from pyensembl.species import human
from pyensembl.species import Species
from anndata import AnnData
from typing import List, Optional
import pybiomart

LOGGER = logging.getLogger(__name__)


def _get_ensembl_mart_df(species: str = "hsapiens") -> pd.DataFrame:
"""
Fetch a (species)_gene_ensembl table via pybiomart.

Parameters
----------
species : str, default "hsapiens"
Species prefix used by Ensembl Biomart (e.g., "hsapiens", "mmusculus").

Returns
-------
pandas.DataFrame
DataFrame with columns "ensembl_id" and "gene_name".
"""
server = pybiomart.Server(host="http://www.ensembl.org")
dataset = server.marts["ENSEMBL_MART_ENSEMBL"].datasets[f"{species}_gene_ensembl"]
df = dataset.query(attributes=["ensembl_gene_id", "external_gene_name"])
df = df.rename(columns={"Gene stable ID": "ensembl_id", "Gene name": "gene_name"})
return df.sort_values(by="ensembl_id")


def map_gene_symbols_to_ensembl_ids(
adata: AnnData, gene_names: str, species: Species = human
adata: AnnData, gene_names: Optional[str] = None, species: str = "hsapiens"
) -> AnnData:
"""
Map gene names to Ensembl IDs using the pyensembl library.
Due to copy events, there might be multiple genes per name. We always take the fist one.
Map gene symbols to Ensembl Gene IDs using pybiomart.

Due to duplication events, some symbols map to multiple Ensembl IDs; we take
the first occurrence after de-duplication.

Parameters
----------
adata : anndata.AnnData
Anndata object containing the gene expression data.
gene_names : str
Column name in adata.var containing the gene names.
species : pyensembl.species.Species, optional, default = human
Species for which the gene names should be mapped.
For the provided species, we take the first 'reference_assembly' as default to do the mapping.
For humans, this is the GRCh38 genome for example.
adata : AnnData
AnnData object containing gene metadata in `adata.var`.
gene_names : str, optional
Column in `adata.var` containing gene symbols. If None, uses `adata.var_names`.
species : str, default "hsapiens"
Species prefix used by Ensembl Biomart (e.g., "hsapiens", "mmusculus").

Returns
-------
anndata.AnnData
Anndata object with the gene names mapped to Ensembl IDs in adata.var["ensembl_id"]
AnnData
Same object with `adata.var["ensembl_id"]` populated.
"""
# this is one time only
ensembl_release = EnsemblRelease(species=species)
ensembl_release.download(overwrite=False)
ensembl_release.index(overwrite=False)

adata.var["ensembl_id"] = pd.Series([None] * len(adata.var), index=adata.var.index)
# we take the first reference assembly from the provided dictionary as default
genome_reference = genome_for_reference_name(
next(iter(species.reference_assemblies))
)
for index, name in adata.var[gene_names].items():
try:
adata.var.at[index, "ensembl_id"] = genome_reference.gene_ids_of_gene_name(
name
)[0]
except:
continue
var_names = adata.var[gene_names] if gene_names is not None else pd.Series(adata.var_names, index=adata.var_names)
adata.var["ensembl_id"] = convert_list_gene_symbols_to_ensembl_ids(var_names, species=species)
non_none_mappings = adata.var["ensembl_id"].notnull().sum()
LOGGER.info(
f"Mapped {non_none_mappings} genes to Ensembl IDs from a total of {adata.var.shape[0]} genes."
)
LOGGER.info("Mapped %d / %d genes to Ensembl IDs.", non_none_mappings, adata.var.shape[0])
return adata


def map_ensembl_ids_to_gene_symbols(
adata: AnnData, ensembl_id_key: str = "ensembl_id", species: Species = human
adata: AnnData, ensembl_id_key: str = "ensembl_id", species: str = "hsapiens"
) -> AnnData:
"""
Map Ensembl IDs to gene names using the pyensembl library.
We use the GRCh38 genome for mapping.
Map Ensembl Gene IDs to gene symbols using pybiomart.

Parameters
----------
adata : anndata.AnnData
Anndata object containing the gene expression data.
ensembl_id_key : str, optional, default = "ensembl_id"
Column name in adata.var containing the ensemble ids.
species : pyensembl.species.Species, optional, default = human
Species for which the Ensembl IDs should be mapped.
For the provided species, we take the first reference genome as default to do the mapping.
For humans, this is the GRCh38 genome for example.
adata : AnnData
AnnData object containing gene metadata in `adata.var`.
ensembl_id_key : str, default "ensembl_id"
Column in `adata.var` containing Ensembl Gene IDs.
species : str, default "hsapiens"
Species prefix used by Ensembl Biomart (e.g., "hsapiens", "mmusculus").

Returns
-------
anndata.AnnData
Anndata object with the Ensembl IDs mapped to gene names in adata.var["gene_names"]
AnnData
Same object with `adata.var["gene_names"]` populated.
"""
# this is one time only
ensembl_release = EnsemblRelease(species=species)
ensembl_release.download(overwrite=False)
ensembl_release.index(overwrite=False)

adata.var["gene_names"] = pd.Series([None] * len(adata.var), index=adata.var.index)
# we take the first reference assembly from the provided dictionary as default
genome_reference = genome_for_reference_name(
next(iter(species.reference_assemblies))
)
for index, ensembl_id in adata.var[ensembl_id_key].items():
try:
adata.var.at[index, "gene_names"] = genome_reference.gene_name_of_gene_id(
ensembl_id
)
except:
continue
adata.var["gene_names"] = convert_list_ensembl_ids_to_gene_symbols(adata.var[ensembl_id_key], species=species)
non_none_mappings = adata.var["gene_names"].notnull().sum()
LOGGER.info(
f"Mapped {non_none_mappings} genes to Gene names from a total of {adata.var.shape[0]} Ensembl IDs."
)
LOGGER.info("Mapped %d / %d Ensembl IDs to gene names.", non_none_mappings, adata.var.shape[0])
return adata


def convert_list_ensembl_ids_to_gene_symbols(ensembl_ids: List[str], species=human) -> List[[str]]:
def convert_list_ensembl_ids_to_gene_symbols(
ensembl_ids: Sequence[str], species: str = "hsapiens"
) -> List[Optional[str]]:
"""
Map a list of Ensembl IDs to gene symbols using pyensembl.
Map a list/sequence of Ensembl Gene IDs to gene symbols using pybiomart.

Parameters
----------
ensembl_ids : List[str]
List of Ensembl Gene IDs (e.g., ENSG00000139618).
species : pyensembl.species.Species, optional
Species to use for mapping (default is human, GRCh38).
ensembl_ids : Sequence[str]
Ensembl Gene IDs (e.g., "ENSG00000139618").
species : str, default "hsapiens"
Species prefix used by Ensembl Biomart.

Returns
-------
List[Optional[str]]
List of gene symbols (or None if not found), in the same order as the input list.
Gene symbols aligned to the input order (None if not found).
"""
# Prepare pyensembl genome reference
genome_reference = genome_for_reference_name(next(iter(species.reference_assemblies)))
genome_reference.download(overwrite=False)
genome_reference.index(overwrite=False)

# Map IDs
gene_symbols = []
for eid in ensembl_ids:
try:
symbol = genome_reference.gene_name_of_gene_id(eid)
except Exception:
symbol = None
gene_symbols.append(symbol)

return gene_symbols
df = _get_ensembl_mart_df(species=species)
mapping = df.drop_duplicates(subset="ensembl_id").set_index("ensembl_id")["gene_name"]
return list(pd.Series(ensembl_ids, dtype="object").map(mapping))


def convert_list_gene_symbols_to_ensembl_ids(gene_symbols: List[str], species=human) -> List[Optional[str]]:
def convert_list_gene_symbols_to_ensembl_ids(
gene_symbols: Sequence[str], species: str = "hsapiens"
) -> List[Optional[str]]:
"""
Map a list of gene symbols to Ensembl IDs using pyensembl.
Map a list/sequence of gene symbols to Ensembl Gene IDs using pybiomart.

Parameters
----------
gene_symbols : List[str]
List of gene symbols (e.g., BRCA2, KANSL2).
species : pyensembl.species.Species, optional
Species to use for mapping (default is human, GRCh38).
gene_symbols : Sequence[str]
Gene symbols (e.g., "BRCA2", "KANSL2").
species : str, default "hsapiens"
Species prefix used by Ensembl Biomart.

Returns
-------
List[Optional[str]]
List of Ensembl Gene IDs (or None if not found), in the same order as the input list.
Ensembl Gene IDs aligned to the input order (None if not found).
"""
genome_reference = genome_for_reference_name(next(iter(species.reference_assemblies)))
genome_reference.download(overwrite=False)
genome_reference.index(overwrite=False)

ensembl_ids = []
for symbol in gene_symbols:
try:
eid = genome_reference.gene_ids_of_gene_name(symbol)[0]
except Exception:
eid = None
ensembl_ids.append(eid)

return ensembl_ids
df = _get_ensembl_mart_df(species=species)
mapping = df.drop_duplicates(subset="gene_name").set_index("gene_name")["ensembl_id"]
return list(pd.Series(gene_symbols, dtype="object").map(mapping))
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "hatchling.build"

[project]
name = "helical"
version = "1.4.7"
version = "1.4.8"
authors = [
{ name="Helical Team", email="support@helical-ai.com" },
]
Expand Down Expand Up @@ -37,7 +37,7 @@ dependencies = [
'omegaconf==2.3.0',
'hydra-core==1.3.2',
'louvain==0.8.2',
'pyensembl',
'pybiomart',
'datasets==3.6.0'
]

Expand Down
Loading