Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 8 additions & 11 deletions examples/fashion-mnist-parallel-coords-6d.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def _():
from wigglystuff import ParallelCoordinates

from anywidget_vector import VectorSpace

return PCA, ParallelCoordinates, VectorSpace, fetch_openml, mo, np, pl, plt


Expand Down Expand Up @@ -244,10 +245,12 @@ def _(LABEL_COLORS, ParallelCoordinates, df, mo, vs):
_indices = sorted(int(sid.split("_")[1]) for sid in _selected)
_filtered_df = df[_indices]
_par = ParallelCoordinates(_filtered_df, color_by="label", color_map=LABEL_COLORS)
mo.vstack([
mo.md(f"**{len(_indices)}** points selected in 3D view"),
mo.ui.anywidget(_par),
])
mo.vstack(
[
mo.md(f"**{len(_indices)}** points selected in 3D view"),
mo.ui.anywidget(_par),
]
)
else:
mo.md("*Lasso or box-select points in the 3D view to filter*")
return
Expand All @@ -258,13 +261,7 @@ def _(LABEL_COLORS, idx, images, label_names, labels, np, plt, vs, widget):
_selected_ids = set(vs.widget.selected_points or [])
_filtered = widget.widget.filtered_indices

if _selected_ids:
_show = [
_i for _i in range(len(idx))
if f"p_{_i}" in _selected_ids
][:10]
else:
_show = list(_filtered[:10])
_show = [_i for _i in range(len(idx)) if f"p_{_i}" in _selected_ids][:10] if _selected_ids else list(_filtered[:10])

_sample_idx = np.array(_show) if len(_show) > 0 else np.array([0])

Expand Down
Loading