diff --git a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py index 067e3d5f06..185ed5a13f 100644 --- a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py +++ b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py @@ -524,6 +524,14 @@ def execution_queries(self) -> Dict["PlanQueryType", List["Query"]]: compiler = PlanCompiler(self) return compiler.compile() + def get_execution_queries_without_cte(self) -> Dict["PlanQueryType", List["Query"]]: + """Compile execution queries with CTE optimization skipped. + Used for retry when CTE-optimized SQL fails at the server.""" + from snowflake.snowpark._internal.compiler.plan_compiler import PlanCompiler + + compiler = PlanCompiler(self) + return compiler.compile(skip_cte_optimization=True) + @property def children_plan_nodes(self) -> List[Union["Selectable", "SnowflakePlan"]]: """ diff --git a/src/snowflake/snowpark/_internal/compiler/plan_compiler.py b/src/snowflake/snowpark/_internal/compiler/plan_compiler.py index e02a495554..e0d938f2eb 100644 --- a/src/snowflake/snowpark/_internal/compiler/plan_compiler.py +++ b/src/snowflake/snowpark/_internal/compiler/plan_compiler.py @@ -78,7 +78,9 @@ def should_start_query_compilation(self) -> bool: ) ) - def compile(self) -> Dict[PlanQueryType, List[Query]]: + def compile( + self, *, skip_cte_optimization: bool = False + ) -> Dict[PlanQueryType, List[Query]]: # initialize the queries with the original queries without optimization final_plan = self._plan queries = { @@ -107,7 +109,10 @@ def compile(self) -> Dict[PlanQueryType, List[Query]]: # 3. apply each optimizations if needed # CTE optimization with measure_time() as cte_time: - if session.cte_optimization_enabled: + if ( + session.cte_optimization_enabled + and not skip_cte_optimization + ): repeated_subquery_eliminator = RepeatedSubqueryElimination( logical_plans, query_generator ) diff --git a/src/snowflake/snowpark/_internal/compiler/telemetry_constants.py b/src/snowflake/snowpark/_internal/compiler/telemetry_constants.py index 326828f29a..19b757ec8f 100644 --- a/src/snowflake/snowpark/_internal/compiler/telemetry_constants.py +++ b/src/snowflake/snowpark/_internal/compiler/telemetry_constants.py @@ -45,6 +45,11 @@ class CompilationStageTelemetryField(Enum): # keys for repeated subquery elimination CTE_NODE_CREATED = "cte_node_created" + # keys for CTE execution retry + CTE_ERROR_COUNT = "cte_error_count" + CTE_QUERY_ID = "cte_query_id" + RETRY_QUERY_ID = "retry_query_id" + # keys for large query breakdown BREAKDOWN_SUMMARY = "breakdown_summary" COMPLEXITY_SCORE_AFTER_CTE_OPTIMIZATION = "complexity_score_after_cte_optimization" diff --git a/src/snowflake/snowpark/_internal/server_connection.py b/src/snowflake/snowpark/_internal/server_connection.py index e0d4b63693..3422854f7f 100644 --- a/src/snowflake/snowpark/_internal/server_connection.py +++ b/src/snowflake/snowpark/_internal/server_connection.py @@ -69,6 +69,7 @@ result_set_to_rows, unwrap_stage_location_single_quote, ) +from snowflake.snowpark import context from snowflake.snowpark.async_job import AsyncJob, _AsyncResultType from snowflake.snowpark.query_history import QueryListener, QueryRecord from snowflake.snowpark.row import Row @@ -696,12 +697,145 @@ def get_result_set( ], Union[List[ResultMetadata], List["ResultMetadataV2"]], ]: - action_id = plan.session._generate_new_action_id() + session = plan.session + action_id = session._generate_new_action_id() plan_queries = plan.execution_queries result, result_meta = None, None statement_params = kwargs.get("_statement_params", None) or {} statement_params["_PLAN_UUID"] = plan.uuid kwargs["_statement_params"] = statement_params + + should_retry_without_cte = ( + block # Snowpark Connect does not use async queries. Async queries retry tracked by SNOW-3293313 + and context._is_snowpark_connect_compatible_mode + and session.cte_optimization_enabled + ) + + try: + result, result_meta = self._execute_queries( + plan, + plan_queries, + action_id, + to_pandas=to_pandas, + to_iter=to_iter, + block=block, + data_type=data_type, + log_on_exception=log_on_exception, + case_sensitive=case_sensitive, + ignore_results=ignore_results, + to_arrow=to_arrow, + **kwargs, + ) + except ProgrammingError as cte_error: + if not should_retry_without_cte: + raise + + # Skip retry if CTE optimization didn't affect the SQL. + # Check for the Snowpark CTE prefix to cover both bare SELECTs + # and DML with embedded CTEs (e.g. INSERT INTO t WITH ...). + main_queries = plan_queries[PlanQueryType.QUERIES] + if not any("SNOWPARK_TEMP_CTE_" in q.sql for q in main_queries): + raise + + unoptimized_plan_queries = plan.get_execution_queries_without_cte() + + logger.debug( + "CTE-optimized query failed with ProgrammingError: %s. " + "Retrying without CTE optimization.", + cte_error, + ) + + # Retry safety: CTE errors are compilation-time rejections (no + # partial DML side effects). Setup DDLs are idempotent, post- + # actions run in a finally block (cleanup on failure), and + # compile() generates fresh temp object names each call. + try: + retry_action_id = session._generate_new_action_id() + result, result_meta = self._execute_queries( + plan, + unoptimized_plan_queries, + retry_action_id, + to_pandas=to_pandas, + to_iter=to_iter, + block=block, + data_type=data_type, + log_on_exception=log_on_exception, + case_sensitive=case_sensitive, + ignore_results=ignore_results, + to_arrow=to_arrow, + **kwargs, + ) + + except ProgrammingError: + raise cte_error + except Exception as retry_error: + # Log both errors for debugging + logger.error( + "Retry without CTE optimization failed with different error: %s. " + "Original CTE error: %s", + retry_error, + cte_error, + ) + raise # Raise the actual retry error, not the original + + else: + cte_query_id = getattr(cte_error, "sfqid", None) + retry_query_id = ( + result.get("sfqid") if isinstance(result, dict) else None + ) + + self._telemetry_client.send_cte_execution_retry_telemetry( + session_id=self.get_session_id(), + plan_uuid=plan.uuid, + error_message=str(cte_error), + api_calls=plan.api_calls, + cte_query_id=cte_query_id, + retry_query_id=retry_query_id, + ) + + if context._cte_error_threshold > 0: + cte_disabled = False + with session._lock: + session._cte_error_count += 1 + if session._cte_error_count >= context._cte_error_threshold: + session._cte_optimization_enabled = False + cte_disabled = True + + if cte_disabled: + + logger.warning( + "CTE optimization has caused %d execution failures. " + "Auto-disabling CTE optimization for " + "the remainder of this session to avoid further " + "performance impact.", + context._cte_error_threshold, + ) + self._telemetry_client.send_cte_optimization_auto_disabled_telemetry( + session_id=self.get_session_id(), + cte_error_count=context._cte_error_threshold, + ) + + if result is None: + raise SnowparkClientExceptionMessages.SQL_LAST_QUERY_RETURN_RESULTSET() + + return result, result_meta + + def _execute_queries( + self, + plan: SnowflakePlan, + plan_queries, + action_id: int, + to_pandas: bool = False, + to_iter: bool = False, + block: bool = True, + data_type: _AsyncResultType = _AsyncResultType.ROW, + log_on_exception: bool = False, + case_sensitive: bool = True, + ignore_results: bool = False, + to_arrow: bool = False, + **kwargs, + ) -> Tuple[Optional[Any], Optional[Any]]: + result, result_meta = None, None try: main_queries = plan_queries[PlanQueryType.QUERIES] post_actions = plan_queries[PlanQueryType.POST_ACTIONS] @@ -810,9 +944,6 @@ def get_result_set( data=get_plan_telemetry_metrics(plan), ) - if result is None: - raise SnowparkClientExceptionMessages.SQL_LAST_QUERY_RETURN_RESULTSET() - return result, result_meta def get_result_and_metadata( diff --git a/src/snowflake/snowpark/_internal/telemetry.py b/src/snowflake/snowpark/_internal/telemetry.py index 0eab8a0717..80f16d395e 100644 --- a/src/snowflake/snowpark/_internal/telemetry.py +++ b/src/snowflake/snowpark/_internal/telemetry.py @@ -63,6 +63,8 @@ class TelemetryField(Enum): TYPE_CURSOR_CREATED = "snowpark_cursor_created" TYPE_SQL_SIMPLIFIER_ENABLED = "snowpark_sql_simplifier_enabled" TYPE_CTE_OPTIMIZATION_ENABLED = "snowpark_cte_optimization_enabled" + TYPE_CTE_EXECUTION_RETRY = "snowpark_cte_execution_retry" + TYPE_CTE_OPTIMIZATION_AUTO_DISABLED = "snowpark_cte_optimization_auto_disabled" # telemetry for optimization that eliminates the extra cast expression generated for expressions TYPE_ELIMINATE_NUMERIC_SQL_VALUE_CAST_ENABLED = ( "snowpark_eliminate_numeric_sql_value_cast_enabled" @@ -660,6 +662,46 @@ def send_cte_optimization_telemetry(self, session_id: str) -> None: } self.send(message) + def send_cte_execution_retry_telemetry( + self, + session_id: str, + plan_uuid: str, + error_message: str, + api_calls: Optional[List[Dict[str, Any]]], + cte_query_id: Optional[str], + retry_query_id: Optional[str], + ) -> None: + message = { + **self._create_basic_telemetry_data( + TelemetryField.TYPE_CTE_EXECUTION_RETRY.value + ), + TelemetryField.KEY_DATA.value: { + TelemetryField.SESSION_ID.value: session_id, + CompilationStageTelemetryField.PLAN_UUID.value: plan_uuid, + CompilationStageTelemetryField.ERROR_MESSAGE.value: error_message, + TelemetryField.KEY_API_CALLS.value: api_calls, + CompilationStageTelemetryField.CTE_QUERY_ID.value: cte_query_id, + CompilationStageTelemetryField.RETRY_QUERY_ID.value: retry_query_id, + }, + } + self.send(message) + + def send_cte_optimization_auto_disabled_telemetry( + self, + session_id: str, + cte_error_count: int, + ) -> None: + message = { + **self._create_basic_telemetry_data( + TelemetryField.TYPE_CTE_OPTIMIZATION_AUTO_DISABLED.value + ), + TelemetryField.KEY_DATA.value: { + TelemetryField.SESSION_ID.value: session_id, + CompilationStageTelemetryField.CTE_ERROR_COUNT.value: cte_error_count, + }, + } + self.send(message) + def send_eliminate_numeric_sql_value_cast_telemetry( self, session_id: str, value: bool ) -> None: diff --git a/src/snowflake/snowpark/context.py b/src/snowflake/snowpark/context.py index c21ddec5eb..2a0ad2fdf3 100644 --- a/src/snowflake/snowpark/context.py +++ b/src/snowflake/snowpark/context.py @@ -32,6 +32,8 @@ # This is an internal-only global flag, used to determine whether the api code which will be executed is compatible with snowflake.snowpark_connect _is_snowpark_connect_compatible_mode = False +_cte_error_threshold = 3 # 0 to disable auto-cte-disable, otherwise the number of times CTE optimization can fail before it is automatically disabled for the remainder of the session. + # Following are internal-only global flags, used to enable development features. _enable_dataframe_trace_on_error = False _debug_eager_schema_validation = False diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index f80255d00d..7c8e1eae93 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -668,6 +668,7 @@ def __init__( self._cte_optimization_enabled: bool = self.is_feature_enabled_for_version( _PYTHON_SNOWPARK_USE_CTE_OPTIMIZATION_VERSION ) + self._cte_error_count: int = 0 self._use_logical_type_for_create_df: bool = ( self._conn._get_client_side_session_parameter( _PYTHON_SNOWPARK_USE_LOGICAL_TYPE_FOR_CREATE_DATAFRAME_STRING, True diff --git a/tests/integ/test_cte.py b/tests/integ/test_cte.py index 1adcbe3558..c10c1ef8a0 100644 --- a/tests/integ/test_cte.py +++ b/tests/integ/test_cte.py @@ -4,14 +4,16 @@ import re import tracemalloc +from contextlib import contextmanager from unittest import mock import pytest +from snowflake.connector.errors import ProgrammingError from snowflake.connector.options import installed_pandas from snowflake.snowpark import Window from snowflake.snowpark._internal.analyzer import analyzer -from snowflake.snowpark._internal.analyzer.snowflake_plan import PlanQueryType +from snowflake.snowpark._internal.analyzer.snowflake_plan import PlanQueryType, Query from snowflake.snowpark._internal.utils import ( TEMP_OBJECT_NAME_PREFIX, TempObjectType, @@ -34,6 +36,7 @@ TimestampType, ) import snowflake.snowpark.context as context +from snowflake.snowpark.exceptions import SnowparkSQLException from tests.integ.scala.test_dataframe_reader_suite import get_reader from tests.integ.utils.sql_counter import SqlCounter, sql_count_checker from tests.utils import IS_IN_STORED_PROC_LOCALFS, TestFiles, Utils @@ -1352,3 +1355,499 @@ def access_queries_and_schema(df): df_union = create_cte_dataframe() with SqlCounter(query_count=0, describe_count=expected_describe_count): access_queries_and_schema(df_union) + + +# --------------------------------------------------------------------------- +# CTE retry integration tests +# --------------------------------------------------------------------------- + + +@contextmanager +def _patch_run_query_fail_on_cte( + session, + error_msg="Simulated CTE error", + failed_sfqid="sim_cte_qid_001", + fail_all=False, +): + """Wrap ``run_query`` so the first call whose SQL contains a Snowpark CTE + identifier raises ``ProgrammingError``. All other calls (setup DDLs, + retry queries) pass through to the real implementation. + + When *fail_all* is True every call raises, which simulates the case where + the retry also fails. + """ + original_run_query = session._conn.run_query + cte_failed = [False] + + def side_effect(sql, *args, **kwargs): + if fail_all: + err = ProgrammingError(error_msg) + err.sfqid = failed_sfqid + raise err + if not cte_failed[0] and isinstance(sql, str) and "SNOWPARK_TEMP_CTE_" in sql: + cte_failed[0] = True + err = ProgrammingError(error_msg) + err.sfqid = failed_sfqid + raise err + return original_run_query(sql, *args, **kwargs) + + with mock.patch.object(session._conn, "run_query", side_effect=side_effect): + yield + + +def _build_cte_plan(session): + """Return a real plan whose execution queries use CTE syntax.""" + df = session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"]) + df_cte = df.union_all(df) + plan = df_cte._plan + queries = plan.execution_queries[PlanQueryType.QUERIES] + assert any( + "SNOWPARK_TEMP_CTE_" in q.sql for q in queries + ), "Plan should produce CTE SQL" + return plan + + +def _build_non_cte_plan(session): + """Return a real plan whose execution queries do NOT use CTE syntax.""" + df = session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"]) + plan = df.select("a")._plan + queries = plan.execution_queries[PlanQueryType.QUERIES] + assert not any( + "SNOWPARK_TEMP_CTE_" in q.sql for q in queries + ), "Plan should not produce CTE SQL" + return plan + + +def test_cte_retry_succeeds_increments_counter(session): + plan = _build_cte_plan(session) + + with mock.patch.object( + context, "_is_snowpark_connect_compatible_mode", True + ), mock.patch.object(session, "_cte_error_count", 0), _patch_run_query_fail_on_cte( + session + ): + session._conn.get_result_set(plan) + assert session._cte_error_count == 1 + assert session._cte_optimization_enabled is True + + +def test_cte_retry_fails_reraises_original_error(session): + plan = _build_cte_plan(session) + + with mock.patch.object( + context, "_is_snowpark_connect_compatible_mode", True + ), mock.patch.object(session, "_cte_error_count", 0), _patch_run_query_fail_on_cte( + session, fail_all=True + ): + with pytest.raises(SnowparkSQLException): + session._conn.get_result_set(plan) + assert session._cte_error_count == 0 + + +def test_cte_retry_no_retry_when_cte_disabled(session): + plan = _build_non_cte_plan(session) + + with mock.patch.object( + context, "_is_snowpark_connect_compatible_mode", True + ), mock.patch.object( + session, "_cte_optimization_enabled", False + ), mock.patch.object( + session, "_cte_error_count", 0 + ), mock.patch.object( + session._conn, "run_query", side_effect=ProgrammingError("fail") + ): + with pytest.raises(SnowparkSQLException): + session._conn.get_result_set(plan) + assert session._cte_error_count == 0 + + +def test_cte_retry_no_retry_when_not_compatible_mode(session): + plan = _build_cte_plan(session) + + with mock.patch.object( + context, "_is_snowpark_connect_compatible_mode", False + ), mock.patch.object(session, "_cte_error_count", 0), _patch_run_query_fail_on_cte( + session + ): + with pytest.raises(SnowparkSQLException): + session._conn.get_result_set(plan) + assert session._cte_error_count == 0 + + +def test_cte_retry_no_retry_on_async_queries(session): + plan = _build_cte_plan(session) + + with mock.patch.object( + context, "_is_snowpark_connect_compatible_mode", True + ), mock.patch.object(session, "_cte_error_count", 0), _patch_run_query_fail_on_cte( + session + ): + with pytest.raises(SnowparkSQLException): + session._conn.get_result_set(plan, block=False) + assert session._cte_error_count == 0 + + +def test_cte_retry_auto_disable_at_threshold(session): + plan = _build_cte_plan(session) + + with mock.patch.object( + context, "_is_snowpark_connect_compatible_mode", True + ), mock.patch.object(session, "_cte_error_count", 2), mock.patch.object( + session._conn._telemetry_client, + "send_cte_optimization_auto_disabled_telemetry", + ) as mock_disable_telemetry, _patch_run_query_fail_on_cte( + session + ): + session._conn.get_result_set(plan) + assert session._cte_error_count == 3 + assert session._cte_optimization_enabled is False + mock_disable_telemetry.assert_called_once_with( + session_id=session._conn.get_session_id(), + cte_error_count=3, + ) + + +def test_cte_retry_counter_below_threshold_does_not_disable(session): + plan = _build_cte_plan(session) + + with mock.patch.object( + context, "_is_snowpark_connect_compatible_mode", True + ), mock.patch.object(session, "_cte_error_count", 0), mock.patch.object( + session._conn._telemetry_client, + "send_cte_optimization_auto_disabled_telemetry", + ) as mock_disable_telemetry, _patch_run_query_fail_on_cte( + session + ): + session._conn.get_result_set(plan) + assert session._cte_error_count == 1 + assert session._cte_optimization_enabled is True + mock_disable_telemetry.assert_not_called() + + +def test_cte_retry_telemetry_contains_reproduction_data(session): + """Both cte_query_id and retry_query_id should be real Snowflake + query IDs. We achieve this by injecting a genuinely broken CTE + query (which Snowflake rejects with a real sfqid) while keeping + the unoptimized queries valid (so the retry succeeds with a real + sfqid). api_calls come from the real plan (populated by production + code via add_api_call / set_api_call_source).""" + plan = _build_cte_plan(session) + real_api_calls = plan.api_calls + assert ( + len(real_api_calls) > 0 + ), "Plan should have api_calls from DataFrame operations" + + good_queries = plan.get_execution_queries_without_cte() + + broken_sql = "WITH SNOWPARK_TEMP_CTE_bad AS (SELECT * FROM __nonexistent_table_xyz_99999) SELECT * FROM SNOWPARK_TEMP_CTE_bad" + broken_plan_queries = { + PlanQueryType.QUERIES: [Query(broken_sql)], + PlanQueryType.POST_ACTIONS: [], + } + + with mock.patch.object( + context, "_is_snowpark_connect_compatible_mode", True + ), mock.patch.object(session, "_cte_error_count", 0), mock.patch.object( + type(plan), + "execution_queries", + new_callable=lambda: property(lambda self: broken_plan_queries), + ), mock.patch.object( + plan, "get_execution_queries_without_cte", return_value=good_queries + ), mock.patch.object( + session._conn._telemetry_client, + "send_cte_execution_retry_telemetry", + ) as mock_retry_telemetry: + session._conn.get_result_set(plan) + + mock_retry_telemetry.assert_called_once() + call_kwargs = mock_retry_telemetry.call_args.kwargs + assert call_kwargs["api_calls"] == real_api_calls + assert call_kwargs["plan_uuid"] == plan.uuid + assert "error_message" in call_kwargs and len(call_kwargs["error_message"]) > 0 + + cte_qid = call_kwargs["cte_query_id"] + assert ( + isinstance(cte_qid, str) and len(cte_qid) > 0 + ), f"Expected a real Snowflake query ID for the failed CTE query, got {cte_qid!r}" + + retry_qid = call_kwargs["retry_query_id"] + assert ( + isinstance(retry_qid, str) and len(retry_qid) > 0 + ), f"Expected a real Snowflake query ID for the retry query, got {retry_qid!r}" + + +def test_cte_retry_threshold_zero_disables_auto_disable(session): + plan = _build_cte_plan(session) + + with mock.patch.object( + context, "_is_snowpark_connect_compatible_mode", True + ), mock.patch.object(context, "_cte_error_threshold", 0), mock.patch.object( + session, "_cte_error_count", 100 + ), mock.patch.object( + session._conn._telemetry_client, + "send_cte_optimization_auto_disabled_telemetry", + ) as mock_disable_telemetry, _patch_run_query_fail_on_cte( + session + ): + session._conn.get_result_set(plan) + assert session._cte_error_count == 100 # counter unchanged + assert session._cte_optimization_enabled is True + mock_disable_telemetry.assert_not_called() + + +def test_cte_retry_no_retry_when_plan_has_no_referenced_ctes(session): + plan = _build_non_cte_plan(session) + + with mock.patch.object( + context, "_is_snowpark_connect_compatible_mode", True + ), mock.patch.object(session, "_cte_error_count", 0), mock.patch.object( + session._conn, "run_query", side_effect=ProgrammingError("fail") + ): + with pytest.raises(SnowparkSQLException): + session._conn.get_result_set(plan) + assert session._cte_error_count == 0 + + +def _save_as_table_cte_retry_helper(session, table_type, large_data=False): + """Shared logic for save_as_table CTE retry tests. + + Calls ``df.write.save_as_table`` through the public API with a + CTE-optimized DataFrame. A mock makes the first CTE-bearing query + fail with ``ProgrammingError``; the retry (unoptimized SQL) succeeds. + Every SQL statement is recorded for inspection. + + When *large_data* is True, uses >= ARRAY_BIND_THRESHOLD cells so that + Snowpark generates temp table + file format setup/post-actions. + + Returns (executed_sqls, table_name). Caller is responsible for + dropping the table. + """ + table_name = random_name_for_temp_object(TempObjectType.TABLE) + if large_data: + rows = [[i, i + 1] for i in range(256)] + else: + rows = [[1, 2], [3, 4]] + df = session.create_dataframe(rows, schema=["a", "b"]) + df_cte = df.union_all(df) + + original_run_query = session._conn.run_query + executed_sqls = [] + cte_failed = [False] + + def recording_side_effect(sql, *args, **kwargs): + executed_sqls.append(sql) + if not cte_failed[0] and isinstance(sql, str) and "SNOWPARK_TEMP_CTE_" in sql: + cte_failed[0] = True + err = ProgrammingError("Simulated CTE error") + err.sfqid = "sim_cte_qid_save" + raise err + return original_run_query(sql, *args, **kwargs) + + with mock.patch.object( + context, "_is_snowpark_connect_compatible_mode", True + ), mock.patch.object(session, "_cte_error_count", 0), mock.patch.object( + session._conn, "run_query", side_effect=recording_side_effect + ): + df_cte.write.save_as_table( + table_name, + mode="append", + table_type=table_type, + _emit_ast=False, + ) + assert session._cte_error_count == 1 + + return executed_sqls, table_name + + +def _assert_save_as_table_cte_retry( + session, + executed_sqls, + table_name, + expect_temporary, + expected_row_count=4, +): + """Shared assertions for save_as_table CTE retry tests. + + Inspects *executed_sqls* by content (not position) because the public + API may prepend extra queries like ``SHOW TABLES LIKE ...``. + """ + inserts = [s for s in executed_sqls if s.strip().upper().startswith("INSERT")] + creates = [ + s + for s in executed_sqls + if s.strip().upper().startswith("CREATE") and table_name.upper() in s.upper() + ] + + # Exactly one CTE INSERT (first attempt, fails) and one plain INSERT (retry). + cte_inserts = [s for s in inserts if "SNOWPARK_TEMP_CTE_" in s] + plain_inserts = [s for s in inserts if "SNOWPARK_TEMP_CTE_" not in s] + assert len(cte_inserts) == 1, f"Expected 1 CTE INSERT, got {len(cte_inserts)}" + assert ( + len(plain_inserts) == 1 + ), f"Expected 1 plain INSERT (retry), got {len(plain_inserts)}" + + cte_insert = cte_inserts[0] + retry_insert = plain_inserts[0] + + # CTE is embedded mid-statement, not at the start. + assert not cte_insert.strip().startswith( + "WITH" + ), "CTE should be embedded inside INSERT, not at the start" + + # CTE identifiers from the first attempt must not appear in the retry. + cte_ids = set(re.findall(r"SNOWPARK_TEMP_CTE_\w+", cte_insert)) + assert len(cte_ids) > 0, "First attempt should have CTE identifiers" + for cte_id in cte_ids: + assert ( + cte_id not in retry_insert + ), f"Retry INSERT should not reference CTE identifier {cte_id}" + + # Two CREATE TABLE statements (one per attempt), both idempotent. + assert ( + len(creates) >= 2 + ), f"Expected at least 2 CREATE TABLE statements, got {len(creates)}" + for ddl in creates: + upper = ddl.upper() + assert ( + "IF NOT EXISTS" in upper or "IF NOT EXISTS" in upper + ), f"CREATE TABLE should use IF NOT EXISTS: {ddl[:120]}" + + if expect_temporary: + assert "TEMPORARY" in creates[0].upper() + else: + assert ( + "TEMPORARY" not in creates[0].upper() + ), "Should create a permanent table, not temporary" + + # The CTE INSERT must come before the retry INSERT. + cte_idx = executed_sqls.index(cte_insert) + retry_idx = executed_sqls.index(retry_insert) + assert cte_idx < retry_idx, "CTE INSERT must precede the retry INSERT" + + # Verify the table was populated by the retry. + result = session.sql(f"SELECT COUNT(*) AS cnt FROM {table_name}").collect() + assert result[0]["CNT"] == expected_row_count + + +def test_cte_retry_save_as_table_with_embedded_cte(session): + """save_as_table (APPEND, new temporary table) with CTE-optimized source. + + The CTE is embedded inside INSERT (not at the start). Verifies the + retry detects it, retries without CTE, and populates the table. + """ + table_name = None + try: + executed_sqls, table_name = _save_as_table_cte_retry_helper( + session, table_type="temporary" + ) + _assert_save_as_table_cte_retry( + session, executed_sqls, table_name, expect_temporary=True + ) + finally: + if table_name: + Utils.drop_table(session, table_name) + + +def test_cte_retry_save_as_table_permanent_new_table(session): + """save_as_table (APPEND, new permanent table) with CTE-optimized source. + + Same as the temporary-table variant but targets a permanent table. + Snowpark auto-creates it with CREATE TABLE IF NOT EXISTS (no TEMPORARY). + """ + table_name = None + try: + executed_sqls, table_name = _save_as_table_cte_retry_helper( + session, table_type="" + ) + _assert_save_as_table_cte_retry( + session, executed_sqls, table_name, expect_temporary=False + ) + finally: + if table_name: + Utils.drop_table(session, table_name) + + +def test_cte_retry_save_as_table_large_data_post_actions(session): + """save_as_table with large data (>= ARRAY_BIND_THRESHOLD cells). + + Large datasets force Snowpark through the ``large_local_relation_plan`` + path which generates setup queries (CREATE OR REPLACE SCOPED TEMP TABLE, + batch INSERT) and a post-action (DROP TABLE IF EXISTS) for the staging + table. This test verifies that: + + 1. Post-action DROPs from the failed CTE attempt execute before the + retry begins (cleanup from ``_execute_queries``'s finally block). + 2. Post-action DROPs from the retry also execute. + 3. The staging table names differ between the two compilation passes + (thread-safe session generates fresh random names). + 4. The target table is correctly populated despite the retry. + """ + table_name = None + try: + executed_sqls, table_name = _save_as_table_cte_retry_helper( + session, table_type="temporary", large_data=True + ) + + _assert_save_as_table_cte_retry( + session, + executed_sqls, + table_name, + expect_temporary=True, + expected_row_count=512, # 256 rows * 2 (union_all) + ) + + # -- Post-action assertions -- + + # Staging table DROPs (exclude the target table). + staging_drops = [ + s + for s in executed_sqls + if s.strip().upper().startswith("DROP") + and "TABLE" in s.upper() + and table_name.upper() not in s.upper() + ] + assert len(staging_drops) >= 2, ( + f"Expected at least 2 DROP TABLE for staging tables (one from " + f"the failed CTE attempt, one from the retry), got " + f"{len(staging_drops)}: {[d[:100] for d in staging_drops]}" + ) + + # The first staging DROP must appear before the retry's main INSERT. + plain_inserts = [ + s + for s in executed_sqls + if s.strip().upper().startswith("INSERT") and "SNOWPARK_TEMP_CTE_" not in s + ] + retry_insert_idx = executed_sqls.index(plain_inserts[0]) + first_staging_drop_idx = executed_sqls.index(staging_drops[0]) + assert first_staging_drop_idx < retry_insert_idx, ( + "Post-action DROP from the failed CTE attempt should run " + "before the retry INSERT" + ) + + # Extract staging table names from CREATE statements (exclude + # the target table) and verify they differ between attempts. + staging_creates = [ + s + for s in executed_sqls + if s.strip().upper().startswith("CREATE") + and TEMP_OBJECT_NAME_PREFIX in s + and table_name.upper() not in s.upper() + ] + staging_names = set() + for ddl in staging_creates: + match = re.search(rf"({re.escape(TEMP_OBJECT_NAME_PREFIX)}\w+)", ddl) + if match: + staging_names.add(match.group(1)) + assert len(staging_creates) >= 2, ( + f"Expected at least 2 staging CREATE statements, got " + f"{len(staging_creates)}" + ) + assert len(staging_names) > 1, ( + f"Staging temp table names should differ between attempts " + f"(thread-safe session generates fresh names), but got: " + f"{staging_names}" + ) + finally: + if table_name: + Utils.drop_table(session, table_name)