Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion .github/workflows/lint.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,11 @@ jobs:
python-version: '3.10'

- name: Install linters
run: pip install autopep8 flake8
run: pip install autopep8 flake8 isort

- name: Check import sorting with isort
run: isort --check-only --diff eks tests
# Reads config from [tool.isort] in pyproject.toml

- name: Check formatting with autopep8
run: autopep8 --diff --recursive --exit-code eks tests
Expand All @@ -40,6 +44,7 @@ jobs:
echo ""
echo "To fix formatting issues locally, run:"
echo " autopep8 --in-place --recursive eks tests"
echo " isort eks tests"
echo ""
echo "To check for flake8 errors locally, run:"
echo " flake8 eks tests --select=E9,F63,F7,F82"
Expand Down
2 changes: 1 addition & 1 deletion eks/ibl_paw_multicam_smoother.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def fit_eks_multicam_ibl_paw(
)

# run eks
camera_dfs, smooth_params_final = ensemble_kalman_smoother_multicam(
camera_dfs, smooth_params_final, df_3d = ensemble_kalman_smoother_multicam(
marker_array=marker_array,
keypoint_names=bodypart_list,
smooth_param=smooth_param,
Expand Down
34 changes: 28 additions & 6 deletions eks/multicam_smoother.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def fit_eks_mirrored_multicam(
marker_array = input_dfs_to_markerArray(camera_model_dfs, bodypart_list, camera_names)

# Run the ensemble Kalman smoother for multi-camera data
camera_dfs, smooth_params_final = ensemble_kalman_smoother_multicam(
camera_dfs, smooth_params_final, df_3d = ensemble_kalman_smoother_multicam(
marker_array=marker_array,
keypoint_names=bodypart_list,
smooth_param=smooth_param,
Expand Down Expand Up @@ -135,7 +135,8 @@ def fit_eks_multicam(
inflate_vars: bool = False,
verbose: bool = False,
n_latent: int = 3,
calibration: str | None = None
calibration: str | None = None,
save_3d_outputs: bool = True,
) -> tuple:
"""
Fit the Ensemble Kalman Smoother for un-mirrored multi-camera data.
Expand All @@ -155,13 +156,15 @@ def fit_eks_multicam(
verbose: True to print out details
n_latent: number of dimensions to keep from PCA
calibration: path to the .toml calibration file for nonlinear projection
save_3d_outputs: if True and calibration is not None, save 3D latents to CSV

Returns:
tuple:
camera_dfs (list): List of Output Dataframes
s_finals (list): List of optimized smoothing factors for each keypoint.
input_dfs (list): List of input DataFrames for plotting.
bodypart_list (list): List of body parts used.
df_3d (pd.DataFrame): DataFrame with 3D latent states and posterior variances.

"""
# Load and format input files
Expand All @@ -177,7 +180,7 @@ def fit_eks_multicam(
marker_array = input_dfs_to_markerArray(input_dfs_list, bodypart_list, camera_names)

# Run the ensemble Kalman smoother for multi-camera data
camera_dfs, smooth_params_final = ensemble_kalman_smoother_multicam(
camera_dfs, smooth_params_final, df_3d = ensemble_kalman_smoother_multicam(
marker_array=marker_array,
keypoint_names=bodypart_list,
smooth_param=smooth_param,
Expand All @@ -196,7 +199,9 @@ def fit_eks_multicam(
for c, camera in enumerate(camera_names):
save_filename = f'multicam_{camera}_results.csv'
camera_dfs[c].to_csv(os.path.join(save_dir, save_filename))
return camera_dfs, smooth_params_final, input_dfs_list, bodypart_list
if save_3d_outputs and calibration is not None:
df_3d.to_csv(os.path.join(save_dir, 'multicam_3d_results.csv'))
return camera_dfs, smooth_params_final, input_dfs_list, bodypart_list, df_3d


@typechecked
Expand Down Expand Up @@ -246,7 +251,7 @@ def ensemble_kalman_smoother_multicam(
camgroup: loaded calibration file for nonlinear projection

Returns:
tuple: Dataframes with smoothed predictions, final smoothing parameters.
tuple: Dataframes with smoothed predictions, final smoothing parameters, 3D latent df.
"""

M, V, T, K, _ = marker_array.shape # n_models, n_cameras, n_timesteps, n_keypoints, (n_coords)
Expand Down Expand Up @@ -420,7 +425,24 @@ def ensemble_kalman_smoother_multicam(
camera_df = pd.DataFrame(camera_arr.T, columns=pdindex)
camera_dfs.append(camera_df)

return camera_dfs, s_finals
# Build 3D latent dataframe from Kalman smoother outputs
labels_3d = ['x', 'y', 'z', 'x_posterior_var', 'y_posterior_var', 'z_posterior_var']
pdindex_3d = make_dlc_pandas_index(keypoint_names, labels=labels_3d)
arr_3d = []
for k in range(K):
ms_k = np.array(ms[k]) # (T, 3)
Vs_k = np.array(Vs[k]) # (T, 3, 3)
arr_3d.extend([
ms_k[:, 0],
ms_k[:, 1],
ms_k[:, 2],
Vs_k[:, 0, 0],
Vs_k[:, 1, 1],
Vs_k[:, 2, 2],
])
df_3d = pd.DataFrame(np.asarray(arr_3d).T, columns=pdindex_3d)

return camera_dfs, s_finals, df_3d


def initialize_kalman_filter_pca(
Expand Down
2 changes: 1 addition & 1 deletion scripts/multicam_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
calibration = args.calibration

# Fit EKS using the provided input data
camera_dfs, s_finals, input_dfs, bodypart_list = fit_eks_multicam(
camera_dfs, s_finals, input_dfs, bodypart_list, df_3d = fit_eks_multicam(
input_source=input_source,
save_dir=save_dir,
bodypart_list=bodypart_list,
Expand Down
18 changes: 11 additions & 7 deletions tests/test_multicam_smoother.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import jax
import jax.numpy as jnp
import numpy as np
import pandas as pd
from sklearn.decomposition import PCA

from eks.marker_array import MarkerArray
Expand Down Expand Up @@ -44,7 +45,7 @@ def test_ensemble_kalman_smoother_multicam():
# ---------------------------------------------------
# Run the smoother
# ---------------------------------------------------
camera_dfs, smooth_params_final = ensemble_kalman_smoother_multicam(
camera_dfs, smooth_params_final, df_3d = ensemble_kalman_smoother_multicam(
marker_array=marker_array,
keypoint_names=keypoint_names,
smooth_param=smooth_param,
Expand All @@ -63,11 +64,12 @@ def test_ensemble_kalman_smoother_multicam():
assert smooth_params_final[k] == smooth_param, \
f"Expected smooth_param_final to match input smooth_param ({smooth_param}), " \
f"got {smooth_params_final}"
assert isinstance(df_3d, pd.DataFrame), "Expected output to be a pandas dataframe"

# ---------------------------------------------------
# Run with variance inflation
# ---------------------------------------------------
camera_dfs, smooth_params_final = ensemble_kalman_smoother_multicam(
camera_dfs, smooth_params_final, df_3d = ensemble_kalman_smoother_multicam(
marker_array=marker_array,
keypoint_names=keypoint_names,
smooth_param=smooth_param,
Expand All @@ -88,11 +90,12 @@ def test_ensemble_kalman_smoother_multicam():
assert smooth_params_final[k] == smooth_param, \
f"Expected smooth_param_final to match input smooth_param ({smooth_param}), " \
f"got {smooth_params_final}"
assert isinstance(df_3d, pd.DataFrame), "Expected output to be a pandas dataframe"

# ---------------------------------------------------
# Run with variance inflation + more maha kwargs
# ---------------------------------------------------
camera_dfs, smooth_params_final = ensemble_kalman_smoother_multicam(
camera_dfs, smooth_params_final, df_3d = ensemble_kalman_smoother_multicam(
marker_array=marker_array,
keypoint_names=keypoint_names,
smooth_param=smooth_param,
Expand All @@ -115,6 +118,7 @@ def test_ensemble_kalman_smoother_multicam():
assert smooth_params_final[k] == smooth_param, \
f"Expected smooth_param_final to match input smooth_param ({smooth_param}), " \
f"got {smooth_params_final}"
assert isinstance(df_3d, pd.DataFrame), "Expected output to be a pandas dataframe"

# ---------------------------------------------------
# Run with variance inflation + more maha kwargs
Expand All @@ -136,7 +140,7 @@ def test_ensemble_kalman_smoother_multicam():
markers_array[..., 2] = 1.0
marker_array = MarkerArray(markers_array, data_fields=data_fields)
# run with variance inflation
camera_dfs, smooth_params_final = ensemble_kalman_smoother_multicam(
camera_dfs, smooth_params_final, df_3d = ensemble_kalman_smoother_multicam(
marker_array=marker_array,
keypoint_names=keypoint_names,
smooth_param=smooth_param,
Expand Down Expand Up @@ -174,7 +178,7 @@ def test_ensemble_kalman_smoother_multicam_no_smooth_param():
s_frames = None

# Run the smoother without providing smooth_param
camera_dfs, smooth_params_final = ensemble_kalman_smoother_multicam(
camera_dfs, smooth_params_final, df_3d = ensemble_kalman_smoother_multicam(
marker_array=markerArray,
keypoint_names=keypoint_names,
smooth_param=None,
Expand Down Expand Up @@ -208,8 +212,8 @@ def test_ensemble_kalman_smoother_multicam_n_latent():
quantile_keep_pca = 90
s_frames = None

for n_latent in [2, 3, 5]: # Test different PCA dimensions
camera_dfs, _ = ensemble_kalman_smoother_multicam(
for n_latent in [3, 4, 5]: # Test different PCA dimensions
camera_dfs, _, _ = ensemble_kalman_smoother_multicam(
marker_array=markerArray,
keypoint_names=keypoint_names,
smooth_param=1, # Fixed smooth_param to speed up test
Expand Down