Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 18 additions & 9 deletions core/ast/node.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from datetime import datetime
from typing import List, Set, Optional, Union
from typing import List, Set, Optional, Tuple, Union
from abc import ABC

from .enums import NodeType, JoinType, SortOrder
Expand Down Expand Up @@ -101,18 +101,20 @@ def __hash__(self):

class LiteralNode(Node):
"""Literal value node"""
def __init__(self, _value: str|int|float|bool|datetime|None, **kwargs):
def __init__(self, _value: str|int|float|bool|datetime|None, _alias: Optional[str] = None, **kwargs):
super().__init__(NodeType.LITERAL, **kwargs)
self.value = _value
self.alias = _alias

def __eq__(self, other):
if not isinstance(other, LiteralNode):
return False
return (super().__eq__(other) and
self.value == other.value)
self.value == other.value and
self.alias == other.alias)

def __hash__(self):
return hash((super().__hash__(), self.value))
return hash((super().__hash__(), self.value, self.alias))

class DataTypeNode(Node):
"""SQL data type node used in CAST expressions (e.g. TEXT, DATE, INTEGER)"""
Expand Down Expand Up @@ -249,24 +251,31 @@ def __hash__(self):

class JoinNode(Node):
"""JOIN clause node"""
def __init__(self, _left_table: Union['TableNode', 'JoinNode', 'SubqueryNode'], _right_table: Union['TableNode', 'SubqueryNode'], _join_type: JoinType = JoinType.INNER, _on_condition: Optional['Node'] = None, **kwargs):
def __init__(self, _left_table: Union['TableNode', 'JoinNode', 'SubqueryNode'], _right_table: Union['TableNode', 'SubqueryNode'], _join_type: JoinType = JoinType.INNER, _on_condition: Optional['Node'] = None, _using: Optional[List['Node']] = None, **kwargs):
children = [_left_table, _right_table]
if _on_condition:
children.append(_on_condition)
if _using:
children.extend(_using)
super().__init__(NodeType.JOIN, children=children, **kwargs)
self.left_table = _left_table
self.right_table = _right_table
self.join_type = _join_type
self.on_condition = _on_condition

self.using = list(_using) if _using else None

def __eq__(self, other):
if not isinstance(other, JoinNode):
return False
return (super().__eq__(other) and
self.join_type == other.join_type)
self.join_type == other.join_type and
self.using == other.using)

def __hash__(self):
return hash((super().__hash__(), self.join_type))
using_key: Tuple = ()
if self.using:
using_key = tuple(self.using)
return hash((super().__hash__(), self.join_type, using_key))

# ============================================================================
# Query Structure Nodes
Expand Down Expand Up @@ -463,4 +472,4 @@ def __eq__(self, other):
return super().__eq__(other) and self.whens == other.whens and self.else_val == other.else_val

def __hash__(self):
return hash((super().__hash__(), tuple(self.whens), self.else_val))
return hash((super().__hash__(), tuple(self.whens), self.else_val))
65 changes: 32 additions & 33 deletions core/query_formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,19 +104,10 @@ def format_select(select_node: SelectNode) -> dict:

items = []
for child in children:
if child.type == NodeType.COLUMN:
if child.alias:
items.append({'name': child.alias, 'value': format_expression(child)})
else:
items.append({'value': format_expression(child)})
elif child.type == NodeType.FUNCTION:
func_expr = format_expression(child)
if hasattr(child, 'alias') and child.alias:
items.append({'name': child.alias, 'value': func_expr})
else:
items.append({'value': func_expr})
else:
items.append({'value': format_expression(child)})
item = {'value': format_expression(child)}
if hasattr(child, 'alias') and child.alias:
item['name'] = child.alias
items.append(item)

select_key = 'select_distinct' if select_node.distinct else 'select'
result[select_key] = items
Expand Down Expand Up @@ -172,47 +163,43 @@ def format_from(from_node: FromNode):

def format_join(join_node: JoinNode) -> list:
"""Format a JOIN node"""
children = list(join_node.children)

if len(children) < 2:
raise ValueError("JoinNode must have at least 2 children (left and right tables)")

left_node = children[0]
right_node = children[1]
join_condition = children[2] if len(children) > 2 else None

left_node = join_node.left_table
right_node = join_node.right_table
join_condition = join_node.on_condition
using_columns = join_node.using

result = []

# Format left side (could be a table or nested join)

if left_node.type == NodeType.JOIN:
# Nested join - recursively format
result.extend(format_join(left_node))
else:
# Simple table - this becomes the FROM table
result.append(format_source(left_node))

# Format the join itself
join_dict = {}

# Map join types to mosql format
join_type_map = {
JoinType.JOIN: 'join',
JoinType.INNER: 'inner join',
JoinType.LEFT: 'left join',
JoinType.RIGHT: 'right join',
JoinType.FULL: 'full join',
JoinType.CROSS: 'cross join',
JoinType.NATURAL: 'natural join',
}

join_key = join_type_map.get(join_node.join_type, 'join')
join_dict[join_key] = format_source(right_node)

# Add join condition if it exists

if join_condition:
join_dict['on'] = format_expression(join_condition)

if using_columns:
if len(using_columns) == 1:
join_dict['using'] = format_expression(using_columns[0])
else:
join_dict['using'] = [format_expression(col) for col in using_columns]

result.append(join_dict)

return result


Expand Down Expand Up @@ -401,5 +388,17 @@ def format_expression(node: Node):
unit = node.unit.name.lower()
return {'interval': [value, unit]}

elif node.type == NodeType.VAR:
return node.name

elif node.type == NodeType.VARSET:
return node.name

elif node.type == NodeType.QUERY:
return ast_to_json(node)

elif node.type == NodeType.COMPOUND_QUERY:
return compound_to_mosql_json(node)

else:
raise ValueError(f"Unsupported node type in expression: {node.type}")
raise ValueError(f"Unsupported node type in expression: {node.type}")
20 changes: 18 additions & 2 deletions core/query_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
)
# TODO: implement ElementVariableNode, SetVariableNode
from core.ast.enums import JoinType, SortOrder
from typing import List, Optional
import mo_sql_parsing as mosql
import json

Expand Down Expand Up @@ -133,8 +134,21 @@ def _append_source(node: Node, alias):
if 'on' in item:
on_condition = self.parse_expression(item['on'], aliases)

using_columns: Optional[List[Node]] = None
if 'using' in item:
using_value = item['using']
if isinstance(using_value, list):
using_columns = [
ColumnNode(str(c)) if not isinstance(c, dict) else self.parse_expression(c, aliases)
for c in using_value
]
elif isinstance(using_value, dict):
using_columns = [self.parse_expression(using_value, aliases)]
else:
using_columns = [ColumnNode(str(using_value))]

join_type = self.parse_join_type(join_key)
join_node = JoinNode(left_source, right_source, join_type, on_condition)
join_node = JoinNode(left_source, right_source, join_type, on_condition, using_columns)
left_source = join_node

elif 'value' in item:
Expand Down Expand Up @@ -592,7 +606,9 @@ def parse_join_type(join_key: str) -> JoinType:
"""Extract JoinType from mo_sql_parsing join key."""
key_lower = join_key.lower().replace(' ', '_')

if 'inner' in key_lower:
if 'natural' in key_lower:
return JoinType.NATURAL
Comment thread
colinthebomb1 marked this conversation as resolved.
elif 'inner' in key_lower:
return JoinType.INNER
elif 'left' in key_lower:
return JoinType.LEFT
Expand Down
Loading
Loading