-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathImageTextComparator.py
More file actions
123 lines (109 loc) · 4.4 KB
/
ImageTextComparator.py
File metadata and controls
123 lines (109 loc) · 4.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import os
import google.generativeai as genai
from llama_index.multi_modal_llms.gemini import GeminiMultiModal
from llama_index.core.multi_modal_llms.generic_utils import load_image_urls
from llama_index.core import SimpleDirectoryReader
import time
from threading import Thread
class ImageTextComparator:
def __init__(self, api_key):
genai.configure(api_key=api_key)
self.model = genai.GenerativeModel(
model_name="gemini-1.5-flash",
generation_config={
"temperature": 1,
"top_p": 0.95,
"top_k": 64,
"max_output_tokens": 8192,
"response_mime_type": "text/plain",
},
)
def upload_to_gemini(self, path, mime_type=None):
"""Uploads the given file to Gemini."""
file = genai.upload_file(path, mime_type=mime_type)
print(f"Uploaded file '{file.display_name}' as: {file.uri}")
return file
def get_text_from_file(self, file_path):
"""Reads text from a file."""
if not os.path.exists(file_path):
raise FileNotFoundError(f"Text file not found: {file_path}")
with open(file_path, "r") as file:
return file.read()
def image_download_and_compare(self, text_file_path, image_folder_path):
"""Upload images and text, then compare using Gemini."""
# Read the text from the file
try:
text = self.get_text_from_file(text_file_path)
except Exception as e:
print(f"Error reading text file: {e}")
return
# Upload all images from the folder
files = []
if not os.path.exists(image_folder_path):
raise FileNotFoundError(f"Image folder not found: {image_folder_path}")
for img_name in os.listdir(image_folder_path):
img_path = os.path.join(image_folder_path, img_name)
if os.path.isfile(img_path) and img_name.lower().endswith(
(".png", ".jpeg", ".jpg")
):
mime_type = (
"image/jpeg"
if img_name.lower().endswith((".jpeg", ".jpg"))
else "image/png"
)
try:
file = self.upload_to_gemini(img_path, mime_type=mime_type)
files.append(file)
except Exception as e:
print(f"Error uploading file {img_path}: {e}")
if not files:
print("No valid images were uploaded.")
return
# Define the prompt with the text and image URIs
prompt = (
"I have the following text and images. Please determine if the images are related to the text and explain how. "
f"Text: {text} "
"Images: " + ", ".join([f.uri for f in files])
)
print(files)
# Start the chat session with a custom prompt
try:
chat_session = self.model.start_chat(
history=[
{
"role": "user",
"parts": [
prompt,
],
}
]
)
# Send a message and get the response
response = chat_session.send_message(
"Are these images contextually related to the text?"
)
return response.text
except Exception as e:
print(f"Error during chat session: {e}")
return
def generate_response(self, text_file_path, image_folder_path):
# load image documents from local directory
image_documents = SimpleDirectoryReader(image_folder_path).load_data()
mm_llm = GeminiMultiModal(
model_name="models/gemini-1.5-flash",
api_key="AIzaSyCfT12LpddN3ZNmYffPiTwEwp-WiKkILMo",
)
# Read the text from the file
try:
text = self.get_text_from_file(text_file_path)
except Exception as e:
print(f"Error reading text file: {e}")
return
start = time.time()
prompt = (
"I have the following text and images. Please determine if the images are related to the text and explain how. "
f"Text: {text} "
)
response = mm_llm.complete(prompt=prompt, image_documents=image_documents).text
end = time.time()
print(f"Response: {response}")