From 97cd3c8e740dc7d0557938419c36de895a5974e4 Mon Sep 17 00:00:00 2001 From: ioanalzr Date: Tue, 27 May 2025 13:51:34 +0100 Subject: [PATCH 1/2] debug --- beneuro_pose_estimation/anipose/aniposeTools.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/beneuro_pose_estimation/anipose/aniposeTools.py b/beneuro_pose_estimation/anipose/aniposeTools.py index 819f451..e0cd125 100644 --- a/beneuro_pose_estimation/anipose/aniposeTools.py +++ b/beneuro_pose_estimation/anipose/aniposeTools.py @@ -758,9 +758,10 @@ def run_pose_test(session, test_name = None, cameras=params.default_cameras, for logger.info("Creating test videos...") tests_dir = tools.create_test_videos(session, cameras, duration_seconds, force_new=force_new_videos, start_frame=start_frame) - test_dir = tests_dir / test_name + if test_name is None: test_name = session + "_test" + test_dir = tests_dir / test_name # 2. Run 2D predictions on test videos logger.info("Running 2D predictions...") sleapTools.get_2Dpredictions(session, cameras, test_name = test_name) From 811ac8912c74fc9f5b6c0feca0d9678619b431ce Mon Sep 17 00:00:00 2001 From: ioanalzr Date: Tue, 27 May 2025 16:12:50 +0100 Subject: [PATCH 2/2] added interactive animation for eval --- .../anipose/aniposeTools.py | 4 + beneuro_pose_estimation/cli.py | 21 +- beneuro_pose_estimation/evaluation.py | 190 ++++++++++++++++++ 3 files changed, 210 insertions(+), 5 deletions(-) diff --git a/beneuro_pose_estimation/anipose/aniposeTools.py b/beneuro_pose_estimation/anipose/aniposeTools.py index e0cd125..f7dc5d9 100644 --- a/beneuro_pose_estimation/anipose/aniposeTools.py +++ b/beneuro_pose_estimation/anipose/aniposeTools.py @@ -823,6 +823,10 @@ def run_pose_test(session, test_name = None, cameras=params.default_cameras, for except Exception as e: logger.error(f"Error deleting {tri_file}: {e}") logging.info(f"Pose estimation completed for {session}.") + return test_dir + + + except Exception as e: logger.error(f"Error in pose test for {session}: {e}") diff --git a/beneuro_pose_estimation/cli.py b/beneuro_pose_estimation/cli.py index 8c6d89b..4112522 100644 --- a/beneuro_pose_estimation/cli.py +++ b/beneuro_pose_estimation/cli.py @@ -185,12 +185,12 @@ def pose_test( start_frame: Optional[int] = typer.Option( None, "--start-frame", "-s", - help="Frame number to start from. If not specified, uses frame 0." + help="Frame number to start from." ), duration: Optional[int] = typer.Option( - 10, + 5, "--duration", "-d", - help="Duration in seconds. If not specified, uses 100 frames." + help="Duration in seconds." ) ): """ @@ -201,8 +201,9 @@ def pose_test( if they already exist. """ from beneuro_pose_estimation.anipose.aniposeTools import run_pose_test + from beneuro_pose_estimation.evaluation import create_interactive_3d_animation - run_pose_test( + test_dir = run_pose_test( session=session, test_name=test_name, cameras=cameras or params.default_cameras, @@ -210,7 +211,17 @@ def pose_test( start_frame=start_frame, duration_seconds=duration, ) - + csv_path = test_dir/f"{session}_3dpts_angles.csv" + html_path = test_dir / f"{session}_3d_animation.html" + + create_interactive_3d_animation( + csv_filepath = str(csv_path), + output_html = str(html_path), + body_parts = params.body_parts, + frame_start = None, + frame_end = None, + ) + return @app.command() diff --git a/beneuro_pose_estimation/evaluation.py b/beneuro_pose_estimation/evaluation.py index cb33987..508f160 100644 --- a/beneuro_pose_estimation/evaluation.py +++ b/beneuro_pose_estimation/evaluation.py @@ -24,6 +24,9 @@ import json import seaborn as sns import sleap +import plotly.express as px + +import plotly.graph_objects as go config = _load_config() logger = logging.getLogger(__name__) @@ -873,6 +876,193 @@ def update(frame): return anim + +def create_interactive_3d_animation( + csv_filepath, + output_html="3d_animation.html", + body_parts=None, + constraints=None, + frame_start=None, + frame_end=None, + height=800, + width=800, + fps=100 +): + """ + Build an interactive 3D animation (rotatable!) from a CSV of 3D keypoints, + including skeleton edges drawn between connected keypoints. + + - csv must have a 'fnum' column plus for each keypoint columns + '_x','_y','_z'. + - If body_parts is None, will autodetect all bases ending in '_x'. + - constraints: list of [i,j] pairs indexing into body_parts, + defaults to params.constraints. + """ + csv_filepath = Path(csv_filepath) + df = pd.read_csv(csv_filepath) + + # optional frame filtering + if frame_start is not None or frame_end is not None: + lo = frame_start or df["fnum"].min() + hi = frame_end or df["fnum"].max() + 1 + df = df[(df["fnum"] >= lo) & (df["fnum"] < hi)] + + # detect keypoints + if body_parts is None: + body_parts = sorted({col[:-2] for col in df.columns if col.endswith("_x")}) + + # default skeleton + if constraints is None: + constraints = params.constraints + + # melt into long form + long = [] + for bp in body_parts: + long.append( + df[["fnum", f"{bp}_x", f"{bp}_y", f"{bp}_z"]] + .rename(columns={f"{bp}_x":"x", f"{bp}_y":"y", f"{bp}_z":"z"}) + .assign(keypoint=bp) + ) + long_df = pd.concat(long, axis=0) + + # build animated 3D scatter + fig = px.scatter_3d( + long_df, + x="x", y="y", z="z", + color="keypoint", + animation_frame="fnum", + height=height, + width=width, + title="3D Pose Animation", + labels={"fnum":"Frame"} + ) + fig.update_traces(marker=dict(size=5)) + fig.layout.updatemenus[0].buttons[0].args[1]["frame"]["duration"] = 1000 / fps + + # --- STEP 1: add one empty line‐trace per constraint into fig.data + n_kp = len(fig.data) + for _ in constraints: + fig.add_trace( + go.Scatter3d( + x=[], y=[], z=[], + mode="lines", + line=dict(color="black", width=2), + showlegend=False + ) + ) + # now fig.data length = n_kp + len(constraints) + + # --- STEP 2: for each frame, extend its data tuple by those same traces + for frame in fig.frames: + orig = list(frame.data) # the n_kp scatter traces + placeholders = fig.data[n_kp:] # the newly appended line traces + frame.data = tuple(orig + list(placeholders)) + + # now fill in each line‐trace for this frame + fnum = int(frame.name) + sub = long_df[long_df["fnum"] == fnum] + coords = {r.keypoint:(r.x, r.y, r.z) for r in sub.itertuples(index=False)} + # for each constraint, set the x/y/z on the placeholder trace + for idx, (i, j) in enumerate(constraints): + p1 = coords.get(body_parts[i]) + p2 = coords.get(body_parts[j]) + if p1 and p2: + xs, ys, zs = [p1[0], p2[0]], [p1[1], p2[1]], [p1[2], p2[2]] + else: + xs = ys = zs = [] + # placeholder trace lives at position n_kp + idx + tr = frame.data[n_kp + idx] + tr.x = xs + tr.y = ys + tr.z = zs + + # --- STEP 3: likewise update the initial (static) traces so lines show before animating + init = list(fig.data) + init_sub = long_df[long_df["fnum"] == int(fig.frames[0].name)] + init_coords = {r.keypoint:(r.x, r.y, r.z) for r in init_sub.itertuples(index=False)} + for idx, (i, j) in enumerate(constraints): + p1 = init_coords.get(body_parts[i]) + p2 = init_coords.get(body_parts[j]) + if p1 and p2: + xs, ys, zs = [p1[0], p2[0]], [p1[1], p2[1]], [p1[2], p2[2]] + else: + xs = ys = zs = [] + trace = init[n_kp + idx] + trace.x = xs + trace.y = ys + trace.z = zs + + fig.data = tuple(init) + + # save + output_html_name = Path(output_html).name + fig.write_html(output_html, include_plotlyjs="cdn") + logger.info(f"Interactive HTML written to: {output_html}") + logger.info(f"=============================================================\nRun on your local machine:\n\n1. scp user-name@bn-brainX.bg.ic.ac.uk:{output_html} . (to save it locally in the current directory)\n\n2. open {output_html_name} (to visualize the animation in your browser)") + + return fig +def create_interactive_3d_animation_old( + csv_filepath, + output_html="3d_animation.html", + body_parts=None, + frame_start=None, + frame_end=None, + height=800, + width=800, +): + """ + Build an interactive 3D animation (rotatable!) from a CSV of 3D keypoints. + + - csv must have a 'fnum' column plus for each keypoint columns + '_x','_y','_z'. + - If body_parts is None, will autodetect all bases ending in '_x'. + + Writes out `output_html` which you can open in your browser. + """ + csv_filepath = Path(csv_filepath) + df = pd.read_csv(csv_filepath) + + # filter frames + if frame_start is not None or frame_end is not None: + frame_start = frame_start or 0 + frame_end = frame_end or len(df) + df = df[(df["fnum"] >= frame_start) & (df["fnum"] < frame_end)] + + # detect keypoints + if body_parts is None: + # find all columns ending in '_x' + body_parts = sorted({col[:-2] for col in df.columns if col.endswith("_x")}) + + # melt into long form + long = [] + for bp in body_parts: + long.append( + df[["fnum", f"{bp}_x", f"{bp}_y", f"{bp}_z"]] + .rename(columns={f"{bp}_x":"x", f"{bp}_y":"y", f"{bp}_z":"z"}) + .assign(keypoint=bp) + ) + long_df = pd.concat(long, axis=0) + + # build animated 3D scatter + fig = px.scatter_3d( + long_df, + x="x", y="y", z="z", + color="keypoint", + animation_frame="fnum", + height=height, + width=width, + title="3D Pose Animation", + labels={"fnum":"Frame"} + ) + fig.update_traces(marker=dict(size=5)) + # tighten up the sliders/buttons + fig.layout.updatemenus[0].buttons[0].args[1]["frame"]["duration"] = 1000/30 # default ~30 FPS + + # save out + fig.write_html(output_html, include_plotlyjs="cdn") + logger.info(f"Interactive HTML written to: {output_html}") + return fig + def plot_reprojection_errors(session_name, test_dir, bins=50): all_errors = get_reprojection_errors(session_name, test_dir) flat_errors = all_errors.flatten()