Skip to content

Commit 4e7cd1d

Browse files
Merge pull request #269 from epflgraph/ocr-one-task-per-pdf-page
Ocr one task per pdf page
2 parents 029dc7f + 258cbbc commit 4e7cd1d

6 files changed

Lines changed: 88 additions & 85 deletions

File tree

.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ cython_debug/
176176
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
177177
# and can be added to the global gitignore or merged into this file. For a more nuclear
178178
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
179-
#.idea/
179+
.idea/
180180

181181
# Abstra
182182
# Abstra is an AI-powered process automation framework.

graphai/celery/image/jobs.py

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
extract_slide_text_task,
1212
extract_slide_text_callback_task,
1313
convert_pdf_to_pages_task,
14+
fanout_pdf_ocr_task,
1415
extract_multi_image_text_task,
1516
collect_multi_image_ocr_task
1617
)
@@ -153,24 +154,17 @@ def ocr_job(
153154
# OCR computation job
154155
#####################
155156
if is_pdf(token):
156-
n_parallel = 8
157157
task_list = [
158158
convert_pdf_to_pages_task.s(token),
159-
group(
160-
extract_multi_image_text_task.s(
161-
i,
162-
n_parallel,
163-
method,
164-
google_api_token,
165-
openai_api_token,
166-
gemini_api_token,
167-
rcp_api_token,
168-
model_type,
169-
enable_tikz,
170-
)
171-
for i in range(n_parallel)
159+
fanout_pdf_ocr_task.s(
160+
method,
161+
google_api_token,
162+
openai_api_token,
163+
gemini_api_token,
164+
rcp_api_token,
165+
model_type,
166+
enable_tikz,
172167
),
173-
collect_multi_image_ocr_task.s()
174168
]
175169
else:
176170
task_list = [

graphai/celery/image/tasks.py

Lines changed: 45 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from celery import shared_task
1+
from celery import shared_task, group, chord
22

33
from graphai.core.image.image import (
44
cache_lookup_retrieve_image_from_url,
@@ -146,6 +146,47 @@ def convert_pdf_to_pages_task(self, token):
146146
return break_pdf_into_images(token, self.file_manager)
147147

148148

149+
@shared_task(
150+
bind=True,
151+
autoretry_for=(Exception,),
152+
retry_backoff=True,
153+
retry_kwargs={"max_retries": 2},
154+
name="image.fanout_pdf_ocr_task",
155+
ignore_result=False,
156+
)
157+
def fanout_pdf_ocr_task(
158+
self,
159+
pages,
160+
method,
161+
google_api_token=None,
162+
openai_api_token=None,
163+
gemini_api_token=None,
164+
rcp_api_token=None,
165+
model_type=None,
166+
enable_tikz=False,
167+
):
168+
# Build one OCR task per page
169+
header = group(
170+
extract_multi_image_text_task.s(
171+
page,
172+
method,
173+
google_api_token,
174+
openai_api_token,
175+
gemini_api_token,
176+
rcp_api_token,
177+
model_type,
178+
enable_tikz,
179+
)
180+
for page in pages
181+
)
182+
183+
# When all pages are OCR'd, collect results
184+
callback = collect_multi_image_ocr_task.s()
185+
186+
# Replace this task with the chord so the outer chain waits properly
187+
raise self.replace(chord(header, callback))
188+
189+
149190
@shared_task(
150191
bind=True,
151192
autoretry_for=(Exception,),
@@ -156,9 +197,7 @@ def convert_pdf_to_pages_task(self, token):
156197
)
157198
def extract_multi_image_text_task(
158199
self,
159-
page_and_filename_list,
160-
i,
161-
n,
200+
page_and_filename,
162201
method="google",
163202
google_api_token=None,
164203
openai_api_token=None,
@@ -167,11 +206,9 @@ def extract_multi_image_text_task(
167206
model_type=None,
168207
enable_tikz=False,
169208
):
170-
print(f'Starting {extract_multi_image_text_task} task for page_and_filename_list {page_and_filename_list}, i {i} and n {n}')
209+
print(f'Starting {extract_multi_image_text_task} task for page_and_filename {page_and_filename}')
171210
return extract_multi_image_text(
172-
page_and_filename_list,
173-
i,
174-
n,
211+
page_and_filename,
175212
method,
176213
google_api_token,
177214
openai_api_token,

graphai/core/image/image.py

Lines changed: 32 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -210,30 +210,19 @@ def perform_ocr(
210210
model_type=None,
211211
enable_tikz=False,
212212
):
213-
ocr_colnames = get_ocr_colnames(method)
214-
215-
results = None
216-
language = None
213+
text = None
217214

218215
if method == 'tesseract':
219-
res = perform_tesseract_ocr(file_path, language='enfr')
216+
text = perform_tesseract_ocr(file_path, language='enfr')
220217

221-
if res:
222-
language = detect_text_language(res)
223-
results = [{'method': ocr_colnames[0], 'text': res}]
224218
elif method == 'google' and google_api_token:
225219
ocr_model = GoogleOCRModel(google_api_token)
226220
ocr_model.establish_connection()
227-
res1, res2 = ocr_model.perform_ocr(file_path)
221+
text1, text2 = ocr_model.perform_ocr(file_path)
222+
223+
# Since DTD usually performs better, method #1 is our point of reference for langdetect
224+
text = text1
228225

229-
if res1:
230-
# Since DTD usually performs better, method #1 is our point of reference for langdetect
231-
language = detect_text_language(res1)
232-
res_list = [res1]
233-
results = [
234-
{'method': ocr_colnames[i], 'text': res_list[i]}
235-
for i in range(len(res_list))
236-
]
237226
else:
238227
ocr_model = None
239228
if method == 'openai' and openai_api_token:
@@ -245,17 +234,14 @@ def perform_ocr(
245234

246235
if ocr_model:
247236
ocr_model.establish_connection()
248-
res = ocr_model.perform_ocr(
249-
file_path, model_type=model_type, enable_tikz=enable_tikz
250-
)
237+
text = ocr_model.perform_ocr(file_path, model_type=model_type, enable_tikz=enable_tikz)
251238

252-
if res:
253-
language = detect_text_language(res)
254-
results = [{'method': ocr_colnames[0], 'text': res}]
239+
if not text:
240+
text = ''
255241

256242
return {
257-
'results': results,
258-
'language': language,
243+
'results': [{'method': get_ocr_colnames(method)[0], 'text': text}],
244+
'language': detect_text_language(text),
259245
}
260246

261247

@@ -296,9 +282,7 @@ def extract_slide_text(
296282

297283

298284
def extract_multi_image_text(
299-
page_and_filename_list,
300-
i,
301-
n,
285+
page_and_filename,
302286
method="google",
303287
google_api_token=None,
304288
openai_api_token=None,
@@ -307,44 +291,33 @@ def extract_multi_image_text(
307291
model_type=None,
308292
enable_tikz=False,
309293
):
310-
# Extract subset of pages to process
311-
n_pages = len(page_and_filename_list)
312-
start_index = int(i / n * n_pages)
313-
end_index = int((i + 1) / n * n_pages)
314-
pages_to_handle = page_and_filename_list[start_index: end_index]
315-
316-
# Perform OCR on subset of pages
317-
results = list()
318-
for page in pages_to_handle:
319-
results.append(
320-
perform_ocr(
321-
page["filename"],
322-
method,
323-
google_api_token,
324-
openai_api_token,
325-
gemini_api_token,
326-
rcp_api_token,
327-
model_type,
328-
enable_tikz,
329-
)
330-
)
294+
# Perform OCR on page
295+
result = perform_ocr(
296+
page_and_filename["filename"],
297+
method,
298+
google_api_token,
299+
openai_api_token,
300+
gemini_api_token,
301+
rcp_api_token,
302+
model_type,
303+
enable_tikz,
304+
)
305+
306+
print(f"Performed OCR on page {page_and_filename['page']}. Result: {result}")
331307

332308
# Build result and return it
333309
return {
334-
'results': [
335-
{
336-
'page': pages_to_handle[i]['page'],
337-
'content': results[i]['results'][0]['text']
338-
}
339-
for i in range(len(results))
340-
],
341-
'language': get_most_common_element([result['language'] for result in results]),
342-
'method': get_most_common_element([result['results'][0]['method'] for result in results])
310+
'result': {
311+
'page': page_and_filename['page'],
312+
'content': result['results'][0]['text']
313+
},
314+
'language': result['language'],
315+
'method': result['results'][0]['method'],
343316
}
344317

345318

346319
def collect_multi_image_ocr(results):
347-
all_results = list(chain.from_iterable(result['results'] for result in results))
320+
all_results = [result['result'] for result in results]
348321
language = get_most_common_element([result['language'] for result in results])
349322
method = get_most_common_element([result['method'] for result in results])
350323
return {

graphai/core/image/ocr.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -342,7 +342,6 @@ def perform_ocr(self, input_filename_with_path, model_type=None, **kwargs):
342342
response = self.model.chat.completions.create(model=model_type, messages=messages, response_format={"type": "json_object"})
343343
print(f'Got {response}')
344344
content = response.choices[0].message.content.strip()
345-
print(f'Got {content}')
346345

347346
# Strip thinking tokens
348347
thinking_tag = '</think>'

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ classifiers = [
1717
"Operating System :: OS Independent"
1818
]
1919
dependencies = [
20-
"loguru"
20+
"loguru",
2121
"numpy",
2222
"scipy",
2323
"pandas",

0 commit comments

Comments
 (0)