Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
401b33c
Replace scv.pl.scatter and scvelo.plotting.paga with scanpy equivalents
Marius1311 Feb 24, 2026
5b368cf
Address PR review: move helpers, drop perc, add outline, document tra…
Marius1311 Feb 24, 2026
126f248
Fix CI: handle save/show ourselves instead of delegating to scanpy
Marius1311 Feb 24, 2026
922ec88
Remove scVelo references from tests, rename ground truth files
Marius1311 Feb 24, 2026
cda964c
Add scVelo as a test dependency
Marius1311 Feb 24, 2026
c79458c
Regenerate random walk ground truth after PR 1 rebase
Marius1311 Feb 24, 2026
10705c7
Remove from __future__ import annotations from PR 1 files
Marius1311 Feb 24, 2026
0d1cc36
Fix discrete state plot z-order: draw states on top with outlines
Marius1311 Feb 26, 2026
70ad98c
Merge branch 'main' into scvelo/replace-scatter-paga
Marius1311 Feb 26, 2026
6f24709
Fix random walk waypoint marker size (#1303)
Marius1311 Feb 26, 2026
d0f292d
Resolve mere conflict
Marius1311 Mar 2, 2026
dff84b5
Expose basis parameter in plot methods; fix duplicate logging
Marius1311 Mar 2, 2026
9637459
Fix logger stderr output causing red background in Jupyter
Marius1311 Mar 2, 2026
dc27682
Hide NA legend entry for background cells in discrete state plots
Marius1311 Mar 2, 2026
23f073b
Fix color gradient plots: legend placement, point size, dpi
Marius1311 Mar 2, 2026
64382c4
Use pairwise diverging colormaps for fate probability gradients
Marius1311 Mar 2, 2026
3ad5381
Respect legend_loc in circular_projection and time scatter
Marius1311 Mar 2, 2026
4c8907b
Fix variable shadowing in _plot_color_gradients
Marius1311 Mar 2, 2026
e5d5808
Deduplicate pyGAM RuntimeWarnings during gene trend fitting
Marius1311 Mar 2, 2026
de6acd6
Update teams page and maintainers section
Marius1311 Mar 3, 2026
4fc55ac
Merge branch 'main' into scvelo/replace-scatter-paga
Marius1311 Mar 3, 2026
7fe58c1
Update CytoTrace imputation scheme in docs
Marius1311 Mar 5, 2026
fe3e173
Remove internal pyGAM warning suppression
Marius1311 Mar 5, 2026
be689eb
Remove obsolete pyGAM DeprecationWarning filters
Marius1311 Mar 5, 2026
ae2fc30
Fix _fit_bulk not forwarding backend/show_progress_bar to parallelize
Marius1311 Mar 5, 2026
59c9079
Suppress pygam bare print("did not converge") in GAM.fit()
Marius1311 Mar 5, 2026
a09fde7
Revert "Update teams page and maintainers section"
Marius1311 Mar 9, 2026
6bd2fd5
Fix CytoTRACE tests: pass layer='Ms' explicitly
Marius1311 Mar 9, 2026
84230d5
Merge branch 'main' into scvelo/replace-scatter-paga
Marius1311 Mar 9, 2026
9f22ff9
Fix _row_normalize_connectivities to respect conn_key
Marius1311 Mar 9, 2026
04758a6
Remove all alias
Marius1311 Mar 9, 2026
997f986
Fix Schur verbosity test: silence CellRank logger during SLEPc check
Marius1311 Mar 9, 2026
fa40fa6
test: flush capsys before Schur verbosity assertion
Marius1311 Mar 9, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ test = [
"pytest-cov>=6",
"pytest-mock>=3.14",
"pytest-xdist",
"scvelo>=0.3",
]
docs = [
"furo>=2024.8.6",
Expand Down Expand Up @@ -246,5 +247,7 @@ ignore_roles = [
]

[tool.uv]
# Include test deps in `uv sync` so local dev has them (e.g. adjusttext, pytest)
default-groups = [ "dev", "test" ]
# pygpcca 1.0.4 incorrectly pins jinja2==3.0.3; override until next release
override-dependencies = [ "jinja2>=3.1" ]
3 changes: 2 additions & 1 deletion src/cellrank/_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def _setup_logger() -> logging.Logger:
"""Set up the ``"cellrank"`` logger with a :class:`~rich.logging.RichHandler`."""
root = logging.getLogger(_LOGGER_NAME)
if not root.handlers:
console = Console(stderr=True, force_terminal=True)
console = Console(stderr=False)
if console.is_jupyter:
console.is_jupyter = False
handler = RichHandler(
Expand All @@ -32,6 +32,7 @@ def _setup_logger() -> logging.Logger:
)
root.addHandler(handler)
root.setLevel(logging.INFO)
root.propagate = False
return root


Expand Down
6 changes: 5 additions & 1 deletion src/cellrank/estimators/mixins/_fate_probabilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,7 @@ def plot_fate_probabilities(
color: str | None = None,
mode: Literal["embedding", "time"] = PlotMode.EMBEDDING,
time_key: str | None = None,
basis: str = "umap",
same_plot: bool = True,
title: str | Sequence[str] | None = None,
cmap: str = "viridis",
Expand All @@ -260,6 +261,8 @@ def plot_fate_probabilities(
Whether to plot the probabilities in an embedding or along the pseudotime.
time_key
Key in :attr:`~anndata.AnnData.obs` where pseudotime is stored. Only used when ``mode = 'time'``.
basis
Key in :attr:`~anndata.AnnData.obsm` for the embedding to use, e.g. ``'umap'`` or ``'tsne'``.
title
Title of the plot.
same_plot
Expand All @@ -268,7 +271,7 @@ def plot_fate_probabilities(
cmap
Colormap for continuous annotations.
kwargs
Keyword arguments for :func:`~scvelo.pl.scatter`.
Keyword arguments for :func:`~scanpy.pl.embedding`.

Returns
-------
Expand All @@ -285,6 +288,7 @@ def plot_fate_probabilities(
color=color,
mode=mode,
time_key=time_key,
basis=basis,
same_plot=same_plot,
title=title,
cmap=cmap,
Expand Down
11 changes: 7 additions & 4 deletions src/cellrank/estimators/mixins/_lineage_drivers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import numpy as np
import pandas as pd
import scanpy as sc
import scvelo as scv
from anndata import AnnData, Raw
from matplotlib import patheffects, rc_context
from matplotlib.axes import Axes
Expand Down Expand Up @@ -235,7 +234,7 @@ def plot_lineage_drivers(
with the actual values.
%(plotting)s
kwargs
Keyword arguments for :func:`~scvelo.pl.scatter`.
Keyword arguments for :func:`~scanpy.pl.embedding`.

Returns
-------
Expand Down Expand Up @@ -286,13 +285,17 @@ def prepare_format(
)
axes = np.ravel([axes])

basis = kwargs.pop("basis", "umap")
# scvelo compat: "right" means "right margin" in scanpy
if kwargs.get("legend_loc") == "right":
kwargs["legend_loc"] = "right margin"
_i = 0
for _i, (gene, ax) in enumerate(zip(genes.index, axes)):
data = genes.loc[gene]
scv.pl.scatter(
sc.pl.embedding(
self.adata,
basis=basis,
color=gene,
ncols=ncols,
use_raw=use_raw,
ax=ax,
show=False,
Expand Down
113 changes: 83 additions & 30 deletions src/cellrank/estimators/terminal_states/_term_states_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@
from collections.abc import Sequence
from typing import Any, Literal

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scanpy as sc
import scipy.sparse as sp
import scvelo as scv
from anndata import AnnData
from matplotlib.colors import to_hex
from pandas.api.types import infer_dtype
Expand All @@ -25,6 +26,7 @@
_convert_to_categorical_series,
_merge_categorical_series,
_unique_order_preserving,
save_fig,
)
from cellrank.estimators._base_estimator import BaseEstimator
from cellrank.estimators.mixins._utils import (
Expand All @@ -35,6 +37,7 @@
shadow,
)
from cellrank.kernels._base_kernel import KernelExpression
from cellrank.pl._utils import _plot_color_gradients, _plot_time_scatter

logger = logging.getLogger(__name__)
__all__ = ["TermStatesEstimator"]
Expand Down Expand Up @@ -306,6 +309,7 @@ def plot_macrostates(
discrete: bool = True,
mode: Literal["embedding", "time"] = PlotMode.EMBEDDING,
time_key: str = "latent_time",
basis: str = "umap",
same_plot: bool = True,
title: str | Sequence[str] | None = None,
cmap: str = "viridis",
Expand All @@ -332,6 +336,8 @@ def plot_macrostates(
Whether to plot the probabilities in an embedding or along the pseudotime.
time_key
Key in :attr:`~anndata.AnnData.obs` where pseudotime is stored. Only used when ``mode = {m.TIME!r}``.
basis
Key in :attr:`~anndata.AnnData.obsm` for the embedding to use, e.g. ``'umap'`` or ``'tsne'``.
title
Title of the plot.
same_plot
Expand All @@ -340,7 +346,7 @@ def plot_macrostates(
cmap
Colormap for continuous annotations.
kwargs
Keyword arguments for :func:`~scvelo.pl.scatter`.
Keyword arguments for :func:`~scanpy.pl.embedding`.

Returns
-------
Expand Down Expand Up @@ -377,6 +383,7 @@ def plot_macrostates(
_title=name,
states=states,
color=color,
basis=basis,
same_plot=same_plot,
title=title,
cmap=cmap,
Expand All @@ -390,6 +397,7 @@ def plot_macrostates(
color=color,
mode=mode,
time_key=time_key,
basis=basis,
same_plot=same_plot,
title=title,
cmap=cmap,
Expand All @@ -403,6 +411,7 @@ def _plot_discrete(
_title: str | None = None,
states: str | Sequence[str] | None = None,
color: str | None = None,
basis: str = "umap",
title: str | Sequence[str] | None = None,
same_plot: bool = True,
cmap: str = "viridis",
Expand Down Expand Up @@ -434,20 +443,24 @@ def _plot_discrete(

same_plot = same_plot or len(names) == 1
kwargs.setdefault("legend_loc", "on data")
kwargs["color_map"] = cmap
# scvelo compat: "right" means "right margin" in scanpy
if kwargs.get("legend_loc") == "right":
kwargs["legend_loc"] = "right margin"
kwargs.pop("color_map", None)
kwargs.pop("dpi", None) # handled at figure level, not by sc.pl.embedding
save = kwargs.pop("save", None)
show = kwargs.pop("show", None)
kwargs["cmap"] = cmap
basis = kwargs.pop("basis", basis)
size = kwargs.get("size", 120_000 / self.adata.n_obs)

# fmt: off
with RandomKeys(self.adata, n=1 if same_plot else len(states), where="obs") as keys:
if same_plot:
outline = _data.cat.categories.to_list()
_data = _data.cat.add_categories(["nan"]).fillna("nan")
states.append("nan")
color_mapper["nan"] = "#dedede"
self.adata.obs[keys[0]] = _data
self.adata.uns[f"{keys[0]}_colors"] = [color_mapper[name] for name in states]
title = _title if title is None else title
else:
outline = None
for key, cat in zip(keys, states):
self.adata.obs[key] = _data.cat.set_categories([cat])
self.adata.uns[f"{key}_colors"] = [color_mapper[cat]]
Expand All @@ -456,13 +469,43 @@ def _plot_discrete(
if isinstance(title, str):
title = [title]

scv.pl.scatter(
kwargs.setdefault("na_color", "#dedede")
kwargs.setdefault("na_in_legend", False)
axes = sc.pl.embedding(
self.adata,
basis=basis,
color=color + keys,
title=color + title,
add_outline=outline,
show=False,
return_fig=False,
**kwargs,
)

# Overlay state cells with outlines so they appear on top of NaN cells
if same_plot:
axes_list = [axes] if not isinstance(axes, list | np.ndarray) else list(np.ravel(axes))
mask = _data.notna()
if mask.any():
adata_sub = self.adata[mask].copy()
for ax, key in zip(axes_list[len(color):], keys):
ax_title = ax.get_title()
sc.pl.embedding(
adata_sub,
basis=basis,
color=key,
add_outline=True,
show=False,
return_fig=False,
ax=ax,
legend_loc="none",
size=size,
)
Comment on lines +474 to +502
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The states are plotted in the background, and all cells not assigned to a state on top and included in the legend as nan. Using the CellRank pseudotime protocol:

Correct (previously)
Image

Incorrect (now)
Image

I suggest we try not using _add_outline_to_groups and call sc.pl.embedding first with the entire dataset, but no coloring, followed by calling sc.pl.embedding with the data subsetted to the states and plotting them with an outline.

ax.set_title(ax_title)

if save is not None:
save_fig(plt.gcf(), save)
if show is True or (show is None and save is None):
plt.show()
# fmt: on

def _plot_continuous(
Expand All @@ -474,6 +517,7 @@ def _plot_continuous(
color: str | None = None,
mode: Literal["embedding", "time"] = PlotMode.EMBEDDING,
time_key: str = "latent_time",
basis: str = "umap",
title: str | Sequence[str] | None = None,
same_plot: bool = True,
cmap: str = "viridis",
Expand All @@ -487,7 +531,6 @@ def _plot_continuous(
states = _data.names
if not len(states):
raise ValueError("No lineages have been selected.")
is_singleton = _data.shape[1] == 1
_data = _data[states].copy()

if mode == "time" and same_plot:
Expand All @@ -510,6 +553,10 @@ def _plot_continuous(
# fmt: off
color = [] if color is None else (color,) if isinstance(color, str) else color
color = _unique_order_preserving(color)
basis = kwargs.pop("basis", basis)
kwargs.pop("color_map", None)
save = kwargs.pop("save", None)
show = kwargs.pop("show", None)

if mode == PlotMode.TIME:
kwargs.setdefault("legend_loc", "best")
Expand All @@ -525,41 +572,47 @@ def _plot_continuous(
if len(color) and len(color) not in (1, _data_X.shape[1]):
raise ValueError(f"Expected `color` to be of length `1` or `{_data_X.shape[1]}`, "
f"found `{len(color)}`.")
kwargs["x"] = self.adata.obs[time_key]
kwargs["y"] = list(_data_X.T)
kwargs["color"] = color if len(color) else None
kwargs["xlabel"] = [time_key] * len(states)
kwargs["ylabel"] = ["probability"] * len(states)
_plot_time_scatter(
self.adata, self.adata.obs[time_key].values, list(_data_X.T),
color=color if len(color) else None,
title=title, xlabel=time_key, ylabel="probability", cmap=cmap,
save=save, show=show, **kwargs,
)
elif mode == PlotMode.EMBEDDING:
kwargs.setdefault("legend_loc", "on data")
# scvelo compat: "right" means "right margin" in scanpy
if kwargs.get("legend_loc") == "right":
kwargs["legend_loc"] = "right margin"

if same_plot:
if color:
# https://github.com/theislab/scvelo/issues/673
logger.warning("Ignoring `color` when `mode='embedding'` and `same_plot=True`")
title = [_title] if title is None else title
kwargs["color_gradients"] = _data
_plot_color_gradients(self.adata, _data, basis=basis, title=title,
save=save, show=show, **kwargs)
else:
kwargs.pop("dpi", None) # handled at figure level, not by sc.pl.embedding
title = [f"{_title} {state}" for state in states] if title is None else title
if isinstance(title, str):
title = [title]
title = color + title
kwargs["color"] = color + list(_data_X.T)
# Store probability arrays as temp obs columns (scanpy requires column names)
with RandomKeys(self.adata, n=_data_X.shape[1], where="obs") as prob_keys:
for key, col in zip(prob_keys, _data_X.T):
self.adata.obs[key] = col
sc.pl.embedding(
self.adata, basis=basis, color=color + list(prob_keys),
title=title, cmap=cmap, show=False, **kwargs,
)
if save is not None:
save_fig(plt.gcf(), save)
if show is True or (show is None and save is None):
plt.show()
else:
raise NotImplementedError(f"Mode `{mode}` is not yet implemented.")
# fmt: on

# e.g. a stationary distribution
if is_singleton and not np.allclose(_data_X, 1.0):
kwargs.setdefault("perc", [0, 95])
_ = kwargs.pop("color_gradients", None)

scv.pl.scatter(
self.adata,
title=title,
color_map=cmap,
**kwargs,
)

def _set_categorical_labels(
self,
categories: pd.Series | dict[str, Any],
Expand Down
Loading
Loading