-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathmain.py
More file actions
150 lines (122 loc) · 5.13 KB
/
main.py
File metadata and controls
150 lines (122 loc) · 5.13 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
PATH_TO_SQL_QUERY_TESTING_FILE = "/Users/shivammishra/Desktop/Adobe/train_generate_task.json"
import os
import json
import time
import threading
import single_pipeline
import RAG.embedding_creator as search_api
def start_single_pipeline(nl_query):
"""
Dummy function to process the NL query.
Should return the corresponding SQL query as a string.
"""
print(f"Processing query: {nl_query}")
start_time = time.time()
# Getting Relevant Tables [RAG]
relevant_tables = search_api.search_tables(nl_query)
# Getting Relevant Tables [LLM]
relevant_tables, table_agent_token_count = single_pipeline.table_agent(nl_query, relevant_tables)
# Getting Relevant Columns [LLM]
relevant_schema, column_agent_token_count = single_pipeline.prune_agent(nl_query, relevant_tables)
# Getting Relevant Sample Queries [RAG]
relevant_sample_queries = search_api.search_similar_query(nl_query)
# Getting SQL Query [LLM]
sql_query, sql_agent_token_count = single_pipeline.final_sql_query_generator(nl_query, relevant_schema, relevant_sample_queries)
print("-----------")
print("Table Creation Token Count:", table_agent_token_count)
print("Column Pruning Token Count:", column_agent_token_count)
print("SQL Generation Token Count:", sql_agent_token_count)
print("-----------")
file_path = "result.json"
# Read existing data
if os.path.exists(file_path) and os.path.getsize(file_path) > 0:
with open(file_path, "r") as f:
try:
data = json.load(f)
except json.JSONDecodeError:
data = []
else:
data = []
# Append new entry
data.append({"NL": nl_query, "Query": sql_query})
# Write back to file
with open(file_path, "w") as f:
json.dump(data, f, indent=4)
end_time = time.time()
print("-----------Performance Metrics-----------")
print(f"Time taken: {end_time - start_time:.2f} seconds")
print(f"Total Tokens: {table_agent_token_count + column_agent_token_count + sql_agent_token_count}")
print(f"SQL Query: {sql_query}")
print("----------------------------")
exit()
return sql_query # Replace this with actual query generation logic
def normalize_string(s):
"""Convert to lowercase and strip extra spaces."""
return s.lower().strip()
def check_query_match(expected_query, generated_query):
"""Compare normalized queries."""
return normalize_string(expected_query) == normalize_string(generated_query)
def process_single_json_object(json_obj):
"""Process a single JSON object and return whether it matches."""
nl_query = json_obj.get("NL")
expected_query = json_obj.get("Query")
generated_query = start_single_pipeline(nl_query)
print("-"*30)
print(f"Expected: {expected_query}")
print(f"Generated: {generated_query}")
print("-"*30)
return check_query_match(expected_query, generated_query)
def process_json_objects(file_path):
"""Read JSON file and return the list of objects."""
with open(file_path, 'r', encoding='utf-8') as file:
return json.load(file)
def process_queries_multithreaded(file_path, max_threads=16):
"""Process queries using multithreading."""
json_objects = process_json_objects(file_path)
threads = []
successful_matches = 0
unsuccessful_matches = 0
lock = threading.Lock()
def worker(json_obj):
nonlocal successful_matches, unsuccessful_matches
if process_single_json_object(json_obj):
with lock:
successful_matches += 1
else:
with lock:
unsuccessful_matches += 1
for json_obj in json_objects:
thread = threading.Thread(target=worker, args=(json_obj,))
threads.append(thread)
thread.start()
if len(threads) >= max_threads:
for t in threads:
t.join()
threads = []
for t in threads:
t.join()
print("\n=== Multithreaded Processing ===")
print(f"Successful Matches: {successful_matches}")
print(f"Unsuccessful Matches: {unsuccessful_matches}")
def process_queries_linear(file_path):
"""Process queries one by one in a normal linear way."""
json_objects = process_json_objects(file_path)
successful_matches = 0
unsuccessful_matches = 0
for json_obj in json_objects:
if process_single_json_object(json_obj):
successful_matches += 1
else:
unsuccessful_matches += 1
# Print live updates
print(f"Processed: {successful_matches + unsuccessful_matches}, "
f"Successful: {successful_matches}, Unsuccessful: {unsuccessful_matches}")
print("\n=== Linear Processing ===")
print(f"Successful Matches: {successful_matches}")
print(f"Unsuccessful Matches: {unsuccessful_matches}")
if __name__ == "__main__":
sql_file_path = 'other/sample_submission_generate_task.json' # Ensure proper escaping
# print("\nRunning Multithreaded Processing...")
# process_queries_multithreaded(sql_file_path)
print("\nRunning Linear Processing...")
process_queries_linear(sql_file_path)