Skip to content

Commit 5a749c7

Browse files
colinthebomb1Copilotbaiqiushi
authored
[Refactor] Add new AST node types and resolve AST TODOs (#99)
* fix ast TODOs with node.py changes * fix tests * comment out test * Update data/asts.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update core/ast/node.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update core/ast/node.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * fix node inconsistencies * Update core/ast/node.py Co-authored-by: QIUSHI BAI <baiqiushi@gmail.com> * address comments --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: QIUSHI BAI <baiqiushi@gmail.com>
1 parent 25b6710 commit 5a749c7

4 files changed

Lines changed: 160 additions & 82 deletions

File tree

core/ast/enums.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,11 @@ class NodeType(Enum):
1212
SUBQUERY = "subquery"
1313
COLUMN = "column"
1414
LITERAL = "literal"
15+
DATA_TYPE = "data_type"
16+
TIME_UNIT = "time_unit"
17+
LIST = "list"
18+
INTERVAL = "interval"
19+
1520
# VarSQL specific
1621
VAR = "var"
1722
VARSET = "varset"
@@ -32,6 +37,8 @@ class NodeType(Enum):
3237
LIMIT = "limit"
3338
OFFSET = "offset"
3439
QUERY = "query"
40+
CASE = "case"
41+
WHEN_THEN = "when_then"
3542

3643
# ============================================================================
3744
# Join Type Enumeration

core/ast/node.py

Lines changed: 104 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,59 @@ def __eq__(self, other):
114114
def __hash__(self):
115115
return hash((super().__hash__(), self.value))
116116

117+
class DataTypeNode(Node):
118+
"""SQL data type node used in CAST expressions (e.g. TEXT, DATE, INTEGER)"""
119+
def __init__(self, _name: str, **kwargs):
120+
super().__init__(NodeType.DATA_TYPE, **kwargs)
121+
self.name = _name
122+
123+
def __eq__(self, other):
124+
if not isinstance(other, DataTypeNode):
125+
return False
126+
return super().__eq__(other) and self.name == other.name
127+
128+
def __hash__(self):
129+
return hash((super().__hash__(), self.name))
130+
131+
132+
class TimeUnitNode(Node):
133+
"""SQL time unit node used in INTERVAL and temporal functions (e.g. DAY, MONTH, SECOND)"""
134+
def __init__(self, _name: str, **kwargs):
135+
super().__init__(NodeType.TIME_UNIT, **kwargs)
136+
self.name = _name
137+
138+
def __eq__(self, other):
139+
if not isinstance(other, TimeUnitNode):
140+
return False
141+
return super().__eq__(other) and self.name == other.name
142+
143+
def __hash__(self):
144+
return hash((super().__hash__(), self.name))
145+
146+
class ListNode(Node):
147+
"""A list of nodes, e.g. the right-hand side of an IN expression"""
148+
def __init__(self, _items: List[Node], **kwargs):
149+
super().__init__(NodeType.LIST, children=_items, **kwargs)
150+
151+
class IntervalNode(Node):
152+
def __init__(self, _value, _unit: TimeUnitNode, **kwargs):
153+
# Include the value in children when it is itself a Node, so that
154+
# generic traversals/formatters that walk via `children` see it.
155+
if isinstance(_value, Node):
156+
children = [_value, _unit]
157+
else:
158+
children = [_unit]
159+
super().__init__(NodeType.INTERVAL, children=children, **kwargs)
160+
self.value = _value
161+
self.unit = _unit
162+
163+
def __eq__(self, other):
164+
if not isinstance(other, IntervalNode):
165+
return False
166+
return super().__eq__(other) and self.value == other.value and self.unit == other.unit
167+
168+
def __hash__(self):
169+
return hash((super().__hash__(), self.value, self.unit))
117170

118171
class VarNode(Node):
119172
"""VarSQL variable node"""
@@ -192,9 +245,22 @@ def __hash__(self):
192245
# ============================================================================
193246

194247
class SelectNode(Node):
195-
"""SELECT clause node"""
196-
def __init__(self, _items: List['Node'], **kwargs):
197-
super().__init__(NodeType.SELECT, children=_items, **kwargs)
248+
"""SELECT clause node. _distinct_on is the list of expressions for DISTINCT ON (e.g. ListNode of columns)."""
249+
def __init__(self, _items: List['Node'], _distinct: bool = False, _distinct_on: Optional['Node'] = None, **kwargs):
250+
children = list(_items)
251+
if _distinct_on is not None:
252+
children.append(_distinct_on)
253+
super().__init__(NodeType.SELECT, children=children, **kwargs)
254+
self.distinct = _distinct
255+
self.distinct_on = _distinct_on
256+
257+
def __eq__(self, other):
258+
if not isinstance(other, SelectNode):
259+
return False
260+
return super().__eq__(other) and self.distinct == other.distinct and self.distinct_on == other.distinct_on
261+
262+
def __hash__(self):
263+
return hash((super().__hash__(), self.distinct, self.distinct_on))
198264

199265

200266
# TODO - confine the valid NodeTypes as children of FromNode
@@ -304,4 +370,38 @@ def __init__(self,
304370
children.append(_limit)
305371
if _offset:
306372
children.append(_offset)
307-
super().__init__(NodeType.QUERY, children=children, **kwargs)
373+
super().__init__(NodeType.QUERY, children=children, **kwargs)
374+
375+
class WhenThenNode(Node):
376+
"""Single WHEN ... THEN ... branch of a CASE expression"""
377+
def __init__(self, _when: Node, _then: Node, **kwargs):
378+
super().__init__(NodeType.WHEN_THEN, children=[_when, _then], **kwargs)
379+
self.when = _when
380+
self.then = _then
381+
382+
def __eq__(self, other):
383+
if not isinstance(other, WhenThenNode):
384+
return False
385+
return super().__eq__(other) and self.when == other.when and self.then == other.then
386+
387+
def __hash__(self):
388+
return hash((super().__hash__(), self.when, self.then))
389+
390+
391+
class CaseNode(Node):
392+
"""SQL CASE WHEN ... THEN ... ELSE ... END expression"""
393+
def __init__(self, _whens: List[WhenThenNode], _else: Optional[Node] = None, **kwargs):
394+
children: List[Node] = list(_whens)
395+
if _else is not None:
396+
children.append(_else)
397+
super().__init__(NodeType.CASE, children=children, **kwargs)
398+
self.whens = _whens
399+
self.else_val = _else
400+
401+
def __eq__(self, other):
402+
if not isinstance(other, CaseNode):
403+
return False
404+
return super().__eq__(other) and self.whens == other.whens and self.else_val == other.else_val
405+
406+
def __hash__(self):
407+
return hash((super().__hash__(), tuple(self.whens), self.else_val))

0 commit comments

Comments
 (0)