From f21e1f5ec4396e32bd585c096aa4e556b5df9a61 Mon Sep 17 00:00:00 2001 From: Adam Ling Date: Thu, 26 Mar 2026 11:20:49 -0700 Subject: [PATCH 1/6] retry upon cte error --- .../_internal/analyzer/snowflake_plan.py | 8 + .../_internal/compiler/plan_compiler.py | 9 +- .../_internal/compiler/telemetry_constants.py | 5 + .../snowpark/_internal/server_connection.py | 130 +++++++++- src/snowflake/snowpark/_internal/telemetry.py | 42 ++++ src/snowflake/snowpark/context.py | 2 + src/snowflake/snowpark/session.py | 1 + tests/integ/test_cte.py | 222 ++++++++++++++++++ 8 files changed, 413 insertions(+), 6 deletions(-) diff --git a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py index 067e3d5f06..185ed5a13f 100644 --- a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py +++ b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py @@ -524,6 +524,14 @@ def execution_queries(self) -> Dict["PlanQueryType", List["Query"]]: compiler = PlanCompiler(self) return compiler.compile() + def get_execution_queries_without_cte(self) -> Dict["PlanQueryType", List["Query"]]: + """Compile execution queries with CTE optimization skipped. + Used for retry when CTE-optimized SQL fails at the server.""" + from snowflake.snowpark._internal.compiler.plan_compiler import PlanCompiler + + compiler = PlanCompiler(self) + return compiler.compile(skip_cte_optimization=True) + @property def children_plan_nodes(self) -> List[Union["Selectable", "SnowflakePlan"]]: """ diff --git a/src/snowflake/snowpark/_internal/compiler/plan_compiler.py b/src/snowflake/snowpark/_internal/compiler/plan_compiler.py index e02a495554..e0d938f2eb 100644 --- a/src/snowflake/snowpark/_internal/compiler/plan_compiler.py +++ b/src/snowflake/snowpark/_internal/compiler/plan_compiler.py @@ -78,7 +78,9 @@ def should_start_query_compilation(self) -> bool: ) ) - def compile(self) -> Dict[PlanQueryType, List[Query]]: + def compile( + self, *, skip_cte_optimization: bool = False + ) -> Dict[PlanQueryType, List[Query]]: # initialize the queries with the original queries without optimization final_plan = self._plan queries = { @@ -107,7 +109,10 @@ def compile(self) -> Dict[PlanQueryType, List[Query]]: # 3. apply each optimizations if needed # CTE optimization with measure_time() as cte_time: - if session.cte_optimization_enabled: + if ( + session.cte_optimization_enabled + and not skip_cte_optimization + ): repeated_subquery_eliminator = RepeatedSubqueryElimination( logical_plans, query_generator ) diff --git a/src/snowflake/snowpark/_internal/compiler/telemetry_constants.py b/src/snowflake/snowpark/_internal/compiler/telemetry_constants.py index 326828f29a..19b757ec8f 100644 --- a/src/snowflake/snowpark/_internal/compiler/telemetry_constants.py +++ b/src/snowflake/snowpark/_internal/compiler/telemetry_constants.py @@ -45,6 +45,11 @@ class CompilationStageTelemetryField(Enum): # keys for repeated subquery elimination CTE_NODE_CREATED = "cte_node_created" + # keys for CTE execution retry + CTE_ERROR_COUNT = "cte_error_count" + CTE_QUERY_ID = "cte_query_id" + RETRY_QUERY_ID = "retry_query_id" + # keys for large query breakdown BREAKDOWN_SUMMARY = "breakdown_summary" COMPLEXITY_SCORE_AFTER_CTE_OPTIMIZATION = "complexity_score_after_cte_optimization" diff --git a/src/snowflake/snowpark/_internal/server_connection.py b/src/snowflake/snowpark/_internal/server_connection.py index e0d4b63693..81ed5fdaa9 100644 --- a/src/snowflake/snowpark/_internal/server_connection.py +++ b/src/snowflake/snowpark/_internal/server_connection.py @@ -69,6 +69,7 @@ result_set_to_rows, unwrap_stage_location_single_quote, ) +from snowflake.snowpark import context from snowflake.snowpark.async_job import AsyncJob, _AsyncResultType from snowflake.snowpark.query_history import QueryListener, QueryRecord from snowflake.snowpark.row import Row @@ -696,12 +697,136 @@ def get_result_set( ], Union[List[ResultMetadata], List["ResultMetadataV2"]], ]: - action_id = plan.session._generate_new_action_id() + session = plan.session + action_id = session._generate_new_action_id() plan_queries = plan.execution_queries result, result_meta = None, None statement_params = kwargs.get("_statement_params", None) or {} statement_params["_PLAN_UUID"] = plan.uuid kwargs["_statement_params"] = statement_params + + should_retry_without_cte = ( + block # Snowpark Connect does not use async queries. Async queries retry tracked by SNOW-3293313 + and context._is_snowpark_connect_compatible_mode + and session.cte_optimization_enabled + ) + + try: + result, result_meta = self._execute_queries( + plan, + plan_queries, + action_id, + to_pandas=to_pandas, + to_iter=to_iter, + block=block, + data_type=data_type, + log_on_exception=log_on_exception, + case_sensitive=case_sensitive, + ignore_results=ignore_results, + to_arrow=to_arrow, + **kwargs, + ) + except ProgrammingError as cte_error: + if not should_retry_without_cte: + raise + + # CTE definitions are only prepended to the last query in the plan + # (see compile_queries_for_plan in compiler/utils.py), so checking + # the last query is sufficient to determine if CTE rewriting occurred. + main_queries = plan_queries[PlanQueryType.QUERIES] + if not main_queries or not main_queries[-1].sql.strip().startswith("WITH "): + raise + + unoptimized_plan_queries = plan.get_execution_queries_without_cte() + + logger.debug( + "CTE-optimized query failed with ProgrammingError: %s. " + "Retrying without CTE optimization.", + cte_error, + ) + + # Re-running _execute_queries is safe: setup queries preceding the + # main query are Snowpark-generated temp object DDL (e.g., CREATE + # TEMP FILE FORMAT IF NOT EXISTS) that are idempotent. CTE rewriting + # only affects the last query, so setup queries are identical across + # both plans. + try: + retry_action_id = session._generate_new_action_id() + result, result_meta = self._execute_queries( + plan, + unoptimized_plan_queries, + retry_action_id, + to_pandas=to_pandas, + to_iter=to_iter, + block=block, + data_type=data_type, + log_on_exception=log_on_exception, + case_sensitive=case_sensitive, + ignore_results=ignore_results, + to_arrow=to_arrow, + **kwargs, + ) + except Exception: + raise cte_error + else: + with session._lock: + session._cte_error_count += 1 + current_count = session._cte_error_count + + cte_query_id = getattr(cte_error, "sfqid", None) + retry_query_id = ( + result.get("sfqid") if isinstance(result, dict) else None + ) + + self._telemetry_client.send_cte_execution_retry_telemetry( + session_id=self.get_session_id(), + plan_uuid=plan.uuid, + error_message=str(cte_error), + api_calls=plan.api_calls, + cte_query_id=cte_query_id, + retry_query_id=retry_query_id, + ) + + if ( + context._cte_error_threshold > 0 + and current_count >= context._cte_error_threshold + ): + logger.warning( + "CTE optimization has caused %d execution failures " + "(threshold=%d). Auto-disabling CTE optimization for " + "the remainder of this session to avoid further " + "performance impact and pipeline failures.", + current_count, + context._cte_error_threshold, + ) + with session._lock: + session._cte_optimization_enabled = False + self._telemetry_client.send_cte_optimization_auto_disabled_telemetry( + session_id=self.get_session_id(), + cte_error_count=current_count, + ) + + if result is None: + raise SnowparkClientExceptionMessages.SQL_LAST_QUERY_RETURN_RESULTSET() + + return result, result_meta + + def _execute_queries( + self, + plan: SnowflakePlan, + plan_queries, + action_id: int, + to_pandas: bool = False, + to_iter: bool = False, + block: bool = True, + data_type: _AsyncResultType = _AsyncResultType.ROW, + log_on_exception: bool = False, + case_sensitive: bool = True, + ignore_results: bool = False, + to_arrow: bool = False, + **kwargs, + ) -> Tuple[Optional[Any], Optional[Any]]: + result, result_meta = None, None try: main_queries = plan_queries[PlanQueryType.QUERIES] post_actions = plan_queries[PlanQueryType.POST_ACTIONS] @@ -810,9 +935,6 @@ def get_result_set( data=get_plan_telemetry_metrics(plan), ) - if result is None: - raise SnowparkClientExceptionMessages.SQL_LAST_QUERY_RETURN_RESULTSET() - return result, result_meta def get_result_and_metadata( diff --git a/src/snowflake/snowpark/_internal/telemetry.py b/src/snowflake/snowpark/_internal/telemetry.py index 0eab8a0717..80f16d395e 100644 --- a/src/snowflake/snowpark/_internal/telemetry.py +++ b/src/snowflake/snowpark/_internal/telemetry.py @@ -63,6 +63,8 @@ class TelemetryField(Enum): TYPE_CURSOR_CREATED = "snowpark_cursor_created" TYPE_SQL_SIMPLIFIER_ENABLED = "snowpark_sql_simplifier_enabled" TYPE_CTE_OPTIMIZATION_ENABLED = "snowpark_cte_optimization_enabled" + TYPE_CTE_EXECUTION_RETRY = "snowpark_cte_execution_retry" + TYPE_CTE_OPTIMIZATION_AUTO_DISABLED = "snowpark_cte_optimization_auto_disabled" # telemetry for optimization that eliminates the extra cast expression generated for expressions TYPE_ELIMINATE_NUMERIC_SQL_VALUE_CAST_ENABLED = ( "snowpark_eliminate_numeric_sql_value_cast_enabled" @@ -660,6 +662,46 @@ def send_cte_optimization_telemetry(self, session_id: str) -> None: } self.send(message) + def send_cte_execution_retry_telemetry( + self, + session_id: str, + plan_uuid: str, + error_message: str, + api_calls: Optional[List[Dict[str, Any]]], + cte_query_id: Optional[str], + retry_query_id: Optional[str], + ) -> None: + message = { + **self._create_basic_telemetry_data( + TelemetryField.TYPE_CTE_EXECUTION_RETRY.value + ), + TelemetryField.KEY_DATA.value: { + TelemetryField.SESSION_ID.value: session_id, + CompilationStageTelemetryField.PLAN_UUID.value: plan_uuid, + CompilationStageTelemetryField.ERROR_MESSAGE.value: error_message, + TelemetryField.KEY_API_CALLS.value: api_calls, + CompilationStageTelemetryField.CTE_QUERY_ID.value: cte_query_id, + CompilationStageTelemetryField.RETRY_QUERY_ID.value: retry_query_id, + }, + } + self.send(message) + + def send_cte_optimization_auto_disabled_telemetry( + self, + session_id: str, + cte_error_count: int, + ) -> None: + message = { + **self._create_basic_telemetry_data( + TelemetryField.TYPE_CTE_OPTIMIZATION_AUTO_DISABLED.value + ), + TelemetryField.KEY_DATA.value: { + TelemetryField.SESSION_ID.value: session_id, + CompilationStageTelemetryField.CTE_ERROR_COUNT.value: cte_error_count, + }, + } + self.send(message) + def send_eliminate_numeric_sql_value_cast_telemetry( self, session_id: str, value: bool ) -> None: diff --git a/src/snowflake/snowpark/context.py b/src/snowflake/snowpark/context.py index c21ddec5eb..2a0ad2fdf3 100644 --- a/src/snowflake/snowpark/context.py +++ b/src/snowflake/snowpark/context.py @@ -32,6 +32,8 @@ # This is an internal-only global flag, used to determine whether the api code which will be executed is compatible with snowflake.snowpark_connect _is_snowpark_connect_compatible_mode = False +_cte_error_threshold = 3 # 0 to disable auto-cte-disable, otherwise the number of times CTE optimization can fail before it is automatically disabled for the remainder of the session. + # Following are internal-only global flags, used to enable development features. _enable_dataframe_trace_on_error = False _debug_eager_schema_validation = False diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index f80255d00d..7c8e1eae93 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -668,6 +668,7 @@ def __init__( self._cte_optimization_enabled: bool = self.is_feature_enabled_for_version( _PYTHON_SNOWPARK_USE_CTE_OPTIMIZATION_VERSION ) + self._cte_error_count: int = 0 self._use_logical_type_for_create_df: bool = ( self._conn._get_client_side_session_parameter( _PYTHON_SNOWPARK_USE_LOGICAL_TYPE_FOR_CREATE_DATAFRAME_STRING, True diff --git a/tests/integ/test_cte.py b/tests/integ/test_cte.py index 1adcbe3558..b693325926 100644 --- a/tests/integ/test_cte.py +++ b/tests/integ/test_cte.py @@ -4,10 +4,12 @@ import re import tracemalloc +from contextlib import contextmanager from unittest import mock import pytest +from snowflake.connector.errors import ProgrammingError from snowflake.connector.options import installed_pandas from snowflake.snowpark import Window from snowflake.snowpark._internal.analyzer import analyzer @@ -34,6 +36,7 @@ TimestampType, ) import snowflake.snowpark.context as context +from snowflake.snowpark.exceptions import SnowparkSQLException from tests.integ.scala.test_dataframe_reader_suite import get_reader from tests.integ.utils.sql_counter import SqlCounter, sql_count_checker from tests.utils import IS_IN_STORED_PROC_LOCALFS, TestFiles, Utils @@ -1352,3 +1355,222 @@ def access_queries_and_schema(df): df_union = create_cte_dataframe() with SqlCounter(query_count=0, describe_count=expected_describe_count): access_queries_and_schema(df_union) + + +# --------------------------------------------------------------------------- +# CTE retry integration tests +# --------------------------------------------------------------------------- + + +@contextmanager +def _patch_run_query_fail_on_cte( + session, + error_msg="Simulated CTE error", + failed_sfqid="sim_cte_qid_001", + fail_all=False, +): + """Wrap ``run_query`` so the first call whose SQL starts with ``WITH `` + raises ``ProgrammingError``. All other calls (setup DDLs, retry queries) + pass through to the real implementation. + + When *fail_all* is True every call raises, which simulates the case where + the retry also fails. + """ + original_run_query = session._conn.run_query + cte_failed = [False] + + def side_effect(sql, *args, **kwargs): + if fail_all: + err = ProgrammingError(error_msg) + err.sfqid = failed_sfqid + raise err + if ( + not cte_failed[0] + and isinstance(sql, str) + and sql.strip().startswith("WITH ") + ): + cte_failed[0] = True + err = ProgrammingError(error_msg) + err.sfqid = failed_sfqid + raise err + return original_run_query(sql, *args, **kwargs) + + with mock.patch.object(session._conn, "run_query", side_effect=side_effect): + yield + + +def _build_cte_plan(session): + """Return a real plan whose last execution query uses CTE syntax.""" + df = session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"]) + df_cte = df.union_all(df) + plan = df_cte._plan + queries = plan.execution_queries[PlanQueryType.QUERIES] + assert queries[-1].sql.strip().startswith("WITH "), "Plan should produce CTE SQL" + return plan + + +def _build_non_cte_plan(session): + """Return a real plan whose last execution query does NOT use CTE syntax.""" + df = session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"]) + plan = df.select("a")._plan + queries = plan.execution_queries[PlanQueryType.QUERIES] + assert not queries[-1].sql.strip().startswith("WITH ") + return plan + + +class TestCteRetryIntegration: + """Integration tests for the CTE retry logic in + ``ServerConnection.get_result_set``. Each test uses a *real* Snowpark + session and real DataFrame operations to produce genuine CTE-optimized + plans. Only ``run_query`` is mocked to simulate server-side failure on + the CTE-bearing query.""" + + @pytest.fixture(autouse=True) + def cte_retry_setup(self, session): + original_count = session._cte_error_count + original_cte_enabled = session._cte_optimization_enabled + session._cte_error_count = 0 + session._cte_optimization_enabled = True + yield + session._cte_error_count = original_count + session._cte_optimization_enabled = original_cte_enabled + + def test_retry_succeeds_increments_counter(self, session): + plan = _build_cte_plan(session) + + with mock.patch.object(context, "_is_snowpark_connect_compatible_mode", True): + with _patch_run_query_fail_on_cte(session): + session._conn.get_result_set(plan) + + assert session._cte_error_count == 1 + assert session._cte_optimization_enabled is True + + def test_retry_fails_reraises_original_error(self, session): + plan = _build_cte_plan(session) + + with mock.patch.object(context, "_is_snowpark_connect_compatible_mode", True): + with _patch_run_query_fail_on_cte(session, fail_all=True): + with pytest.raises(SnowparkSQLException): + session._conn.get_result_set(plan) + + assert session._cte_error_count == 0 + + def test_no_retry_when_cte_disabled(self, session): + session._cte_optimization_enabled = False + plan = _build_non_cte_plan(session) + + with mock.patch.object(context, "_is_snowpark_connect_compatible_mode", True): + with mock.patch.object( + session._conn, "run_query", side_effect=ProgrammingError("fail") + ): + with pytest.raises(SnowparkSQLException): + session._conn.get_result_set(plan) + + assert session._cte_error_count == 0 + + def test_no_retry_when_not_compatible_mode(self, session): + plan = _build_cte_plan(session) + + with mock.patch.object(context, "_is_snowpark_connect_compatible_mode", False): + with _patch_run_query_fail_on_cte(session): + with pytest.raises(SnowparkSQLException): + session._conn.get_result_set(plan) + + assert session._cte_error_count == 0 + + def test_no_retry_on_async_queries(self, session): + plan = _build_cte_plan(session) + + with mock.patch.object(context, "_is_snowpark_connect_compatible_mode", True): + with _patch_run_query_fail_on_cte(session): + with pytest.raises(SnowparkSQLException): + session._conn.get_result_set(plan, block=False) + + assert session._cte_error_count == 0 + + def test_auto_disable_cte_at_threshold(self, session): + session._cte_error_count = 2 + plan = _build_cte_plan(session) + + with mock.patch.object(context, "_is_snowpark_connect_compatible_mode", True): + with _patch_run_query_fail_on_cte(session): + with mock.patch.object( + session._conn._telemetry_client, + "send_cte_optimization_auto_disabled_telemetry", + ) as mock_disable_telemetry: + session._conn.get_result_set(plan) + + assert session._cte_error_count == 3 + assert session._cte_optimization_enabled is False + mock_disable_telemetry.assert_called_once() + + def test_counter_below_threshold_does_not_disable(self, session): + session._cte_error_count = 0 + plan = _build_cte_plan(session) + + with mock.patch.object(context, "_is_snowpark_connect_compatible_mode", True): + with _patch_run_query_fail_on_cte(session): + with mock.patch.object( + session._conn._telemetry_client, + "send_cte_optimization_auto_disabled_telemetry", + ) as mock_disable_telemetry: + session._conn.get_result_set(plan) + + assert session._cte_error_count == 1 + assert session._cte_optimization_enabled is True + mock_disable_telemetry.assert_not_called() + + def test_retry_telemetry_contains_reproduction_data(self, session): + plan = _build_cte_plan(session) + plan.api_calls = [ + {"name": "DataFrame.filter"}, + {"name": "DataFrame.collect"}, + ] + + with mock.patch.object(context, "_is_snowpark_connect_compatible_mode", True): + with _patch_run_query_fail_on_cte( + session, + error_msg="CTE error: invalid identifier", + failed_sfqid="cte_qid_abc", + ): + with mock.patch.object( + session._conn._telemetry_client, + "send_cte_execution_retry_telemetry", + ) as mock_retry_telemetry: + session._conn.get_result_set(plan) + + mock_retry_telemetry.assert_called_once() + call_kwargs = mock_retry_telemetry.call_args.kwargs + assert call_kwargs["api_calls"] == plan.api_calls + assert call_kwargs["cte_query_id"] == "cte_qid_abc" + assert call_kwargs["retry_query_id"] is not None + assert "CTE error" in call_kwargs["error_message"] + assert call_kwargs["plan_uuid"] == plan.uuid + + def test_threshold_zero_disables_auto_disable(self, session): + session._cte_error_count = 100 + plan = _build_cte_plan(session) + + with mock.patch.object(context, "_is_snowpark_connect_compatible_mode", True): + with mock.patch.object(context, "_cte_error_threshold", 0): + with _patch_run_query_fail_on_cte(session): + with mock.patch.object( + session._conn._telemetry_client, + "send_cte_optimization_auto_disabled_telemetry", + ) as mock_disable_telemetry: + session._conn.get_result_set(plan) + + assert session._cte_optimization_enabled is True + mock_disable_telemetry.assert_not_called() + + def test_no_retry_when_query_has_no_cte_prefix(self, session): + plan = _build_non_cte_plan(session) + + with mock.patch.object(context, "_is_snowpark_connect_compatible_mode", True): + with mock.patch.object( + session._conn, "run_query", side_effect=ProgrammingError("fail") + ): + with pytest.raises(SnowparkSQLException): + session._conn.get_result_set(plan) + + assert session._cte_error_count == 0 From ab022cb84ce27635b6ddc25e3b6b6b1d0521f154 Mon Sep 17 00:00:00 2001 From: Adam Ling Date: Thu, 26 Mar 2026 13:27:37 -0700 Subject: [PATCH 2/6] update --- .../snowpark/_internal/server_connection.py | 40 ++- tests/integ/test_cte.py | 268 ++++++++++-------- 2 files changed, 169 insertions(+), 139 deletions(-) diff --git a/src/snowflake/snowpark/_internal/server_connection.py b/src/snowflake/snowpark/_internal/server_connection.py index 81ed5fdaa9..6acebd2448 100644 --- a/src/snowflake/snowpark/_internal/server_connection.py +++ b/src/snowflake/snowpark/_internal/server_connection.py @@ -769,10 +769,6 @@ def get_result_set( except Exception: raise cte_error else: - with session._lock: - session._cte_error_count += 1 - current_count = session._cte_error_count - cte_query_id = getattr(cte_error, "sfqid", None) retry_query_id = ( result.get("sfqid") if isinstance(result, dict) else None @@ -787,24 +783,26 @@ def get_result_set( retry_query_id=retry_query_id, ) - if ( - context._cte_error_threshold > 0 - and current_count >= context._cte_error_threshold - ): - logger.warning( - "CTE optimization has caused %d execution failures " - "(threshold=%d). Auto-disabling CTE optimization for " - "the remainder of this session to avoid further " - "performance impact and pipeline failures.", - current_count, - context._cte_error_threshold, - ) + if context._cte_error_threshold > 0: with session._lock: - session._cte_optimization_enabled = False - self._telemetry_client.send_cte_optimization_auto_disabled_telemetry( - session_id=self.get_session_id(), - cte_error_count=current_count, - ) + session._cte_error_count += 1 + current_count = session._cte_error_count + + if current_count >= context._cte_error_threshold: + logger.warning( + "CTE optimization has caused %d execution failures " + "(threshold=%d). Auto-disabling CTE optimization for " + "the remainder of this session to avoid further " + "performance impact.", + current_count, + context._cte_error_threshold, + ) + with session._lock: + session._cte_optimization_enabled = False + self._telemetry_client.send_cte_optimization_auto_disabled_telemetry( + session_id=self.get_session_id(), + cte_error_count=current_count, + ) if result is None: raise SnowparkClientExceptionMessages.SQL_LAST_QUERY_RETURN_RESULTSET() diff --git a/tests/integ/test_cte.py b/tests/integ/test_cte.py index b693325926..a4122d4477 100644 --- a/tests/integ/test_cte.py +++ b/tests/integ/test_cte.py @@ -13,7 +13,7 @@ from snowflake.connector.options import installed_pandas from snowflake.snowpark import Window from snowflake.snowpark._internal.analyzer import analyzer -from snowflake.snowpark._internal.analyzer.snowflake_plan import PlanQueryType +from snowflake.snowpark._internal.analyzer.snowflake_plan import PlanQueryType, Query from snowflake.snowpark._internal.utils import ( TEMP_OBJECT_NAME_PREFIX, TempObjectType, @@ -1418,159 +1418,191 @@ def _build_non_cte_plan(session): return plan -class TestCteRetryIntegration: - """Integration tests for the CTE retry logic in - ``ServerConnection.get_result_set``. Each test uses a *real* Snowpark - session and real DataFrame operations to produce genuine CTE-optimized - plans. Only ``run_query`` is mocked to simulate server-side failure on - the CTE-bearing query.""" - - @pytest.fixture(autouse=True) - def cte_retry_setup(self, session): - original_count = session._cte_error_count - original_cte_enabled = session._cte_optimization_enabled - session._cte_error_count = 0 - session._cte_optimization_enabled = True - yield - session._cte_error_count = original_count - session._cte_optimization_enabled = original_cte_enabled - - def test_retry_succeeds_increments_counter(self, session): - plan = _build_cte_plan(session) - - with mock.patch.object(context, "_is_snowpark_connect_compatible_mode", True): - with _patch_run_query_fail_on_cte(session): - session._conn.get_result_set(plan) +def test_cte_retry_succeeds_increments_counter(session): + plan = _build_cte_plan(session) + with mock.patch.object( + context, "_is_snowpark_connect_compatible_mode", True + ), mock.patch.object(session, "_cte_error_count", 0), _patch_run_query_fail_on_cte( + session + ): + session._conn.get_result_set(plan) assert session._cte_error_count == 1 assert session._cte_optimization_enabled is True - def test_retry_fails_reraises_original_error(self, session): - plan = _build_cte_plan(session) - with mock.patch.object(context, "_is_snowpark_connect_compatible_mode", True): - with _patch_run_query_fail_on_cte(session, fail_all=True): - with pytest.raises(SnowparkSQLException): - session._conn.get_result_set(plan) +def test_cte_retry_fails_reraises_original_error(session): + plan = _build_cte_plan(session) + with mock.patch.object( + context, "_is_snowpark_connect_compatible_mode", True + ), mock.patch.object(session, "_cte_error_count", 0), _patch_run_query_fail_on_cte( + session, fail_all=True + ): + with pytest.raises(SnowparkSQLException): + session._conn.get_result_set(plan) assert session._cte_error_count == 0 - def test_no_retry_when_cte_disabled(self, session): - session._cte_optimization_enabled = False - plan = _build_non_cte_plan(session) - with mock.patch.object(context, "_is_snowpark_connect_compatible_mode", True): - with mock.patch.object( - session._conn, "run_query", side_effect=ProgrammingError("fail") - ): - with pytest.raises(SnowparkSQLException): - session._conn.get_result_set(plan) +def test_cte_retry_no_retry_when_cte_disabled(session): + plan = _build_non_cte_plan(session) + with mock.patch.object( + context, "_is_snowpark_connect_compatible_mode", True + ), mock.patch.object( + session, "_cte_optimization_enabled", False + ), mock.patch.object( + session, "_cte_error_count", 0 + ), mock.patch.object( + session._conn, "run_query", side_effect=ProgrammingError("fail") + ): + with pytest.raises(SnowparkSQLException): + session._conn.get_result_set(plan) assert session._cte_error_count == 0 - def test_no_retry_when_not_compatible_mode(self, session): - plan = _build_cte_plan(session) - with mock.patch.object(context, "_is_snowpark_connect_compatible_mode", False): - with _patch_run_query_fail_on_cte(session): - with pytest.raises(SnowparkSQLException): - session._conn.get_result_set(plan) +def test_cte_retry_no_retry_when_not_compatible_mode(session): + plan = _build_cte_plan(session) + with mock.patch.object( + context, "_is_snowpark_connect_compatible_mode", False + ), mock.patch.object(session, "_cte_error_count", 0), _patch_run_query_fail_on_cte( + session + ): + with pytest.raises(SnowparkSQLException): + session._conn.get_result_set(plan) assert session._cte_error_count == 0 - def test_no_retry_on_async_queries(self, session): - plan = _build_cte_plan(session) - with mock.patch.object(context, "_is_snowpark_connect_compatible_mode", True): - with _patch_run_query_fail_on_cte(session): - with pytest.raises(SnowparkSQLException): - session._conn.get_result_set(plan, block=False) +def test_cte_retry_no_retry_on_async_queries(session): + plan = _build_cte_plan(session) + with mock.patch.object( + context, "_is_snowpark_connect_compatible_mode", True + ), mock.patch.object(session, "_cte_error_count", 0), _patch_run_query_fail_on_cte( + session + ): + with pytest.raises(SnowparkSQLException): + session._conn.get_result_set(plan, block=False) assert session._cte_error_count == 0 - def test_auto_disable_cte_at_threshold(self, session): - session._cte_error_count = 2 - plan = _build_cte_plan(session) - with mock.patch.object(context, "_is_snowpark_connect_compatible_mode", True): - with _patch_run_query_fail_on_cte(session): - with mock.patch.object( - session._conn._telemetry_client, - "send_cte_optimization_auto_disabled_telemetry", - ) as mock_disable_telemetry: - session._conn.get_result_set(plan) +def test_cte_retry_auto_disable_at_threshold(session): + plan = _build_cte_plan(session) + with mock.patch.object( + context, "_is_snowpark_connect_compatible_mode", True + ), mock.patch.object(session, "_cte_error_count", 2), mock.patch.object( + session._conn._telemetry_client, + "send_cte_optimization_auto_disabled_telemetry", + ) as mock_disable_telemetry, _patch_run_query_fail_on_cte( + session + ): + session._conn.get_result_set(plan) assert session._cte_error_count == 3 assert session._cte_optimization_enabled is False - mock_disable_telemetry.assert_called_once() + mock_disable_telemetry.assert_called_once_with( + session_id=session._conn.get_session_id(), + cte_error_count=3, + ) - def test_counter_below_threshold_does_not_disable(self, session): - session._cte_error_count = 0 - plan = _build_cte_plan(session) - with mock.patch.object(context, "_is_snowpark_connect_compatible_mode", True): - with _patch_run_query_fail_on_cte(session): - with mock.patch.object( - session._conn._telemetry_client, - "send_cte_optimization_auto_disabled_telemetry", - ) as mock_disable_telemetry: - session._conn.get_result_set(plan) +def test_cte_retry_counter_below_threshold_does_not_disable(session): + plan = _build_cte_plan(session) + with mock.patch.object( + context, "_is_snowpark_connect_compatible_mode", True + ), mock.patch.object(session, "_cte_error_count", 0), mock.patch.object( + session._conn._telemetry_client, + "send_cte_optimization_auto_disabled_telemetry", + ) as mock_disable_telemetry, _patch_run_query_fail_on_cte( + session + ): + session._conn.get_result_set(plan) assert session._cte_error_count == 1 assert session._cte_optimization_enabled is True mock_disable_telemetry.assert_not_called() - def test_retry_telemetry_contains_reproduction_data(self, session): - plan = _build_cte_plan(session) - plan.api_calls = [ - {"name": "DataFrame.filter"}, - {"name": "DataFrame.collect"}, - ] - with mock.patch.object(context, "_is_snowpark_connect_compatible_mode", True): - with _patch_run_query_fail_on_cte( - session, - error_msg="CTE error: invalid identifier", - failed_sfqid="cte_qid_abc", - ): - with mock.patch.object( - session._conn._telemetry_client, - "send_cte_execution_retry_telemetry", - ) as mock_retry_telemetry: - session._conn.get_result_set(plan) - - mock_retry_telemetry.assert_called_once() - call_kwargs = mock_retry_telemetry.call_args.kwargs - assert call_kwargs["api_calls"] == plan.api_calls - assert call_kwargs["cte_query_id"] == "cte_qid_abc" - assert call_kwargs["retry_query_id"] is not None - assert "CTE error" in call_kwargs["error_message"] - assert call_kwargs["plan_uuid"] == plan.uuid - - def test_threshold_zero_disables_auto_disable(self, session): - session._cte_error_count = 100 - plan = _build_cte_plan(session) - - with mock.patch.object(context, "_is_snowpark_connect_compatible_mode", True): - with mock.patch.object(context, "_cte_error_threshold", 0): - with _patch_run_query_fail_on_cte(session): - with mock.patch.object( - session._conn._telemetry_client, - "send_cte_optimization_auto_disabled_telemetry", - ) as mock_disable_telemetry: - session._conn.get_result_set(plan) +def test_cte_retry_telemetry_contains_reproduction_data(session): + """Both cte_query_id and retry_query_id should be real Snowflake + query IDs. We achieve this by injecting a genuinely broken CTE + query (which Snowflake rejects with a real sfqid) while keeping + the unoptimized queries valid (so the retry succeeds with a real + sfqid). api_calls come from the real plan (populated by production + code via add_api_call / set_api_call_source).""" + plan = _build_cte_plan(session) + real_api_calls = plan.api_calls + assert ( + len(real_api_calls) > 0 + ), "Plan should have api_calls from DataFrame operations" + + good_queries = plan.get_execution_queries_without_cte() + + broken_sql = "WITH __bad_cte AS (SELECT * FROM __nonexistent_table_xyz_99999) SELECT * FROM __bad_cte" + broken_plan_queries = { + PlanQueryType.QUERIES: [Query(broken_sql)], + PlanQueryType.POST_ACTIONS: [], + } + + with mock.patch.object( + context, "_is_snowpark_connect_compatible_mode", True + ), mock.patch.object(session, "_cte_error_count", 0), mock.patch.object( + type(plan), + "execution_queries", + new_callable=lambda: property(lambda self: broken_plan_queries), + ), mock.patch.object( + plan, "get_execution_queries_without_cte", return_value=good_queries + ), mock.patch.object( + session._conn._telemetry_client, + "send_cte_execution_retry_telemetry", + ) as mock_retry_telemetry: + session._conn.get_result_set(plan) + + mock_retry_telemetry.assert_called_once() + call_kwargs = mock_retry_telemetry.call_args.kwargs + assert call_kwargs["api_calls"] == real_api_calls + assert call_kwargs["plan_uuid"] == plan.uuid + assert "error_message" in call_kwargs and len(call_kwargs["error_message"]) > 0 + + cte_qid = call_kwargs["cte_query_id"] + assert ( + isinstance(cte_qid, str) and len(cte_qid) > 0 + ), f"Expected a real Snowflake query ID for the failed CTE query, got {cte_qid!r}" + + retry_qid = call_kwargs["retry_query_id"] + assert ( + isinstance(retry_qid, str) and len(retry_qid) > 0 + ), f"Expected a real Snowflake query ID for the retry query, got {retry_qid!r}" + + +def test_cte_retry_threshold_zero_disables_auto_disable(session): + plan = _build_cte_plan(session) + with mock.patch.object( + context, "_is_snowpark_connect_compatible_mode", True + ), mock.patch.object(context, "_cte_error_threshold", 0), mock.patch.object( + session, "_cte_error_count", 100 + ), mock.patch.object( + session._conn._telemetry_client, + "send_cte_optimization_auto_disabled_telemetry", + ) as mock_disable_telemetry, _patch_run_query_fail_on_cte( + session + ): + session._conn.get_result_set(plan) + assert session._cte_error_count == 100 # counter unchanged assert session._cte_optimization_enabled is True mock_disable_telemetry.assert_not_called() - def test_no_retry_when_query_has_no_cte_prefix(self, session): - plan = _build_non_cte_plan(session) - with mock.patch.object(context, "_is_snowpark_connect_compatible_mode", True): - with mock.patch.object( - session._conn, "run_query", side_effect=ProgrammingError("fail") - ): - with pytest.raises(SnowparkSQLException): - session._conn.get_result_set(plan) +def test_cte_retry_no_retry_when_query_has_no_cte_prefix(session): + plan = _build_non_cte_plan(session) + with mock.patch.object( + context, "_is_snowpark_connect_compatible_mode", True + ), mock.patch.object(session, "_cte_error_count", 0), mock.patch.object( + session._conn, "run_query", side_effect=ProgrammingError("fail") + ): + with pytest.raises(SnowparkSQLException): + session._conn.get_result_set(plan) assert session._cte_error_count == 0 From 1e36445c01eeba45cd5132f4ece88e8e94ddf42c Mon Sep 17 00:00:00 2001 From: Adam Ling Date: Fri, 27 Mar 2026 16:53:11 -0700 Subject: [PATCH 3/6] update tests and fix bug --- .../snowpark/_internal/server_connection.py | 33 +-- tests/integ/test_cte.py | 273 +++++++++++++++++- 2 files changed, 275 insertions(+), 31 deletions(-) diff --git a/src/snowflake/snowpark/_internal/server_connection.py b/src/snowflake/snowpark/_internal/server_connection.py index 6acebd2448..889f7bed5c 100644 --- a/src/snowflake/snowpark/_internal/server_connection.py +++ b/src/snowflake/snowpark/_internal/server_connection.py @@ -730,11 +730,11 @@ def get_result_set( if not should_retry_without_cte: raise - # CTE definitions are only prepended to the last query in the plan - # (see compile_queries_for_plan in compiler/utils.py), so checking - # the last query is sufficient to determine if CTE rewriting occurred. + # Skip retry if CTE optimization didn't affect the SQL. + # Check for the Snowpark CTE prefix to cover both bare SELECTs + # and DML with embedded CTEs (e.g. INSERT INTO t WITH ...). main_queries = plan_queries[PlanQueryType.QUERIES] - if not main_queries or not main_queries[-1].sql.strip().startswith("WITH "): + if not any("SNOWPARK_TEMP_CTE_" in q.sql for q in main_queries): raise unoptimized_plan_queries = plan.get_execution_queries_without_cte() @@ -745,11 +745,10 @@ def get_result_set( cte_error, ) - # Re-running _execute_queries is safe: setup queries preceding the - # main query are Snowpark-generated temp object DDL (e.g., CREATE - # TEMP FILE FORMAT IF NOT EXISTS) that are idempotent. CTE rewriting - # only affects the last query, so setup queries are identical across - # both plans. + # Retry safety: CTE errors are compilation-time rejections (no + # partial DML side effects). Setup DDLs are idempotent, post- + # actions run in a finally block (cleanup on failure), and + # compile() generates fresh temp object names each call. try: retry_action_id = session._generate_new_action_id() result, result_meta = self._execute_queries( @@ -784,24 +783,24 @@ def get_result_set( ) if context._cte_error_threshold > 0: + cte_disabled = False with session._lock: session._cte_error_count += 1 - current_count = session._cte_error_count + if session._cte_error_count >= context._cte_error_threshold: + session._cte_optimization_enabled = False + cte_disabled = True - if current_count >= context._cte_error_threshold: + if cte_disabled: logger.warning( - "CTE optimization has caused %d execution failures " - "(threshold=%d). Auto-disabling CTE optimization for " + "CTE optimization has caused %d execution failures." + "Auto-disabling CTE optimization for " "the remainder of this session to avoid further " "performance impact.", - current_count, context._cte_error_threshold, ) - with session._lock: - session._cte_optimization_enabled = False self._telemetry_client.send_cte_optimization_auto_disabled_telemetry( session_id=self.get_session_id(), - cte_error_count=current_count, + cte_error_count=context._cte_error_threshold, ) if result is None: diff --git a/tests/integ/test_cte.py b/tests/integ/test_cte.py index a4122d4477..c10c1ef8a0 100644 --- a/tests/integ/test_cte.py +++ b/tests/integ/test_cte.py @@ -1369,9 +1369,9 @@ def _patch_run_query_fail_on_cte( failed_sfqid="sim_cte_qid_001", fail_all=False, ): - """Wrap ``run_query`` so the first call whose SQL starts with ``WITH `` - raises ``ProgrammingError``. All other calls (setup DDLs, retry queries) - pass through to the real implementation. + """Wrap ``run_query`` so the first call whose SQL contains a Snowpark CTE + identifier raises ``ProgrammingError``. All other calls (setup DDLs, + retry queries) pass through to the real implementation. When *fail_all* is True every call raises, which simulates the case where the retry also fails. @@ -1384,11 +1384,7 @@ def side_effect(sql, *args, **kwargs): err = ProgrammingError(error_msg) err.sfqid = failed_sfqid raise err - if ( - not cte_failed[0] - and isinstance(sql, str) - and sql.strip().startswith("WITH ") - ): + if not cte_failed[0] and isinstance(sql, str) and "SNOWPARK_TEMP_CTE_" in sql: cte_failed[0] = True err = ProgrammingError(error_msg) err.sfqid = failed_sfqid @@ -1400,21 +1396,25 @@ def side_effect(sql, *args, **kwargs): def _build_cte_plan(session): - """Return a real plan whose last execution query uses CTE syntax.""" + """Return a real plan whose execution queries use CTE syntax.""" df = session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"]) df_cte = df.union_all(df) plan = df_cte._plan queries = plan.execution_queries[PlanQueryType.QUERIES] - assert queries[-1].sql.strip().startswith("WITH "), "Plan should produce CTE SQL" + assert any( + "SNOWPARK_TEMP_CTE_" in q.sql for q in queries + ), "Plan should produce CTE SQL" return plan def _build_non_cte_plan(session): - """Return a real plan whose last execution query does NOT use CTE syntax.""" + """Return a real plan whose execution queries do NOT use CTE syntax.""" df = session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"]) plan = df.select("a")._plan queries = plan.execution_queries[PlanQueryType.QUERIES] - assert not queries[-1].sql.strip().startswith("WITH ") + assert not any( + "SNOWPARK_TEMP_CTE_" in q.sql for q in queries + ), "Plan should not produce CTE SQL" return plan @@ -1539,7 +1539,7 @@ def test_cte_retry_telemetry_contains_reproduction_data(session): good_queries = plan.get_execution_queries_without_cte() - broken_sql = "WITH __bad_cte AS (SELECT * FROM __nonexistent_table_xyz_99999) SELECT * FROM __bad_cte" + broken_sql = "WITH SNOWPARK_TEMP_CTE_bad AS (SELECT * FROM __nonexistent_table_xyz_99999) SELECT * FROM SNOWPARK_TEMP_CTE_bad" broken_plan_queries = { PlanQueryType.QUERIES: [Query(broken_sql)], PlanQueryType.POST_ACTIONS: [], @@ -1595,7 +1595,7 @@ def test_cte_retry_threshold_zero_disables_auto_disable(session): mock_disable_telemetry.assert_not_called() -def test_cte_retry_no_retry_when_query_has_no_cte_prefix(session): +def test_cte_retry_no_retry_when_plan_has_no_referenced_ctes(session): plan = _build_non_cte_plan(session) with mock.patch.object( @@ -1606,3 +1606,248 @@ def test_cte_retry_no_retry_when_query_has_no_cte_prefix(session): with pytest.raises(SnowparkSQLException): session._conn.get_result_set(plan) assert session._cte_error_count == 0 + + +def _save_as_table_cte_retry_helper(session, table_type, large_data=False): + """Shared logic for save_as_table CTE retry tests. + + Calls ``df.write.save_as_table`` through the public API with a + CTE-optimized DataFrame. A mock makes the first CTE-bearing query + fail with ``ProgrammingError``; the retry (unoptimized SQL) succeeds. + Every SQL statement is recorded for inspection. + + When *large_data* is True, uses >= ARRAY_BIND_THRESHOLD cells so that + Snowpark generates temp table + file format setup/post-actions. + + Returns (executed_sqls, table_name). Caller is responsible for + dropping the table. + """ + table_name = random_name_for_temp_object(TempObjectType.TABLE) + if large_data: + rows = [[i, i + 1] for i in range(256)] + else: + rows = [[1, 2], [3, 4]] + df = session.create_dataframe(rows, schema=["a", "b"]) + df_cte = df.union_all(df) + + original_run_query = session._conn.run_query + executed_sqls = [] + cte_failed = [False] + + def recording_side_effect(sql, *args, **kwargs): + executed_sqls.append(sql) + if not cte_failed[0] and isinstance(sql, str) and "SNOWPARK_TEMP_CTE_" in sql: + cte_failed[0] = True + err = ProgrammingError("Simulated CTE error") + err.sfqid = "sim_cte_qid_save" + raise err + return original_run_query(sql, *args, **kwargs) + + with mock.patch.object( + context, "_is_snowpark_connect_compatible_mode", True + ), mock.patch.object(session, "_cte_error_count", 0), mock.patch.object( + session._conn, "run_query", side_effect=recording_side_effect + ): + df_cte.write.save_as_table( + table_name, + mode="append", + table_type=table_type, + _emit_ast=False, + ) + assert session._cte_error_count == 1 + + return executed_sqls, table_name + + +def _assert_save_as_table_cte_retry( + session, + executed_sqls, + table_name, + expect_temporary, + expected_row_count=4, +): + """Shared assertions for save_as_table CTE retry tests. + + Inspects *executed_sqls* by content (not position) because the public + API may prepend extra queries like ``SHOW TABLES LIKE ...``. + """ + inserts = [s for s in executed_sqls if s.strip().upper().startswith("INSERT")] + creates = [ + s + for s in executed_sqls + if s.strip().upper().startswith("CREATE") and table_name.upper() in s.upper() + ] + + # Exactly one CTE INSERT (first attempt, fails) and one plain INSERT (retry). + cte_inserts = [s for s in inserts if "SNOWPARK_TEMP_CTE_" in s] + plain_inserts = [s for s in inserts if "SNOWPARK_TEMP_CTE_" not in s] + assert len(cte_inserts) == 1, f"Expected 1 CTE INSERT, got {len(cte_inserts)}" + assert ( + len(plain_inserts) == 1 + ), f"Expected 1 plain INSERT (retry), got {len(plain_inserts)}" + + cte_insert = cte_inserts[0] + retry_insert = plain_inserts[0] + + # CTE is embedded mid-statement, not at the start. + assert not cte_insert.strip().startswith( + "WITH" + ), "CTE should be embedded inside INSERT, not at the start" + + # CTE identifiers from the first attempt must not appear in the retry. + cte_ids = set(re.findall(r"SNOWPARK_TEMP_CTE_\w+", cte_insert)) + assert len(cte_ids) > 0, "First attempt should have CTE identifiers" + for cte_id in cte_ids: + assert ( + cte_id not in retry_insert + ), f"Retry INSERT should not reference CTE identifier {cte_id}" + + # Two CREATE TABLE statements (one per attempt), both idempotent. + assert ( + len(creates) >= 2 + ), f"Expected at least 2 CREATE TABLE statements, got {len(creates)}" + for ddl in creates: + upper = ddl.upper() + assert ( + "IF NOT EXISTS" in upper or "IF NOT EXISTS" in upper + ), f"CREATE TABLE should use IF NOT EXISTS: {ddl[:120]}" + + if expect_temporary: + assert "TEMPORARY" in creates[0].upper() + else: + assert ( + "TEMPORARY" not in creates[0].upper() + ), "Should create a permanent table, not temporary" + + # The CTE INSERT must come before the retry INSERT. + cte_idx = executed_sqls.index(cte_insert) + retry_idx = executed_sqls.index(retry_insert) + assert cte_idx < retry_idx, "CTE INSERT must precede the retry INSERT" + + # Verify the table was populated by the retry. + result = session.sql(f"SELECT COUNT(*) AS cnt FROM {table_name}").collect() + assert result[0]["CNT"] == expected_row_count + + +def test_cte_retry_save_as_table_with_embedded_cte(session): + """save_as_table (APPEND, new temporary table) with CTE-optimized source. + + The CTE is embedded inside INSERT (not at the start). Verifies the + retry detects it, retries without CTE, and populates the table. + """ + table_name = None + try: + executed_sqls, table_name = _save_as_table_cte_retry_helper( + session, table_type="temporary" + ) + _assert_save_as_table_cte_retry( + session, executed_sqls, table_name, expect_temporary=True + ) + finally: + if table_name: + Utils.drop_table(session, table_name) + + +def test_cte_retry_save_as_table_permanent_new_table(session): + """save_as_table (APPEND, new permanent table) with CTE-optimized source. + + Same as the temporary-table variant but targets a permanent table. + Snowpark auto-creates it with CREATE TABLE IF NOT EXISTS (no TEMPORARY). + """ + table_name = None + try: + executed_sqls, table_name = _save_as_table_cte_retry_helper( + session, table_type="" + ) + _assert_save_as_table_cte_retry( + session, executed_sqls, table_name, expect_temporary=False + ) + finally: + if table_name: + Utils.drop_table(session, table_name) + + +def test_cte_retry_save_as_table_large_data_post_actions(session): + """save_as_table with large data (>= ARRAY_BIND_THRESHOLD cells). + + Large datasets force Snowpark through the ``large_local_relation_plan`` + path which generates setup queries (CREATE OR REPLACE SCOPED TEMP TABLE, + batch INSERT) and a post-action (DROP TABLE IF EXISTS) for the staging + table. This test verifies that: + + 1. Post-action DROPs from the failed CTE attempt execute before the + retry begins (cleanup from ``_execute_queries``'s finally block). + 2. Post-action DROPs from the retry also execute. + 3. The staging table names differ between the two compilation passes + (thread-safe session generates fresh random names). + 4. The target table is correctly populated despite the retry. + """ + table_name = None + try: + executed_sqls, table_name = _save_as_table_cte_retry_helper( + session, table_type="temporary", large_data=True + ) + + _assert_save_as_table_cte_retry( + session, + executed_sqls, + table_name, + expect_temporary=True, + expected_row_count=512, # 256 rows * 2 (union_all) + ) + + # -- Post-action assertions -- + + # Staging table DROPs (exclude the target table). + staging_drops = [ + s + for s in executed_sqls + if s.strip().upper().startswith("DROP") + and "TABLE" in s.upper() + and table_name.upper() not in s.upper() + ] + assert len(staging_drops) >= 2, ( + f"Expected at least 2 DROP TABLE for staging tables (one from " + f"the failed CTE attempt, one from the retry), got " + f"{len(staging_drops)}: {[d[:100] for d in staging_drops]}" + ) + + # The first staging DROP must appear before the retry's main INSERT. + plain_inserts = [ + s + for s in executed_sqls + if s.strip().upper().startswith("INSERT") and "SNOWPARK_TEMP_CTE_" not in s + ] + retry_insert_idx = executed_sqls.index(plain_inserts[0]) + first_staging_drop_idx = executed_sqls.index(staging_drops[0]) + assert first_staging_drop_idx < retry_insert_idx, ( + "Post-action DROP from the failed CTE attempt should run " + "before the retry INSERT" + ) + + # Extract staging table names from CREATE statements (exclude + # the target table) and verify they differ between attempts. + staging_creates = [ + s + for s in executed_sqls + if s.strip().upper().startswith("CREATE") + and TEMP_OBJECT_NAME_PREFIX in s + and table_name.upper() not in s.upper() + ] + staging_names = set() + for ddl in staging_creates: + match = re.search(rf"({re.escape(TEMP_OBJECT_NAME_PREFIX)}\w+)", ddl) + if match: + staging_names.add(match.group(1)) + assert len(staging_creates) >= 2, ( + f"Expected at least 2 staging CREATE statements, got " + f"{len(staging_creates)}" + ) + assert len(staging_names) > 1, ( + f"Staging temp table names should differ between attempts " + f"(thread-safe session generates fresh names), but got: " + f"{staging_names}" + ) + finally: + if table_name: + Utils.drop_table(session, table_name) From 05cb04896cde23c0c64ede405ffd7aec7ce929ae Mon Sep 17 00:00:00 2001 From: Adam Ling Date: Fri, 27 Mar 2026 17:03:26 -0700 Subject: [PATCH 4/6] Update src/snowflake/snowpark/_internal/server_connection.py Co-authored-by: graphite-app[bot] <96075541+graphite-app[bot]@users.noreply.github.com> --- src/snowflake/snowpark/_internal/server_connection.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/snowflake/snowpark/_internal/server_connection.py b/src/snowflake/snowpark/_internal/server_connection.py index 889f7bed5c..f8d2dc2c72 100644 --- a/src/snowflake/snowpark/_internal/server_connection.py +++ b/src/snowflake/snowpark/_internal/server_connection.py @@ -791,12 +791,14 @@ def get_result_set( cte_disabled = True if cte_disabled: + logger.warning( - "CTE optimization has caused %d execution failures." + "CTE optimization has caused %d execution failures. " "Auto-disabling CTE optimization for " "the remainder of this session to avoid further " "performance impact.", context._cte_error_threshold, + ) self._telemetry_client.send_cte_optimization_auto_disabled_telemetry( session_id=self.get_session_id(), From a8b238b811f329b904c6fb5a3635086bd049a73c Mon Sep 17 00:00:00 2001 From: Adam Ling Date: Fri, 27 Mar 2026 17:05:43 -0700 Subject: [PATCH 5/6] Update src/snowflake/snowpark/_internal/server_connection.py Co-authored-by: graphite-app[bot] <96075541+graphite-app[bot]@users.noreply.github.com> --- .../snowpark/_internal/server_connection.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/src/snowflake/snowpark/_internal/server_connection.py b/src/snowflake/snowpark/_internal/server_connection.py index f8d2dc2c72..fb0dcfb675 100644 --- a/src/snowflake/snowpark/_internal/server_connection.py +++ b/src/snowflake/snowpark/_internal/server_connection.py @@ -765,8 +765,19 @@ def get_result_set( to_arrow=to_arrow, **kwargs, ) - except Exception: + + except ProgrammingError: raise cte_error + except Exception as retry_error: + # Log both errors for debugging + logger.error( + "Retry without CTE optimization failed with different error: %s. " + "Original CTE error: %s", + retry_error, + cte_error, + ) + raise # Raise the actual retry error, not the original + else: cte_query_id = getattr(cte_error, "sfqid", None) retry_query_id = ( From 1d847d054708f24f23ba25cd56c1d60fca4e00e6 Mon Sep 17 00:00:00 2001 From: Adam Ling Date: Fri, 27 Mar 2026 17:23:25 -0700 Subject: [PATCH 6/6] lint --- src/snowflake/snowpark/_internal/server_connection.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/snowflake/snowpark/_internal/server_connection.py b/src/snowflake/snowpark/_internal/server_connection.py index fb0dcfb675..3422854f7f 100644 --- a/src/snowflake/snowpark/_internal/server_connection.py +++ b/src/snowflake/snowpark/_internal/server_connection.py @@ -809,7 +809,6 @@ def get_result_set( "the remainder of this session to avoid further " "performance impact.", context._cte_error_threshold, - ) self._telemetry_client.send_cte_optimization_auto_disabled_telemetry( session_id=self.get_session_id(),