Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 102 additions & 11 deletions src/graph/agent/llm_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,19 @@ def optimize_sql_node(state: SQLState) -> SQLState:
" 5) 输出必须是 Calcite 可解析的 SQL。\n"
"\n"
"【输出格式要求】\n"
"请先给出结构化的“分析”部分,逐条说明你如何利用 EXPLAIN 与统计信息进行决策,至少包含:\n"
" - 来自 EXPLAIN 的证据(例如:using_filesort/using_temporary、rows/filtered/attached_condition、possible_keys/used_key/idx 覆盖、Join 类型、驱动表选择等);\n"
" - 来自统计信息的证据(例如:表行数、索引存在性、列选择性/基数、外键/唯一约束等);\n"
" - 每条改写的预期效果与其如何降低 query_cost(如提升过滤选择性、减少回表、避免临时表、移除排序、减少扫描范围等)。\n"
"然后再输出优化后的 SQL 代码块,使用 ```sql ... ```。\n"
"请使用以下格式输出:\n"
"<thinking>\n"
"在这里详细分析EXPLAIN结果和统计信息,逐条说明优化决策:\n"
"- 来自 EXPLAIN 的证据(例如:using_filesort/using_temporary、rows/filtered/attached_condition、possible_keys/used_key/idx 覆盖、Join 类型、驱动表选择等);\n"
"- 来自统计信息的证据(例如:表行数、索引存在性、列选择性/基数、外键/唯一约束等);\n"
"- 每条改写的预期效果与其如何降低 query_cost(如提升过滤选择性、减少回表、避免临时表、移除排序、减少扫描范围等)。\n"
"</thinking>\n"
"<answer>\n"
"优化后的 SQL:\n"
"```sql\n"
"优化后的SQL代码\n"
"```\n"
"</answer>\n"
),
},
{
Expand All @@ -79,11 +87,44 @@ def optimize_sql_node(state: SQLState) -> SQLState:
]
try:
llm = state.get("llm")
content = llm.chat(messages)
explanation = _extract_explanation_from_text(content)
optimized_sql = _extract_sql_from_text(content)
state["rewrite_explanation"] = explanation
state["optimized_sql"] = optimized_sql

# 检查是否是思考模型
settings = state.get("settings")
if settings and "reasoner" in settings.model.lower():
# 使用思考模型的新方法
final_answer, reasoning = llm.chat_with_reasoning(messages)
state["llm_reasoning"] = reasoning
state["rewrite_explanation"] = reasoning
optimized_sql = _extract_sql_from_text(final_answer)
state["optimized_sql"] = optimized_sql

# 将推理过程添加到 messages 中,以便在 LangGraph Studio 中显示
if reasoning:
state.setdefault("messages", []).append({
"role": "assistant",
"content": f"🧠 **SQL优化推理过程**\n\n{reasoning}\n\n---\n\n💡 **优化后的SQL**\n\n```sql\n{optimized_sql}\n```"
})

# 添加调试信息到状态中,确保在Studio中可见
state["debug_info"] = {
"model_type": "reasoner",
"reasoning_extracted": bool(reasoning),
"final_answer_length": len(final_answer) if final_answer else 0
}
else:
# 使用传统方法
content = llm.chat(messages)
explanation = _extract_explanation_from_text(content)
optimized_sql = _extract_sql_from_text(content)
state["rewrite_explanation"] = explanation
state["optimized_sql"] = optimized_sql

# 添加调试信息
state["debug_info"] = {
"model_type": "standard",
"content_length": len(content) if content else 0
}

# state.setdefault("history", []).append("[optimize_sql] 已生成候选改写 SQL与改写说明")
except Exception as e:
# state.setdefault("history", []).append(f"[optimize_sql] 生成候选改写失败:{str(e)}")
Expand Down Expand Up @@ -268,7 +309,57 @@ def generate_optimization_plans(state: SQLState) -> SQLState:

try:
llm = state.get("llm")
content = llm.chat(messages)

# 检查是否是思考模型 - 直接从配置获取
from src.config import get_settings
settings = get_settings()
if settings and "reasoner" in settings.model.lower():
# 使用流式思考模型
reasoning = ""
final_answer = ""

# 流式接收并累积推理过程
for chunk in llm.stream_with_reasoning(messages):
chunk_type = chunk.get("type")

