Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#### Bug Fixes

- Fixed a Snowflake platform compatibility issue (SNOW-3259059) where `concat(lit('"'), ...)` could lose the leading quote through some `EXCEPT` / chained set-operation plans by lowering that literal to `CHR(34)` in generated SQL.
- Fixed a bug where chained `DataFrame.filter()` calls with raw SQL text containing `OR` produced incorrect results.

#### Improvements

Expand Down
13 changes: 10 additions & 3 deletions src/snowflake/snowpark/_internal/analyzer/analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -798,10 +798,13 @@ def unary_expression_extractor(
expr.is_permissive,
)
else:
child_sql = self.analyze(
child, df_aliased_col_name_to_real_col_name, parse_local_name
)
if isinstance(child, UnresolvedAttribute) and child.is_sql_text:
child_sql = f"({child_sql})"
return unary_expression(
self.analyze(
child, df_aliased_col_name_to_real_col_name, parse_local_name
),
child_sql,
expr.sql_operator,
expr.operator_first,
)
Expand Down Expand Up @@ -830,6 +833,10 @@ def binary_operator_extractor(
right_sql_expr = self.analyze(
right, df_aliased_col_name_to_real_col_name, parse_local_name
)
if isinstance(left, UnresolvedAttribute) and left.is_sql_text:
left_sql_expr = f"({left_sql_expr})"
if isinstance(right, UnresolvedAttribute) and right.is_sql_text:
right_sql_expr = f"({right_sql_expr})"
if isinstance(expr, BinaryArithmeticExpression):
return binary_arithmetic_expression(
expr.sql_operator,
Expand Down
62 changes: 62 additions & 0 deletions tests/integ/test_column.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
col,
lit,
parse_json,
sql_expr,
try_parse_json,
second,
to_timestamp,
Expand Down Expand Up @@ -331,6 +332,67 @@ def test_when_accept_sql_expr(session):
).collect() == [Row(5), Row(None), Row(6), Row(None), Row(5)]


@pytest.mark.xfail(
"config.getoption('local_testing_mode', default=False)",
reason="SQL expr is not supported in Local Testing",
run=False,
)
def test_sql_text_parenthesis_wrapped(session):
df = session.create_dataframe(
[
("E", "XX", "EQ"),
("OTHER", "EQ", "EQ"),
("E", "XX", "OTHER"),
("OTHER", "XX", "OTHER"),
],
schema=["a", "b", "c"],
)

Utils.check_answer(
df.filter("a = 'E' or b = 'EQ'").filter("c = 'EQ'"),
df.filter((col("a") == "E") | (col("b") == "EQ")).filter(col("c") == "EQ"),
)

Utils.check_answer(
df.filter("a = 'E' or b = 'EQ'").filter("c = 'EQ'"),
df.filter(sql_expr("a = 'E' or b = 'EQ'")).filter(sql_expr("c = 'EQ'")),
)

Utils.check_answer(
df.filter("a = 'E' or b = 'EQ'").filter("c = 'EQ' or b = 'XX'"),
df.filter((col("a") == "E") | (col("b") == "EQ")).filter(
(col("c") == "EQ") | (col("b") == "XX")
),
)

df2 = session.create_dataframe(
[("E", "XX"), ("OTHER", "EQ"), ("OTHER", "XX")],
schema=["a", "b"],
)
Utils.check_answer(
df2.filter(~sql_expr("a = 'E' or b = 'EQ'")),
[Row("OTHER", "XX")],
)

df3 = session.create_dataframe(
[[1, 2], [3, 4], [5, 6]],
schema=["a", "b"],
)
Utils.check_answer(
df3.filter(sql_expr("a > 2")),
df3.filter(col("a") > 2),
)
Utils.check_answer(
df3.select(sql_expr("a + b")),
[Row(3), Row(7), Row(11)],
sort=False,
)
Utils.check_answer(
df3.filter("a > 2 and b < 6"),
[Row(3, 4)],
)


@pytest.mark.skipif(
"config.getoption('local_testing_mode', default=False)",
reason="SNOW-1358930 TODO: Decimal should not be casted to int64",
Expand Down
Loading