-
Notifications
You must be signed in to change notification settings - Fork 54
Replace scv.pl.scatter and scvelo paga with scanpy equivalents #1302
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
401b33c
5b368cf
126f248
922ec88
cda964c
c79458c
10705c7
0d1cc36
70ad98c
6f24709
d0f292d
dff84b5
9637459
dc27682
23f073b
64382c4
3ad5381
4c8907b
e5d5808
de6acd6
4fc55ac
7fe58c1
fe3e173
be689eb
ae2fc30
59c9079
a09fde7
6bd2fd5
84230d5
9f22ff9
04758a6
997f986
fa40fa6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4,10 +4,11 @@ | |
| from collections.abc import Sequence | ||
| from typing import Any, Literal | ||
|
|
||
| import matplotlib.pyplot as plt | ||
| import numpy as np | ||
| import pandas as pd | ||
| import scanpy as sc | ||
| import scipy.sparse as sp | ||
| import scvelo as scv | ||
| from anndata import AnnData | ||
| from matplotlib.colors import to_hex | ||
| from pandas.api.types import infer_dtype | ||
|
|
@@ -25,6 +26,7 @@ | |
| _convert_to_categorical_series, | ||
| _merge_categorical_series, | ||
| _unique_order_preserving, | ||
| save_fig, | ||
| ) | ||
| from cellrank.estimators._base_estimator import BaseEstimator | ||
| from cellrank.estimators.mixins._utils import ( | ||
|
|
@@ -35,6 +37,7 @@ | |
| shadow, | ||
| ) | ||
| from cellrank.kernels._base_kernel import KernelExpression | ||
| from cellrank.pl._utils import _plot_color_gradients, _plot_time_scatter | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
| __all__ = ["TermStatesEstimator"] | ||
|
|
@@ -306,6 +309,7 @@ def plot_macrostates( | |
| discrete: bool = True, | ||
| mode: Literal["embedding", "time"] = PlotMode.EMBEDDING, | ||
| time_key: str = "latent_time", | ||
| basis: str = "umap", | ||
| same_plot: bool = True, | ||
| title: str | Sequence[str] | None = None, | ||
| cmap: str = "viridis", | ||
|
|
@@ -332,6 +336,8 @@ def plot_macrostates( | |
| Whether to plot the probabilities in an embedding or along the pseudotime. | ||
| time_key | ||
| Key in :attr:`~anndata.AnnData.obs` where pseudotime is stored. Only used when ``mode = {m.TIME!r}``. | ||
| basis | ||
| Key in :attr:`~anndata.AnnData.obsm` for the embedding to use, e.g. ``'umap'`` or ``'tsne'``. | ||
| title | ||
| Title of the plot. | ||
| same_plot | ||
|
|
@@ -340,7 +346,7 @@ def plot_macrostates( | |
| cmap | ||
| Colormap for continuous annotations. | ||
| kwargs | ||
| Keyword arguments for :func:`~scvelo.pl.scatter`. | ||
| Keyword arguments for :func:`~scanpy.pl.embedding`. | ||
|
|
||
| Returns | ||
| ------- | ||
|
|
@@ -377,6 +383,7 @@ def plot_macrostates( | |
| _title=name, | ||
| states=states, | ||
| color=color, | ||
| basis=basis, | ||
| same_plot=same_plot, | ||
| title=title, | ||
| cmap=cmap, | ||
|
|
@@ -390,6 +397,7 @@ def plot_macrostates( | |
| color=color, | ||
| mode=mode, | ||
| time_key=time_key, | ||
| basis=basis, | ||
| same_plot=same_plot, | ||
| title=title, | ||
| cmap=cmap, | ||
|
|
@@ -403,6 +411,7 @@ def _plot_discrete( | |
| _title: str | None = None, | ||
| states: str | Sequence[str] | None = None, | ||
| color: str | None = None, | ||
| basis: str = "umap", | ||
| title: str | Sequence[str] | None = None, | ||
| same_plot: bool = True, | ||
| cmap: str = "viridis", | ||
|
|
@@ -434,20 +443,24 @@ def _plot_discrete( | |
|
|
||
| same_plot = same_plot or len(names) == 1 | ||
| kwargs.setdefault("legend_loc", "on data") | ||
| kwargs["color_map"] = cmap | ||
| # scvelo compat: "right" means "right margin" in scanpy | ||
| if kwargs.get("legend_loc") == "right": | ||
| kwargs["legend_loc"] = "right margin" | ||
| kwargs.pop("color_map", None) | ||
| kwargs.pop("dpi", None) # handled at figure level, not by sc.pl.embedding | ||
| save = kwargs.pop("save", None) | ||
| show = kwargs.pop("show", None) | ||
| kwargs["cmap"] = cmap | ||
| basis = kwargs.pop("basis", basis) | ||
| size = kwargs.get("size", 120_000 / self.adata.n_obs) | ||
|
|
||
| # fmt: off | ||
| with RandomKeys(self.adata, n=1 if same_plot else len(states), where="obs") as keys: | ||
| if same_plot: | ||
| outline = _data.cat.categories.to_list() | ||
| _data = _data.cat.add_categories(["nan"]).fillna("nan") | ||
| states.append("nan") | ||
| color_mapper["nan"] = "#dedede" | ||
| self.adata.obs[keys[0]] = _data | ||
| self.adata.uns[f"{keys[0]}_colors"] = [color_mapper[name] for name in states] | ||
| title = _title if title is None else title | ||
| else: | ||
| outline = None | ||
| for key, cat in zip(keys, states): | ||
| self.adata.obs[key] = _data.cat.set_categories([cat]) | ||
| self.adata.uns[f"{key}_colors"] = [color_mapper[cat]] | ||
|
|
@@ -456,13 +469,43 @@ def _plot_discrete( | |
| if isinstance(title, str): | ||
| title = [title] | ||
|
|
||
| scv.pl.scatter( | ||
| kwargs.setdefault("na_color", "#dedede") | ||
| kwargs.setdefault("na_in_legend", False) | ||
| axes = sc.pl.embedding( | ||
| self.adata, | ||
| basis=basis, | ||
| color=color + keys, | ||
| title=color + title, | ||
| add_outline=outline, | ||
| show=False, | ||
| return_fig=False, | ||
| **kwargs, | ||
| ) | ||
|
|
||
| # Overlay state cells with outlines so they appear on top of NaN cells | ||
| if same_plot: | ||
| axes_list = [axes] if not isinstance(axes, list | np.ndarray) else list(np.ravel(axes)) | ||
| mask = _data.notna() | ||
| if mask.any(): | ||
| adata_sub = self.adata[mask].copy() | ||
| for ax, key in zip(axes_list[len(color):], keys): | ||
| ax_title = ax.get_title() | ||
| sc.pl.embedding( | ||
| adata_sub, | ||
| basis=basis, | ||
| color=key, | ||
| add_outline=True, | ||
| show=False, | ||
| return_fig=False, | ||
| ax=ax, | ||
| legend_loc="none", | ||
| size=size, | ||
| ) | ||
|
Comment on lines
+474
to
+502
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The states are plotted in the background, and all cells not assigned to a state on top and included in the legend as nan. Using the CellRank pseudotime protocol: I suggest we try not using |
||
| ax.set_title(ax_title) | ||
|
|
||
| if save is not None: | ||
| save_fig(plt.gcf(), save) | ||
| if show is True or (show is None and save is None): | ||
| plt.show() | ||
| # fmt: on | ||
|
|
||
| def _plot_continuous( | ||
|
|
@@ -474,6 +517,7 @@ def _plot_continuous( | |
| color: str | None = None, | ||
| mode: Literal["embedding", "time"] = PlotMode.EMBEDDING, | ||
| time_key: str = "latent_time", | ||
| basis: str = "umap", | ||
| title: str | Sequence[str] | None = None, | ||
| same_plot: bool = True, | ||
| cmap: str = "viridis", | ||
|
|
@@ -487,7 +531,6 @@ def _plot_continuous( | |
| states = _data.names | ||
| if not len(states): | ||
| raise ValueError("No lineages have been selected.") | ||
| is_singleton = _data.shape[1] == 1 | ||
| _data = _data[states].copy() | ||
|
|
||
| if mode == "time" and same_plot: | ||
|
|
@@ -510,6 +553,10 @@ def _plot_continuous( | |
| # fmt: off | ||
| color = [] if color is None else (color,) if isinstance(color, str) else color | ||
| color = _unique_order_preserving(color) | ||
| basis = kwargs.pop("basis", basis) | ||
| kwargs.pop("color_map", None) | ||
| save = kwargs.pop("save", None) | ||
| show = kwargs.pop("show", None) | ||
|
|
||
| if mode == PlotMode.TIME: | ||
| kwargs.setdefault("legend_loc", "best") | ||
|
|
@@ -525,41 +572,47 @@ def _plot_continuous( | |
| if len(color) and len(color) not in (1, _data_X.shape[1]): | ||
| raise ValueError(f"Expected `color` to be of length `1` or `{_data_X.shape[1]}`, " | ||
| f"found `{len(color)}`.") | ||
| kwargs["x"] = self.adata.obs[time_key] | ||
| kwargs["y"] = list(_data_X.T) | ||
| kwargs["color"] = color if len(color) else None | ||
| kwargs["xlabel"] = [time_key] * len(states) | ||
| kwargs["ylabel"] = ["probability"] * len(states) | ||
| _plot_time_scatter( | ||
| self.adata, self.adata.obs[time_key].values, list(_data_X.T), | ||
| color=color if len(color) else None, | ||
| title=title, xlabel=time_key, ylabel="probability", cmap=cmap, | ||
| save=save, show=show, **kwargs, | ||
| ) | ||
| elif mode == PlotMode.EMBEDDING: | ||
| kwargs.setdefault("legend_loc", "on data") | ||
| # scvelo compat: "right" means "right margin" in scanpy | ||
| if kwargs.get("legend_loc") == "right": | ||
| kwargs["legend_loc"] = "right margin" | ||
|
|
||
| if same_plot: | ||
| if color: | ||
| # https://github.com/theislab/scvelo/issues/673 | ||
| logger.warning("Ignoring `color` when `mode='embedding'` and `same_plot=True`") | ||
| title = [_title] if title is None else title | ||
| kwargs["color_gradients"] = _data | ||
| _plot_color_gradients(self.adata, _data, basis=basis, title=title, | ||
| save=save, show=show, **kwargs) | ||
| else: | ||
| kwargs.pop("dpi", None) # handled at figure level, not by sc.pl.embedding | ||
| title = [f"{_title} {state}" for state in states] if title is None else title | ||
| if isinstance(title, str): | ||
| title = [title] | ||
| title = color + title | ||
| kwargs["color"] = color + list(_data_X.T) | ||
| # Store probability arrays as temp obs columns (scanpy requires column names) | ||
| with RandomKeys(self.adata, n=_data_X.shape[1], where="obs") as prob_keys: | ||
| for key, col in zip(prob_keys, _data_X.T): | ||
| self.adata.obs[key] = col | ||
| sc.pl.embedding( | ||
| self.adata, basis=basis, color=color + list(prob_keys), | ||
| title=title, cmap=cmap, show=False, **kwargs, | ||
| ) | ||
| if save is not None: | ||
| save_fig(plt.gcf(), save) | ||
| if show is True or (show is None and save is None): | ||
| plt.show() | ||
| else: | ||
| raise NotImplementedError(f"Mode `{mode}` is not yet implemented.") | ||
| # fmt: on | ||
|
|
||
| # e.g. a stationary distribution | ||
| if is_singleton and not np.allclose(_data_X, 1.0): | ||
| kwargs.setdefault("perc", [0, 95]) | ||
| _ = kwargs.pop("color_gradients", None) | ||
|
|
||
| scv.pl.scatter( | ||
| self.adata, | ||
| title=title, | ||
| color_map=cmap, | ||
| **kwargs, | ||
| ) | ||
|
|
||
| def _set_categorical_labels( | ||
| self, | ||
| categories: pd.Series | dict[str, Any], | ||
|
|
||


Uh oh!
There was an error while loading. Please reload this page.