From ff6a3b04b9020dc07f7807c86b54121a54f0ecf4 Mon Sep 17 00:00:00 2001 From: May Liu Date: Thu, 26 Mar 2026 13:46:14 -0700 Subject: [PATCH] Fix chained filter() produced incorrect query result due to not wrapping sql text in parenthesis --- CHANGELOG.md | 2 + .../snowpark/_internal/analyzer/analyzer.py | 13 +++- tests/integ/test_column.py | 62 +++++++++++++++++++ 3 files changed, 74 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5db5b56cc0..7517668609 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,8 @@ #### Bug Fixes +- Fixed a bug where chained `DataFrame.filter()` calls with raw SQL text containing `OR` produced incorrect results. + #### Improvements - Restored the following query improvements that were reverted in 1.47.0 due to bugs: diff --git a/src/snowflake/snowpark/_internal/analyzer/analyzer.py b/src/snowflake/snowpark/_internal/analyzer/analyzer.py index 2571145058..5c76b81be7 100644 --- a/src/snowflake/snowpark/_internal/analyzer/analyzer.py +++ b/src/snowflake/snowpark/_internal/analyzer/analyzer.py @@ -798,10 +798,13 @@ def unary_expression_extractor( expr.is_permissive, ) else: + child_sql = self.analyze( + child, df_aliased_col_name_to_real_col_name, parse_local_name + ) + if isinstance(child, UnresolvedAttribute) and child.is_sql_text: + child_sql = f"({child_sql})" return unary_expression( - self.analyze( - child, df_aliased_col_name_to_real_col_name, parse_local_name - ), + child_sql, expr.sql_operator, expr.operator_first, ) @@ -830,6 +833,10 @@ def binary_operator_extractor( right_sql_expr = self.analyze( right, df_aliased_col_name_to_real_col_name, parse_local_name ) + if isinstance(left, UnresolvedAttribute) and left.is_sql_text: + left_sql_expr = f"({left_sql_expr})" + if isinstance(right, UnresolvedAttribute) and right.is_sql_text: + right_sql_expr = f"({right_sql_expr})" if isinstance(expr, BinaryArithmeticExpression): return binary_arithmetic_expression( expr.sql_operator, diff --git a/tests/integ/test_column.py b/tests/integ/test_column.py index d515fb35dd..747bdb5143 100644 --- a/tests/integ/test_column.py +++ b/tests/integ/test_column.py @@ -14,6 +14,7 @@ col, lit, parse_json, + sql_expr, try_parse_json, second, to_timestamp, @@ -331,6 +332,67 @@ def test_when_accept_sql_expr(session): ).collect() == [Row(5), Row(None), Row(6), Row(None), Row(5)] +@pytest.mark.xfail( + "config.getoption('local_testing_mode', default=False)", + reason="SQL expr is not supported in Local Testing", + run=False, +) +def test_sql_text_parenthesis_wrapped(session): + df = session.create_dataframe( + [ + ("E", "XX", "EQ"), + ("OTHER", "EQ", "EQ"), + ("E", "XX", "OTHER"), + ("OTHER", "XX", "OTHER"), + ], + schema=["a", "b", "c"], + ) + + Utils.check_answer( + df.filter("a = 'E' or b = 'EQ'").filter("c = 'EQ'"), + df.filter((col("a") == "E") | (col("b") == "EQ")).filter(col("c") == "EQ"), + ) + + Utils.check_answer( + df.filter("a = 'E' or b = 'EQ'").filter("c = 'EQ'"), + df.filter(sql_expr("a = 'E' or b = 'EQ'")).filter(sql_expr("c = 'EQ'")), + ) + + Utils.check_answer( + df.filter("a = 'E' or b = 'EQ'").filter("c = 'EQ' or b = 'XX'"), + df.filter((col("a") == "E") | (col("b") == "EQ")).filter( + (col("c") == "EQ") | (col("b") == "XX") + ), + ) + + df2 = session.create_dataframe( + [("E", "XX"), ("OTHER", "EQ"), ("OTHER", "XX")], + schema=["a", "b"], + ) + Utils.check_answer( + df2.filter(~sql_expr("a = 'E' or b = 'EQ'")), + [Row("OTHER", "XX")], + ) + + df3 = session.create_dataframe( + [[1, 2], [3, 4], [5, 6]], + schema=["a", "b"], + ) + Utils.check_answer( + df3.filter(sql_expr("a > 2")), + df3.filter(col("a") > 2), + ) + Utils.check_answer( + df3.select(sql_expr("a + b")), + [Row(3), Row(7), Row(11)], + sort=False, + ) + Utils.check_answer( + df3.filter("a > 2 and b < 6"), + [Row(3, 4)], + ) + + @pytest.mark.skipif( "config.getoption('local_testing_mode', default=False)", reason="SNOW-1358930 TODO: Decimal should not be casted to int64",