Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
04cc5fb
Bump notebook from 7.1.3 to 7.2.2
dependabot[bot] Jan 22, 2025
c259e1a
Bump transformers from 4.39.2 to 4.48.0
dependabot[bot] Feb 11, 2025
886e104
Bump certifi from 2024.2.2 to 2024.7.4
dependabot[bot] Feb 12, 2025
c56a280
Bump jinja2 from 3.1.3 to 3.1.6
dependabot[bot] Mar 6, 2025
375ada2
Bump ray from 2.10.0 to 2.43.0
dependabot[bot] Mar 6, 2025
9fe2041
Bump pytorch-lightning from 2.2.4 to 2.4.0
dependabot[bot] Mar 21, 2025
aa8894e
Merge pull request #16 from microsoft/dependabot/pip/pytorch-lightnin…
v-mahughes Jun 9, 2025
a37be56
Merge pull request #15 from microsoft/dependabot/pip/ray-2.43.0
v-mahughes Jun 9, 2025
3657e81
Merge pull request #14 from microsoft/dependabot/pip/jinja2-3.1.6
v-mahughes Jun 9, 2025
97ff437
Merge pull request #13 from microsoft/dependabot/pip/certifi-2024.7.4
v-mahughes Jun 9, 2025
432fabd
Merge pull request #12 from microsoft/dependabot/pip/transformers-4.48.0
v-mahughes Jun 9, 2025
7b71717
Merge pull request #11 from microsoft/dependabot/pip/notebook-7.2.2
v-mahughes Jun 9, 2025
8b2f875
Bump requests from 2.31.0 to 2.32.4
dependabot[bot] Jun 10, 2025
df6e33f
Merge pull request #23 from microsoft/dependabot/pip/requests-2.32.4
agdenadel Jun 10, 2025
7d4b1b1
Add pretrained pca eval + plotting
agdenadel Jun 12, 2025
e0f8d3f
Fix typos
Jun 25, 2025
e358333
Merge pull request #32 from microsoft/perturb-eval
agdenadel Jun 25, 2025
1ba56f4
fix typo in pretrained pca zero shot model evaluator
Jun 25, 2025
b4b3264
add pre-trained pca as an option to zero shot integration script
Jun 25, 2025
f48216f
Merge pull request #33 from microsoft/pca_foundation
agdenadel Jun 25, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 16 additions & 6 deletions eval/zero_shot_classification.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""zero_shot_classification.py evaluates the performance of a pre-trained model
on an unseen dataset without fine-tuning."""
at classifying an unseen dataset without fine-tuning."""
import string
import random
from collections import defaultdict
Expand All @@ -11,11 +11,18 @@
import scanpy as sc
import anndata as ad
from evaluation_utils import prep_for_evaluation
from zero_shot_model_evaluators import SSLZeroShotEvaluator, SCVIZeroShotEvaluator
from zero_shot_model_evaluators import GeneformerZeroShotEvaluator


from zero_shot_model_evaluators import VariableGeneZeroShotEvaluator
from zero_shot_model_evaluators import PrincipalComponentsZeroShotEvaluator
from model_loaders import load_scvi_model, load_ssl_model

from zero_shot_model_evaluators import SCVIZeroShotEvaluator
from zero_shot_model_evaluators import SSLZeroShotEvaluator
from zero_shot_model_evaluators import GeneformerZeroShotEvaluator
from zero_shot_model_evaluators import PretrainedPrincipalComponentsZeroShotEvaluator


from model_loaders import load_scvi_model, load_ssl_model, load_pca_model
from model_loaders import load_geneformer_model, get_ssl_checkpoint_file


Expand Down Expand Up @@ -59,9 +66,12 @@ def get_classification_metrics_df(train_adata,
random_string = ''.join(random.choices(
string.ascii_uppercase + string.digits, k=10))
tmp_output_dir = Path(
f"tmp_zero_shot_integration_geneformer_{random_string}")
f"tmp_zero_shot_classification_geneformer_{random_string}")
zero_shot_evaluator = GeneformerZeroShotEvaluator(
geneformer_model, var_file, dict_dir, tmp_output_dir)
elif method == "PretrainedPCA": # todo test this
pca_model = load_pca_model(downsampling_method, percentage, seed, model_directory)
zero_shot_evaluator = PretrainedPrincipalComponentsZeroShotEvaluator(pca_model)

