diff --git a/README.md b/README.md index 3dd75048..b28cfd37 100644 --- a/README.md +++ b/README.md @@ -120,6 +120,7 @@ Within the `examples/notebooks` folder, open the notebook of your choice. We rec |[Cell-Type-Annotation.ipynb](./examples/notebooks/Cell-Type-Annotation.ipynb)|An example how to do probing with scGPT by training a neural network to predict cell type annotations.|[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/helicalAI/helical/blob/main/examples/notebooks/Cell-Type-Annotation.ipynb) | |[Cell-Type-Classification-Fine-Tuning.ipynb](./examples/notebooks/Cell-Type-Classification-Fine-Tuning.ipynb)|An example how to fine-tune different models on classification tasks.|[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/helicalAI/helical/blob/main/examples/notebooks/Cell-Type-Classification-Fine-Tuning.ipynb) | |[HyenaDNA-Fine-Tuning.ipynb](./examples/notebooks/HyenaDNA-Fine-Tuning.ipynb)|An example of how to fine-tune the HyenaDNA model on downstream benchmarks.|[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/helicalAI/helical/blob/main/examples/notebooks/HyenaDNA-Fine-Tuning.ipynb) | +|[Cell-Gene-Cls-embedding-generation.ipynb](./examples/notebooks/Cell-Gene-Cls-embedding-generation.ipynb)|A notebook explaining the different embedding modes of single cell RNA models.|[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/helicalAI/helical/blob/main/examples/notebooks/Cell-Gene-Cls-embedding-generation.ipynb) | | Coming Soon | New models such as SCimilarity, scVI; benchmarking scripts; new use cases; others | ## Stuck somewhere ? Other ideas ? @@ -179,4 +180,3 @@ Please use this BibTeX to cite this repository in your publications: url = {https://doi.org/10.5281/zenodo.13135902} } ``` - diff --git a/docs/index.md b/docs/index.md index a8032986..5b3ac1d1 100644 --- a/docs/index.md +++ b/docs/index.md @@ -98,6 +98,7 @@ Within the `example/notebooks` folder, open the notebook of your choice. We reco |[Cell-Type-Annotation.ipynb](./notebooks/Cell-Type-Annotation.ipynb)|An example how to do probing with scGPT by training a neural network to predict cell type annotations.|[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/helicalAI/helical/blob/main/examples/notebooks/Cell-Type-Annotation.ipynb) | |[Cell-Type-Classification-Fine-Tuning.ipynb](./notebooks/Cell-Type-Classification-Fine-Tuning.ipynb)|An example how to fine-tune different models on classification tasks.|[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/helicalAI/helical/blob/main/examples/notebooks/Cell-Type-Classification-Fine-Tuning.ipynb) | |[HyenaDNA-Fine-Tuning.ipynb](./notebooks/HyenaDNA-Fine-Tuning.ipynb)|An example of how to fine-tune the HyenaDNA model on downstream benchmarks.|[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/helicalAI/helical/blob/main/examples/notebooks/HyenaDNA-Fine-Tuning.ipynb) | +|[Cell-Gene-Cls-embedding-generation.ipynb](./examples/notebooks/Cell-Gene-Cls-embedding-generation.ipynb)|A notebook explaining the different embedding modes of single cell RNA models.|[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/helicalAI/helical/blob/main/examples/notebooks/Cell-Gene-Cls-embedding-generation.ipynb) | | Coming Soon | New models such as SCimilarity, scVI; benchmarking scripts; new use cases; others | ## Stuck somewhere ? Other ideas ? diff --git a/docs/model_cards/geneformer.md b/docs/model_cards/geneformer.md index 4dae0cd8..a213e88d 100644 --- a/docs/model_cards/geneformer.md +++ b/docs/model_cards/geneformer.md @@ -192,16 +192,19 @@ import anndata as ad # Example configuration model_config = GeneformerConfig(model_name="gf-12L-95M-i4096", batch_size=10) -geneformer = Geneformer(model_config=model_config) +geneformer_v2 = Geneformer(model_config) # Example usage for base pretrained model -ann_data = ad.read_h5ad("general_dataset.h5ad") +ann_data = ad.read_h5ad("anndata_file.h5ad") dataset = geneformer_v2.process_data(ann_data) embeddings = geneformer_v2.get_embeddings(dataset) print("Base model embeddings shape:", embeddings.shape) # Example usage for cancer-tuned model -cancer_ann_data = ad.read_h5ad("cancer_dataset.h5ad") +model_config_cancer = GeneformerConfig(model_name="gf-12L-95M-i4096-CLcancer", batch_size=10) +geneformer_v2_cancer = Geneformer(model_config) + +cancer_ann_data = ad.read_h5ad("anndata_file.h5ad") cancer_dataset = geneformer_v2_cancer.process_data(cancer_ann_data) cancer_embeddings = geneformer_v2_cancer.get_embeddings(cancer_dataset) print("Cancer-tuned model embeddings shape:", cancer_embeddings.shape) @@ -211,43 +214,47 @@ print("Cancer-tuned model embeddings shape:", cancer_embeddings.shape) ```python from helical import GeneformerConfig, GeneformerFineTuningModel +import anndata as ad -# Prepare the data -ann_data = ad.read_h5ad("dataset.h5ad") - -# Get the desired label class -cell_types = list(ann_data.obs.cell_type) +# Load the data +ann_data = ad.read_h5ad("/home/matthew/helical-dev/helical/yolksac_human.h5ad") -# Create a dictionary mapping the classes to unique integers for training +# Get the column for fine-tuning +cell_types = list(ann_data.obs["cell_types"]) label_set = set(cell_types) -class_id_dict = dict(zip(label_set, [i for i in range(len(label_set))])) -for i in range(len(cell_types)): - cell_types[i] = class_id_dict[cell_types[i]] +# Create a GeneformerConfig object +geneformer_config = GeneformerConfig(model_name="gf-12L-95M-i4096", batch_size=10) + +# Create a GeneformerFineTuningModel object +geneformer_fine_tune = GeneformerFineTuningModel(geneformer_config=geneformer_config, fine_tuning_head="classification", output_size=len(label_set)) -# Add this column to the Dataset +# Process the data +dataset = geneformer_fine_tune.process_data(ann_data[:10]) + +# Add column to the dataset dataset = dataset.add_column('cell_types', cell_types) -# Create the fine-tuning model -model_config = GeneformerConfig(model_name="gf-12L-95M-i4096", batch_size=10) -geneformer_fine_tune = GeneformerFineTuningModel( - geneformer_config=model_config, - fine_tuning_head="classification", - label="cell_types", - output_size=len(label_set) -) +# Create a dictionary to map cell types to ids +class_id_dict = dict(zip(label_set, [i for i in range(len(label_set))])) + +def classes_to_ids(example): + example["cell_types"] = class_id_dict[example["cell_types"]] + return example -# Process the data for training -dataset = geneformer_fine_tune.process_data(ann_data) +# Convert cell types to ids +dataset = dataset.map(classes_to_ids, num_proc=1) -# Fine-tune -geneformer_fine_tune.train(train_dataset=dataset) +# Fine-tune the model +geneformer_fine_tune.train(train_dataset=dataset, label="cell_types") -# Get outputs of the fine-tuned model +# Get logits from the fine-tuned model outputs = geneformer_fine_tune.get_outputs(dataset) +print(outputs[:10]) -# Get the embeddings of the fine-tuned model +# Get embeddings from the fine-tuned model embeddings = geneformer_fine_tune.get_embeddings(dataset) +print(embeddings[:10]) ``` ## Contact diff --git a/docs/model_cards/helix_mrna.md b/docs/model_cards/helix_mrna.md index 58d97b88..95b4c3ec 100644 --- a/docs/model_cards/helix_mrna.md +++ b/docs/model_cards/helix_mrna.md @@ -101,7 +101,7 @@ import torch device = "cuda" if torch.cuda.is_available() else "cpu" -helix_mrna_config = HelimRNAConfig(batch_size=5, max_length=100, device=device) +helix_mrna_config = HelixmRNAConfig(batch_size=5, max_length=100, device=device) helix_mrna = HelixmRNA(configurer=helix_mrna_config) rna_sequences = ["EACUEGGG", "EACUEGGG", "EACUEGGG", "EACUEGGG", "EACUEGGG"] diff --git a/examples/fine_tune_models/fine_tune_geneformer.py b/examples/fine_tune_models/fine_tune_geneformer.py index 0715ef9a..064d2bd8 100644 --- a/examples/fine_tune_models/fine_tune_geneformer.py +++ b/examples/fine_tune_models/fine_tune_geneformer.py @@ -32,7 +32,7 @@ def classes_to_ids(example): dataset = dataset.map(classes_to_ids, num_proc=1) - geneformer_fine_tune.train(train_dataset=dataset) + geneformer_fine_tune.train(train_dataset=dataset, label="cell_types") outputs = geneformer_fine_tune.get_outputs(dataset) print(outputs) diff --git a/examples/notebooks/Cell-Gene-Cls-embedding-generation.ipynb b/examples/notebooks/Cell-Gene-Cls-embedding-generation.ipynb new file mode 100644 index 00000000..84db27f7 --- /dev/null +++ b/examples/notebooks/Cell-Gene-Cls-embedding-generation.ipynb @@ -0,0 +1,713 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Predictions using different embedding modes\n", + "\n", + "In this Notebook, we want to show the different embedding modes that are available for the different single cell RNA models, available in the package." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING:py.warnings:/home/benoit/miniconda3/envs/helical-package/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n", + "\n", + "INFO:datasets:PyTorch version 2.5.1 available.\n", + "INFO:datasets:Polars version 0.20.31 available.\n", + "WARNING:helical.models.scgpt.model_dir.multiomic_model:flash_attn is not installed.\n" + ] + } + ], + "source": [ + "from helical import scGPT, scGPTConfig\n", + "import torch\n", + "import anndata\n", + "from pathlib import Path\n", + "from helical.utils.downloader import Downloader\n", + "import os\n", + "from helical.constants.paths import CACHE_DIR_HELICAL" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We show the working principle using the scGPT model. Get the data if you don't have it already:" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:helical.models.scgpt.model:Model finished initializing.\n", + "INFO:helical.models.scgpt.model:'scGPT' model is in 'eval' mode, on device 'cpu' with embedding mode 'cls'.\n" + ] + } + ], + "source": [ + "scgpt = scGPT()\n", + "path = Path.joinpath(CACHE_DIR_HELICAL, \"17_04_24_YolkSacRaw_F158_WE_annots.h5ad\")\n", + "if not os.path.exists(path):\n", + " downloader = Downloader()\n", + " downloader.download_via_name(\"17_04_24_YolkSacRaw_F158_WE_annots.h5ad\")\n", + " \n", + "data = anndata.read_h5ad(path)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "To explain the working principle of the different embedding modes, it is easier to simulate returned embeddings from the model.\n", + "We can do this in the following cell:\n", + "- we define a torch tensor, simulating the embeddings\n", + "- overwrite the `scgpt.model._encode` function to return those embeddings \n", + "- skip the `scgpt._normalize_embeddings` function by returning the input without modifying it \n" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "# Mock the method directly on the instance\n", + "mocked_embeddings = torch.tensor([\n", + " [[1.0, 1.0, 1.0, 1.0, 1.0], \n", + " [5.0, 5.0, 5.0, 5.0, 5.0], \n", + " [1.0, 2.0, 3.0, 2.0, 1.0], \n", + " [6.0, 6.0, 6.0, 6.0, 6.0]],\n", + " ])\n", + "scgpt.model._encode = lambda *args, **kwargs: mocked_embeddings\n", + "scgpt._normalize_embeddings = lambda x: x" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "With this, we can run scGPT in the 3 different modes: `gene`, `cell` and `cls`.\n", + "\n", + "- The `gene` mode returns embeddings for every gene.\n", + "- The `cell` mode returns the average of the gene embeddings.\n", + "- The `cls` mode returns the `cls` specific row, returned by the model. It can be thought of as a summary of the observation.\n", + "\n", + "We run scGPT on a single observation / cell to explain the process." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:helical.models.scgpt.model:Processing data for scGPT.\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:helical.models.scgpt.model:Filtering out 10801 genes to a total of 26517 genes with an ID in the scGPT vocabulary.\n", + "INFO:helical.models.scgpt.model:Successfully processed the data for scGPT.\n", + "INFO:helical.models.scgpt.model:Started getting embeddings:\n", + "Embedding cells: 100%|██████████| 1/1 [00:00<00:00, 8.50it/s]\n", + "INFO:helical.models.scgpt.model:Finished getting embeddings.\n", + "INFO:helical.models.scgpt.model:Started getting embeddings:\n", + "Embedding cells: 100%|██████████| 1/1 [00:00<00:00, 1123.88it/s]\n", + "INFO:helical.models.scgpt.model:Finished getting embeddings.\n", + "INFO:helical.models.scgpt.model:Started getting embeddings:\n", + "Embedding cells: 100%|██████████| 1/1 [00:00<00:00, 1723.92it/s]\n", + "INFO:helical.models.scgpt.model:Finished getting embeddings.\n" + ] + } + ], + "source": [ + "dataset = scgpt.process_data(data[0])\n", + "\n", + "scgpt.config[\"emb_mode\"] = \"gene\"\n", + "gene_embeddings = scgpt.get_embeddings(dataset)\n", + "\n", + "scgpt.config[\"emb_mode\"] = \"cell\"\n", + "cell_embeddings = scgpt.get_embeddings(dataset)\n", + "\n", + "scgpt.config[\"emb_mode\"] = \"cls\"\n", + "cls_embeddings = scgpt.get_embeddings(dataset)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The gene embeddings return embeddings for every gene:" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "SLC39A14 [5.0, 5.0, 5.0, 5.0, 5.0]\n", + "MPDU1 [1.0, 2.0, 3.0, 2.0, 1.0]\n", + "GPHN [6.0, 6.0, 6.0, 6.0, 6.0]\n", + "dtype: object" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "gene_embeddings[0]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The cell embeddings hold the averages of the gene embeddings:" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([4. , 4.3333335, 4.6666665, 4.3333335, 4. ],\n", + " dtype=float32)" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "cell_embeddings[0]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The cls embeddings correspond to the first row returned by the model.\n", + "\n", + "This means that scGPT in `cls` mode ignores the remaining 3 rows." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([1., 1., 1., 1., 1.], dtype=float32)" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "cls_embeddings[0]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can run this on real data too but the interpreation of this is harder to visualise:\n", + "\n", + "First, we remove our modified scGPT model and instantiate a new one." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:helical.models.scgpt.model:Model finished initializing.\n", + "INFO:helical.models.scgpt.model:'scGPT' model is in 'eval' mode, on device 'cuda' with embedding mode 'cls'.\n" + ] + } + ], + "source": [ + "del scgpt\n", + "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", + "scgpt = scGPT(configurer=scGPTConfig(device=device))" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:helical.models.scgpt.model:Started getting embeddings:\n", + "Embedding cells: 100%|██████████| 1/1 [00:00<00:00, 5.18it/s]\n", + "INFO:helical.models.scgpt.model:Finished getting embeddings.\n" + ] + }, + { + "data": { + "text/plain": [ + "SLC39A14 [-0.0011799249, 0.0031951678, -0.0037296554, 0...\n", + "MPDU1 [-0.0013756858, 0.017062135, -0.007849643, 0.0...\n", + "GPHN [0.007110456, 0.025636358, 0.0028697518, 0.005...\n", + "AGFG2 [-0.008909806, 0.01073429, 0.006347002, 0.0071...\n", + "POLR3B [-0.012140153, 0.04901718, 0.02245722, 0.00043...\n", + " ... \n", + "TMEM258 [-0.0077039357, 0.017461302, 0.002785733, 0.01...\n", + "BNIP3L [-0.0103421565, 0.035706572, 0.011275602, 0.00...\n", + "KPNB1 [0.0004736521, 0.032073762, 0.0024564175, 0.00...\n", + "ZSWIM5 [-0.012645806, 0.048165236, 0.02488112, -0.006...\n", + "REPIN1 [0.00678998, 0.019529147, -0.0017630243, 0.001...\n", + "Length: 1199, dtype: object" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "scgpt.config[\"emb_mode\"] = \"gene\"\n", + "gene_embeddings = scgpt.get_embeddings(dataset)\n", + "gene_embeddings[0]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "With real data, it is easier to analyse the output sizes:" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Number of genes with embeddings: (1199,)\n", + "Embedding size per gene: (512,)\n" + ] + } + ], + "source": [ + "print(f\"Number of genes with embeddings: {gene_embeddings[0].shape}\")\n", + "print(f\"Embedding size per gene: {gene_embeddings[0][0].shape}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:helical.models.scgpt.model:Started getting embeddings:\n", + "Embedding cells: 100%|██████████| 1/1 [00:00<00:00, 116.03it/s]\n", + "INFO:helical.models.scgpt.model:Finished getting embeddings.\n" + ] + }, + { + "data": { + "text/plain": [ + "array([-8.15041643e-03, 2.63333023e-02, 7.81406183e-03, 9.81337484e-03,\n", + " 1.40524013e-02, -2.95165903e-03, -1.69922467e-02, -4.81104245e-03,\n", + " -7.60848820e-03, 4.59150635e-02, 5.87423565e-03, -5.45159075e-03,\n", + " 2.13732291e-02, 9.06534586e-03, 1.08101517e-02, -4.91651380e-03,\n", + " -1.33220525e-02, -2.16271300e-02, 1.46969380e-02, -2.12924127e-02,\n", + " -1.29918456e-02, -8.90347362e-03, -3.24050188e-02, 1.69960652e-02,\n", + " 6.72291825e-03, 3.31430845e-02, -3.28126512e-02, -2.09503025e-02,\n", + " -2.97727287e-02, 7.01213209e-03, 2.35989746e-02, -2.13149730e-02,\n", + " -2.99774809e-03, -1.64881814e-02, -1.13256490e-02, -3.82535718e-03,\n", + " -1.26695279e-02, 5.00416942e-03, -6.52590021e-02, -4.58240882e-03,\n", + " -2.73203477e-02, -5.88836940e-03, -2.24745702e-02, -5.89525886e-03,\n", + " -3.59893106e-02, -2.78945770e-02, 8.49726028e-04, 1.33508444e-02,\n", + " 1.13855442e-02, 2.42668390e-02, -1.94337759e-02, 1.32291475e-02,\n", + " -4.44257283e-04, -1.83102898e-02, -2.08448507e-02, 1.82826556e-02,\n", + " 1.19242212e-02, -3.09342373e-04, 4.26567607e-02, 1.78650152e-02,\n", + " -1.22072417e-02, 1.22858435e-02, 3.05658821e-02, 9.43962764e-03,\n", + " 4.51750346e-02, -2.96661500e-02, -5.75119108e-02, -5.81376860e-03,\n", + " -1.80639122e-02, -1.69001874e-02, 1.67023279e-02, -2.39931233e-03,\n", + " -1.83809933e-03, 3.20161544e-02, -3.58780213e-02, 2.30181478e-02,\n", + " -1.01301484e-02, -8.31659790e-03, 1.36515126e-02, -1.00036536e-03,\n", + " -7.31237512e-03, -2.01087706e-02, 6.31964812e-03, 5.51546691e-03,\n", + " 2.41016336e-02, 4.74413810e-03, -3.58009264e-02, 1.88941117e-02,\n", + " 3.33672091e-02, 2.22256090e-02, 1.24969874e-02, 3.52528156e-03,\n", + " 1.23718660e-02, 2.18623634e-02, -2.01625726e-03, 1.86573213e-03,\n", + " -3.18950601e-02, -9.83154459e-05, 2.58983634e-02, -1.17942384e-02,\n", + " -9.07449238e-03, 1.70814723e-03, 1.31330080e-02, 4.50423360e-02,\n", + " -5.97934751e-03, 8.15686863e-03, -7.17658550e-03, -3.37880524e-03,\n", + " -1.05222268e-02, 7.52747199e-03, -1.90593023e-02, 2.03801859e-02,\n", + " -2.10038084e-03, -2.45623812e-02, 3.24325897e-02, -4.20615524e-02,\n", + " -3.38341546e-04, 4.45089638e-02, -7.43741263e-03, 1.41986310e-02,\n", + " 2.59784493e-03, 1.08500132e-02, 3.10978964e-02, -7.94472266e-03,\n", + " 2.52118427e-03, -1.47453938e-02, -9.28190630e-03, 7.39593059e-03,\n", + " -1.47715509e-02, -9.86053608e-03, 3.98509763e-02, 3.56436428e-03,\n", + " 8.54810979e-03, 5.99423714e-04, -1.20271575e-02, -8.72934796e-03,\n", + " -1.60508286e-02, 2.57443152e-02, 9.43687651e-03, 1.95260067e-02,\n", + " 2.37606536e-03, -1.94778237e-02, 2.16944255e-02, 2.05736980e-02,\n", + " -6.41872967e-03, -1.56051095e-03, -1.59282275e-02, 2.52879895e-02,\n", + " -1.93675328e-02, -4.43797708e-02, 1.20970234e-02, 4.58547957e-02,\n", + " 7.58533133e-03, 1.02704652e-02, -1.28064835e-02, 2.98559740e-02,\n", + " -2.03474611e-02, -5.76473475e-01, -5.13960654e-03, 2.43202727e-02,\n", + " 4.93220724e-02, 1.12590883e-02, -1.69760119e-02, -2.47351043e-02,\n", + " 2.91592497e-02, 2.48088595e-02, -3.54069695e-02, -3.00732683e-02,\n", + " -1.46070169e-02, 1.72611121e-02, 2.25844923e-02, -2.45836191e-02,\n", + " 1.40006086e-02, -4.39725034e-02, -2.84583420e-02, -2.03797631e-02,\n", + " 1.95964333e-02, 6.25951355e-03, -2.55475808e-02, 5.02686277e-02,\n", + " 2.04409156e-02, -3.24284844e-02, -3.34916124e-03, -3.73260230e-02,\n", + " 1.40493819e-02, -1.19160721e-02, -7.24646961e-03, 4.72128317e-02,\n", + " -6.35270076e-03, 3.85226980e-02, -8.69447645e-03, 1.94680113e-02,\n", + " 1.27684288e-02, 2.93238671e-03, 1.31395962e-02, -1.21938772e-02,\n", + " -8.36459640e-03, -5.38320187e-03, 3.49017006e-04, -2.18278822e-02,\n", + " -9.98797244e-04, 6.98351813e-03, 2.31984966e-02, 1.62911341e-02,\n", + " -1.99025515e-02, 8.43606051e-03, -1.17899561e-02, 2.57441169e-03,\n", + " -4.76897210e-02, 5.17554879e-02, -1.16848052e-02, 4.94676270e-03,\n", + " 1.86417084e-02, -2.96746138e-02, -4.41117473e-02, 2.91055930e-03,\n", + " 4.77770017e-03, -8.67339503e-03, -1.43343639e-02, -6.08797045e-03,\n", + " 3.59102711e-02, 5.66549180e-03, 1.50430361e-02, 4.81424406e-02,\n", + " -1.86657626e-02, 9.20644403e-03, 1.14493174e-02, 1.21861871e-03,\n", + " 1.30169988e-02, -3.25866719e-03, -1.95563585e-02, -1.34446248e-02,\n", + " -5.34193264e-03, 1.15789995e-02, -3.59822251e-03, -6.05917443e-03,\n", + " -5.88430315e-02, 6.70885667e-03, -2.40413994e-02, 1.33898249e-02,\n", + " -1.53990593e-02, -1.54478103e-02, 1.97891481e-02, 1.68796023e-03,\n", + " 4.15964276e-02, -7.24982703e-03, 5.10265715e-02, -6.14837324e-03,\n", + " -8.88798106e-03, -2.81753968e-02, -4.49333660e-04, -5.43319527e-03,\n", + " 2.33879238e-02, -5.35442606e-02, -4.07010689e-03, -1.10896630e-02,\n", + " 2.38413922e-02, -2.05681357e-03, -7.62468611e-04, -1.29195498e-02,\n", + " 3.25112402e-01, -2.09854450e-03, -1.52156102e-02, 2.79599242e-03,\n", + " -3.00400592e-02, 2.00533817e-04, 4.48551625e-02, -7.97613803e-03,\n", + " -1.24780852e-02, 2.27570944e-02, 3.98986116e-02, -3.02351173e-02,\n", + " 4.93414933e-03, -7.46077276e-04, -4.29602712e-03, 8.70624091e-04,\n", + " 2.66363956e-02, -8.91784951e-03, -5.09983748e-02, -3.93554047e-02,\n", + " -8.97834636e-03, 5.34497380e-01, 4.15710807e-02, -2.93631069e-02,\n", + " 1.20375818e-02, -1.28079718e-02, -1.42050358e-02, 3.97833362e-02,\n", + " 1.40107870e-02, -2.39920523e-03, -4.70105046e-03, 4.68210801e-02,\n", + " 1.54475784e-02, 5.61998133e-03, 2.43147742e-02, -2.24164352e-02,\n", + " 1.46648940e-02, 8.99495464e-03, -6.53423090e-03, 1.79887563e-02,\n", + " 1.85154285e-02, -3.26716118e-02, -1.27531597e-02, 5.07580070e-03,\n", + " -1.71143860e-02, 1.28918644e-02, -1.35411078e-03, 2.26482712e-02,\n", + " 1.18187880e-02, 1.74715593e-02, 6.35542581e-03, -7.73004442e-03,\n", + " 6.38830243e-03, 2.84507312e-02, 1.76105853e-02, -2.28062086e-02,\n", + " -1.62149465e-03, 6.47854060e-02, 1.48503073e-02, 1.37997037e-02,\n", + " 2.97018103e-02, 8.74831621e-03, -2.74825264e-02, -1.43635236e-02,\n", + " -2.59066490e-03, -8.15996062e-03, -1.75764989e-02, 2.93594282e-02,\n", + " -5.96561050e-03, 5.27633261e-03, -2.72250082e-02, 3.07980534e-02,\n", + " 8.68623145e-03, 1.27530685e-02, 9.36667155e-03, 1.43885938e-02,\n", + " -2.89556384e-02, 2.75708195e-02, -1.02682346e-02, 1.21141179e-02,\n", + " -3.35273594e-02, 2.95504229e-03, 1.72882657e-02, -6.86635673e-02,\n", + " -3.61121632e-02, -2.03203037e-02, -1.13087250e-02, -1.38313007e-02,\n", + " -2.70826034e-02, 2.37206947e-02, -2.15230137e-02, 4.15991666e-03,\n", + " 8.86991248e-03, 2.63496563e-02, 2.41908673e-02, -3.67345810e-02,\n", + " 1.83664281e-02, -4.39537130e-02, 5.98312495e-03, -3.49636190e-03,\n", + " 1.00252330e-02, 1.33393332e-02, 7.60935480e-03, 5.20363031e-03,\n", + " -1.65041406e-02, -4.52627009e-03, 3.55193093e-02, 8.90749320e-03,\n", + " 1.27909388e-02, 8.33811518e-03, 1.94471348e-02, 1.68389231e-02,\n", + " -1.93751212e-02, -1.10085038e-02, 1.83981564e-02, -2.40139961e-02,\n", + " -2.94874515e-02, 1.01125650e-02, -2.05271598e-02, -2.76341336e-03,\n", + " -1.32646821e-02, 3.76331359e-02, 8.15028697e-03, -3.10920514e-02,\n", + " -3.44311609e-03, -6.40069786e-03, -5.01741320e-02, -1.87176559e-02,\n", + " 8.99801496e-03, -2.24871212e-03, -3.84792278e-04, -3.24686430e-02,\n", + " 1.87653247e-02, -1.03269462e-02, -2.04273276e-02, 5.72628248e-03,\n", + " 3.20323631e-02, -5.93421690e-04, -3.81626301e-02, 1.03982529e-02,\n", + " 9.31399036e-03, -9.05160414e-05, 3.46498378e-02, -7.75304995e-03,\n", + " -9.04363301e-03, 2.21235268e-02, 6.51048403e-03, 2.03500427e-02,\n", + " 1.03350561e-02, 7.78120058e-03, 1.49649344e-02, 2.23369263e-02,\n", + " 2.59858426e-02, -5.30170510e-03, 2.96027455e-02, -2.79687159e-03,\n", + " -2.20495407e-02, 1.88293215e-02, 7.58425985e-03, -5.03022000e-02,\n", + " -8.56653601e-03, 2.05458421e-03, -2.22983491e-03, -1.31371897e-02,\n", + " 3.72169213e-03, -1.90856550e-02, -2.25026943e-02, 3.84237472e-04,\n", + " -8.41809437e-03, 3.35728265e-02, -2.40656547e-02, 1.65674407e-02,\n", + " 4.37575877e-02, -2.22611297e-02, 1.56481992e-02, -4.64896252e-03,\n", + " 1.13456706e-02, -3.59299663e-03, -2.44456939e-02, -5.62684610e-02,\n", + " -6.08444796e-04, -7.87207112e-03, -1.70890689e-02, -1.05043147e-02,\n", + " 2.19502654e-02, 1.09337000e-02, -2.74148956e-02, -3.31539623e-02,\n", + " -1.42244538e-02, 2.58499496e-02, -1.75755024e-02, -2.55133826e-02,\n", + " 7.51739135e-04, 2.48481100e-03, -3.04994378e-02, -1.40746264e-02,\n", + " 2.61744000e-02, 5.87598886e-03, 2.33898405e-02, -6.32737800e-02,\n", + " 1.78641230e-02, -1.75519601e-01, -8.24389141e-03, 1.98861826e-02,\n", + " 3.37476358e-02, 7.28844898e-03, 3.08266692e-02, 4.87708626e-03,\n", + " 2.09365971e-02, 1.73568614e-02, 3.69173177e-02, 2.79097166e-02,\n", + " 2.77034808e-02, -2.53460445e-02, -2.13546418e-02, -5.77083193e-02,\n", + " -1.14561366e-02, 2.31049191e-02, 3.03287655e-02, -3.81544698e-04,\n", + " -1.71796624e-02, -4.23317999e-02, 1.18717095e-02, 2.22866610e-02,\n", + " -1.19446928e-03, 1.10752694e-02, -5.63540356e-03, 2.86641587e-02,\n", + " 1.95889026e-02, 1.96824670e-02, -2.92297900e-02, 2.18246505e-02,\n", + " -2.48097051e-02, 1.63886491e-02, 3.13611962e-02, -1.68342353e-03,\n", + " 1.88188329e-02, 2.00219527e-02, -5.88387949e-03, 2.13814694e-02,\n", + " 7.82733504e-03, 7.68757053e-03, 3.46655548e-02, -1.79147162e-02,\n", + " 1.58434100e-02, -3.25256586e-02, 9.75492597e-03, 8.89630523e-03,\n", + " 3.20187919e-02, -2.86053754e-02, -4.22061840e-03, -1.45295085e-02],\n", + " dtype=float32)" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "scgpt.config[\"emb_mode\"] = \"cell\"\n", + "cell_embeddings = scgpt.get_embeddings(dataset)\n", + "cell_embeddings[0]" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Embedding size per cell: (512,)\n" + ] + } + ], + "source": [ + "print(f\"Embedding size per cell: {cell_embeddings[0].shape}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:helical.models.scgpt.model:Started getting embeddings:\n", + "Embedding cells: 100%|██████████| 1/1 [00:00<00:00, 240.78it/s]\n", + "INFO:helical.models.scgpt.model:Finished getting embeddings.\n" + ] + }, + { + "data": { + "text/plain": [ + "array([-2.45153867e-02, 6.76829070e-02, 9.26359184e-03, -2.09791050e-03,\n", + " 2.24513449e-02, 2.08408223e-03, -2.53852215e-02, 3.11240065e-03,\n", + " -6.62306789e-03, 3.31380554e-02, 2.76659112e-02, 4.46426682e-03,\n", + " 3.08492407e-02, 1.56556666e-02, 1.73311401e-02, -1.20286504e-02,\n", + " -3.41191003e-03, -2.74797548e-02, -8.76120233e-04, -1.59723293e-02,\n", + " -1.63279437e-02, -1.06245605e-02, -1.66277196e-02, -2.04682187e-03,\n", + " 1.51277408e-02, 5.14600612e-02, -5.13950512e-02, -3.17132138e-02,\n", + " -2.95655672e-02, -2.13937908e-02, 1.59325860e-02, -2.51241624e-02,\n", + " 4.61029354e-03, -2.76504010e-02, -1.92681160e-02, -3.63738127e-02,\n", + " -6.18889090e-03, -4.07493208e-03, -7.18622655e-02, 4.02772194e-03,\n", + " -3.14314142e-02, -7.31843291e-03, -5.05978167e-02, -9.09360871e-03,\n", + " -2.41975486e-02, -1.72051415e-02, 7.02964189e-03, 3.86377797e-02,\n", + " 9.64887347e-03, 5.07224277e-02, -2.60307938e-02, 2.90858541e-02,\n", + " 7.36009097e-03, 1.74028252e-03, -2.98949555e-02, 2.33976301e-02,\n", + " 1.02939699e-02, 2.12073866e-02, 4.89693619e-02, 2.88790110e-02,\n", + " -2.77336799e-02, -4.46378300e-03, 4.08341996e-02, -3.96675081e-04,\n", + " 5.43113872e-02, -4.32825685e-02, -7.32045248e-02, 9.76623502e-03,\n", + " -1.32888434e-02, -3.96071039e-02, 7.86911510e-03, -1.11062010e-03,\n", + " 1.97083247e-03, 1.95169840e-02, -3.42093743e-02, 2.74168756e-02,\n", + " -3.95796215e-03, -7.52964709e-03, 2.74726804e-02, 2.59213755e-03,\n", + " -6.16539083e-03, -1.58587899e-02, 2.82052979e-02, -1.15956198e-02,\n", + " 3.54868509e-02, 2.10501440e-02, -2.49753688e-02, 3.49047594e-02,\n", + " 3.15196700e-02, 2.41361558e-02, -3.76607967e-03, 3.49559938e-03,\n", + " 1.93185955e-02, 3.04949172e-02, -1.16525106e-02, -3.90736852e-03,\n", + " -1.58282351e-02, 3.65261957e-02, 3.23544145e-02, -2.08612122e-02,\n", + " -1.35824289e-02, 1.10562518e-02, 2.50375569e-02, 7.57342279e-02,\n", + " -1.28078144e-02, 7.01878732e-03, -5.82913216e-03, -4.30742698e-03,\n", + " 4.81213778e-02, -6.27691951e-03, -2.15034243e-02, 2.16278862e-02,\n", + " -1.80496816e-02, -4.38741744e-02, 3.37660015e-02, -4.03557792e-02,\n", + " 2.39017084e-02, 4.97239716e-02, 8.25660117e-03, -1.15766805e-02,\n", + " 6.64573535e-03, 8.37536808e-03, 1.50651345e-02, -3.06401737e-02,\n", + " 1.06578423e-02, -2.15192046e-02, 1.51927005e-02, 7.00843893e-03,\n", + " -1.39160594e-02, -1.12494053e-02, 5.43792397e-02, 4.49348055e-02,\n", + " -8.21045507e-03, 6.14139019e-03, -1.22415740e-02, 4.16773022e-04,\n", + " -2.25790031e-02, 2.32806262e-02, -2.98567116e-03, -1.85264312e-02,\n", + " 1.34663479e-02, -2.40827929e-02, -4.38230578e-04, 2.45382376e-02,\n", + " -1.47257214e-02, -1.46871654e-03, -3.57116312e-02, 2.32473407e-02,\n", + " -2.97865532e-02, -5.13614155e-02, 1.58212744e-02, 8.06239173e-02,\n", + " -2.04761140e-02, -1.62395532e-04, -1.97993778e-02, 6.55988678e-02,\n", + " -3.79635133e-02, -4.76373076e-01, 1.76495910e-02, 1.76523570e-02,\n", + " 4.72312011e-02, 3.60119641e-02, -4.50604688e-03, -9.84678417e-03,\n", + " -1.48310815e-03, 6.49683643e-04, -5.05734533e-02, -2.23090481e-02,\n", + " -2.09893622e-02, 2.83944681e-02, 3.98665443e-02, -3.97819020e-02,\n", + " 3.02289277e-02, -6.41960055e-02, -3.12850736e-02, -2.14051344e-02,\n", + " 9.48422309e-03, 1.04411719e-02, -2.30573844e-02, 5.35256118e-02,\n", + " 1.43965846e-02, -4.82759178e-02, -4.62439191e-03, -6.90604970e-02,\n", + " 1.99532807e-02, -2.63340026e-03, -7.42220599e-03, 4.60068770e-02,\n", + " -2.90969908e-02, 2.18398310e-02, -2.34585330e-02, 2.26608873e-03,\n", + " -2.22239364e-03, -2.21185423e-02, 3.63707938e-03, -2.51304284e-02,\n", + " -1.48233669e-02, -1.08530521e-02, 1.45721203e-02, -2.17926800e-02,\n", + " -4.89135645e-03, 8.29203892e-03, 2.51028836e-02, 1.03409085e-02,\n", + " -1.78557765e-02, -6.04140945e-03, 2.05238699e-03, 2.82709692e-02,\n", + " -3.24503630e-02, 3.54559459e-02, -3.53872031e-02, 3.11379209e-02,\n", + " 4.96259928e-02, -8.66587460e-03, -7.14815855e-02, 1.15210470e-02,\n", + " 1.80784240e-02, -5.19741587e-02, -6.71254983e-03, -1.26365563e-02,\n", + " 4.98214737e-02, 8.07952322e-03, 2.27515530e-02, 6.07486628e-02,\n", + " -1.38171967e-02, 3.36158723e-02, 5.27171185e-03, -2.48884223e-02,\n", + " 2.67648492e-02, 7.27484655e-03, -7.95399770e-03, -3.81333083e-02,\n", + " 2.81804637e-03, 1.15901353e-02, -2.18091030e-02, -2.17718966e-02,\n", + " -4.97795902e-02, 4.76862397e-03, -2.01723278e-02, 2.65548769e-02,\n", + " -1.98825561e-02, -3.65659930e-02, 3.32224146e-02, -1.67506132e-02,\n", + " 4.91061732e-02, -1.20102270e-02, 2.07320806e-02, 6.19586110e-02,\n", + " 3.37009295e-03, -1.01545528e-02, -9.37982090e-03, 3.72043811e-03,\n", + " 3.52773704e-02, -5.10503836e-02, 1.51191065e-02, 1.35486033e-02,\n", + " 2.85783764e-02, -6.08509965e-03, 8.12581927e-03, -9.47200111e-04,\n", + " 4.36377168e-01, -1.17061613e-02, -3.76500525e-02, -1.73743330e-02,\n", + " -3.62554304e-02, -7.12229521e-04, 7.17211589e-02, -1.11871678e-03,\n", + " -2.81942077e-03, 1.81365814e-02, 4.89433147e-02, -5.35011478e-02,\n", + " -2.05936991e-02, -2.75716581e-03, -1.34341754e-02, 1.71495937e-02,\n", + " 3.41272503e-02, -9.98707488e-04, -4.40332443e-02, -2.61033233e-02,\n", + " -5.76740783e-03, 3.51253748e-01, 3.57882269e-02, 2.57680687e-04,\n", + " 8.88473634e-03, -4.53057215e-02, -3.88377905e-02, 4.76049073e-02,\n", + " -1.86286587e-02, -6.37506600e-03, 1.03429775e-03, 4.43534069e-02,\n", + " -3.41886049e-03, 1.73910931e-02, 4.96236533e-02, -3.81948426e-02,\n", + " 1.57124866e-02, 2.50780489e-03, -3.65860411e-03, 2.53409501e-02,\n", + " -7.24622048e-04, -1.44517347e-02, -1.94737352e-02, 3.76115702e-02,\n", + " -3.52155380e-02, 3.56680248e-03, -2.10319012e-02, 2.98389420e-02,\n", + " -5.44308778e-03, 1.40576540e-02, -2.62061512e-04, -1.53648760e-02,\n", + " 1.98135469e-02, 3.25694233e-02, 4.54310477e-02, -1.63885728e-02,\n", + " 1.20220520e-02, 6.35347366e-02, 1.08912708e-02, 2.02855766e-02,\n", + " 3.82428616e-02, 1.10052777e-02, -2.17193719e-02, -1.58026423e-02,\n", + " -2.14802809e-02, -5.12225088e-03, -2.51318403e-02, 6.36755824e-02,\n", + " -3.40835005e-02, 2.07607169e-03, -2.16116831e-02, 7.36416802e-02,\n", + " 1.74865592e-02, 3.75458077e-02, 4.12650825e-03, 8.52579810e-03,\n", + " -3.39522772e-02, 2.49971002e-02, -2.51762550e-02, 8.62706732e-03,\n", + " -4.32539880e-02, -8.96353833e-03, 6.73645409e-03, -7.29783550e-02,\n", + " -6.26313537e-02, -2.44746829e-04, -9.67285596e-03, -3.47111858e-02,\n", + " -1.16014006e-02, 2.85756849e-02, -2.09196750e-02, -1.40493680e-02,\n", + " 8.25099554e-03, 5.86548038e-02, 1.85851846e-02, -5.52713126e-02,\n", + " 3.80891152e-02, -6.55806288e-02, 3.01559316e-03, -1.53750554e-02,\n", + " -1.32344523e-02, -1.39216371e-02, 2.48366762e-02, 1.19781457e-02,\n", + " -3.60681303e-03, -8.88350792e-03, 2.99509112e-02, 1.41091915e-02,\n", + " 3.02721411e-02, 2.74510216e-02, 3.79114896e-02, 6.18577981e-03,\n", + " -2.02162806e-02, 8.86006933e-03, 4.37244959e-03, -1.69865005e-02,\n", + " -3.95388640e-02, -6.01971615e-03, -4.53112926e-03, -3.33280605e-03,\n", + " -7.93357007e-03, 6.15930259e-02, 7.47404993e-03, -5.24884239e-02,\n", + " -1.02607217e-02, -4.16327454e-02, -6.04979992e-02, -4.74545322e-02,\n", + " 5.26728295e-03, -1.57921184e-02, -4.90473490e-03, -1.73121970e-02,\n", + " -3.25186062e-03, -1.78076476e-02, -1.31681720e-02, -1.48400199e-02,\n", + " 3.45820636e-02, -3.79318222e-02, -3.98465209e-02, 5.30735124e-03,\n", + " 2.45902408e-02, 3.14300656e-02, 6.30108267e-02, -1.24083590e-02,\n", + " -1.89693309e-02, 2.83043850e-02, -1.92273594e-02, -4.13932558e-03,\n", + " 2.84970496e-02, 2.58654524e-02, 1.38802072e-02, 1.09579228e-03,\n", + " 4.01955470e-02, -1.92459077e-02, 6.25998676e-02, -3.43498192e-03,\n", + " -1.17322234e-02, 4.65216972e-02, -1.45426631e-04, -6.76989853e-02,\n", + " 3.45930792e-02, 1.78965426e-03, -4.68827598e-03, -2.11978052e-02,\n", + " 1.68629754e-02, -2.95592397e-02, -2.81799436e-02, 2.55590193e-02,\n", + " 1.55170374e-02, 4.00625654e-02, -1.83338411e-02, 5.29153924e-03,\n", + " 2.13614181e-02, -2.45864578e-02, 2.74816044e-02, -5.72612137e-03,\n", + " 2.09344216e-02, -7.11009139e-03, -1.07827922e-02, -5.92843071e-02,\n", + " -2.96214614e-02, -6.93373266e-04, -3.61338742e-02, -2.20520645e-02,\n", + " 3.14995237e-02, -2.18518469e-02, -5.05348034e-02, -5.09402566e-02,\n", + " -2.13348819e-03, 1.54039720e-02, 3.19538489e-02, -4.11356539e-02,\n", + " 3.83283990e-03, -4.44084872e-03, -3.83663289e-02, -2.78404041e-04,\n", + " 4.50295173e-02, -1.53042013e-02, 2.65689809e-02, -5.76632135e-02,\n", + " 2.92779431e-02, -1.99816048e-01, -2.14710343e-03, 8.64214171e-03,\n", + " 1.44068627e-02, 2.82023028e-02, 3.13479416e-02, -2.64063105e-02,\n", + " 3.91461104e-02, 4.71042767e-02, 4.99161296e-02, 1.69883831e-03,\n", + " 4.15363535e-02, -2.07511596e-02, -2.40179505e-02, -7.79969022e-02,\n", + " 2.89676187e-04, 3.72988544e-02, 8.40481650e-03, 5.54896705e-03,\n", + " -3.84214185e-02, -1.18556013e-02, -1.60694565e-03, 3.53423283e-02,\n", + " -3.45797054e-02, 2.27767508e-02, -4.68104240e-03, 8.65800027e-03,\n", + " 3.38197835e-02, 1.60816684e-02, -3.27193998e-02, 1.10832201e-02,\n", + " -1.63125228e-02, 2.71654911e-02, 7.86161283e-04, -1.17324255e-02,\n", + " 3.07963714e-02, 2.13367324e-02, 3.76224564e-03, -2.13393662e-03,\n", + " 1.07512320e-03, 3.83881922e-03, 3.68292853e-02, -1.65097248e-02,\n", + " 2.31586695e-02, -5.19343726e-02, 4.10866551e-02, 5.31221693e-03,\n", + " 3.09123825e-02, -1.99824646e-02, -4.52331454e-02, -4.32595611e-03],\n", + " dtype=float32)" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "scgpt.config[\"emb_mode\"] = \"cls\"\n", + "cls_embeddings = scgpt.get_embeddings(dataset)\n", + "cls_embeddings[0]" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Embedding size per cls: (512,)\n" + ] + } + ], + "source": [ + "print(f\"Embedding size per cls: {cls_embeddings[0].shape}\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "helical-package", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.3" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/helical/models/classification/__init__.py b/helical/models/classification/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/helical/models/classification/neural_network.py b/helical/models/classification/neural_network.py deleted file mode 100644 index 3db8295e..00000000 --- a/helical/models/classification/neural_network.py +++ /dev/null @@ -1,187 +0,0 @@ -from numpy import ndarray -from typing_extensions import Self -from sklearn.preprocessing import LabelEncoder -import numpy as np -from helical.models.base_models import BaseTaskModel -from pathlib import Path -import torch -import torch.nn as nn -import torch.optim as optim -from torch.utils.data import DataLoader, TensorDataset -import logging -from typing import Any - -LOGGER = logging.getLogger(__name__) -class NeuralNetwork(BaseTaskModel): - def __init__(self, loss: Any = nn.CrossEntropyLoss(), learning_rate: float = 0.001, epochs=10, batch_size=32) -> None: - """Initialize the neural network model. - - Parameters - ---------- - loss : Any, optional, default=nn.CrossEntropyLoss() - The loss function to use for training the neural network. - learning_rate : float - The learning rate of the neural network. - epochs : int - The number of epochs to train the neural network. - batch_size : int - The batch size to use for training the neural network. - """ - - self.learning_rate = learning_rate - self.epochs = epochs - self.batch_size = batch_size - self.loss_fn = loss - self.encoder = LabelEncoder() - - def compile(self, num_classes: int, input_shape: int) -> None: - """Compile a neural network. The model is a simple feedforward neural network with 2 hidden layers. - TODO - Add more flexibility to the model architecture. - - Parameters - ---------- - num_classes : int - The number of classes to predict. - input_shape : int - The input shape of the neural network. - """ - - self.num_classes = num_classes - self.input_shape = input_shape - self.model = nn.Sequential( - nn.Linear(input_shape, 256), - nn.ReLU(), - nn.Dropout(0.4), - nn.Linear(256, 64), - nn.ReLU(), - nn.Dropout(0.4), - nn.Linear(64, num_classes) - ) - - # Set optimizer and loss function - self.optimizer = optim.Adam(self.model.parameters(), lr=self.learning_rate) - self.loss_fn = nn.CrossEntropyLoss() - - def train(self, X_train: ndarray, y_train: ndarray, validation_data: tuple[ndarray, ndarray]) -> Self: - """Train the neural network on the training and validation data. - - Parameters - ---------- - X_train : ndarray - The training data features. - y_train : ndarray - The training data labels. - validation_data : tuple(ndarray, ndarray) - The validation data features and labels. - - Returns - ------- - The neural network instance. - """ - # Ensure model is in training mode - self.model.train() - - x_val, y_val = validation_data - self.encoder.fit_transform(np.concatenate((y_train, y_val), axis = 0)) - - y_train_encoded = self.encoder.transform(y_train) - y_val_encoded = self.encoder.transform(y_val) - - X_train_tensor = torch.tensor(X_train, dtype=torch.float32) - y_train_tensor = torch.tensor(y_train_encoded, dtype=torch.long) - X_val_tensor = torch.tensor(x_val, dtype=torch.float32) - y_val_tensor = torch.tensor(y_val_encoded, dtype=torch.long) - - train_dataset = TensorDataset(X_train_tensor, y_train_tensor) - val_dataset = TensorDataset(X_val_tensor, y_val_tensor) - - train_loader = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True) - val_loader = DataLoader(val_dataset, batch_size=self.batch_size, shuffle=False) - - # Training loop - for epoch in range(self.epochs): - for batch_X, batch_y in train_loader: - # Zero the parameter gradients - self.optimizer.zero_grad() - - # Forward pass - outputs = self.model(batch_X) - - # Compute loss - loss = self.loss_fn(outputs, batch_y) - - # Backward pass and optimize - loss.backward() - self.optimizer.step() - - # Validation phase (optional) - self.model.eval() - with torch.no_grad(): - val_losses = [] - for val_X, val_y in val_loader: - val_outputs = self.model(val_X) - val_loss = self.loss_fn(val_outputs, val_y) - val_losses.append(val_loss.item()) - - print(f"Epoch {epoch+1}, Validation Loss: {sum(val_losses)/len(val_losses)}") - - # Set back to training mode for next epoch - self.model.train() - return self - - def predict(self, x: ndarray) -> ndarray: - """Use the neural network to make predictions. - - Parameters - ---------- - x : ndarray - The data to make predictions upon. - - Returns - ------- - The prediction of the neural network. - """ - self.model.eval() - predictions_nn = self.model(torch.Tensor(x)) - y_pred = np.array(torch.argmax(predictions_nn, dim=1)) - return self.encoder.inverse_transform(y_pred) - - def save(self, path: str) -> None: - """Save the neural network model and its encoder to a directory. - Any missing parents of this path are created as needed. - - Parameters - ---------- - path : str - The path to the directory to save the model and the encoder. - """ - Path(path).mkdir(parents=True, exist_ok=True) - np.save(f"{path}/encoder", self.encoder.classes_) - torch.save(self.model, f"{path}/neural_network.pth") - - def load(self, path: str, classes: str) -> Self: - """Load the neural network model from a file. - - Parameters - ---------- - path : str - The path to load the model from. - classes : str - The path to classes used for encoding the labels. - - Returns - ------- - The neural network instance. - """ - # set to None, showing this is a loaded model and not trained - self.loss = None - self.learning_rate = None - self.epochs = None - self.batch_size = None - - self.encoder.classes_ = classes - # self.model = nn.Sequential() - self.model = torch.load(path) - self.model.eval() - - return self \ No newline at end of file diff --git a/helical/models/classification/svm.py b/helical/models/classification/svm.py deleted file mode 100644 index 0e5fbe83..00000000 --- a/helical/models/classification/svm.py +++ /dev/null @@ -1,93 +0,0 @@ -from helical.models.base_models import BaseTaskModel -from numpy import ndarray -from sklearn import svm -from typing_extensions import Self -from typing import Optional -import pickle -import os - -class SupportVectorMachine(BaseTaskModel): - def __init__(self, kernel='rbf', degree=3, C=1, decision_function_shape='ovr') -> None: - self.kernel = kernel - self.degree = degree - self.C = C - self.decision_function_shape = decision_function_shape - - def compile(self, num_classes: Optional[int] = None, input_shape: Optional[int] = None) -> None: - """ - Compile a SVM. The input parameters are not needed for the SVM model, providing them makes for cleaner code. - - Parameters - ---------- - num_classes : int, None - The number of classes to predict, default is None. - The SVM should find this automatically. - input_shape : int, None - The input shape of the neural network, default is None. - The SVM should find this automatically. - """ - self.svm_model = svm.SVC(kernel = self.kernel, - degree = self.degree, - C = self.C, - decision_function_shape = self.decision_function_shape) - - def train(self, X_train: ndarray, y_train: ndarray, **kwargs) -> Self: - """Train an SVM on the training. - - Parameters - ---------- - X_train : ndarray - The training data features. - y_train : ndarray - The training data labels. - - Returns - ------- - The neural network instance. - """ - self.svm_model.fit(X_train, y_train) - return self - - def predict(self, x: ndarray) -> ndarray: - """Use the SVM to make predictions. - - Parameters - ---------- - x : ndarray - The data to make predictions upon. - - Returns - ------- - The prediction of the SVM. - """ - return self.svm_model.predict(x) - - def save(self, path: str) -> None: - """Save the SVM model to a file. - - Parameters - ---------- - path : str - The path to save the model. - """ - os.makedirs(os.path.dirname(path), exist_ok=True) - file = f"{path}svm.h5" - with open(file, 'wb') as f: - pickle.dump(self.svm_model, f) - - def load(self, path: str) -> Self: - """Load the SVM model from a file. - - Parameters - ---------- - path : str - The path to load the model from. - - Returns - ------- - The SVM instance. - """ - with open(path, 'rb') as f: - self.svm_model = pickle.load(f) - return self - \ No newline at end of file diff --git a/helical/models/geneformer/fine_tuning_model.py b/helical/models/geneformer/fine_tuning_model.py index 45214495..f3fb46c5 100644 --- a/helical/models/geneformer/fine_tuning_model.py +++ b/helical/models/geneformer/fine_tuning_model.py @@ -22,43 +22,47 @@ class GeneformerFineTuningModel(HelicalBaseFineTuningModel, Geneformer): ---------- ```python from helical import GeneformerConfig, GeneformerFineTuningModel + import anndata as ad - # Prepare the data - ann_data = ad.read_h5ad("dataset.h5ad") + # Load the data + ann_data = ad.read_h5ad("/home/matthew/helical-dev/helical/yolksac_human.h5ad") - # Get the desired label class - cell_types = list(ann_data.obs.cell_type) - - # Create a dictionary mapping the classes to unique integers for training + # Get the column for fine-tuning + cell_types = list(ann_data.obs["cell_types"]) label_set = set(cell_types) - class_id_dict = dict(zip(label_set, [i for i in range(len(label_set))])) - for i in range(len(cell_types)): - cell_types[i] = class_id_dict[cell_types[i]] + # Create a GeneformerConfig object + geneformer_config = GeneformerConfig(model_name="gf-12L-95M-i4096", batch_size=10) + + # Create a GeneformerFineTuningModel object + geneformer_fine_tune = GeneformerFineTuningModel(geneformer_config=geneformer_config, fine_tuning_head="classification", output_size=len(label_set)) - # Add this column to the Dataset + # Process the data + dataset = geneformer_fine_tune.process_data(ann_data[:10]) + + # Add column to the dataset dataset = dataset.add_column('cell_types', cell_types) - # Create the fine-tuning model - model_config = GeneformerConfig(model_name="gf-12L-95M-i4096", batch_size=10) - geneformer_fine_tune = GeneformerFineTuningModel( - geneformer_config=model_config, - fine_tuning_head="classification", - label="cell_types", - output_size=len(label_set) - ) + # Create a dictionary to map cell types to ids + class_id_dict = dict(zip(label_set, [i for i in range(len(label_set))])) + + def classes_to_ids(example): + example["cell_types"] = class_id_dict[example["cell_types"]] + return example - # Process the data for training - dataset = geneformer_fine_tune.process_data(ann_data) + # Convert cell types to ids + dataset = dataset.map(classes_to_ids, num_proc=1) - # Fine-tune - geneformer_fine_tune.train(train_dataset=dataset) + # Fine-tune the model + geneformer_fine_tune.train(train_dataset=dataset, label="cell_types") - # Get outputs of the fine-tuned model + # Get logits from the fine-tuned model outputs = geneformer_fine_tune.get_outputs(dataset) + print(outputs[:10]) - # Get the embeddings of the fine-tuned model + # Get embeddings from the fine-tuned model embeddings = geneformer_fine_tune.get_embeddings(dataset) + print(embeddings[:10]) ``` Parameters diff --git a/helical/models/geneformer/model.py b/helical/models/geneformer/model.py index d088b05a..ecaaa89e 100644 --- a/helical/models/geneformer/model.py +++ b/helical/models/geneformer/model.py @@ -32,29 +32,27 @@ class Geneformer(HelicalRNAModel): Example ------- ```python - from helical.models import Geneformer, GeneformerConfig - import anndata as ad - - # For Version 2.0 - geneformer_config_v2 = GeneformerConfig(model_name="gf-12L-95M-i4096", batch_size=10) - geneformer_v2 = Geneformer(configurer=geneformer_config_v2) - - # You can use other model names in the config, such as: - # "gf-12L-30M-i2048" (Version 1.0) - # "gf-12L-95M-i4096-CLcancer" (Version 2.0, Cancer-tuned) - # "gf-20L-95M-i4096" (Version 2.0, 20-layer model) - - # Example usage for base pretrained model (for general transcriptomic analysis, v1 and v2) - ann_data = ad.read_h5ad("general_dataset.h5ad") - dataset = geneformer_v2.process_data(ann_data) - embeddings = geneformer_v2.get_embeddings(dataset) - print("Base model embeddings shape:", embeddings.shape) - - # Example usage for cancer-tuned model (for cancer-specific analysis) - cancer_ann_data = ad.read_h5ad("cancer_dataset.h5ad") - cancer_dataset = geneformer_v2_cancer.process_data(cancer_ann_data) - cancer_embeddings = geneformer_v2_cancer.get_embeddings(cancer_dataset) - print("Cancer-tuned model embeddings shape:", cancer_embeddings.shape) + from helical import Geneformer, GeneformerConfig + import anndata as ad + + # Example configuration + model_config = GeneformerConfig(model_name="gf-12L-95M-i4096", batch_size=10) + geneformer_v2 = Geneformer(model_config) + + # Example usage for base pretrained model + ann_data = ad.read_h5ad("anndata_file.h5ad") + dataset = geneformer_v2.process_data(ann_data) + embeddings = geneformer_v2.get_embeddings(dataset) + print("Base model embeddings shape:", embeddings.shape) + + # Example usage for cancer-tuned model + model_config_cancer = GeneformerConfig(model_name="gf-12L-95M-i4096-CLcancer", batch_size=10) + geneformer_v2_cancer = Geneformer(model_config) + + cancer_ann_data = ad.read_h5ad("anndata_file.h5ad") + cancer_dataset = geneformer_v2_cancer.process_data(cancer_ann_data) + cancer_embeddings = geneformer_v2_cancer.get_embeddings(cancer_dataset) + print("Cancer-tuned model embeddings shape:", cancer_embeddings.shape) ``` Parameters diff --git a/helical/models/genept/model.py b/helical/models/genept/model.py index 41203e14..4aede994 100644 --- a/helical/models/genept/model.py +++ b/helical/models/genept/model.py @@ -5,16 +5,16 @@ from helical.utils.downloader import Downloader from helical.models.genept.genept_config import GenePTConfig from helical.utils.mapping import map_ensembl_ids_to_gene_symbols -import logging import scanpy as sc import torch import json -import torch LOGGER = logging.getLogger(__name__) + + class GenePT(HelicalRNAModel): - """GenePT Model. - + """GenePT Model. + ``` Parameters @@ -27,7 +27,9 @@ class GenePT(HelicalRNAModel): """ + default_configurer = GenePTConfig() + def __init__(self, configurer: GenePTConfig = default_configurer): super().__init__() self.configurer = configurer @@ -37,34 +39,35 @@ def __init__(self, configurer: GenePTConfig = default_configurer): for file in self.config["list_of_files_to_download"]: downloader.download_via_name(file) - with open(self.config['embeddings_path'],"r") as f: + with open(self.config["embeddings_path"], "r") as f: self.embeddings = json.load(f) LOGGER.info("GenePT initialized successfully.") - def process_data(self, - adata: AnnData, - gene_names: str = "index", - use_raw_counts: bool = True, - ) -> AnnData: + def process_data( + self, + adata: AnnData, + gene_names: str = "index", + use_raw_counts: bool = True, + ) -> AnnData: """ Processes the data for the GenePT model. Parameters ---------- adata : AnnData - The AnnData object containing the data to be processed. GenePT uses Ensembl IDs to identify genes - and currently supports only human genes. If the AnnData object already has an 'ensembl_id' column, + The AnnData object containing the data to be processed. GenePT uses Ensembl IDs to identify genes + and currently supports only human genes. If the AnnData object already has an 'ensembl_id' column, the mapping step can be skipped. gene_names : str, optional, default="index" - The column in `adata.var` that contains the gene names. If set to a value other than "ensembl_id", - the gene symbols in that column will be mapped to Ensembl IDs using the 'pyensembl' package, + The column in `adata.var` that contains the gene names. If set to a value other than "ensembl_id", + the gene symbols in that column will be mapped to Ensembl IDs using the 'pyensembl' package, which retrieves mappings from the Ensembl FTP server and loads them into a local database. - If set to "index", the index of the AnnData object will be used and mapped to Ensembl IDs. - If set to "ensembl_id", no mapping will occur. Special case: - If the index of `adata` already contains Ensembl IDs, setting this to "index" will result in - invalid mappings. In such cases, create a new column containing Ensembl IDs and pass "ensembl_id" + If the index of `adata` already contains Ensembl IDs, setting this to "index" will result in + invalid mappings. In such cases, create a new column containing Ensembl IDs and pass "ensembl_id" as the value of `gene_names`. use_raw_counts : bool, optional, default=True Determines whether raw counts should be used. @@ -74,32 +77,34 @@ def process_data(self, Dataset The tokenized dataset in the form of a Huggingface Dataset object. """ - LOGGER.info(f"Processing data for GenePT.") + LOGGER.info("Processing data for GenePT.") self.ensure_rna_data_validity(adata, gene_names, use_raw_counts) # map gene symbols to ensemble ids if provided if gene_names == "ensembl_id": - if (adata.var[gene_names].str.startswith("ENS").all()) or (adata.var[gene_names].str.startswith("None").any()): - message = "It seems an anndata with 'ensemble ids' and/or 'None' was passed. " \ - "Please set gene_names='ensembl_id' and remove 'None's to skip mapping." + if (adata.var[gene_names].str.startswith("ENS").all()) or ( + adata.var[gene_names].str.startswith("None").any() + ): + message = ( + "It seems an anndata with 'ensemble ids' and/or 'None' was passed. " + "Please set gene_names='ensembl_id' and remove 'None's to skip mapping." + ) LOGGER.info(message) raise ValueError(message) adata = map_ensembl_ids_to_gene_symbols(adata, gene_names) - n_top_genes = 1000 - LOGGER.info(f"Filtering the top {n_top_genes} highly variable genes.") - sc.pp.highly_variable_genes(adata, n_top_genes=n_top_genes, flavor='seurat_v3') + sc.pp.highly_variable_genes(adata, flavor="seurat_v3") sc.pp.normalize_total(adata, target_sum=1e4) sc.pp.log1p(adata) - genes_names = adata.var_names[adata.var['highly_variable']].tolist() - adata = adata[:,genes_names] - - LOGGER.info(f"Successfully processed the data for GenePT.") + genes_names = adata.var_names[adata.var["highly_variable"]].tolist() + adata = adata[:, genes_names] + + LOGGER.info("Successfully processed the data for GenePT.") return adata - + def get_text_embeddings(self, dataset: AnnData) -> np.array: - """Gets the gene embeddings from the GenePT model + """Gets the gene embeddings from the GenePT model Parameters ---------- @@ -111,26 +116,30 @@ def get_text_embeddings(self, dataset: AnnData) -> np.array: np.array The gene embeddings in the form of a numpy array """ - # Generate a response + # Generate a response raw_embeddings = dataset.var_names weights = [] - count_missed = 0 gene_list = [] - for i,emb in enumerate(raw_embeddings): - gene = self.embeddings.get(emb.upper(),None) + count_missed = 0 + + for emb in raw_embeddings: + gene = self.embeddings.get(emb.upper(), None) if gene is not None: - weights.append(gene['embeddings']) + weights.append(gene["embeddings"]) gene_list.append(emb) else: count_missed += 1 - LOGGER.info("Couln't find {} genes in embeddings".format(count_missed)) + + LOGGER.info(f"Couln't find {count_missed} genes in embeddings") weights = torch.Tensor(weights) - embeddings = torch.matmul(torch.Tensor(dataset[:,gene_list].X.toarray()),weights) + embeddings = torch.matmul( + torch.Tensor(dataset[:, gene_list].X.toarray()), weights + ) return embeddings - + def get_embeddings(self, dataset: AnnData) -> torch.Tensor: - """Gets the gene embeddings from the GenePT model + """Gets the gene embeddings from the GenePT model Parameters ---------- @@ -143,7 +152,7 @@ def get_embeddings(self, dataset: AnnData) -> torch.Tensor: The gene embeddings in the form of a numpy array """ LOGGER.info(f"Inference started:") - # Generate a response + # Generate a response embeddings = self.get_text_embeddings(dataset) - embeddings = (embeddings/(np.linalg.norm(embeddings,axis=1)).reshape(-1,1)) - return embeddings \ No newline at end of file + embeddings = embeddings / (np.linalg.norm(embeddings, axis=1)).reshape(-1, 1) + return embeddings diff --git a/helical/models/helix_mrna/model.py b/helical/models/helix_mrna/model.py index 621b7cdf..93264b7d 100644 --- a/helical/models/helix_mrna/model.py +++ b/helical/models/helix_mrna/model.py @@ -30,7 +30,7 @@ class HelixmRNA(HelicalRNAModel): device = "cuda" if torch.cuda.is_available() else "cpu" - helix_mrna_config = HelimRNAConfig(batch_size=5, max_length=100, device=device) + helix_mrna_config = HelixmRNAConfig(batch_size=5, max_length=100, device=device) helix_mrna = HelixmRNA(configurer=helix_mrna_config) rna_sequences = ["EACUEGGG", "EACUEGGG", "EACUEGGG", "EACUEGGG", "EACUEGGG"] diff --git a/helical/utils/downloader.py b/helical/utils/downloader.py index d4a143c7..ea2c82c3 100644 --- a/helical/utils/downloader.py +++ b/helical/utils/downloader.py @@ -82,6 +82,8 @@ 'caduceus/caduceus-ps-4L-seqlen-1k-d256/config.json': '655d2c3a692ab35718cfe87ce98b14c4e57807f85089c2af81809873f356349f', 'genept/genept_embeddings/genept_embeddings.json': '54a58177e6f4cb9c2d98f39cb8c586bd347a526375eba861df15a3714f737ccc', + + '17_04_24_YolkSacRaw_F158_WE_annots.h5ad': '0585c186ef23951a538522dd6882492c2d5c165c615543fe01bf0d0daedc2f5a', } class Downloader(Logger): diff --git a/mkdocs.yml b/mkdocs.yml index cac2519b..dfd3b570 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -52,6 +52,7 @@ nav: - Geneformer-vs-UCE: ./notebooks/Geneformer-vs-UCE.ipynb - Hyena-DNA-Inference: ./notebooks/Hyena-DNA-Inference.ipynb - HyenaDNA-Fine-Tuning: ./notebooks/HyenaDNA-Fine-Tuning.ipynb + - Cell-Gene-Cls-embedding-generation: ./notebooks/Cell-Gene-Cls-embedding-generation.ipynb theme: name: material diff --git a/pyproject.toml b/pyproject.toml index 9ce6aca1..5392551f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "helical" -version = "0.0.1a21" +version = "0.0.1a22" authors = [ { name="Helical Team", email="support@helical-ai.com" }, ] @@ -24,17 +24,17 @@ dependencies = [ 'pandas==2.2.2', 'anndata==0.11', 'numpy==1.26.4', - 'scikit-learn>=1.2.2', + 'scikit-learn>=1.5.0', 'scipy==1.13.1', 'gitpython==3.1.43', 'torch==2.5.1', 'torchvision==0.20.1', 'accelerate==0.29.3', - 'transformers==4.45.1', + 'transformers==4.48.0', 'loompy==3.0.7', 'scib==1.1.5', 'scikit-misc==0.3.1', - 'azure-identity==1.16.0', + 'azure-identity==1.16.1', 'azure-storage-blob==12.19.1', 'azure-core==1.30.1', 'einops==0.8.0', diff --git a/requirements.txt b/requirements.txt index 51e7a7dc..30c24ecf 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,20 +1,20 @@ -requests==2.31.0 +requests==2.32.2 pandas==2.2.2 anndata==0.10.7 numpy==1.26.4 -scikit-learn==1.2.2 +scikit-learn==1.5.0 scipy==1.13.1 gitpython==3.1.43 torch>=2.0.0,<=2.3.0 torchvision>=0.15.0,<=0.18.0 accelerate==0.29.3 -transformers==4.45.1 +transformers==4.48.0 loompy==3.0.7 scib==1.1.5 scikit-misc==0.3.1 datasets==2.14.7 azure-storage-blob==12.19.1 -azure-identity==1.16.0 +azure-identity==1.16.1 azure-core==1.30.1 einops==0.8.0 omegaconf==2.3.0