This repository was archived by the owner on Nov 18, 2021. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathextract_table_image.py
More file actions
333 lines (260 loc) · 11.1 KB
/
extract_table_image.py
File metadata and controls
333 lines (260 loc) · 11.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import cv2
import pytesseract
from operator import itemgetter
import json
import os
import argparse
import time
def img2text(img, x, y, w, h):
'''
Function: translate image into texts
Input: original image, and location of text boxes
Output: extracted texts
'''
ROI = img[y - 3:(y + h + 6), x - 3:(x + w + 6)]
# change the 'lang' here for different traineddata
text = pytesseract.image_to_string(ROI, lang='eng', config='--psm 6 --oem 3').strip()
new_text = text.replace("\n", " ")
return new_text
def rm_lines(img):
'''
Function: remove all the horizontal and vertical lines in image and binary it
Input: original image
Output: image after preprocessing
'''
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
binary = cv2.adaptiveThreshold(~gray, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY, 35, -5)
# binary = cv2.adaptiveThreshold(gray, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY, 15, -5)
rows, cols = binary.shape
# detect horizontal lines
scale = 40
kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (cols // scale, 1))
eroded = cv2.erode(binary, kernel, iterations=1)
dilatedcol = cv2.dilate(eroded, kernel, iterations=2)
# detect vertical lines
scale = 20
kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (1, rows // scale))
eroded = cv2.erode(binary, kernel, iterations=1)
dilatedrow = cv2.dilate(eroded, kernel, iterations=2)
# merge two groups of lines
merge = cv2.add(dilatedcol, dilatedrow)
# comment the next line to save image with detected lines
# cv2.imwrite("lines.jpg", merge)
after = cv2.add(gray, merge)
# comment the next line to save borderless table images
# cv2.imwrite("borderless.jpg", after)
return after
def find_cells(img):
'''
Function: find cells in table images and sort them from top-left to bottom-right
Input: original image
Output: ordered table cells, and processed image
'''
added = cv2.copyMakeBorder(img, 10, 10, 10, 10, cv2.BORDER_CONSTANT, value=[255, 255, 255])
size = added.shape
# print('added.shape: ', size)
imgarea = size[0] * size[1]
# gray = cv2.cvtColor(added, cv2.COLOR_BGR2GRAY)
gray = rm_lines(img)
ret, thresh = cv2.threshold(gray, 190, 255, cv2.THRESH_BINARY)
# thresh2 = cv2.adaptiveThreshold(gray,255,cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY,10,0)
# comment the next line to save binary tables
# cv2.imwrite("thresh.jpg", thresh)
rows, cols = thresh.shape
scale = 150 # the larger, the rectangles smaller
# the second parameter of kernel and morphology iterations /
# need to be fine-tuned according to the image size
kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (cols // scale, rows // scale + 1))
eroded = cv2.morphologyEx(thresh, cv2.MORPH_GRADIENT, kernel, iterations=3)
eroded = cv2.bitwise_not(eroded)
# comment the next line to save images after morphology processing
# cv2.imwrite("eroded.jpg", eroded)
# first remove a few pixels and then add white borders before finding contours
eroded = eroded[10:(rows - 10), 10:(cols - 10)]
eroded = cv2.copyMakeBorder(eroded, 10, 10, 10, 10, cv2.BORDER_CONSTANT, value=[255, 255, 255])
contours, hierarchy = cv2.findContours(eroded, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
# 'cells' save the location and sort
cells = []
for c in contours:
x, y, w, h = cv2.boundingRect(c)
# case 1:eliminate rectangles that are too thin (might be lines)
if w > h * 30 or h > w * 30:
continue
# case 2:remove a box similar to the whole image
if (w > size[1] * 0.8) and (h > size[0] * 0.8):
continue
# case 3: eliminate small boxes that could be noises
area = cv2.contourArea(c)
# method 1: constant area. Does not work on images that are too large or too small
if area < 200:
continue
# method 2: proportional area
# if area > 0:
# if imgarea / area > 20000:
# continue
cells.append((x, y, w, h))
# 可能上下只差1 pixel,但左右顺序就错了
cells = sorted(cells, key=itemgetter(1, 0))
return cells, added, thresh
def cell2table(cells, added, thresh, target_dir, pmc):
'''
Function: save table texts in several rows
Input: ordered table cells, and processed image
Output: table text saved line by line
'''
# after sort, read cells line by line
color = (0, 255, 0) # box color
table_row = []
row = []
for (i, (x, y, w, h)) in enumerate(cells):
# print(x, y, w, h)
cv2.rectangle(added, (x, y), (x + w, y + h), color, 1)
row.append(cells[i])
# the last cell, footer or normal cell
if i == len(cells)-1:
# save the last line
table_row.append(row)
break
# newlines: [i] x+w >[i+1] x // [i+1] y > [i] y+h (latter is used, more accurate)
# minus 5 in case two lines are too close
if cells[i + 1][1] > cells[i][1] + cells[i][3] - 5:
table_row.append(row)
# save a new line
row = []
for row in table_row:
##new_row = []
row.sort(key=lambda x: x[0])
'''
next two commented blocks recognize section names in the left-most column
'''
# start = row[0]
# # row(i)[0][x]+[w] < row(i+1)[0][x] means the below cell is blank
# if prev[0] + prev[2] < start[0]:
# append to the above line
for (j, (x, y, w, h)) in enumerate(row):
# # if cells in the first column have all white pixels below, marked as section name
# # x<w ensures it is in the first column
# if j == 0 and x<w and eroded[x, y+2*h] > 0:
# # append the first element row[0]
# new_row.append(''.join(img2text(thresh, x, y, w, h)))
# new_table.append(new_row)
# new_row=[]
# continue
row[j] = img2text(thresh, x, y, w, h)
## text = img2text(thresh, x, y, w, h)
## new_row.append(text)
# # comment to write OCR results directly in the image
# font = cv2.FONT_HERSHEY_SIMPLEX
# cv2.putText(added, text, (x, y - 10), font, 1, color, 1);
## table_row.append(new_row)
# show the cell detection result image
# cv2.imwrite(target_dir + '/' + "{}_result.jpg".format(pmc), added)
return table_row
def text2json(table_row):
'''
Function: save table into a formatted json file
Input: table text saved line by line
Output: formatted json file of tables
'''
identifier = ''
title = ''
footer = ''
superline = ''
cnt1 = cnt2 = 0 # count to identify the column name line
for (i, row) in enumerate(table_row):
if len(row) == 1:
if i == 0:
while len(table_row[i]) == 1:
superline = superline + ' ' + ''.join(table_row[i])
i = i + 1
cnt1 = cnt1 + 1
low = superline.lower()
identifier = superline[low.find('table'): low.find('table') + 7]
title = superline[low.find('table') + 9:].strip()
if i == len(table_row) - 1:
superline = ''
while len(table_row[i]) == 1:
superline = ''.join(table_row[i]) + ' ' + superline
i = i - 1
cnt2 = cnt2 + 1
footer = superline
# remove titles and footers
table_row = table_row[cnt1: len(table_row) - 1 - cnt2]
table = {}
sections = []
cur_section = {}
pre_header = []
pre_superrow = None
cur_header = ''
cur_superrow = ''
for (i, row) in enumerate(table_row):
if i == 0:
cur_header = row
# elif is_column:
# cur_header = row
elif i != 0 and len(row) == 1:
if i != len(table_row) - 1:
cur_superrow = row
# # last line start with doi, retain?
# if ''.join(cells[i]).startswith('doi:'):
# break
# skip blank rows (rarely happen)
if not any([i for i in row if i not in ['', 'None']]):
continue
else:
if cur_header != pre_header:
sections = []
pre_superrow = None
table = {'identifier': identifier,
'title': title,
'columns': cur_header,
'section': sections,
'footer': footer}
elif cur_header == pre_header:
table['section'] = sections
if cur_superrow != pre_superrow:
cur_section = {'section_name': cur_superrow, 'results': []}
sections.append(cur_section)
elif cur_superrow == pre_superrow:
cur_section['results'].append(row)
pre_header = cur_header
pre_superrow = cur_superrow
table_json = {'tables': table}
# print(table_json)
return table_json
if __name__ == "__main__":
# time_start = time.time()
parser = argparse.ArgumentParser(prog='PROG')
group = parser.add_mutually_exclusive_group()
group.add_argument('-f', '--imgPath', type=str, help='File path of table images')
group.add_argument('-b', '--base_dir', type=str, help='Base directory of table files')
parser.add_argument('-t', '--target_dir', type=str, help='Target directory of JSON output')
parser.add_argument("-c", "--config", type=str, help="filepath for configuration JSON file")
args = parser.parse_args()
imgPath = args.imgPath
base_dir = args.base_dir
target_dir = args.target_dir
config_path = args.config
with open(config_path, 'rb') as f:
config = json.load(f)
if not os.path.isdir(target_dir):
try:
os.makedirs(target_dir)
except:
raise FileNotFoundError('Target filepath does not exist')
if imgPath:
img_list = os.listdir(imgPath)
for imglist in img_list:
print('READING: ', imglist)
imgname = imglist.split('/')[-1]
pmc = imgname[0:imgname.rfind('.')]
img = cv2.imread(os.path.join(imgPath, imglist))
cells, added, thresh = find_cells(img)
table_row = cell2table(cells, added, thresh, target_dir, pmc)
table_json = text2json(table_row)
with open(os.path.join(target_dir, "{}_table_image.json".format(pmc)), "w") as outfile:
json.dump(table_json, outfile, indent=4, separators=(',', ': '))
# time_end = time.time()
# print('Total time: ', time_end - time_start)