Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion beneuro_pose_estimation/anipose/aniposeTools.py
Original file line number Diff line number Diff line change
Expand Up @@ -758,9 +758,10 @@ def run_pose_test(session, test_name = None, cameras=params.default_cameras, for
logger.info("Creating test videos...")
tests_dir = tools.create_test_videos(session, cameras, duration_seconds,
force_new=force_new_videos, start_frame=start_frame)
test_dir = tests_dir / test_name

if test_name is None:
test_name = session + "_test"
test_dir = tests_dir / test_name
# 2. Run 2D predictions on test videos
logger.info("Running 2D predictions...")
sleapTools.get_2Dpredictions(session, cameras, test_name = test_name)
Expand Down Expand Up @@ -822,6 +823,10 @@ def run_pose_test(session, test_name = None, cameras=params.default_cameras, for
except Exception as e:
logger.error(f"Error deleting {tri_file}: {e}")
logging.info(f"Pose estimation completed for {session}.")
return test_dir




except Exception as e:
logger.error(f"Error in pose test for {session}: {e}")
Expand Down
21 changes: 16 additions & 5 deletions beneuro_pose_estimation/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,12 +185,12 @@ def pose_test(
start_frame: Optional[int] = typer.Option(
None,
"--start-frame", "-s",
help="Frame number to start from. If not specified, uses frame 0."
help="Frame number to start from."
),
duration: Optional[int] = typer.Option(
10,
5,
"--duration", "-d",
help="Duration in seconds. If not specified, uses 100 frames."
help="Duration in seconds."
)
):
"""
Expand All @@ -201,16 +201,27 @@ def pose_test(
if they already exist.
"""
from beneuro_pose_estimation.anipose.aniposeTools import run_pose_test
from beneuro_pose_estimation.evaluation import create_interactive_3d_animation

run_pose_test(
test_dir = run_pose_test(
session=session,
test_name=test_name,
cameras=cameras or params.default_cameras,
force_new_videos=force_new,
start_frame=start_frame,
duration_seconds=duration,
)

csv_path = test_dir/f"{session}_3dpts_angles.csv"
html_path = test_dir / f"{session}_3d_animation.html"

create_interactive_3d_animation(
csv_filepath = str(csv_path),
output_html = str(html_path),
body_parts = params.body_parts,
frame_start = None,
frame_end = None,
)

return

@app.command()
Expand Down
190 changes: 190 additions & 0 deletions beneuro_pose_estimation/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@
import json
import seaborn as sns
import sleap
import plotly.express as px

import plotly.graph_objects as go
config = _load_config()
logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -873,6 +876,193 @@ def update(frame):

return anim


def create_interactive_3d_animation(
csv_filepath,
output_html="3d_animation.html",
body_parts=None,
constraints=None,
frame_start=None,
frame_end=None,
height=800,
width=800,
fps=100
):
"""
Build an interactive 3D animation (rotatable!) from a CSV of 3D keypoints,
including skeleton edges drawn between connected keypoints.

- csv must have a 'fnum' column plus for each keypoint columns
'<part>_x','<part>_y','<part>_z'.
- If body_parts is None, will autodetect all bases ending in '_x'.
- constraints: list of [i,j] pairs indexing into body_parts,
defaults to params.constraints.
"""
csv_filepath = Path(csv_filepath)
df = pd.read_csv(csv_filepath)

# optional frame filtering
if frame_start is not None or frame_end is not None:
lo = frame_start or df["fnum"].min()
hi = frame_end or df["fnum"].max() + 1
df = df[(df["fnum"] >= lo) & (df["fnum"] < hi)]

# detect keypoints
if body_parts is None:
body_parts = sorted({col[:-2] for col in df.columns if col.endswith("_x")})

# default skeleton
if constraints is None:
constraints = params.constraints

# melt into long form
long = []
for bp in body_parts:
long.append(
df[["fnum", f"{bp}_x", f"{bp}_y", f"{bp}_z"]]
.rename(columns={f"{bp}_x":"x", f"{bp}_y":"y", f"{bp}_z":"z"})
.assign(keypoint=bp)
)
long_df = pd.concat(long, axis=0)

# build animated 3D scatter
fig = px.scatter_3d(
long_df,
x="x", y="y", z="z",
color="keypoint",
animation_frame="fnum",
height=height,
width=width,
title="3D Pose Animation",
labels={"fnum":"Frame"}
)
fig.update_traces(marker=dict(size=5))
fig.layout.updatemenus[0].buttons[0].args[1]["frame"]["duration"] = 1000 / fps

# --- STEP 1: add one empty line‐trace per constraint into fig.data
n_kp = len(fig.data)
for _ in constraints:
fig.add_trace(
go.Scatter3d(
x=[], y=[], z=[],
mode="lines",
line=dict(color="black", width=2),
showlegend=False
)
)
# now fig.data length = n_kp + len(constraints)

# --- STEP 2: for each frame, extend its data tuple by those same traces
for frame in fig.frames:
orig = list(frame.data) # the n_kp scatter traces
placeholders = fig.data[n_kp:] # the newly appended line traces
frame.data = tuple(orig + list(placeholders))

# now fill in each line‐trace for this frame
fnum = int(frame.name)
sub = long_df[long_df["fnum"] == fnum]
coords = {r.keypoint:(r.x, r.y, r.z) for r in sub.itertuples(index=False)}
# for each constraint, set the x/y/z on the placeholder trace
for idx, (i, j) in enumerate(constraints):
p1 = coords.get(body_parts[i])
p2 = coords.get(body_parts[j])
if p1 and p2:
xs, ys, zs = [p1[0], p2[0]], [p1[1], p2[1]], [p1[2], p2[2]]
else:
xs = ys = zs = []
# placeholder trace lives at position n_kp + idx
tr = frame.data[n_kp + idx]
tr.x = xs
tr.y = ys
tr.z = zs

# --- STEP 3: likewise update the initial (static) traces so lines show before animating
init = list(fig.data)
init_sub = long_df[long_df["fnum"] == int(fig.frames[0].name)]
init_coords = {r.keypoint:(r.x, r.y, r.z) for r in init_sub.itertuples(index=False)}
for idx, (i, j) in enumerate(constraints):
p1 = init_coords.get(body_parts[i])
p2 = init_coords.get(body_parts[j])
if p1 and p2:
xs, ys, zs = [p1[0], p2[0]], [p1[1], p2[1]], [p1[2], p2[2]]
else:
xs = ys = zs = []
trace = init[n_kp + idx]
trace.x = xs
trace.y = ys
trace.z = zs

fig.data = tuple(init)

# save
output_html_name = Path(output_html).name
fig.write_html(output_html, include_plotlyjs="cdn")
logger.info(f"Interactive HTML written to: {output_html}")
logger.info(f"=============================================================\nRun on your local machine:\n\n1. scp user-name@bn-brainX.bg.ic.ac.uk:{output_html} . (to save it locally in the current directory)\n\n2. open {output_html_name} (to visualize the animation in your browser)")

return fig
def create_interactive_3d_animation_old(
csv_filepath,
output_html="3d_animation.html",
body_parts=None,
frame_start=None,
frame_end=None,
height=800,
width=800,
):
"""
Build an interactive 3D animation (rotatable!) from a CSV of 3D keypoints.

- csv must have a 'fnum' column plus for each keypoint columns
'<part>_x','<part>_y','<part>_z'.
- If body_parts is None, will autodetect all bases ending in '_x'.

Writes out `output_html` which you can open in your browser.
"""
csv_filepath = Path(csv_filepath)
df = pd.read_csv(csv_filepath)

# filter frames
if frame_start is not None or frame_end is not None:
frame_start = frame_start or 0
frame_end = frame_end or len(df)
df = df[(df["fnum"] >= frame_start) & (df["fnum"] < frame_end)]

# detect keypoints
if body_parts is None:
# find all columns ending in '_x'
body_parts = sorted({col[:-2] for col in df.columns if col.endswith("_x")})

# melt into long form
long = []
for bp in body_parts:
long.append(
df[["fnum", f"{bp}_x", f"{bp}_y", f"{bp}_z"]]
.rename(columns={f"{bp}_x":"x", f"{bp}_y":"y", f"{bp}_z":"z"})
.assign(keypoint=bp)
)
long_df = pd.concat(long, axis=0)

# build animated 3D scatter
fig = px.scatter_3d(
long_df,
x="x", y="y", z="z",
color="keypoint",
animation_frame="fnum",
height=height,
width=width,
title="3D Pose Animation",
labels={"fnum":"Frame"}
)
fig.update_traces(marker=dict(size=5))
# tighten up the sliders/buttons
fig.layout.updatemenus[0].buttons[0].args[1]["frame"]["duration"] = 1000/30 # default ~30 FPS

# save out
fig.write_html(output_html, include_plotlyjs="cdn")
logger.info(f"Interactive HTML written to: {output_html}")
return fig

def plot_reprojection_errors(session_name, test_dir, bins=50):
all_errors = get_reprojection_errors(session_name, test_dir)
flat_errors = all_errors.flatten()
Expand Down
Loading