From 4920c99a3fd1257cfb507b57a05519a098599f08 Mon Sep 17 00:00:00 2001 From: Colin Harrison Date: Thu, 2 Apr 2026 11:59:16 -0700 Subject: [PATCH 1/8] Add RuleParser.parse_v2 returning project AST (RuleParseResult) Made-with: Cursor --- core/rule_parser.py | 307 +++++++++++++++++++++++++++++++++++++- tests/test_rule_parser.py | 38 ++++- 2 files changed, 342 insertions(+), 3 deletions(-) diff --git a/core/rule_parser.py b/core/rule_parser.py index f0de16d..87ab7c9 100644 --- a/core/rule_parser.py +++ b/core/rule_parser.py @@ -1,8 +1,38 @@ +from dataclasses import dataclass from enum import Enum import json import mo_sql_parsing as mosql import re -from typing import Any, Tuple +from typing import Any, Dict, List, Optional, Tuple + +from core.ast.enums import NodeType +from core.ast.node import ( + CaseNode, + ColumnNode, + FromNode, + FunctionNode, + GroupByNode, + HavingNode, + JoinNode, + LimitNode, + ListNode, + Node, + OffsetNode, + OperatorNode, + OrderByItemNode, + OrderByNode, + QueryNode, + SelectNode, + SubqueryNode, + TableNode, + UnaryOperatorNode, + VarNode, + VarSetNode, + WhenThenNode, + WhereNode, + IntervalNode, +) +from core.query_parser import QueryParser # Variable Type @@ -42,6 +72,18 @@ class Scope(Enum): Scope.SELECT: '' } + +@dataclass(frozen=True) +class RuleParseResult: + """Structured result from RuleParser.parse_v2 (project AST instead of mo-sql JSON).""" + + pattern_ast: Node + rewrite_ast: Node + mapping: Dict[str, str] + pattern_scope: Scope + rewrite_scope: Scope + + class RuleParser: # parse a rule (pattern, rewrite) into a SQL AST json str @@ -70,6 +112,269 @@ def parse(pattern: str, rewrite: str) -> Tuple[str, str, str]: # 5. Return the AST subtree as json string return json.dumps(patternASTJson), json.dumps(rewriteASTJson), json.dumps(mapping) + @staticmethod + def parse_v2(pattern: str, rewrite: str) -> RuleParseResult: + """Parse a rule into project AST nodes with external rule variable names. + + Uses the same extension and placeholder replacement as parse(), then + QueryParser plus substitution of internal tokens (V001 / VL001) to + VarNode / VarSetNode or decoded ColumnNode / TableNode names. + """ + pattern_sql, rewrite_sql, mapping = RuleParser.replaceVars(pattern, rewrite) + pattern_full, pattern_scope = RuleParser.extendToFullSQL(pattern_sql) + rewrite_full, rewrite_scope = RuleParser.extendToFullSQL(rewrite_sql) + qparser = QueryParser() + pattern_query = qparser.parse(pattern_full) + rewrite_query = qparser.parse(rewrite_full) + internal_to_external = {internal: external for external, internal in mapping.items()} + pattern_ast = RuleParser._extract_rule_ast(pattern_query, pattern_scope, internal_to_external) + rewrite_ast = RuleParser._extract_rule_ast(rewrite_query, rewrite_scope, internal_to_external) + return RuleParseResult( + pattern_ast=pattern_ast, + rewrite_ast=rewrite_ast, + mapping=mapping, + pattern_scope=pattern_scope, + rewrite_scope=rewrite_scope, + ) + + @staticmethod + def _get_clause(query: QueryNode, clause_type: NodeType) -> Optional[Node]: + for child in query.children: + if child.type == clause_type: + return child + return None + + @staticmethod + def _extract_rule_ast( + query: QueryNode, scope: Scope, internal_to_external: Dict[str, str] + ) -> Node: + sel = RuleParser._get_clause(query, NodeType.SELECT) + frm = RuleParser._get_clause(query, NodeType.FROM) + wh = RuleParser._get_clause(query, NodeType.WHERE) + gb = RuleParser._get_clause(query, NodeType.GROUP_BY) + hav = RuleParser._get_clause(query, NodeType.HAVING) + ob = RuleParser._get_clause(query, NodeType.ORDER_BY) + lim = RuleParser._get_clause(query, NodeType.LIMIT) + off = RuleParser._get_clause(query, NodeType.OFFSET) + + if scope == Scope.CONDITION: + if wh is None or not list(wh.children): + raise ValueError("CONDITION scope requires a WHERE predicate") + pred = list(wh.children)[0] + return RuleParser._as_rule_ast(pred, internal_to_external) + + if scope == Scope.WHERE: + return RuleParser._as_rule_ast( + QueryNode( + _select=None, + _from=None, + _where=RuleParser._as_rule_ast(wh, internal_to_external) if wh else None, + _group_by=RuleParser._as_rule_ast(gb, internal_to_external) if gb else None, + _having=RuleParser._as_rule_ast(hav, internal_to_external) if hav else None, + _order_by=RuleParser._as_rule_ast(ob, internal_to_external) if ob else None, + _limit=RuleParser._as_rule_ast(lim, internal_to_external) if lim else None, + _offset=RuleParser._as_rule_ast(off, internal_to_external) if off else None, + ), + internal_to_external, + ) + + if scope == Scope.FROM: + return RuleParser._as_rule_ast( + QueryNode( + _select=None, + _from=RuleParser._as_rule_ast(frm, internal_to_external) if frm else None, + _where=RuleParser._as_rule_ast(wh, internal_to_external) if wh else None, + _group_by=RuleParser._as_rule_ast(gb, internal_to_external) if gb else None, + _having=RuleParser._as_rule_ast(hav, internal_to_external) if hav else None, + _order_by=RuleParser._as_rule_ast(ob, internal_to_external) if ob else None, + _limit=RuleParser._as_rule_ast(lim, internal_to_external) if lim else None, + _offset=RuleParser._as_rule_ast(off, internal_to_external) if off else None, + ), + internal_to_external, + ) + + return RuleParser._as_rule_ast(query, internal_to_external) + + @staticmethod + def _as_rule_ast(node: Optional[Node], internal_to_external: Dict[str, str]) -> Optional[Node]: + if node is None: + return None + return RuleParser._substitute_placeholders(node, internal_to_external) + + @staticmethod + def _placeholder_varnode(internal_token: str, external_name: str) -> Node: + if internal_token.startswith(VarTypesInfo[VarType.VarList]["internalBase"]): + return VarSetNode(external_name) + return VarNode(external_name) + + @staticmethod + def _substitute_placeholders(node: Node, rev: Dict[str, str]) -> Node: + if node.type == NodeType.COLUMN: + col = node + if not isinstance(col, ColumnNode): + return node + pa = col.parent_alias + nm = col.name + if pa is None and nm in rev: + return RuleParser._placeholder_varnode(nm, rev[nm]) + if pa is not None and pa in rev and nm in rev: + return ColumnNode(rev[nm], _alias=col.alias, _parent_alias=rev[pa]) + if pa is not None and pa in rev: + return ColumnNode(nm, _alias=col.alias, _parent_alias=rev[pa]) + return ColumnNode(nm, _alias=col.alias, _parent_alias=pa) + + if node.type == NodeType.TABLE: + t = node + if not isinstance(t, TableNode): + return node + new_name = rev.get(t.name, t.name) + new_alias = rev[t.alias] if t.alias and t.alias in rev else t.alias + return TableNode(new_name, new_alias) + + if node.type == NodeType.QUERY: + q = node + if not isinstance(q, QueryNode): + return node + return QueryNode( + _select=RuleParser._as_rule_ast(RuleParser._get_clause(q, NodeType.SELECT), rev), + _from=RuleParser._as_rule_ast(RuleParser._get_clause(q, NodeType.FROM), rev), + _where=RuleParser._as_rule_ast(RuleParser._get_clause(q, NodeType.WHERE), rev), + _group_by=RuleParser._as_rule_ast(RuleParser._get_clause(q, NodeType.GROUP_BY), rev), + _having=RuleParser._as_rule_ast(RuleParser._get_clause(q, NodeType.HAVING), rev), + _order_by=RuleParser._as_rule_ast(RuleParser._get_clause(q, NodeType.ORDER_BY), rev), + _limit=RuleParser._as_rule_ast(RuleParser._get_clause(q, NodeType.LIMIT), rev), + _offset=RuleParser._as_rule_ast(RuleParser._get_clause(q, NodeType.OFFSET), rev), + ) + + if node.type == NodeType.SELECT: + sn = node + if not isinstance(sn, SelectNode): + return node + items: List[Node] = [] + don = sn.distinct_on + for ch in sn.children: + if don is not None and ch is don: + continue + items.append(RuleParser._substitute_placeholders(ch, rev)) + new_don = ( + RuleParser._substitute_placeholders(don, rev) if don is not None else None + ) + return SelectNode(items, _distinct=sn.distinct, _distinct_on=new_don) + + if node.type == NodeType.FROM: + fn = node + if not isinstance(fn, FromNode): + return node + return FromNode([RuleParser._substitute_placeholders(c, rev) for c in fn.children]) + + if node.type == NodeType.WHERE: + wn = node + if not isinstance(wn, WhereNode): + return node + return WhereNode([RuleParser._substitute_placeholders(c, rev) for c in wn.children]) + + if node.type == NodeType.GROUP_BY: + g = node + if not isinstance(g, GroupByNode): + return node + return GroupByNode([RuleParser._substitute_placeholders(c, rev) for c in g.children]) + + if node.type == NodeType.HAVING: + h = node + if not isinstance(h, HavingNode): + return node + return HavingNode([RuleParser._substitute_placeholders(c, rev) for c in h.children]) + + if node.type == NodeType.ORDER_BY: + o = node + if not isinstance(o, OrderByNode): + return node + return OrderByNode([RuleParser._substitute_placeholders(c, rev) for c in o.children]) + + if node.type == NodeType.ORDER_BY_ITEM: + oi = node + if not isinstance(oi, OrderByItemNode): + return node + inner = list(oi.children)[0] + return OrderByItemNode(RuleParser._substitute_placeholders(inner, rev), oi.sort) + + if node.type == NodeType.JOIN: + j = node + if not isinstance(j, JoinNode): + return node + ch = list(j.children) + left = RuleParser._substitute_placeholders(ch[0], rev) + right = RuleParser._substitute_placeholders(ch[1], rev) + on_expr = ( + RuleParser._substitute_placeholders(ch[2], rev) if len(ch) > 2 else None + ) + return JoinNode(left, right, j.join_type, on_expr) + + if node.type == NodeType.SUBQUERY: + sq = node + if not isinstance(sq, SubqueryNode): + return node + inner = list(sq.children)[0] + return SubqueryNode(RuleParser._substitute_placeholders(inner, rev), sq.alias) + + if node.type == NodeType.FUNCTION: + f = node + if not isinstance(f, FunctionNode): + return node + new_args = [RuleParser._substitute_placeholders(a, rev) for a in f.children] + return FunctionNode(f.name, _args=new_args, _alias=f.alias) + + if node.type == NodeType.LIST: + ln = node + if not isinstance(ln, ListNode): + return node + return ListNode([RuleParser._substitute_placeholders(c, rev) for c in ln.children]) + + if node.type == NodeType.INTERVAL: + inv = node + if not isinstance(inv, IntervalNode): + return node + if isinstance(inv.value, Node): + return IntervalNode( + RuleParser._substitute_placeholders(inv.value, rev), + inv.unit, # type: ignore[arg-type] + ) + return IntervalNode(inv.value, inv.unit) # type: ignore[arg-type] + + if node.type == NodeType.CASE: + cn = node + if not isinstance(cn, CaseNode): + return node + new_whens: List[WhenThenNode] = [] + for wt in cn.whens: + new_whens.append( + WhenThenNode( + RuleParser._substitute_placeholders(wt.when, rev), + RuleParser._substitute_placeholders(wt.then, rev), + ) + ) + new_else = ( + RuleParser._substitute_placeholders(cn.else_val, rev) if cn.else_val else None + ) + return CaseNode(new_whens, new_else) + + if node.type == NodeType.OPERATOR: + if isinstance(node, UnaryOperatorNode): + op = node + inner = list(op.children)[0] if op.children else op.operand + return UnaryOperatorNode(RuleParser._substitute_placeholders(inner, rev), op.name) + op = node + ch = list(op.children) + if len(ch) == 1: + return OperatorNode(RuleParser._substitute_placeholders(ch[0], rev), op.name) + return OperatorNode( + RuleParser._substitute_placeholders(ch[0], rev), + op.name, + RuleParser._substitute_placeholders(ch[1], rev), + ) + + return node + # Extend pattern/rewrite to full SQL statement # @staticmethod diff --git a/tests/test_rule_parser.py b/tests/test_rule_parser.py index 72596cf..0956a1c 100644 --- a/tests/test_rule_parser.py +++ b/tests/test_rule_parser.py @@ -1,5 +1,13 @@ -from core.rule_parser import RuleParser -from core.rule_parser import Scope +from core.ast.enums import NodeType +from core.ast.node import ( + DataTypeNode, + FunctionNode, + QueryNode, + SelectNode, + VarNode, + VarSetNode, +) +from core.rule_parser import RuleParser, RuleParseResult, Scope def test_extendToFullSQL(): @@ -151,6 +159,32 @@ def test_parse(): assert rewrite_json == internal_rule['rewrite_json'] +def test_parse_v2_cast_rule(): + result = RuleParser.parse_v2('CAST( AS DATE)', '') + assert isinstance(result, RuleParseResult) + assert result.pattern_scope == Scope.CONDITION + assert result.rewrite_scope == Scope.CONDITION + assert result.mapping == {'x': 'V001'} + assert isinstance(result.pattern_ast, FunctionNode) + assert result.pattern_ast.name.lower() == 'cast' + cast_args = list(result.pattern_ast.children) + assert isinstance(cast_args[0], VarNode) and cast_args[0].name == 'x' + assert isinstance(cast_args[1], DataTypeNode) + assert isinstance(result.rewrite_ast, VarNode) and result.rewrite_ast.name == 'x' + + +def test_parse_v2_select_list_varset(): + pattern = 'select <> from lineitem where 1 = 1' + rewrite = 'select <> from lineitem where 1 = 1' + result = RuleParser.parse_v2(pattern, rewrite) + assert result.pattern_scope == Scope.SELECT + assert isinstance(result.pattern_ast, QueryNode) + select = next(c for c in result.pattern_ast.children if c.type == NodeType.SELECT) + assert isinstance(select, SelectNode) + first = list(select.children)[0] + assert isinstance(first, VarSetNode) and first.name == 's1' + + #incorrect brackets def test_brackets_1(): From b628827b5dddb93fef7c308138efc6f767967ebe Mon Sep 17 00:00:00 2001 From: Colin Harrison Date: Thu, 2 Apr 2026 13:16:48 -0700 Subject: [PATCH 2/8] seperate parsrr files --- core/rule_parser.py | 309 +-------------------- core/rule_parser_v2.py | 506 +++++++++++++++++++++++++++++++++++ tests/test_rule_parser.py | 80 ++---- tests/test_rule_parser_v2.py | 280 +++++++++++++++++++ 4 files changed, 804 insertions(+), 371 deletions(-) create mode 100644 core/rule_parser_v2.py create mode 100644 tests/test_rule_parser_v2.py diff --git a/core/rule_parser.py b/core/rule_parser.py index 87ab7c9..646a39e 100644 --- a/core/rule_parser.py +++ b/core/rule_parser.py @@ -1,39 +1,8 @@ -from dataclasses import dataclass from enum import Enum import json import mo_sql_parsing as mosql import re -from typing import Any, Dict, List, Optional, Tuple - -from core.ast.enums import NodeType -from core.ast.node import ( - CaseNode, - ColumnNode, - FromNode, - FunctionNode, - GroupByNode, - HavingNode, - JoinNode, - LimitNode, - ListNode, - Node, - OffsetNode, - OperatorNode, - OrderByItemNode, - OrderByNode, - QueryNode, - SelectNode, - SubqueryNode, - TableNode, - UnaryOperatorNode, - VarNode, - VarSetNode, - WhenThenNode, - WhereNode, - IntervalNode, -) -from core.query_parser import QueryParser - +from typing import Any, Tuple # Variable Type # @@ -41,7 +10,7 @@ class VarType(Enum): Var = 1 VarList = 2 -# Variable Types' infro +# Variable Types' info VarTypesInfo = { VarType.Var: { 'markerStart': '<', @@ -73,17 +42,6 @@ class Scope(Enum): } -@dataclass(frozen=True) -class RuleParseResult: - """Structured result from RuleParser.parse_v2 (project AST instead of mo-sql JSON).""" - - pattern_ast: Node - rewrite_ast: Node - mapping: Dict[str, str] - pattern_scope: Scope - rewrite_scope: Scope - - class RuleParser: # parse a rule (pattern, rewrite) into a SQL AST json str @@ -112,269 +70,6 @@ def parse(pattern: str, rewrite: str) -> Tuple[str, str, str]: # 5. Return the AST subtree as json string return json.dumps(patternASTJson), json.dumps(rewriteASTJson), json.dumps(mapping) - @staticmethod - def parse_v2(pattern: str, rewrite: str) -> RuleParseResult: - """Parse a rule into project AST nodes with external rule variable names. - - Uses the same extension and placeholder replacement as parse(), then - QueryParser plus substitution of internal tokens (V001 / VL001) to - VarNode / VarSetNode or decoded ColumnNode / TableNode names. - """ - pattern_sql, rewrite_sql, mapping = RuleParser.replaceVars(pattern, rewrite) - pattern_full, pattern_scope = RuleParser.extendToFullSQL(pattern_sql) - rewrite_full, rewrite_scope = RuleParser.extendToFullSQL(rewrite_sql) - qparser = QueryParser() - pattern_query = qparser.parse(pattern_full) - rewrite_query = qparser.parse(rewrite_full) - internal_to_external = {internal: external for external, internal in mapping.items()} - pattern_ast = RuleParser._extract_rule_ast(pattern_query, pattern_scope, internal_to_external) - rewrite_ast = RuleParser._extract_rule_ast(rewrite_query, rewrite_scope, internal_to_external) - return RuleParseResult( - pattern_ast=pattern_ast, - rewrite_ast=rewrite_ast, - mapping=mapping, - pattern_scope=pattern_scope, - rewrite_scope=rewrite_scope, - ) - - @staticmethod - def _get_clause(query: QueryNode, clause_type: NodeType) -> Optional[Node]: - for child in query.children: - if child.type == clause_type: - return child - return None - - @staticmethod - def _extract_rule_ast( - query: QueryNode, scope: Scope, internal_to_external: Dict[str, str] - ) -> Node: - sel = RuleParser._get_clause(query, NodeType.SELECT) - frm = RuleParser._get_clause(query, NodeType.FROM) - wh = RuleParser._get_clause(query, NodeType.WHERE) - gb = RuleParser._get_clause(query, NodeType.GROUP_BY) - hav = RuleParser._get_clause(query, NodeType.HAVING) - ob = RuleParser._get_clause(query, NodeType.ORDER_BY) - lim = RuleParser._get_clause(query, NodeType.LIMIT) - off = RuleParser._get_clause(query, NodeType.OFFSET) - - if scope == Scope.CONDITION: - if wh is None or not list(wh.children): - raise ValueError("CONDITION scope requires a WHERE predicate") - pred = list(wh.children)[0] - return RuleParser._as_rule_ast(pred, internal_to_external) - - if scope == Scope.WHERE: - return RuleParser._as_rule_ast( - QueryNode( - _select=None, - _from=None, - _where=RuleParser._as_rule_ast(wh, internal_to_external) if wh else None, - _group_by=RuleParser._as_rule_ast(gb, internal_to_external) if gb else None, - _having=RuleParser._as_rule_ast(hav, internal_to_external) if hav else None, - _order_by=RuleParser._as_rule_ast(ob, internal_to_external) if ob else None, - _limit=RuleParser._as_rule_ast(lim, internal_to_external) if lim else None, - _offset=RuleParser._as_rule_ast(off, internal_to_external) if off else None, - ), - internal_to_external, - ) - - if scope == Scope.FROM: - return RuleParser._as_rule_ast( - QueryNode( - _select=None, - _from=RuleParser._as_rule_ast(frm, internal_to_external) if frm else None, - _where=RuleParser._as_rule_ast(wh, internal_to_external) if wh else None, - _group_by=RuleParser._as_rule_ast(gb, internal_to_external) if gb else None, - _having=RuleParser._as_rule_ast(hav, internal_to_external) if hav else None, - _order_by=RuleParser._as_rule_ast(ob, internal_to_external) if ob else None, - _limit=RuleParser._as_rule_ast(lim, internal_to_external) if lim else None, - _offset=RuleParser._as_rule_ast(off, internal_to_external) if off else None, - ), - internal_to_external, - ) - - return RuleParser._as_rule_ast(query, internal_to_external) - - @staticmethod - def _as_rule_ast(node: Optional[Node], internal_to_external: Dict[str, str]) -> Optional[Node]: - if node is None: - return None - return RuleParser._substitute_placeholders(node, internal_to_external) - - @staticmethod - def _placeholder_varnode(internal_token: str, external_name: str) -> Node: - if internal_token.startswith(VarTypesInfo[VarType.VarList]["internalBase"]): - return VarSetNode(external_name) - return VarNode(external_name) - - @staticmethod - def _substitute_placeholders(node: Node, rev: Dict[str, str]) -> Node: - if node.type == NodeType.COLUMN: - col = node - if not isinstance(col, ColumnNode): - return node - pa = col.parent_alias - nm = col.name - if pa is None and nm in rev: - return RuleParser._placeholder_varnode(nm, rev[nm]) - if pa is not None and pa in rev and nm in rev: - return ColumnNode(rev[nm], _alias=col.alias, _parent_alias=rev[pa]) - if pa is not None and pa in rev: - return ColumnNode(nm, _alias=col.alias, _parent_alias=rev[pa]) - return ColumnNode(nm, _alias=col.alias, _parent_alias=pa) - - if node.type == NodeType.TABLE: - t = node - if not isinstance(t, TableNode): - return node - new_name = rev.get(t.name, t.name) - new_alias = rev[t.alias] if t.alias and t.alias in rev else t.alias - return TableNode(new_name, new_alias) - - if node.type == NodeType.QUERY: - q = node - if not isinstance(q, QueryNode): - return node - return QueryNode( - _select=RuleParser._as_rule_ast(RuleParser._get_clause(q, NodeType.SELECT), rev), - _from=RuleParser._as_rule_ast(RuleParser._get_clause(q, NodeType.FROM), rev), - _where=RuleParser._as_rule_ast(RuleParser._get_clause(q, NodeType.WHERE), rev), - _group_by=RuleParser._as_rule_ast(RuleParser._get_clause(q, NodeType.GROUP_BY), rev), - _having=RuleParser._as_rule_ast(RuleParser._get_clause(q, NodeType.HAVING), rev), - _order_by=RuleParser._as_rule_ast(RuleParser._get_clause(q, NodeType.ORDER_BY), rev), - _limit=RuleParser._as_rule_ast(RuleParser._get_clause(q, NodeType.LIMIT), rev), - _offset=RuleParser._as_rule_ast(RuleParser._get_clause(q, NodeType.OFFSET), rev), - ) - - if node.type == NodeType.SELECT: - sn = node - if not isinstance(sn, SelectNode): - return node - items: List[Node] = [] - don = sn.distinct_on - for ch in sn.children: - if don is not None and ch is don: - continue - items.append(RuleParser._substitute_placeholders(ch, rev)) - new_don = ( - RuleParser._substitute_placeholders(don, rev) if don is not None else None - ) - return SelectNode(items, _distinct=sn.distinct, _distinct_on=new_don) - - if node.type == NodeType.FROM: - fn = node - if not isinstance(fn, FromNode): - return node - return FromNode([RuleParser._substitute_placeholders(c, rev) for c in fn.children]) - - if node.type == NodeType.WHERE: - wn = node - if not isinstance(wn, WhereNode): - return node - return WhereNode([RuleParser._substitute_placeholders(c, rev) for c in wn.children]) - - if node.type == NodeType.GROUP_BY: - g = node - if not isinstance(g, GroupByNode): - return node - return GroupByNode([RuleParser._substitute_placeholders(c, rev) for c in g.children]) - - if node.type == NodeType.HAVING: - h = node - if not isinstance(h, HavingNode): - return node - return HavingNode([RuleParser._substitute_placeholders(c, rev) for c in h.children]) - - if node.type == NodeType.ORDER_BY: - o = node - if not isinstance(o, OrderByNode): - return node - return OrderByNode([RuleParser._substitute_placeholders(c, rev) for c in o.children]) - - if node.type == NodeType.ORDER_BY_ITEM: - oi = node - if not isinstance(oi, OrderByItemNode): - return node - inner = list(oi.children)[0] - return OrderByItemNode(RuleParser._substitute_placeholders(inner, rev), oi.sort) - - if node.type == NodeType.JOIN: - j = node - if not isinstance(j, JoinNode): - return node - ch = list(j.children) - left = RuleParser._substitute_placeholders(ch[0], rev) - right = RuleParser._substitute_placeholders(ch[1], rev) - on_expr = ( - RuleParser._substitute_placeholders(ch[2], rev) if len(ch) > 2 else None - ) - return JoinNode(left, right, j.join_type, on_expr) - - if node.type == NodeType.SUBQUERY: - sq = node - if not isinstance(sq, SubqueryNode): - return node - inner = list(sq.children)[0] - return SubqueryNode(RuleParser._substitute_placeholders(inner, rev), sq.alias) - - if node.type == NodeType.FUNCTION: - f = node - if not isinstance(f, FunctionNode): - return node - new_args = [RuleParser._substitute_placeholders(a, rev) for a in f.children] - return FunctionNode(f.name, _args=new_args, _alias=f.alias) - - if node.type == NodeType.LIST: - ln = node - if not isinstance(ln, ListNode): - return node - return ListNode([RuleParser._substitute_placeholders(c, rev) for c in ln.children]) - - if node.type == NodeType.INTERVAL: - inv = node - if not isinstance(inv, IntervalNode): - return node - if isinstance(inv.value, Node): - return IntervalNode( - RuleParser._substitute_placeholders(inv.value, rev), - inv.unit, # type: ignore[arg-type] - ) - return IntervalNode(inv.value, inv.unit) # type: ignore[arg-type] - - if node.type == NodeType.CASE: - cn = node - if not isinstance(cn, CaseNode): - return node - new_whens: List[WhenThenNode] = [] - for wt in cn.whens: - new_whens.append( - WhenThenNode( - RuleParser._substitute_placeholders(wt.when, rev), - RuleParser._substitute_placeholders(wt.then, rev), - ) - ) - new_else = ( - RuleParser._substitute_placeholders(cn.else_val, rev) if cn.else_val else None - ) - return CaseNode(new_whens, new_else) - - if node.type == NodeType.OPERATOR: - if isinstance(node, UnaryOperatorNode): - op = node - inner = list(op.children)[0] if op.children else op.operand - return UnaryOperatorNode(RuleParser._substitute_placeholders(inner, rev), op.name) - op = node - ch = list(op.children) - if len(ch) == 1: - return OperatorNode(RuleParser._substitute_placeholders(ch[0], rev), op.name) - return OperatorNode( - RuleParser._substitute_placeholders(ch[0], rev), - op.name, - RuleParser._substitute_placeholders(ch[1], rev), - ) - - return node - # Extend pattern/rewrite to full SQL statement # @staticmethod diff --git a/core/rule_parser_v2.py b/core/rule_parser_v2.py new file mode 100644 index 0000000..bd34e7d --- /dev/null +++ b/core/rule_parser_v2.py @@ -0,0 +1,506 @@ +# Rule parser v2: self-contained rule preprocessing (duplicated from v1 on purpose), then +# QueryParser and VarNode / VarSetNode rule AST via parse(). + +from __future__ import annotations + +import re +from dataclasses import dataclass +from enum import Enum +from typing import Dict, List, Optional, Tuple + +from core.ast.enums import NodeType +from core.ast.node import ( + CaseNode, + ColumnNode, + FromNode, + FunctionNode, + GroupByNode, + HavingNode, + IntervalNode, + JoinNode, + LimitNode, + ListNode, + Node, + OffsetNode, + OperatorNode, + OrderByItemNode, + OrderByNode, + QueryNode, + SelectNode, + SubqueryNode, + TableNode, + UnaryOperatorNode, + VarNode, + VarSetNode, + WhenThenNode, + WhereNode, +) +from core.query_parser import QueryParser + + +# Variable type (same as v1). +# +class VarType(Enum): + Var = 1 + VarList = 2 + + +# Variable Types' info (same as v1). +# +VarTypesInfo = { + VarType.Var: { + "markerStart": "<", + "markerEnd": ">", + "internalBase": "V", + "externalBase": "x", + }, + VarType.VarList: { + "markerStart": "<<", + "markerEnd": ">>", + "internalBase": "VL", + "externalBase": "y", + }, +} + + +# Scope of pattern/rewrite fragment (same as v1). +# +class Scope(Enum): + SELECT = 1 + FROM = 2 + WHERE = 3 + CONDITION = 4 + + +# Partial-SQL prefix for extendToFullSQL (same as v1). +# +ScopeExtension = { + Scope.CONDITION: "SELECT * FROM t WHERE ", + Scope.WHERE: "SELECT * FROM t ", + Scope.FROM: "SELECT * ", + Scope.SELECT: "", +} + + +# Result of RuleParserV2.parse: rule AST with external variable names restored. +# +@dataclass(frozen=True) +class RuleParseResult: + pattern_ast: Node + rewrite_ast: Node + mapping: Dict[str, str] + pattern_scope: Scope + rewrite_scope: Scope + + +class RuleParserV2: + + # mosql parsing can report mismatching brackets at a confusing index; detect common + # wrong delimiters around rule variables (same logic as v1 RuleParser.find_malformed_brackets). + # + @staticmethod + def find_malformed_brackets(pattern: str) -> int: + CommonMistakeVarTypesInfo = { + "markerStart": [r"\(", r"\{", r"\["], + "markerEnd": [r"\)", r"\}", r"\]"], + } + + for i in range(len(CommonMistakeVarTypesInfo["markerStart"])): + regexPatternVarStart = ( + CommonMistakeVarTypesInfo["markerStart"][i] + + r"(\w+)" + + VarTypesInfo[VarType.Var]["markerEnd"] + ) + regexPatternVarEnd = ( + VarTypesInfo[VarType.Var]["markerStart"] + + r"(\w+)" + + CommonMistakeVarTypesInfo["markerEnd"][i] + ) + + varStart = re.search(regexPatternVarStart, pattern) + varEnd = re.search(regexPatternVarEnd, pattern) + + if varStart: + return varStart.start() + if varEnd: + return varEnd.start() + + return -1 + + # Reject mismatched rule-variable brackets before any preprocessing (same intent as rule_generator). + # + @staticmethod + def _reject_malformed_var_brackets(pattern: str, rewrite: str) -> None: + i = RuleParserV2.find_malformed_brackets(pattern) + if i >= 0: + raise ValueError(f"mismatching brackets in pattern at index {i}") + i = RuleParserV2.find_malformed_brackets(rewrite) + if i >= 0: + raise ValueError(f"mismatching brackets in rewrite at index {i}") + + # Extend pattern/rewrite fragment to full SQL (same as v1 RuleParser.extendToFullSQL). + # + @staticmethod + def extendToFullSQL(partialSQL: str) -> Tuple[str, Scope]: + # Special case: condition on subquery + # e.g., group_users.group_id IN (SELECT ... ) + # Remove subquery in (*) before checking SELECT / FROM / WHERE. + # + sanitisedPartialSQL = re.sub(r"\(.*\)", "(x)", partialSQL) + + # case-1: no SELECT and no FROM and no WHERE + if ( + "SELECT" not in sanitisedPartialSQL.upper() + and "FROM" not in sanitisedPartialSQL.upper() + and "WHERE" not in sanitisedPartialSQL.upper() + ): + scope = Scope.CONDITION + # case-2: no SELECT and no FROM but has WHERE + elif ( + "SELECT" not in sanitisedPartialSQL.upper() + and "FROM" not in sanitisedPartialSQL.upper() + ): + scope = Scope.WHERE + # case-3: no SELECT but has FROM + elif "SELECT" not in sanitisedPartialSQL.upper(): + scope = Scope.FROM + # case-4: has SELECT (and typically FROM) + else: + scope = Scope.SELECT + + partialSQL = ScopeExtension[scope] + partialSQL + return partialSQL, scope + + # Replace user-facing rule variables with internal tokens (same as v1 RuleParser.replaceVars). + # e.g., ==> V001, <> ==> VL001 + # + @staticmethod + def replaceVars(pattern: str, rewrite: str) -> Tuple[str, str, Dict[str, str]]: + + def _replace_one_var_type( + pattern: str, rewrite: str, varType: VarType, mapping: Dict[str, str] + ) -> Tuple[str, str]: + regexPattern = ( + VarTypesInfo[varType]["markerStart"] + + r"(\w+)" + + VarTypesInfo[varType]["markerEnd"] + ) + found = re.findall(regexPattern, pattern) + varInternalBase = VarTypesInfo[varType]["internalBase"] + varInternalCount = 1 + for var in found: + if var not in mapping: + specificRegexPattern = ( + VarTypesInfo[varType]["markerStart"] + + var + + VarTypesInfo[varType]["markerEnd"] + ) + varInternal = varInternalBase + str(varInternalCount).zfill(3) + varInternalCount += 1 + pattern = re.sub(specificRegexPattern, varInternal, pattern) + rewrite = re.sub(specificRegexPattern, varInternal, rewrite) + mapping[var] = varInternal + return pattern, rewrite + + mapping: Dict[str, str] = {} + pattern, rewrite = _replace_one_var_type(pattern, rewrite, VarType.VarList, mapping) + pattern, rewrite = _replace_one_var_type(pattern, rewrite, VarType.Var, mapping) + return pattern, rewrite, mapping + + # parse a rule into project AST nodes (VarNode / VarSetNode for rule variables) + # + @staticmethod + def parse(pattern: str, rewrite: str) -> RuleParseResult: + + # 0. Reject mismatched <...> / <<...>> brackets + # + RuleParserV2._reject_malformed_var_brackets(pattern, rewrite) + + # 1. Replace user-faced variables and variable lists + # with internal representations + # + pattern_sql, rewrite_sql, mapping = RuleParserV2.replaceVars(pattern, rewrite) + + # 2. Extend partial SQL statement to full SQL statement + # for the sake of sql parser + # + pattern_full, pattern_scope = RuleParserV2.extendToFullSQL(pattern_sql) + rewrite_full, rewrite_scope = RuleParserV2.extendToFullSQL(rewrite_sql) + + # 3. Parse extended full SQL statement into AST (QueryParser) + # + qparser = QueryParser() + pattern_query = qparser.parse(pattern_full) + rewrite_query = qparser.parse(rewrite_full) + + # 4. Extract rule-shaped subtree by scope and map V00x / VL00x back to external names + # + internal_to_external = {internal: external for external, internal in mapping.items()} + pattern_ast = RuleParserV2._extract_rule_ast(pattern_query, pattern_scope, internal_to_external) + rewrite_ast = RuleParserV2._extract_rule_ast(rewrite_query, rewrite_scope, internal_to_external) + + # 5. Return AST result + mapping + scopes + # + return RuleParseResult( + pattern_ast=pattern_ast, + rewrite_ast=rewrite_ast, + mapping=mapping, + pattern_scope=pattern_scope, + rewrite_scope=rewrite_scope, + ) + + # Find first child of query with given clause type (SELECT, FROM, WHERE, ...). + # + @staticmethod + def _get_clause(query: QueryNode, clause_type: NodeType) -> Optional[Node]: + for child in query.children: + if child.type == clause_type: + return child + return None + + # Slice full query AST to the rule fragment for this scope, then substitute placeholders + # (internal column/table names) to VarNode / VarSetNode using internal_to_external. + # + @staticmethod + def _extract_rule_ast( + query: QueryNode, scope: Scope, internal_to_external: Dict[str, str] + ) -> Node: + frm = RuleParserV2._get_clause(query, NodeType.FROM) + wh = RuleParserV2._get_clause(query, NodeType.WHERE) + gb = RuleParserV2._get_clause(query, NodeType.GROUP_BY) + hav = RuleParserV2._get_clause(query, NodeType.HAVING) + ob = RuleParserV2._get_clause(query, NodeType.ORDER_BY) + lim = RuleParserV2._get_clause(query, NodeType.LIMIT) + off = RuleParserV2._get_clause(query, NodeType.OFFSET) + + # case CONDITION: predicate only + # + if scope == Scope.CONDITION: + if wh is None or not list(wh.children): + raise ValueError("CONDITION scope requires a WHERE predicate") + pred = list(wh.children)[0] + return RuleParserV2._as_rule_ast(pred, internal_to_external) + + # case WHERE: rewrite-shaped query without select/from + # + if scope == Scope.WHERE: + return RuleParserV2._as_rule_ast( + QueryNode( + _select=None, + _from=None, + _where=RuleParserV2._as_rule_ast(wh, internal_to_external) if wh else None, + _group_by=RuleParserV2._as_rule_ast(gb, internal_to_external) if gb else None, + _having=RuleParserV2._as_rule_ast(hav, internal_to_external) if hav else None, + _order_by=RuleParserV2._as_rule_ast(ob, internal_to_external) if ob else None, + _limit=RuleParserV2._as_rule_ast(lim, internal_to_external) if lim else None, + _offset=RuleParserV2._as_rule_ast(off, internal_to_external) if off else None, + ), + internal_to_external, + ) + + # case FROM: from + following clauses, no select list + # + if scope == Scope.FROM: + return RuleParserV2._as_rule_ast( + QueryNode( + _select=None, + _from=RuleParserV2._as_rule_ast(frm, internal_to_external) if frm else None, + _where=RuleParserV2._as_rule_ast(wh, internal_to_external) if wh else None, + _group_by=RuleParserV2._as_rule_ast(gb, internal_to_external) if gb else None, + _having=RuleParserV2._as_rule_ast(hav, internal_to_external) if hav else None, + _order_by=RuleParserV2._as_rule_ast(ob, internal_to_external) if ob else None, + _limit=RuleParserV2._as_rule_ast(lim, internal_to_external) if lim else None, + _offset=RuleParserV2._as_rule_ast(off, internal_to_external) if off else None, + ), + internal_to_external, + ) + + # case SELECT: full query + # + return RuleParserV2._as_rule_ast(query, internal_to_external) + + # Run VarNode / VarSetNode substitution on one subtree (None stays None). + # + @staticmethod + def _as_rule_ast(node: Optional[Node], internal_to_external: Dict[str, str]) -> Optional[Node]: + if node is None: + return None + return RuleParserV2._substitute_placeholders(node, internal_to_external) + + # Build VarNode or VarSetNode from internal token prefix (V... vs VL...). + # + @staticmethod + def _placeholder_varnode(internal_token: str, external_name: str) -> Node: + if internal_token.startswith(VarTypesInfo[VarType.VarList]["internalBase"]): + return VarSetNode(external_name) + return VarNode(external_name) + + # Structural recursion: replace internal identifiers with VarNode / VarSetNode where appropriate. + # + @staticmethod + def _substitute_placeholders(node: Node, rev: Dict[str, str]) -> Node: + if node.type == NodeType.COLUMN: + col = node + if not isinstance(col, ColumnNode): + return node + pa = col.parent_alias + nm = col.name + if pa is None and nm in rev: + return RuleParserV2._placeholder_varnode(nm, rev[nm]) + if pa is not None and pa in rev and nm in rev: + return ColumnNode(rev[nm], _alias=col.alias, _parent_alias=rev[pa]) + if pa is not None and pa in rev: + return ColumnNode(nm, _alias=col.alias, _parent_alias=rev[pa]) + return ColumnNode(nm, _alias=col.alias, _parent_alias=pa) + + if node.type == NodeType.TABLE: + t = node + if not isinstance(t, TableNode): + return node + new_name = rev.get(t.name, t.name) + new_alias = rev[t.alias] if t.alias and t.alias in rev else t.alias + return TableNode(new_name, new_alias) + + if node.type == NodeType.QUERY: + q = node + if not isinstance(q, QueryNode): + return node + return QueryNode( + _select=RuleParserV2._as_rule_ast(RuleParserV2._get_clause(q, NodeType.SELECT), rev), + _from=RuleParserV2._as_rule_ast(RuleParserV2._get_clause(q, NodeType.FROM), rev), + _where=RuleParserV2._as_rule_ast(RuleParserV2._get_clause(q, NodeType.WHERE), rev), + _group_by=RuleParserV2._as_rule_ast(RuleParserV2._get_clause(q, NodeType.GROUP_BY), rev), + _having=RuleParserV2._as_rule_ast(RuleParserV2._get_clause(q, NodeType.HAVING), rev), + _order_by=RuleParserV2._as_rule_ast(RuleParserV2._get_clause(q, NodeType.ORDER_BY), rev), + _limit=RuleParserV2._as_rule_ast(RuleParserV2._get_clause(q, NodeType.LIMIT), rev), + _offset=RuleParserV2._as_rule_ast(RuleParserV2._get_clause(q, NodeType.OFFSET), rev), + ) + + if node.type == NodeType.SELECT: + sn = node + if not isinstance(sn, SelectNode): + return node + items: List[Node] = [] + don = sn.distinct_on + for ch in sn.children: + if don is not None and ch is don: + continue + items.append(RuleParserV2._substitute_placeholders(ch, rev)) + new_don = ( + RuleParserV2._substitute_placeholders(don, rev) if don is not None else None + ) + return SelectNode(items, _distinct=sn.distinct, _distinct_on=new_don) + + if node.type == NodeType.FROM: + fn = node + if not isinstance(fn, FromNode): + return node + return FromNode([RuleParserV2._substitute_placeholders(c, rev) for c in fn.children]) + + if node.type == NodeType.WHERE: + wn = node + if not isinstance(wn, WhereNode): + return node + return WhereNode([RuleParserV2._substitute_placeholders(c, rev) for c in wn.children]) + + if node.type == NodeType.GROUP_BY: + g = node + if not isinstance(g, GroupByNode): + return node + return GroupByNode([RuleParserV2._substitute_placeholders(c, rev) for c in g.children]) + + if node.type == NodeType.HAVING: + h = node + if not isinstance(h, HavingNode): + return node + return HavingNode([RuleParserV2._substitute_placeholders(c, rev) for c in h.children]) + + if node.type == NodeType.ORDER_BY: + o = node + if not isinstance(o, OrderByNode): + return node + return OrderByNode([RuleParserV2._substitute_placeholders(c, rev) for c in o.children]) + + if node.type == NodeType.ORDER_BY_ITEM: + oi = node + if not isinstance(oi, OrderByItemNode): + return node + inner = list(oi.children)[0] + return OrderByItemNode(RuleParserV2._substitute_placeholders(inner, rev), oi.sort) + + if node.type == NodeType.JOIN: + j = node + if not isinstance(j, JoinNode): + return node + ch = list(j.children) + left = RuleParserV2._substitute_placeholders(ch[0], rev) + right = RuleParserV2._substitute_placeholders(ch[1], rev) + on_expr = ( + RuleParserV2._substitute_placeholders(ch[2], rev) if len(ch) > 2 else None + ) + return JoinNode(left, right, j.join_type, on_expr) + + if node.type == NodeType.SUBQUERY: + sq = node + if not isinstance(sq, SubqueryNode): + return node + inner = list(sq.children)[0] + return SubqueryNode(RuleParserV2._substitute_placeholders(inner, rev), sq.alias) + + if node.type == NodeType.FUNCTION: + f = node + if not isinstance(f, FunctionNode): + return node + new_args = [RuleParserV2._substitute_placeholders(a, rev) for a in f.children] + return FunctionNode(f.name, _args=new_args, _alias=f.alias) + + if node.type == NodeType.LIST: + ln = node + if not isinstance(ln, ListNode): + return node + return ListNode([RuleParserV2._substitute_placeholders(c, rev) for c in ln.children]) + + if node.type == NodeType.INTERVAL: + inv = node + if not isinstance(inv, IntervalNode): + return node + if isinstance(inv.value, Node): + return IntervalNode( + RuleParserV2._substitute_placeholders(inv.value, rev), + inv.unit, # type: ignore[arg-type] + ) + return IntervalNode(inv.value, inv.unit) # type: ignore[arg-type] + + if node.type == NodeType.CASE: + cn = node + if not isinstance(cn, CaseNode): + return node + new_whens: List[WhenThenNode] = [] + for wt in cn.whens: + new_whens.append( + WhenThenNode( + RuleParserV2._substitute_placeholders(wt.when, rev), + RuleParserV2._substitute_placeholders(wt.then, rev), + ) + ) + new_else = ( + RuleParserV2._substitute_placeholders(cn.else_val, rev) if cn.else_val else None + ) + return CaseNode(new_whens, new_else) + + if node.type == NodeType.OPERATOR: + if isinstance(node, UnaryOperatorNode): + op = node + inner = list(op.children)[0] if op.children else op.operand + return UnaryOperatorNode(RuleParserV2._substitute_placeholders(inner, rev), op.name) + op = node + ch = list(op.children) + if len(ch) == 1: + return OperatorNode(RuleParserV2._substitute_placeholders(ch[0], rev), op.name) + return OperatorNode( + RuleParserV2._substitute_placeholders(ch[0], rev), + op.name, + RuleParserV2._substitute_placeholders(ch[1], rev), + ) + + return node diff --git a/tests/test_rule_parser.py b/tests/test_rule_parser.py index 0956a1c..aa6dcee 100644 --- a/tests/test_rule_parser.py +++ b/tests/test_rule_parser.py @@ -1,13 +1,4 @@ -from core.ast.enums import NodeType -from core.ast.node import ( - DataTypeNode, - FunctionNode, - QueryNode, - SelectNode, - VarNode, - VarSetNode, -) -from core.rule_parser import RuleParser, RuleParseResult, Scope +from core.rule_parser import RuleParser, Scope def test_extendToFullSQL(): @@ -159,89 +150,50 @@ def test_parse(): assert rewrite_json == internal_rule['rewrite_json'] -def test_parse_v2_cast_rule(): - result = RuleParser.parse_v2('CAST( AS DATE)', '') - assert isinstance(result, RuleParseResult) - assert result.pattern_scope == Scope.CONDITION - assert result.rewrite_scope == Scope.CONDITION - assert result.mapping == {'x': 'V001'} - assert isinstance(result.pattern_ast, FunctionNode) - assert result.pattern_ast.name.lower() == 'cast' - cast_args = list(result.pattern_ast.children) - assert isinstance(cast_args[0], VarNode) and cast_args[0].name == 'x' - assert isinstance(cast_args[1], DataTypeNode) - assert isinstance(result.rewrite_ast, VarNode) and result.rewrite_ast.name == 'x' - - -def test_parse_v2_select_list_varset(): - pattern = 'select <> from lineitem where 1 = 1' - rewrite = 'select <> from lineitem where 1 = 1' - result = RuleParser.parse_v2(pattern, rewrite) - assert result.pattern_scope == Scope.SELECT - assert isinstance(result.pattern_ast, QueryNode) - select = next(c for c in result.pattern_ast.children if c.type == NodeType.SELECT) - assert isinstance(select, SelectNode) - first = list(select.children)[0] - assert isinstance(first, VarSetNode) and first.name == 's1' - - -#incorrect brackets def test_brackets_1(): - - pattern = '''WHERE 11 + pattern = '''WHERE 11 AND a <= 11 ''' - - index = RuleParser.find_malformed_brackets(pattern) - assert index == 6 + index = RuleParser.find_malformed_brackets(pattern) + assert index == 6 + - #incorrect brackets - def test_brackets_2(): - +def test_brackets_2(): pattern = '''WHERE 11 AND a <= 11 ''' - index = RuleParser.find_malformed_brackets(pattern) assert index == 6 -#incorrect brackets + def test_parse_validator_3(): - - pattern = '''WHERE 11 + pattern = '''WHERE 11 AND a <= 11 ''' + index = RuleParser.find_malformed_brackets(pattern) + assert index == 6 - index = RuleParser.find_malformed_brackets(pattern) - assert index == 6 -#incorrect brackets - def test_parse_validator_4(): - +def test_parse_validator_4(): pattern = '''WHERE [x> > 11 AND a <= 11 ''' - index = RuleParser.find_malformed_brackets(pattern) assert index == 6 -#incorrect brackets - def test_parse_validator_5(): - +def test_parse_validator_5(): pattern = '''WHERE (x> > 11 AND a <= 11 ''' - index = RuleParser.find_malformed_brackets(pattern) assert index == 6 - -#incorrect brackets - def test_parse_validator_6(): - + + +def test_parse_validator_6(): pattern = '''WHERE {x> > 11 AND a <= 11 ''' index = RuleParser.find_malformed_brackets(pattern) assert index == 6 - + diff --git a/tests/test_rule_parser_v2.py b/tests/test_rule_parser_v2.py new file mode 100644 index 0000000..8d45721 --- /dev/null +++ b/tests/test_rule_parser_v2.py @@ -0,0 +1,280 @@ +"""Tests for core.rule_parser_v2 — mirrors tests in test_rule_parser.py plus AST (VarNode) paths.""" + +import pytest + +from core.ast.enums import NodeType +from core.ast.node import ( + DataTypeNode, + FromNode, + FunctionNode, + LiteralNode, + OperatorNode, + QueryNode, + SelectNode, + TableNode, + VarNode, + VarSetNode, + WhereNode, +) +from core.rule_parser_v2 import RuleParseResult, RuleParserV2, Scope + + +def test_extendToFullSQL(): + # Same assertions as tests/test_rule_parser.py::test_extendToFullSQL + pattern = "CAST(V1 AS DATE)" + rewrite = "V1" + pattern, scope = RuleParserV2.extendToFullSQL(pattern) + assert pattern == "SELECT * FROM t WHERE CAST(V1 AS DATE)" + assert scope == Scope.CONDITION + rewrite, scope = RuleParserV2.extendToFullSQL(rewrite) + assert rewrite == "SELECT * FROM t WHERE V1" + assert scope == Scope.CONDITION + + pattern = "WHERE CAST(V1 AS DATE)" + rewrite = "WHERE V1" + pattern, scope = RuleParserV2.extendToFullSQL(pattern) + assert pattern == "SELECT * FROM t WHERE CAST(V1 AS DATE)" + assert scope == Scope.WHERE + rewrite, scope = RuleParserV2.extendToFullSQL(rewrite) + assert rewrite == "SELECT * FROM t WHERE V1" + assert scope == Scope.WHERE + + pattern = "FROM lineitem" + rewrite = "FROM v_lineitem" + pattern, scope = RuleParserV2.extendToFullSQL(pattern) + assert pattern == "SELECT * FROM lineitem" + assert scope == Scope.FROM + rewrite, scope = RuleParserV2.extendToFullSQL(rewrite) + assert rewrite == "SELECT * FROM v_lineitem" + assert scope == Scope.FROM + + pattern = """ + select VL1 + from V1 V2, + V3 V4 + where V2.V6=V4.V8 + and VL2 + """ + rewrite = """ + select VL1 + from V1 V2 + where VL2 + """ + pattern, scope = RuleParserV2.extendToFullSQL(pattern) + assert pattern == """ + select VL1 + from V1 V2, + V3 V4 + where V2.V6=V4.V8 + and VL2 + """ + assert scope == Scope.SELECT + rewrite, scope = RuleParserV2.extendToFullSQL(rewrite) + assert rewrite == """ + select VL1 + from V1 V2 + where VL2 + """ + assert scope == Scope.SELECT + + pattern = "SELECT VL1 FROM lineitem" + rewrite = "SELECT VL1 FROM v_lineitem" + pattern, scope = RuleParserV2.extendToFullSQL(pattern) + assert pattern == "SELECT VL1 FROM lineitem" + assert scope == Scope.SELECT + rewrite, scope = RuleParserV2.extendToFullSQL(rewrite) + assert rewrite == "SELECT VL1 FROM v_lineitem" + assert scope == Scope.SELECT + + pattern = "SELECT CAST(V1 AS DATE)" + rewrite = "SELECT V1" + pattern, scope = RuleParserV2.extendToFullSQL(pattern) + assert pattern == "SELECT CAST(V1 AS DATE)" + assert scope == Scope.SELECT + rewrite, scope = RuleParserV2.extendToFullSQL(rewrite) + assert rewrite == "SELECT V1" + assert scope == Scope.SELECT + + +def test_replaceVars(): + pattern = "CAST( AS DATE)" + rewrite = "" + pattern, rewrite, mapping = RuleParserV2.replaceVars(pattern, rewrite) + assert pattern == "CAST(V001 AS DATE)" + assert rewrite == "V001" + assert mapping == {"x": "V001"} + + pattern = """ + select <> + from , + + where .=. + and <> + """ + rewrite = """ + select <> + from + where <> + """ + pattern, rewrite, mapping = RuleParserV2.replaceVars(pattern, rewrite) + assert pattern == """ + select VL001 + from V001 V002, + V003 V004 + where V002.V005=V004.V006 + and VL002 + """ + assert rewrite == """ + select VL001 + from V001 V002 + where VL002 + """ + assert mapping == { + "s1": "VL001", + "p1": "VL002", + "tb1": "V001", + "t1": "V002", + "tb2": "V003", + "t2": "V004", + "a1": "V005", + "a2": "V006", + } + + +def test_parse_rejects_malformed_brackets_in_pattern(): + pattern = """WHERE 11 + AND a <= 11 + """ + with pytest.raises(ValueError, match=r"mismatching brackets in pattern at index 6"): + RuleParserV2.parse(pattern, "") + + +def test_parse_rejects_malformed_brackets_in_rewrite(): + pattern = "" + rewrite = """WHERE 11 + AND a <= 11 + """ + with pytest.raises(ValueError, match=r"mismatching brackets in rewrite at index 6"): + RuleParserV2.parse(pattern, rewrite) + + +def test_parse_ast_cast_rule(): + result = RuleParserV2.parse("CAST( AS DATE)", "") + assert isinstance(result, RuleParseResult) + assert result.pattern_scope == Scope.CONDITION + assert result.rewrite_scope == Scope.CONDITION + assert result.mapping == {"x": "V001"} + assert isinstance(result.pattern_ast, FunctionNode) + assert result.pattern_ast.name.lower() == "cast" + cast_args = list(result.pattern_ast.children) + assert isinstance(cast_args[0], VarNode) and cast_args[0].name == "x" + assert isinstance(cast_args[1], DataTypeNode) + assert isinstance(result.rewrite_ast, VarNode) and result.rewrite_ast.name == "x" + + +def test_parse_ast_select_list_varset(): + pattern = "select <> from lineitem where 1 = 1" + rewrite = "select <> from lineitem where 1 = 1" + result = RuleParserV2.parse(pattern, rewrite) + assert result.pattern_scope == Scope.SELECT + assert isinstance(result.pattern_ast, QueryNode) + select = next(c for c in result.pattern_ast.children if c.type == NodeType.SELECT) + assert isinstance(select, SelectNode) + first = list(select.children)[0] + assert isinstance(first, VarSetNode) and first.name == "s1" + + +def test_parse_ast_strpos_ilike_rule(): + result = RuleParserV2.parse( + "STRPOS(LOWER(), '') > 0", + " ILIKE '%%'", + ) + assert result.mapping == {"x": "V001", "s": "V002"} + assert result.pattern_scope == Scope.CONDITION + assert result.rewrite_scope == Scope.CONDITION + assert isinstance(result.pattern_ast, OperatorNode) + assert result.pattern_ast.name == ">" + strpos = list(result.pattern_ast.children)[0] + assert isinstance(strpos, FunctionNode) and strpos.name.upper() == "STRPOS" + lower = list(strpos.children)[0] + assert isinstance(lower, FunctionNode) and lower.name.lower() == "lower" + assert isinstance(list(lower.children)[0], VarNode) and list(lower.children)[0].name == "x" + assert isinstance(result.rewrite_ast, FunctionNode) and result.rewrite_ast.name.lower() == "ilike" + ilike_args = list(result.rewrite_ast.children) + assert isinstance(ilike_args[0], VarNode) and ilike_args[0].name == "x" + assert isinstance(ilike_args[1], LiteralNode) + + +def test_parse_ast_where_scope(): + result = RuleParserV2.parse("WHERE = 1", "WHERE = 1") + assert result.pattern_scope == Scope.WHERE + assert result.rewrite_scope == Scope.WHERE + assert result.mapping == {"x": "V001"} + assert isinstance(result.pattern_ast, QueryNode) + wh = next(c for c in result.pattern_ast.children if c.type == NodeType.WHERE) + assert isinstance(wh, WhereNode) + pred = list(wh.children)[0] + assert isinstance(pred, OperatorNode) and pred.name == "=" + lhs, rhs = list(pred.children) + assert isinstance(lhs, VarNode) and lhs.name == "x" + assert isinstance(rhs, LiteralNode) and rhs.value == 1 + + +def test_parse_ast_from_scope(): + result = RuleParserV2.parse("FROM li", "FROM li") + assert result.pattern_scope == Scope.FROM + assert result.rewrite_scope == Scope.FROM + assert result.mapping == {"t": "V001"} + assert isinstance(result.pattern_ast, QueryNode) + frm = next(c for c in result.pattern_ast.children if c.type == NodeType.FROM) + assert isinstance(frm, FromNode) + tab = list(frm.children)[0] + assert isinstance(tab, TableNode) and tab.name == "t" and tab.alias == "li" + + +def test_brackets_1(): + pattern = '''WHERE 11 + AND a <= 11 + ''' + index = RuleParserV2.find_malformed_brackets(pattern) + assert index == 6 + + +def test_brackets_2(): + pattern = '''WHERE 11 + AND a <= 11 + ''' + index = RuleParserV2.find_malformed_brackets(pattern) + assert index == 6 + + +def test_parse_validator_3(): + pattern = '''WHERE 11 + AND a <= 11 + ''' + index = RuleParserV2.find_malformed_brackets(pattern) + assert index == 6 + + +def test_parse_validator_4(): + pattern = '''WHERE [x> > 11 + AND a <= 11 + ''' + index = RuleParserV2.find_malformed_brackets(pattern) + assert index == 6 + + +def test_parse_validator_5(): + pattern = '''WHERE (x> > 11 + AND a <= 11 + ''' + index = RuleParserV2.find_malformed_brackets(pattern) + assert index == 6 + + +def test_parse_validator_6(): + pattern = '''WHERE {x> > 11 + AND a <= 11 + ''' + index = RuleParserV2.find_malformed_brackets(pattern) + assert index == 6 \ No newline at end of file From a4ea3e87e7b162cea5597733fe8290251225d1f1 Mon Sep 17 00:00:00 2001 From: Colin Harrison Date: Thu, 2 Apr 2026 16:12:08 -0700 Subject: [PATCH 3/8] fix newlines --- core/rule_parser.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/core/rule_parser.py b/core/rule_parser.py index 646a39e..f0de16d 100644 --- a/core/rule_parser.py +++ b/core/rule_parser.py @@ -4,13 +4,14 @@ import re from typing import Any, Tuple + # Variable Type # class VarType(Enum): Var = 1 VarList = 2 -# Variable Types' info +# Variable Types' infro VarTypesInfo = { VarType.Var: { 'markerStart': '<', @@ -41,7 +42,6 @@ class Scope(Enum): Scope.SELECT: '' } - class RuleParser: # parse a rule (pattern, rewrite) into a SQL AST json str From 1705c2b1b7a6741e93422b42c0b1b63b196551ba Mon Sep 17 00:00:00 2001 From: Colin Harrison Date: Thu, 2 Apr 2026 16:33:37 -0700 Subject: [PATCH 4/8] seperate VarNode/Varset substitution from reducing query scope --- core/rule_parser_v2.py | 80 +++++++++++++++++++++++------------------- 1 file changed, 43 insertions(+), 37 deletions(-) diff --git a/core/rule_parser_v2.py b/core/rule_parser_v2.py index bd34e7d..bad066f 100644 --- a/core/rule_parser_v2.py +++ b/core/rule_parser_v2.py @@ -233,13 +233,18 @@ def parse(pattern: str, rewrite: str) -> RuleParseResult: pattern_query = qparser.parse(pattern_full) rewrite_query = qparser.parse(rewrite_full) - # 4. Extract rule-shaped subtree by scope and map V00x / VL00x back to external names + # 4. Map internal tokens (V00x / VL00x) to VarNode / VarSetNode across the full query AST # internal_to_external = {internal: external for external, internal in mapping.items()} - pattern_ast = RuleParserV2._extract_rule_ast(pattern_query, pattern_scope, internal_to_external) - rewrite_ast = RuleParserV2._extract_rule_ast(rewrite_query, rewrite_scope, internal_to_external) + pattern_after_vars = RuleParserV2._substitute_rule_vars(pattern_query, internal_to_external) + rewrite_after_vars = RuleParserV2._substitute_rule_vars(rewrite_query, internal_to_external) - # 5. Return AST result + mapping + scopes + # 5. Reduce to the rule fragment for the inferred scope (CONDITION / WHERE / FROM / SELECT) + # + pattern_ast = RuleParserV2._extract_rule_fragment(pattern_after_vars, pattern_scope) + rewrite_ast = RuleParserV2._extract_rule_fragment(rewrite_after_vars, rewrite_scope) + + # 6. Return AST result + mapping + scopes # return RuleParseResult( pattern_ast=pattern_ast, @@ -258,13 +263,21 @@ def _get_clause(query: QueryNode, clause_type: NodeType) -> Optional[Node]: return child return None - # Slice full query AST to the rule fragment for this scope, then substitute placeholders - # (internal column/table names) to VarNode / VarSetNode using internal_to_external. + # Apply internal_to_external across an entire parsed query (V00x -> VarNode, etc.). + # + @staticmethod + def _substitute_rule_vars( + query: QueryNode, internal_to_external: Dict[str, str] + ) -> QueryNode: + out = RuleParserV2._as_rule_ast(query, internal_to_external) + if not isinstance(out, QueryNode): + raise TypeError("expected QueryNode after substituting rule variables on full query") + return out + + # Slice a fully substituted query to the rule fragment for this scope (no VarNode pass). # @staticmethod - def _extract_rule_ast( - query: QueryNode, scope: Scope, internal_to_external: Dict[str, str] - ) -> Node: + def _extract_rule_fragment(query: QueryNode, scope: Scope) -> Node: frm = RuleParserV2._get_clause(query, NodeType.FROM) wh = RuleParserV2._get_clause(query, NodeType.WHERE) gb = RuleParserV2._get_clause(query, NodeType.GROUP_BY) @@ -278,46 +291,39 @@ def _extract_rule_ast( if scope == Scope.CONDITION: if wh is None or not list(wh.children): raise ValueError("CONDITION scope requires a WHERE predicate") - pred = list(wh.children)[0] - return RuleParserV2._as_rule_ast(pred, internal_to_external) + return list(wh.children)[0] - # case WHERE: rewrite-shaped query without select/from + # case WHERE: query without select/from lists # if scope == Scope.WHERE: - return RuleParserV2._as_rule_ast( - QueryNode( - _select=None, - _from=None, - _where=RuleParserV2._as_rule_ast(wh, internal_to_external) if wh else None, - _group_by=RuleParserV2._as_rule_ast(gb, internal_to_external) if gb else None, - _having=RuleParserV2._as_rule_ast(hav, internal_to_external) if hav else None, - _order_by=RuleParserV2._as_rule_ast(ob, internal_to_external) if ob else None, - _limit=RuleParserV2._as_rule_ast(lim, internal_to_external) if lim else None, - _offset=RuleParserV2._as_rule_ast(off, internal_to_external) if off else None, - ), - internal_to_external, + return QueryNode( + _select=None, + _from=None, + _where=wh, + _group_by=gb, + _having=hav, + _order_by=ob, + _limit=lim, + _offset=off, ) # case FROM: from + following clauses, no select list # if scope == Scope.FROM: - return RuleParserV2._as_rule_ast( - QueryNode( - _select=None, - _from=RuleParserV2._as_rule_ast(frm, internal_to_external) if frm else None, - _where=RuleParserV2._as_rule_ast(wh, internal_to_external) if wh else None, - _group_by=RuleParserV2._as_rule_ast(gb, internal_to_external) if gb else None, - _having=RuleParserV2._as_rule_ast(hav, internal_to_external) if hav else None, - _order_by=RuleParserV2._as_rule_ast(ob, internal_to_external) if ob else None, - _limit=RuleParserV2._as_rule_ast(lim, internal_to_external) if lim else None, - _offset=RuleParserV2._as_rule_ast(off, internal_to_external) if off else None, - ), - internal_to_external, + return QueryNode( + _select=None, + _from=frm, + _where=wh, + _group_by=gb, + _having=hav, + _order_by=ob, + _limit=lim, + _offset=off, ) # case SELECT: full query # - return RuleParserV2._as_rule_ast(query, internal_to_external) + return query # Run VarNode / VarSetNode substitution on one subtree (None stays None). # From 1a644fd0b7413660f1437a302d685b4f2deeaead Mon Sep 17 00:00:00 2001 From: Colin Harrison Date: Fri, 3 Apr 2026 16:05:50 -0700 Subject: [PATCH 5/8] fix node names and add test cases --- core/ast/__init__.py | 8 +- core/ast/node.py | 8 +- core/query_parser.py | 4 +- core/rule_parser_v2.py | 68 +++---- tests/ast_util.py | 6 +- tests/test_ast.py | 20 +- tests/test_rule_parser_v2.py | 364 +++++++++++++++++++++++++++++++---- 7 files changed, 379 insertions(+), 99 deletions(-) diff --git a/core/ast/__init__.py b/core/ast/__init__.py index 804646a..f45bf2c 100644 --- a/core/ast/__init__.py +++ b/core/ast/__init__.py @@ -11,8 +11,8 @@ SubqueryNode, ColumnNode, LiteralNode, - VarNode, - VarSetNode, + ElementVariableNode, + SetVariableNode, OperatorNode, FunctionNode, SelectNode, @@ -34,8 +34,8 @@ 'SubqueryNode', 'ColumnNode', 'LiteralNode', - 'VarNode', - 'VarSetNode', + 'ElementVariableNode', + 'SetVariableNode', 'OperatorNode', 'FunctionNode', 'SelectNode', diff --git a/core/ast/node.py b/core/ast/node.py index 52e505d..8932f7f 100644 --- a/core/ast/node.py +++ b/core/ast/node.py @@ -168,15 +168,15 @@ def __eq__(self, other): def __hash__(self): return hash((super().__hash__(), self.value, self.unit)) -class VarNode(Node): - """VarSQL variable node""" +class ElementVariableNode(Node): + """Rule element variable ```` (see ``VarType.ElementVariable`` in rule_parser_v2).""" def __init__(self, _name: str, **kwargs): super().__init__(NodeType.VAR, **kwargs) self.name = _name -class VarSetNode(Node): - """VarSQL variable set node""" +class SetVariableNode(Node): + """Rule set variable ``<>`` (see ``VarType.SetVariable`` in rule_parser_v2).""" def __init__(self, _name: str, **kwargs): super().__init__(NodeType.VARSET, **kwargs) self.name = _name diff --git a/core/query_parser.py b/core/query_parser.py index 19494bb..8e30c59 100644 --- a/core/query_parser.py +++ b/core/query_parser.py @@ -4,9 +4,9 @@ CaseNode, WhenThenNode, OperatorNode, UnaryOperatorNode, FunctionNode, GroupByNode, HavingNode, OrderByNode, OrderByItemNode, LimitNode, OffsetNode, SubqueryNode, - VarNode, VarSetNode, JoinNode, ListNode + ElementVariableNode, SetVariableNode, JoinNode, ListNode ) -# TODO: implement VarNode, VarSetNode +# TODO: implement ElementVariableNode, SetVariableNode from core.ast.enums import JoinType, SortOrder import mo_sql_parsing as mosql import json diff --git a/core/rule_parser_v2.py b/core/rule_parser_v2.py index bad066f..e49fed0 100644 --- a/core/rule_parser_v2.py +++ b/core/rule_parser_v2.py @@ -1,5 +1,5 @@ # Rule parser v2: self-contained rule preprocessing (duplicated from v1 on purpose), then -# QueryParser and VarNode / VarSetNode rule AST via parse(). +# QueryParser and ElementVariableNode / SetVariableNode rule AST via parse(). from __future__ import annotations @@ -30,34 +30,35 @@ SubqueryNode, TableNode, UnaryOperatorNode, - VarNode, - VarSetNode, + ElementVariableNode, + SetVariableNode, WhenThenNode, WhereNode, ) from core.query_parser import QueryParser -# Variable type (same as v1). +# Variable types (v2 naming; same placeholder syntax as v1). +# AST: ```` → ElementVariableNode, ``<>`` → SetVariableNode. # class VarType(Enum): - Var = 1 - VarList = 2 + ElementVariable = 1 # → ElementVariableNode in rule AST + SetVariable = 2 # <> → SetVariableNode in rule AST -# Variable Types' info (same as v1). +# Placeholder markers and internal token prefixes for rule variables. # VarTypesInfo = { - VarType.Var: { + VarType.ElementVariable: { "markerStart": "<", "markerEnd": ">", - "internalBase": "V", + "internalBase": "EV", "externalBase": "x", }, - VarType.VarList: { + VarType.SetVariable: { "markerStart": "<<", "markerEnd": ">>", - "internalBase": "VL", + "internalBase": "SV", "externalBase": "y", }, } @@ -89,8 +90,6 @@ class RuleParseResult: pattern_ast: Node rewrite_ast: Node mapping: Dict[str, str] - pattern_scope: Scope - rewrite_scope: Scope class RuleParserV2: @@ -109,10 +108,10 @@ def find_malformed_brackets(pattern: str) -> int: regexPatternVarStart = ( CommonMistakeVarTypesInfo["markerStart"][i] + r"(\w+)" - + VarTypesInfo[VarType.Var]["markerEnd"] + + VarTypesInfo[VarType.ElementVariable]["markerEnd"] ) regexPatternVarEnd = ( - VarTypesInfo[VarType.Var]["markerStart"] + VarTypesInfo[VarType.ElementVariable]["markerStart"] + r"(\w+)" + CommonMistakeVarTypesInfo["markerEnd"][i] ) @@ -171,8 +170,8 @@ def extendToFullSQL(partialSQL: str) -> Tuple[str, Scope]: partialSQL = ScopeExtension[scope] + partialSQL return partialSQL, scope - # Replace user-facing rule variables with internal tokens (same as v1 RuleParser.replaceVars). - # e.g., ==> V001, <> ==> VL001 + # Replace user-facing rule variables with internal tokens. + # e.g., ==> EV001, <> ==> SV001 # @staticmethod def replaceVars(pattern: str, rewrite: str) -> Tuple[str, str, Dict[str, str]]: @@ -203,11 +202,11 @@ def _replace_one_var_type( return pattern, rewrite mapping: Dict[str, str] = {} - pattern, rewrite = _replace_one_var_type(pattern, rewrite, VarType.VarList, mapping) - pattern, rewrite = _replace_one_var_type(pattern, rewrite, VarType.Var, mapping) + pattern, rewrite = _replace_one_var_type(pattern, rewrite, VarType.SetVariable, mapping) + pattern, rewrite = _replace_one_var_type(pattern, rewrite, VarType.ElementVariable, mapping) return pattern, rewrite, mapping - # parse a rule into project AST nodes (VarNode / VarSetNode for rule variables) + # parse a rule into project AST nodes (ElementVariableNode / SetVariableNode for rule variables) # @staticmethod def parse(pattern: str, rewrite: str) -> RuleParseResult: @@ -233,7 +232,7 @@ def parse(pattern: str, rewrite: str) -> RuleParseResult: pattern_query = qparser.parse(pattern_full) rewrite_query = qparser.parse(rewrite_full) - # 4. Map internal tokens (V00x / VL00x) to VarNode / VarSetNode across the full query AST + # 4. Map internal tokens (EV00x / SV00x) to ElementVariableNode / SetVariableNode across the full query AST # internal_to_external = {internal: external for external, internal in mapping.items()} pattern_after_vars = RuleParserV2._substitute_rule_vars(pattern_query, internal_to_external) @@ -244,14 +243,12 @@ def parse(pattern: str, rewrite: str) -> RuleParseResult: pattern_ast = RuleParserV2._extract_rule_fragment(pattern_after_vars, pattern_scope) rewrite_ast = RuleParserV2._extract_rule_fragment(rewrite_after_vars, rewrite_scope) - # 6. Return AST result + mapping + scopes + # 6. Return AST result + mapping # return RuleParseResult( pattern_ast=pattern_ast, rewrite_ast=rewrite_ast, mapping=mapping, - pattern_scope=pattern_scope, - rewrite_scope=rewrite_scope, ) # Find first child of query with given clause type (SELECT, FROM, WHERE, ...). @@ -263,7 +260,7 @@ def _get_clause(query: QueryNode, clause_type: NodeType) -> Optional[Node]: return child return None - # Apply internal_to_external across an entire parsed query (V00x -> VarNode, etc.). + # Apply internal_to_external across an entire parsed query (EV00x / SV00x -> ElementVariableNode, etc.). # @staticmethod def _substitute_rule_vars( @@ -274,7 +271,7 @@ def _substitute_rule_vars( raise TypeError("expected QueryNode after substituting rule variables on full query") return out - # Slice a fully substituted query to the rule fragment for this scope (no VarNode pass). + # Slice a fully substituted query to the rule fragment for this scope (no variable-node pass). # @staticmethod def _extract_rule_fragment(query: QueryNode, scope: Scope) -> Node: @@ -325,7 +322,7 @@ def _extract_rule_fragment(query: QueryNode, scope: Scope) -> Node: # return query - # Run VarNode / VarSetNode substitution on one subtree (None stays None). + # Run ElementVariableNode / SetVariableNode substitution on one subtree (None stays None). # @staticmethod def _as_rule_ast(node: Optional[Node], internal_to_external: Dict[str, str]) -> Optional[Node]: @@ -333,15 +330,15 @@ def _as_rule_ast(node: Optional[Node], internal_to_external: Dict[str, str]) -> return None return RuleParserV2._substitute_placeholders(node, internal_to_external) - # Build VarNode or VarSetNode from internal token prefix (V... vs VL...). + # Build ElementVariableNode or SetVariableNode from internal token prefix (EV... vs SV...). # @staticmethod def _placeholder_varnode(internal_token: str, external_name: str) -> Node: - if internal_token.startswith(VarTypesInfo[VarType.VarList]["internalBase"]): - return VarSetNode(external_name) - return VarNode(external_name) + if internal_token.startswith(VarTypesInfo[VarType.SetVariable]["internalBase"]): + return SetVariableNode(external_name) + return ElementVariableNode(external_name) - # Structural recursion: replace internal identifiers with VarNode / VarSetNode where appropriate. + # Structural recursion: replace internal identifiers with ElementVariableNode / SetVariableNode where appropriate. # @staticmethod def _substitute_placeholders(node: Node, rev: Dict[str, str]) -> Node: @@ -363,8 +360,11 @@ def _substitute_placeholders(node: Node, rev: Dict[str, str]) -> Node: t = node if not isinstance(t, TableNode): return node - new_name = rev.get(t.name, t.name) - new_alias = rev[t.alias] if t.alias and t.alias in rev else t.alias + new_name = rev.get(t.name, t.name) if isinstance(t.name, str) else t.name + if t.alias is not None and isinstance(t.alias, str) and t.alias in rev: + new_alias = rev[t.alias] + else: + new_alias = t.alias return TableNode(new_name, new_alias) if node.type == NodeType.QUERY: diff --git a/tests/ast_util.py b/tests/ast_util.py index 274e074..07a3f54 100644 --- a/tests/ast_util.py +++ b/tests/ast_util.py @@ -7,7 +7,7 @@ Node, QueryNode, SelectNode, FromNode, WhereNode, TableNode, ColumnNode, LiteralNode, OperatorNode, FunctionNode, GroupByNode, HavingNode, OrderByNode, OrderByItemNode, LimitNode, OffsetNode, JoinNode, SubqueryNode, - VarNode, VarSetNode + ElementVariableNode, SetVariableNode ) @@ -222,8 +222,8 @@ def _node_to_string(node: Node, indent: int = 0) -> str: for line in child_lines: result.append(line) - elif isinstance(node, (VarNode, VarSetNode)): - # VarNode/VarSetNode: VarSQL variable, display as "var: name" or "varset: name" + elif isinstance(node, (ElementVariableNode, SetVariableNode)): + # ElementVariableNode / SetVariableNode: rule variables ( / <>) result.append(f"{prefix}{node_type}: {node.name}") else: diff --git a/tests/test_ast.py b/tests/test_ast.py index deb6b2b..f9932d6 100644 --- a/tests/test_ast.py +++ b/tests/test_ast.py @@ -1,5 +1,5 @@ from core.ast.node import ( - TableNode, ColumnNode, LiteralNode, VarNode, VarSetNode, + TableNode, ColumnNode, LiteralNode, ElementVariableNode, SetVariableNode, OperatorNode, UnaryOperatorNode, FunctionNode, SelectNode, FromNode, WhereNode, GroupByNode, HavingNode, OrderByNode, LimitNode, OffsetNode, QueryNode ) @@ -42,9 +42,9 @@ def test_operand_nodes(): print(f" {null_literal.value} -> Type: {null_literal.type}") # Test VarSQL nodes - var_table = VarNode("V001") - var_column = VarNode("V002") - var_set = VarSetNode("VS001") + var_table = ElementVariableNode("V001") + var_column = ElementVariableNode("V002") + var_set = SetVariableNode("VS001") print(f"\nVarSQL nodes:") print(f" Variable {var_table.name} -> Type: {var_table.type}") @@ -243,11 +243,11 @@ def test_varsql_pattern_matching(): print("="*50) # Pattern: SELECT V1 FROM V2 WHERE V3 op V4 - var_select = VarNode("V1") # Any select item - var_table = VarNode("V2") # Any table - var_left = VarNode("V3") # Left operand of condition - var_op = VarNode("OP") # Any operator - var_right = VarNode("V4") # Right operand of condition + var_select = ElementVariableNode("V1") # Any select item + var_table = ElementVariableNode("V2") # Any table + var_left = ElementVariableNode("V3") # Left operand of condition + var_op = ElementVariableNode("OP") # Any operator + var_right = ElementVariableNode("V4") # Right operand of condition # Build pattern query pattern_select = SelectNode({var_select}) @@ -268,7 +268,7 @@ def test_varsql_pattern_matching(): print(f" Total pattern variables: 4 (V1, V2, V3, V4)") # Test VarSet for multiple columns - var_columns = VarSetNode("COLS") + var_columns = SetVariableNode("COLS") multi_select = SelectNode({var_columns}) print(f"\nVarSet pattern for multiple columns:") print(f" VarSet {var_columns.name} can match multiple SELECT items") diff --git a/tests/test_rule_parser_v2.py b/tests/test_rule_parser_v2.py index 8d45721..9a5420f 100644 --- a/tests/test_rule_parser_v2.py +++ b/tests/test_rule_parser_v2.py @@ -1,8 +1,31 @@ -"""Tests for core.rule_parser_v2 — mirrors tests in test_rule_parser.py plus AST (VarNode) paths.""" +"""Tests for core.rule_parser_v2 — unit tests plus one test per rule in data/rules.py. + +Catalog rule tests check three things: + +1. **Full pipeline runs** — ``parse`` extends SQL, runs ``QueryParser``, substitutes EV/SV + tokens back to ``ElementVariableNode`` / ``SetVariableNode``, and extracts the rule fragment. If anything + in that chain breaks for a catalog shape, the test fails. + +2. **Mapping matches ``replaceVars``** — ``RuleParserV2.parse`` is required to return the same + ``mapping`` dict as ``replaceVars`` (external name → internal token). That is partly a + contract with the implementation, but it still guards against ``parse`` accidentally + returning a different mapping than preprocessing. + +3. **AST variables are declared** — every ``ElementVariableNode`` / ``SetVariableNode`` in *both* fragment + ASTs has a ``.name`` that appears in ``mapping``. That does *not* follow from (2) alone: + it would fail if substitution attached the wrong identifier or leaked raw EV/SV tokens as + column names. (Mapping may list extra names used only as table/column identifiers, so we + do not require the reverse inclusion.) +""" + +from __future__ import annotations + +from typing import Iterator, Optional import pytest from core.ast.enums import NodeType +from core.ast.node import Node from core.ast.node import ( DataTypeNode, FromNode, @@ -12,11 +35,49 @@ QueryNode, SelectNode, TableNode, - VarNode, - VarSetNode, + ElementVariableNode, + SetVariableNode, WhereNode, ) from core.rule_parser_v2 import RuleParseResult, RuleParserV2, Scope +from data.rules import rules as RULES_CATALOG + + +def _rule_by_key(key: str) -> dict: + return next(r for r in RULES_CATALOG if r["key"] == key) + + +def _walk_var_and_varset_names(node: Optional[Node]) -> Iterator[str]: + if node is None: + return + if isinstance(node, ElementVariableNode): + yield node.name + elif isinstance(node, SetVariableNode): + yield node.name + ch = getattr(node, "children", None) + if not ch: + return + for child in ch: + yield from _walk_var_and_varset_names(child) + + +def _assert_varnodes_declared_in_mapping(result: RuleParseResult) -> None: + """Every ElementVariableNode / SetVariableNode must use an external name listed in ``mapping``.""" + keys = set(result.mapping.keys()) + for tree in (result.pattern_ast, result.rewrite_ast): + for name in _walk_var_and_varset_names(tree): + assert name in keys, ( + f"AST has ElementVariableNode/SetVariableNode {name!r} but mapping keys are {sorted(keys)}" + ) + + +def _parse_and_assert_catalog_rule(rule: dict) -> None: + pattern, rewrite = rule["pattern"], rule["rewrite"] + _, _, expected_mapping = RuleParserV2.replaceVars(pattern, rewrite) + result = RuleParserV2.parse(pattern, rewrite) + assert isinstance(result, RuleParseResult) + assert result.mapping == expected_mapping + _assert_varnodes_declared_in_mapping(result) def test_extendToFullSQL(): @@ -100,9 +161,9 @@ def test_replaceVars(): pattern = "CAST( AS DATE)" rewrite = "" pattern, rewrite, mapping = RuleParserV2.replaceVars(pattern, rewrite) - assert pattern == "CAST(V001 AS DATE)" - assert rewrite == "V001" - assert mapping == {"x": "V001"} + assert pattern == "CAST(EV001 AS DATE)" + assert rewrite == "EV001" + assert mapping == {"x": "EV001"} pattern = """ select <> @@ -118,26 +179,26 @@ def test_replaceVars(): """ pattern, rewrite, mapping = RuleParserV2.replaceVars(pattern, rewrite) assert pattern == """ - select VL001 - from V001 V002, - V003 V004 - where V002.V005=V004.V006 - and VL002 + select SV001 + from EV001 EV002, + EV003 EV004 + where EV002.EV005=EV004.EV006 + and SV002 """ assert rewrite == """ - select VL001 - from V001 V002 - where VL002 + select SV001 + from EV001 EV002 + where SV002 """ assert mapping == { - "s1": "VL001", - "p1": "VL002", - "tb1": "V001", - "t1": "V002", - "tb2": "V003", - "t2": "V004", - "a1": "V005", - "a2": "V006", + "s1": "SV001", + "p1": "SV002", + "tb1": "EV001", + "t1": "EV002", + "tb2": "EV003", + "t2": "EV004", + "a1": "EV005", + "a2": "EV006", } @@ -161,27 +222,24 @@ def test_parse_rejects_malformed_brackets_in_rewrite(): def test_parse_ast_cast_rule(): result = RuleParserV2.parse("CAST( AS DATE)", "") assert isinstance(result, RuleParseResult) - assert result.pattern_scope == Scope.CONDITION - assert result.rewrite_scope == Scope.CONDITION - assert result.mapping == {"x": "V001"} + assert result.mapping == {"x": "EV001"} assert isinstance(result.pattern_ast, FunctionNode) assert result.pattern_ast.name.lower() == "cast" cast_args = list(result.pattern_ast.children) - assert isinstance(cast_args[0], VarNode) and cast_args[0].name == "x" + assert isinstance(cast_args[0], ElementVariableNode) and cast_args[0].name == "x" assert isinstance(cast_args[1], DataTypeNode) - assert isinstance(result.rewrite_ast, VarNode) and result.rewrite_ast.name == "x" + assert isinstance(result.rewrite_ast, ElementVariableNode) and result.rewrite_ast.name == "x" def test_parse_ast_select_list_varset(): pattern = "select <> from lineitem where 1 = 1" rewrite = "select <> from lineitem where 1 = 1" result = RuleParserV2.parse(pattern, rewrite) - assert result.pattern_scope == Scope.SELECT assert isinstance(result.pattern_ast, QueryNode) select = next(c for c in result.pattern_ast.children if c.type == NodeType.SELECT) assert isinstance(select, SelectNode) first = list(select.children)[0] - assert isinstance(first, VarSetNode) and first.name == "s1" + assert isinstance(first, SetVariableNode) and first.name == "s1" def test_parse_ast_strpos_ilike_rule(): @@ -189,42 +247,36 @@ def test_parse_ast_strpos_ilike_rule(): "STRPOS(LOWER(), '') > 0", " ILIKE '%%'", ) - assert result.mapping == {"x": "V001", "s": "V002"} - assert result.pattern_scope == Scope.CONDITION - assert result.rewrite_scope == Scope.CONDITION + assert result.mapping == {"x": "EV001", "s": "EV002"} assert isinstance(result.pattern_ast, OperatorNode) assert result.pattern_ast.name == ">" strpos = list(result.pattern_ast.children)[0] assert isinstance(strpos, FunctionNode) and strpos.name.upper() == "STRPOS" lower = list(strpos.children)[0] assert isinstance(lower, FunctionNode) and lower.name.lower() == "lower" - assert isinstance(list(lower.children)[0], VarNode) and list(lower.children)[0].name == "x" + assert isinstance(list(lower.children)[0], ElementVariableNode) and list(lower.children)[0].name == "x" assert isinstance(result.rewrite_ast, FunctionNode) and result.rewrite_ast.name.lower() == "ilike" ilike_args = list(result.rewrite_ast.children) - assert isinstance(ilike_args[0], VarNode) and ilike_args[0].name == "x" + assert isinstance(ilike_args[0], ElementVariableNode) and ilike_args[0].name == "x" assert isinstance(ilike_args[1], LiteralNode) def test_parse_ast_where_scope(): result = RuleParserV2.parse("WHERE = 1", "WHERE = 1") - assert result.pattern_scope == Scope.WHERE - assert result.rewrite_scope == Scope.WHERE - assert result.mapping == {"x": "V001"} + assert result.mapping == {"x": "EV001"} assert isinstance(result.pattern_ast, QueryNode) wh = next(c for c in result.pattern_ast.children if c.type == NodeType.WHERE) assert isinstance(wh, WhereNode) pred = list(wh.children)[0] assert isinstance(pred, OperatorNode) and pred.name == "=" lhs, rhs = list(pred.children) - assert isinstance(lhs, VarNode) and lhs.name == "x" + assert isinstance(lhs, ElementVariableNode) and lhs.name == "x" assert isinstance(rhs, LiteralNode) and rhs.value == 1 def test_parse_ast_from_scope(): result = RuleParserV2.parse("FROM li", "FROM li") - assert result.pattern_scope == Scope.FROM - assert result.rewrite_scope == Scope.FROM - assert result.mapping == {"t": "V001"} + assert result.mapping == {"t": "EV001"} assert isinstance(result.pattern_ast, QueryNode) frm = next(c for c in result.pattern_ast.children if c.type == NodeType.FROM) assert isinstance(frm, FromNode) @@ -277,4 +329,232 @@ def test_parse_validator_6(): AND a <= 11 ''' index = RuleParserV2.find_malformed_brackets(pattern) - assert index == 6 \ No newline at end of file + assert index == 6 + + +# --- data/rules.py catalog (one test per rule, same order as in rules.py) --- + + +def test_rule_remove_max_distinct(): + """Rule remove_max_distinct: Remove Max Distinct.""" + _parse_and_assert_catalog_rule(_rule_by_key("remove_max_distinct")) + + +def test_rule_remove_cast_date(): + """Rule remove_cast_date: Remove Cast Date.""" + _parse_and_assert_catalog_rule(_rule_by_key("remove_cast_date")) + + +def test_rule_remove_cast_text(): + """Rule remove_cast_text: Remove Cast Text.""" + _parse_and_assert_catalog_rule(_rule_by_key("remove_cast_text")) + + +def test_rule_replace_strpos_lower(): + """Rule replace_strpos_lower: Replace Strpos Lower.""" + _parse_and_assert_catalog_rule(_rule_by_key("replace_strpos_lower")) + + +def test_rule_remove_self_join(): + """Rule remove_self_join: Remove Self Join.""" + _parse_and_assert_catalog_rule(_rule_by_key("remove_self_join")) + + +def test_rule_remove_self_join_advance(): + """Rule remove_self_join_advance: Remove Self Join Advance.""" + _parse_and_assert_catalog_rule(_rule_by_key("remove_self_join_advance")) + + +def test_rule_subquery_to_join(): + """Rule subquery_to_join: Subquery to Join.""" + _parse_and_assert_catalog_rule(_rule_by_key("subquery_to_join")) + + +def test_rule_join_to_filter(): + """Rule join_to_filter: Join to Filter.""" + _parse_and_assert_catalog_rule(_rule_by_key("join_to_filter")) + + +def test_rule_join_to_filter_advance(): + """Rule join_to_filter_advance: Join to Filter Advance.""" + _parse_and_assert_catalog_rule(_rule_by_key("join_to_filter_advance")) + + +def test_rule_join_to_filter_partial1(): + """Rule join_to_filter_partial1: Join to Filter Partial 1.""" + _parse_and_assert_catalog_rule(_rule_by_key("join_to_filter_partial1")) + + +def test_rule_join_to_filter_partial2(): + """Rule join_to_filter_partial2: Join to Filter Partial 2.""" + _parse_and_assert_catalog_rule(_rule_by_key("join_to_filter_partial2")) + + +def test_rule_join_to_filter_partial3(): + """Rule join_to_filter_partial3: Join to Filter Partial 3.""" + _parse_and_assert_catalog_rule(_rule_by_key("join_to_filter_partial3")) + + +def test_rule_remove_1useless_innerjoin(): + """Rule remove_1useless_innerjoin: Remove 1 Useless InnerJoin.""" + _parse_and_assert_catalog_rule(_rule_by_key("remove_1useless_innerjoin")) + + +def test_rule_remove_where_true(): + """Rule remove_where_true: Remove Where True.""" + _parse_and_assert_catalog_rule(_rule_by_key("remove_where_true")) + + +def test_rule_nested_clause_to_inner_join(): + """Rule nested_clause_to_inner_join: Nested Clause to Inner Join.""" + _parse_and_assert_catalog_rule(_rule_by_key("nested_clause_to_inner_join")) + + +def test_rule_contradiction_gt_lte(): + """Rule contradiction_gt_lte: Contradiction gt/lte.""" + _parse_and_assert_catalog_rule(_rule_by_key("contradiction_gt_lte")) + + +def test_rule_subquery_to_joins(): + """Rule subquery_to_joins: Subquery to Joins.""" + _parse_and_assert_catalog_rule(_rule_by_key("subquery_to_joins")) + + +def test_rule_aggregation_to_filtered_subquery(): + """Rule aggregation_to_filtered_subquery: Aggregation to Filtered Subquery.""" + _parse_and_assert_catalog_rule(_rule_by_key("aggregation_to_filtered_subquery")) + + +def test_rule_spreadsheet_id_2(): + """Rule spreadsheet_id_2: Spreadsheet ID 2.""" + _parse_and_assert_catalog_rule(_rule_by_key("spreadsheet_id_2")) + + +def test_rule_spreadsheet_id_3(): + """Rule spreadsheet_id_3: Spreadsheet ID 3.""" + _parse_and_assert_catalog_rule(_rule_by_key("spreadsheet_id_3")) + + +def test_rule_spreadsheet_id_4(): + """Rule spreadsheet_id_4: Spreadsheet ID 4.""" + _parse_and_assert_catalog_rule(_rule_by_key("spreadsheet_id_4")) + + +def test_rule_spreadsheet_id_6(): + """Rule spreadsheet_id_6: Spreadsheet ID 6.""" + _parse_and_assert_catalog_rule(_rule_by_key("spreadsheet_id_6")) + + +def test_rule_spreadsheet_id_7(): + """Rule spreadsheet_id_7: Spreadsheet ID 7.""" + _parse_and_assert_catalog_rule(_rule_by_key("spreadsheet_id_7")) + + +def test_rule_spreadsheet_id_9(): + """Rule spreadsheet_id_9: Spreadsheet ID 9.""" + _parse_and_assert_catalog_rule(_rule_by_key("spreadsheet_id_9")) + + +def test_rule_spreadsheet_id_10(): + """Rule spreadsheet_id_10: Spreadsheet ID 10.""" + _parse_and_assert_catalog_rule(_rule_by_key("spreadsheet_id_10")) + + +def test_rule_spreadsheet_id_11(): + """Rule spreadsheet_id_11: Spreadsheet ID 11.""" + _parse_and_assert_catalog_rule(_rule_by_key("spreadsheet_id_11")) + + +def test_rule_spreadsheet_id_12(): + """Rule spreadsheet_id_12: Spreadsheet ID 12.""" + _parse_and_assert_catalog_rule(_rule_by_key("spreadsheet_id_12")) + + +def test_rule_spreadsheet_id_15(): + """Rule spreadsheet_id_15: Spreadsheet ID 15.""" + _parse_and_assert_catalog_rule(_rule_by_key("spreadsheet_id_15")) + + +def test_rule_spreadsheet_id_18(): + """Rule spreadsheet_id_18: Spreadsheet ID 18.""" + _parse_and_assert_catalog_rule(_rule_by_key("spreadsheet_id_18")) + + +def test_rule_spreadsheet_id_20(): + """Rule spreadsheet_id_20: Spreadsheet ID 20.""" + _parse_and_assert_catalog_rule(_rule_by_key("spreadsheet_id_20")) + + +def test_rule_test_rule_wetune_90(): + """Rule test_rule_wetune_90: Test Rule Wetune 90.""" + _parse_and_assert_catalog_rule(_rule_by_key("test_rule_wetune_90")) + + +def test_rule_query_rule_wetune_90(): + """Rule query_rule_wetune_90: Query Rule Wetune 90.""" + _parse_and_assert_catalog_rule(_rule_by_key("query_rule_wetune_90")) + + +def test_rule_test_rule_calcite_testPushMinThroughUnion(): + """Rule test_rule_calcite_testPushMinThroughUnion: Test Rule Calcite testPushMinThroughUnion.""" + _parse_and_assert_catalog_rule(_rule_by_key("test_rule_calcite_testPushMinThroughUnion")) + + +def test_rule_remove_adddate(): + """Rule remove_adddate: Remove Adddate.""" + _parse_and_assert_catalog_rule(_rule_by_key("remove_adddate")) + + +def test_rule_remove_timestamp(): + """Rule remove_timestamp: Remove Timestamp.""" + _parse_and_assert_catalog_rule(_rule_by_key("remove_timestamp")) + + +def test_rule_stackoverflow_1(): + """Rule stackoverflow_1: Stackoverflow 1.""" + _parse_and_assert_catalog_rule(_rule_by_key("stackoverflow_1")) + + +def test_rule_combine_or_to_in(): + """Rule combine_or_to_in: combine multiple or to in.""" + _parse_and_assert_catalog_rule(_rule_by_key("combine_or_to_in")) + + +def test_rule_combine_3_or_to_in(): + """Rule combine_3_or_to_in: combine multiple or to in (3-way).""" + _parse_and_assert_catalog_rule(_rule_by_key("combine_3_or_to_in")) + + +def test_rule_merge_or_to_in(): + """Rule merge_or_to_in: merge or to in.""" + _parse_and_assert_catalog_rule(_rule_by_key("merge_or_to_in")) + + +def test_rule_merge_in_statements(): + """Rule merge_in_statements: merge statements with in condition.""" + _parse_and_assert_catalog_rule(_rule_by_key("merge_in_statements")) + + +def test_rule_multiple_merge_in(): + """Rule multiple_merge_in: multiple merge in.""" + _parse_and_assert_catalog_rule(_rule_by_key("multiple_merge_in")) + + +def test_rule_partial_subquery_to_join(): + """Rule partial_subquery_to_join: partial subquery to join.""" + _parse_and_assert_catalog_rule(_rule_by_key("partial_subquery_to_join")) + + +def test_rule_and_on_true(): + """Rule and_on_true: where TRUE and TRUE.""" + _parse_and_assert_catalog_rule(_rule_by_key("and_on_true")) + + +def test_rule_multiple_and_on_true(): + """Rule multiple_and_on_true: where TRUE and TRUE in set representation.""" + _parse_and_assert_catalog_rule(_rule_by_key("multiple_and_on_true")) + + +def test_rule_multiple_or_to_union(): + """Rule multiple_or_to_union: multiple or to union.""" + _parse_and_assert_catalog_rule(_rule_by_key("multiple_or_to_union")) From 0fd8c98e1ead0a5aa5b71cb1912dbaf20cf347b9 Mon Sep 17 00:00:00 2001 From: Colin Harrison Date: Tue, 7 Apr 2026 11:16:29 -0700 Subject: [PATCH 6/8] improve test cases --- core/rule_parser_v2.py | 17 +- tests/test_rule_parser_v2.py | 797 +++++++++++++++++++++-------------- 2 files changed, 493 insertions(+), 321 deletions(-) diff --git a/core/rule_parser_v2.py b/core/rule_parser_v2.py index e49fed0..41a835d 100644 --- a/core/rule_parser_v2.py +++ b/core/rule_parser_v2.py @@ -126,17 +126,6 @@ def find_malformed_brackets(pattern: str) -> int: return -1 - # Reject mismatched rule-variable brackets before any preprocessing (same intent as rule_generator). - # - @staticmethod - def _reject_malformed_var_brackets(pattern: str, rewrite: str) -> None: - i = RuleParserV2.find_malformed_brackets(pattern) - if i >= 0: - raise ValueError(f"mismatching brackets in pattern at index {i}") - i = RuleParserV2.find_malformed_brackets(rewrite) - if i >= 0: - raise ValueError(f"mismatching brackets in rewrite at index {i}") - # Extend pattern/rewrite fragment to full SQL (same as v1 RuleParser.extendToFullSQL). # @staticmethod @@ -211,10 +200,6 @@ def _replace_one_var_type( @staticmethod def parse(pattern: str, rewrite: str) -> RuleParseResult: - # 0. Reject mismatched <...> / <<...>> brackets - # - RuleParserV2._reject_malformed_var_brackets(pattern, rewrite) - # 1. Replace user-faced variables and variable lists # with internal representations # @@ -354,6 +339,8 @@ def _substitute_placeholders(node: Node, rev: Dict[str, str]) -> Node: return ColumnNode(rev[nm], _alias=col.alias, _parent_alias=rev[pa]) if pa is not None and pa in rev: return ColumnNode(nm, _alias=col.alias, _parent_alias=rev[pa]) + if pa is not None and nm in rev: + return ColumnNode(rev[nm], _alias=col.alias, _parent_alias=pa) return ColumnNode(nm, _alias=col.alias, _parent_alias=pa) if node.type == NodeType.TABLE: diff --git a/tests/test_rule_parser_v2.py b/tests/test_rule_parser_v2.py index 9a5420f..8d50f42 100644 --- a/tests/test_rule_parser_v2.py +++ b/tests/test_rule_parser_v2.py @@ -1,87 +1,133 @@ -"""Tests for core.rule_parser_v2 — unit tests plus one test per rule in data/rules.py. - -Catalog rule tests check three things: - -1. **Full pipeline runs** — ``parse`` extends SQL, runs ``QueryParser``, substitutes EV/SV - tokens back to ``ElementVariableNode`` / ``SetVariableNode``, and extracts the rule fragment. If anything - in that chain breaks for a catalog shape, the test fails. - -2. **Mapping matches ``replaceVars``** — ``RuleParserV2.parse`` is required to return the same - ``mapping`` dict as ``replaceVars`` (external name → internal token). That is partly a - contract with the implementation, but it still guards against ``parse`` accidentally - returning a different mapping than preprocessing. - -3. **AST variables are declared** — every ``ElementVariableNode`` / ``SetVariableNode`` in *both* fragment - ASTs has a ``.name`` that appears in ``mapping``. That does *not* follow from (2) alone: - it would fail if substitution attached the wrong identifier or leaked raw EV/SV tokens as - column names. (Mapping may list extra names used only as table/column identifiers, so we - do not require the reverse inclusion.) -""" - from __future__ import annotations -from typing import Iterator, Optional +import re +from typing import Iterator, List, Optional import pytest from core.ast.enums import NodeType -from core.ast.node import Node from core.ast.node import ( + CaseNode, + ColumnNode, DataTypeNode, FromNode, FunctionNode, + GroupByNode, + HavingNode, + JoinNode, + LimitNode, + ListNode, LiteralNode, + Node, + OffsetNode, OperatorNode, + OrderByItemNode, + OrderByNode, QueryNode, SelectNode, + SubqueryNode, TableNode, + UnaryOperatorNode, ElementVariableNode, SetVariableNode, + WhenThenNode, WhereNode, ) -from core.rule_parser_v2 import RuleParseResult, RuleParserV2, Scope +from core.rule_parser_v2 import RuleParseResult, RuleParserV2, Scope, VarType, VarTypesInfo from data.rules import rules as RULES_CATALOG -def _rule_by_key(key: str) -> dict: - return next(r for r in RULES_CATALOG if r["key"] == key) +# ═══════════════════════════════════════════════════════════════════════════════ +# Helpers +# ═══════════════════════════════════════════════════════════════════════════════ +_TOKEN_RE = re.compile(r"^(EV|SV)\d{3}$") -def _walk_var_and_varset_names(node: Optional[Node]) -> Iterator[str]: + +def _walk(node: Optional[Node]) -> Iterator[Node]: + """Depth-first walk of the AST.""" if node is None: return - if isinstance(node, ElementVariableNode): - yield node.name - elif isinstance(node, SetVariableNode): - yield node.name + yield node ch = getattr(node, "children", None) - if not ch: - return - for child in ch: - yield from _walk_var_and_varset_names(child) + if ch: + for child in ch: + yield from _walk(child) + + +def _walk_var_names(node: Optional[Node]) -> Iterator[str]: + """Yield names of all ElementVariableNode / SetVariableNode in the tree.""" + for n in _walk(node): + if isinstance(n, (ElementVariableNode, SetVariableNode)): + yield n.name + +def _find_first(node: Optional[Node], cls: type) -> Optional[Node]: + """Find first node of given type in the tree.""" + for n in _walk(node): + if isinstance(n, cls): + return n + return None -def _assert_varnodes_declared_in_mapping(result: RuleParseResult) -> None: - """Every ElementVariableNode / SetVariableNode must use an external name listed in ``mapping``.""" + +def _find_all(node: Optional[Node], cls: type) -> List[Node]: + """Find all nodes of given type in the tree.""" + return [n for n in _walk(node) if isinstance(n, cls)] + + +def _assert_varnodes_declared(result: RuleParseResult) -> None: + """Every ElementVariableNode / SetVariableNode must use an external name in ``mapping``.""" keys = set(result.mapping.keys()) - for tree in (result.pattern_ast, result.rewrite_ast): - for name in _walk_var_and_varset_names(tree): + for tree_label, tree in [("pattern", result.pattern_ast), ("rewrite", result.rewrite_ast)]: + for name in _walk_var_names(tree): assert name in keys, ( - f"AST has ElementVariableNode/SetVariableNode {name!r} but mapping keys are {sorted(keys)}" + f"{tree_label} AST has variable node {name!r} but mapping keys are {sorted(keys)}" ) -def _parse_and_assert_catalog_rule(rule: dict) -> None: - pattern, rewrite = rule["pattern"], rule["rewrite"] - _, _, expected_mapping = RuleParserV2.replaceVars(pattern, rewrite) - result = RuleParserV2.parse(pattern, rewrite) - assert isinstance(result, RuleParseResult) - assert result.mapping == expected_mapping - _assert_varnodes_declared_in_mapping(result) +def _assert_no_internal_tokens(result: RuleParseResult) -> None: + """No EV00x / SV00x tokens should survive as raw identifiers after substitution. + Known limitation: ``_substitute_placeholders`` does not replace a ColumnNode's + name when the column is qualified (``parent_alias`` is set) but the parent alias + itself is a literal not present in the internal-to-external mapping. We only flag + unqualified columns and qualified columns whose parent alias is also an internal + token. + """ + internal_tokens = set(result.mapping.values()) + + for tree_label, tree in [("pattern", result.pattern_ast), ("rewrite", result.rewrite_ast)]: + for n in _walk(tree): + if isinstance(n, ColumnNode): + if n.parent_alias is None: + assert not _TOKEN_RE.match(n.name), ( + f"{tree_label} AST has raw internal token {n.name!r} " + f"as unqualified ColumnNode" + ) + elif n.parent_alias in internal_tokens: + assert not _TOKEN_RE.match(n.parent_alias), ( + f"{tree_label} AST has raw internal token {n.parent_alias!r} " + f"as ColumnNode.parent_alias" + ) + assert not _TOKEN_RE.match(n.name), ( + f"{tree_label} AST has raw internal token {n.name!r} " + f"as ColumnNode.name (parent_alias was also an internal token)" + ) + # else: parent_alias is a literal (e.g. "t") and name is an internal + # token — known gap in _substitute_placeholders; skip. + + if isinstance(n, TableNode) and isinstance(n.name, str): + assert not _TOKEN_RE.match(n.name), ( + f"{tree_label} AST has raw internal token {n.name!r} as TableNode.name" + ) + + +# ═══════════════════════════════════════════════════════════════════════════════ +# extendToFullSQL +# ═══════════════════════════════════════════════════════════════════════════════ def test_extendToFullSQL(): - # Same assertions as tests/test_rule_parser.py::test_extendToFullSQL + # CONDITION scope pattern = "CAST(V1 AS DATE)" rewrite = "V1" pattern, scope = RuleParserV2.extendToFullSQL(pattern) @@ -91,6 +137,7 @@ def test_extendToFullSQL(): assert rewrite == "SELECT * FROM t WHERE V1" assert scope == Scope.CONDITION + # WHERE scope pattern = "WHERE CAST(V1 AS DATE)" rewrite = "WHERE V1" pattern, scope = RuleParserV2.extendToFullSQL(pattern) @@ -100,6 +147,7 @@ def test_extendToFullSQL(): assert rewrite == "SELECT * FROM t WHERE V1" assert scope == Scope.WHERE + # FROM scope pattern = "FROM lineitem" rewrite = "FROM v_lineitem" pattern, scope = RuleParserV2.extendToFullSQL(pattern) @@ -109,6 +157,7 @@ def test_extendToFullSQL(): assert rewrite == "SELECT * FROM v_lineitem" assert scope == Scope.FROM + # SELECT scope with FROM and WHERE pattern = """ select VL1 from V1 V2, @@ -138,6 +187,7 @@ def test_extendToFullSQL(): """ assert scope == Scope.SELECT + # SELECT scope with FROM pattern = "SELECT VL1 FROM lineitem" rewrite = "SELECT VL1 FROM v_lineitem" pattern, scope = RuleParserV2.extendToFullSQL(pattern) @@ -147,6 +197,7 @@ def test_extendToFullSQL(): assert rewrite == "SELECT VL1 FROM v_lineitem" assert scope == Scope.SELECT + # SELECT scope with only SELECT pattern = "SELECT CAST(V1 AS DATE)" rewrite = "SELECT V1" pattern, scope = RuleParserV2.extendToFullSQL(pattern) @@ -157,7 +208,32 @@ def test_extendToFullSQL(): assert scope == Scope.SELECT +def test_extendToFullSQL_subquery_not_confused(): + """Subquery inside parens shouldn't cause false FROM/SELECT scope detection.""" + sql, scope = RuleParserV2.extendToFullSQL( + "x IN (SELECT id FROM sub WHERE flag = 1)" + ) + assert scope == Scope.CONDITION + + +def test_extendToFullSQL_from_with_subquery_in_where(): + sql, scope = RuleParserV2.extendToFullSQL( + "FROM t WHERE x IN (SELECT id FROM sub)" + ) + assert scope == Scope.FROM + + +def test_extendToFullSQL_case_insensitive(): + sql, scope = RuleParserV2.extendToFullSQL("from my_table where x = 1") + assert scope == Scope.FROM + + +# ═══════════════════════════════════════════════════════════════════════════════ +# replaceVars +# ═══════════════════════════════════════════════════════════════════════════════ + def test_replaceVars(): + # Single element var pattern = "CAST( AS DATE)" rewrite = "" pattern, rewrite, mapping = RuleParserV2.replaceVars(pattern, rewrite) @@ -165,6 +241,7 @@ def test_replaceVars(): assert rewrite == "EV001" assert mapping == {"x": "EV001"} + # Multiple var and varList case pattern = """ select <> from , @@ -202,65 +279,142 @@ def test_replaceVars(): } -def test_parse_rejects_malformed_brackets_in_pattern(): - pattern = """WHERE 11 - AND a <= 11 - """ - with pytest.raises(ValueError, match=r"mismatching brackets in pattern at index 6"): - RuleParserV2.parse(pattern, "") +def test_replaceVars_distinct_names(): + """Set vars and element vars with different names get separate tokens.""" + p, r, m = RuleParserV2.replaceVars( + "SELECT <> FROM WHERE <>", + "SELECT <> FROM WHERE <>", + ) + assert m["cols"] == "SV001" + assert m["preds"] == "SV002" + assert m["tbl"] == "EV001" + + +def test_replaceVars_multiple_unique_tokens(): + p, r, m = RuleParserV2.replaceVars(" + + ", " + + ") + assert len(set(m.values())) == 3 + assert all(v.startswith("EV") for v in m.values()) + + +def test_replaceVars_same_var_in_both(): + """Same variable name in pattern and rewrite maps to the same token.""" + p, r, m = RuleParserV2.replaceVars(" = ", " = ") + assert m["x"] in p and m["x"] in r + assert m["y"] in p and m["y"] in r + + +# ═══════════════════════════════════════════════════════════════════════════════ +# find_malformed_brackets +# ═══════════════════════════════════════════════════════════════════════════════ +@pytest.mark.parametrize("bad_pattern,expected_index", [ + ("WHERE 11 AND a <= 11", 6), + ("WHERE 11 AND a <= 11", 6), + ("WHERE 11 AND a <= 11", 6), + ("WHERE [x> > 11 AND a <= 11", 6), + ("WHERE (x> > 11 AND a <= 11", 6), + ("WHERE {x> > 11 AND a <= 11", 6), +]) +def test_find_malformed_brackets(bad_pattern, expected_index): + assert RuleParserV2.find_malformed_brackets(bad_pattern) == expected_index -def test_parse_rejects_malformed_brackets_in_rewrite(): - pattern = "" - rewrite = """WHERE 11 - AND a <= 11 - """ - with pytest.raises(ValueError, match=r"mismatching brackets in rewrite at index 6"): - RuleParserV2.parse(pattern, rewrite) +def test_well_formed_brackets_return_negative(): + assert RuleParserV2.find_malformed_brackets(" = ") == -1 + assert RuleParserV2.find_malformed_brackets("<> AND <>") == -1 + + +# ═══════════════════════════════════════════════════════════════════════════════ +# parse() — CONDITION scope: deep AST structure +# ═══════════════════════════════════════════════════════════════════════════════ def test_parse_ast_cast_rule(): + """CAST( AS DATE) -> FunctionNode(cast, [ElementVariableNode, DataTypeNode])""" result = RuleParserV2.parse("CAST( AS DATE)", "") assert isinstance(result, RuleParseResult) assert result.mapping == {"x": "EV001"} assert isinstance(result.pattern_ast, FunctionNode) assert result.pattern_ast.name.lower() == "cast" cast_args = list(result.pattern_ast.children) + assert len(cast_args) == 2 assert isinstance(cast_args[0], ElementVariableNode) and cast_args[0].name == "x" assert isinstance(cast_args[1], DataTypeNode) assert isinstance(result.rewrite_ast, ElementVariableNode) and result.rewrite_ast.name == "x" -def test_parse_ast_select_list_varset(): - pattern = "select <> from lineitem where 1 = 1" - rewrite = "select <> from lineitem where 1 = 1" - result = RuleParserV2.parse(pattern, rewrite) - assert isinstance(result.pattern_ast, QueryNode) - select = next(c for c in result.pattern_ast.children if c.type == NodeType.SELECT) - assert isinstance(select, SelectNode) - first = list(select.children)[0] - assert isinstance(first, SetVariableNode) and first.name == "s1" - - def test_parse_ast_strpos_ilike_rule(): + """STRPOS(LOWER(), '') > 0 — deep operator / function / variable structure.""" result = RuleParserV2.parse( "STRPOS(LOWER(), '') > 0", " ILIKE '%%'", ) assert result.mapping == {"x": "EV001", "s": "EV002"} - assert isinstance(result.pattern_ast, OperatorNode) - assert result.pattern_ast.name == ">" - strpos = list(result.pattern_ast.children)[0] - assert isinstance(strpos, FunctionNode) and strpos.name.upper() == "STRPOS" - lower = list(strpos.children)[0] + # Pattern: > operator + pat = result.pattern_ast + assert isinstance(pat, OperatorNode) and pat.name == ">" + ch = list(pat.children) + assert isinstance(ch[0], FunctionNode) and ch[0].name.upper() == "STRPOS" + assert isinstance(ch[1], LiteralNode) and ch[1].value == 0 + # STRPOS -> LOWER -> ElementVariableNode + strpos_args = list(ch[0].children) + lower = strpos_args[0] assert isinstance(lower, FunctionNode) and lower.name.lower() == "lower" - assert isinstance(list(lower.children)[0], ElementVariableNode) and list(lower.children)[0].name == "x" - assert isinstance(result.rewrite_ast, FunctionNode) and result.rewrite_ast.name.lower() == "ilike" - ilike_args = list(result.rewrite_ast.children) + assert isinstance(list(lower.children)[0], ElementVariableNode) + assert list(lower.children)[0].name == "x" + assert isinstance(strpos_args[1], LiteralNode) + # Rewrite: ILIKE + rew = result.rewrite_ast + assert isinstance(rew, FunctionNode) and rew.name.lower() == "ilike" + ilike_args = list(rew.children) assert isinstance(ilike_args[0], ElementVariableNode) and ilike_args[0].name == "x" assert isinstance(ilike_args[1], LiteralNode) +def test_parse_ast_max_distinct(): + """MAX(DISTINCT ) -> MAX()""" + result = RuleParserV2.parse("MAX(DISTINCT )", "MAX()") + assert isinstance(result.pattern_ast, FunctionNode) and result.pattern_ast.name.lower() == "max" + assert isinstance(result.rewrite_ast, FunctionNode) and result.rewrite_ast.name.lower() == "max" + assert "x" in list(_walk_var_names(result.pattern_ast)) + assert "x" in list(_walk_var_names(result.rewrite_ast)) + + +def test_parse_ast_contradiction(): + """ > AND <= -> FALSE""" + result = RuleParserV2.parse(" > AND <= ", "FALSE") + assert isinstance(result.pattern_ast, OperatorNode) and result.pattern_ast.name.lower() == "and" + _assert_varnodes_declared(result) + _assert_no_internal_tokens(result) + + +def test_parse_ast_combine_or_to_in(): + """ = OR = -> IN (, )""" + result = RuleParserV2.parse(" = OR = ", " IN (, )") + assert isinstance(result.pattern_ast, OperatorNode) + assert isinstance(result.rewrite_ast, (OperatorNode, FunctionNode)) + _assert_varnodes_declared(result) + _assert_no_internal_tokens(result) + + +def test_parse_ast_or_to_case(): + """OR chain -> CASE WHEN — verifies CaseNode with 3 whens + else.""" + result = RuleParserV2.parse( + " OR OR ", + "1 = CASE WHEN THEN 1 WHEN THEN 1 WHEN THEN 1 ELSE 0 END", + ) + _assert_varnodes_declared(result) + _assert_no_internal_tokens(result) + case_nodes = _find_all(result.rewrite_ast, CaseNode) + assert len(case_nodes) >= 1, "Rewrite should contain a CaseNode" + case = case_nodes[0] + assert len(case.whens) == 3 + assert case.else_val is not None + + +# ═══════════════════════════════════════════════════════════════════════════════ +# parse() — WHERE scope +# ═══════════════════════════════════════════════════════════════════════════════ + def test_parse_ast_where_scope(): result = RuleParserV2.parse("WHERE = 1", "WHERE = 1") assert result.mapping == {"x": "EV001"} @@ -274,6 +428,18 @@ def test_parse_ast_where_scope(): assert isinstance(rhs, LiteralNode) and rhs.value == 1 +def test_parse_where_scope_strips_select_and_from(): + """WHERE scope extraction should produce no SelectNode or FromNode.""" + result = RuleParserV2.parse("WHERE > ", "WHERE > ") + assert isinstance(result.pattern_ast, QueryNode) + assert _find_first(result.pattern_ast, SelectNode) is None + assert _find_first(result.pattern_ast, FromNode) is None + + +# ═══════════════════════════════════════════════════════════════════════════════ +# parse() — FROM scope +# ═══════════════════════════════════════════════════════════════════════════════ + def test_parse_ast_from_scope(): result = RuleParserV2.parse("FROM li", "FROM li") assert result.mapping == {"t": "EV001"} @@ -284,277 +450,296 @@ def test_parse_ast_from_scope(): assert isinstance(tab, TableNode) and tab.name == "t" and tab.alias == "li" -def test_brackets_1(): - pattern = '''WHERE 11 - AND a <= 11 - ''' - index = RuleParserV2.find_malformed_brackets(pattern) - assert index == 6 - - -def test_brackets_2(): - pattern = '''WHERE 11 - AND a <= 11 - ''' - index = RuleParserV2.find_malformed_brackets(pattern) - assert index == 6 - - -def test_parse_validator_3(): - pattern = '''WHERE 11 - AND a <= 11 - ''' - index = RuleParserV2.find_malformed_brackets(pattern) - assert index == 6 - - -def test_parse_validator_4(): - pattern = '''WHERE [x> > 11 - AND a <= 11 - ''' - index = RuleParserV2.find_malformed_brackets(pattern) - assert index == 6 - - -def test_parse_validator_5(): - pattern = '''WHERE (x> > 11 - AND a <= 11 - ''' - index = RuleParserV2.find_malformed_brackets(pattern) - assert index == 6 - - -def test_parse_validator_6(): - pattern = '''WHERE {x> > 11 - AND a <= 11 - ''' - index = RuleParserV2.find_malformed_brackets(pattern) - assert index == 6 - - -# --- data/rules.py catalog (one test per rule, same order as in rules.py) --- - - -def test_rule_remove_max_distinct(): - """Rule remove_max_distinct: Remove Max Distinct.""" - _parse_and_assert_catalog_rule(_rule_by_key("remove_max_distinct")) - - -def test_rule_remove_cast_date(): - """Rule remove_cast_date: Remove Cast Date.""" - _parse_and_assert_catalog_rule(_rule_by_key("remove_cast_date")) - - -def test_rule_remove_cast_text(): - """Rule remove_cast_text: Remove Cast Text.""" - _parse_and_assert_catalog_rule(_rule_by_key("remove_cast_text")) - - -def test_rule_replace_strpos_lower(): - """Rule replace_strpos_lower: Replace Strpos Lower.""" - _parse_and_assert_catalog_rule(_rule_by_key("replace_strpos_lower")) - - -def test_rule_remove_self_join(): - """Rule remove_self_join: Remove Self Join.""" - _parse_and_assert_catalog_rule(_rule_by_key("remove_self_join")) - - -def test_rule_remove_self_join_advance(): - """Rule remove_self_join_advance: Remove Self Join Advance.""" - _parse_and_assert_catalog_rule(_rule_by_key("remove_self_join_advance")) - - -def test_rule_subquery_to_join(): - """Rule subquery_to_join: Subquery to Join.""" - _parse_and_assert_catalog_rule(_rule_by_key("subquery_to_join")) - - -def test_rule_join_to_filter(): - """Rule join_to_filter: Join to Filter.""" - _parse_and_assert_catalog_rule(_rule_by_key("join_to_filter")) - - -def test_rule_join_to_filter_advance(): - """Rule join_to_filter_advance: Join to Filter Advance.""" - _parse_and_assert_catalog_rule(_rule_by_key("join_to_filter_advance")) - - -def test_rule_join_to_filter_partial1(): - """Rule join_to_filter_partial1: Join to Filter Partial 1.""" - _parse_and_assert_catalog_rule(_rule_by_key("join_to_filter_partial1")) - - -def test_rule_join_to_filter_partial2(): - """Rule join_to_filter_partial2: Join to Filter Partial 2.""" - _parse_and_assert_catalog_rule(_rule_by_key("join_to_filter_partial2")) - - -def test_rule_join_to_filter_partial3(): - """Rule join_to_filter_partial3: Join to Filter Partial 3.""" - _parse_and_assert_catalog_rule(_rule_by_key("join_to_filter_partial3")) - - -def test_rule_remove_1useless_innerjoin(): - """Rule remove_1useless_innerjoin: Remove 1 Useless InnerJoin.""" - _parse_and_assert_catalog_rule(_rule_by_key("remove_1useless_innerjoin")) - - -def test_rule_remove_where_true(): - """Rule remove_where_true: Remove Where True.""" - _parse_and_assert_catalog_rule(_rule_by_key("remove_where_true")) - - -def test_rule_nested_clause_to_inner_join(): - """Rule nested_clause_to_inner_join: Nested Clause to Inner Join.""" - _parse_and_assert_catalog_rule(_rule_by_key("nested_clause_to_inner_join")) - - -def test_rule_contradiction_gt_lte(): - """Rule contradiction_gt_lte: Contradiction gt/lte.""" - _parse_and_assert_catalog_rule(_rule_by_key("contradiction_gt_lte")) - - -def test_rule_subquery_to_joins(): - """Rule subquery_to_joins: Subquery to Joins.""" - _parse_and_assert_catalog_rule(_rule_by_key("subquery_to_joins")) - - -def test_rule_aggregation_to_filtered_subquery(): - """Rule aggregation_to_filtered_subquery: Aggregation to Filtered Subquery.""" - _parse_and_assert_catalog_rule(_rule_by_key("aggregation_to_filtered_subquery")) - - -def test_rule_spreadsheet_id_2(): - """Rule spreadsheet_id_2: Spreadsheet ID 2.""" - _parse_and_assert_catalog_rule(_rule_by_key("spreadsheet_id_2")) - - -def test_rule_spreadsheet_id_3(): - """Rule spreadsheet_id_3: Spreadsheet ID 3.""" - _parse_and_assert_catalog_rule(_rule_by_key("spreadsheet_id_3")) - - -def test_rule_spreadsheet_id_4(): - """Rule spreadsheet_id_4: Spreadsheet ID 4.""" - _parse_and_assert_catalog_rule(_rule_by_key("spreadsheet_id_4")) - +def test_parse_from_scope_strips_select(): + """FROM scope extraction should produce no SelectNode.""" + result = RuleParserV2.parse("FROM ", "FROM ") + assert isinstance(result.pattern_ast, QueryNode) + assert _find_first(result.pattern_ast, SelectNode) is None -def test_rule_spreadsheet_id_6(): - """Rule spreadsheet_id_6: Spreadsheet ID 6.""" - _parse_and_assert_catalog_rule(_rule_by_key("spreadsheet_id_6")) +def test_parse_from_scope_with_where(): + """FROM WHERE ... — pattern keeps WHERE, rewrite without WHERE drops it.""" + result = RuleParserV2.parse( + "FROM WHERE > - 2", + "FROM ", + ) + assert isinstance(result.pattern_ast, QueryNode) + assert _find_first(result.pattern_ast, FromNode) is not None + assert _find_first(result.pattern_ast, WhereNode) is not None -def test_rule_spreadsheet_id_7(): - """Rule spreadsheet_id_7: Spreadsheet ID 7.""" - _parse_and_assert_catalog_rule(_rule_by_key("spreadsheet_id_7")) +def test_parse_from_scope_with_join(): + """FROM with INNER JOIN should produce JoinNode.""" + result = RuleParserV2.parse( + "FROM INNER JOIN ON . = .", + "FROM INNER JOIN ON . = .", + ) + assert isinstance(result.pattern_ast, QueryNode) + assert len(_find_all(result.pattern_ast, JoinNode)) >= 1 -def test_rule_spreadsheet_id_9(): - """Rule spreadsheet_id_9: Spreadsheet ID 9.""" - _parse_and_assert_catalog_rule(_rule_by_key("spreadsheet_id_9")) +# ═══════════════════════════════════════════════════════════════════════════════ +# parse() — SELECT scope: complex rules +# ═══════════════════════════════════════════════════════════════════════════════ -def test_rule_spreadsheet_id_10(): - """Rule spreadsheet_id_10: Spreadsheet ID 10.""" - _parse_and_assert_catalog_rule(_rule_by_key("spreadsheet_id_10")) +def test_parse_ast_select_list_varset(): + """SetVariableNode in the SELECT list.""" + result = RuleParserV2.parse( + "select <> from lineitem where 1 = 1", + "select <> from lineitem where 1 = 1", + ) + assert isinstance(result.pattern_ast, QueryNode) + select = next(c for c in result.pattern_ast.children if c.type == NodeType.SELECT) + assert isinstance(select, SelectNode) + first = list(select.children)[0] + assert isinstance(first, SetVariableNode) and first.name == "s1" -def test_rule_spreadsheet_id_11(): - """Rule spreadsheet_id_11: Spreadsheet ID 11.""" - _parse_and_assert_catalog_rule(_rule_by_key("spreadsheet_id_11")) +def test_parse_self_join_rule(): + """Remove Self Join: 2 tables in pattern, 1 in rewrite, SetVariableNodes present.""" + result = RuleParserV2.parse( + """select <> + from , + where .=. and <>""", + """select <> + from + where 1=1 and <>""", + ) + _assert_varnodes_declared(result) + _assert_no_internal_tokens(result) + assert len(_find_all(result.pattern_ast, TableNode)) >= 2 + assert len(_find_all(result.rewrite_ast, TableNode)) >= 1 + pat_svs = [n for n in _walk(result.pattern_ast) if isinstance(n, SetVariableNode)] + assert len(pat_svs) >= 2 # s1 and p1 -def test_rule_spreadsheet_id_12(): - """Rule spreadsheet_id_12: Spreadsheet ID 12.""" - _parse_and_assert_catalog_rule(_rule_by_key("spreadsheet_id_12")) +def test_parse_subquery_to_join_rule(): + """IN (SELECT ...) pattern has SubqueryNode; comma-join rewrite does not.""" + result = RuleParserV2.parse( + """select <> from + where in (select from where <>) + and <>""", + """select distinct <> from , + where . = . + and <> and <>""", + ) + _assert_varnodes_declared(result) + _assert_no_internal_tokens(result) + assert len(_find_all(result.pattern_ast, SubqueryNode)) >= 1 + assert len(_find_all(result.rewrite_ast, SubqueryNode)) == 0 -def test_rule_spreadsheet_id_15(): - """Rule spreadsheet_id_15: Spreadsheet ID 15.""" - _parse_and_assert_catalog_rule(_rule_by_key("spreadsheet_id_15")) +def test_parse_join_to_filter_rule(): + """Double INNER JOIN pattern has more JoinNodes than single INNER JOIN rewrite.""" + result = RuleParserV2.parse( + """select <> + from + inner join on . = . + inner join on . = . + where . = and <>""", + """select <> + from + inner join on . = . + where . = and <>""", + ) + _assert_varnodes_declared(result) + _assert_no_internal_tokens(result) + assert len(_find_all(result.pattern_ast, JoinNode)) > len( + _find_all(result.rewrite_ast, JoinNode) + ) -def test_rule_spreadsheet_id_18(): - """Rule spreadsheet_id_18: Spreadsheet ID 18.""" - _parse_and_assert_catalog_rule(_rule_by_key("spreadsheet_id_18")) +def test_parse_distinct_on(): + """DISTINCT ON should be preserved in the SelectNode.""" + result = RuleParserV2.parse( + "SELECT DISTINCT ON () , FROM ", + "SELECT , FROM ", + ) + _assert_varnodes_declared(result) + pat_sel = _find_first(result.pattern_ast, SelectNode) + assert pat_sel is not None + assert pat_sel.distinct or getattr(pat_sel, "distinct_on", None) is not None -def test_rule_spreadsheet_id_20(): - """Rule spreadsheet_id_20: Spreadsheet ID 20.""" - _parse_and_assert_catalog_rule(_rule_by_key("spreadsheet_id_20")) +def test_parse_order_by_and_limit(): + """ORDER BY and LIMIT should produce their respective node types.""" + result = RuleParserV2.parse( + "SELECT FROM ORDER BY ASC LIMIT ", + "SELECT FROM ORDER BY ASC LIMIT ", + ) + _assert_varnodes_declared(result) + assert _find_first(result.pattern_ast, OrderByNode) is not None + assert _find_first(result.pattern_ast, OrderByItemNode) is not None + assert _find_first(result.pattern_ast, LimitNode) is not None -def test_rule_test_rule_wetune_90(): - """Rule test_rule_wetune_90: Test Rule Wetune 90.""" - _parse_and_assert_catalog_rule(_rule_by_key("test_rule_wetune_90")) +def test_parse_distinct_to_group_by(): + """SELECT DISTINCT -> GROUP BY rewrite.""" + result = RuleParserV2.parse( + "SELECT DISTINCT <> FROM <> WHERE <>", + "SELECT <> FROM <> WHERE <> GROUP BY <>", + ) + _assert_varnodes_declared(result) + _assert_no_internal_tokens(result) + assert _find_first(result.rewrite_ast, GroupByNode) is not None + pat_sel = _find_first(result.pattern_ast, SelectNode) + if pat_sel is not None: + assert pat_sel.distinct is True -def test_rule_query_rule_wetune_90(): - """Rule query_rule_wetune_90: Query Rule Wetune 90.""" - _parse_and_assert_catalog_rule(_rule_by_key("query_rule_wetune_90")) +def test_parse_set_variable_in_select_and_where(): + """SetVariableNode should appear in both SELECT and WHERE.""" + result = RuleParserV2.parse( + "SELECT <> FROM tbl WHERE <>", + "SELECT <> FROM tbl WHERE <>", + ) + sv_names = {n.name for n in _walk(result.pattern_ast) if isinstance(n, SetVariableNode)} + assert "cols" in sv_names + assert "preds" in sv_names -def test_rule_test_rule_calcite_testPushMinThroughUnion(): - """Rule test_rule_calcite_testPushMinThroughUnion: Test Rule Calcite testPushMinThroughUnion.""" - _parse_and_assert_catalog_rule(_rule_by_key("test_rule_calcite_testPushMinThroughUnion")) +# ═══════════════════════════════════════════════════════════════════════════════ +# Column + parent_alias substitution +# ═══════════════════════════════════════════════════════════════════════════════ +def test_qualified_column_both_parts_substituted(): + """. — both parent_alias and name should become external names.""" + result = RuleParserV2.parse(". = 1", ". = 1") + _assert_varnodes_declared(result) + _assert_no_internal_tokens(result) + cols = _find_all(result.pattern_ast, ColumnNode) + qualified = [c for c in cols if c.parent_alias is not None] + assert len(qualified) >= 1 + for c in qualified: + assert c.parent_alias in result.mapping + assert c.name in result.mapping -def test_rule_remove_adddate(): - """Rule remove_adddate: Remove Adddate.""" - _parse_and_assert_catalog_rule(_rule_by_key("remove_adddate")) +def test_qualified_column_only_parent_alias_is_var(): + """.fixed_col — only the alias is a variable; the column name is a literal.""" + result = RuleParserV2.parse(".created_at = 1", ".created_at = 1") + _assert_no_internal_tokens(result) + cols = _find_all(result.pattern_ast, ColumnNode) + qualified = [c for c in cols if c.parent_alias is not None] + assert len(qualified) >= 1 -def test_rule_remove_timestamp(): - """Rule remove_timestamp: Remove Timestamp.""" - _parse_and_assert_catalog_rule(_rule_by_key("remove_timestamp")) +# ═══════════════════════════════════════════════════════════════════════════════ +# Error paths +# ═══════════════════════════════════════════════════════════════════════════════ -def test_rule_stackoverflow_1(): - """Rule stackoverflow_1: Stackoverflow 1.""" - _parse_and_assert_catalog_rule(_rule_by_key("stackoverflow_1")) +def test_invalid_sql_raises(): + """Completely invalid SQL should raise during parse.""" + with pytest.raises(Exception): + RuleParserV2.parse("!!NOT_VALID_SQL!!", "") -def test_rule_combine_or_to_in(): - """Rule combine_or_to_in: combine multiple or to in.""" - _parse_and_assert_catalog_rule(_rule_by_key("combine_or_to_in")) +def test_deeply_nested_parens(): + """Deeply nested expressions should not confuse scope detection.""" + result = RuleParserV2.parse( + "((( + ) * ) > 0)", + "( + ) * > 0", + ) + _assert_varnodes_declared(result) + + +# ═══════════════════════════════════════════════════════════════════════════════ +# No internal token leak — parametrized across shapes +# ═══════════════════════════════════════════════════════════════════════════════ + +@pytest.mark.parametrize("pattern,rewrite", [ + ("CAST( AS DATE)", ""), + ("STRPOS(LOWER(), '') > 0", " ILIKE '%%'"), + ("MAX(DISTINCT )", "MAX()"), + (" = OR = ", " IN (, )"), + ("WHERE = 1", "WHERE = 1"), + ("FROM ", "FROM "), +]) +def test_no_internal_tokens_survive(pattern, rewrite): + result = RuleParserV2.parse(pattern, rewrite) + _assert_no_internal_tokens(result) + _assert_varnodes_declared(result) -def test_rule_combine_3_or_to_in(): - """Rule combine_3_or_to_in: combine multiple or to in (3-way).""" - _parse_and_assert_catalog_rule(_rule_by_key("combine_3_or_to_in")) +# ═══════════════════════════════════════════════════════════════════════════════ +# Mapping consistency — parse() vs replaceVars() +# ═══════════════════════════════════════════════════════════════════════════════ +@pytest.mark.parametrize("pattern,rewrite", [ + ("CAST( AS DATE)", ""), + ("STRPOS(LOWER(), '') > 0", " ILIKE '%%'"), + ("SELECT <> FROM WHERE <>", "SELECT <> FROM WHERE <>"), +]) +def test_parse_mapping_matches_replaceVars(pattern, rewrite): + _, _, expected = RuleParserV2.replaceVars(pattern, rewrite) + result = RuleParserV2.parse(pattern, rewrite) + assert result.mapping == expected -def test_rule_merge_or_to_in(): - """Rule merge_or_to_in: merge or to in.""" - _parse_and_assert_catalog_rule(_rule_by_key("merge_or_to_in")) +# ═══════════════════════════════════════════════════════════════════════════════ +# Variable coverage +# ═══════════════════════════════════════════════════════════════════════════════ -def test_rule_merge_in_statements(): - """Rule merge_in_statements: merge statements with in condition.""" - _parse_and_assert_catalog_rule(_rule_by_key("merge_in_statements")) +def test_rewrite_vars_subset_of_pattern(): + """For simple rules, rewrite variables are a subset of pattern variables.""" + result = RuleParserV2.parse("CAST( AS DATE)", "") + assert set(_walk_var_names(result.rewrite_ast)) <= set(_walk_var_names(result.pattern_ast)) -def test_rule_multiple_merge_in(): - """Rule multiple_merge_in: multiple merge in.""" - _parse_and_assert_catalog_rule(_rule_by_key("multiple_merge_in")) +def test_identity_rule_same_vars(): + """An identity rewrite has the same variable set in both trees.""" + result = RuleParserV2.parse(" = ", " = ") + assert set(_walk_var_names(result.pattern_ast)) == set(_walk_var_names(result.rewrite_ast)) -def test_rule_partial_subquery_to_join(): - """Rule partial_subquery_to_join: partial subquery to join.""" - _parse_and_assert_catalog_rule(_rule_by_key("partial_subquery_to_join")) +# ═══════════════════════════════════════════════════════════════════════════════ +# VarType / VarTypesInfo metadata +# ═══════════════════════════════════════════════════════════════════════════════ +def test_element_var_markers(): + info = VarTypesInfo[VarType.ElementVariable] + assert info["markerStart"] == "<" + assert info["markerEnd"] == ">" + assert info["internalBase"] == "EV" -def test_rule_and_on_true(): - """Rule and_on_true: where TRUE and TRUE.""" - _parse_and_assert_catalog_rule(_rule_by_key("and_on_true")) +def test_set_var_markers(): + info = VarTypesInfo[VarType.SetVariable] + assert info["markerStart"] == "<<" + assert info["markerEnd"] == ">>" + assert info["internalBase"] == "SV" -def test_rule_multiple_and_on_true(): - """Rule multiple_and_on_true: where TRUE and TRUE in set representation.""" - _parse_and_assert_catalog_rule(_rule_by_key("multiple_and_on_true")) +# ═══════════════════════════════════════════════════════════════════════════════ +# data/rules.py catalog — parametrized over all rules +# ═══════════════════════════════════════════════════════════════════════════════ -def test_rule_multiple_or_to_union(): - """Rule multiple_or_to_union: multiple or to union.""" - _parse_and_assert_catalog_rule(_rule_by_key("multiple_or_to_union")) +@pytest.mark.parametrize( + "rule", + RULES_CATALOG, + ids=[r["key"] for r in RULES_CATALOG], +) +class TestCatalogRules: + + def test_parse_succeeds(self, rule): + """Full parse pipeline completes without error.""" + result = RuleParserV2.parse(rule["pattern"], rule["rewrite"]) + assert isinstance(result, RuleParseResult) + assert result.pattern_ast is not None + assert result.rewrite_ast is not None + + def test_mapping_matches_replaceVars(self, rule): + """parse() returns the same mapping as replaceVars().""" + _, _, expected = RuleParserV2.replaceVars(rule["pattern"], rule["rewrite"]) + result = RuleParserV2.parse(rule["pattern"], rule["rewrite"]) + assert result.mapping == expected + + def test_varnodes_declared_in_mapping(self, rule): + """Every variable node in the AST uses an external name present in mapping.""" + result = RuleParserV2.parse(rule["pattern"], rule["rewrite"]) + _assert_varnodes_declared(result) + + def test_no_internal_tokens_leak(self, rule): + """No EV00x / SV00x tokens survive as raw identifiers.""" + result = RuleParserV2.parse(rule["pattern"], rule["rewrite"]) + _assert_no_internal_tokens(result) \ No newline at end of file From efda13fe663d18175922dedf58710d6052a878e5 Mon Sep 17 00:00:00 2001 From: Colin Harrison Date: Wed, 8 Apr 2026 11:49:25 -0700 Subject: [PATCH 7/8] address comments --- core/ast/__init__.py | 1 + core/ast/node.py | 16 +++++++++++++++ core/rule_parser_v2.py | 32 ++++++++++++++++++++++++++++++ tests/test_rule_parser_v2.py | 38 +++++++++++++++++------------------- 4 files changed, 67 insertions(+), 20 deletions(-) diff --git a/core/ast/__init__.py b/core/ast/__init__.py index f45bf2c..1474504 100644 --- a/core/ast/__init__.py +++ b/core/ast/__init__.py @@ -22,6 +22,7 @@ GroupByNode, HavingNode, OrderByNode, + OrderByItemNode, LimitNode, OffsetNode, QueryNode diff --git a/core/ast/node.py b/core/ast/node.py index 8932f7f..03fb78d 100644 --- a/core/ast/node.py +++ b/core/ast/node.py @@ -174,6 +174,14 @@ def __init__(self, _name: str, **kwargs): super().__init__(NodeType.VAR, **kwargs) self.name = _name + def __eq__(self, other): + if not isinstance(other, ElementVariableNode): + return False + return super().__eq__(other) and self.name == other.name + + def __hash__(self): + return hash((super().__hash__(), self.name)) + class SetVariableNode(Node): """Rule set variable ``<>`` (see ``VarType.SetVariable`` in rule_parser_v2).""" @@ -181,6 +189,14 @@ def __init__(self, _name: str, **kwargs): super().__init__(NodeType.VARSET, **kwargs) self.name = _name + def __eq__(self, other): + if not isinstance(other, SetVariableNode): + return False + return super().__eq__(other) and self.name == other.name + + def __hash__(self): + return hash((super().__hash__(), self.name)) + class OperatorNode(Node): """Operator node""" diff --git a/core/rule_parser_v2.py b/core/rule_parser_v2.py index 41a835d..c0e13c2 100644 --- a/core/rule_parser_v2.py +++ b/core/rule_parser_v2.py @@ -18,6 +18,7 @@ HavingNode, IntervalNode, JoinNode, + LiteralNode, LimitNode, ListNode, Node, @@ -327,6 +328,13 @@ def _placeholder_varnode(internal_token: str, external_name: str) -> Node: # @staticmethod def _substitute_placeholders(node: Node, rev: Dict[str, str]) -> Node: + def _replace_internal_in_string(s: str) -> str: + # Replace EV00x / SV00x occurrences inside strings (e.g., '%EV001%'). + out = s + for internal, external in rev.items(): + out = out.replace(internal, external) + return out + if node.type == NodeType.COLUMN: col = node if not isinstance(col, ColumnNode): @@ -354,6 +362,14 @@ def _substitute_placeholders(node: Node, rev: Dict[str, str]) -> Node: new_alias = t.alias return TableNode(new_name, new_alias) + if node.type == NodeType.LITERAL: + lit = node + if not isinstance(lit, LiteralNode): + return node + if isinstance(lit.value, str): + return LiteralNode(_replace_internal_in_string(lit.value)) + return LiteralNode(lit.value) + if node.type == NodeType.QUERY: q = node if not isinstance(q, QueryNode): @@ -414,6 +430,22 @@ def _substitute_placeholders(node: Node, rev: Dict[str, str]) -> Node: return node return OrderByNode([RuleParserV2._substitute_placeholders(c, rev) for c in o.children]) + if node.type == NodeType.LIMIT: + lim = node + if not isinstance(lim, LimitNode): + return node + if isinstance(lim.limit, str): + return LimitNode(_replace_internal_in_string(lim.limit)) + return LimitNode(lim.limit) + + if node.type == NodeType.OFFSET: + off = node + if not isinstance(off, OffsetNode): + return node + if isinstance(off.offset, str): + return OffsetNode(_replace_internal_in_string(off.offset)) + return OffsetNode(off.offset) + if node.type == NodeType.ORDER_BY_ITEM: oi = node if not isinstance(oi, OrderByItemNode): diff --git a/tests/test_rule_parser_v2.py b/tests/test_rule_parser_v2.py index 8d50f42..c06a561 100644 --- a/tests/test_rule_parser_v2.py +++ b/tests/test_rule_parser_v2.py @@ -86,35 +86,20 @@ def _assert_varnodes_declared(result: RuleParseResult) -> None: def _assert_no_internal_tokens(result: RuleParseResult) -> None: - """No EV00x / SV00x tokens should survive as raw identifiers after substitution. - - Known limitation: ``_substitute_placeholders`` does not replace a ColumnNode's - name when the column is qualified (``parent_alias`` is set) but the parent alias - itself is a literal not present in the internal-to-external mapping. We only flag - unqualified columns and qualified columns whose parent alias is also an internal - token. - """ + """No EV00x / SV00x tokens should survive in identifier-bearing AST fields.""" internal_tokens = set(result.mapping.values()) for tree_label, tree in [("pattern", result.pattern_ast), ("rewrite", result.rewrite_ast)]: for n in _walk(tree): if isinstance(n, ColumnNode): - if n.parent_alias is None: - assert not _TOKEN_RE.match(n.name), ( - f"{tree_label} AST has raw internal token {n.name!r} " - f"as unqualified ColumnNode" - ) - elif n.parent_alias in internal_tokens: + assert not _TOKEN_RE.match(n.name), ( + f"{tree_label} AST has raw internal token {n.name!r} as ColumnNode.name" + ) + if n.parent_alias in internal_tokens: assert not _TOKEN_RE.match(n.parent_alias), ( f"{tree_label} AST has raw internal token {n.parent_alias!r} " f"as ColumnNode.parent_alias" ) - assert not _TOKEN_RE.match(n.name), ( - f"{tree_label} AST has raw internal token {n.name!r} " - f"as ColumnNode.name (parent_alias was also an internal token)" - ) - # else: parent_alias is a literal (e.g. "t") and name is an internal - # token — known gap in _substitute_placeholders; skip. if isinstance(n, TableNode) and isinstance(n.name, str): assert not _TOKEN_RE.match(n.name), ( @@ -368,6 +353,19 @@ def test_parse_ast_strpos_ilike_rule(): ilike_args = list(rew.children) assert isinstance(ilike_args[0], ElementVariableNode) and ilike_args[0].name == "x" assert isinstance(ilike_args[1], LiteralNode) + assert ilike_args[1].value == "%s%" + + +def test_substitute_placeholders_limit_offset_string_tokens(): + """Directly exercise LIMIT/OFFSET token replacement for string payloads.""" + lim = RuleParserV2._substitute_placeholders( # type: ignore[arg-type] + LimitNode("EV001"), {"EV001": "x"} + ) + off = RuleParserV2._substitute_placeholders( # type: ignore[arg-type] + OffsetNode("EV002"), {"EV002": "y"} + ) + assert isinstance(lim, LimitNode) and lim.limit == "x" + assert isinstance(off, OffsetNode) and off.offset == "y" def test_parse_ast_max_distinct(): From cd75120c7b750f791952525b78837cfcd6ca9a77 Mon Sep 17 00:00:00 2001 From: Colin Harrison Date: Wed, 8 Apr 2026 12:15:03 -0700 Subject: [PATCH 8/8] address comments --- core/rule_parser_v2.py | 16 ++++++++++------ tests/test_rule_parser_v2.py | 31 +++++++++++++++++++++++++++++++ 2 files changed, 41 insertions(+), 6 deletions(-) diff --git a/core/rule_parser_v2.py b/core/rule_parser_v2.py index c0e13c2..8bf392d 100644 --- a/core/rule_parser_v2.py +++ b/core/rule_parser_v2.py @@ -341,15 +341,17 @@ def _replace_internal_in_string(s: str) -> str: return node pa = col.parent_alias nm = col.name + new_alias = _replace_internal_in_string(col.alias) if isinstance(col.alias, str) else col.alias + new_pa = _replace_internal_in_string(pa) if isinstance(pa, str) else pa if pa is None and nm in rev: return RuleParserV2._placeholder_varnode(nm, rev[nm]) if pa is not None and pa in rev and nm in rev: - return ColumnNode(rev[nm], _alias=col.alias, _parent_alias=rev[pa]) + return ColumnNode(rev[nm], _alias=new_alias, _parent_alias=rev[pa]) if pa is not None and pa in rev: - return ColumnNode(nm, _alias=col.alias, _parent_alias=rev[pa]) + return ColumnNode(nm, _alias=new_alias, _parent_alias=rev[pa]) if pa is not None and nm in rev: - return ColumnNode(rev[nm], _alias=col.alias, _parent_alias=pa) - return ColumnNode(nm, _alias=col.alias, _parent_alias=pa) + return ColumnNode(rev[nm], _alias=new_alias, _parent_alias=new_pa) + return ColumnNode(nm, _alias=new_alias, _parent_alias=new_pa) if node.type == NodeType.TABLE: t = node @@ -470,14 +472,16 @@ def _replace_internal_in_string(s: str) -> str: if not isinstance(sq, SubqueryNode): return node inner = list(sq.children)[0] - return SubqueryNode(RuleParserV2._substitute_placeholders(inner, rev), sq.alias) + alias = _replace_internal_in_string(sq.alias) if isinstance(sq.alias, str) else sq.alias + return SubqueryNode(RuleParserV2._substitute_placeholders(inner, rev), alias) if node.type == NodeType.FUNCTION: f = node if not isinstance(f, FunctionNode): return node new_args = [RuleParserV2._substitute_placeholders(a, rev) for a in f.children] - return FunctionNode(f.name, _args=new_args, _alias=f.alias) + alias = _replace_internal_in_string(f.alias) if isinstance(f.alias, str) else f.alias + return FunctionNode(f.name, _args=new_args, _alias=alias) if node.type == NodeType.LIST: ln = node diff --git a/tests/test_rule_parser_v2.py b/tests/test_rule_parser_v2.py index c06a561..a25ce79 100644 --- a/tests/test_rule_parser_v2.py +++ b/tests/test_rule_parser_v2.py @@ -95,6 +95,10 @@ def _assert_no_internal_tokens(result: RuleParseResult) -> None: assert not _TOKEN_RE.match(n.name), ( f"{tree_label} AST has raw internal token {n.name!r} as ColumnNode.name" ) + if isinstance(n.alias, str): + assert not _TOKEN_RE.match(n.alias), ( + f"{tree_label} AST has raw internal token {n.alias!r} as ColumnNode.alias" + ) if n.parent_alias in internal_tokens: assert not _TOKEN_RE.match(n.parent_alias), ( f"{tree_label} AST has raw internal token {n.parent_alias!r} " @@ -105,6 +109,20 @@ def _assert_no_internal_tokens(result: RuleParseResult) -> None: assert not _TOKEN_RE.match(n.name), ( f"{tree_label} AST has raw internal token {n.name!r} as TableNode.name" ) + if isinstance(n.alias, str): + assert not _TOKEN_RE.match(n.alias), ( + f"{tree_label} AST has raw internal token {n.alias!r} as TableNode.alias" + ) + + if isinstance(n, SubqueryNode) and isinstance(n.alias, str): + assert not _TOKEN_RE.match(n.alias), ( + f"{tree_label} AST has raw internal token {n.alias!r} as SubqueryNode.alias" + ) + + if isinstance(n, FunctionNode) and isinstance(n.alias, str): + assert not _TOKEN_RE.match(n.alias), ( + f"{tree_label} AST has raw internal token {n.alias!r} as FunctionNode.alias" + ) # ═══════════════════════════════════════════════════════════════════════════════ @@ -368,6 +386,19 @@ def test_substitute_placeholders_limit_offset_string_tokens(): assert isinstance(off, OffsetNode) and off.offset == "y" +def test_parse_substitutes_alias_fields(): + """Column/function/subquery aliases should not leak EV/SV internal tokens.""" + result = RuleParserV2.parse( + "SELECT SUM() AS , t.c AS FROM (SELECT FROM ) AS , t", + "SELECT SUM() AS , t.c AS FROM (SELECT FROM ) AS , t", + ) + assert isinstance(result, RuleParseResult) + assert result.mapping["f_alias"].startswith("EV") + assert result.mapping["c_alias"].startswith("EV") + assert result.mapping["sq_alias"].startswith("EV") + _assert_no_internal_tokens(result) + + def test_parse_ast_max_distinct(): """MAX(DISTINCT ) -> MAX()""" result = RuleParserV2.parse("MAX(DISTINCT )", "MAX()")