-
Notifications
You must be signed in to change notification settings - Fork 146
SNOW-2203826: Loosen flattening rules for sort and filter #4026
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
b694e2f
a25942b
1df1e96
1104361
d67f5df
43c3316
3f7a98e
8035696
c00e140
4c75cb2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -20,6 +20,7 @@ | |
| Sequence, | ||
| Set, | ||
| Union, | ||
| Literal, | ||
| ) | ||
|
|
||
| import snowflake.snowpark._internal.utils | ||
|
|
@@ -86,6 +87,7 @@ | |
| is_sql_select_statement, | ||
| ExprAliasUpdateDict, | ||
| ) | ||
| import snowflake.snowpark.context as context | ||
|
|
||
| # Python 3.8 needs to use typing.Iterable because collections.abc.Iterable is not subscriptable | ||
| # Python 3.9 can use both | ||
|
|
@@ -1386,9 +1388,9 @@ def select(self, cols: List[Expression]) -> "SelectStatement": | |
| ): | ||
| # TODO: Clean up, this entire if case is parameter protection | ||
| can_be_flattened = False | ||
| elif (self.where or self.order_by or self.limit_) and has_data_generator_exp( | ||
| cols | ||
| ): | ||
| elif ( | ||
| self.where or self.order_by or self.limit_ | ||
| ) and has_data_generator_or_window_function_exp(cols): | ||
| can_be_flattened = False | ||
| elif self.where and ( | ||
| (subquery_dependent_columns := derive_dependent_columns(self.where)) | ||
|
|
@@ -1399,6 +1401,20 @@ def select(self, cols: List[Expression]) -> "SelectStatement": | |
| subquery_dependent_columns & new_column_states.active_columns | ||
| ) | ||
| ) | ||
| or ( | ||
| # unflattenable condition: dropped column is used in subquery WHERE clause and dropped column status is NEW or CHANGED in the subquery | ||
| # reason: we should not flatten because the dropped column is not available in the new query, leading to WHERE clause error | ||
| # sample query: 'select "b" from (select "a" as "c", "b" from table where "c" > 1)' can not be flatten to 'select "b" from table where "c" > 1' | ||
| context._is_snowpark_connect_compatible_mode | ||
| and new_column_states.dropped_columns | ||
| and any( | ||
| self.column_states[_col].change_state | ||
| in (ColumnChangeState.NEW, ColumnChangeState.CHANGED_EXP) | ||
| for _col in ( | ||
| subquery_dependent_columns & new_column_states.dropped_columns | ||
| ) | ||
| ) | ||
| ) | ||
| ): | ||
| can_be_flattened = False | ||
| elif self.order_by and ( | ||
|
|
@@ -1411,6 +1427,20 @@ def select(self, cols: List[Expression]) -> "SelectStatement": | |
| subquery_dependent_columns & new_column_states.active_columns | ||
| ) | ||
| ) | ||
| or ( | ||
| # unflattenable condition: dropped column is used in subquery ORDER BY clause and dropped column status is NEW or CHANGED in the subquery | ||
| # reason: we should not flatten because the dropped column is not available in the new query, leading to ORDER BY clause error | ||
| # sample query: 'select "b" from (select "a" as "c", "b" order by "c")' can not be flatten to 'select "b" from table order by "c"' | ||
| context._is_snowpark_connect_compatible_mode | ||
| and new_column_states.dropped_columns | ||
| and any( | ||
| self.column_states[_col].change_state | ||
| in (ColumnChangeState.NEW, ColumnChangeState.CHANGED_EXP) | ||
| for _col in ( | ||
| subquery_dependent_columns & new_column_states.dropped_columns | ||
| ) | ||
| ) | ||
| ) | ||
| ): | ||
| can_be_flattened = False | ||
| elif self.distinct_: | ||
|
|
@@ -1478,12 +1508,17 @@ def select(self, cols: List[Expression]) -> "SelectStatement": | |
| return new | ||
|
|
||
| def filter(self, col: Expression) -> "SelectStatement": | ||
| self._session._retrieve_aggregation_function_list() | ||
| can_be_flattened = ( | ||
| (not self.flatten_disabled) | ||
| and can_clause_dependent_columns_flatten( | ||
| derive_dependent_columns(col), self.column_states | ||
| derive_dependent_columns(col), self.column_states, "filter" | ||
| ) | ||
| and not has_data_generator_exp(self.projection) | ||
| and not has_data_generator_or_window_function_exp(self.projection) | ||
| and not ( | ||
| context._is_snowpark_connect_compatible_mode | ||
| and has_aggregation_function_exp(self.projection) | ||
| ) # sum(col) as new_col, new_col can not be flattened in where clause | ||
| and not (self.order_by and self.limit_ is not None) | ||
| ) | ||
| if can_be_flattened: | ||
|
|
@@ -1518,9 +1553,12 @@ def sort(self, cols: List[Expression]) -> "SelectStatement": | |
| and (not self.limit_) | ||
| and (not self.offset) | ||
| and can_clause_dependent_columns_flatten( | ||
| derive_dependent_columns(*cols), self.column_states | ||
| derive_dependent_columns(*cols), self.column_states, "sort" | ||
| ) | ||
| and not has_data_generator_exp(self.projection) | ||
| # we do not check aggregation function here like filter | ||
| # in the case when aggregation function is in the projection | ||
| # order by is evaluated after aggregation, row info are not taken in the calculation | ||
| ) | ||
| if can_be_flattened: | ||
| new = copy(self) | ||
|
|
@@ -1557,7 +1595,7 @@ def distinct(self) -> "SelectStatement": | |
| # .order_by(col1).select(col2).distinct() cannot be flattened because | ||
| # SELECT DISTINCT B FROM TABLE ORDER BY A is not valid SQL | ||
| and (not (self.order_by and self.has_projection)) | ||
| and not has_data_generator_exp(self.projection) | ||
| and not has_data_generator_or_window_function_exp(self.projection) | ||
| ) | ||
| if can_be_flattened: | ||
| new = copy(self) | ||
|
|
@@ -2048,7 +2086,12 @@ def can_projection_dependent_columns_be_flattened( | |
| def can_clause_dependent_columns_flatten( | ||
| dependent_columns: Optional[AbstractSet[str]], | ||
| subquery_column_states: ColumnStateDict, | ||
| clause: Literal["filter", "sort"], | ||
| ) -> bool: | ||
| assert clause in ( | ||
| "filter", | ||
| "sort", | ||
| ), f"Invalid clause called in can_clause_dependent_columns_flatten: {clause}" | ||
| if dependent_columns == COLUMN_DEPENDENCY_DOLLAR: | ||
| return False | ||
| elif ( | ||
|
|
@@ -2063,15 +2106,18 @@ def can_clause_dependent_columns_flatten( | |
| dc_state = subquery_column_states.get(dc) | ||
| if dc_state: | ||
| if dc_state.change_state == ColumnChangeState.CHANGED_EXP: | ||
| return False | ||
| if ( | ||
| clause == "filter" | ||
| ): # where can not be flattened because 'where' is evaluated before projection, flattening leads to wrong result | ||
| # df.select((col('a') + 1).alias('a')).filter(col('a') > 5) -- this should be applied to the new 'a', flattening will use the old 'a' to evaluated | ||
| return False | ||
| else: # clause == 'sort' | ||
| # df.select((col('a') + 1).alias('a')).sort(col('a')) -- this is valid to flatten because 'order by' is evaluated after projection | ||
| # however, if the order by is a data generator, it should not be flattened because generator is evaluated dynamically according to the order. | ||
| return context._is_snowpark_connect_compatible_mode | ||
| elif dc_state.change_state == ColumnChangeState.NEW: | ||
| # Most of the time this can be flattened. But if a new column uses window function and this column | ||
| # is used in a clause, the sql doesn't work in Snowflake. | ||
| # For instance `select a, rank() over(order by b) as d from test_table where d = 1` doesn't work. | ||
| # But `select a, b as d from test_table where d = 1` works | ||
| # We can inspect whether the referenced new column uses window function. Here we are being | ||
| # conservative for now to not flatten the SQL. | ||
| return False | ||
| return context._is_snowpark_connect_compatible_mode | ||
|
|
||
| return True | ||
|
|
||
|
|
||
|
|
@@ -2295,18 +2341,89 @@ def derive_column_states_from_subquery( | |
| return column_states | ||
|
|
||
|
|
||
| def has_data_generator_exp(expressions: Optional[List["Expression"]]) -> bool: | ||
| def _check_expressions_for_types( | ||
| expressions: Optional[List["Expression"]], | ||
| check_data_gen: bool = False, | ||
| check_window: bool = False, | ||
| check_aggregation: bool = False, | ||
| ) -> bool: | ||
| """Efficiently check if expressions contain specific types in a single pass. | ||
|
|
||
| Args: | ||
| expressions: List of expressions to check | ||
| check_data_gen: Check for data generator functions | ||
| check_window: Check for window functions | ||
| check_aggregation: Check for aggregation functions | ||
|
|
||
| Returns: | ||
| True if any requested type is found | ||
| """ | ||
| if expressions is None: | ||
| return False | ||
|
|
||
| for exp in expressions: | ||
| if isinstance(exp, WindowExpression): | ||
| if exp is None: | ||
| continue | ||
|
|
||
| # Check window functions | ||
| if check_window and isinstance(exp, WindowExpression): | ||
| return True | ||
| if isinstance(exp, FunctionExpression) and ( | ||
| exp.is_data_generator | ||
| or exp.name.lower() in SEQUENCE_DEPENDENT_DATA_GENERATION | ||
|
|
||
| # Check data generators (including window in non-connect mode) | ||
| if check_data_gen: | ||
| # In non-connect mode, windows are treated as data generators | ||
| if not context._is_snowpark_connect_compatible_mode and isinstance( | ||
| exp, WindowExpression | ||
| ): | ||
| return True | ||
|
Comment on lines
+2374
to
+2378
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is confusing imo. It would be better to remove this and just do the following in |
||
| # Check actual data generator functions | ||
| if isinstance(exp, FunctionExpression) and ( | ||
| exp.is_data_generator | ||
| or exp.name.lower() in SEQUENCE_DEPENDENT_DATA_GENERATION | ||
| ): | ||
| # https://docs.snowflake.com/en/sql-reference/functions-data-generation | ||
| return True | ||
|
|
||
| # Check aggregation functions | ||
| if check_aggregation and isinstance(exp, FunctionExpression): | ||
| if exp.name.lower() in context._aggregation_function_set: | ||
| return True | ||
|
|
||
| # Recursively check children | ||
| if _check_expressions_for_types( | ||
| exp.children, check_data_gen, check_window, check_aggregation | ||
| ): | ||
| # https://docs.snowflake.com/en/sql-reference/functions-data-generation | ||
| return True | ||
| if exp is not None and has_data_generator_exp(exp.children): | ||
| return True | ||
|
|
||
| return False | ||
|
|
||
|
|
||
| def has_data_generator_exp(expressions: Optional[List["Expression"]]) -> bool: | ||
| """Check if expressions contain data generator functions. | ||
|
|
||
| Note: | ||
| In non-connect mode, check_data_gen check both data generator and window expressions for backward compatibility. | ||
| In connect mode, check_data_gen only checks data generator expressions. | ||
| """ | ||
| return _check_expressions_for_types(expressions, check_data_gen=True) | ||
|
|
||
|
|
||
| def has_data_generator_or_window_function_exp( | ||
| expressions: Optional[List["Expression"]], | ||
| ) -> bool: | ||
| """Check if expressions contain data generators or window functions. | ||
|
|
||
| Optimized to do a single pass checking both types simultaneously. | ||
| """ | ||
| if not context._is_snowpark_connect_compatible_mode: | ||
| # In non-connect mode, windows are already treated as data generators | ||
| return _check_expressions_for_types(expressions, check_data_gen=True) | ||
| # In connect mode, check both in a single pass | ||
| return _check_expressions_for_types( | ||
| expressions, check_data_gen=True, check_window=True | ||
| ) | ||
|
|
||
|
|
||
| def has_aggregation_function_exp(expressions: Optional[List["Expression"]]) -> bool: | ||
| """Check if expressions contain aggregation functions.""" | ||
| return _check_expressions_for_types(expressions, check_aggregation=True) | ||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -5029,6 +5029,35 @@ def _execute_sproc_internal( | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Note the collect is implicit within the stored procedure call, so should not emit_ast here. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return df.collect(statement_params=statement_params, _emit_ast=False)[0][0] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def _retrieve_aggregation_function_list(self) -> None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """Retrieve the list of aggregation functions which will later be used in sql simplifier.""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if ( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| not context._is_snowpark_connect_compatible_mode | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| or context._aggregation_function_set | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| retrieved_set = set() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for sql in [ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """select function_name from information_schema.functions where is_aggregate = 'YES'""", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """show functions ->> select "name" from $1 where "is_aggregate" = 'Y'""", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+5043
to
+5044
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do we need both of these? For udfs? |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ]: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| retrieved_set.update({r[0].lower() for r in self.sql(sql).collect()}) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| except BaseException as e: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| _logger.debug( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "Unable to get aggregation functions from the database: %s", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| e, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # we raise error here as a pessimistic tactics | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # the reason is that if we fail to retrieve the aggregation function list, we have empty set | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # the simplifier will flatten the query which contains aggregation functions leading to incorrect results | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| raise | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| with context._aggregation_function_set_lock: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| context._aggregation_function_set.update(retrieved_set) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+5032
to
+5059
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Race condition in Issue: Multiple threads calling
Fix: def _retrieve_aggregation_function_list(self) -> None:
if not context._is_snowpark_connect_compatible_mode:
return
with context._aggregation_function_set_lock:
# Re-check inside the lock
if context._aggregation_function_set:
return
retrieved_set = set()
for sql in [...]:
# ... query logic ...
context._aggregation_function_set.update(retrieved_set)
Suggested change
Spotted by Graphite Agent |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def directory(self, stage_name: str, _emit_ast: bool = True) -> DataFrame: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Returns a DataFrame representing the results of a directory table query on the specified stage. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just to confirm, are all of these changes are behing the
_is_snowpark_connect_compatible_modeflag?