Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 46 additions & 22 deletions bnd/pipeline/nwbtools/pycontrol_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,21 +83,41 @@ def adjust_timestamps(self, start_time: int) -> None:
for k in self.session.analog_data.keys():
self.session.analog_data[k][:, 0] -= start_time

def _get_pos_timestamps(self) -> np.ndarray:
time_x = self.session.analog_data["MotSen1-X"][:, 0]
time_y = self.session.analog_data["MotSen1-Y"][:, 0]

assert len(time_x) == len(time_y)
assert np.all(time_x == time_y)

return time_x
def _get_pos_timestamps_data(self, motion_sensor: str) -> np.ndarray:
"""Get motion sensore data and timestamptes

Parameters
----------
motion_sensor : str
Name of motion sensore x or y (e.g., "MotSen1-X" or ""MotSen1-y")

Returns
-------
time : np.ndarray
data : np.ndarray
"""
if motion_sensor not in ["MotSen1-X", "MotSen1-Y"]:
raise ValueError(
f"motion sensor: {motion_sensor} not a valid option (['MotSen1-X', 'MotSen1-Y'])"
)

def _get_pos_data(self) -> np.ndarray:
data_x = self.session.analog_data["MotSen1-X"][:, 1]
data_y = self.session.analog_data["MotSen1-Y"][:, 1]
pos_data = np.stack([data_x, data_y]).T
time, data = self.session.analog_data[f"{motion_sensor}"][:, [0, 1]].T
return time, data

return pos_data
def _get_spatial_series(self, motion_sensor: str):
if motion_sensor not in ["MotSen1-X", "MotSen1-Y"]:
raise ValueError(
f"motion sensor: {motion_sensor} not a valid option (['MotSen1-X', 'MotSen1-Y'])"
)
time, data = self._get_pos_timestamps_data(motion_sensor)
spatial_series_obj = SpatialSeries(
name=f"{motion_sensor}",
description=f"Ball position as measured by PyControl ({motion_sensor})",
data=data,
timestamps=time.astype(float),
reference_frame="(0,0) is what?", # TODO
)
return spatial_series_obj

def _add_to_behavior_module(self, beh_obj, nwbfile: NWBFile) -> None:
# behavior_module = nwbfile.processing.get(
Expand All @@ -114,15 +134,19 @@ def _add_to_behavior_module(self, beh_obj, nwbfile: NWBFile) -> None:
behavior_module.add(beh_obj)

def add_position(self, nwbfile: NWBFile) -> None:
spatial_series_obj = SpatialSeries(
name="Ball position",
description="(x,y) position as measured by PyControl",
data=self._get_pos_data(),
timestamps=self._get_pos_timestamps().astype(float),
reference_frame="(0,0) is what?", # TODO
)

self._add_to_behavior_module(Position(spatial_series=spatial_series_obj), nwbfile)
spatial_series_obj_motion_sensor_x = self._get_spatial_series("MotSen1-X")
spatial_series_obj_motion_sensor_y = self._get_spatial_series("MotSen1-Y")

self._add_to_behavior_module(
Position(
spatial_series=[
spatial_series_obj_motion_sensor_x,
spatial_series_obj_motion_sensor_y,
]
),
nwbfile,
)

def add_print_events(self, nwbfile: NWBFile):
print_events = BehavioralEvents(name="print_events")
Expand Down Expand Up @@ -202,7 +226,7 @@ def add_to_nwbfile(self, nwbfile: NWBFile, metadata: DeepDict) -> None:
try:
self.add_position(nwbfile)
except Exception as e:
logger.warning(f"Error parsing motion sensores: {e}")
logger.warning(f"Error parsing motion sensors: {e}")

def get_metadata(self) -> DeepDict:
metadata = DeepDict()
Expand Down
100 changes: 52 additions & 48 deletions bnd/pipeline/pyaldata.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,20 +273,18 @@ def _parse_spatial_series(spatial_series: SpatialSeries) -> pd.DataFrame:
pd.DataFrame :
Contains x, y, z and timestamp
"""
df = pd.DataFrame()

if spatial_series.data[:].shape[1] == 2:
colnames = ["x", "y"]
elif spatial_series.data[:].shape[1] == 3:
colnames = ["x", "y", "z"]
else:
raise ValueError(
f"Shape {spatial_series.data[:].shape} is not supported by pynwb. "
f"Please provide a valid SpatialSeries object"
)
if spatial_series.data.ndim == 1:
df["data"] = spatial_series.data[:]

df = pd.DataFrame()
for i, col in enumerate(colnames):
df[col] = spatial_series.data[:, i]
else:
if spatial_series.data[:].shape[1] > 3:
raise ValueError(
f"Shape {spatial_series.data[:].shape} is not supported by pynwb. "
f"Please provide a valid SpatialSeries object"
)
raise NotImplementedError("We currently only support 1d spatial series.")

df["timestamps"] = spatial_series.timestamps[:]

Expand Down Expand Up @@ -448,9 +446,7 @@ def parse_nwb_pycontrol_events(self) -> None:
# Behavioural event dont have values but print events do so we need to
# stay consistent with dimension
df_behav_events["event"] = behav_events_time_series.data[:]
df_behav_events["value"] = np.full(
behav_events_time_series.data[:].shape[0], np.nan
)
df_behav_events["value"] = np.full(behav_events_time_series.data[:].shape[0], np.nan)
df_behav_events["timestamp"] = behav_events_time_series.timestamps[:]

# Then make dataframe with print events, and a df for each print event
Expand All @@ -466,9 +462,7 @@ def parse_nwb_pycontrol_events(self) -> None:
tmp_df["value"] = print_events_time_series[print_event].data[:]
tmp_df["timestamp"] = print_events_time_series[print_event].timestamps[:]

df_print_events = pd.concat(
[df_print_events, tmp_df], axis=0, ignore_index=True
)
df_print_events = pd.concat([df_print_events, tmp_df], axis=0, ignore_index=True)

# Concatenate both dataframes
df_events = pd.concat([df_behav_events, df_print_events], axis=0, ignore_index=True)
Expand All @@ -490,15 +484,21 @@ def try_to_parse_motion_sensors(self) -> None:
logger.warning("No motion data available")
return

ball_position_spatial_series = self.behavior["Position"].spatial_series[
"Ball position"
]
self.pycontrol_motion_sensors = _parse_spatial_series(ball_position_spatial_series)
self.pycontrol_motion_sensors = {}

for spatial_series_key, spatial_series in self.behavior[
"Position"
].spatial_series.items():
# These keys will normally be MotSen1-X or MotSen1-Y
self.pycontrol_motion_sensors[f"{spatial_series_key.replace('-', '_')}"] = (
_parse_spatial_series(spatial_series)
)

return

def try_parsing_anipose_output(self):
"""
Add anipose data (xyz) or angle to instance as a dictionary with each keypoint
Add anipose data (xyz) or angle to instance as a dictionary with each keypo int

