From 0fc1a95424ce79c6bcf8ec2a2f528fd5fc46d8fd Mon Sep 17 00:00:00 2001 From: Lightning Sagar Date: Tue, 21 Apr 2026 18:55:06 +0530 Subject: [PATCH] fix: replace implicit SQLAlchemy select coercions in visitors --- api/app/sta2rest/visitors.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/api/app/sta2rest/visitors.py b/api/app/sta2rest/visitors.py index abf44466..8e2f29fd 100644 --- a/api/app/sta2rest/visitors.py +++ b/api/app/sta2rest/visitors.py @@ -418,8 +418,9 @@ def visit_ExpandNode(self, node: ExpandNode, parent=None): for relationship in join_relationships: sub_query = sub_query.join(relationship) + sub_query_from = sub_query.subquery() columns_to_select = [] - for column in sub_query.columns: + for column in sub_query_from.c: if column.name not in labels: columns_to_select.append(column) else: @@ -439,7 +440,7 @@ def visit_ExpandNode(self, node: ExpandNode, parent=None): ) if columns_to_select is not None: - sub_query = select(*columns_to_select).select_from(sub_query) + sub_query = select(*columns_to_select).select_from(sub_query_from) expand_queries.append( [ @@ -799,7 +800,7 @@ def visit_QueryNode(self, node: QueryNode): limited_count_query_str = str( select(func.count()) .select_from( - query_estimate_count.limit(COUNT_ESTIMATE_THRESHOLD) + query_estimate_count.limit(COUNT_ESTIMATE_THRESHOLD).subquery() ) .compile( dialect=engine.dialect, @@ -827,8 +828,9 @@ def visit_QueryNode(self, node: QueryNode): top_value -= 1 main_query = main_query.limit(top_value).offset(skip_value) + main_query_subquery = main_query.subquery("main_query_source") columns_to_select = [] - for column in main_query.columns: + for column in main_query_subquery.c: if column.name not in labels: columns_to_select.append(column) else: @@ -850,11 +852,11 @@ def visit_QueryNode(self, node: QueryNode): if columns_to_select is not None: main_query = ( select(*columns_to_select) - .select_from(main_query) + .select_from(main_query_subquery) .alias("main_query") ) else: - main_query = main_query.alias("main_query") + main_query = main_query_subquery.alias("main_query") if result_format == "DataArray": if not node.expand: @@ -913,9 +915,10 @@ def visit_QueryNode(self, node: QueryNode): value = select_query[0].name else: value = select_query[0].right + main_query_json = main_query.subquery("main_query_json") main_query = select( - main_query.c.json.op("->")(text(f"'{value}'")) - ).select_from(main_query) + main_query_json.c.json.op("->")(text(f"'{value}'")) + ).select_from(main_query_json) main_query_str = str( main_query.compile(