diff --git a/src/graph/agent/llm_nodes.py b/src/graph/agent/llm_nodes.py
index e314ed8..d2f3cb0 100644
--- a/src/graph/agent/llm_nodes.py
+++ b/src/graph/agent/llm_nodes.py
@@ -59,11 +59,19 @@ def optimize_sql_node(state: SQLState) -> SQLState:
" 5) 输出必须是 Calcite 可解析的 SQL。\n"
"\n"
"【输出格式要求】\n"
- "请先给出结构化的“分析”部分,逐条说明你如何利用 EXPLAIN 与统计信息进行决策,至少包含:\n"
- " - 来自 EXPLAIN 的证据(例如:using_filesort/using_temporary、rows/filtered/attached_condition、possible_keys/used_key/idx 覆盖、Join 类型、驱动表选择等);\n"
- " - 来自统计信息的证据(例如:表行数、索引存在性、列选择性/基数、外键/唯一约束等);\n"
- " - 每条改写的预期效果与其如何降低 query_cost(如提升过滤选择性、减少回表、避免临时表、移除排序、减少扫描范围等)。\n"
- "然后再输出优化后的 SQL 代码块,使用 ```sql ... ```。\n"
+ "请使用以下格式输出:\n"
+ "\n"
+ "在这里详细分析EXPLAIN结果和统计信息,逐条说明优化决策:\n"
+ "- 来自 EXPLAIN 的证据(例如:using_filesort/using_temporary、rows/filtered/attached_condition、possible_keys/used_key/idx 覆盖、Join 类型、驱动表选择等);\n"
+ "- 来自统计信息的证据(例如:表行数、索引存在性、列选择性/基数、外键/唯一约束等);\n"
+ "- 每条改写的预期效果与其如何降低 query_cost(如提升过滤选择性、减少回表、避免临时表、移除排序、减少扫描范围等)。\n"
+ "\n"
+ "\n"
+ "优化后的 SQL:\n"
+ "```sql\n"
+ "优化后的SQL代码\n"
+ "```\n"
+ "\n"
),
},
{
@@ -79,11 +87,44 @@ def optimize_sql_node(state: SQLState) -> SQLState:
]
try:
llm = state.get("llm")
- content = llm.chat(messages)
- explanation = _extract_explanation_from_text(content)
- optimized_sql = _extract_sql_from_text(content)
- state["rewrite_explanation"] = explanation
- state["optimized_sql"] = optimized_sql
+
+ # 检查是否是思考模型
+ settings = state.get("settings")
+ if settings and "reasoner" in settings.model.lower():
+ # 使用思考模型的新方法
+ final_answer, reasoning = llm.chat_with_reasoning(messages)
+ state["llm_reasoning"] = reasoning
+ state["rewrite_explanation"] = reasoning
+ optimized_sql = _extract_sql_from_text(final_answer)
+ state["optimized_sql"] = optimized_sql
+
+ # 将推理过程添加到 messages 中,以便在 LangGraph Studio 中显示
+ if reasoning:
+ state.setdefault("messages", []).append({
+ "role": "assistant",
+ "content": f"🧠 **SQL优化推理过程**\n\n{reasoning}\n\n---\n\n💡 **优化后的SQL**\n\n```sql\n{optimized_sql}\n```"
+ })
+
+ # 添加调试信息到状态中,确保在Studio中可见
+ state["debug_info"] = {
+ "model_type": "reasoner",
+ "reasoning_extracted": bool(reasoning),
+ "final_answer_length": len(final_answer) if final_answer else 0
+ }
+ else:
+ # 使用传统方法
+ content = llm.chat(messages)
+ explanation = _extract_explanation_from_text(content)
+ optimized_sql = _extract_sql_from_text(content)
+ state["rewrite_explanation"] = explanation
+ state["optimized_sql"] = optimized_sql
+
+ # 添加调试信息
+ state["debug_info"] = {
+ "model_type": "standard",
+ "content_length": len(content) if content else 0
+ }
+
# state.setdefault("history", []).append("[optimize_sql] 已生成候选改写 SQL与改写说明")
except Exception as e:
# state.setdefault("history", []).append(f"[optimize_sql] 生成候选改写失败:{str(e)}")
@@ -268,7 +309,57 @@ def generate_optimization_plans(state: SQLState) -> SQLState:
try:
llm = state.get("llm")
- content = llm.chat(messages)
+
+ # 检查是否是思考模型 - 直接从配置获取
+ from src.config import get_settings
+ settings = get_settings()
+ if settings and "reasoner" in settings.model.lower():
+ # 使用流式思考模型
+ reasoning = ""
+ final_answer = ""
+
+ # 流式接收并累积推理过程
+ for chunk in llm.stream_with_reasoning(messages):
+ chunk_type = chunk.get("type")
+
+ if chunk_type == "reasoning":
+ reasoning = chunk.get("content", "")
+ # 实时更新推理内容到state
+ state["generate_plans_reasoning"] = reasoning
+
+ elif chunk_type == "answer":
+ final_answer = chunk.get("content", "")
+
+ elif chunk_type == "done":
+ reasoning = chunk.get("reasoning", "")
+ final_answer = chunk.get("answer", "")
+ break
+
+ state["generate_plans_reasoning"] = reasoning
+ content = final_answer
+
+ # 将完整的推理过程添加到 messages 中
+ if reasoning:
+ state.setdefault("messages", []).append({
+ "role": "assistant",
+ "content": f"🧠 **推理过程分析**\n\n{reasoning}\n\n---\n\n💡 **最终优化方案**\n\n{final_answer}"
+ })
+
+ # 添加调试信息
+ state["debug_info"] = {
+ "model_type": "reasoner",
+ "reasoning_extracted": bool(reasoning),
+ "final_answer_length": len(final_answer) if final_answer else 0
+ }
+ else:
+ # 使用传统方法
+ content = llm.chat(messages)
+
+ # 添加调试信息
+ state["debug_info"] = {
+ "model_type": "standard",
+ "content_length": len(content) if content else 0
+ }
# 提取JSON部分
import re
diff --git a/src/graph/graph.py b/src/graph/graph.py
index 95e19c0..697d3a5 100644
--- a/src/graph/graph.py
+++ b/src/graph/graph.py
@@ -203,6 +203,7 @@ def build_sqlopt_graph() -> CompiledStateGraph[Any, Any, Any, Any]:
graph.add_node("input", input_node)
graph.add_node("get_plan", get_query_plan_node)
graph.add_node("get_stats", get_stats_node)
+ # 使用带流式推理的generate_plans节点
graph.add_node("generate_plans", generate_optimization_plans)
graph.add_node("check_equivalence", equivalence_check_node)
graph.add_node("get_costs", get_costs_node)
diff --git a/src/graph/state/sqlstate.py b/src/graph/state/sqlstate.py
index b40be48..d86422e 100644
--- a/src/graph/state/sqlstate.py
+++ b/src/graph/state/sqlstate.py
@@ -36,6 +36,9 @@ class SQLState(MessagesState, total=False):
optimized_sql: str
rewrite_explanation: str
+ llm_reasoning: str # 添加推理过程字段
+ generate_plans_reasoning: str # 添加生成方案的推理过程字段
+ debug_info: Dict[str, Any] # 添加调试信息字段
equivalence: bool
cost_before: Optional[float]
cost_after: Optional[float]
diff --git a/src/llm/langchain_llm.py b/src/llm/langchain_llm.py
index 7b6ceba..6caca60 100644
--- a/src/llm/langchain_llm.py
+++ b/src/llm/langchain_llm.py
@@ -1,7 +1,10 @@
-from typing import List, Dict, Optional, TypedDict
+from typing import List, Dict, Optional, TypedDict, Generator, Tuple
from src.config import get_settings
from src.llm.client import LLMClient
from langchain.chat_models import init_chat_model
+import json
+import re
+import openai
class LangchainLLMClient(LLMClient):
@@ -40,6 +43,123 @@ def chat(
content = resp.content if resp.content else ""
return content or ""
+ def chat_with_reasoning(
+ self,
+ messages: List[Dict[str, str]],
+ temperature: float = 0.2,
+ max_tokens: Optional[int] = None,
+ state: Optional[TypedDict] = None,
+ ) -> Tuple[str, str]:
+ """
+ 调用思考模型并返回最终答案和推理过程
+ 返回: (最终答案, 推理过程)
+ """
+ settings = get_settings()
+ if not settings.api_key or settings.api_key == "EMPTY_KEY":
+ raise RuntimeError("OPENAI_API_KEY 未设置,请先在环境变量或 .env 中配置。")
+
+ self._llm.bind(
+ temperature=temperature,
+ max_tokens=max_tokens,
+ )
+ try:
+ resp = self._llm.invoke(messages)
+ except Exception as e:
+ raise RuntimeError(f"调用 LLM 出错:{e}") from e
+
+ # DeepSeek Reasoner 模型会在 reasoning_content 字段返回推理过程
+ reasoning_content = ""
+ final_answer = ""
+
+ # 检查是否有 reasoning_content 属性(DeepSeek Reasoner 特有)
+ if hasattr(resp, 'reasoning_content') and resp.reasoning_content:
+ reasoning_content = resp.reasoning_content
+ final_answer = resp.content if resp.content else ""
+ elif hasattr(resp, 'additional_kwargs') and 'reasoning_content' in resp.additional_kwargs:
+ reasoning_content = resp.additional_kwargs['reasoning_content']
+ final_answer = resp.content if resp.content else ""
+ else:
+ # 如果没有 reasoning_content,尝试从 content 中解析
+ content = resp.content if resp.content else ""
+ reasoning_content, final_answer = self._parse_reasoning_output(content)
+
+ return final_answer, reasoning_content
+
+ def _parse_reasoning_output(self, content: str) -> Tuple[str, str]:
+ """
+ 解析思考模型的输出,提取推理过程和最终答案
+ """
+ if not content:
+ return "", ""
+
+ # 尝试多种解析模式
+ patterns = [
+ # 模式1: ... ...
+ r'(.*?)\s*(.*?)',
+ # 模式2: 思考:... 答案:...
+ r'思考:\s*(.*?)\s*答案:\s*(.*?)$',
+ # 模式3: 推理过程:... 最终答案:...
+ r'推理过程:\s*(.*?)\s*最终答案:\s*(.*?)$',
+ # 模式4: 分析:... 结论:...
+ r'分析:\s*(.*?)\s*结论:\s*(.*?)$',
+ ]
+
+ for pattern in patterns:
+ match = re.search(pattern, content, re.DOTALL | re.IGNORECASE)
+ if match:
+ reasoning = match.group(1).strip()
+ final_answer = match.group(2).strip()
+ return reasoning, final_answer
+
+ # 如果没有找到特定模式,尝试按段落分割
+ paragraphs = content.split('\n\n')
+ if len(paragraphs) >= 2:
+ # 假设最后一段是最终答案,前面的都是推理过程
+ reasoning = '\n\n'.join(paragraphs[:-1]).strip()
+ final_answer = paragraphs[-1].strip()
+ return reasoning, final_answer
+
+ # 如果无法分离,返回原内容作为答案
+ return "", content
+
+ def chat_stream_with_reasoning(
+ self,
+ messages: List[Dict[str, str]],
+ temperature: float = 0.2,
+ max_tokens: Optional[int] = None,
+ state: Optional[TypedDict] = None,
+ ) -> Generator[Dict[str, str], None, None]:
+ """
+ 流式调用思考模型,实时返回推理过程
+ 生成器返回: {"type": "reasoning"|"answer", "content": "..."}
+ """
+ settings = get_settings()
+ if not settings.api_key or settings.api_key == "EMPTY_KEY":
+ raise RuntimeError("OPENAI_API_KEY 未设置,请先在环境变量或 .env 中配置。")
+
+ self._llm.bind(
+ temperature=temperature,
+ max_tokens=max_tokens,
+ )
+
+ try:
+ # 使用流式调用
+ for chunk in self._llm.stream(messages):
+ if hasattr(chunk, 'content') and chunk.content:
+ content = chunk.content
+
+ # 尝试判断当前内容类型
+ if '' in content or '思考:' in content or '推理' in content:
+ yield {"type": "reasoning", "content": content}
+ elif '' in content or '答案:' in content or '结论:' in content:
+ yield {"type": "answer", "content": content}
+ else:
+ # 默认作为推理过程
+ yield {"type": "reasoning", "content": content}
+
+ except Exception as e:
+ raise RuntimeError(f"流式调用 LLM 出错:{e}") from e
+
async def chat_async(
self,
messages: List[Dict[str, str]],
@@ -62,4 +182,70 @@ async def chat_async(
except Exception as e:
raise RuntimeError(f"调用 LLM 出错:{e}") from e
content = resp.content if resp.content else ""
- return content or ""
\ No newline at end of file
+ return content or ""
+ def stream_with_reasoning(
+ self,
+ messages: List[Dict[str, str]],
+ temperature: float = 0.2,
+ max_tokens: Optional[int] = None,
+ ) -> Generator[Dict[str, str], None, None]:
+ """
+ 流式调用思考模型,实时返回推理过程(使用OpenAI API)
+ 生成器返回: {"type": "reasoning"|"answer"|"done", "content": "...", "delta": "..."}
+ """
+ settings = get_settings()
+ if not settings.api_key or settings.api_key == "EMPTY_KEY":
+ raise RuntimeError("OPENAI_API_KEY 未设置,请先在环境变量或 .env 中配置。")
+
+ full_reasoning = ""
+ full_answer = ""
+
+ try:
+ # 直接使用OpenAI API以获取reasoning_content
+ client = openai.OpenAI(
+ api_key=settings.api_key,
+ base_url=settings.base_url
+ )
+
+ response = client.chat.completions.create(
+ model=settings.model,
+ messages=messages,
+ temperature=temperature,
+ max_tokens=max_tokens,
+ stream=True
+ )
+
+ # 处理流式响应
+ for chunk in response:
+ if chunk.choices and len(chunk.choices) > 0:
+ delta_obj = chunk.choices[0].delta
+
+ # 检查reasoning_content (DeepSeek Reasoner特有)
+ if hasattr(delta_obj, 'reasoning_content') and delta_obj.reasoning_content:
+ delta = delta_obj.reasoning_content
+ full_reasoning += delta
+ yield {
+ "type": "reasoning",
+ "delta": delta,
+ "content": full_reasoning
+ }
+
+ # 检查普通content
+ if hasattr(delta_obj, 'content') and delta_obj.content:
+ delta = delta_obj.content
+ full_answer += delta
+ yield {
+ "type": "answer",
+ "delta": delta,
+ "content": full_answer
+ }
+
+ # 返回完成信号
+ yield {
+ "type": "done",
+ "reasoning": full_reasoning,
+ "answer": full_answer
+ }
+
+ except Exception as e:
+ raise RuntimeError(f"流式调用 LLM 出错:{e}") from e
diff --git a/test_e2e.py b/test_e2e.py
new file mode 100755
index 0000000..def6e07
--- /dev/null
+++ b/test_e2e.py
@@ -0,0 +1,330 @@
+#!/usr/bin/env python3
+"""
+端到端测试脚本 - 模拟LangGraph完整流程
+使用TPC-H schema和查询
+"""
+import sys
+import os
+
+# 添加项目路径
+sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
+
+from src.graph.graph import build_sqlopt_graph
+from src.graph.state import SQLState
+from src.config import get_settings
+from src.llm.langchain_llm import LangchainLLMClient
+import json
+
+# TPC-H Schema
+TPCH_SCHEMA = """
+CREATE TABLE NATION (
+ N_NATIONKEY INTEGER NOT NULL,
+ N_NAME CHAR(25) NOT NULL,
+ N_REGIONKEY INTEGER NOT NULL,
+ N_COMMENT VARCHAR(152)
+);
+
+CREATE TABLE REGION (
+ R_REGIONKEY INTEGER NOT NULL,
+ R_NAME CHAR(25) NOT NULL,
+ R_COMMENT VARCHAR(152)
+);
+
+CREATE TABLE PART (
+ P_PARTKEY INTEGER NOT NULL,
+ P_NAME VARCHAR(55) NOT NULL,
+ P_MFGR CHAR(25) NOT NULL,
+ P_BRAND CHAR(10) NOT NULL,
+ P_TYPE VARCHAR(25) NOT NULL,
+ P_SIZE INTEGER NOT NULL,
+ P_CONTAINER CHAR(10) NOT NULL,
+ P_RETAILPRICE DECIMAL(15,2) NOT NULL,
+ P_COMMENT VARCHAR(23) NOT NULL
+);
+
+CREATE TABLE SUPPLIER (
+ S_SUPPKEY INTEGER NOT NULL,
+ S_NAME CHAR(25) NOT NULL,
+ S_ADDRESS VARCHAR(40) NOT NULL,
+ S_NATIONKEY INTEGER NOT NULL,
+ S_PHONE CHAR(15) NOT NULL,
+ S_ACCTBAL DECIMAL(15,2) NOT NULL,
+ S_COMMENT VARCHAR(101) NOT NULL
+);
+
+CREATE TABLE PARTSUPP (
+ PS_PARTKEY INTEGER NOT NULL,
+ PS_SUPPKEY INTEGER NOT NULL,
+ PS_AVAILQTY INTEGER NOT NULL,
+ PS_SUPPLYCOST DECIMAL(15,2) NOT NULL,
+ PS_COMMENT VARCHAR(199) NOT NULL
+);
+
+CREATE TABLE CUSTOMER (
+ C_CUSTKEY INTEGER NOT NULL,
+ C_NAME VARCHAR(25) NOT NULL,
+ C_ADDRESS VARCHAR(40) NOT NULL,
+ C_NATIONKEY INTEGER NOT NULL,
+ C_PHONE CHAR(15) NOT NULL,
+ C_ACCTBAL DECIMAL(15,2) NOT NULL,
+ C_MKTSEGMENT CHAR(10) NOT NULL,
+ C_COMMENT VARCHAR(117) NOT NULL
+);
+
+CREATE TABLE ORDERS (
+ O_ORDERKEY INTEGER NOT NULL,
+ O_CUSTKEY INTEGER NOT NULL,
+ O_ORDERSTATUS CHAR(1) NOT NULL,
+ O_TOTALPRICE DECIMAL(15,2) NOT NULL,
+ O_ORDERDATE DATE NOT NULL,
+ O_ORDERPRIORITY CHAR(15) NOT NULL,
+ O_CLERK CHAR(15) NOT NULL,
+ O_SHIPPRIORITY INTEGER NOT NULL,
+ O_COMMENT VARCHAR(79) NOT NULL
+);
+
+CREATE TABLE LINEITEM (
+ L_ORDERKEY INTEGER NOT NULL,
+ L_PARTKEY INTEGER NOT NULL,
+ L_SUPPKEY INTEGER NOT NULL,
+ L_LINENUMBER INTEGER NOT NULL,
+ L_QUANTITY DECIMAL(15,2) NOT NULL,
+ L_EXTENDEDPRICE DECIMAL(15,2) NOT NULL,
+ L_DISCOUNT DECIMAL(15,2) NOT NULL,
+ L_TAX DECIMAL(15,2) NOT NULL,
+ L_RETURNFLAG CHAR(1) NOT NULL,
+ L_LINESTATUS CHAR(1) NOT NULL,
+ L_SHIPDATE DATE NOT NULL,
+ L_COMMITDATE DATE NOT NULL,
+ L_RECEIPTDATE DATE NOT NULL,
+ L_SHIPINSTRUCT CHAR(25) NOT NULL,
+ L_SHIPMODE CHAR(10) NOT NULL,
+ L_COMMENT VARCHAR(44) NOT NULL
+);
+"""
+
+# TPC-H Query 1
+TPCH_QUERY_1 = """
+SELECT
+ l_returnflag,
+ l_linestatus,
+ SUM(l_quantity) AS sum_qty,
+ SUM(l_extendedprice) AS sum_base_price,
+ SUM(l_extendedprice * (1 - l_discount)) AS sum_disc_price,
+ SUM(l_extendedprice * (1 - l_discount) * (1 + l_tax)) AS sum_charge,
+ AVG(l_quantity) AS avg_qty,
+ AVG(l_extendedprice) AS avg_price,
+ AVG(l_discount) AS avg_disc,
+ COUNT(*) AS count_order
+FROM lineitem
+WHERE l_shipdate <= DATE '1998-12-01' - INTERVAL 90 DAY
+GROUP BY l_returnflag, l_linestatus
+ORDER BY l_returnflag, l_linestatus;
+"""
+
+
+def print_section(title):
+ """打印章节标题"""
+ print("\n" + "=" * 100)
+ print(f" {title}")
+ print("=" * 100)
+
+
+def print_subsection(title):
+ """打印子章节标题"""
+ print("\n" + "-" * 100)
+ print(f" {title}")
+ print("-" * 100)
+
+
+def test_e2e_sql_optimization():
+ """端到端测试SQL优化流程"""
+
+ print_section("🚀 端到端测试 - TPC-H Query 1 优化")
+
+ # 1. 初始化配置
+ print_subsection("1️⃣ 初始化配置")
+ settings = get_settings()
+ print(f"✅ 模型: {settings.model}")
+ print(f"✅ Base URL: {settings.base_url}")
+ print(f"✅ 数据库: {settings.mysql_host}:{settings.mysql_port}/{settings.mysql_database}")
+
+ # 2. 创建LLM客户端
+ print_subsection("2️⃣ 创建LLM客户端")
+ llm = LangchainLLMClient(
+ model=settings.model,
+ base_url=settings.base_url,
+ api_key=settings.api_key
+ )
+ print(f"✅ LLM客户端创建成功")
+
+ # 3. 构建Graph
+ print_subsection("3️⃣ 构建SQL优化图")
+ graph = build_sqlopt_graph()
+ print(f"✅ Graph构建成功")
+
+ # 4. 准备输入状态
+ print_subsection("4️⃣ 准备输入状态")
+ print(f"\n📋 Schema:")
+ print(TPCH_SCHEMA[:500] + "...\n")
+
+ print(f"📝 SQL Query:")
+ print(TPCH_QUERY_1)
+
+ initial_state = SQLState(
+ sql=TPCH_QUERY_1.strip(),
+ db_schema=TPCH_SCHEMA.strip(),
+ max_iterations=2,
+ llm=llm,
+ settings=settings,
+ iteration_count=0,
+ messages=[],
+ optimization_plans=[],
+ current_plan_index=0
+ )
+ print(f"\n✅ 初始状态准备完成")
+
+ # 5. 执行优化流程
+ print_section("5️⃣ 执行优化流程")
+
+ try:
+ print("\n🔄 开始执行Graph...\n")
+
+ # 流式执行graph
+ step_count = 0
+ for event in graph.stream(initial_state):
+ step_count += 1
+
+ # 打印节点信息
+ for node_name, node_output in event.items():
+ print(f"\n🔹 步骤 {step_count}: 节点 [{node_name}]")
+
+ # 显示关键信息
+ if node_name == "input":
+ print(f" ├─ SQL长度: {len(node_output.get('sql', ''))} 字符")
+ print(f" └─ Schema长度: {len(node_output.get('db_schema', ''))} 字符")
+
+ elif node_name == "get_plan":
+ plan = node_output.get("plan", "")
+ print(f" ├─ 执行计划已获取")
+ print(f" └─ 计划长度: {len(plan)} 字符")
+ if plan:
+ print(f"\n 📊 执行计划预览:")
+ print(" " + "\n ".join(plan[:300].split("\n")))
+ if len(plan) > 300:
+ print(" ...")
+
+ elif node_name == "get_stats":
+ stats = node_output.get("stats", {})
+ print(f" ├─ 统计信息已获取")
+ if stats.get("tables"):
+ print(f" ├─ 涉及表: {', '.join(stats['tables'].keys())}")
+ if stats.get("collection_success"):
+ print(f" └─ ✅ 统计信息收集成功")
+ else:
+ print(f" └─ ⚠️ 统计信息收集部分失败")
+
+ elif node_name == "generate_plans":
+ plans = node_output.get("optimization_plans", [])
+ reasoning = node_output.get("generate_plans_reasoning", "")
+
+ print(f" ├─ 生成了 {len(plans)} 个优化方案")
+
+ if reasoning:
+ print(f" ├─ 🧠 推理过程长度: {len(reasoning)} 字符")
+ print(f"\n 🧠 推理过程预览:")
+ print(" " + "\n ".join(reasoning[:500].split("\n")))
+ if len(reasoning) > 500:
+ print(" ...")
+
+ print(f"\n 💡 优化方案:")
+ for i, plan in enumerate(plans, 1):
+ print(f"\n 方案 {i}:")
+ print(f" ├─ 理由: {plan.get('reason', '')[:100]}...")
+ opt_sql = plan.get('optimized_sql', '')
+ print(f" └─ SQL: {opt_sql[:100]}...")
+
+ elif node_name == "check_equivalence":
+ eq = node_output.get("equivalence", False)
+ idx = node_output.get("current_plan_index", 0)
+ print(f" ├─ 方案 {idx + 1} 等价性检查")
+ print(f" └─ {'✅ 等价' if eq else '❌ 不等价'}")
+
+ elif node_name == "get_costs":
+ cost_before = node_output.get("cost_before")
+ cost_after = node_output.get("cost_after")
+ idx = node_output.get("current_plan_index", 0)
+ print(f" ├─ 方案 {idx + 1} 成本估算")
+ if cost_before is not None and cost_after is not None:
+ print(f" ├─ 优化前成本: {cost_before}")
+ print(f" ├─ 优化后成本: {cost_after}")
+ if cost_before > 0:
+ improvement = ((cost_before - cost_after) / cost_before) * 100
+ print(f" └─ 改善: {improvement:.2f}%")
+ else:
+ print(f" └─ ⚠️ 成本估算失败或跳过")
+
+ elif node_name == "next_plan":
+ idx = node_output.get("current_plan_index", 0)
+ print(f" └─ 切换到方案 {idx + 1}")
+
+ elif node_name == "report":
+ final_report = node_output.get("final_report", "")
+ print(f" ├─ 📄 最终报告已生成")
+ print(f" └─ 报告长度: {len(final_report)} 字符")
+
+ print(f"\n✅ Graph执行完成!共 {step_count} 个步骤\n")
+
+ except Exception as e:
+ print(f"\n❌ 执行失败: {e}")
+ import traceback
+ traceback.print_exc()
+ return
+
+ # 6. 显示最终结果
+ print_section("6️⃣ 最终结果")
+
+ # 获取最终状态
+ final_state = event
+ if final_state:
+ # 获取最后一个节点的输出(应该是report节点)
+ for node_name, state in final_state.items():
+ if node_name == "report":
+
+ # 显示优化方案
+ plans = state.get("optimization_plans", [])
+ if plans:
+ print_subsection("📊 优化方案详情")
+ for i, plan in enumerate(plans, 1):
+ print(f"\n方案 {i}:")
+ print(f" 理由: {plan.get('reason', 'N/A')}")
+ print(f" 等价性: {'✅ 通过' if plan.get('equivalence') else '❌ 未通过'}")
+ print(f" 成本: {plan.get('cost', 'N/A')}")
+ print(f" SQL:")
+ print(" " + "\n ".join(plan.get('optimized_sql', '').split("\n")))
+
+ # 显示推理过程
+ reasoning = state.get("generate_plans_reasoning", "")
+ if reasoning:
+ print_subsection("🧠 完整推理过程")
+ print(reasoning)
+
+ # 显示最终报告
+ final_report = state.get("final_report", "")
+ if final_report:
+ print_subsection("📄 最终报告")
+ print(final_report)
+
+ print_section("✅ 测试完成")
+
+
+if __name__ == "__main__":
+ try:
+ test_e2e_sql_optimization()
+ except KeyboardInterrupt:
+ print("\n\n⚠️ 测试被用户中断")
+ except Exception as e:
+ print(f"\n\n❌ 测试失败: {e}")
+ import traceback
+ traceback.print_exc()
+
diff --git a/test_reasoning_simple.py b/test_reasoning_simple.py
new file mode 100644
index 0000000..9d2fe39
--- /dev/null
+++ b/test_reasoning_simple.py
@@ -0,0 +1,103 @@
+#!/usr/bin/env python3
+"""
+简单测试流式推理功能
+"""
+import sys
+import os
+sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
+
+from src.llm.langchain_llm import LangchainLLMClient
+from src.config import get_settings
+
+def test_streaming_reasoning():
+ """测试流式推理"""
+ print("=" * 80)
+ print("🧪 测试流式推理功能")
+ print("=" * 80)
+
+ # 初始化
+ settings = get_settings()
+ print(f"\n✅ 模型: {settings.model}")
+
+ llm = LangchainLLMClient(
+ model=settings.model,
+ base_url=settings.base_url,
+ api_key=settings.api_key
+ )
+ print(f"✅ LLM客户端创建成功\n")
+
+ # 简单的推理任务
+ messages = [
+ {
+ "role": "user",
+ "content": "请分析一下为什么这个SQL查询可能会慢,并给出优化建议:SELECT * FROM orders WHERE order_date > '2024-01-01' ORDER BY total_amount DESC"
+ }
+ ]
+
+ print("🔄 开始流式推理...\n")
+ print("-" * 80)
+
+ try:
+ reasoning_length = 0
+ answer_length = 0
+ chunk_count = 0
+
+ for chunk in llm.stream_with_reasoning(messages):
+ chunk_count += 1
+ chunk_type = chunk.get("type")
+
+ if chunk_type == "reasoning":
+ delta = chunk.get("delta", "")
+ if delta:
+ print(f"🧠 [推理chunk {chunk_count}] {len(delta)} 字符", flush=True)
+ reasoning_length += len(delta)
+ # 打印前50个字符
+ preview = delta[:50].replace("\n", " ")
+ print(f" 预览: {preview}...", flush=True)
+
+ elif chunk_type == "answer":
+ delta = chunk.get("delta", "")
+ if delta:
+ print(f"💡 [回答chunk {chunk_count}] {len(delta)} 字符", flush=True)
+ answer_length += len(delta)
+ # 打印前50个字符
+ preview = delta[:50].replace("\n", " ")
+ print(f" 预览: {preview}...", flush=True)
+
+ elif chunk_type == "done":
+ print(f"\n✅ 完成!", flush=True)
+ print(f" 总推理长度: {reasoning_length} 字符", flush=True)
+ print(f" 总回答长度: {answer_length} 字符", flush=True)
+ print(f" 总chunk数: {chunk_count}", flush=True)
+
+ full_reasoning = chunk.get("reasoning", "")
+ full_answer = chunk.get("answer", "")
+
+ print("\n" + "=" * 80)
+ print("📋 完整推理过程:")
+ print("=" * 80)
+ print(full_reasoning[:500] + "..." if len(full_reasoning) > 500 else full_reasoning)
+
+ print("\n" + "=" * 80)
+ print("💡 完整回答:")
+ print("=" * 80)
+ print(full_answer[:500] + "..." if len(full_answer) > 500 else full_answer)
+ break
+
+ print("\n" + "=" * 80)
+ print("✅ 测试成功!")
+ print("=" * 80)
+
+ except Exception as e:
+ print(f"\n❌ 测试失败: {e}", flush=True)
+ import traceback
+ traceback.print_exc()
+
+if __name__ == "__main__":
+ test_streaming_reasoning()
+
+
+
+
+
+