diff --git a/CHANGELOG.md b/CHANGELOG.md index 807e08bfc6..eea6446a48 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,7 @@ #### Bug Fixes - Fixed a Snowflake platform compatibility issue (SNOW-3259059) where `concat(lit('"'), ...)` could lose the leading quote through some `EXCEPT` / chained set-operation plans by lowering that literal to `CHR(34)` in generated SQL. +- Fixed a bug where chained `DataFrame.filter()` calls with raw SQL text containing `OR` produced incorrect results. #### Improvements diff --git a/src/snowflake/snowpark/_internal/analyzer/analyzer.py b/src/snowflake/snowpark/_internal/analyzer/analyzer.py index 2571145058..5c76b81be7 100644 --- a/src/snowflake/snowpark/_internal/analyzer/analyzer.py +++ b/src/snowflake/snowpark/_internal/analyzer/analyzer.py @@ -798,10 +798,13 @@ def unary_expression_extractor( expr.is_permissive, ) else: + child_sql = self.analyze( + child, df_aliased_col_name_to_real_col_name, parse_local_name + ) + if isinstance(child, UnresolvedAttribute) and child.is_sql_text: + child_sql = f"({child_sql})" return unary_expression( - self.analyze( - child, df_aliased_col_name_to_real_col_name, parse_local_name - ), + child_sql, expr.sql_operator, expr.operator_first, ) @@ -830,6 +833,10 @@ def binary_operator_extractor( right_sql_expr = self.analyze( right, df_aliased_col_name_to_real_col_name, parse_local_name ) + if isinstance(left, UnresolvedAttribute) and left.is_sql_text: + left_sql_expr = f"({left_sql_expr})" + if isinstance(right, UnresolvedAttribute) and right.is_sql_text: + right_sql_expr = f"({right_sql_expr})" if isinstance(expr, BinaryArithmeticExpression): return binary_arithmetic_expression( expr.sql_operator, diff --git a/tests/integ/test_column.py b/tests/integ/test_column.py index d515fb35dd..747bdb5143 100644 --- a/tests/integ/test_column.py +++ b/tests/integ/test_column.py @@ -14,6 +14,7 @@ col, lit, parse_json, + sql_expr, try_parse_json, second, to_timestamp, @@ -331,6 +332,67 @@ def test_when_accept_sql_expr(session): ).collect() == [Row(5), Row(None), Row(6), Row(None), Row(5)] +@pytest.mark.xfail( + "config.getoption('local_testing_mode', default=False)", + reason="SQL expr is not supported in Local Testing", + run=False, +) +def test_sql_text_parenthesis_wrapped(session): + df = session.create_dataframe( + [ + ("E", "XX", "EQ"), + ("OTHER", "EQ", "EQ"), + ("E", "XX", "OTHER"), + ("OTHER", "XX", "OTHER"), + ], + schema=["a", "b", "c"], + ) + + Utils.check_answer( + df.filter("a = 'E' or b = 'EQ'").filter("c = 'EQ'"), + df.filter((col("a") == "E") | (col("b") == "EQ")).filter(col("c") == "EQ"), + ) + + Utils.check_answer( + df.filter("a = 'E' or b = 'EQ'").filter("c = 'EQ'"), + df.filter(sql_expr("a = 'E' or b = 'EQ'")).filter(sql_expr("c = 'EQ'")), + ) + + Utils.check_answer( + df.filter("a = 'E' or b = 'EQ'").filter("c = 'EQ' or b = 'XX'"), + df.filter((col("a") == "E") | (col("b") == "EQ")).filter( + (col("c") == "EQ") | (col("b") == "XX") + ), + ) + + df2 = session.create_dataframe( + [("E", "XX"), ("OTHER", "EQ"), ("OTHER", "XX")], + schema=["a", "b"], + ) + Utils.check_answer( + df2.filter(~sql_expr("a = 'E' or b = 'EQ'")), + [Row("OTHER", "XX")], + ) + + df3 = session.create_dataframe( + [[1, 2], [3, 4], [5, 6]], + schema=["a", "b"], + ) + Utils.check_answer( + df3.filter(sql_expr("a > 2")), + df3.filter(col("a") > 2), + ) + Utils.check_answer( + df3.select(sql_expr("a + b")), + [Row(3), Row(7), Row(11)], + sort=False, + ) + Utils.check_answer( + df3.filter("a > 2 and b < 6"), + [Row(3, 4)], + ) + + @pytest.mark.skipif( "config.getoption('local_testing_mode', default=False)", reason="SNOW-1358930 TODO: Decimal should not be casted to int64",