-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathpreprocess_data.py
More file actions
35 lines (23 loc) · 833 Bytes
/
preprocess_data.py
File metadata and controls
35 lines (23 loc) · 833 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
#!/usr/bin/python2
import os
import glob
import numpy as np
import cv2
import time
from tqdm import tqdm
img_list = glob.glob("pi/data/combined/*.png")
total_labels = 3
t = cv2.cvtColor(cv2.imread(img_list[0]), cv2.COLOR_RGB2GRAY)
dims = t.shape
inputs = dims[0]*dims[1]
label = np.identity(total_labels)
frames = np.zeros((1, inputs))
labels = np.zeros((1, total_labels))
for i in tqdm(xrange(len(img_list))):
img = img_list[i]
t = cv2.cvtColor(cv2.imread(img), cv2.COLOR_RGB2GRAY).astype(np.int32)
div = 2 if total_labels==3 else 1
frames = np.vstack((frames, t.reshape((1, inputs)).astype(np.float32)))
labels = np.vstack((labels, label[int(img.split('.')[-2], 10) / div]))
# print img, label[int(img.split('.')[-2], 10) / div]
np.savez('train_data/' + str(int(time.time())) + '.npz', data=frames, labels=labels)