-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathiORG_stimulus_alignment.py
More file actions
182 lines (150 loc) · 7.63 KB
/
iORG_stimulus_alignment.py
File metadata and controls
182 lines (150 loc) · 7.63 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
"""
Author: Brea Brennan
Date: 12/12/2024
Description: New implementation for the alignment of the stimulus in animal iORG videos. This class will be used to find
the minimum number of frames before, during and post stimulus presentation to align all videos in the collected dataset.
Will be taking inspiration from the implementation of the old method initial writen and posted on the AOIP's GitHub
account https://github.com/AOIPLab/Animal_iOrg_Frame_Extractor
"""
import cv2
import os
import numpy as np
import pandas as pd
from prettytable import PrettyTable
class IORGStimulusAlignment:
def __init__(self):
self.h = None
self.frate = None
self.w = None
self.data_path = None
self.vis_videos = None
def get_video_data(self, vid_name):
"""
Function to load in each videos data using CV2
:param vid_name: Name of the video that's data is to be loaded
:return: the numpy array containing all the video data by its frames
"""
capture = cv2.VideoCapture(os.path.join(self.data_path, vid_name))
self.w = int(capture.get(cv2.CAP_PROP_FRAME_WIDTH))
self.h = int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
num_frames = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))
self.frate = capture.get(cv2.CAP_PROP_FPS)
video_data = np.zeros((self.h, self.w, num_frames), dtype=np.uint8)
c = 0
while capture.isOpened():
r, fr = capture.read()
video_data[:, :, c] = fr[..., 0]
c += 1
if c >= num_frames:
break
capture.release()
return video_data
def find_frames_for_stimulus_alignment(self):
"""
Function to find all the frame cut off for pre, during, and post stimulus presentation. Table containing all
the frame data for each video is written to a csv table at the end of the function.
:return: None
"""
# Table to keep track of the visible names and the number of frames to determine if the video is a control or
# stim video and to align all stim videos stimulus presentation to one another
frame_data = PrettyTable(["Vis Name", "Pre Stim Frames", "During Stim Frames", "Post Stim Frames", "VidType"])
# For each visible video
for i in range(0, len(self.vis_videos)):
video_data = self.get_video_data(self.vis_videos[i])
# Variables to track the code progression
first = 1
progression = 0
finish = None
start = None
# For each frame in the video
for j in range(0, video_data.shape[-1]):
# Check to ensure that another frame is available
if j < video_data.shape[-1] - 1:
# If progression flag is raised skip then next iteration because frame is already accounted for
if progression == 1:
progression = 0
continue
# Read current and next frames for the video data
curr_frame = video_data[:, :, j]
next_frame = video_data[:, :, j + 1]
# Get the difference in the frame means to determine if a stimulus was presented
m1 = np.mean(curr_frame)
m2 = np.mean(next_frame)
diff = np.abs(m2 - m1)
if diff > 1: # stimulus was presented
if first == 1:
# Stimulus was started on next frame
start = j + 1
first = 0
else:
finish = j + 1
progression = 1
if finish is None or start is None:
condition = "Control"
frame_data.add_row([self.vis_videos[i], video_data.shape[-1], 0, 0, condition])
else:
# Get the number of frames that were in each video portion
stim_frames = (finish - start) + 1
pre_stim_frames = start
post_stim_frames = (j - finish)
condition = "Stim"
# Add the frame information to the table
frame_data.add_row([self.vis_videos[i], pre_stim_frames, stim_frames, post_stim_frames, condition])
# Save the dataset frame information
print("All videos and frame data determined. Saving csv.......")
with open(os.path.join(self.data_path, "dataset_frame_information.csv"), 'w', newline='') as csvfile:
csvfile.write(frame_data.get_csv_string())
def align_stimulus_across_videos(self, video_names):
"""
Function to align the stimulus presentation across all videos in the dataset.
:param video_names: The list of videos to be aligned in the order of video number
:return: None
"""
csv_data = pd.read_csv(os.path.join(self.data_path, "dataset_frame_information.csv"))
csv_data = csv_data.values
# Find the min number of frames in each section
pre_stim_frames = csv_data[:, 1]
min_pre_stim_frames = pre_stim_frames.min()
# save the cropped video to a new directory
new_path = self.data_path + "\\Stim_Aligned\\"
if not os.path.exists(new_path):
os.mkdir(new_path)
# Align the stimulus by cropping the videos
for i in range(0, len(video_names)):
cropped_data = []
video_data = self.get_video_data(video_names[i])
video_condition = csv_data[i, -1]
if video_condition != "Control":
# Pre stimulus presentation
video_pre_stim_len = csv_data[i, 1]
pre_start_frame = (video_pre_stim_len - min_pre_stim_frames)
pre_end_frame = video_pre_stim_len
for frame_num in range(pre_start_frame, pre_end_frame):
cropped_data.append(video_data[:, :, frame_num])
# During stimulus presentation
video_dur_stim_len = csv_data[i, 2]
start_stim_frame = csv_data[i, 1]
end_stim_frame = start_stim_frame + video_dur_stim_len
for frame_num in range(start_stim_frame, end_stim_frame):
cropped_data.append(video_data[:, :, frame_num])
# Post stimulus presentation
start_post_frame = csv_data[i, 1] + csv_data[i, 2]
end_post_frame = video_data.shape[-1] - start_post_frame
for frame_num in range(start_post_frame, end_post_frame):
cropped_data.append(video_data[:, :, frame_num])
file_name = new_path + video_names[i]
code = cv2.VideoWriter.fourcc(*'Y800')
output = cv2.VideoWriter(file_name, code, self.frate, (self.w, self.h), isColor=False)
for j in range(len(cropped_data)):
output.write(cropped_data[j].astype(np.uint8))
output.release()
print("Video " + video_names[i] + " have been combined back together and saved!")
else:
file_name = new_path + video_names[i]
code = cv2.VideoWriter.fourcc(*'Y800')
output = cv2.VideoWriter(file_name, code, self.frate, (self.w, self.h), isColor=False)
for j in range(0, video_data.shape[-1]):
output.write(video_data[:, :, j].astype(np.uint8))
output.release()
print("Video " + video_names[i] + " have been combined back together and saved!")
continue