-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathoffline_app.py
More file actions
294 lines (239 loc) · 11.2 KB
/
offline_app.py
File metadata and controls
294 lines (239 loc) · 11.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
import json
import pymysql
import openai
import re
import sys
import os
from typing import List, Dict, Any, Tuple
import numpy as np
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
import jieba
from haystack.document_stores import InMemoryDocumentStore
from haystack.nodes import EmbeddingRetriever
from haystack.schema import Document
# 设置控制台输出编码
sys.stdout.reconfigure(encoding='utf-8')
sys.stderr.reconfigure(encoding='utf-8')
# 设置jieba分词的日志级别
jieba.setLogLevel(20) # 设置为INFO级别,减少输出
# 实现一个基于TF-IDF的检索器,针对中文场景优化
class TfidfRetriever:
def __init__(self, documents):
self.documents = documents
self.descriptions = [doc['description'] for doc in documents]
# 使用jieba分词预处理文档
self.processed_docs = []
for desc in self.descriptions:
seg_list = jieba.cut(desc, cut_all=False)
self.processed_docs.append(" ".join(seg_list))
# 创建TF-IDF向量化器
self.vectorizer = TfidfVectorizer()
self.tfidf_matrix = self.vectorizer.fit_transform(self.processed_docs)
def retrieve(self, query, top_k=1):
# 使用jieba分词处理查询
seg_list = jieba.cut(query, cut_all=False)
processed_query = " ".join(seg_list)
# 向量化查询
query_vector = self.vectorizer.transform([processed_query])
# 计算余弦相似度
similarities = cosine_similarity(query_vector, self.tfidf_matrix).flatten()
# 获取最相似的文档索引
top_indices = similarities.argsort()[-top_k:][::-1]
results = []
for idx in top_indices:
results.append({
"description": self.descriptions[idx],
"sql": self.documents[idx]['sql'],
"score": similarities[idx]
})
return results
# 加载数据
with open('./nlsql.json', 'r', encoding='utf-8') as f:
items = json.load(f)
print(f"加载了 {len(items)} 条SQL模板记录")
# 创建TF-IDF检索器
tfidf_retriever = TfidfRetriever(items)
# 初始化文档存储 (保留Haystack实现作为备选)
document_store = InMemoryDocumentStore(embedding_dim=384) # bge-small模型维度为384
# 构建文档对象
documents = []
for item in items:
doc = Document(
content=item['description'],
meta={"sql": item["sql"]}
)
documents.append(doc)
# 将文档添加到文档存储
document_store.write_documents(documents)
# 检查本地模型是否存在
local_model_path = "./models/sentence-transformers/all-MiniLM-L6-v2"
use_haystack = os.path.exists(local_model_path)
if use_haystack:
# 使用本地Embedding模型初始化Retriever
haystack_retriever = EmbeddingRetriever(
document_store=document_store,
embedding_model=local_model_path,
model_format="sentence_transformers",
use_gpu=False
)
# 更新文档存储中的Embeddings
document_store.update_embeddings(haystack_retriever)
print("Haystack向量检索器初始化完成")
else:
print("未找到本地模型,将仅使用TF-IDF检索")
# 混合查询函数
def query_sql(nl_query: str, top_k: int = 1) -> List[Dict[str, str]]:
print(f"\n执行查询: {nl_query}")
# 使用TF-IDF检索
tfidf_results = tfidf_retriever.retrieve(nl_query, top_k=top_k)
if use_haystack:
# 使用Haystack检索
haystack_docs = haystack_retriever.retrieve(nl_query, top_k=top_k)
haystack_results = []
for doc in haystack_docs:
haystack_results.append({
"description": doc.content,
"sql": doc.meta.get("sql", ""),
"score": doc.score
})
# 结果合并和排序
# 简单策略:将两种方法的结果合并,并基于得分重新排序
all_results = tfidf_results + haystack_results
# 根据description去重
unique_results = {}
for res in all_results:
if res["description"] not in unique_results or res["score"] > unique_results[res["description"]]["score"]:
unique_results[res["description"]] = res
all_results = list(unique_results.values())
all_results.sort(key=lambda x: x["score"], reverse=True)
# 只返回前top_k个结果
results = all_results[:top_k]
else:
# 如果没有Haystack,直接使用TF-IDF结果
results = tfidf_results
# 打印检索结果和相似度分数
for i, res in enumerate(results):
print(f"结果 {i+1},相似度分数: {res.get('score', 0):.4f}")
print(f"描述: {res['description']}")
print(f"SQL: {res['sql']}\n")
return results
# 用OpenAI API(兼容Ollama)改写SQL
def openai_rewrite_sql(user_query: str, retrieved_sql: str, model="deepseek-r1-distill-qwen-32b-dzhbqijn") -> str:
try:
prompt = f"""
你是SQL专家。你需要修改下面的SQL来满足用户的查询需求。
用户查询: {user_query}
原始SQL: {retrieved_sql}
具体要求:
1. 对SQL进行必要修改,使其符合用户的查询意图(如更改地区、时间等参数)
2. 保持SQL的整体结构不变
3. 只返回修改后的SQL语句,不要有任何解释、注释、思考过程或其他内容
4. 不要使用markdown格式
5. 不要添加"SELECT"、"SQL:"等前缀
6. 确保SQL语法正确
7. 禁止输出任何思考过程或推理分析
下面是几个示例:
示例1:
用户查询: 查询北京2024年3月的案件明细
原始SQL: select a.* from t_case_management a where substring(a.crime_date,1,4) ='2025' and substring(a.crime_date,6,2) = '02' and a.reporting_unit like '%石家庄%'
修改后SQL: select a.* from t_case_management a where substring(a.crime_date,1,4) ='2024' and substring(a.crime_date,6,2) = '03' and a.reporting_unit like '%北京%'
示例2:
用户查询: 查询上海2023年10月至12月的案件数量
原始SQL: select b.reporting_unit, count(*) sl from (select a.* from t_case_management a where substring(a.crime_date,1,4) ='2025' and substring(a.crime_date,6,2) = '02' and a.reporting_unit like '%石家庄%') b group by b.reporting_unit
修改后SQL: select b.reporting_unit, count(*) sl from (select a.* from t_case_management a where substring(a.crime_date,1,4) ='2023' and substring(a.crime_date,6,2) in ('10','11','12') and a.reporting_unit like '%上海%') b group by b.reporting_unit
现在,请根据用户查询修改原始SQL。直接输出修改后的SQL,不要有任何思考过程:
"""
client = openai.OpenAI(
api_key='none',
base_url='http://192.168.0.7:11434/v1'
)
response = client.chat.completions.create(
model=model,
messages=[{"role": "system", "content": "你是一个SQL专家,只输出SQL语句,不要有任何思考过程或解释。"},
{"role": "user", "content": prompt}],
temperature=0.1, # 降低温度以获得更确定性的回答
max_tokens=500
)
# 获取响应并清理
sql = response.choices[0].message.content.strip()
# 检查是否包含</think>标记(deepseek模型特有的思考结束标记)
if "</think>" in sql:
# 如果有</think>标记,直接提取其后的内容作为SQL
sql = sql.split("</think>")[-1].strip()
# 如果分割后还是没有SQL语句,尝试正则表达式提取
if not sql or not re.search(r'select\s+.+\s+from\s+', sql, re.IGNORECASE):
# 直接提取最后一个完整的SQL语句
sql_pattern = r'select\s+.+?from\s+.+?(?:where\s+.+?)?(?:group\s+by\s+.+?)?(?:order\s+by\s+.+?)?(?:limit\s+\d+)?'
matches = list(re.finditer(sql_pattern, sql, re.IGNORECASE | re.DOTALL))
if matches:
# 获取最后一个匹配的SQL语句
sql = matches[-1].group(0)
else:
print("未能提取出SQL语句,使用原始SQL")
return retrieved_sql
# 从SQL中去除可能的markdown格式和注释
sql = re.sub(r'^```sql\s*', '', sql)
sql = re.sub(r'\s*```$', '', sql)
sql = re.sub(r'^SQL:\s*', '', sql, flags=re.IGNORECASE)
# 如果SQL语句不完整,则使用原始SQL
if not re.search(r'select\s+.+\s+from\s+', sql, re.IGNORECASE):
print("提取的SQL语句不完整,使用原始SQL")
return retrieved_sql
return sql.strip()
except Exception as e:
print(f"大模型API调用失败: {e}")
print("离线环境下使用原始SQL查询")
# 直接返回原始SQL作为离线环境的备选方案
return retrieved_sql
# MySQL查询函数
def query_mysql(sql: str, host="192.168.0.7", port=3306, user="baizh", password="bai12345", database="baizh") -> Tuple[List[str], List[Tuple]]:
conn = pymysql.connect(host=host, port=port, user=user, password=password, database=database, charset="utf8mb4")
try:
with conn.cursor() as cursor:
cursor.execute(sql)
result = cursor.fetchall()
columns = [desc[0] for desc in cursor.description]
return columns, result
finally:
conn.close()
if __name__ == "__main__":
try:
print("====== 自然语言SQL检索系统 ======")
print("已使用TF-IDF + jieba分词优化中文检索")
while True:
try:
q = input("\n请输入案件查询描述 (输入q退出):")
if q.lower() == 'q':
break
results = query_sql(q, top_k=2) # 检索2个最相关的结果
if not results:
print("未找到相关的SQL模板")
continue
# 使用第一个最相关的结果
res = results[0]
print(f"选择最相关描述: {res['description']}")
try:
# 用OpenAI兼容API改写SQL
new_sql = openai_rewrite_sql(q, res['sql'], model="qwen2.5:14b")
print(f"\n大模型智能改写后SQL:\n{new_sql}")
# 执行SQL并输出结果
try:
columns, result = query_mysql(new_sql)
print("\n查询结果:")
print(columns)
for row in result:
print(row)
except Exception as e:
print(f"SQL执行出错: {e}")
except KeyboardInterrupt:
print("\n用户中断了操作")
continue
except EOFError:
print("\n检测到输入中断,程序退出")
break
except KeyboardInterrupt:
print("\n用户中断了程序")
break
except Exception as e:
print(f"程序执行出错: {e}")