if chunk_type == "reasoning":
reasoning = chunk.get("content", "")
# 实时更新推理内容到state
state["generate_plans_reasoning"] = reasoning

elif chunk_type == "answer":
final_answer = chunk.get("content", "")

elif chunk_type == "done":
reasoning = chunk.get("reasoning", "")
final_answer = chunk.get("answer", "")
break

state["generate_plans_reasoning"] = reasoning
content = final_answer

# 将完整的推理过程添加到 messages 中
if reasoning:
state.setdefault("messages", []).append({
"role": "assistant",
"content": f"🧠 **推理过程分析**\n\n{reasoning}\n\n---\n\n💡 **最终优化方案**\n\n{final_answer}"
})

# 添加调试信息
state["debug_info"] = {
"model_type": "reasoner",
"reasoning_extracted": bool(reasoning),
"final_answer_length": len(final_answer) if final_answer else 0
}
else:
# 使用传统方法
content = llm.chat(messages)

# 添加调试信息
state["debug_info"] = {
"model_type": "standard",
"content_length": len(content) if content else 0
}

# 提取JSON部分
import re
Expand Down
1 change: 1 addition & 0 deletions src/graph/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ def build_sqlopt_graph() -> CompiledStateGraph[Any, Any, Any, Any]:
graph.add_node("input", input_node)
graph.add_node("get_plan", get_query_plan_node)
graph.add_node("get_stats", get_stats_node)
# 使用带流式推理的generate_plans节点
graph.add_node("generate_plans", generate_optimization_plans)
graph.add_node("check_equivalence", equivalence_check_node)
graph.add_node("get_costs", get_costs_node)
Expand Down
3 changes: 3 additions & 0 deletions src/graph/state/sqlstate.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ class SQLState(MessagesState, total=False):

optimized_sql: str
rewrite_explanation: str
llm_reasoning: str # 添加推理过程字段
generate_plans_reasoning: str # 添加生成方案的推理过程字段
debug_info: Dict[str, Any] # 添加调试信息字段
equivalence: bool
cost_before: Optional[float]
cost_after: Optional[float]
Expand Down
190 changes: 188 additions & 2 deletions src/llm/langchain_llm.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from typing import List, Dict, Optional, TypedDict
from typing import List, Dict, Optional, TypedDict, Generator, Tuple
from src.config import get_settings
from src.llm.client import LLMClient
from langchain.chat_models import init_chat_model
import json
import re
import openai


class LangchainLLMClient(LLMClient):
Expand Down Expand Up @@ -40,6 +43,123 @@ def chat(
content = resp.content if resp.content else ""
return content or ""

def chat_with_reasoning(
self,
messages: List[Dict[str, str]],
temperature: float = 0.2,
max_tokens: Optional[int] = None,
state: Optional[TypedDict] = None,
) -> Tuple[str, str]:
"""
调用思考模型并返回最终答案和推理过程
返回: (最终答案, 推理过程)
"""
settings = get_settings()
if not settings.api_key or settings.api_key == "EMPTY_KEY":
raise RuntimeError("OPENAI_API_KEY 未设置,请先在环境变量或 .env 中配置。")

self._llm.bind(
temperature=temperature,
max_tokens=max_tokens,
)
try:
resp = self._llm.invoke(messages)
except Exception as e:
raise RuntimeError(f"调用 LLM 出错:{e}") from e

# DeepSeek Reasoner 模型会在 reasoning_content 字段返回推理过程
reasoning_content = ""
final_answer = ""

# 检查是否有 reasoning_content 属性(DeepSeek Reasoner 特有)
if hasattr(resp, 'reasoning_content') and resp.reasoning_content:
reasoning_content = resp.reasoning_content
final_answer = resp.content if resp.content else ""
elif hasattr(resp, 'additional_kwargs') and 'reasoning_content' in resp.additional_kwargs:
reasoning_content = resp.additional_kwargs['reasoning_content']
final_answer = resp.content if resp.content else ""
else:
# 如果没有 reasoning_content,尝试从 content 中解析
content = resp.content if resp.content else ""
reasoning_content, final_answer = self._parse_reasoning_output(content)

return final_answer, reasoning_content

def _parse_reasoning_output(self, content: str) -> Tuple[str, str]:
"""
解析思考模型的输出,提取推理过程和最终答案
"""
if not content:
return "", ""

# 尝试多种解析模式
patterns = [
# 模式1: <thinking>...</thinking> <answer>...</answer>
r'<thinking>(.*?)</thinking>\s*<answer>(.*?)</answer>',
# 模式2: 思考:... 答案:...
r'思考:\s*(.*?)\s*答案:\s*(.*?)$',
# 模式3: 推理过程:... 最终答案:...
r'推理过程:\s*(.*?)\s*最终答案:\s*(.*?)$',
# 模式4: 分析:... 结论:...
r'分析:\s*(.*?)\s*结论:\s*(.*?)$',
]

for pattern in patterns:
match = re.search(pattern, content, re.DOTALL | re.IGNORECASE)
if match:
reasoning = match.group(1).strip()
final_answer = match.group(2).strip()
return reasoning, final_answer

# 如果没有找到特定模式,尝试按段落分割
paragraphs = content.split('\n\n')
if len(paragraphs) >= 2:
# 假设最后一段是最终答案,前面的都是推理过程
reasoning = '\n\n'.join(paragraphs[:-1]).strip()
final_answer = paragraphs[-1].strip()
return reasoning, final_answer

# 如果无法分离,返回原内容作为答案
return "", content

def chat_stream_with_reasoning(
self,
messages: List[Dict[str, str]],
temperature: float = 0.2,
max_tokens: Optional[int] = None,
state: Optional[TypedDict] = None,
) -> Generator[Dict[str, str], None, None]:
"""
流式调用思考模型,实时返回推理过程
生成器返回: {"type": "reasoning"|"answer", "content": "..."}
"""
settings = get_settings()
if not settings.api_key or settings.api_key == "EMPTY_KEY":
raise RuntimeError("OPENAI_API_KEY 未设置,请先在环境变量或 .env 中配置。")

self._llm.bind(
temperature=temperature,
max_tokens=max_tokens,
)

try:
# 使用流式调用
for chunk in self._llm.stream(messages):
if hasattr(chunk, 'content') and chunk.content:
content = chunk.content

# 尝试判断当前内容类型
if '<thinking>' in content or '思考:' in content or '推理' in content:
yield {"type": "reasoning", "content": content}
elif '<answer>' in content or '答案:' in content or '结论:' in content:
yield {"type": "answer", "content": content}
else:
# 默认作为推理过程
yield {"type": "reasoning", "content": content}

except Exception as e:
raise RuntimeError(f"流式调用 LLM 出错:{e}") from e

async def chat_async(
self,
messages: List[Dict[str, str]],
Expand All @@ -62,4 +182,70 @@ async def chat_async(
except Exception as e:
raise RuntimeError(f"调用 LLM 出错:{e}") from e
content = resp.content if resp.content else ""
return content or ""
return content or ""
def stream_with_reasoning(
self,
messages: List[Dict[str, str]],
temperature: float = 0.2,
max_tokens: Optional[int] = None,
) -> Generator[Dict[str, str], None, None]:
"""
流式调用思考模型,实时返回推理过程(使用OpenAI API)
生成器返回: {"type": "reasoning"|"answer"|"done", "content": "...", "delta": "..."}
"""
settings = get_settings()
if not settings.api_key or settings.api_key == "EMPTY_KEY":
raise RuntimeError("OPENAI_API_KEY 未设置,请先在环境变量或 .env 中配置。")

full_reasoning = ""
full_answer = ""

try:
# 直接使用OpenAI API以获取reasoning_content
client = openai.OpenAI(
api_key=settings.api_key,
base_url=settings.base_url
)

response = client.chat.completions.create(
model=settings.model,
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
stream=True
)

# 处理流式响应
for chunk in response:
if chunk.choices and len(chunk.choices) > 0:
delta_obj = chunk.choices[0].delta

# 检查reasoning_content (DeepSeek Reasoner特有)
if hasattr(delta_obj, 'reasoning_content') and delta_obj.reasoning_content:
delta = delta_obj.reasoning_content
full_reasoning += delta
yield {
"type": "reasoning",
"delta": delta,
"content": full_reasoning
}

# 检查普通content
if hasattr(delta_obj, 'content') and delta_obj.content:
delta = delta_obj.content
full_answer += delta
yield {
"type": "answer",
"delta": delta,
"content": full_answer
}

# 返回完成信号
yield {
"type": "done",
"reasoning": full_reasoning,
"answer": full_answer
}

except Exception as e:
raise RuntimeError(f"流式调用 LLM 出错:{e}") from e
Loading