Returns
-------
Expand Down Expand Up @@ -554,9 +554,9 @@ def add_pycontrol_states_to_df(self):
self.pycontrol_states.start_time.values[:] / 1000 / self.bin_size
).astype(int)
self.pyaldata_df["idx_trial_end"] = (
np.floor(
self.pycontrol_states.stop_time.values[:] / 1000 / self.bin_size
).astype(int)
np.floor(self.pycontrol_states.stop_time.values[:] / 1000 / self.bin_size).astype(
int
)
- 1
)
self.pyaldata_df["trial_name"] = self.pycontrol_states.state_name[:]
Expand Down Expand Up @@ -594,26 +594,31 @@ def add_pycontrol_events_to_df(self):
return

def add_motion_sensor_data_to_df(self):

if hasattr(self, "pycontrol_motion_sensors"):
# Bin timestamps
self.pycontrol_motion_sensors["timestamp_idx"] = np.floor(
self.pycontrol_motion_sensors.timestamps.values[:] / 1000 / self.bin_size
).astype(int)
for mot_sens_key, mot_sens in self.pycontrol_motion_sensors.items():
# Bin timestamps
self.pycontrol_motion_sensors[f"{mot_sens_key}"]["timestamp_idx"] = np.floor(
mot_sens.timestamps.values[:] / 1000 / self.bin_size
).astype(int)

# Add columns
self.pyaldata_df["motion_sensor_xy"] = np.nan
# Create column
self.pyaldata_df[f"values_{mot_sens_key}"] = np.nan
self.pyaldata_df[f"idx_{mot_sens_key}"] = np.nan

# Add data in relevant rows (states)
self.pyaldata_df = _add_data_to_trial(
df_to_add_to=self.pyaldata_df,
new_data_column=f"values_{mot_sens_key}",
df_to_add_from=mot_sens,
columns_to_read_from="data", # TODO extend this when there are two columns
timestamp_column=f"idx_{mot_sens_key}",
)

# Add data
self.pyaldata_df = _add_data_to_trial(
df_to_add_to=self.pyaldata_df,
new_data_column="motion_sensor_xy",
df_to_add_from=self.pycontrol_motion_sensors,
columns_to_read_from=["x", "y"],
timestamp_column=None,
)
return

def add_anipose_data_to_df(self):
# TODO
if hasattr(self, "anipose_data"):
for anipose_key, anipose_value in self.anipose_data.items():
# Bin timestamps
Expand Down Expand Up @@ -702,16 +707,19 @@ def purge_nan_columns(self, column_subset="values_") -> None:

def _is_empty_array_or_nans(value):
if isinstance(value, np.ndarray):
if value.ndim != 0 and all(np.isnan(item) for item in value):
if value.ndim == 0 and np.isnan(value.item()):
return True
elif value.ndim == 0 and not np.isnan(value.item()):
elif value.ndim > 0 and all(np.isnan(item) for item in value):
return True
else:
return False
elif value == np.nan:
return True
else:
return False

for col_name in columns_to_select:

if self.pyaldata_df[col_name].apply(_is_empty_array_or_nans).all():
self.pyaldata_df.drop(col_name, axis=1, inplace=True)

Expand Down Expand Up @@ -812,9 +820,7 @@ def _partition_and_save_to_mat(self):
return
else:
# Partition array
logger.info(
f"Session ({nbytes / 2**30:.2f} GB) exceeds matlab 5 format (2 GB) "
)
logger.info(f"Session ({nbytes / 2**30:.2f} GB) exceeds matlab 5 format (2 GB) ")

logger.info(f"Partitioning array into {num_partitions} chunks...")
partition_sizes = [
Expand Down Expand Up @@ -865,9 +871,7 @@ def save(self):
logger.info("Please enter 'y' for yes or 'n' for no.")
else:
self._partition_and_save_to_mat()
logger.info(
f"Saved pyaldata file(s) in {self.nwbfile_path.parent.name} session"
)
logger.info(f"Saved pyaldata file(s) in {self.nwbfile_path.parent.name} session")
return


Expand Down
Loading