diff --git a/.github/workflows/lint.yaml b/.github/workflows/lint.yaml index edb7036..c0d5e07 100644 --- a/.github/workflows/lint.yaml +++ b/.github/workflows/lint.yaml @@ -22,7 +22,11 @@ jobs: python-version: '3.10' - name: Install linters - run: pip install autopep8 flake8 + run: pip install autopep8 flake8 isort + + - name: Check import sorting with isort + run: isort --check-only --diff eks tests + # Reads config from [tool.isort] in pyproject.toml - name: Check formatting with autopep8 run: autopep8 --diff --recursive --exit-code eks tests @@ -40,6 +44,7 @@ jobs: echo "" echo "To fix formatting issues locally, run:" echo " autopep8 --in-place --recursive eks tests" + echo " isort eks tests" echo "" echo "To check for flake8 errors locally, run:" echo " flake8 eks tests --select=E9,F63,F7,F82" diff --git a/eks/ibl_paw_multicam_smoother.py b/eks/ibl_paw_multicam_smoother.py index faca898..c2c1d63 100644 --- a/eks/ibl_paw_multicam_smoother.py +++ b/eks/ibl_paw_multicam_smoother.py @@ -174,7 +174,7 @@ def fit_eks_multicam_ibl_paw( ) # run eks - camera_dfs, smooth_params_final = ensemble_kalman_smoother_multicam( + camera_dfs, smooth_params_final, df_3d = ensemble_kalman_smoother_multicam( marker_array=marker_array, keypoint_names=bodypart_list, smooth_param=smooth_param, diff --git a/eks/multicam_smoother.py b/eks/multicam_smoother.py index 4f0c4fb..7ad122a 100644 --- a/eks/multicam_smoother.py +++ b/eks/multicam_smoother.py @@ -93,7 +93,7 @@ def fit_eks_mirrored_multicam( marker_array = input_dfs_to_markerArray(camera_model_dfs, bodypart_list, camera_names) # Run the ensemble Kalman smoother for multi-camera data - camera_dfs, smooth_params_final = ensemble_kalman_smoother_multicam( + camera_dfs, smooth_params_final, df_3d = ensemble_kalman_smoother_multicam( marker_array=marker_array, keypoint_names=bodypart_list, smooth_param=smooth_param, @@ -135,7 +135,8 @@ def fit_eks_multicam( inflate_vars: bool = False, verbose: bool = False, n_latent: int = 3, - calibration: str | None = None + calibration: str | None = None, + save_3d_outputs: bool = True, ) -> tuple: """ Fit the Ensemble Kalman Smoother for un-mirrored multi-camera data. @@ -155,6 +156,7 @@ def fit_eks_multicam( verbose: True to print out details n_latent: number of dimensions to keep from PCA calibration: path to the .toml calibration file for nonlinear projection + save_3d_outputs: if True and calibration is not None, save 3D latents to CSV Returns: tuple: @@ -162,6 +164,7 @@ def fit_eks_multicam( s_finals (list): List of optimized smoothing factors for each keypoint. input_dfs (list): List of input DataFrames for plotting. bodypart_list (list): List of body parts used. + df_3d (pd.DataFrame): DataFrame with 3D latent states and posterior variances. """ # Load and format input files @@ -177,7 +180,7 @@ def fit_eks_multicam( marker_array = input_dfs_to_markerArray(input_dfs_list, bodypart_list, camera_names) # Run the ensemble Kalman smoother for multi-camera data - camera_dfs, smooth_params_final = ensemble_kalman_smoother_multicam( + camera_dfs, smooth_params_final, df_3d = ensemble_kalman_smoother_multicam( marker_array=marker_array, keypoint_names=bodypart_list, smooth_param=smooth_param, @@ -196,7 +199,9 @@ def fit_eks_multicam( for c, camera in enumerate(camera_names): save_filename = f'multicam_{camera}_results.csv' camera_dfs[c].to_csv(os.path.join(save_dir, save_filename)) - return camera_dfs, smooth_params_final, input_dfs_list, bodypart_list + if save_3d_outputs and calibration is not None: + df_3d.to_csv(os.path.join(save_dir, 'multicam_3d_results.csv')) + return camera_dfs, smooth_params_final, input_dfs_list, bodypart_list, df_3d @typechecked @@ -246,7 +251,7 @@ def ensemble_kalman_smoother_multicam( camgroup: loaded calibration file for nonlinear projection Returns: - tuple: Dataframes with smoothed predictions, final smoothing parameters. + tuple: Dataframes with smoothed predictions, final smoothing parameters, 3D latent df. """ M, V, T, K, _ = marker_array.shape # n_models, n_cameras, n_timesteps, n_keypoints, (n_coords) @@ -420,7 +425,24 @@ def ensemble_kalman_smoother_multicam( camera_df = pd.DataFrame(camera_arr.T, columns=pdindex) camera_dfs.append(camera_df) - return camera_dfs, s_finals + # Build 3D latent dataframe from Kalman smoother outputs + labels_3d = ['x', 'y', 'z', 'x_posterior_var', 'y_posterior_var', 'z_posterior_var'] + pdindex_3d = make_dlc_pandas_index(keypoint_names, labels=labels_3d) + arr_3d = [] + for k in range(K): + ms_k = np.array(ms[k]) # (T, 3) + Vs_k = np.array(Vs[k]) # (T, 3, 3) + arr_3d.extend([ + ms_k[:, 0], + ms_k[:, 1], + ms_k[:, 2], + Vs_k[:, 0, 0], + Vs_k[:, 1, 1], + Vs_k[:, 2, 2], + ]) + df_3d = pd.DataFrame(np.asarray(arr_3d).T, columns=pdindex_3d) + + return camera_dfs, s_finals, df_3d def initialize_kalman_filter_pca( diff --git a/scripts/multicam_example.py b/scripts/multicam_example.py index 74cdad4..ecbda09 100644 --- a/scripts/multicam_example.py +++ b/scripts/multicam_example.py @@ -30,7 +30,7 @@ calibration = args.calibration # Fit EKS using the provided input data -camera_dfs, s_finals, input_dfs, bodypart_list = fit_eks_multicam( +camera_dfs, s_finals, input_dfs, bodypart_list, df_3d = fit_eks_multicam( input_source=input_source, save_dir=save_dir, bodypart_list=bodypart_list, diff --git a/tests/test_multicam_smoother.py b/tests/test_multicam_smoother.py index 34f7041..bcaf518 100644 --- a/tests/test_multicam_smoother.py +++ b/tests/test_multicam_smoother.py @@ -4,6 +4,7 @@ import jax import jax.numpy as jnp import numpy as np +import pandas as pd from sklearn.decomposition import PCA from eks.marker_array import MarkerArray @@ -44,7 +45,7 @@ def test_ensemble_kalman_smoother_multicam(): # --------------------------------------------------- # Run the smoother # --------------------------------------------------- - camera_dfs, smooth_params_final = ensemble_kalman_smoother_multicam( + camera_dfs, smooth_params_final, df_3d = ensemble_kalman_smoother_multicam( marker_array=marker_array, keypoint_names=keypoint_names, smooth_param=smooth_param, @@ -63,11 +64,12 @@ def test_ensemble_kalman_smoother_multicam(): assert smooth_params_final[k] == smooth_param, \ f"Expected smooth_param_final to match input smooth_param ({smooth_param}), " \ f"got {smooth_params_final}" + assert isinstance(df_3d, pd.DataFrame), "Expected output to be a pandas dataframe" # --------------------------------------------------- # Run with variance inflation # --------------------------------------------------- - camera_dfs, smooth_params_final = ensemble_kalman_smoother_multicam( + camera_dfs, smooth_params_final, df_3d = ensemble_kalman_smoother_multicam( marker_array=marker_array, keypoint_names=keypoint_names, smooth_param=smooth_param, @@ -88,11 +90,12 @@ def test_ensemble_kalman_smoother_multicam(): assert smooth_params_final[k] == smooth_param, \ f"Expected smooth_param_final to match input smooth_param ({smooth_param}), " \ f"got {smooth_params_final}" + assert isinstance(df_3d, pd.DataFrame), "Expected output to be a pandas dataframe" # --------------------------------------------------- # Run with variance inflation + more maha kwargs # --------------------------------------------------- - camera_dfs, smooth_params_final = ensemble_kalman_smoother_multicam( + camera_dfs, smooth_params_final, df_3d = ensemble_kalman_smoother_multicam( marker_array=marker_array, keypoint_names=keypoint_names, smooth_param=smooth_param, @@ -115,6 +118,7 @@ def test_ensemble_kalman_smoother_multicam(): assert smooth_params_final[k] == smooth_param, \ f"Expected smooth_param_final to match input smooth_param ({smooth_param}), " \ f"got {smooth_params_final}" + assert isinstance(df_3d, pd.DataFrame), "Expected output to be a pandas dataframe" # --------------------------------------------------- # Run with variance inflation + more maha kwargs @@ -136,7 +140,7 @@ def test_ensemble_kalman_smoother_multicam(): markers_array[..., 2] = 1.0 marker_array = MarkerArray(markers_array, data_fields=data_fields) # run with variance inflation - camera_dfs, smooth_params_final = ensemble_kalman_smoother_multicam( + camera_dfs, smooth_params_final, df_3d = ensemble_kalman_smoother_multicam( marker_array=marker_array, keypoint_names=keypoint_names, smooth_param=smooth_param, @@ -174,7 +178,7 @@ def test_ensemble_kalman_smoother_multicam_no_smooth_param(): s_frames = None # Run the smoother without providing smooth_param - camera_dfs, smooth_params_final = ensemble_kalman_smoother_multicam( + camera_dfs, smooth_params_final, df_3d = ensemble_kalman_smoother_multicam( marker_array=markerArray, keypoint_names=keypoint_names, smooth_param=None, @@ -208,8 +212,8 @@ def test_ensemble_kalman_smoother_multicam_n_latent(): quantile_keep_pca = 90 s_frames = None - for n_latent in [2, 3, 5]: # Test different PCA dimensions - camera_dfs, _ = ensemble_kalman_smoother_multicam( + for n_latent in [3, 4, 5]: # Test different PCA dimensions + camera_dfs, _, _ = ensemble_kalman_smoother_multicam( marker_array=markerArray, keypoint_names=keypoint_names, smooth_param=1, # Fixed smooth_param to speed up test