From e4d55ab6d74829435b76a91ca1980359e9673c0a Mon Sep 17 00:00:00 2001 From: bputzeys Date: Thu, 23 Jan 2025 20:48:10 +0100 Subject: [PATCH 1/9] Add check fi all genes are filtered --- ci/tests/test_scgpt/test_scgpt_model.py | 7 +++++++ helical/models/scgpt/model.py | 8 +++++++- 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/ci/tests/test_scgpt/test_scgpt_model.py b/ci/tests/test_scgpt/test_scgpt_model.py index b2839752..23c33b9b 100644 --- a/ci/tests/test_scgpt/test_scgpt_model.py +++ b/ci/tests/test_scgpt/test_scgpt_model.py @@ -89,6 +89,13 @@ def test_ensure_data_validity__value_error(self, data): self.scgpt.ensure_data_validity(data, "index", False) assert "total_counts" in data.obs + def test_process_data_no_matching_genes(dummy_data): + dummy_data.var['gene_ids'] = [-1, -1, -1, -1] + model = scGPT() + + with pytest.raises(ValueError): + model.process_data(dummy_data, gene_names='gene_name') + np_arr_data = ad.read_h5ad("ci/tests/data/cell_type_sample.h5ad") csr_data = ad.read_h5ad("ci/tests/data/cell_type_sample.h5ad") csr_data.X = csr_matrix(np.random.poisson(1, size=(100, 5)), dtype=np.float32) diff --git a/helical/models/scgpt/model.py b/helical/models/scgpt/model.py index 0557ddad..79f4f2a2 100644 --- a/helical/models/scgpt/model.py +++ b/helical/models/scgpt/model.py @@ -268,7 +268,13 @@ def process_data(self, # filtering adata.var["id_in_vocab"] = [ self.vocab[gene] if gene in self.vocab else -1 for gene in adata.var[self.gene_names] ] - LOGGER.info(f"Filtering out {np.sum(adata.var['id_in_vocab'] < 0)} genes to a total of {np.sum(adata.var['id_in_vocab'] >= 0)} genes with an id in the scGPT vocabulary.") + LOGGER.info(f"Filtering out {np.sum(adata.var['id_in_vocab'] < 0)} genes to a total of {np.sum(adata.var['id_in_vocab'] >= 0)} genes with an ID in the scGPT vocabulary.") + + if np.sum(adata.var["id_in_vocab"] >= 0) == 0: + message = "No matching genes found between input data and scGPT gene vocabulary. Please check the 'gene_names' in .var of the anndata input object." + LOGGER.error(message) + raise ValueError(message) + adata = adata[:, adata.var["id_in_vocab"] >= 0] # Binning will be applied after tokenization. A possible way to do is to use the unified way of binning in the data collator. From ba62371322d6950d12ae27953c33113a001a9e75 Mon Sep 17 00:00:00 2001 From: Benoit Putzeys Date: Thu, 23 Jan 2025 20:57:54 +0100 Subject: [PATCH 2/9] Adjust test to throw error when no matching genes are found --- ci/tests/test_scgpt/test_scgpt_model.py | 6 +++--- helical/models/scgpt/model.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/ci/tests/test_scgpt/test_scgpt_model.py b/ci/tests/test_scgpt/test_scgpt_model.py index 23c33b9b..e4ace5c8 100644 --- a/ci/tests/test_scgpt/test_scgpt_model.py +++ b/ci/tests/test_scgpt/test_scgpt_model.py @@ -89,12 +89,12 @@ def test_ensure_data_validity__value_error(self, data): self.scgpt.ensure_data_validity(data, "index", False) assert "total_counts" in data.obs - def test_process_data_no_matching_genes(dummy_data): - dummy_data.var['gene_ids'] = [-1, -1, -1, -1] + def test_process_data_no_matching_genes(self): + self.dummy_data.var['gene_ids'] = [1]*self.dummy_data.n_vars model = scGPT() with pytest.raises(ValueError): - model.process_data(dummy_data, gene_names='gene_name') + model.process_data(self.dummy_data, gene_names='gene_ids') np_arr_data = ad.read_h5ad("ci/tests/data/cell_type_sample.h5ad") csr_data = ad.read_h5ad("ci/tests/data/cell_type_sample.h5ad") diff --git a/helical/models/scgpt/model.py b/helical/models/scgpt/model.py index 79f4f2a2..4ce4614d 100644 --- a/helical/models/scgpt/model.py +++ b/helical/models/scgpt/model.py @@ -271,7 +271,7 @@ def process_data(self, LOGGER.info(f"Filtering out {np.sum(adata.var['id_in_vocab'] < 0)} genes to a total of {np.sum(adata.var['id_in_vocab'] >= 0)} genes with an ID in the scGPT vocabulary.") if np.sum(adata.var["id_in_vocab"] >= 0) == 0: - message = "No matching genes found between input data and scGPT gene vocabulary. Please check the 'gene_names' in .var of the anndata input object." + message = "No matching genes found between input data and scGPT gene vocabulary. Please check the gene names in .var of the anndata input object." LOGGER.error(message) raise ValueError(message) From 1396691d4d7de3f1a85a29249f99e60ddd6722d3 Mon Sep 17 00:00:00 2001 From: Benoit Putzeys Date: Thu, 23 Jan 2025 21:05:58 +0100 Subject: [PATCH 3/9] Catch error when no gene names are mapped to ensembl ids --- ci/tests/test_geneformer/test_geneformer_model.py | 6 ++++++ helical/models/geneformer/model.py | 6 +++++- 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/ci/tests/test_geneformer/test_geneformer_model.py b/ci/tests/test_geneformer/test_geneformer_model.py index 623966ad..11702a09 100644 --- a/ci/tests/test_geneformer/test_geneformer_model.py +++ b/ci/tests/test_geneformer/test_geneformer_model.py @@ -47,6 +47,12 @@ def test_process_data_mapping_to_ensemble_ids(self, geneformer, mock_data): assert mock_data.var[mock_data.var['gene_symbols'] == 'PLEKHN1']['ensembl_id'].values[0] == 'ENSG00000187583' assert mock_data.var[mock_data.var['gene_symbols'] == 'HES4']['ensembl_id'].values[0] == 'ENSG00000188290' + def test_process_data_mapping_to_ensemble_ids_resulting_in_0_genes(self, geneformer, mock_data): + # provide a gene that does not exist in the ensembl database + mock_data.var['gene_symbols'] = ['1', '2', '3'] + with pytest.raises(ValueError): + geneformer.process_data(mock_data, gene_names="gene_symbols") + @pytest.mark.parametrize("invalid_model_names", ["gf-12L-35M-i2048", "gf-34L-30M-i5000"]) def test_pass_invalid_model_name(self, invalid_model_names): with pytest.raises(ValueError): diff --git a/helical/models/geneformer/model.py b/helical/models/geneformer/model.py index bdf2d109..d088b05a 100644 --- a/helical/models/geneformer/model.py +++ b/helical/models/geneformer/model.py @@ -153,7 +153,11 @@ def process_data(self, raise ValueError(message) adata = map_gene_symbols_to_ensembl_ids(adata, gene_names) - + if adata.var["ensembl_id"].isnull().all(): + message = "All gene symbols could not be mapped to Ensembl IDs. Please check the input data." + LOGGER.info(message) + raise ValueError(message) + tokenized_cells, cell_metadata = self.tk.tokenize_anndata(adata) # tokenized_cells, cell_metadata = self.tk.tokenize_anndata(adata) From 7b02e00596a63d0d8fec686fe212f1d05e0744a2 Mon Sep 17 00:00:00 2001 From: Benoit Putzeys Date: Thu, 23 Jan 2025 21:47:31 +0100 Subject: [PATCH 4/9] Add tests for UCE gene embeddings retrieval --- ci/tests/test_uce/test_gene_embeddings.py | 39 +++++++++++++++++++++++ helical/models/uce/gene_embeddings.py | 5 +++ 2 files changed, 44 insertions(+) create mode 100644 ci/tests/test_uce/test_gene_embeddings.py diff --git a/ci/tests/test_uce/test_gene_embeddings.py b/ci/tests/test_uce/test_gene_embeddings.py new file mode 100644 index 00000000..06839620 --- /dev/null +++ b/ci/tests/test_uce/test_gene_embeddings.py @@ -0,0 +1,39 @@ +from helical.models.uce.gene_embeddings import load_gene_embeddings_adata +from anndata import AnnData +import pandas as pd +import numpy as np +from pathlib import Path +import pytest +from pathlib import Path +CACHE_DIR_HELICAL = Path(Path.home(), '.cache', 'helical', 'models') + +class TestUCEGeneEmbeddings: + + adata = AnnData( + X=np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), + obs=pd.DataFrame({"species": ["human", "mouse", "rat"]}), + var=pd.DataFrame({"gene": ["gene1", "gene2", "gene3"]}) + ) + species = ["human"] + embedding_model = "ESM2" + embeddings_path = Path(CACHE_DIR_HELICAL, 'uce', "protein_embeddings") + + def test_load_gene_embeddings_adata_filtering_all_genes(self): + with pytest.raises(ValueError): + load_gene_embeddings_adata(self.adata, self.species, self.embedding_model, self.embeddings_path) + + def test_load_gene_embeddings_adata_filtering_no_genes(self): + self.adata.var_names = ['hoxa6', 'cav2', 'txk'] + anndata, mapping_dict = load_gene_embeddings_adata(self.adata, self.species, self.embedding_model, self.embeddings_path) + assert (anndata.var_names == ['hoxa6', 'cav2', 'txk']).all() + assert (anndata.obs == self.adata.obs).all().all() + assert (anndata.X == self.adata.X).all() + assert len(mapping_dict['human']) == 19790 + + def test_load_gene_embeddings_adata_filtering_some_genes(self): + self.adata.var_names = ['hoxa6', 'cav2', '1'] + anndata, mapping_dict = load_gene_embeddings_adata(self.adata, self.species, self.embedding_model, self.embeddings_path) + assert (anndata.var_names == ['hoxa6', 'cav2']).all() + assert (anndata.obs == self.adata.obs).all().all() + assert (anndata.X == [[1, 2], [4, 5], [7, 8]]).all() + assert len(mapping_dict['human']) == 19790 \ No newline at end of file diff --git a/helical/models/uce/gene_embeddings.py b/helical/models/uce/gene_embeddings.py index 55f5eee9..1fb607df 100644 --- a/helical/models/uce/gene_embeddings.py +++ b/helical/models/uce/gene_embeddings.py @@ -79,6 +79,11 @@ def load_gene_embeddings_adata(adata: AnnData, species: list, embedding_model: s filtered = adata.var_names.shape[0] - filtered_adata.var_names.shape[0] LOGGER.info(f'Filtered out {filtered} genes to a total of {filtered_adata.var_names.shape[0]} genes with embeddings.') + if filtered_adata.var_names.shape[0] == 0: + message = "No matching genes found between input data and UCE gene embedding vocabulary. Please check the gene names in .var of the anndata input object." + LOGGER.error(message) + raise ValueError(message) + # Load gene symbols for desired species for later use with indexes species_to_all_gene_symbols = { species: [ From 4718f8a91f6657691ccbb5f85962445557343cd7 Mon Sep 17 00:00:00 2001 From: Matthew Wood <62712722+mattwoodx@users.noreply.github.com> Date: Mon, 27 Jan 2025 09:09:03 +0100 Subject: [PATCH 5/9] Small fixes (#178) * Allow warnings to be shown by the logger * Remove excessive installs at the beginning of Helix-mRNA notebook * Update model categories in docs --- examples/notebooks/Helix-mRNA.ipynb | 871 ++++++++++++++-------------- helical/__init__.py | 8 +- mkdocs.yml | 29 +- 3 files changed, 446 insertions(+), 462 deletions(-) diff --git a/examples/notebooks/Helix-mRNA.ipynb b/examples/notebooks/Helix-mRNA.ipynb index d34513b8..3433a3db 100644 --- a/examples/notebooks/Helix-mRNA.ipynb +++ b/examples/notebooks/Helix-mRNA.ipynb @@ -1,443 +1,432 @@ { - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Use Helix-mRNA" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Access the Helical GitHub [here](https://github.com/helicalAI)!" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "**In this notebook we will dive into using our latest mRNA Bio Foundation Model, Helix-mRNA.**\n", - "\n", - "**We will get and plot embeddings for our data.**\n", - "\n", - "**We will fine-tune the model both using the Helical package**" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## If running on colab, run the cell below. Comment out if running locally" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "!pip uninstall torch torchtext torchaudio -y\n", - "!pip install torch==2.3.0 torchtext torchaudio --index-url https://download.pytorch.org/whl/cu121\n", - "!pip uninstall mamba-ssm causal-conv1d -y\n", - "!pip install mamba-ssm==2.2.2 --no-cache-dir" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "!pip install helical\n", - "!pip uninstall causal-conv1d -y" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Imports" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "INFO:datasets:PyTorch version 2.3.0 available.\n", - "INFO:datasets:Polars version 0.20.31 available.\n", - "INFO:datasets:JAX version 0.4.31 available.\n" - ] - } - ], - "source": [ - "from helical import HelixmRNAConfig, HelixmRNA, HelixmRNAFineTuningModel\n", - "import subprocess\n", - "import torch\n", - "import pandas as pd\n", - "\n", - "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Download one of CodonBERT's fine-tuning benchmarks" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "url = \"https://raw.githubusercontent.com/Sanofi-Public/CodonBERT/refs/heads/master/benchmarks/CodonBERT/data/fine-tune/mRFP_Expression.csv\"\n", - "\n", - "output_filename = \"mRFP_Expression.csv\"\n", - "wget_command = [\"wget\", \"-O\", output_filename, url]\n", - "\n", - "try:\n", - " subprocess.run(wget_command, check=True)\n", - " print(f\"File downloaded successfully as {output_filename}\")\n", - "except subprocess.CalledProcessError as e:\n", - " print(f\"Error occurred: {e}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Load the dataset as a pandas dataframe and get the splits\n", - "- For this example we take a subset of the splits, feel free to run it on the entire dataset!" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "dataset = pd.read_csv(output_filename)\n", - "train_data = dataset[dataset[\"Split\"] == \"train\"][:10]\n", - "eval_data = dataset[dataset[\"Split\"] == \"val\"][:5]\n", - "test_data = dataset[dataset[\"Split\"] == \"test\"][:5]" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Define our Helix-mRNA model and desired configs" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "INFO:helical.models.helix_mrna.model:Helix-mRNA initialized successfully.\n" - ] - } - ], - "source": [ - "# We set the max length to the maximum length of the sequences in the training data + 10 to include space for special tokens\n", - "helix_mrna_config = HelixmRNAConfig(device=device, batch_size=5, max_length=max(len(s) for s in train_data[\"Sequence\"])+10)\n", - "helix_mrna = HelixmRNA(helix_mrna_config)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Process our training sequences to tokenize them and prepare them for the model" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [], - "source": [ - "processed_train_data = helix_mrna.process_data(train_data[\"Sequence\"].to_list())" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Generate embeddings for the train data" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "- We get an embeddings for each letter/token in the sequence, in this case 100 embeddings for each of the 688 tokens and our embedding dimension is 256\n", - "- Because the model has a recurrent nature, our final non-special token embedding at the second last position encapsulates everything that came before it" - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Getting embeddings: 100%|██████████| 20/20 [00:00<00:00, 71.50it/s]" - ] + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Use Helix-mRNA" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Access the Helical GitHub [here](https://github.com/helicalAI)!" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**In this notebook we will dive into using our latest mRNA Bio Foundation Model, Helix-mRNA.**\n", + "\n", + "**We will get and plot embeddings for our data.**\n", + "\n", + "**We will fine-tune the model both using the Helical package**" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "If running on a CUDA device compatible with mamba-ssm and causal-conv1d install the package below, otherwise remove the [mamba-ssm] optional dependency\n", + "- If running on colab, remove the [mamba-ssm] dependency" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!pip install --upgrade helical[mamba-ssm]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Imports" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:datasets:PyTorch version 2.3.0 available.\n", + "INFO:datasets:Polars version 0.20.31 available.\n", + "INFO:datasets:JAX version 0.4.31 available.\n" + ] + } + ], + "source": [ + "from helical import HelixmRNAConfig, HelixmRNA, HelixmRNAFineTuningModel\n", + "import subprocess\n", + "import torch\n", + "import pandas as pd\n", + "\n", + "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Download one of CodonBERT's fine-tuning benchmarks" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "url = \"https://raw.githubusercontent.com/Sanofi-Public/CodonBERT/refs/heads/master/benchmarks/CodonBERT/data/fine-tune/mRFP_Expression.csv\"\n", + "\n", + "output_filename = \"mRFP_Expression.csv\"\n", + "wget_command = [\"wget\", \"-O\", output_filename, url]\n", + "\n", + "try:\n", + " subprocess.run(wget_command, check=True)\n", + " print(f\"File downloaded successfully as {output_filename}\")\n", + "except subprocess.CalledProcessError as e:\n", + " print(f\"Error occurred: {e}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Load the dataset as a pandas dataframe and get the splits\n", + "- For this example we take a subset of the splits, feel free to run it on the entire dataset!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "dataset = pd.read_csv(output_filename)\n", + "train_data = dataset[dataset[\"Split\"] == \"train\"][:10]\n", + "eval_data = dataset[dataset[\"Split\"] == \"val\"][:5]\n", + "test_data = dataset[dataset[\"Split\"] == \"test\"][:5]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Define our Helix-mRNA model and desired configs" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:helical.models.helix_mrna.model:Helix-mRNA initialized successfully.\n" + ] + } + ], + "source": [ + "# We set the max length to the maximum length of the sequences in the training data + 10 to include space for special tokens\n", + "helix_mrna_config = HelixmRNAConfig(device=device, batch_size=1, max_length=max(len(s) for s in train_data[\"Sequence\"])+10)\n", + "helix_mrna = HelixmRNA(helix_mrna_config)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Process our training sequences to tokenize them and prepare them for the model" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "processed_train_data = helix_mrna.process_data(train_data[\"Sequence\"].to_list())" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Generate embeddings for the train data" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "- We get an embeddings for each letter/token in the sequence, in this case 100 embeddings for each of the 688 tokens and our embedding dimension is 256\n", + "- Because the model has a recurrent nature, our final non-special token embedding at the second last position encapsulates everything that came before it" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Getting embeddings: 100%|██████████| 20/20 [00:00<00:00, 71.50it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(100, 256)\n", + "[[-9.70379915e-04 5.94667019e-03 1.07590854e-02 5.22067677e-03\n", + " -9.54915071e-04 -6.74154516e-03 -2.91207526e-03 -1.49831397e-03\n", + " 1.78750437e-02 5.13957115e-03 8.79890576e-04 1.21943112e-02\n", + " -1.92209042e-03 3.27306171e-03 -2.27077748e-03 4.50014602e-04\n", + " 7.30314665e-03 -8.66744318e-04 -8.81821662e-03 -7.57190645e-01\n", + " 1.89280566e-02 -4.05776373e-04 6.08320069e-03 -1.78794132e-03\n", + " -8.79776548e-04 -8.19147026e-05 9.60175938e-04 -8.30806512e-03\n", + " 5.66601008e-03 -5.93393855e-03 -5.19109843e-03 6.86887605e-03\n", + " -7.94085041e-02 -5.38914884e-03 -1.55241350e-02 -2.42359545e-02\n", + " 2.57678051e-03 -9.53892432e-03 -7.16619950e-04 1.50164040e-02\n", + " -9.01486576e-01 -4.68801707e-03 3.71015654e-03 -1.07593695e-02\n", + " 9.67101427e-04 5.75249782e-03 2.86138593e-03 -6.41007500e-04\n", + " -3.93231586e-03 -5.53809397e-04 1.72096007e-02 -8.10448000e-06\n", + " 1.20042302e-02 -7.83413649e-03 2.40328256e-03 1.44813021e-04\n", + " 6.37711585e-03 -2.75100190e-02 -9.19151399e-03 2.25025918e-02\n", + " -2.71240231e-02 3.18764849e-03 -8.12906027e-03 -9.14498232e-03\n", + " 3.32334498e-03 8.77279043e-03 2.19076360e-03 -4.82588913e-03\n", + " 7.93280269e-05 -4.37264703e-03 -1.03688613e-02 -1.15277218e-02\n", + " -4.73860680e-04 1.04218733e-03 4.24548006e-03 3.50359431e-03\n", + " -1.29856959e-01 1.34938598e-01 -3.02679911e-02 -2.39217840e-02\n", + " -2.36590131e-04 -4.36108152e-04 4.19709226e-03 -8.65293201e-03\n", + " 1.99613022e-03 1.79194030e-03 2.84497248e-04 1.43997103e-03\n", + " -1.53440237e-01 -2.64659454e-03 2.34050487e-04 -2.68558436e-03\n", + " -2.99103018e-02 -4.54764161e-03 7.06061395e-03 -2.53863144e-03\n", + " 8.90936144e-03 -7.19320118e-01 7.57683277e-01 -1.85208302e-02\n", + " -1.54839545e-01 -2.54138529e-01 8.30988749e-04 -8.85983836e-03\n", + " -9.01097711e-03 -7.00991787e-03 -1.82072073e-01 -5.14741063e-01\n", + " -2.93075689e-03 2.55425880e-03 1.37590896e-03 5.38261468e-03\n", + " -5.74133135e-02 -4.50886175e-04 1.39132570e-02 -9.26930178e-03\n", + " 3.89014231e-03 3.58247235e-02 1.38020525e-02 4.48753638e-03\n", + " 4.69827838e-03 5.32380529e-02 7.67468300e-04 -2.27806643e-02\n", + " 9.79826669e-04 3.29421629e-04 2.56255385e-03 -3.15385172e-04\n", + " 1.13730943e-02 5.02255885e-03 7.63128162e-04 -4.30183439e-03\n", + " -1.41088907e-02 -7.07946122e-02 2.18413552e-04 -4.30437940e-04\n", + " 5.93306264e-03 3.88289336e-03 -6.69274572e-03 -1.05123809e-02\n", + " 7.17154052e-03 9.30194370e-03 -2.66307388e-02 -2.35042372e-03\n", + " -3.61418119e-03 -1.88636947e-02 4.10996377e-03 1.86230207e-03\n", + " -7.77591905e-03 1.07999649e-02 -2.15348396e-02 -1.56054425e-03\n", + " -4.75367473e-04 -2.42964807e-03 1.37075689e-03 -1.18554395e-03\n", + " 1.96172502e-02 8.72136280e-03 -2.54987436e-03 -1.78763457e-03\n", + " 1.48834437e-01 4.15487972e-04 -8.82838969e-04 -4.85490542e-04\n", + " 9.73013118e-02 1.01735163e-02 9.76046920e-03 7.66289607e-03\n", + " 3.93118672e-02 5.41610224e-03 -7.19898380e-03 -4.61950190e-02\n", + " 6.28079474e-03 2.30385065e-02 -1.32811114e-01 2.61072395e-03\n", + " 2.72905454e-03 -8.26253928e-03 2.76575685e-02 -1.16535993e-02\n", + " 7.09296510e-05 -5.02431765e-03 -2.00841855e-02 -9.82477888e-03\n", + " 1.99634713e-04 -2.33941106e-03 -1.01937279e-02 -6.17030673e-02\n", + " 5.41278534e-03 9.48928879e-04 9.36821289e-03 -7.82263931e-03\n", + " -1.20594129e-02 -6.56401785e-03 -8.18305537e-02 8.73102434e-03\n", + " -2.41522095e-03 6.06243312e-03 -2.66978621e-01 8.72417178e-04\n", + " 8.10213108e-03 -1.89128786e-01 8.86955822e-04 1.45062711e-02\n", + " -4.65695048e-03 -3.56003083e-03 -1.77745167e-02 -3.33940163e-02\n", + " 1.01557758e-04 -8.14760383e-03 8.52145813e-03 -9.25995596e-03\n", + " 4.04966250e-03 3.44780415e-01 -4.55286279e-02 1.27975168e-02\n", + " 5.34113357e-03 -1.58847857e-03 4.65576863e-03 -3.07517336e-03\n", + " 2.26003435e-02 -3.24756862e-03 7.61093199e-02 -8.04630481e-03\n", + " 1.46187656e-02 -4.15891828e-03 -7.69484183e-03 -6.56060642e-03\n", + " 6.32394385e-03 1.89167308e-03 -1.65223163e-02 -7.15268105e-02\n", + " -5.13067655e-02 8.95772595e-03 3.47553840e-04 3.91185429e-04\n", + " 1.50305599e-01 5.39071718e-03 -3.56106623e-03 -1.07512353e-02\n", + " -1.70928031e-01 2.04306114e-02 2.14800192e-03 -1.82585061e-01\n", + " -4.20546830e-02 1.20053962e-02 -6.52526272e-04 1.29553266e-02\n", + " 2.05104008e-01 -3.85842402e-03 8.15556012e-03 -6.55666053e-01\n", + " -6.22088835e-03 1.99010246e-03 -1.32145118e-02 1.12704304e-03]]\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "embeddings = helix_mrna.get_embeddings(processed_train_data)\n", + "embeddings = embeddings[:, -2, :]\n", + "print(embeddings.shape)\n", + "print(embeddings[:1])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Fine-tuning the model on our data\n", + "- This is a regression task and so our output is 1 continuous value" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:helical.models.helix_mrna.model:Helix-mRNA initialized successfully.\n" + ] + } + ], + "source": [ + "helix_mrna_fine_tuning_model = HelixmRNAFineTuningModel(helix_mrna_config=helix_mrna_config, fine_tuning_head=\"regression\", output_size=1)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Our training data is already processed since the standard Helix-mRNA model and fine-tuning model take the same input!\n", + "- We process our eval and test data" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "processed_eval_data = helix_mrna_fine_tuning_model.process_data(eval_data[\"Sequence\"].to_list())\n", + "processed_test_data = helix_mrna_fine_tuning_model.process_data(test_data[\"Sequence\"].to_list())" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Run fine-tuning on the model for this small sample of data" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:helical.models.helix_mrna.fine_tuning_model:Unfreezing the last 2 layers of the Helix_mRNA model.\n", + "INFO:helical.models.helix_mrna.fine_tuning_model:Starting Fine-Tuning\n", + "Fine-Tuning: epoch 1/5: 100%|██████████| 20/20 [00:00<00:00, 49.31it/s, loss=86.9]\n", + "Fine-Tuning Validation: 100%|██████████| 4/4 [00:00<00:00, 56.86it/s, val_loss=88.8]\n", + "Fine-Tuning: epoch 2/5: 100%|██████████| 20/20 [00:00<00:00, 59.90it/s, loss=80.6]\n", + "Fine-Tuning Validation: 100%|██████████| 4/4 [00:00<00:00, 53.01it/s, val_loss=83.3]\n", + "Fine-Tuning: epoch 3/5: 100%|██████████| 20/20 [00:00<00:00, 51.32it/s, loss=74.9]\n", + "Fine-Tuning Validation: 100%|██████████| 4/4 [00:00<00:00, 167.73it/s, val_loss=76.6]\n", + "Fine-Tuning: epoch 4/5: 100%|██████████| 20/20 [00:00<00:00, 49.26it/s, loss=67.3]\n", + "Fine-Tuning Validation: 100%|██████████| 4/4 [00:00<00:00, 54.98it/s, val_loss=67.4]\n", + "Fine-Tuning: epoch 5/5: 100%|██████████| 20/20 [00:00<00:00, 60.40it/s, loss=59.8]\n", + "Fine-Tuning Validation: 100%|██████████| 4/4 [00:00<00:00, 50.23it/s, val_loss=61.1]\n", + "INFO:helical.models.helix_mrna.fine_tuning_model:Fine-Tuning Complete. Epochs: 5\n" + ] + } + ], + "source": [ + "helix_mrna_fine_tuning_model.train(train_dataset=processed_train_data, \n", + " train_labels=train_data[\"Value\"].to_numpy().reshape(-1, 1),\n", + " validation_dataset=processed_eval_data, \n", + " validation_labels= eval_data[\"Value\"].to_numpy().reshape(-1, 1),\n", + " epochs=5,\n", + " loss_function=torch.nn.MSELoss(),\n", + " trainable_layers=2)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Get outputs from our model on the test data" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Generating outputs: 100%|██████████| 2/2 [00:00<00:00, 46.17it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[[2.682281 ]\n", + " [2.4182768]\n", + " [2.4362845]\n", + " [2.6120207]\n", + " [2.6543183]\n", + " [2.6988027]\n", + " [2.671821 ]\n", + " [2.144202 ]\n", + " [2.6866376]\n", + " [2.6734226]]\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "outputs = helix_mrna_fine_tuning_model.get_outputs(processed_test_data)\n", + "print(outputs)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "helical-package", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.8" + } }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "(100, 256)\n", - "[[-9.70379915e-04 5.94667019e-03 1.07590854e-02 5.22067677e-03\n", - " -9.54915071e-04 -6.74154516e-03 -2.91207526e-03 -1.49831397e-03\n", - " 1.78750437e-02 5.13957115e-03 8.79890576e-04 1.21943112e-02\n", - " -1.92209042e-03 3.27306171e-03 -2.27077748e-03 4.50014602e-04\n", - " 7.30314665e-03 -8.66744318e-04 -8.81821662e-03 -7.57190645e-01\n", - " 1.89280566e-02 -4.05776373e-04 6.08320069e-03 -1.78794132e-03\n", - " -8.79776548e-04 -8.19147026e-05 9.60175938e-04 -8.30806512e-03\n", - " 5.66601008e-03 -5.93393855e-03 -5.19109843e-03 6.86887605e-03\n", - " -7.94085041e-02 -5.38914884e-03 -1.55241350e-02 -2.42359545e-02\n", - " 2.57678051e-03 -9.53892432e-03 -7.16619950e-04 1.50164040e-02\n", - " -9.01486576e-01 -4.68801707e-03 3.71015654e-03 -1.07593695e-02\n", - " 9.67101427e-04 5.75249782e-03 2.86138593e-03 -6.41007500e-04\n", - " -3.93231586e-03 -5.53809397e-04 1.72096007e-02 -8.10448000e-06\n", - " 1.20042302e-02 -7.83413649e-03 2.40328256e-03 1.44813021e-04\n", - " 6.37711585e-03 -2.75100190e-02 -9.19151399e-03 2.25025918e-02\n", - " -2.71240231e-02 3.18764849e-03 -8.12906027e-03 -9.14498232e-03\n", - " 3.32334498e-03 8.77279043e-03 2.19076360e-03 -4.82588913e-03\n", - " 7.93280269e-05 -4.37264703e-03 -1.03688613e-02 -1.15277218e-02\n", - " -4.73860680e-04 1.04218733e-03 4.24548006e-03 3.50359431e-03\n", - " -1.29856959e-01 1.34938598e-01 -3.02679911e-02 -2.39217840e-02\n", - " -2.36590131e-04 -4.36108152e-04 4.19709226e-03 -8.65293201e-03\n", - " 1.99613022e-03 1.79194030e-03 2.84497248e-04 1.43997103e-03\n", - " -1.53440237e-01 -2.64659454e-03 2.34050487e-04 -2.68558436e-03\n", - " -2.99103018e-02 -4.54764161e-03 7.06061395e-03 -2.53863144e-03\n", - " 8.90936144e-03 -7.19320118e-01 7.57683277e-01 -1.85208302e-02\n", - " -1.54839545e-01 -2.54138529e-01 8.30988749e-04 -8.85983836e-03\n", - " -9.01097711e-03 -7.00991787e-03 -1.82072073e-01 -5.14741063e-01\n", - " -2.93075689e-03 2.55425880e-03 1.37590896e-03 5.38261468e-03\n", - " -5.74133135e-02 -4.50886175e-04 1.39132570e-02 -9.26930178e-03\n", - " 3.89014231e-03 3.58247235e-02 1.38020525e-02 4.48753638e-03\n", - " 4.69827838e-03 5.32380529e-02 7.67468300e-04 -2.27806643e-02\n", - " 9.79826669e-04 3.29421629e-04 2.56255385e-03 -3.15385172e-04\n", - " 1.13730943e-02 5.02255885e-03 7.63128162e-04 -4.30183439e-03\n", - " -1.41088907e-02 -7.07946122e-02 2.18413552e-04 -4.30437940e-04\n", - " 5.93306264e-03 3.88289336e-03 -6.69274572e-03 -1.05123809e-02\n", - " 7.17154052e-03 9.30194370e-03 -2.66307388e-02 -2.35042372e-03\n", - " -3.61418119e-03 -1.88636947e-02 4.10996377e-03 1.86230207e-03\n", - " -7.77591905e-03 1.07999649e-02 -2.15348396e-02 -1.56054425e-03\n", - " -4.75367473e-04 -2.42964807e-03 1.37075689e-03 -1.18554395e-03\n", - " 1.96172502e-02 8.72136280e-03 -2.54987436e-03 -1.78763457e-03\n", - " 1.48834437e-01 4.15487972e-04 -8.82838969e-04 -4.85490542e-04\n", - " 9.73013118e-02 1.01735163e-02 9.76046920e-03 7.66289607e-03\n", - " 3.93118672e-02 5.41610224e-03 -7.19898380e-03 -4.61950190e-02\n", - " 6.28079474e-03 2.30385065e-02 -1.32811114e-01 2.61072395e-03\n", - " 2.72905454e-03 -8.26253928e-03 2.76575685e-02 -1.16535993e-02\n", - " 7.09296510e-05 -5.02431765e-03 -2.00841855e-02 -9.82477888e-03\n", - " 1.99634713e-04 -2.33941106e-03 -1.01937279e-02 -6.17030673e-02\n", - " 5.41278534e-03 9.48928879e-04 9.36821289e-03 -7.82263931e-03\n", - " -1.20594129e-02 -6.56401785e-03 -8.18305537e-02 8.73102434e-03\n", - " -2.41522095e-03 6.06243312e-03 -2.66978621e-01 8.72417178e-04\n", - " 8.10213108e-03 -1.89128786e-01 8.86955822e-04 1.45062711e-02\n", - " -4.65695048e-03 -3.56003083e-03 -1.77745167e-02 -3.33940163e-02\n", - " 1.01557758e-04 -8.14760383e-03 8.52145813e-03 -9.25995596e-03\n", - " 4.04966250e-03 3.44780415e-01 -4.55286279e-02 1.27975168e-02\n", - " 5.34113357e-03 -1.58847857e-03 4.65576863e-03 -3.07517336e-03\n", - " 2.26003435e-02 -3.24756862e-03 7.61093199e-02 -8.04630481e-03\n", - " 1.46187656e-02 -4.15891828e-03 -7.69484183e-03 -6.56060642e-03\n", - " 6.32394385e-03 1.89167308e-03 -1.65223163e-02 -7.15268105e-02\n", - " -5.13067655e-02 8.95772595e-03 3.47553840e-04 3.91185429e-04\n", - " 1.50305599e-01 5.39071718e-03 -3.56106623e-03 -1.07512353e-02\n", - " -1.70928031e-01 2.04306114e-02 2.14800192e-03 -1.82585061e-01\n", - " -4.20546830e-02 1.20053962e-02 -6.52526272e-04 1.29553266e-02\n", - " 2.05104008e-01 -3.85842402e-03 8.15556012e-03 -6.55666053e-01\n", - " -6.22088835e-03 1.99010246e-03 -1.32145118e-02 1.12704304e-03]]\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "\n" - ] - } - ], - "source": [ - "embeddings = helix_mrna.get_embeddings(processed_train_data)\n", - "embeddings = embeddings[:, -2, :]\n", - "print(embeddings.shape)\n", - "print(embeddings[:1])" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Fine-tuning the model on our data\n", - "- This is a regression task and so our output is 1 continuous value" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "INFO:helical.models.helix_mrna.model:Helix-mRNA initialized successfully.\n" - ] - } - ], - "source": [ - "helix_mrna_fine_tuning_model = HelixmRNAFineTuningModel(helix_mrna_config=helix_mrna_config, fine_tuning_head=\"regression\", output_size=1)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Our training data is already processed since the standard Helix-mRNA model and fine-tuning model take the same input!\n", - "- We process our eval and test data" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [], - "source": [ - "processed_eval_data = helix_mrna_fine_tuning_model.process_data(eval_data[\"Sequence\"].to_list())\n", - "processed_test_data = helix_mrna_fine_tuning_model.process_data(test_data[\"Sequence\"].to_list())" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Run fine-tuning on the model for this small sample of data" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "INFO:helical.models.helix_mrna.fine_tuning_model:Unfreezing the last 2 layers of the Helix_mRNA model.\n", - "INFO:helical.models.helix_mrna.fine_tuning_model:Starting Fine-Tuning\n", - "Fine-Tuning: epoch 1/5: 100%|██████████| 20/20 [00:00<00:00, 49.31it/s, loss=86.9]\n", - "Fine-Tuning Validation: 100%|██████████| 4/4 [00:00<00:00, 56.86it/s, val_loss=88.8]\n", - "Fine-Tuning: epoch 2/5: 100%|██████████| 20/20 [00:00<00:00, 59.90it/s, loss=80.6]\n", - "Fine-Tuning Validation: 100%|██████████| 4/4 [00:00<00:00, 53.01it/s, val_loss=83.3]\n", - "Fine-Tuning: epoch 3/5: 100%|██████████| 20/20 [00:00<00:00, 51.32it/s, loss=74.9]\n", - "Fine-Tuning Validation: 100%|██████████| 4/4 [00:00<00:00, 167.73it/s, val_loss=76.6]\n", - "Fine-Tuning: epoch 4/5: 100%|██████████| 20/20 [00:00<00:00, 49.26it/s, loss=67.3]\n", - "Fine-Tuning Validation: 100%|██████████| 4/4 [00:00<00:00, 54.98it/s, val_loss=67.4]\n", - "Fine-Tuning: epoch 5/5: 100%|██████████| 20/20 [00:00<00:00, 60.40it/s, loss=59.8]\n", - "Fine-Tuning Validation: 100%|██████████| 4/4 [00:00<00:00, 50.23it/s, val_loss=61.1]\n", - "INFO:helical.models.helix_mrna.fine_tuning_model:Fine-Tuning Complete. Epochs: 5\n" - ] - } - ], - "source": [ - "helix_mrna_fine_tuning_model.train(train_dataset=processed_train_data, \n", - " train_labels=train_data[\"Value\"].to_numpy().reshape(-1, 1),\n", - " validation_dataset=processed_eval_data, \n", - " validation_labels= eval_data[\"Value\"].to_numpy().reshape(-1, 1),\n", - " epochs=5,\n", - " loss_function=torch.nn.MSELoss(),\n", - " trainable_layers=2)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Get outputs from our model on the test data" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Generating outputs: 100%|██████████| 2/2 [00:00<00:00, 46.17it/s]" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[[2.682281 ]\n", - " [2.4182768]\n", - " [2.4362845]\n", - " [2.6120207]\n", - " [2.6543183]\n", - " [2.6988027]\n", - " [2.671821 ]\n", - " [2.144202 ]\n", - " [2.6866376]\n", - " [2.6734226]]\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "\n" - ] - } - ], - "source": [ - "outputs = helix_mrna_fine_tuning_model.get_outputs(processed_test_data)\n", - "print(outputs)" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "helical-package", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.8" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} + "nbformat": 4, + "nbformat_minor": 2 + } + \ No newline at end of file diff --git a/helical/__init__.py b/helical/__init__.py index e5fa83ab..65619bb3 100644 --- a/helical/__init__.py +++ b/helical/__init__.py @@ -3,10 +3,6 @@ logging.captureWarnings(True) -class InfoAndErrorFilter(logging.Filter): - def filter(self, record): - return record.levelno in (logging.INFO, logging.ERROR) - for handler in logging.root.handlers[:]: logging.root.removeHandler(handler) @@ -16,8 +12,6 @@ def filter(self, record): handler = logging.StreamHandler() handler.setLevel(logging.INFO) -handler.addFilter(InfoAndErrorFilter()) - formatter = logging.Formatter('%(levelname)s:%(name)s:%(message)s') handler.setFormatter(formatter) @@ -41,4 +35,4 @@ def filter(self, record): from .models.caduceus import Caduceus, CaduceusConfig, CaduceusFineTuningModel except: LOGGER = logging.getLogger(__name__) - LOGGER.info("Caduceus not available: If you want to use this model, ensure you have a CUDA GPU and have installed the optional helical[mamba-ssm] dependencies.") \ No newline at end of file + LOGGER.info("Caduceus not available: If you want to use this model, ensure you have a CUDA GPU and have installed the optional helical[mamba-ssm] dependencies.") diff --git a/mkdocs.yml b/mkdocs.yml index 3e228d6f..cac2519b 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -4,17 +4,7 @@ repo_url: https://github.com/helicalAI/helical repo_name: helicalAI/helical copyright: Helical Team
Copyright © 2024 nav: - - RNA Models: - - Helix-mRNA: - - Model Card: ./model_cards/helix_mrna.md - - Config: ./configs/helix_mrna_config.md - - Model: ./models/helix_mrna.md - - Fine-Tuning Model: ./fine_tuning_models/helix_mrna_fine_tune.md - - Mamba2-mRNA: - - Model Card: ./model_cards/mamba2_mrna.md - - Config: ./configs/mamba2_mrna_config.md - - Model: ./models/mamba2_mrna.md - - Fine-Tuning Model: ./fine_tuning_models/mamba2_mrna_fine_tune.md + - Single-Cell Models: - Geneformer: - Model Card: ./model_cards/geneformer.md - Config: ./configs/geneformer_config.md @@ -30,7 +20,18 @@ nav: - Config: ./configs/uce_config.md - Model: ./models/uce.md - Fine-Tuning Model: ./fine_tuning_models/uce_fine_tune.md - - DNA Models: + - RNA Sequence Models: + - Helix-mRNA: + - Model Card: ./model_cards/helix_mrna.md + - Config: ./configs/helix_mrna_config.md + - Model: ./models/helix_mrna.md + - Fine-Tuning Model: ./fine_tuning_models/helix_mrna_fine_tune.md + - Mamba2-mRNA: + - Model Card: ./model_cards/mamba2_mrna.md + - Config: ./configs/mamba2_mrna_config.md + - Model: ./models/mamba2_mrna.md + - Fine-Tuning Model: ./fine_tuning_models/mamba2_mrna_fine_tune.md + - DNA Sequence Models: - HyenaDNA: - Model Card: ./model_cards/hyenadna.md - Config: ./configs/hyenadna_config.md @@ -41,8 +42,8 @@ nav: - Config: ./configs/caduceus_config.md - Model: ./models/caduceus.md - Fine-Tuning Model: ./fine_tuning_models/caduceus_fine_tune.md - - Helical Base Models: ./models/base_models.md - - Fine-Tuning Heads: ./models/fine_tuning_heads.md + - Helical Base Models: ./models/base_models.md + - Fine-Tuning Heads: ./models/fine_tuning_heads.md - Example Notebooks: - Quick-Start-Tutorial: ./notebooks/Quick-Start-Tutorial.ipynb - Helix-mRNA: ./notebooks/Helix-mRNA.ipynb From e900c4b8b21a119dd8916d4ab04d88f187fed483 Mon Sep 17 00:00:00 2001 From: Giovanni Ortolani Date: Thu, 30 Jan 2025 08:55:19 +0000 Subject: [PATCH 6/9] Add hash for Geneformer v2 with 33 layers. (#179) Co-authored-by: giogix2 --- helical/utils/downloader.py | 1 + 1 file changed, 1 insertion(+) diff --git a/helical/utils/downloader.py b/helical/utils/downloader.py index 35c115b2..d4a143c7 100644 --- a/helical/utils/downloader.py +++ b/helical/utils/downloader.py @@ -17,6 +17,7 @@ HASH_DICT = { 'uce/4layer_model.torch': '16430370e0d672c8db6e275440e7974d2fd0a21f29aa9299e141085f82a5a886', + 'uce/33l_8ep_1024t_1280.torch': 'aa6457a0eb2e91d8382d96fb455456e40a9423a00509ea296079a75b1a9390c0', 'uce/all_tokens.torch': 'e3e3ad03a9f8fdca8babec5b0c72f7f4043a4bec2e3eb009b8fe1b28d984c93a', 'uce/species_chrom.csv': '7f5d32e6adcc3786c613043a4de8e2a47187935cfb9a1d3fcf7373eb50caebf7', 'uce/species_offsets.pkl': 'abda5b2bc4018187e408623b292686a061912f449daceb4c9c9603caf0d62538', From 93e6596468491dcb9925014df04e4236026d6b61 Mon Sep 17 00:00:00 2001 From: Benoit Putzeys Date: Tue, 4 Feb 2025 11:22:43 +0100 Subject: [PATCH 7/9] Bump causal-conv1d --- README.md | 3 ++- pyproject.toml | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 458cac33..b0c9d32a 100644 --- a/README.md +++ b/README.md @@ -72,7 +72,8 @@ or in case you're installing from the Helical repo cloned locally: pip install .[mamba-ssm] ``` -Note: make sure your machine has GPU(s) and Cuda installed. Currently this is a requirement for the packages mamba-ssm and causal-conv1d. +Note: make sure your machine has GPU(s) and Cuda installed. Currently this is a requirement for the packages mamba-ssm and causal-conv1d. +Also make sure `torch` is already installed. ### Singularity (Optional) If you desire to run your code in a singularity file, you can use the [singularity.def](./singularity.def) file and build an apptainer with it: diff --git a/pyproject.toml b/pyproject.toml index 91630042..88d9cab9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,7 +48,7 @@ dependencies = [ [project.optional-dependencies] mamba-ssm = [ 'mamba-ssm==2.2.4', - 'causal-conv1d==1.4.0', + 'causal-conv1d==1.5.0.post8', ] [tool.hatch.metadata] From 5bcd4f16e132f4f9d7bd7be12fbeffd7bfc59b74 Mon Sep 17 00:00:00 2001 From: Benoit Putzeys Date: Tue, 4 Feb 2025 11:23:41 +0100 Subject: [PATCH 8/9] Update version --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 88d9cab9..9ce6aca1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "helical" -version = "0.0.1a20" +version = "0.0.1a21" authors = [ { name="Helical Team", email="support@helical-ai.com" }, ] From 5fd39eb2faeaaf29af105ae1cd8f8d40c46e96d5 Mon Sep 17 00:00:00 2001 From: Benoit Putzeys Date: Tue, 4 Feb 2025 11:38:54 +0100 Subject: [PATCH 9/9] Clarify installation steps for mamba-ssm --- README.md | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index b0c9d32a..3dd75048 100644 --- a/README.md +++ b/README.md @@ -72,8 +72,9 @@ or in case you're installing from the Helical repo cloned locally: pip install .[mamba-ssm] ``` -Note: make sure your machine has GPU(s) and Cuda installed. Currently this is a requirement for the packages mamba-ssm and causal-conv1d. -Also make sure `torch` is already installed. +Note: +- Make sure your machine has GPU(s) and Cuda installed. Currently this is a requirement for the packages mamba-ssm and causal-conv1d. +- The package `causal_conv1d` requires `torch` to be installed already. First installing `helical` separately (without `[mamba-ssm]`) will install `torch` for you. A second installation (with `[mamba-ssm]`), installs the packages correctly. ### Singularity (Optional) If you desire to run your code in a singularity file, you can use the [singularity.def](./singularity.def) file and build an apptainer with it: