-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
121 lines (99 loc) · 4.64 KB
/
main.py
File metadata and controls
121 lines (99 loc) · 4.64 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import sys
import cv2
import time
import json
import numpy as np
from scipy.optimize import linear_sum_assignment
from tensorflow import keras
import settings
from track import Track
from detection import detect
def min_cost_matching(cost_matrix) -> tuple:
# Instead of using pure Hungarian algorithm, we use Jonker-Volgenant algorithm for performance
# It returns the tuple (track_indices : np.ndarray, detection_indices : np.ndarray)
return linear_sum_assignment(cost_matrix)
def main(file_path : str) -> None:
video = cv2.VideoCapture(file_path)
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
frame_size = (int(video.get(cv2.CAP_PROP_FRAME_WIDTH)), int(video.get(cv2.CAP_PROP_FRAME_HEIGHT)))
output = cv2.VideoWriter(settings.OUTPUT_PATH, fourcc, 24, frame_size)
tracks = []
net = cv2.dnn.readNet(settings.YOLO_WEIGHTS_PATH, settings.YOLO_CONFIG_PATH) # YOLO used for detection
resnet = keras.applications.ResNet50(weights='imagenet', include_top=False) # ResNet for descriptors
frame_count = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
json_output = [{} for i in range(frame_count)]
current_frame = 0
while video.isOpened():
ret, frame = video.read()
percentage = current_frame / frame_count
sys.stdout.write('\rFrame: {current_frame}/{frame_count} ({percentage:.2%}) [{progress}]'.format(
current_frame=current_frame,
frame_count=frame_count,
percentage=percentage,
progress=('#' * int(percentage * 20)).ljust(20, '-')
))
sys.stdout.flush()
if percentage >= 1:
break
detected_tracks = detect(frame, net)
for track in detected_tracks:
track.update_descriptors(frame, resnet)
matches = {}
unmatched_detections = detected_tracks
for age in range(settings.MAX_AGE):
tracks_of_age = [track for track in tracks if track.age == age]
cost_matrix = np.zeros((len(tracks_of_age), len(unmatched_detections)))
gate_matrix = np.zeros(cost_matrix.shape)
for i, track in enumerate(tracks_of_age):
for j, detection in enumerate(unmatched_detections):
cost_matrix[i, j] = track.calculate_cost(detection, weight=0.5)
gate_matrix[i, j] = track.calculate_gate(detection)
min_cost_matched = min_cost_matching(cost_matrix)
for i, j in zip(*min_cost_matched):
if gate_matrix[i, j] * cost_matrix[i, j] > 0:
matches[tracks_of_age[i]] = unmatched_detections[j]
unmatched_detections = [track for track in unmatched_detections if track not in matches.values()]
tracks.extend(unmatched_detections)
for track in tracks:
if track.age >= settings.MAX_AGE:
tracks.remove(track)
if track in matches.keys():
track.update(matches[track], frame, resnet)
else:
track.update(None, frame, resnet)
if settings.SAVE_TO_JSON:
json_output[current_frame][track.identifier] = [_ for _ in track.rect]
for track in tracks:
if track.update_count > settings.SHOW_MIN_UPDATES:
if settings.SHOW_KALMAN_FILTER:
predicted_width = track.movement_model.x[2] * track.movement_model.x[3]
predicted_height = track.movement_model.x[3]
predicted_x = track.movement_model.x[0] - predicted_width // 2
predicted_y = track.movement_model.x[1] - predicted_height // 2
cv2.rectangle(
frame,
(int(predicted_x), int(predicted_y)),
(int(predicted_x + predicted_width), int(predicted_y + predicted_height)),
(255, 0, 0), 1
)
if settings.SHOW_UPDATES:
if track in matches.keys():
track.draw_rect(frame, (0, 0, 255))
else:
track.draw_rect(frame, (255, 255, 255))
else:
track.draw_rect(frame, (255, 255, 255))
output.write(frame)
if settings.SHOW:
# showing the video frame
cv2.imshow('frame', frame)
if cv2.waitKey(1) & 0xFF == ord('q'):
break
current_frame += 1
if settings.SAVE_TO_JSON:
with open(settings.JSON_OUTPUT_PATH, 'w', encoding='utf-8') as f:
json.dump(json_output, f, ensure_ascii=False, indent=4)
print()
video.release()
if __name__ == '__main__':
main(sys.argv[1])