diff --git a/README.md b/README.md index 458cac33..3dd75048 100644 --- a/README.md +++ b/README.md @@ -72,7 +72,9 @@ or in case you're installing from the Helical repo cloned locally: pip install .[mamba-ssm] ``` -Note: make sure your machine has GPU(s) and Cuda installed. Currently this is a requirement for the packages mamba-ssm and causal-conv1d. +Note: +- Make sure your machine has GPU(s) and Cuda installed. Currently this is a requirement for the packages mamba-ssm and causal-conv1d. +- The package `causal_conv1d` requires `torch` to be installed already. First installing `helical` separately (without `[mamba-ssm]`) will install `torch` for you. A second installation (with `[mamba-ssm]`), installs the packages correctly. ### Singularity (Optional) If you desire to run your code in a singularity file, you can use the [singularity.def](./singularity.def) file and build an apptainer with it: diff --git a/ci/tests/test_geneformer/test_geneformer_model.py b/ci/tests/test_geneformer/test_geneformer_model.py index 623966ad..11702a09 100644 --- a/ci/tests/test_geneformer/test_geneformer_model.py +++ b/ci/tests/test_geneformer/test_geneformer_model.py @@ -47,6 +47,12 @@ def test_process_data_mapping_to_ensemble_ids(self, geneformer, mock_data): assert mock_data.var[mock_data.var['gene_symbols'] == 'PLEKHN1']['ensembl_id'].values[0] == 'ENSG00000187583' assert mock_data.var[mock_data.var['gene_symbols'] == 'HES4']['ensembl_id'].values[0] == 'ENSG00000188290' + def test_process_data_mapping_to_ensemble_ids_resulting_in_0_genes(self, geneformer, mock_data): + # provide a gene that does not exist in the ensembl database + mock_data.var['gene_symbols'] = ['1', '2', '3'] + with pytest.raises(ValueError): + geneformer.process_data(mock_data, gene_names="gene_symbols") + @pytest.mark.parametrize("invalid_model_names", ["gf-12L-35M-i2048", "gf-34L-30M-i5000"]) def test_pass_invalid_model_name(self, invalid_model_names): with pytest.raises(ValueError): diff --git a/ci/tests/test_scgpt/test_scgpt_model.py b/ci/tests/test_scgpt/test_scgpt_model.py index b2839752..e4ace5c8 100644 --- a/ci/tests/test_scgpt/test_scgpt_model.py +++ b/ci/tests/test_scgpt/test_scgpt_model.py @@ -89,6 +89,13 @@ def test_ensure_data_validity__value_error(self, data): self.scgpt.ensure_data_validity(data, "index", False) assert "total_counts" in data.obs + def test_process_data_no_matching_genes(self): + self.dummy_data.var['gene_ids'] = [1]*self.dummy_data.n_vars + model = scGPT() + + with pytest.raises(ValueError): + model.process_data(self.dummy_data, gene_names='gene_ids') + np_arr_data = ad.read_h5ad("ci/tests/data/cell_type_sample.h5ad") csr_data = ad.read_h5ad("ci/tests/data/cell_type_sample.h5ad") csr_data.X = csr_matrix(np.random.poisson(1, size=(100, 5)), dtype=np.float32) diff --git a/ci/tests/test_uce/test_gene_embeddings.py b/ci/tests/test_uce/test_gene_embeddings.py new file mode 100644 index 00000000..06839620 --- /dev/null +++ b/ci/tests/test_uce/test_gene_embeddings.py @@ -0,0 +1,39 @@ +from helical.models.uce.gene_embeddings import load_gene_embeddings_adata +from anndata import AnnData +import pandas as pd +import numpy as np +from pathlib import Path +import pytest +from pathlib import Path +CACHE_DIR_HELICAL = Path(Path.home(), '.cache', 'helical', 'models') + +class TestUCEGeneEmbeddings: + + adata = AnnData( + X=np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), + obs=pd.DataFrame({"species": ["human", "mouse", "rat"]}), + var=pd.DataFrame({"gene": ["gene1", "gene2", "gene3"]}) + ) + species = ["human"] + embedding_model = "ESM2" + embeddings_path = Path(CACHE_DIR_HELICAL, 'uce', "protein_embeddings") + + def test_load_gene_embeddings_adata_filtering_all_genes(self): + with pytest.raises(ValueError): + load_gene_embeddings_adata(self.adata, self.species, self.embedding_model, self.embeddings_path) + + def test_load_gene_embeddings_adata_filtering_no_genes(self): + self.adata.var_names = ['hoxa6', 'cav2', 'txk'] + anndata, mapping_dict = load_gene_embeddings_adata(self.adata, self.species, self.embedding_model, self.embeddings_path) + assert (anndata.var_names == ['hoxa6', 'cav2', 'txk']).all() + assert (anndata.obs == self.adata.obs).all().all() + assert (anndata.X == self.adata.X).all() + assert len(mapping_dict['human']) == 19790 + + def test_load_gene_embeddings_adata_filtering_some_genes(self): + self.adata.var_names = ['hoxa6', 'cav2', '1'] + anndata, mapping_dict = load_gene_embeddings_adata(self.adata, self.species, self.embedding_model, self.embeddings_path) + assert (anndata.var_names == ['hoxa6', 'cav2']).all() + assert (anndata.obs == self.adata.obs).all().all() + assert (anndata.X == [[1, 2], [4, 5], [7, 8]]).all() + assert len(mapping_dict['human']) == 19790 \ No newline at end of file diff --git a/examples/notebooks/Helix-mRNA.ipynb b/examples/notebooks/Helix-mRNA.ipynb index d34513b8..3433a3db 100644 --- a/examples/notebooks/Helix-mRNA.ipynb +++ b/examples/notebooks/Helix-mRNA.ipynb @@ -1,443 +1,432 @@ { - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Use Helix-mRNA" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Access the Helical GitHub [here](https://github.com/helicalAI)!" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "**In this notebook we will dive into using our latest mRNA Bio Foundation Model, Helix-mRNA.**\n", - "\n", - "**We will get and plot embeddings for our data.**\n", - "\n", - "**We will fine-tune the model both using the Helical package**" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## If running on colab, run the cell below. Comment out if running locally" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "!pip uninstall torch torchtext torchaudio -y\n", - "!pip install torch==2.3.0 torchtext torchaudio --index-url https://download.pytorch.org/whl/cu121\n", - "!pip uninstall mamba-ssm causal-conv1d -y\n", - "!pip install mamba-ssm==2.2.2 --no-cache-dir" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "!pip install helical\n", - "!pip uninstall causal-conv1d -y" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Imports" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "INFO:datasets:PyTorch version 2.3.0 available.\n", - "INFO:datasets:Polars version 0.20.31 available.\n", - "INFO:datasets:JAX version 0.4.31 available.\n" - ] - } - ], - "source": [ - "from helical import HelixmRNAConfig, HelixmRNA, HelixmRNAFineTuningModel\n", - "import subprocess\n", - "import torch\n", - "import pandas as pd\n", - "\n", - "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Download one of CodonBERT's fine-tuning benchmarks" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "url = \"https://raw.githubusercontent.com/Sanofi-Public/CodonBERT/refs/heads/master/benchmarks/CodonBERT/data/fine-tune/mRFP_Expression.csv\"\n", - "\n", - "output_filename = \"mRFP_Expression.csv\"\n", - "wget_command = [\"wget\", \"-O\", output_filename, url]\n", - "\n", - "try:\n", - " subprocess.run(wget_command, check=True)\n", - " print(f\"File downloaded successfully as {output_filename}\")\n", - "except subprocess.CalledProcessError as e:\n", - " print(f\"Error occurred: {e}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Load the dataset as a pandas dataframe and get the splits\n", - "- For this example we take a subset of the splits, feel free to run it on the entire dataset!" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "dataset = pd.read_csv(output_filename)\n", - "train_data = dataset[dataset[\"Split\"] == \"train\"][:10]\n", - "eval_data = dataset[dataset[\"Split\"] == \"val\"][:5]\n", - "test_data = dataset[dataset[\"Split\"] == \"test\"][:5]" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Define our Helix-mRNA model and desired configs" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "INFO:helical.models.helix_mrna.model:Helix-mRNA initialized successfully.\n" - ] - } - ], - "source": [ - "# We set the max length to the maximum length of the sequences in the training data + 10 to include space for special tokens\n", - "helix_mrna_config = HelixmRNAConfig(device=device, batch_size=5, max_length=max(len(s) for s in train_data[\"Sequence\"])+10)\n", - "helix_mrna = HelixmRNA(helix_mrna_config)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Process our training sequences to tokenize them and prepare them for the model" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [], - "source": [ - "processed_train_data = helix_mrna.process_data(train_data[\"Sequence\"].to_list())" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Generate embeddings for the train data" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "- We get an embeddings for each letter/token in the sequence, in this case 100 embeddings for each of the 688 tokens and our embedding dimension is 256\n", - "- Because the model has a recurrent nature, our final non-special token embedding at the second last position encapsulates everything that came before it" - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Getting embeddings: 100%|██████████| 20/20 [00:00<00:00, 71.50it/s]" - ] + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Use Helix-mRNA" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Access the Helical GitHub [here](https://github.com/helicalAI)!" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**In this notebook we will dive into using our latest mRNA Bio Foundation Model, Helix-mRNA.**\n", + "\n", + "**We will get and plot embeddings for our data.**\n", + "\n", + "**We will fine-tune the model both using the Helical package**" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "If running on a CUDA device compatible with mamba-ssm and causal-conv1d install the package below, otherwise remove the [mamba-ssm] optional dependency\n", + "- If running on colab, remove the [mamba-ssm] dependency" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!pip install --upgrade helical[mamba-ssm]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Imports" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:datasets:PyTorch version 2.3.0 available.\n", + "INFO:datasets:Polars version 0.20.31 available.\n", + "INFO:datasets:JAX version 0.4.31 available.\n" + ] + } + ], + "source": [ + "from helical import HelixmRNAConfig, HelixmRNA, HelixmRNAFineTuningModel\n", + "import subprocess\n", + "import torch\n", + "import pandas as pd\n", + "\n", + "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Download one of CodonBERT's fine-tuning benchmarks" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "url = \"https://raw.githubusercontent.com/Sanofi-Public/CodonBERT/refs/heads/master/benchmarks/CodonBERT/data/fine-tune/mRFP_Expression.csv\"\n", + "\n", + "output_filename = \"mRFP_Expression.csv\"\n", + "wget_command = [\"wget\", \"-O\", output_filename, url]\n", + "\n", + "try:\n", + " subprocess.run(wget_command, check=True)\n", + " print(f\"File downloaded successfully as {output_filename}\")\n", + "except subprocess.CalledProcessError as e:\n", + " print(f\"Error occurred: {e}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Load the dataset as a pandas dataframe and get the splits\n", + "- For this example we take a subset of the splits, feel free to run it on the entire dataset!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "dataset = pd.read_csv(output_filename)\n", + "train_data = dataset[dataset[\"Split\"] == \"train\"][:10]\n", + "eval_data = dataset[dataset[\"Split\"] == \"val\"][:5]\n", + "test_data = dataset[dataset[\"Split\"] == \"test\"][:5]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Define our Helix-mRNA model and desired configs" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:helical.models.helix_mrna.model:Helix-mRNA initialized successfully.\n" + ] + } + ], + "source": [ + "# We set the max length to the maximum length of the sequences in the training data + 10 to include space for special tokens\n", + "helix_mrna_config = HelixmRNAConfig(device=device, batch_size=1, max_length=max(len(s) for s in train_data[\"Sequence\"])+10)\n", + "helix_mrna = HelixmRNA(helix_mrna_config)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Process our training sequences to tokenize them and prepare them for the model" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "processed_train_data = helix_mrna.process_data(train_data[\"Sequence\"].to_list())" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Generate embeddings for the train data" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "- We get an embeddings for each letter/token in the sequence, in this case 100 embeddings for each of the 688 tokens and our embedding dimension is 256\n", + "- Because the model has a recurrent nature, our final non-special token embedding at the second last position encapsulates everything that came before it" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Getting embeddings: 100%|██████████| 20/20 [00:00<00:00, 71.50it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(100, 256)\n", + "[[-9.70379915e-04 5.94667019e-03 1.07590854e-02 5.22067677e-03\n", + " -9.54915071e-04 -6.74154516e-03 -2.91207526e-03 -1.49831397e-03\n", + " 1.78750437e-02 5.13957115e-03 8.79890576e-04 1.21943112e-02\n", + " -1.92209042e-03 3.27306171e-03 -2.27077748e-03 4.50014602e-04\n", + " 7.30314665e-03 -8.66744318e-04 -8.81821662e-03 -7.57190645e-01\n", + " 1.89280566e-02 -4.05776373e-04 6.08320069e-03 -1.78794132e-03\n", + " -8.79776548e-04 -8.19147026e-05 9.60175938e-04 -8.30806512e-03\n", + " 5.66601008e-03 -5.93393855e-03 -5.19109843e-03 6.86887605e-03\n", + " -7.94085041e-02 -5.38914884e-03 -1.55241350e-02 -2.42359545e-02\n", + " 2.57678051e-03 -9.53892432e-03 -7.16619950e-04 1.50164040e-02\n", + " -9.01486576e-01 -4.68801707e-03 3.71015654e-03 -1.07593695e-02\n", + " 9.67101427e-04 5.75249782e-03 2.86138593e-03 -6.41007500e-04\n", + " -3.93231586e-03 -5.53809397e-04 1.72096007e-02 -8.10448000e-06\n", + " 1.20042302e-02 -7.83413649e-03 2.40328256e-03 1.44813021e-04\n", + " 6.37711585e-03 -2.75100190e-02 -9.19151399e-03 2.25025918e-02\n", + " -2.71240231e-02 3.18764849e-03 -8.12906027e-03 -9.14498232e-03\n", + " 3.32334498e-03 8.77279043e-03 2.19076360e-03 -4.82588913e-03\n", + " 7.93280269e-05 -4.37264703e-03 -1.03688613e-02 -1.15277218e-02\n", + " -4.73860680e-04 1.04218733e-03 4.24548006e-03 3.50359431e-03\n", + " -1.29856959e-01 1.34938598e-01 -3.02679911e-02 -2.39217840e-02\n", + " -2.36590131e-04 -4.36108152e-04 4.19709226e-03 -8.65293201e-03\n", + " 1.99613022e-03 1.79194030e-03 2.84497248e-04 1.43997103e-03\n", + " -1.53440237e-01 -2.64659454e-03 2.34050487e-04 -2.68558436e-03\n", + " -2.99103018e-02 -4.54764161e-03 7.06061395e-03 -2.53863144e-03\n", + " 8.90936144e-03 -7.19320118e-01 7.57683277e-01 -1.85208302e-02\n", + " -1.54839545e-01 -2.54138529e-01 8.30988749e-04 -8.85983836e-03\n", + " -9.01097711e-03 -7.00991787e-03 -1.82072073e-01 -5.14741063e-01\n", + " -2.93075689e-03 2.55425880e-03 1.37590896e-03 5.38261468e-03\n", + " -5.74133135e-02 -4.50886175e-04 1.39132570e-02 -9.26930178e-03\n", + " 3.89014231e-03 3.58247235e-02 1.38020525e-02 4.48753638e-03\n", + " 4.69827838e-03 5.32380529e-02 7.67468300e-04 -2.27806643e-02\n", + " 9.79826669e-04 3.29421629e-04 2.56255385e-03 -3.15385172e-04\n", + " 1.13730943e-02 5.02255885e-03 7.63128162e-04 -4.30183439e-03\n", + " -1.41088907e-02 -7.07946122e-02 2.18413552e-04 -4.30437940e-04\n", + " 5.93306264e-03 3.88289336e-03 -6.69274572e-03 -1.05123809e-02\n", + " 7.17154052e-03 9.30194370e-03 -2.66307388e-02 -2.35042372e-03\n", + " -3.61418119e-03 -1.88636947e-02 4.10996377e-03 1.86230207e-03\n", + " -7.77591905e-03 1.07999649e-02 -2.15348396e-02 -1.56054425e-03\n", + " -4.75367473e-04 -2.42964807e-03 1.37075689e-03 -1.18554395e-03\n", + " 1.96172502e-02 8.72136280e-03 -2.54987436e-03 -1.78763457e-03\n", + " 1.48834437e-01 4.15487972e-04 -8.82838969e-04 -4.85490542e-04\n", + " 9.73013118e-02 1.01735163e-02 9.76046920e-03 7.66289607e-03\n", + " 3.93118672e-02 5.41610224e-03 -7.19898380e-03 -4.61950190e-02\n", + " 6.28079474e-03 2.30385065e-02 -1.32811114e-01 2.61072395e-03\n", + " 2.72905454e-03 -8.26253928e-03 2.76575685e-02 -1.16535993e-02\n", + " 7.09296510e-05 -5.02431765e-03 -2.00841855e-02 -9.82477888e-03\n", + " 1.99634713e-04 -2.33941106e-03 -1.01937279e-02 -6.17030673e-02\n", + " 5.41278534e-03 9.48928879e-04 9.36821289e-03 -7.82263931e-03\n", + " -1.20594129e-02 -6.56401785e-03 -8.18305537e-02 8.73102434e-03\n", + " -2.41522095e-03 6.06243312e-03 -2.66978621e-01 8.72417178e-04\n", + " 8.10213108e-03 -1.89128786e-01 8.86955822e-04 1.45062711e-02\n", + " -4.65695048e-03 -3.56003083e-03 -1.77745167e-02 -3.33940163e-02\n", + " 1.01557758e-04 -8.14760383e-03 8.52145813e-03 -9.25995596e-03\n", + " 4.04966250e-03 3.44780415e-01 -4.55286279e-02 1.27975168e-02\n", + " 5.34113357e-03 -1.58847857e-03 4.65576863e-03 -3.07517336e-03\n", + " 2.26003435e-02 -3.24756862e-03 7.61093199e-02 -8.04630481e-03\n", + " 1.46187656e-02 -4.15891828e-03 -7.69484183e-03 -6.56060642e-03\n", + " 6.32394385e-03 1.89167308e-03 -1.65223163e-02 -7.15268105e-02\n", + " -5.13067655e-02 8.95772595e-03 3.47553840e-04 3.91185429e-04\n", + " 1.50305599e-01 5.39071718e-03 -3.56106623e-03 -1.07512353e-02\n", + " -1.70928031e-01 2.04306114e-02 2.14800192e-03 -1.82585061e-01\n", + " -4.20546830e-02 1.20053962e-02 -6.52526272e-04 1.29553266e-02\n", + " 2.05104008e-01 -3.85842402e-03 8.15556012e-03 -6.55666053e-01\n", + " -6.22088835e-03 1.99010246e-03 -1.32145118e-02 1.12704304e-03]]\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "embeddings = helix_mrna.get_embeddings(processed_train_data)\n", + "embeddings = embeddings[:, -2, :]\n", + "print(embeddings.shape)\n", + "print(embeddings[:1])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Fine-tuning the model on our data\n", + "- This is a regression task and so our output is 1 continuous value" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:helical.models.helix_mrna.model:Helix-mRNA initialized successfully.\n" + ] + } + ], + "source": [ + "helix_mrna_fine_tuning_model = HelixmRNAFineTuningModel(helix_mrna_config=helix_mrna_config, fine_tuning_head=\"regression\", output_size=1)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Our training data is already processed since the standard Helix-mRNA model and fine-tuning model take the same input!\n", + "- We process our eval and test data" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "processed_eval_data = helix_mrna_fine_tuning_model.process_data(eval_data[\"Sequence\"].to_list())\n", + "processed_test_data = helix_mrna_fine_tuning_model.process_data(test_data[\"Sequence\"].to_list())" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Run fine-tuning on the model for this small sample of data" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:helical.models.helix_mrna.fine_tuning_model:Unfreezing the last 2 layers of the Helix_mRNA model.\n", + "INFO:helical.models.helix_mrna.fine_tuning_model:Starting Fine-Tuning\n", + "Fine-Tuning: epoch 1/5: 100%|██████████| 20/20 [00:00<00:00, 49.31it/s, loss=86.9]\n", + "Fine-Tuning Validation: 100%|██████████| 4/4 [00:00<00:00, 56.86it/s, val_loss=88.8]\n", + "Fine-Tuning: epoch 2/5: 100%|██████████| 20/20 [00:00<00:00, 59.90it/s, loss=80.6]\n", + "Fine-Tuning Validation: 100%|██████████| 4/4 [00:00<00:00, 53.01it/s, val_loss=83.3]\n", + "Fine-Tuning: epoch 3/5: 100%|██████████| 20/20 [00:00<00:00, 51.32it/s, loss=74.9]\n", + "Fine-Tuning Validation: 100%|██████████| 4/4 [00:00<00:00, 167.73it/s, val_loss=76.6]\n", + "Fine-Tuning: epoch 4/5: 100%|██████████| 20/20 [00:00<00:00, 49.26it/s, loss=67.3]\n", + "Fine-Tuning Validation: 100%|██████████| 4/4 [00:00<00:00, 54.98it/s, val_loss=67.4]\n", + "Fine-Tuning: epoch 5/5: 100%|██████████| 20/20 [00:00<00:00, 60.40it/s, loss=59.8]\n", + "Fine-Tuning Validation: 100%|██████████| 4/4 [00:00<00:00, 50.23it/s, val_loss=61.1]\n", + "INFO:helical.models.helix_mrna.fine_tuning_model:Fine-Tuning Complete. Epochs: 5\n" + ] + } + ], + "source": [ + "helix_mrna_fine_tuning_model.train(train_dataset=processed_train_data, \n", + " train_labels=train_data[\"Value\"].to_numpy().reshape(-1, 1),\n", + " validation_dataset=processed_eval_data, \n", + " validation_labels= eval_data[\"Value\"].to_numpy().reshape(-1, 1),\n", + " epochs=5,\n", + " loss_function=torch.nn.MSELoss(),\n", + " trainable_layers=2)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Get outputs from our model on the test data" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Generating outputs: 100%|██████████| 2/2 [00:00<00:00, 46.17it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[[2.682281 ]\n", + " [2.4182768]\n", + " [2.4362845]\n", + " [2.6120207]\n", + " [2.6543183]\n", + " [2.6988027]\n", + " [2.671821 ]\n", + " [2.144202 ]\n", + " [2.6866376]\n", + " [2.6734226]]\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "outputs = helix_mrna_fine_tuning_model.get_outputs(processed_test_data)\n", + "print(outputs)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "helical-package", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.8" + } }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "(100, 256)\n", - "[[-9.70379915e-04 5.94667019e-03 1.07590854e-02 5.22067677e-03\n", - " -9.54915071e-04 -6.74154516e-03 -2.91207526e-03 -1.49831397e-03\n", - " 1.78750437e-02 5.13957115e-03 8.79890576e-04 1.21943112e-02\n", - " -1.92209042e-03 3.27306171e-03 -2.27077748e-03 4.50014602e-04\n", - " 7.30314665e-03 -8.66744318e-04 -8.81821662e-03 -7.57190645e-01\n", - " 1.89280566e-02 -4.05776373e-04 6.08320069e-03 -1.78794132e-03\n", - " -8.79776548e-04 -8.19147026e-05 9.60175938e-04 -8.30806512e-03\n", - " 5.66601008e-03 -5.93393855e-03 -5.19109843e-03 6.86887605e-03\n", - " -7.94085041e-02 -5.38914884e-03 -1.55241350e-02 -2.42359545e-02\n", - " 2.57678051e-03 -9.53892432e-03 -7.16619950e-04 1.50164040e-02\n", - " -9.01486576e-01 -4.68801707e-03 3.71015654e-03 -1.07593695e-02\n", - " 9.67101427e-04 5.75249782e-03 2.86138593e-03 -6.41007500e-04\n", - " -3.93231586e-03 -5.53809397e-04 1.72096007e-02 -8.10448000e-06\n", - " 1.20042302e-02 -7.83413649e-03 2.40328256e-03 1.44813021e-04\n", - " 6.37711585e-03 -2.75100190e-02 -9.19151399e-03 2.25025918e-02\n", - " -2.71240231e-02 3.18764849e-03 -8.12906027e-03 -9.14498232e-03\n", - " 3.32334498e-03 8.77279043e-03 2.19076360e-03 -4.82588913e-03\n", - " 7.93280269e-05 -4.37264703e-03 -1.03688613e-02 -1.15277218e-02\n", - " -4.73860680e-04 1.04218733e-03 4.24548006e-03 3.50359431e-03\n", - " -1.29856959e-01 1.34938598e-01 -3.02679911e-02 -2.39217840e-02\n", - " -2.36590131e-04 -4.36108152e-04 4.19709226e-03 -8.65293201e-03\n", - " 1.99613022e-03 1.79194030e-03 2.84497248e-04 1.43997103e-03\n", - " -1.53440237e-01 -2.64659454e-03 2.34050487e-04 -2.68558436e-03\n", - " -2.99103018e-02 -4.54764161e-03 7.06061395e-03 -2.53863144e-03\n", - " 8.90936144e-03 -7.19320118e-01 7.57683277e-01 -1.85208302e-02\n", - " -1.54839545e-01 -2.54138529e-01 8.30988749e-04 -8.85983836e-03\n", - " -9.01097711e-03 -7.00991787e-03 -1.82072073e-01 -5.14741063e-01\n", - " -2.93075689e-03 2.55425880e-03 1.37590896e-03 5.38261468e-03\n", - " -5.74133135e-02 -4.50886175e-04 1.39132570e-02 -9.26930178e-03\n", - " 3.89014231e-03 3.58247235e-02 1.38020525e-02 4.48753638e-03\n", - " 4.69827838e-03 5.32380529e-02 7.67468300e-04 -2.27806643e-02\n", - " 9.79826669e-04 3.29421629e-04 2.56255385e-03 -3.15385172e-04\n", - " 1.13730943e-02 5.02255885e-03 7.63128162e-04 -4.30183439e-03\n", - " -1.41088907e-02 -7.07946122e-02 2.18413552e-04 -4.30437940e-04\n", - " 5.93306264e-03 3.88289336e-03 -6.69274572e-03 -1.05123809e-02\n", - " 7.17154052e-03 9.30194370e-03 -2.66307388e-02 -2.35042372e-03\n", - " -3.61418119e-03 -1.88636947e-02 4.10996377e-03 1.86230207e-03\n", - " -7.77591905e-03 1.07999649e-02 -2.15348396e-02 -1.56054425e-03\n", - " -4.75367473e-04 -2.42964807e-03 1.37075689e-03 -1.18554395e-03\n", - " 1.96172502e-02 8.72136280e-03 -2.54987436e-03 -1.78763457e-03\n", - " 1.48834437e-01 4.15487972e-04 -8.82838969e-04 -4.85490542e-04\n", - " 9.73013118e-02 1.01735163e-02 9.76046920e-03 7.66289607e-03\n", - " 3.93118672e-02 5.41610224e-03 -7.19898380e-03 -4.61950190e-02\n", - " 6.28079474e-03 2.30385065e-02 -1.32811114e-01 2.61072395e-03\n", - " 2.72905454e-03 -8.26253928e-03 2.76575685e-02 -1.16535993e-02\n", - " 7.09296510e-05 -5.02431765e-03 -2.00841855e-02 -9.82477888e-03\n", - " 1.99634713e-04 -2.33941106e-03 -1.01937279e-02 -6.17030673e-02\n", - " 5.41278534e-03 9.48928879e-04 9.36821289e-03 -7.82263931e-03\n", - " -1.20594129e-02 -6.56401785e-03 -8.18305537e-02 8.73102434e-03\n", - " -2.41522095e-03 6.06243312e-03 -2.66978621e-01 8.72417178e-04\n", - " 8.10213108e-03 -1.89128786e-01 8.86955822e-04 1.45062711e-02\n", - " -4.65695048e-03 -3.56003083e-03 -1.77745167e-02 -3.33940163e-02\n", - " 1.01557758e-04 -8.14760383e-03 8.52145813e-03 -9.25995596e-03\n", - " 4.04966250e-03 3.44780415e-01 -4.55286279e-02 1.27975168e-02\n", - " 5.34113357e-03 -1.58847857e-03 4.65576863e-03 -3.07517336e-03\n", - " 2.26003435e-02 -3.24756862e-03 7.61093199e-02 -8.04630481e-03\n", - " 1.46187656e-02 -4.15891828e-03 -7.69484183e-03 -6.56060642e-03\n", - " 6.32394385e-03 1.89167308e-03 -1.65223163e-02 -7.15268105e-02\n", - " -5.13067655e-02 8.95772595e-03 3.47553840e-04 3.91185429e-04\n", - " 1.50305599e-01 5.39071718e-03 -3.56106623e-03 -1.07512353e-02\n", - " -1.70928031e-01 2.04306114e-02 2.14800192e-03 -1.82585061e-01\n", - " -4.20546830e-02 1.20053962e-02 -6.52526272e-04 1.29553266e-02\n", - " 2.05104008e-01 -3.85842402e-03 8.15556012e-03 -6.55666053e-01\n", - " -6.22088835e-03 1.99010246e-03 -1.32145118e-02 1.12704304e-03]]\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "\n" - ] - } - ], - "source": [ - "embeddings = helix_mrna.get_embeddings(processed_train_data)\n", - "embeddings = embeddings[:, -2, :]\n", - "print(embeddings.shape)\n", - "print(embeddings[:1])" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Fine-tuning the model on our data\n", - "- This is a regression task and so our output is 1 continuous value" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "INFO:helical.models.helix_mrna.model:Helix-mRNA initialized successfully.\n" - ] - } - ], - "source": [ - "helix_mrna_fine_tuning_model = HelixmRNAFineTuningModel(helix_mrna_config=helix_mrna_config, fine_tuning_head=\"regression\", output_size=1)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Our training data is already processed since the standard Helix-mRNA model and fine-tuning model take the same input!\n", - "- We process our eval and test data" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [], - "source": [ - "processed_eval_data = helix_mrna_fine_tuning_model.process_data(eval_data[\"Sequence\"].to_list())\n", - "processed_test_data = helix_mrna_fine_tuning_model.process_data(test_data[\"Sequence\"].to_list())" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Run fine-tuning on the model for this small sample of data" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "INFO:helical.models.helix_mrna.fine_tuning_model:Unfreezing the last 2 layers of the Helix_mRNA model.\n", - "INFO:helical.models.helix_mrna.fine_tuning_model:Starting Fine-Tuning\n", - "Fine-Tuning: epoch 1/5: 100%|██████████| 20/20 [00:00<00:00, 49.31it/s, loss=86.9]\n", - "Fine-Tuning Validation: 100%|██████████| 4/4 [00:00<00:00, 56.86it/s, val_loss=88.8]\n", - "Fine-Tuning: epoch 2/5: 100%|██████████| 20/20 [00:00<00:00, 59.90it/s, loss=80.6]\n", - "Fine-Tuning Validation: 100%|██████████| 4/4 [00:00<00:00, 53.01it/s, val_loss=83.3]\n", - "Fine-Tuning: epoch 3/5: 100%|██████████| 20/20 [00:00<00:00, 51.32it/s, loss=74.9]\n", - "Fine-Tuning Validation: 100%|██████████| 4/4 [00:00<00:00, 167.73it/s, val_loss=76.6]\n", - "Fine-Tuning: epoch 4/5: 100%|██████████| 20/20 [00:00<00:00, 49.26it/s, loss=67.3]\n", - "Fine-Tuning Validation: 100%|██████████| 4/4 [00:00<00:00, 54.98it/s, val_loss=67.4]\n", - "Fine-Tuning: epoch 5/5: 100%|██████████| 20/20 [00:00<00:00, 60.40it/s, loss=59.8]\n", - "Fine-Tuning Validation: 100%|██████████| 4/4 [00:00<00:00, 50.23it/s, val_loss=61.1]\n", - "INFO:helical.models.helix_mrna.fine_tuning_model:Fine-Tuning Complete. Epochs: 5\n" - ] - } - ], - "source": [ - "helix_mrna_fine_tuning_model.train(train_dataset=processed_train_data, \n", - " train_labels=train_data[\"Value\"].to_numpy().reshape(-1, 1),\n", - " validation_dataset=processed_eval_data, \n", - " validation_labels= eval_data[\"Value\"].to_numpy().reshape(-1, 1),\n", - " epochs=5,\n", - " loss_function=torch.nn.MSELoss(),\n", - " trainable_layers=2)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Get outputs from our model on the test data" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Generating outputs: 100%|██████████| 2/2 [00:00<00:00, 46.17it/s]" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[[2.682281 ]\n", - " [2.4182768]\n", - " [2.4362845]\n", - " [2.6120207]\n", - " [2.6543183]\n", - " [2.6988027]\n", - " [2.671821 ]\n", - " [2.144202 ]\n", - " [2.6866376]\n", - " [2.6734226]]\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "\n" - ] - } - ], - "source": [ - "outputs = helix_mrna_fine_tuning_model.get_outputs(processed_test_data)\n", - "print(outputs)" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "helical-package", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.8" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} + "nbformat": 4, + "nbformat_minor": 2 + } + \ No newline at end of file diff --git a/helical/__init__.py b/helical/__init__.py index e5fa83ab..65619bb3 100644 --- a/helical/__init__.py +++ b/helical/__init__.py @@ -3,10 +3,6 @@ logging.captureWarnings(True) -class InfoAndErrorFilter(logging.Filter): - def filter(self, record): - return record.levelno in (logging.INFO, logging.ERROR) - for handler in logging.root.handlers[:]: logging.root.removeHandler(handler) @@ -16,8 +12,6 @@ def filter(self, record): handler = logging.StreamHandler() handler.setLevel(logging.INFO) -handler.addFilter(InfoAndErrorFilter()) - formatter = logging.Formatter('%(levelname)s:%(name)s:%(message)s') handler.setFormatter(formatter) @@ -41,4 +35,4 @@ def filter(self, record): from .models.caduceus import Caduceus, CaduceusConfig, CaduceusFineTuningModel except: LOGGER = logging.getLogger(__name__) - LOGGER.info("Caduceus not available: If you want to use this model, ensure you have a CUDA GPU and have installed the optional helical[mamba-ssm] dependencies.") \ No newline at end of file + LOGGER.info("Caduceus not available: If you want to use this model, ensure you have a CUDA GPU and have installed the optional helical[mamba-ssm] dependencies.") diff --git a/helical/models/geneformer/model.py b/helical/models/geneformer/model.py index bdf2d109..d088b05a 100644 --- a/helical/models/geneformer/model.py +++ b/helical/models/geneformer/model.py @@ -153,7 +153,11 @@ def process_data(self, raise ValueError(message) adata = map_gene_symbols_to_ensembl_ids(adata, gene_names) - + if adata.var["ensembl_id"].isnull().all(): + message = "All gene symbols could not be mapped to Ensembl IDs. Please check the input data." + LOGGER.info(message) + raise ValueError(message) + tokenized_cells, cell_metadata = self.tk.tokenize_anndata(adata) # tokenized_cells, cell_metadata = self.tk.tokenize_anndata(adata) diff --git a/helical/models/scgpt/model.py b/helical/models/scgpt/model.py index 0557ddad..4ce4614d 100644 --- a/helical/models/scgpt/model.py +++ b/helical/models/scgpt/model.py @@ -268,7 +268,13 @@ def process_data(self, # filtering adata.var["id_in_vocab"] = [ self.vocab[gene] if gene in self.vocab else -1 for gene in adata.var[self.gene_names] ] - LOGGER.info(f"Filtering out {np.sum(adata.var['id_in_vocab'] < 0)} genes to a total of {np.sum(adata.var['id_in_vocab'] >= 0)} genes with an id in the scGPT vocabulary.") + LOGGER.info(f"Filtering out {np.sum(adata.var['id_in_vocab'] < 0)} genes to a total of {np.sum(adata.var['id_in_vocab'] >= 0)} genes with an ID in the scGPT vocabulary.") + + if np.sum(adata.var["id_in_vocab"] >= 0) == 0: + message = "No matching genes found between input data and scGPT gene vocabulary. Please check the gene names in .var of the anndata input object." + LOGGER.error(message) + raise ValueError(message) + adata = adata[:, adata.var["id_in_vocab"] >= 0] # Binning will be applied after tokenization. A possible way to do is to use the unified way of binning in the data collator. diff --git a/helical/models/uce/gene_embeddings.py b/helical/models/uce/gene_embeddings.py index 55f5eee9..1fb607df 100644 --- a/helical/models/uce/gene_embeddings.py +++ b/helical/models/uce/gene_embeddings.py @@ -79,6 +79,11 @@ def load_gene_embeddings_adata(adata: AnnData, species: list, embedding_model: s filtered = adata.var_names.shape[0] - filtered_adata.var_names.shape[0] LOGGER.info(f'Filtered out {filtered} genes to a total of {filtered_adata.var_names.shape[0]} genes with embeddings.') + if filtered_adata.var_names.shape[0] == 0: + message = "No matching genes found between input data and UCE gene embedding vocabulary. Please check the gene names in .var of the anndata input object." + LOGGER.error(message) + raise ValueError(message) + # Load gene symbols for desired species for later use with indexes species_to_all_gene_symbols = { species: [ diff --git a/helical/utils/downloader.py b/helical/utils/downloader.py index 35c115b2..d4a143c7 100644 --- a/helical/utils/downloader.py +++ b/helical/utils/downloader.py @@ -17,6 +17,7 @@ HASH_DICT = { 'uce/4layer_model.torch': '16430370e0d672c8db6e275440e7974d2fd0a21f29aa9299e141085f82a5a886', + 'uce/33l_8ep_1024t_1280.torch': 'aa6457a0eb2e91d8382d96fb455456e40a9423a00509ea296079a75b1a9390c0', 'uce/all_tokens.torch': 'e3e3ad03a9f8fdca8babec5b0c72f7f4043a4bec2e3eb009b8fe1b28d984c93a', 'uce/species_chrom.csv': '7f5d32e6adcc3786c613043a4de8e2a47187935cfb9a1d3fcf7373eb50caebf7', 'uce/species_offsets.pkl': 'abda5b2bc4018187e408623b292686a061912f449daceb4c9c9603caf0d62538', diff --git a/mkdocs.yml b/mkdocs.yml index 3e228d6f..cac2519b 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -4,17 +4,7 @@ repo_url: https://github.com/helicalAI/helical repo_name: helicalAI/helical copyright: Helical Team
Copyright © 2024 nav: - - RNA Models: - - Helix-mRNA: - - Model Card: ./model_cards/helix_mrna.md - - Config: ./configs/helix_mrna_config.md - - Model: ./models/helix_mrna.md - - Fine-Tuning Model: ./fine_tuning_models/helix_mrna_fine_tune.md - - Mamba2-mRNA: - - Model Card: ./model_cards/mamba2_mrna.md - - Config: ./configs/mamba2_mrna_config.md - - Model: ./models/mamba2_mrna.md - - Fine-Tuning Model: ./fine_tuning_models/mamba2_mrna_fine_tune.md + - Single-Cell Models: - Geneformer: - Model Card: ./model_cards/geneformer.md - Config: ./configs/geneformer_config.md @@ -30,7 +20,18 @@ nav: - Config: ./configs/uce_config.md - Model: ./models/uce.md - Fine-Tuning Model: ./fine_tuning_models/uce_fine_tune.md - - DNA Models: + - RNA Sequence Models: + - Helix-mRNA: + - Model Card: ./model_cards/helix_mrna.md + - Config: ./configs/helix_mrna_config.md + - Model: ./models/helix_mrna.md + - Fine-Tuning Model: ./fine_tuning_models/helix_mrna_fine_tune.md + - Mamba2-mRNA: + - Model Card: ./model_cards/mamba2_mrna.md + - Config: ./configs/mamba2_mrna_config.md + - Model: ./models/mamba2_mrna.md + - Fine-Tuning Model: ./fine_tuning_models/mamba2_mrna_fine_tune.md + - DNA Sequence Models: - HyenaDNA: - Model Card: ./model_cards/hyenadna.md - Config: ./configs/hyenadna_config.md @@ -41,8 +42,8 @@ nav: - Config: ./configs/caduceus_config.md - Model: ./models/caduceus.md - Fine-Tuning Model: ./fine_tuning_models/caduceus_fine_tune.md - - Helical Base Models: ./models/base_models.md - - Fine-Tuning Heads: ./models/fine_tuning_heads.md + - Helical Base Models: ./models/base_models.md + - Fine-Tuning Heads: ./models/fine_tuning_heads.md - Example Notebooks: - Quick-Start-Tutorial: ./notebooks/Quick-Start-Tutorial.ipynb - Helix-mRNA: ./notebooks/Helix-mRNA.ipynb diff --git a/pyproject.toml b/pyproject.toml index 91630042..9ce6aca1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "helical" -version = "0.0.1a20" +version = "0.0.1a21" authors = [ { name="Helical Team", email="support@helical-ai.com" }, ] @@ -48,7 +48,7 @@ dependencies = [ [project.optional-dependencies] mamba-ssm = [ 'mamba-ssm==2.2.4', - 'causal-conv1d==1.4.0', + 'causal-conv1d==1.5.0.post8', ] [tool.hatch.metadata]