Skip to content

Commit 736d011

Browse files
author
Qiushi Bai
committed
Adding a ast_util.py to visualize ASTs.
1 parent e6befe1 commit 736d011

3 files changed

Lines changed: 2079 additions & 1649 deletions

File tree

pytest.ini

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
[pytest]
2+
# Capture log messages during test execution
3+
log_cli = true
4+
log_cli_level = INFO
5+
log_file_level = INFO
6+
log_level = INFO
7+
8+
# Show output from print statements and logging
9+
addopts = -v --tb=short

tests/ast_util.py

Lines changed: 304 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,304 @@
1+
"""
2+
Utility functions for visualizing and working with AST structures.
3+
"""
4+
import textwrap
5+
import sqlparse
6+
from core.ast.node import (
7+
Node, QueryNode, SelectNode, FromNode, WhereNode, TableNode, ColumnNode,
8+
LiteralNode, OperatorNode, FunctionNode, GroupByNode, HavingNode,
9+
OrderByNode, OrderByItemNode, LimitNode, OffsetNode, JoinNode, SubqueryNode,
10+
VarNode, VarSetNode
11+
)
12+
13+
14+
def _beautify_sql(sql: str) -> str:
15+
"""
16+
Beautify SQL query string with proper indentation and formatting.
17+
18+
Uses sqlparse library.
19+
20+
Args:
21+
sql: Raw SQL query string
22+
23+
Returns:
24+
Formatted SQL string with proper indentation
25+
"""
26+
27+
formatted = sqlparse.format(
28+
sql,
29+
reindent=True,
30+
keyword_case="upper"
31+
)
32+
33+
return formatted
34+
35+
36+
def _node_to_string(node: Node, indent: int = 0) -> str:
37+
"""
38+
Convert an AST node to a tree-formatted string representation.
39+
40+
This function recursively converts AST nodes into a human-readable tree format
41+
for visualization. The translation rules for each node type are:
42+
43+
- TableNode: "table: name [alias]"
44+
- name: table name
45+
- [alias]: optional table alias (e.g., "employees [e]")
46+
47+
- ColumnNode: "column: name (parent_alias) as alias"
48+
- name: column name
49+
- (parent_alias): optional table alias this column references (e.g., "salary (e)")
50+
- as alias: optional column-level alias (e.g., "as emp_count")
51+
52+
- LiteralNode: "literal: value"
53+
- value: the literal value (e.g., 40000, 'text')
54+
55+
- FunctionNode: "function: name as alias"
56+
- name: function name (e.g., COUNT, SUM)
57+
- as alias: optional function alias (e.g., "as emp_count")
58+
- children: function arguments displayed as child nodes
59+
60+
- OperatorNode: "operator: op_name"
61+
- op_name: the operator (e.g., =, AND, OR, IN, >)
62+
- children: operands as child nodes
63+
- Special case for IN: displays a "values:" node containing the list items
64+
65+
- JoinNode: "join: join_type"
66+
- join_type: INNER, LEFT, RIGHT, FULL, CROSS, etc.
67+
- children: left table, right table, and join condition
68+
69+
- OrderByItemNode: "order_by_item: sort_order"
70+
- sort_order: ASC or DESC
71+
- children: the column being sorted
72+
73+
- SelectNode, FromNode, WhereNode, GroupByNode, HavingNode, OrderByNode:
74+
"select", "from", "where", "group_by", "having", "order_by"
75+
- These clause nodes have children representing their contents
76+
77+
- LimitNode, OffsetNode: "limit: value" / "offset: value"
78+
- value: the numeric limit or offset
79+
80+
- QueryNode: "query"
81+
- Represents the root query or a subquery's internal structure
82+
- children: SELECT, FROM, WHERE, GROUP BY, HAVING, ORDER BY, LIMIT, OFFSET clauses
83+
84+
- SubqueryNode: "subquery [alias]"
85+
- [alias]: optional subquery alias (e.g., "[grouped_items]")
86+
- children: the internal QueryNode
87+
88+
Args:
89+
node: AST node to convert
90+
indent: Current indentation level
91+
92+
Returns:
93+
String representation of the node in tree format
94+
"""
95+
result = []
96+
prefix = "| " * indent + "+- "
97+
98+
# Get node type name
99+
node_type = node.type.value if hasattr(node.type, 'value') else str(node.type)
100+
101+
# Build node representation based on node type
102+
if isinstance(node, TableNode):
103+
# TableNode: display as "table: table_name [alias]"
104+
# Example: "table: employees [e]" - "e" is the table alias for reference in WHERE/SELECT
105+
alias_str = f" [{node.alias}]" if node.alias else ""
106+
result.append(f"{prefix}{node_type}: {node.name}{alias_str}")
107+
108+
elif isinstance(node, ColumnNode):
109+
# ColumnNode: display as "column: column_name (parent_alias) as alias"
110+
# Example: "column: salary (e) as avg_salary"
111+
# - (e) indicates this column belongs to table with alias "e"
112+
# - "as avg_salary" is the column's output alias in the result set
113+
parent_alias = f" ({node.parent_alias})" if node.parent_alias else ""
114+
alias_str = f" as {node.alias}" if node.alias else ""
115+
result.append(f"{prefix}{node_type}: {node.name}{parent_alias}{alias_str}")
116+
117+
elif isinstance(node, LiteralNode):
118+
# LiteralNode: display the literal value
119+
# Examples: "literal: 40000", "literal: 'hello'", "literal: true"
120+
result.append(f"{prefix}{node_type}: {node.value}")
121+
122+
elif isinstance(node, FunctionNode):
123+
# FunctionNode: display as "function: function_name as alias"
124+
# Example: "function: COUNT as emp_count", "function: SUM"
125+
# The function arguments are shown as child nodes
126+
alias_str = f" as {node.alias}" if node.alias else ""
127+
result.append(f"{prefix}{node_type}: {node.name}{alias_str}")
128+
if node.children:
129+
for i, child in enumerate(node.children):
130+
child_lines = _node_to_string(child, indent + 1).split('\n')
131+
for line in child_lines:
132+
result.append(line)
133+
134+
elif isinstance(node, OperatorNode):
135+
# OperatorNode: display as "operator: operator_symbol"
136+
# Examples: "operator: =", "operator: AND", "operator: >", "operator: IN"
137+
# Binary operators like "=" have two operands (left, right) as children
138+
# Logical operators like "AND" combine conditions
139+
result.append(f"{prefix}{node_type}: {node.name}")
140+
if node.children:
141+
for i, child in enumerate(node.children):
142+
if isinstance(child, list):
143+
# Special handling for IN operator with list of values
144+
# IN can have: (column, IN, [value1, value2, ...])
145+
list_prefix = "| " * (indent + 1) + "+- "
146+
result.append(f"{list_prefix}values:")
147+
for item in child:
148+
item_lines = _node_to_string(item, indent + 2).split('\n')
149+
for line in item_lines:
150+
result.append(line)
151+
else:
152+
child_lines = _node_to_string(child, indent + 1).split('\n')
153+
for line in child_lines:
154+
result.append(line)
155+
156+
elif isinstance(node, JoinNode):
157+
# JoinNode: display as "join: join_type"
158+
# Example: "join: inner" for INNER JOIN
159+
# Children include: left table, right table, and join condition (ON clause)
160+
join_type = node.join_type.value if hasattr(node.join_type, 'value') else str(node.join_type)
161+
result.append(f"{prefix}{node_type}: {join_type}")
162+
left_lines = _node_to_string(node.left_table, indent + 1).split('\n')
163+
for line in left_lines:
164+
result.append(line)
165+
right_lines = _node_to_string(node.right_table, indent + 1).split('\n')
166+
for line in right_lines:
167+
result.append(line)
168+
if node.on_condition:
169+
cond_lines = _node_to_string(node.on_condition, indent + 1).split('\n')
170+
for line in cond_lines:
171+
result.append(line)
172+
173+
elif isinstance(node, OrderByItemNode):
174+
# OrderByItemNode: display as "order_by_item: sort_order"
175+
# Example: "order_by_item: ASC" or "order_by_item: DESC"
176+
# The column being sorted is shown as a child
177+
sort_order = node.sort.value if hasattr(node.sort, 'value') else str(node.sort)
178+
result.append(f"{prefix}{node_type}: {sort_order}")
179+
if node.children:
180+
for child in node.children:
181+
child_lines = _node_to_string(child, indent + 1).split('\n')
182+
for line in child_lines:
183+
result.append(line)
184+
185+
elif isinstance(node, (SelectNode, FromNode, WhereNode, GroupByNode, HavingNode, OrderByNode)):
186+
# Clause nodes: display as the clause name only
187+
# Examples: "select", "from", "where", "group_by", "having", "order_by"
188+
# Children represent the contents of each clause
189+
result.append(f"{prefix}{node_type}")
190+
if node.children:
191+
for child in node.children:
192+
child_lines = _node_to_string(child, indent + 1).split('\n')
193+
for line in child_lines:
194+
result.append(line)
195+
196+
elif isinstance(node, (LimitNode, OffsetNode)):
197+
# LimitNode/OffsetNode: display as "limit: value" or "offset: value"
198+
# Example: "limit: 10", "offset: 5"
199+
value = node.limit if isinstance(node, LimitNode) else node.offset
200+
result.append(f"{prefix}{node_type}: {value}")
201+
202+
elif isinstance(node, QueryNode):
203+
# QueryNode: root query or subquery structure, display as "query"
204+
# Maintains tree structure consistency by using proper prefix and indentation
205+
# Children are the clauses: SELECT, FROM, WHERE, GROUP BY, etc.
206+
result.append(f"{prefix}query")
207+
if node.children:
208+
for child in node.children:
209+
child_lines = _node_to_string(child, indent + 1).split('\n')
210+
for line in child_lines:
211+
result.append(line)
212+
213+
elif isinstance(node, SubqueryNode):
214+
# SubqueryNode: display as "subquery [alias]"
215+
# Example: "subquery [t1]" where "t1" is the alias used to reference this subquery
216+
# Children: the internal QueryNode representing the subquery's structure
217+
alias_str = f" [{node.alias}]" if node.alias else ""
218+
result.append(f"{prefix}{node_type}{alias_str}")
219+
if node.children:
220+
for child in node.children:
221+
child_lines = _node_to_string(child, indent + 1).split('\n')
222+
for line in child_lines:
223+
result.append(line)
224+
225+
elif isinstance(node, (VarNode, VarSetNode)):
226+
# VarNode/VarSetNode: VarSQL variable, display as "var: name" or "varset: name"
227+
result.append(f"{prefix}{node_type}: {node.name}")
228+
229+
else:
230+
# Default case for any other node types
231+
result.append(f"{prefix}{node_type}")
232+
if node.children:
233+
for child in node.children:
234+
child_lines = _node_to_string(child, indent + 1).split('\n')
235+
for line in child_lines:
236+
result.append(line)
237+
238+
return '\n'.join(result)
239+
240+
241+
def visualize_ast(sql: str, ast: QueryNode, max_sql_width: int = 50) -> str:
242+
"""
243+
Generate a side-by-side visualization of SQL query and AST structure.
244+
245+
This function beautifies the SQL query on the left and displays the AST
246+
tree structure on the right, allowing for easy comparison and review.
247+
Individual SQL lines that exceed max_sql_width are automatically wrapped.
248+
249+
Args:
250+
sql: SQL query string to visualize
251+
ast: QueryNode representing the parsed AST
252+
max_sql_width: Maximum width for SQL column before wrapping (default: 50)
253+
254+
Returns:
255+
Formatted string with SQL on the left and AST tree on the right
256+
"""
257+
# Beautify SQL
258+
beautified_sql = _beautify_sql(sql)
259+
sql_lines = beautified_sql.split('\n')
260+
261+
# Wrap long SQL lines to fit within max_sql_width
262+
wrapped_sql_lines = []
263+
for line in sql_lines:
264+
if len(line) > max_sql_width:
265+
# Wrap long lines, preserving indentation
266+
wrapped = textwrap.fill(
267+
line,
268+
width=max_sql_width,
269+
subsequent_indent=' ', # Indent continuation lines
270+
break_long_words=False,
271+
break_on_hyphens=False
272+
)
273+
wrapped_sql_lines.extend(wrapped.split('\n'))
274+
else:
275+
wrapped_sql_lines.append(line)
276+
277+
# Convert AST to tree format
278+
ast_tree = _node_to_string(ast)
279+
ast_lines = ast_tree.split('\n')
280+
281+
# Calculate column widths based on wrapped SQL
282+
actual_sql_width = max(len(line) for line in wrapped_sql_lines) if wrapped_sql_lines else 0
283+
max_ast_width = max(len(line) for line in ast_lines) if ast_lines else 0
284+
padding = 3 # Space between columns
285+
286+
total_width = actual_sql_width + padding + max_ast_width
287+
288+
result = []
289+
result.append("=" * total_width)
290+
result.append(f"{'SQL QUERY':<{actual_sql_width}}{' ' * padding}{'AST STRUCTURE'}")
291+
result.append("=" * total_width)
292+
293+
# Merge lines side-by-side
294+
max_lines = max(len(wrapped_sql_lines), len(ast_lines))
295+
for i in range(max_lines):
296+
sql_line = wrapped_sql_lines[i] if i < len(wrapped_sql_lines) else ""
297+
ast_line = ast_lines[i] if i < len(ast_lines) else ""
298+
299+
# Pad SQL line to match column width
300+
result.append(f"{sql_line:<{actual_sql_width}}{' ' * padding}{ast_line}")
301+
302+
result.append("=" * total_width)
303+
304+
return '\n'.join(result)

0 commit comments

Comments
 (0)