classification_metrics = zero_shot_evaluator.evaluate_classification(
train_adata, test_adata, cell_type_col)
Expand Down Expand Up @@ -128,7 +138,7 @@ def main():

new_adata = prep_for_evaluation(adata, formatted_h5ad_file, var_file)

if method == "SSL":
if method == "SSL" or method == "PretrainedPCA":
print("processing anndata")
sc.pp.normalize_per_cell(new_adata, counts_per_cell_after=1e4)
sc.pp.log1p(new_adata)
Expand Down
8 changes: 7 additions & 1 deletion eval/zero_shot_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,11 @@
from zero_shot_model_evaluators import VariableGeneZeroShotEvaluator
from zero_shot_model_evaluators import PrincipalComponentsZeroShotEvaluator

from zero_shot_model_evaluators import SSLZeroShotEvaluator, SCVIZeroShotEvaluator
from zero_shot_model_evaluators import SCVIZeroShotEvaluator
from zero_shot_model_evaluators import SSLZeroShotEvaluator
from zero_shot_model_evaluators import GeneformerZeroShotEvaluator
from zero_shot_model_evaluators import PretrainedPrincipalComponentsZeroShotEvaluator


from evaluation_utils import prep_for_evaluation

Expand Down Expand Up @@ -61,6 +64,9 @@ def get_scib_metrics_df(adata,
f"tmp_zero_shot_integration_geneformer_{random_string}")
zero_shot_evaluator = GeneformerZeroShotEvaluator(
geneformer_model, var_file, dict_dir, tmp_output_dir)
elif method == "PretrainedPCA": # todo test this
pca_model = load_pca_model(downsampling_method, percentage, seed, model_directory)
zero_shot_evaluator = PretrainedPrincipalComponentsZeroShotEvaluator(pca_model)

scib_metrics = zero_shot_evaluator.evaluate_integration(
adata, batch_col=batch_col, label_col=label_col)
Expand Down
2 changes: 1 addition & 1 deletion eval/zero_shot_model_evaluators.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def __init__(self, model):
self.embedding_name = "X_Pretrained_PCA"
self.model = model
def get_embeddings(self, adata):
return adata.X @ model
return adata.X @ self.model


class SSLZeroShotEvaluator(ZeroShotEvaluator):
Expand Down
139 changes: 109 additions & 30 deletions plotting/lineplots.ipynb

Large diffs are not rendered by default.

14 changes: 7 additions & 7 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ betterproto==1.2.5
biothings-client==0.3.1
bleach==6.1.0
blosc2==2.5.1
certifi==2024.2.2
certifi==2024.7.4
cffi==1.16.0
charset-normalizer==3.3.2
chex==0.1.86
Expand Down Expand Up @@ -88,7 +88,7 @@ isoduration==20.11.0
jax==0.4.28
jaxlib==0.4.28
jedi==0.19.1
Jinja2==3.1.3
Jinja2==3.1.6
joblib==1.4.2
json5==0.9.25
jsonpointer==2.4
Expand Down Expand Up @@ -141,7 +141,7 @@ ndindex==1.8
nest-asyncio==1.6.0
networkx==3.4.2
ninja==1.11.1.1
notebook==7.1.3
notebook==7.2.2
notebook_shim==0.2.4
npy-append-array==0.9.16
numba==0.59.1
Expand Down Expand Up @@ -212,7 +212,7 @@ pyro-api==0.1.2
pyro-ppl==1.9.0
python-dateutil==2.9.0.post0
python-json-logger==2.0.7
pytorch-lightning==2.2.4
pytorch-lightning==2.4.0
pytorch-tabnet==4.1.0
pytz==2024.2
pyudorandom==1.0.0
Expand All @@ -221,11 +221,11 @@ pyzmq==26.0.3
qtconsole==5.5.2
QtPy==2.4.1
rapids-dask-dependency==24.2.0
ray==2.10.0
ray==2.43.0
rdkit==2023.9.6
referencing==0.34.0
regex==2023.12.25
requests==2.31.0
requests==2.32.4
rfc3339-validator==0.1.4
rfc3986-validator==0.1.1
rich==13.7.1
Expand Down Expand Up @@ -276,7 +276,7 @@ torchvision==0.15.2
tornado==6.4
tqdm==4.66.5
traitlets==5.14.3
transformers==4.39.2
transformers==4.48.0
triton==2.0.0
types-python-dateutil==2.9.0.20240316
typing_extensions==4.10.0
Expand Down
Loading