1+
2+ import os
3+ import re
4+ import logging
5+ from sqlalchemy .ext .asyncio import AsyncSession
6+ from xml .etree import ElementTree as ET
7+ from typing import Dict , Any , List , Optional
8+ from sqlalchemy import text
9+ from sqlalchemy .ext .asyncio import AsyncSession
10+ from fastapi .encoders import jsonable_encoder
11+ from typing import Dict , Any , List , Optional , Union
12+
13+ logging .basicConfig (level = logging .DEBUG )
14+ logger = logging .getLogger (__name__ )
15+
16+ class SqlXmlExecutor :
17+ def __init__ (self , db : AsyncSession , mapper_dir : str = "mapper" ):
18+ self .db = db
19+ self .queries = self .load_queries (mapper_dir )
20+
21+ def load_queries (self , dir_path : str ) -> Dict [str , Dict [str , str ]]:
22+ queries = {}
23+ for filename in os .listdir (dir_path ):
24+ if filename .endswith ('.xml' ):
25+ module = filename .split ('.' )[0 ]
26+ file_path = os .path .join (dir_path , filename )
27+ tree = ET .parse (file_path )
28+ root = tree .getroot ()
29+ queries [module ] = {}
30+ for query in root .findall ('query' ):
31+ query_id = query .get ('id' )
32+ # 提取整个 <query> 标签内的完整内容(含子标签)
33+ query_text = self ._get_full_query_text (query ).strip ()
34+ queries [module ][query_id ] = query_text
35+ return queries
36+
37+ def _get_full_query_text (self , element ):
38+ """
39+ 递归获取元素及其所有子元素的文本内容
40+ """
41+ text = element .text or ""
42+ for child in element :
43+ text += self ._get_full_query_text (child )
44+ text += element .tail or ""
45+ return text
46+
47+ def parse_xml_query (self , xml_query : str , params : dict ) -> str :
48+ wrapped = f"<root>{ xml_query } </root>"
49+ try :
50+ root = ET .fromstring (wrapped )
51+ except ET .ParseError as e :
52+ raise ValueError (f"XML 解析失败: { e } " )
53+
54+ def process_node (node ):
55+ sql_parts = []
56+ for child in node :
57+ if child .tag == "if" :
58+ condition = child .attrib ["test" ]
59+ if eval_condition (condition , params ):
60+ content = child .text .strip () if child .text else ""
61+ sql_parts .append (content )
62+ elif child .tag == "where" :
63+ where_sql = process_node (child )
64+ if where_sql :
65+ sql_parts .append ("WHERE " + where_sql )
66+ elif child .tag == "choose" :
67+ for when in child .findall ("when" ):
68+ cond = when .attrib ["test" ]
69+ if eval_condition (cond , params ):
70+ content = when .text .strip () if when .text else ""
71+ sql_parts .append (content )
72+ break
73+ else :
74+ inner = process_node (child )
75+ if inner :
76+ sql_parts .append (inner )
77+ return "\n " .join (sql_parts )
78+
79+ def eval_condition (condition : str , params : dict ) -> bool :
80+ return condition in params and params [condition ] is not None
81+
82+ raw_sql = re .sub (r'\s+AND\s' , '\n AND ' , process_node (root ), flags = re .IGNORECASE ).strip ()
83+ return raw_sql .replace (">" , ">" ).replace ("<" , "<" )
84+
85+ async def execute (
86+ self ,
87+ module : str ,
88+ query_id : str ,
89+ params : Optional [Dict [str , Any ]] = None ,
90+ single_row : bool = False ,
91+ v_return_obj : bool = True ,
92+ schema : Any = None
93+ ) -> Union [List [Dict ], Dict , None ]:
94+ if module not in self .queries or query_id not in self .queries [module ]:
95+ raise ValueError (f"Query ID '{ query_id } ' not found in module '{ module } '" )
96+
97+ raw_xml = self .queries [module ][query_id ]
98+
99+ # 如果没有 <if>、<where> 等标签,直接执行原始 SQL
100+ if "<if" not in raw_xml and "<where" not in raw_xml :
101+ final_sql = raw_xml .replace (">" , ">" ).replace ("<" , "<" )
102+
103+ # 🔍 打印 SQL 和参数
104+ logger .info (f"[SQL Query] Module: { module } , Query ID: { query_id } " )
105+ logger .info (f"Final SQL:\n { final_sql } " )
106+ logger .info (f"Params: { params } " )
107+
108+ result = await self .db .execute (text (final_sql ), params or {})
109+ rows = result .mappings ().all ()
110+ if not rows :
111+ return None
112+
113+ data = [dict (row ) for row in rows ]
114+ if v_return_obj and schema :
115+ data = [schema (** item ) for item in data ]
116+ return data [0 ] if single_row else data
117+
118+ # 否则才走 XML 动态解析逻辑(如果需要的话)
119+ final_sql = self .parse_xml_query (raw_xml , params or {})
120+
121+ # 🔍 打印解析后的 SQL 和参数
122+ logger .info (f"[SQL Query] Module: { module } , Query ID: { query_id } " )
123+ logger .info (f"Parsed SQL:\n { final_sql } " )
124+ logger .info (f"Params: { params } " )
125+
126+ result = await self .db .execute (text (final_sql ), params or {})
127+ rows = result .mappings ().all ()
128+
129+ if not rows :
130+ return None
131+
132+ data = [dict (row ) for row in rows ]
133+ if v_return_obj and schema :
134+ data = [schema (** item ) for item in data ]
135+ return data [0 ] if single_row else data
0 commit comments