diff --git a/examples/fashion-mnist-parallel-coords-6d.py b/examples/fashion-mnist-parallel-coords-6d.py index 64d51c2..6a691f9 100644 --- a/examples/fashion-mnist-parallel-coords-6d.py +++ b/examples/fashion-mnist-parallel-coords-6d.py @@ -34,6 +34,7 @@ def _(): from wigglystuff import ParallelCoordinates from anywidget_vector import VectorSpace + return PCA, ParallelCoordinates, VectorSpace, fetch_openml, mo, np, pl, plt @@ -244,10 +245,12 @@ def _(LABEL_COLORS, ParallelCoordinates, df, mo, vs): _indices = sorted(int(sid.split("_")[1]) for sid in _selected) _filtered_df = df[_indices] _par = ParallelCoordinates(_filtered_df, color_by="label", color_map=LABEL_COLORS) - mo.vstack([ - mo.md(f"**{len(_indices)}** points selected in 3D view"), - mo.ui.anywidget(_par), - ]) + mo.vstack( + [ + mo.md(f"**{len(_indices)}** points selected in 3D view"), + mo.ui.anywidget(_par), + ] + ) else: mo.md("*Lasso or box-select points in the 3D view to filter*") return @@ -258,13 +261,7 @@ def _(LABEL_COLORS, idx, images, label_names, labels, np, plt, vs, widget): _selected_ids = set(vs.widget.selected_points or []) _filtered = widget.widget.filtered_indices - if _selected_ids: - _show = [ - _i for _i in range(len(idx)) - if f"p_{_i}" in _selected_ids - ][:10] - else: - _show = list(_filtered[:10]) + _show = [_i for _i in range(len(idx)) if f"p_{_i}" in _selected_ids][:10] if _selected_ids else list(_filtered[:10]) _sample_idx = np.array(_show) if len(_show) > 0 else np.array([0])