diff --git a/src/graph/agent/llm_nodes.py b/src/graph/agent/llm_nodes.py index e314ed8..d2f3cb0 100644 --- a/src/graph/agent/llm_nodes.py +++ b/src/graph/agent/llm_nodes.py @@ -59,11 +59,19 @@ def optimize_sql_node(state: SQLState) -> SQLState: " 5) 输出必须是 Calcite 可解析的 SQL。\n" "\n" "【输出格式要求】\n" - "请先给出结构化的“分析”部分,逐条说明你如何利用 EXPLAIN 与统计信息进行决策,至少包含:\n" - " - 来自 EXPLAIN 的证据(例如:using_filesort/using_temporary、rows/filtered/attached_condition、possible_keys/used_key/idx 覆盖、Join 类型、驱动表选择等);\n" - " - 来自统计信息的证据(例如:表行数、索引存在性、列选择性/基数、外键/唯一约束等);\n" - " - 每条改写的预期效果与其如何降低 query_cost(如提升过滤选择性、减少回表、避免临时表、移除排序、减少扫描范围等)。\n" - "然后再输出优化后的 SQL 代码块,使用 ```sql ... ```。\n" + "请使用以下格式输出:\n" + "\n" + "在这里详细分析EXPLAIN结果和统计信息,逐条说明优化决策:\n" + "- 来自 EXPLAIN 的证据(例如:using_filesort/using_temporary、rows/filtered/attached_condition、possible_keys/used_key/idx 覆盖、Join 类型、驱动表选择等);\n" + "- 来自统计信息的证据(例如:表行数、索引存在性、列选择性/基数、外键/唯一约束等);\n" + "- 每条改写的预期效果与其如何降低 query_cost(如提升过滤选择性、减少回表、避免临时表、移除排序、减少扫描范围等)。\n" + "\n" + "\n" + "优化后的 SQL:\n" + "```sql\n" + "优化后的SQL代码\n" + "```\n" + "\n" ), }, { @@ -79,11 +87,44 @@ def optimize_sql_node(state: SQLState) -> SQLState: ] try: llm = state.get("llm") - content = llm.chat(messages) - explanation = _extract_explanation_from_text(content) - optimized_sql = _extract_sql_from_text(content) - state["rewrite_explanation"] = explanation - state["optimized_sql"] = optimized_sql + + # 检查是否是思考模型 + settings = state.get("settings") + if settings and "reasoner" in settings.model.lower(): + # 使用思考模型的新方法 + final_answer, reasoning = llm.chat_with_reasoning(messages) + state["llm_reasoning"] = reasoning + state["rewrite_explanation"] = reasoning + optimized_sql = _extract_sql_from_text(final_answer) + state["optimized_sql"] = optimized_sql + + # 将推理过程添加到 messages 中,以便在 LangGraph Studio 中显示 + if reasoning: + state.setdefault("messages", []).append({ + "role": "assistant", + "content": f"🧠 **SQL优化推理过程**\n\n{reasoning}\n\n---\n\n💡 **优化后的SQL**\n\n```sql\n{optimized_sql}\n```" + }) + + # 添加调试信息到状态中,确保在Studio中可见 + state["debug_info"] = { + "model_type": "reasoner", + "reasoning_extracted": bool(reasoning), + "final_answer_length": len(final_answer) if final_answer else 0 + } + else: + # 使用传统方法 + content = llm.chat(messages) + explanation = _extract_explanation_from_text(content) + optimized_sql = _extract_sql_from_text(content) + state["rewrite_explanation"] = explanation + state["optimized_sql"] = optimized_sql + + # 添加调试信息 + state["debug_info"] = { + "model_type": "standard", + "content_length": len(content) if content else 0 + } + # state.setdefault("history", []).append("[optimize_sql] 已生成候选改写 SQL与改写说明") except Exception as e: # state.setdefault("history", []).append(f"[optimize_sql] 生成候选改写失败:{str(e)}") @@ -268,7 +309,57 @@ def generate_optimization_plans(state: SQLState) -> SQLState: try: llm = state.get("llm") - content = llm.chat(messages) + + # 检查是否是思考模型 - 直接从配置获取 + from src.config import get_settings + settings = get_settings() + if settings and "reasoner" in settings.model.lower(): + # 使用流式思考模型 + reasoning = "" + final_answer = "" + + # 流式接收并累积推理过程 + for chunk in llm.stream_with_reasoning(messages): + chunk_type = chunk.get("type") + + if chunk_type == "reasoning": + reasoning = chunk.get("content", "") + # 实时更新推理内容到state + state["generate_plans_reasoning"] = reasoning + + elif chunk_type == "answer": + final_answer = chunk.get("content", "") + + elif chunk_type == "done": + reasoning = chunk.get("reasoning", "") + final_answer = chunk.get("answer", "") + break + + state["generate_plans_reasoning"] = reasoning + content = final_answer + + # 将完整的推理过程添加到 messages 中 + if reasoning: + state.setdefault("messages", []).append({ + "role": "assistant", + "content": f"🧠 **推理过程分析**\n\n{reasoning}\n\n---\n\n💡 **最终优化方案**\n\n{final_answer}" + }) + + # 添加调试信息 + state["debug_info"] = { + "model_type": "reasoner", + "reasoning_extracted": bool(reasoning), + "final_answer_length": len(final_answer) if final_answer else 0 + } + else: + # 使用传统方法 + content = llm.chat(messages) + + # 添加调试信息 + state["debug_info"] = { + "model_type": "standard", + "content_length": len(content) if content else 0 + } # 提取JSON部分 import re diff --git a/src/graph/graph.py b/src/graph/graph.py index 95e19c0..697d3a5 100644 --- a/src/graph/graph.py +++ b/src/graph/graph.py @@ -203,6 +203,7 @@ def build_sqlopt_graph() -> CompiledStateGraph[Any, Any, Any, Any]: graph.add_node("input", input_node) graph.add_node("get_plan", get_query_plan_node) graph.add_node("get_stats", get_stats_node) + # 使用带流式推理的generate_plans节点 graph.add_node("generate_plans", generate_optimization_plans) graph.add_node("check_equivalence", equivalence_check_node) graph.add_node("get_costs", get_costs_node) diff --git a/src/graph/state/sqlstate.py b/src/graph/state/sqlstate.py index b40be48..d86422e 100644 --- a/src/graph/state/sqlstate.py +++ b/src/graph/state/sqlstate.py @@ -36,6 +36,9 @@ class SQLState(MessagesState, total=False): optimized_sql: str rewrite_explanation: str + llm_reasoning: str # 添加推理过程字段 + generate_plans_reasoning: str # 添加生成方案的推理过程字段 + debug_info: Dict[str, Any] # 添加调试信息字段 equivalence: bool cost_before: Optional[float] cost_after: Optional[float] diff --git a/src/llm/langchain_llm.py b/src/llm/langchain_llm.py index 7b6ceba..6caca60 100644 --- a/src/llm/langchain_llm.py +++ b/src/llm/langchain_llm.py @@ -1,7 +1,10 @@ -from typing import List, Dict, Optional, TypedDict +from typing import List, Dict, Optional, TypedDict, Generator, Tuple from src.config import get_settings from src.llm.client import LLMClient from langchain.chat_models import init_chat_model +import json +import re +import openai class LangchainLLMClient(LLMClient): @@ -40,6 +43,123 @@ def chat( content = resp.content if resp.content else "" return content or "" + def chat_with_reasoning( + self, + messages: List[Dict[str, str]], + temperature: float = 0.2, + max_tokens: Optional[int] = None, + state: Optional[TypedDict] = None, + ) -> Tuple[str, str]: + """ + 调用思考模型并返回最终答案和推理过程 + 返回: (最终答案, 推理过程) + """ + settings = get_settings() + if not settings.api_key or settings.api_key == "EMPTY_KEY": + raise RuntimeError("OPENAI_API_KEY 未设置,请先在环境变量或 .env 中配置。") + + self._llm.bind( + temperature=temperature, + max_tokens=max_tokens, + ) + try: + resp = self._llm.invoke(messages) + except Exception as e: + raise RuntimeError(f"调用 LLM 出错:{e}") from e + + # DeepSeek Reasoner 模型会在 reasoning_content 字段返回推理过程 + reasoning_content = "" + final_answer = "" + + # 检查是否有 reasoning_content 属性(DeepSeek Reasoner 特有) + if hasattr(resp, 'reasoning_content') and resp.reasoning_content: + reasoning_content = resp.reasoning_content + final_answer = resp.content if resp.content else "" + elif hasattr(resp, 'additional_kwargs') and 'reasoning_content' in resp.additional_kwargs: + reasoning_content = resp.additional_kwargs['reasoning_content'] + final_answer = resp.content if resp.content else "" + else: + # 如果没有 reasoning_content,尝试从 content 中解析 + content = resp.content if resp.content else "" + reasoning_content, final_answer = self._parse_reasoning_output(content) + + return final_answer, reasoning_content + + def _parse_reasoning_output(self, content: str) -> Tuple[str, str]: + """ + 解析思考模型的输出,提取推理过程和最终答案 + """ + if not content: + return "", "" + + # 尝试多种解析模式 + patterns = [ + # 模式1: ... ... + r'(.*?)\s*(.*?)', + # 模式2: 思考:... 答案:... + r'思考:\s*(.*?)\s*答案:\s*(.*?)$', + # 模式3: 推理过程:... 最终答案:... + r'推理过程:\s*(.*?)\s*最终答案:\s*(.*?)$', + # 模式4: 分析:... 结论:... + r'分析:\s*(.*?)\s*结论:\s*(.*?)$', + ] + + for pattern in patterns: + match = re.search(pattern, content, re.DOTALL | re.IGNORECASE) + if match: + reasoning = match.group(1).strip() + final_answer = match.group(2).strip() + return reasoning, final_answer + + # 如果没有找到特定模式,尝试按段落分割 + paragraphs = content.split('\n\n') + if len(paragraphs) >= 2: + # 假设最后一段是最终答案,前面的都是推理过程 + reasoning = '\n\n'.join(paragraphs[:-1]).strip() + final_answer = paragraphs[-1].strip() + return reasoning, final_answer + + # 如果无法分离,返回原内容作为答案 + return "", content + + def chat_stream_with_reasoning( + self, + messages: List[Dict[str, str]], + temperature: float = 0.2, + max_tokens: Optional[int] = None, + state: Optional[TypedDict] = None, + ) -> Generator[Dict[str, str], None, None]: + """ + 流式调用思考模型,实时返回推理过程 + 生成器返回: {"type": "reasoning"|"answer", "content": "..."} + """ + settings = get_settings() + if not settings.api_key or settings.api_key == "EMPTY_KEY": + raise RuntimeError("OPENAI_API_KEY 未设置,请先在环境变量或 .env 中配置。") + + self._llm.bind( + temperature=temperature, + max_tokens=max_tokens, + ) + + try: + # 使用流式调用 + for chunk in self._llm.stream(messages): + if hasattr(chunk, 'content') and chunk.content: + content = chunk.content + + # 尝试判断当前内容类型 + if '' in content or '思考:' in content or '推理' in content: + yield {"type": "reasoning", "content": content} + elif '' in content or '答案:' in content or '结论:' in content: + yield {"type": "answer", "content": content} + else: + # 默认作为推理过程 + yield {"type": "reasoning", "content": content} + + except Exception as e: + raise RuntimeError(f"流式调用 LLM 出错:{e}") from e + async def chat_async( self, messages: List[Dict[str, str]], @@ -62,4 +182,70 @@ async def chat_async( except Exception as e: raise RuntimeError(f"调用 LLM 出错:{e}") from e content = resp.content if resp.content else "" - return content or "" \ No newline at end of file + return content or "" + def stream_with_reasoning( + self, + messages: List[Dict[str, str]], + temperature: float = 0.2, + max_tokens: Optional[int] = None, + ) -> Generator[Dict[str, str], None, None]: + """ + 流式调用思考模型,实时返回推理过程(使用OpenAI API) + 生成器返回: {"type": "reasoning"|"answer"|"done", "content": "...", "delta": "..."} + """ + settings = get_settings() + if not settings.api_key or settings.api_key == "EMPTY_KEY": + raise RuntimeError("OPENAI_API_KEY 未设置,请先在环境变量或 .env 中配置。") + + full_reasoning = "" + full_answer = "" + + try: + # 直接使用OpenAI API以获取reasoning_content + client = openai.OpenAI( + api_key=settings.api_key, + base_url=settings.base_url + ) + + response = client.chat.completions.create( + model=settings.model, + messages=messages, + temperature=temperature, + max_tokens=max_tokens, + stream=True + ) + + # 处理流式响应 + for chunk in response: + if chunk.choices and len(chunk.choices) > 0: + delta_obj = chunk.choices[0].delta + + # 检查reasoning_content (DeepSeek Reasoner特有) + if hasattr(delta_obj, 'reasoning_content') and delta_obj.reasoning_content: + delta = delta_obj.reasoning_content + full_reasoning += delta + yield { + "type": "reasoning", + "delta": delta, + "content": full_reasoning + } + + # 检查普通content + if hasattr(delta_obj, 'content') and delta_obj.content: + delta = delta_obj.content + full_answer += delta + yield { + "type": "answer", + "delta": delta, + "content": full_answer + } + + # 返回完成信号 + yield { + "type": "done", + "reasoning": full_reasoning, + "answer": full_answer + } + + except Exception as e: + raise RuntimeError(f"流式调用 LLM 出错:{e}") from e diff --git a/test_e2e.py b/test_e2e.py new file mode 100755 index 0000000..def6e07 --- /dev/null +++ b/test_e2e.py @@ -0,0 +1,330 @@ +#!/usr/bin/env python3 +""" +端到端测试脚本 - 模拟LangGraph完整流程 +使用TPC-H schema和查询 +""" +import sys +import os + +# 添加项目路径 +sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) + +from src.graph.graph import build_sqlopt_graph +from src.graph.state import SQLState +from src.config import get_settings +from src.llm.langchain_llm import LangchainLLMClient +import json + +# TPC-H Schema +TPCH_SCHEMA = """ +CREATE TABLE NATION ( + N_NATIONKEY INTEGER NOT NULL, + N_NAME CHAR(25) NOT NULL, + N_REGIONKEY INTEGER NOT NULL, + N_COMMENT VARCHAR(152) +); + +CREATE TABLE REGION ( + R_REGIONKEY INTEGER NOT NULL, + R_NAME CHAR(25) NOT NULL, + R_COMMENT VARCHAR(152) +); + +CREATE TABLE PART ( + P_PARTKEY INTEGER NOT NULL, + P_NAME VARCHAR(55) NOT NULL, + P_MFGR CHAR(25) NOT NULL, + P_BRAND CHAR(10) NOT NULL, + P_TYPE VARCHAR(25) NOT NULL, + P_SIZE INTEGER NOT NULL, + P_CONTAINER CHAR(10) NOT NULL, + P_RETAILPRICE DECIMAL(15,2) NOT NULL, + P_COMMENT VARCHAR(23) NOT NULL +); + +CREATE TABLE SUPPLIER ( + S_SUPPKEY INTEGER NOT NULL, + S_NAME CHAR(25) NOT NULL, + S_ADDRESS VARCHAR(40) NOT NULL, + S_NATIONKEY INTEGER NOT NULL, + S_PHONE CHAR(15) NOT NULL, + S_ACCTBAL DECIMAL(15,2) NOT NULL, + S_COMMENT VARCHAR(101) NOT NULL +); + +CREATE TABLE PARTSUPP ( + PS_PARTKEY INTEGER NOT NULL, + PS_SUPPKEY INTEGER NOT NULL, + PS_AVAILQTY INTEGER NOT NULL, + PS_SUPPLYCOST DECIMAL(15,2) NOT NULL, + PS_COMMENT VARCHAR(199) NOT NULL +); + +CREATE TABLE CUSTOMER ( + C_CUSTKEY INTEGER NOT NULL, + C_NAME VARCHAR(25) NOT NULL, + C_ADDRESS VARCHAR(40) NOT NULL, + C_NATIONKEY INTEGER NOT NULL, + C_PHONE CHAR(15) NOT NULL, + C_ACCTBAL DECIMAL(15,2) NOT NULL, + C_MKTSEGMENT CHAR(10) NOT NULL, + C_COMMENT VARCHAR(117) NOT NULL +); + +CREATE TABLE ORDERS ( + O_ORDERKEY INTEGER NOT NULL, + O_CUSTKEY INTEGER NOT NULL, + O_ORDERSTATUS CHAR(1) NOT NULL, + O_TOTALPRICE DECIMAL(15,2) NOT NULL, + O_ORDERDATE DATE NOT NULL, + O_ORDERPRIORITY CHAR(15) NOT NULL, + O_CLERK CHAR(15) NOT NULL, + O_SHIPPRIORITY INTEGER NOT NULL, + O_COMMENT VARCHAR(79) NOT NULL +); + +CREATE TABLE LINEITEM ( + L_ORDERKEY INTEGER NOT NULL, + L_PARTKEY INTEGER NOT NULL, + L_SUPPKEY INTEGER NOT NULL, + L_LINENUMBER INTEGER NOT NULL, + L_QUANTITY DECIMAL(15,2) NOT NULL, + L_EXTENDEDPRICE DECIMAL(15,2) NOT NULL, + L_DISCOUNT DECIMAL(15,2) NOT NULL, + L_TAX DECIMAL(15,2) NOT NULL, + L_RETURNFLAG CHAR(1) NOT NULL, + L_LINESTATUS CHAR(1) NOT NULL, + L_SHIPDATE DATE NOT NULL, + L_COMMITDATE DATE NOT NULL, + L_RECEIPTDATE DATE NOT NULL, + L_SHIPINSTRUCT CHAR(25) NOT NULL, + L_SHIPMODE CHAR(10) NOT NULL, + L_COMMENT VARCHAR(44) NOT NULL +); +""" + +# TPC-H Query 1 +TPCH_QUERY_1 = """ +SELECT + l_returnflag, + l_linestatus, + SUM(l_quantity) AS sum_qty, + SUM(l_extendedprice) AS sum_base_price, + SUM(l_extendedprice * (1 - l_discount)) AS sum_disc_price, + SUM(l_extendedprice * (1 - l_discount) * (1 + l_tax)) AS sum_charge, + AVG(l_quantity) AS avg_qty, + AVG(l_extendedprice) AS avg_price, + AVG(l_discount) AS avg_disc, + COUNT(*) AS count_order +FROM lineitem +WHERE l_shipdate <= DATE '1998-12-01' - INTERVAL 90 DAY +GROUP BY l_returnflag, l_linestatus +ORDER BY l_returnflag, l_linestatus; +""" + + +def print_section(title): + """打印章节标题""" + print("\n" + "=" * 100) + print(f" {title}") + print("=" * 100) + + +def print_subsection(title): + """打印子章节标题""" + print("\n" + "-" * 100) + print(f" {title}") + print("-" * 100) + + +def test_e2e_sql_optimization(): + """端到端测试SQL优化流程""" + + print_section("🚀 端到端测试 - TPC-H Query 1 优化") + + # 1. 初始化配置 + print_subsection("1️⃣ 初始化配置") + settings = get_settings() + print(f"✅ 模型: {settings.model}") + print(f"✅ Base URL: {settings.base_url}") + print(f"✅ 数据库: {settings.mysql_host}:{settings.mysql_port}/{settings.mysql_database}") + + # 2. 创建LLM客户端 + print_subsection("2️⃣ 创建LLM客户端") + llm = LangchainLLMClient( + model=settings.model, + base_url=settings.base_url, + api_key=settings.api_key + ) + print(f"✅ LLM客户端创建成功") + + # 3. 构建Graph + print_subsection("3️⃣ 构建SQL优化图") + graph = build_sqlopt_graph() + print(f"✅ Graph构建成功") + + # 4. 准备输入状态 + print_subsection("4️⃣ 准备输入状态") + print(f"\n📋 Schema:") + print(TPCH_SCHEMA[:500] + "...\n") + + print(f"📝 SQL Query:") + print(TPCH_QUERY_1) + + initial_state = SQLState( + sql=TPCH_QUERY_1.strip(), + db_schema=TPCH_SCHEMA.strip(), + max_iterations=2, + llm=llm, + settings=settings, + iteration_count=0, + messages=[], + optimization_plans=[], + current_plan_index=0 + ) + print(f"\n✅ 初始状态准备完成") + + # 5. 执行优化流程 + print_section("5️⃣ 执行优化流程") + + try: + print("\n🔄 开始执行Graph...\n") + + # 流式执行graph + step_count = 0 + for event in graph.stream(initial_state): + step_count += 1 + + # 打印节点信息 + for node_name, node_output in event.items(): + print(f"\n🔹 步骤 {step_count}: 节点 [{node_name}]") + + # 显示关键信息 + if node_name == "input": + print(f" ├─ SQL长度: {len(node_output.get('sql', ''))} 字符") + print(f" └─ Schema长度: {len(node_output.get('db_schema', ''))} 字符") + + elif node_name == "get_plan": + plan = node_output.get("plan", "") + print(f" ├─ 执行计划已获取") + print(f" └─ 计划长度: {len(plan)} 字符") + if plan: + print(f"\n 📊 执行计划预览:") + print(" " + "\n ".join(plan[:300].split("\n"))) + if len(plan) > 300: + print(" ...") + + elif node_name == "get_stats": + stats = node_output.get("stats", {}) + print(f" ├─ 统计信息已获取") + if stats.get("tables"): + print(f" ├─ 涉及表: {', '.join(stats['tables'].keys())}") + if stats.get("collection_success"): + print(f" └─ ✅ 统计信息收集成功") + else: + print(f" └─ ⚠️ 统计信息收集部分失败") + + elif node_name == "generate_plans": + plans = node_output.get("optimization_plans", []) + reasoning = node_output.get("generate_plans_reasoning", "") + + print(f" ├─ 生成了 {len(plans)} 个优化方案") + + if reasoning: + print(f" ├─ 🧠 推理过程长度: {len(reasoning)} 字符") + print(f"\n 🧠 推理过程预览:") + print(" " + "\n ".join(reasoning[:500].split("\n"))) + if len(reasoning) > 500: + print(" ...") + + print(f"\n 💡 优化方案:") + for i, plan in enumerate(plans, 1): + print(f"\n 方案 {i}:") + print(f" ├─ 理由: {plan.get('reason', '')[:100]}...") + opt_sql = plan.get('optimized_sql', '') + print(f" └─ SQL: {opt_sql[:100]}...") + + elif node_name == "check_equivalence": + eq = node_output.get("equivalence", False) + idx = node_output.get("current_plan_index", 0) + print(f" ├─ 方案 {idx + 1} 等价性检查") + print(f" └─ {'✅ 等价' if eq else '❌ 不等价'}") + + elif node_name == "get_costs": + cost_before = node_output.get("cost_before") + cost_after = node_output.get("cost_after") + idx = node_output.get("current_plan_index", 0) + print(f" ├─ 方案 {idx + 1} 成本估算") + if cost_before is not None and cost_after is not None: + print(f" ├─ 优化前成本: {cost_before}") + print(f" ├─ 优化后成本: {cost_after}") + if cost_before > 0: + improvement = ((cost_before - cost_after) / cost_before) * 100 + print(f" └─ 改善: {improvement:.2f}%") + else: + print(f" └─ ⚠️ 成本估算失败或跳过") + + elif node_name == "next_plan": + idx = node_output.get("current_plan_index", 0) + print(f" └─ 切换到方案 {idx + 1}") + + elif node_name == "report": + final_report = node_output.get("final_report", "") + print(f" ├─ 📄 最终报告已生成") + print(f" └─ 报告长度: {len(final_report)} 字符") + + print(f"\n✅ Graph执行完成!共 {step_count} 个步骤\n") + + except Exception as e: + print(f"\n❌ 执行失败: {e}") + import traceback + traceback.print_exc() + return + + # 6. 显示最终结果 + print_section("6️⃣ 最终结果") + + # 获取最终状态 + final_state = event + if final_state: + # 获取最后一个节点的输出(应该是report节点) + for node_name, state in final_state.items(): + if node_name == "report": + + # 显示优化方案 + plans = state.get("optimization_plans", []) + if plans: + print_subsection("📊 优化方案详情") + for i, plan in enumerate(plans, 1): + print(f"\n方案 {i}:") + print(f" 理由: {plan.get('reason', 'N/A')}") + print(f" 等价性: {'✅ 通过' if plan.get('equivalence') else '❌ 未通过'}") + print(f" 成本: {plan.get('cost', 'N/A')}") + print(f" SQL:") + print(" " + "\n ".join(plan.get('optimized_sql', '').split("\n"))) + + # 显示推理过程 + reasoning = state.get("generate_plans_reasoning", "") + if reasoning: + print_subsection("🧠 完整推理过程") + print(reasoning) + + # 显示最终报告 + final_report = state.get("final_report", "") + if final_report: + print_subsection("📄 最终报告") + print(final_report) + + print_section("✅ 测试完成") + + +if __name__ == "__main__": + try: + test_e2e_sql_optimization() + except KeyboardInterrupt: + print("\n\n⚠️ 测试被用户中断") + except Exception as e: + print(f"\n\n❌ 测试失败: {e}") + import traceback + traceback.print_exc() + diff --git a/test_reasoning_simple.py b/test_reasoning_simple.py new file mode 100644 index 0000000..9d2fe39 --- /dev/null +++ b/test_reasoning_simple.py @@ -0,0 +1,103 @@ +#!/usr/bin/env python3 +""" +简单测试流式推理功能 +""" +import sys +import os +sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) + +from src.llm.langchain_llm import LangchainLLMClient +from src.config import get_settings + +def test_streaming_reasoning(): + """测试流式推理""" + print("=" * 80) + print("🧪 测试流式推理功能") + print("=" * 80) + + # 初始化 + settings = get_settings() + print(f"\n✅ 模型: {settings.model}") + + llm = LangchainLLMClient( + model=settings.model, + base_url=settings.base_url, + api_key=settings.api_key + ) + print(f"✅ LLM客户端创建成功\n") + + # 简单的推理任务 + messages = [ + { + "role": "user", + "content": "请分析一下为什么这个SQL查询可能会慢,并给出优化建议:SELECT * FROM orders WHERE order_date > '2024-01-01' ORDER BY total_amount DESC" + } + ] + + print("🔄 开始流式推理...\n") + print("-" * 80) + + try: + reasoning_length = 0 + answer_length = 0 + chunk_count = 0 + + for chunk in llm.stream_with_reasoning(messages): + chunk_count += 1 + chunk_type = chunk.get("type") + + if chunk_type == "reasoning": + delta = chunk.get("delta", "") + if delta: + print(f"🧠 [推理chunk {chunk_count}] {len(delta)} 字符", flush=True) + reasoning_length += len(delta) + # 打印前50个字符 + preview = delta[:50].replace("\n", " ") + print(f" 预览: {preview}...", flush=True) + + elif chunk_type == "answer": + delta = chunk.get("delta", "") + if delta: + print(f"💡 [回答chunk {chunk_count}] {len(delta)} 字符", flush=True) + answer_length += len(delta) + # 打印前50个字符 + preview = delta[:50].replace("\n", " ") + print(f" 预览: {preview}...", flush=True) + + elif chunk_type == "done": + print(f"\n✅ 完成!", flush=True) + print(f" 总推理长度: {reasoning_length} 字符", flush=True) + print(f" 总回答长度: {answer_length} 字符", flush=True) + print(f" 总chunk数: {chunk_count}", flush=True) + + full_reasoning = chunk.get("reasoning", "") + full_answer = chunk.get("answer", "") + + print("\n" + "=" * 80) + print("📋 完整推理过程:") + print("=" * 80) + print(full_reasoning[:500] + "..." if len(full_reasoning) > 500 else full_reasoning) + + print("\n" + "=" * 80) + print("💡 完整回答:") + print("=" * 80) + print(full_answer[:500] + "..." if len(full_answer) > 500 else full_answer) + break + + print("\n" + "=" * 80) + print("✅ 测试成功!") + print("=" * 80) + + except Exception as e: + print(f"\n❌ 测试失败: {e}", flush=True) + import traceback + traceback.print_exc() + +if __name__ == "__main__": + test_streaming_reasoning() + + + + + +