@@ -114,6 +114,59 @@ def __eq__(self, other):
114114 def __hash__ (self ):
115115 return hash ((super ().__hash__ (), self .value ))
116116
117+ class DataTypeNode (Node ):
118+ """SQL data type node used in CAST expressions (e.g. TEXT, DATE, INTEGER)"""
119+ def __init__ (self , _name : str , ** kwargs ):
120+ super ().__init__ (NodeType .DATA_TYPE , ** kwargs )
121+ self .name = _name
122+
123+ def __eq__ (self , other ):
124+ if not isinstance (other , DataTypeNode ):
125+ return False
126+ return super ().__eq__ (other ) and self .name == other .name
127+
128+ def __hash__ (self ):
129+ return hash ((super ().__hash__ (), self .name ))
130+
131+
132+ class TimeUnitNode (Node ):
133+ """SQL time unit node used in INTERVAL and temporal functions (e.g. DAY, MONTH, SECOND)"""
134+ def __init__ (self , _name : str , ** kwargs ):
135+ super ().__init__ (NodeType .TIME_UNIT , ** kwargs )
136+ self .name = _name
137+
138+ def __eq__ (self , other ):
139+ if not isinstance (other , TimeUnitNode ):
140+ return False
141+ return super ().__eq__ (other ) and self .name == other .name
142+
143+ def __hash__ (self ):
144+ return hash ((super ().__hash__ (), self .name ))
145+
146+ class ListNode (Node ):
147+ """A list of nodes, e.g. the right-hand side of an IN expression"""
148+ def __init__ (self , _items : List [Node ], ** kwargs ):
149+ super ().__init__ (NodeType .LIST , children = _items , ** kwargs )
150+
151+ class IntervalNode (Node ):
152+ def __init__ (self , _value , _unit : TimeUnitNode , ** kwargs ):
153+ # Include the value in children when it is itself a Node, so that
154+ # generic traversals/formatters that walk via `children` see it.
155+ if isinstance (_value , Node ):
156+ children = [_value , _unit ]
157+ else :
158+ children = [_unit ]
159+ super ().__init__ (NodeType .INTERVAL , children = children , ** kwargs )
160+ self .value = _value
161+ self .unit = _unit
162+
163+ def __eq__ (self , other ):
164+ if not isinstance (other , IntervalNode ):
165+ return False
166+ return super ().__eq__ (other ) and self .value == other .value and self .unit == other .unit
167+
168+ def __hash__ (self ):
169+ return hash ((super ().__hash__ (), self .value , self .unit ))
117170
118171class VarNode (Node ):
119172 """VarSQL variable node"""
@@ -192,9 +245,22 @@ def __hash__(self):
192245# ============================================================================
193246
194247class SelectNode (Node ):
195- """SELECT clause node"""
196- def __init__ (self , _items : List ['Node' ], ** kwargs ):
197- super ().__init__ (NodeType .SELECT , children = _items , ** kwargs )
248+ """SELECT clause node. _distinct_on is the list of expressions for DISTINCT ON (e.g. ListNode of columns)."""
249+ def __init__ (self , _items : List ['Node' ], _distinct : bool = False , _distinct_on : Optional ['Node' ] = None , ** kwargs ):
250+ children = list (_items )
251+ if _distinct_on is not None :
252+ children .append (_distinct_on )
253+ super ().__init__ (NodeType .SELECT , children = children , ** kwargs )
254+ self .distinct = _distinct
255+ self .distinct_on = _distinct_on
256+
257+ def __eq__ (self , other ):
258+ if not isinstance (other , SelectNode ):
259+ return False
260+ return super ().__eq__ (other ) and self .distinct == other .distinct and self .distinct_on == other .distinct_on
261+
262+ def __hash__ (self ):
263+ return hash ((super ().__hash__ (), self .distinct , self .distinct_on ))
198264
199265
200266# TODO - confine the valid NodeTypes as children of FromNode
@@ -304,4 +370,38 @@ def __init__(self,
304370 children .append (_limit )
305371 if _offset :
306372 children .append (_offset )
307- super ().__init__ (NodeType .QUERY , children = children , ** kwargs )
373+ super ().__init__ (NodeType .QUERY , children = children , ** kwargs )
374+
375+ class WhenThenNode (Node ):
376+ """Single WHEN ... THEN ... branch of a CASE expression"""
377+ def __init__ (self , _when : Node , _then : Node , ** kwargs ):
378+ super ().__init__ (NodeType .WHEN_THEN , children = [_when , _then ], ** kwargs )
379+ self .when = _when
380+ self .then = _then
381+
382+ def __eq__ (self , other ):
383+ if not isinstance (other , WhenThenNode ):
384+ return False
385+ return super ().__eq__ (other ) and self .when == other .when and self .then == other .then
386+
387+ def __hash__ (self ):
388+ return hash ((super ().__hash__ (), self .when , self .then ))
389+
390+
391+ class CaseNode (Node ):
392+ """SQL CASE WHEN ... THEN ... ELSE ... END expression"""
393+ def __init__ (self , _whens : List [WhenThenNode ], _else : Optional [Node ] = None , ** kwargs ):
394+ children : List [Node ] = list (_whens )
395+ if _else is not None :
396+ children .append (_else )
397+ super ().__init__ (NodeType .CASE , children = children , ** kwargs )
398+ self .whens = _whens
399+ self .else_val = _else
400+
401+ def __eq__ (self , other ):
402+ if not isinstance (other , CaseNode ):
403+ return False
404+ return super ().__eq__ (other ) and self .whens == other .whens and self .else_val == other .else_val
405+
406+ def __hash__ (self ):
407+ return hash ((super ().__hash__ (), tuple (self .whens ), self .else_val ))
0 commit comments