diff --git a/bnd/pipeline/nwbtools/pycontrol_interface.py b/bnd/pipeline/nwbtools/pycontrol_interface.py index 0b195ea..4d443ce 100644 --- a/bnd/pipeline/nwbtools/pycontrol_interface.py +++ b/bnd/pipeline/nwbtools/pycontrol_interface.py @@ -83,21 +83,41 @@ def adjust_timestamps(self, start_time: int) -> None: for k in self.session.analog_data.keys(): self.session.analog_data[k][:, 0] -= start_time - def _get_pos_timestamps(self) -> np.ndarray: - time_x = self.session.analog_data["MotSen1-X"][:, 0] - time_y = self.session.analog_data["MotSen1-Y"][:, 0] - - assert len(time_x) == len(time_y) - assert np.all(time_x == time_y) - - return time_x + def _get_pos_timestamps_data(self, motion_sensor: str) -> np.ndarray: + """Get motion sensore data and timestamptes + + Parameters + ---------- + motion_sensor : str + Name of motion sensore x or y (e.g., "MotSen1-X" or ""MotSen1-y") + + Returns + ------- + time : np.ndarray + data : np.ndarray + """ + if motion_sensor not in ["MotSen1-X", "MotSen1-Y"]: + raise ValueError( + f"motion sensor: {motion_sensor} not a valid option (['MotSen1-X', 'MotSen1-Y'])" + ) - def _get_pos_data(self) -> np.ndarray: - data_x = self.session.analog_data["MotSen1-X"][:, 1] - data_y = self.session.analog_data["MotSen1-Y"][:, 1] - pos_data = np.stack([data_x, data_y]).T + time, data = self.session.analog_data[f"{motion_sensor}"][:, [0, 1]].T + return time, data - return pos_data + def _get_spatial_series(self, motion_sensor: str): + if motion_sensor not in ["MotSen1-X", "MotSen1-Y"]: + raise ValueError( + f"motion sensor: {motion_sensor} not a valid option (['MotSen1-X', 'MotSen1-Y'])" + ) + time, data = self._get_pos_timestamps_data(motion_sensor) + spatial_series_obj = SpatialSeries( + name=f"{motion_sensor}", + description=f"Ball position as measured by PyControl ({motion_sensor})", + data=data, + timestamps=time.astype(float), + reference_frame="(0,0) is what?", # TODO + ) + return spatial_series_obj def _add_to_behavior_module(self, beh_obj, nwbfile: NWBFile) -> None: # behavior_module = nwbfile.processing.get( @@ -114,15 +134,19 @@ def _add_to_behavior_module(self, beh_obj, nwbfile: NWBFile) -> None: behavior_module.add(beh_obj) def add_position(self, nwbfile: NWBFile) -> None: - spatial_series_obj = SpatialSeries( - name="Ball position", - description="(x,y) position as measured by PyControl", - data=self._get_pos_data(), - timestamps=self._get_pos_timestamps().astype(float), - reference_frame="(0,0) is what?", # TODO - ) - self._add_to_behavior_module(Position(spatial_series=spatial_series_obj), nwbfile) + spatial_series_obj_motion_sensor_x = self._get_spatial_series("MotSen1-X") + spatial_series_obj_motion_sensor_y = self._get_spatial_series("MotSen1-Y") + + self._add_to_behavior_module( + Position( + spatial_series=[ + spatial_series_obj_motion_sensor_x, + spatial_series_obj_motion_sensor_y, + ] + ), + nwbfile, + ) def add_print_events(self, nwbfile: NWBFile): print_events = BehavioralEvents(name="print_events") @@ -202,7 +226,7 @@ def add_to_nwbfile(self, nwbfile: NWBFile, metadata: DeepDict) -> None: try: self.add_position(nwbfile) except Exception as e: - logger.warning(f"Error parsing motion sensores: {e}") + logger.warning(f"Error parsing motion sensors: {e}") def get_metadata(self) -> DeepDict: metadata = DeepDict() diff --git a/bnd/pipeline/pyaldata.py b/bnd/pipeline/pyaldata.py index 0f192ab..41c66eb 100644 --- a/bnd/pipeline/pyaldata.py +++ b/bnd/pipeline/pyaldata.py @@ -273,20 +273,18 @@ def _parse_spatial_series(spatial_series: SpatialSeries) -> pd.DataFrame: pd.DataFrame : Contains x, y, z and timestamp """ + df = pd.DataFrame() - if spatial_series.data[:].shape[1] == 2: - colnames = ["x", "y"] - elif spatial_series.data[:].shape[1] == 3: - colnames = ["x", "y", "z"] - else: - raise ValueError( - f"Shape {spatial_series.data[:].shape} is not supported by pynwb. " - f"Please provide a valid SpatialSeries object" - ) + if spatial_series.data.ndim == 1: + df["data"] = spatial_series.data[:] - df = pd.DataFrame() - for i, col in enumerate(colnames): - df[col] = spatial_series.data[:, i] + else: + if spatial_series.data[:].shape[1] > 3: + raise ValueError( + f"Shape {spatial_series.data[:].shape} is not supported by pynwb. " + f"Please provide a valid SpatialSeries object" + ) + raise NotImplementedError("We currently only support 1d spatial series.") df["timestamps"] = spatial_series.timestamps[:] @@ -448,9 +446,7 @@ def parse_nwb_pycontrol_events(self) -> None: # Behavioural event dont have values but print events do so we need to # stay consistent with dimension df_behav_events["event"] = behav_events_time_series.data[:] - df_behav_events["value"] = np.full( - behav_events_time_series.data[:].shape[0], np.nan - ) + df_behav_events["value"] = np.full(behav_events_time_series.data[:].shape[0], np.nan) df_behav_events["timestamp"] = behav_events_time_series.timestamps[:] # Then make dataframe with print events, and a df for each print event @@ -466,9 +462,7 @@ def parse_nwb_pycontrol_events(self) -> None: tmp_df["value"] = print_events_time_series[print_event].data[:] tmp_df["timestamp"] = print_events_time_series[print_event].timestamps[:] - df_print_events = pd.concat( - [df_print_events, tmp_df], axis=0, ignore_index=True - ) + df_print_events = pd.concat([df_print_events, tmp_df], axis=0, ignore_index=True) # Concatenate both dataframes df_events = pd.concat([df_behav_events, df_print_events], axis=0, ignore_index=True) @@ -490,15 +484,21 @@ def try_to_parse_motion_sensors(self) -> None: logger.warning("No motion data available") return - ball_position_spatial_series = self.behavior["Position"].spatial_series[ - "Ball position" - ] - self.pycontrol_motion_sensors = _parse_spatial_series(ball_position_spatial_series) + self.pycontrol_motion_sensors = {} + + for spatial_series_key, spatial_series in self.behavior[ + "Position" + ].spatial_series.items(): + # These keys will normally be MotSen1-X or MotSen1-Y + self.pycontrol_motion_sensors[f"{spatial_series_key.replace('-', '_')}"] = ( + _parse_spatial_series(spatial_series) + ) + return def try_parsing_anipose_output(self): """ - Add anipose data (xyz) or angle to instance as a dictionary with each keypoint + Add anipose data (xyz) or angle to instance as a dictionary with each keypo int Returns ------- @@ -554,9 +554,9 @@ def add_pycontrol_states_to_df(self): self.pycontrol_states.start_time.values[:] / 1000 / self.bin_size ).astype(int) self.pyaldata_df["idx_trial_end"] = ( - np.floor( - self.pycontrol_states.stop_time.values[:] / 1000 / self.bin_size - ).astype(int) + np.floor(self.pycontrol_states.stop_time.values[:] / 1000 / self.bin_size).astype( + int + ) - 1 ) self.pyaldata_df["trial_name"] = self.pycontrol_states.state_name[:] @@ -594,26 +594,31 @@ def add_pycontrol_events_to_df(self): return def add_motion_sensor_data_to_df(self): + if hasattr(self, "pycontrol_motion_sensors"): - # Bin timestamps - self.pycontrol_motion_sensors["timestamp_idx"] = np.floor( - self.pycontrol_motion_sensors.timestamps.values[:] / 1000 / self.bin_size - ).astype(int) + for mot_sens_key, mot_sens in self.pycontrol_motion_sensors.items(): + # Bin timestamps + self.pycontrol_motion_sensors[f"{mot_sens_key}"]["timestamp_idx"] = np.floor( + mot_sens.timestamps.values[:] / 1000 / self.bin_size + ).astype(int) - # Add columns - self.pyaldata_df["motion_sensor_xy"] = np.nan + # Create column + self.pyaldata_df[f"values_{mot_sens_key}"] = np.nan + self.pyaldata_df[f"idx_{mot_sens_key}"] = np.nan + + # Add data in relevant rows (states) + self.pyaldata_df = _add_data_to_trial( + df_to_add_to=self.pyaldata_df, + new_data_column=f"values_{mot_sens_key}", + df_to_add_from=mot_sens, + columns_to_read_from="data", # TODO extend this when there are two columns + timestamp_column=f"idx_{mot_sens_key}", + ) - # Add data - self.pyaldata_df = _add_data_to_trial( - df_to_add_to=self.pyaldata_df, - new_data_column="motion_sensor_xy", - df_to_add_from=self.pycontrol_motion_sensors, - columns_to_read_from=["x", "y"], - timestamp_column=None, - ) return def add_anipose_data_to_df(self): + # TODO if hasattr(self, "anipose_data"): for anipose_key, anipose_value in self.anipose_data.items(): # Bin timestamps @@ -702,9 +707,11 @@ def purge_nan_columns(self, column_subset="values_") -> None: def _is_empty_array_or_nans(value): if isinstance(value, np.ndarray): - if value.ndim != 0 and all(np.isnan(item) for item in value): + if value.ndim == 0 and np.isnan(value.item()): return True - elif value.ndim == 0 and not np.isnan(value.item()): + elif value.ndim > 0 and all(np.isnan(item) for item in value): + return True + else: return False elif value == np.nan: return True @@ -712,6 +719,7 @@ def _is_empty_array_or_nans(value): return False for col_name in columns_to_select: + if self.pyaldata_df[col_name].apply(_is_empty_array_or_nans).all(): self.pyaldata_df.drop(col_name, axis=1, inplace=True) @@ -812,9 +820,7 @@ def _partition_and_save_to_mat(self): return else: # Partition array - logger.info( - f"Session ({nbytes / 2**30:.2f} GB) exceeds matlab 5 format (2 GB) " - ) + logger.info(f"Session ({nbytes / 2**30:.2f} GB) exceeds matlab 5 format (2 GB) ") logger.info(f"Partitioning array into {num_partitions} chunks...") partition_sizes = [ @@ -865,9 +871,7 @@ def save(self): logger.info("Please enter 'y' for yes or 'n' for no.") else: self._partition_and_save_to_mat() - logger.info( - f"Saved pyaldata file(s) in {self.nwbfile_path.parent.name} session" - ) + logger.info(f"Saved pyaldata file(s) in {self.nwbfile_path.parent.name} session") return