diff --git a/src/snowflake/snowpark/_internal/compiler/plan_compiler.py b/src/snowflake/snowpark/_internal/compiler/plan_compiler.py index e02a495554..5cd20c5510 100644 --- a/src/snowflake/snowpark/_internal/compiler/plan_compiler.py +++ b/src/snowflake/snowpark/_internal/compiler/plan_compiler.py @@ -4,7 +4,7 @@ import copy import logging -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import ( get_complexity_score, @@ -53,7 +53,9 @@ class PlanCompiler: def __init__(self, plan: SnowflakePlan) -> None: self._plan = plan - def should_start_query_compilation(self) -> bool: + def should_start_query_compilation( + self, cte_enabled: Optional[bool] = None + ) -> bool: """ Whether optimization should be applied to the plan or not. Optimization can be applied if @@ -61,6 +63,12 @@ def should_start_query_compilation(self) -> bool: 2) the query compilation stage is enabled 3) optimizations are enabled in the current session, such as cte_optimization_enabled + Parameters + ---------- + cte_enabled : Optional[bool] + When provided, overrides ``session.cte_optimization_enabled`` for this + check only. Pass ``False`` to test whether compilation should proceed + with CTE optimization explicitly disabled (e.g. retry path). Returns ------- @@ -68,17 +76,21 @@ def should_start_query_compilation(self) -> bool: """ current_session = self._plan.session + effective_cte = ( + cte_enabled + if cte_enabled is not None + else current_session.cte_optimization_enabled + ) return ( not isinstance(current_session._conn, MockServerConnection) and (self._plan.source_plan is not None) and current_session._query_compilation_stage_enabled - and ( - current_session.cte_optimization_enabled - or current_session.large_query_breakdown_enabled - ) + and (effective_cte or current_session.large_query_breakdown_enabled) ) - def compile(self) -> Dict[PlanQueryType, List[Query]]: + def compile( + self, cte_enabled: Optional[bool] = None + ) -> Dict[PlanQueryType, List[Query]]: # initialize the queries with the original queries without optimization final_plan = self._plan queries = { @@ -86,8 +98,13 @@ def compile(self) -> Dict[PlanQueryType, List[Query]]: PlanQueryType.POST_ACTIONS: final_plan.post_actions, } - if self.should_start_query_compilation(): + if self.should_start_query_compilation(cte_enabled=cte_enabled): session = self._plan.session + effective_cte = ( + cte_enabled + if cte_enabled is not None + else session.cte_optimization_enabled + ) try: with measure_time() as total_time: # preparation for compilation @@ -107,7 +124,7 @@ def compile(self) -> Dict[PlanQueryType, List[Query]]: # 3. apply each optimizations if needed # CTE optimization with measure_time() as cte_time: - if session.cte_optimization_enabled: + if effective_cte: repeated_subquery_eliminator = RepeatedSubqueryElimination( logical_plans, query_generator ) @@ -151,7 +168,7 @@ def compile(self) -> Dict[PlanQueryType, List[Query]]: # log telemetry data summary_value = { - TelemetryField.CTE_OPTIMIZATION_ENABLED.value: session.cte_optimization_enabled, + TelemetryField.CTE_OPTIMIZATION_ENABLED.value: effective_cte, TelemetryField.LARGE_QUERY_BREAKDOWN_ENABLED.value: session.large_query_breakdown_enabled, CompilationStageTelemetryField.COMPLEXITY_SCORE_BOUNDS.value: session.large_query_breakdown_complexity_bounds, CompilationStageTelemetryField.TIME_TAKEN_FOR_COMPILATION.value: total_time(), diff --git a/src/snowflake/snowpark/_internal/compiler/telemetry_constants.py b/src/snowflake/snowpark/_internal/compiler/telemetry_constants.py index 326828f29a..f1f6c74bc6 100644 --- a/src/snowflake/snowpark/_internal/compiler/telemetry_constants.py +++ b/src/snowflake/snowpark/_internal/compiler/telemetry_constants.py @@ -45,6 +45,16 @@ class CompilationStageTelemetryField(Enum): # keys for repeated subquery elimination CTE_NODE_CREATED = "cte_node_created" + # types / categories for CTE execution fallback + TYPE_CTE_EXECUTION_FALLBACK = "snowpark_cte_execution_fallback" + CAT_CTE_EXECUTION_FALLBACK = "cte_execution_fallback" + CAT_CTE_AUTO_DISABLED = "cte_auto_disabled" + + # keys for CTE execution fallback events + CTE_FALLBACK_SFQID = "sfqid" + CTE_FALLBACK_RETRY_SUCCEEDED = "retry_succeeded" + CTE_FALLBACK_COUNT = "fallback_count" + # keys for large query breakdown BREAKDOWN_SUMMARY = "breakdown_summary" COMPLEXITY_SCORE_AFTER_CTE_OPTIMIZATION = "complexity_score_after_cte_optimization" diff --git a/src/snowflake/snowpark/_internal/server_connection.py b/src/snowflake/snowpark/_internal/server_connection.py index e0d4b63693..1d1ed6cac8 100644 --- a/src/snowflake/snowpark/_internal/server_connection.py +++ b/src/snowflake/snowpark/_internal/server_connection.py @@ -48,6 +48,7 @@ SnowflakePlan, ) from snowflake.snowpark._internal.ast.utils import DATAFRAME_AST_PARAMETER +from snowflake.snowpark._internal.compiler.plan_compiler import PlanCompiler from snowflake.snowpark._internal.error_message import SnowparkClientExceptionMessages from snowflake.snowpark._internal.telemetry import ( TelemetryClient, @@ -81,6 +82,23 @@ logger = getLogger(__name__) +# Number of CTE execution fallbacks after which CTE optimization is permanently +# disabled for the session. Kept as a private module variable so it can be +# adjusted in one place without touching session state. +_CTE_FALLBACK_AUTO_DISABLE_THRESHOLD = 3 + + +def _should_retry_cte_error(err: ProgrammingError) -> bool: + """Return True if *err* warrants re-executing the query with CTE optimization + disabled. + + Currently accepts any ``ProgrammingError``. The ``err.errno`` and + ``err.sqlstate`` attributes are intentionally available here for targeted + filtering once specific error codes are identified from production data. + """ + return True + + # parameters needed for usage tracking PARAM_APPLICATION = "application" PARAM_INTERNAL_APPLICATION_NAME = "internal_application_name" @@ -756,39 +774,110 @@ def get_result_set( else: dataframe_ast = None - for i, query in enumerate(main_queries): - if isinstance(query, BatchInsertQuery): - self.run_batch_insert(query.sql, query.rows, **kwargs) - else: - is_last = i == len(main_queries) - 1 and not block - final_query = query.sql - for holder, id_ in placeholders.items(): - final_query = final_query.replace(holder, id_) - if i == len(main_queries) - 1 and dataframe_ast: - kwargs[DATAFRAME_AST_PARAMETER] = dataframe_ast - is_final_query = i == len(main_queries) - 1 + # CTE retry is eligible when the sync path is used, CTE was + # applied to this plan, and there is exactly one main query + # (guarantees no partial side-effects that would make a full + # restart unsafe). Multi-query retry support is a TODO. + cte_retry_eligible = ( + block + and plan.session.cte_optimization_enabled + and len(main_queries) == 1 + ) + try: + for i, query in enumerate(main_queries): + if isinstance(query, BatchInsertQuery): + self.run_batch_insert(query.sql, query.rows, **kwargs) + else: + is_last = i == len(main_queries) - 1 and not block + final_query = query.sql + for holder, id_ in placeholders.items(): + final_query = final_query.replace(holder, id_) + if i == len(main_queries) - 1 and dataframe_ast: + kwargs[DATAFRAME_AST_PARAMETER] = dataframe_ast + is_final_query = i == len(main_queries) - 1 + result = self.run_query( + final_query, + to_pandas, + to_iter and is_final_query, + is_ddl_on_temp_object=query.is_ddl_on_temp_object, + block=not is_last, + data_type=data_type, + async_job_plan=plan, + log_on_exception=log_on_exception, + case_sensitive=case_sensitive, + params=query.params, + ignore_results=ignore_results, + async_post_actions=post_actions, + to_arrow=to_arrow and is_final_query, + **kwargs, + ) + placeholders[query.query_id_place_holder] = ( + result["sfqid"] if not is_last else result.query_id + ) + result_meta = get_new_description(self._cursor) + if action_id < plan.session._last_canceled_id: + raise SnowparkClientExceptionMessages.SERVER_QUERY_IS_CANCELLED() + except ProgrammingError as exec_err: + if not (cte_retry_eligible and _should_retry_cte_error(exec_err)): + raise + + # --- CTE execution fallback --- + # Recompile the plan with CTE optimization disabled and + # re-execute. We never mutate session state here; the + # cte_enabled override is local to this compile() call. + sfqid = getattr(exec_err, "sfqid", None) + retry_succeeded = False + try: + retry_plan_queries = PlanCompiler(plan).compile( + cte_enabled=False + ) + retry_main = retry_plan_queries[PlanQueryType.QUERIES] + retry_query = retry_main[0] result = self.run_query( - final_query, + retry_query.sql, to_pandas, - to_iter and is_final_query, - is_ddl_on_temp_object=query.is_ddl_on_temp_object, - block=not is_last, + to_iter, + is_ddl_on_temp_object=retry_query.is_ddl_on_temp_object, + block=True, data_type=data_type, async_job_plan=plan, log_on_exception=log_on_exception, case_sensitive=case_sensitive, - params=query.params, + params=retry_query.params, ignore_results=ignore_results, - async_post_actions=post_actions, - to_arrow=to_arrow and is_final_query, + to_arrow=to_arrow, **kwargs, ) - placeholders[query.query_id_place_holder] = ( - result["sfqid"] if not is_last else result.query_id - ) result_meta = get_new_description(self._cursor) - if action_id < plan.session._last_canceled_id: - raise SnowparkClientExceptionMessages.SERVER_QUERY_IS_CANCELLED() + # Switch post_actions to the retry compilation so the + # finally block cleans up the correct temp objects. + post_actions = retry_plan_queries[PlanQueryType.POST_ACTIONS] + retry_succeeded = True + except Exception: + pass # fall through to telemetry then re-raise + + session = plan.session + session._cte_optimization_fallback_count += 1 + fallback_count = session._cte_optimization_fallback_count + self._telemetry_client.send_cte_execution_fallback_telemetry( + session_id=session.session_id, + plan_uuid=plan.uuid, + sfqid=sfqid, + error_type=type(exec_err).__name__, + error_message=str(exec_err), + api_calls=plan.api_calls, + retry_succeeded=retry_succeeded, + fallback_count=fallback_count, + ) + if fallback_count >= _CTE_FALLBACK_AUTO_DISABLE_THRESHOLD: + session._cte_optimization_enabled = False + self._telemetry_client.send_cte_auto_disabled_telemetry( + session_id=session.session_id, + fallback_count=fallback_count, + ) + + if not retry_succeeded: + raise exec_err finally: # delete created tmp object if block: diff --git a/src/snowflake/snowpark/_internal/telemetry.py b/src/snowflake/snowpark/_internal/telemetry.py index 0eab8a0717..e339acfbbd 100644 --- a/src/snowflake/snowpark/_internal/telemetry.py +++ b/src/snowflake/snowpark/_internal/telemetry.py @@ -738,6 +738,52 @@ def send_query_compilation_stage_failed_telemetry( } self.send(message) + def send_cte_execution_fallback_telemetry( + self, + session_id: int, + plan_uuid: str, + sfqid: Optional[str], + error_type: str, + error_message: str, + api_calls: List[Dict], + retry_succeeded: bool, + fallback_count: int, + ) -> None: + message = { + **self._create_basic_telemetry_data( + CompilationStageTelemetryField.TYPE_CTE_EXECUTION_FALLBACK.value + ), + TelemetryField.KEY_DATA.value: { + TelemetryField.SESSION_ID.value: session_id, + TelemetryField.KEY_CATEGORY.value: CompilationStageTelemetryField.CAT_CTE_EXECUTION_FALLBACK.value, + CompilationStageTelemetryField.PLAN_UUID.value: plan_uuid, + CompilationStageTelemetryField.CTE_FALLBACK_SFQID.value: sfqid, + CompilationStageTelemetryField.ERROR_TYPE.value: error_type, + CompilationStageTelemetryField.ERROR_MESSAGE.value: error_message, + TelemetryField.KEY_API_CALLS.value: api_calls, + CompilationStageTelemetryField.CTE_FALLBACK_RETRY_SUCCEEDED.value: retry_succeeded, + CompilationStageTelemetryField.CTE_FALLBACK_COUNT.value: fallback_count, + }, + } + self.send(message) + + def send_cte_auto_disabled_telemetry( + self, + session_id: int, + fallback_count: int, + ) -> None: + message = { + **self._create_basic_telemetry_data( + CompilationStageTelemetryField.TYPE_CTE_EXECUTION_FALLBACK.value + ), + TelemetryField.KEY_DATA.value: { + TelemetryField.SESSION_ID.value: session_id, + TelemetryField.KEY_CATEGORY.value: CompilationStageTelemetryField.CAT_CTE_AUTO_DISABLED.value, + CompilationStageTelemetryField.CTE_FALLBACK_COUNT.value: fallback_count, + }, + } + self.send(message) + def send_plan_metrics_telemetry( self, session_id: int, data: Dict[str, Any] ) -> None: diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index ad583da1a7..c202b50f0b 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -662,6 +662,9 @@ def __init__( self._cte_optimization_enabled: bool = self.is_feature_enabled_for_version( _PYTHON_SNOWPARK_USE_CTE_OPTIMIZATION_VERSION ) + # Counts how many times CTE execution has fallen back to a non-CTE retry. + # Used to auto-disable CTE optimization when the threshold is exceeded. + self._cte_optimization_fallback_count: int = 0 self._use_logical_type_for_create_df: bool = ( self._conn._get_client_side_session_parameter( _PYTHON_SNOWPARK_USE_LOGICAL_TYPE_FOR_CREATE_DATAFRAME_STRING, True diff --git a/tests/integ/test_cte.py b/tests/integ/test_cte.py index 1adcbe3558..7ebdb15928 100644 --- a/tests/integ/test_cte.py +++ b/tests/integ/test_cte.py @@ -8,8 +8,9 @@ import pytest +from snowflake.connector.errors import ProgrammingError from snowflake.connector.options import installed_pandas -from snowflake.snowpark import Window +from snowflake.snowpark import Row, Window from snowflake.snowpark._internal.analyzer import analyzer from snowflake.snowpark._internal.analyzer.snowflake_plan import PlanQueryType from snowflake.snowpark._internal.utils import ( @@ -1317,6 +1318,105 @@ def test_table_select_cte(session): ) +def test_cte_execution_fallback_retries_on_programming_error(session): + """When a ProgrammingError is raised during execution and CTE is enabled, + the query should be retried with CTE disabled. The result should be + identical to what a non-CTE run would produce.""" + df = session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"]) + # Create a plan that triggers CTE optimization (duplicate subtree). + df_union = df.union_all(df) + + original_run_query = session._conn.run_query + call_count = [0] + + def run_query_first_call_fails(sql, *args, **kwargs): + call_count[0] += 1 + if call_count[0] == 1: + err = ProgrammingError("Simulated CTE execution error") + err.sfqid = "simulated_sfqid" + raise err + return original_run_query(sql, *args, **kwargs) + + with mock.patch.object( + session._conn, "run_query", side_effect=run_query_first_call_fails + ): + # Should not raise: the retry with CTE off must succeed. + result = df_union.collect() + + assert result == [Row(A=1, B=2), Row(A=3, B=4), Row(A=1, B=2), Row(A=3, B=4)] + assert call_count[0] == 2 # first call failed, second (retry) succeeded + assert session._cte_optimization_fallback_count == 1 + + +def test_cte_execution_fallback_auto_disables_after_threshold(session): + """After _CTE_FALLBACK_AUTO_DISABLE_THRESHOLD consecutive fallbacks, CTE + optimization is permanently disabled for the session.""" + from snowflake.snowpark._internal.server_connection import ( + _CTE_FALLBACK_AUTO_DISABLE_THRESHOLD, + ) + + df = session.create_dataframe([[1]], schema=["a"]) + df_union = df.union_all(df) + + original_run_query = session._conn.run_query + # Track which call within each collect() is the first (the CTE attempt). + call_count_per_collect = [0] + + def always_fail_first_call(sql, *args, **kwargs): + call_count_per_collect[0] += 1 + if call_count_per_collect[0] == 1: + err = ProgrammingError("Simulated CTE error") + err.sfqid = "sfqid" + raise err + return original_run_query(sql, *args, **kwargs) + + # Trigger fallback _CTE_FALLBACK_AUTO_DISABLE_THRESHOLD times. + for _ in range(_CTE_FALLBACK_AUTO_DISABLE_THRESHOLD): + call_count_per_collect[0] = 0 + with mock.patch.object( + session._conn, "run_query", side_effect=always_fail_first_call + ): + df_union.collect() + + assert session._cte_optimization_enabled is False + assert ( + session._cte_optimization_fallback_count == _CTE_FALLBACK_AUTO_DISABLE_THRESHOLD + ) + + +def test_cte_execution_fallback_telemetry_sent(session): + """Verify that send_cte_execution_fallback_telemetry is called with the + correct fields when a fallback occurs.""" + df = session.create_dataframe([[1, 2]], schema=["a", "b"]) + df_union = df.union_all(df) + + original_run_query = session._conn.run_query + call_count = [0] + + def fail_once(sql, *args, **kwargs): + call_count[0] += 1 + if call_count[0] == 1: + err = ProgrammingError("CTE error for telemetry test") + err.sfqid = "telemetry_sfqid" + raise err + return original_run_query(sql, *args, **kwargs) + + with mock.patch.object( + session._conn, "run_query", side_effect=fail_once + ), mock.patch.object( + session._conn._telemetry_client, + "send_cte_execution_fallback_telemetry", + ) as mock_telemetry: + df_union.collect() + + mock_telemetry.assert_called_once() + _, kwargs = mock_telemetry.call_args + assert kwargs["sfqid"] == "telemetry_sfqid" + assert kwargs["retry_succeeded"] is True + assert kwargs["fallback_count"] == 1 + assert kwargs["error_type"] == "ProgrammingError" + + @pytest.mark.parametrize( "reduce_describe_enabled,expected_describe_counts", [ diff --git a/tests/unit/test_cte_retry.py b/tests/unit/test_cte_retry.py new file mode 100644 index 0000000000..97788c69ba --- /dev/null +++ b/tests/unit/test_cte_retry.py @@ -0,0 +1,512 @@ +# +# Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved. +# + +"""Unit tests for the CTE execution fallback (retry) mechanism. + +These tests are pure-unit: no Snowflake connection is required. All I/O and +compilation are mocked so the tests run without any external dependencies. +""" + +from unittest import mock +from unittest.mock import MagicMock, patch + +import pytest + +from snowflake.connector import SnowflakeConnection +from snowflake.connector.cursor import SnowflakeCursor +from snowflake.connector.errors import ProgrammingError +from snowflake.snowpark._internal.analyzer.snowflake_plan import PlanQueryType, Query +from snowflake.snowpark._internal.server_connection import ( + ServerConnection, + _CTE_FALLBACK_AUTO_DISABLE_THRESHOLD, + _should_retry_cte_error, +) +from snowflake.snowpark.session import _PYTHON_SNOWPARK_GENERATE_MULTILINE_QUERIES + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_query(sql="SELECT 1 AS v"): + q = MagicMock(spec=Query) + q.sql = sql + q.params = [] + q.query_id_place_holder = "__QUERY_ID_PLACEHOLDER__" + q.is_ddl_on_temp_object = False + q.temp_obj_name_placeholder = None + return q + + +def _make_plan_queries(query=None, post_actions=None): + return { + PlanQueryType.QUERIES: [query or _make_query()], + PlanQueryType.POST_ACTIONS: post_actions or [], + } + + +def _make_mock_session(cte_enabled=True, fallback_count=0): + session = MagicMock() + session.cte_optimization_enabled = cte_enabled + session._cte_optimization_enabled = cte_enabled + session._cte_optimization_fallback_count = fallback_count + session._last_canceled_id = -1 + session._generate_new_action_id.return_value = 0 + session.session_id = "fake_session_id" + session._collect_snowflake_plan_telemetry_at_critical_path = False + return session + + +def _make_mock_plan(session=None, cte_queries=None, post_actions=None): + session = session or _make_mock_session() + plan = MagicMock() + plan.session = session + plan.uuid = "fake_plan_uuid" + plan.api_calls = [] + plan.execution_queries = _make_plan_queries( + query=cte_queries[0] if cte_queries else None, + post_actions=post_actions, + ) + return plan + + +def _make_server_connection(): + fake_conn = mock.create_autospec(SnowflakeConnection) + fake_conn._conn = MagicMock() + fake_conn._telemetry = None + fake_conn._session_parameters = {_PYTHON_SNOWPARK_GENERATE_MULTILINE_QUERIES: True} + fake_conn.cursor.return_value = mock.create_autospec(SnowflakeCursor) + fake_conn.is_closed.return_value = False + sc = ServerConnection({}, fake_conn) + sc._telemetry_client = MagicMock() + return sc + + +def _run_query_success(sql="SELECT 1 AS v"): + """Return value for a successful blocking run_query call.""" + return {"sfqid": "fake_sfqid", "data": [[1]]} + + +# --------------------------------------------------------------------------- +# Tests: _should_retry_cte_error +# --------------------------------------------------------------------------- + + +def test_should_retry_cte_error_true_for_any_programming_error(): + assert _should_retry_cte_error(ProgrammingError("any error")) is True + + +def test_should_retry_cte_error_exposes_errno_and_sqlstate(): + """errno / sqlstate are accessible for future targeted filtering.""" + err = ProgrammingError("msg", errno=1234, sqlstate="42000") + assert err.errno == 1234 + assert err.sqlstate == "42000" + # Still returns True — no filtering yet. + assert _should_retry_cte_error(err) is True + + +# --------------------------------------------------------------------------- +# Tests: _CTE_FALLBACK_AUTO_DISABLE_THRESHOLD module constant +# --------------------------------------------------------------------------- + + +def test_auto_disable_threshold_is_positive_int(): + assert isinstance(_CTE_FALLBACK_AUTO_DISABLE_THRESHOLD, int) + assert _CTE_FALLBACK_AUTO_DISABLE_THRESHOLD > 0 + + +# --------------------------------------------------------------------------- +# Tests: get_result_set — retry eligibility guards +# --------------------------------------------------------------------------- + + +@patch( + "snowflake.snowpark._internal.server_connection.get_new_description", + return_value=[], +) +def test_no_retry_when_no_error(mock_desc): + """Happy path: successful execution never touches the retry path.""" + sc = _make_server_connection() + plan = _make_mock_plan() + success_result = _run_query_success() + + with patch.object(sc, "run_query", return_value=success_result): + result, _ = sc.get_result_set(plan, block=True) + + assert result == success_result + sc._telemetry_client.send_cte_execution_fallback_telemetry.assert_not_called() + sc._telemetry_client.send_cte_auto_disabled_telemetry.assert_not_called() + + +@patch( + "snowflake.snowpark._internal.server_connection.get_new_description", + return_value=[], +) +def test_no_retry_when_block_false(mock_desc): + """block=False (async) is not eligible for retry.""" + sc = _make_server_connection() + cte_query = _make_query("WITH cte AS (SELECT 1) SELECT * FROM cte") + session = _make_mock_session(cte_enabled=True) + plan = _make_mock_plan(session=session, cte_queries=[cte_query]) + + prog_err = ProgrammingError("CTE error") + prog_err.sfqid = "bad_sfqid" + + with patch.object(sc, "run_query", side_effect=prog_err): + with pytest.raises(ProgrammingError, match="CTE error"): + sc.get_result_set(plan, block=False) + + sc._telemetry_client.send_cte_execution_fallback_telemetry.assert_not_called() + + +@patch( + "snowflake.snowpark._internal.server_connection.get_new_description", + return_value=[], +) +def test_no_retry_when_cte_disabled(mock_desc): + """CTE optimization was off — no fallback attempt.""" + sc = _make_server_connection() + session = _make_mock_session(cte_enabled=False) + plan = _make_mock_plan(session=session) + + prog_err = ProgrammingError("some SQL error") + + with patch.object(sc, "run_query", side_effect=prog_err): + with pytest.raises(ProgrammingError): + sc.get_result_set(plan, block=True) + + sc._telemetry_client.send_cte_execution_fallback_telemetry.assert_not_called() + + +@patch( + "snowflake.snowpark._internal.server_connection.get_new_description", + return_value=[], +) +def test_no_retry_when_multiple_main_queries(mock_desc): + """Multi-query plans are not retried (partial side-effects risk).""" + sc = _make_server_connection() + session = _make_mock_session(cte_enabled=True) + + q1 = _make_query("CREATE TABLE t AS SELECT 1") + q2 = _make_query("SELECT * FROM t") + plan = MagicMock() + plan.session = session + plan.uuid = "uuid" + plan.api_calls = [] + plan.execution_queries = { + PlanQueryType.QUERIES: [q1, q2], + PlanQueryType.POST_ACTIONS: [], + } + + prog_err = ProgrammingError("CTE error") + + with patch.object(sc, "run_query", side_effect=prog_err): + with pytest.raises(ProgrammingError): + sc.get_result_set(plan, block=True) + + sc._telemetry_client.send_cte_execution_fallback_telemetry.assert_not_called() + + +# --------------------------------------------------------------------------- +# Tests: get_result_set — successful retry +# --------------------------------------------------------------------------- + + +@patch( + "snowflake.snowpark._internal.server_connection.get_new_description", + return_value=[], +) +def test_retry_succeeds_returns_non_cte_result(mock_desc): + """When the CTE query fails but the non-CTE retry succeeds, the retry + result is returned and the original error is suppressed.""" + sc = _make_server_connection() + session = _make_mock_session(cte_enabled=True) + cte_query = _make_query("WITH cte AS (SELECT 1) SELECT * FROM cte") + plan = _make_mock_plan(session=session, cte_queries=[cte_query]) + + prog_err = ProgrammingError("CTE error") + prog_err.sfqid = "failed_sfqid" + retry_result = {"sfqid": "retry_sfqid", "data": [[1]]} + non_cte_query = _make_query("SELECT 1") # no CTE + + retry_plan_queries = { + PlanQueryType.QUERIES: [non_cte_query], + PlanQueryType.POST_ACTIONS: [], + } + + with patch.object(sc, "run_query", side_effect=[prog_err, retry_result]), patch( + "snowflake.snowpark._internal.server_connection.PlanCompiler" + ) as MockCompiler: + MockCompiler.return_value.compile.return_value = retry_plan_queries + result, _ = sc.get_result_set(plan, block=True) + + assert result == retry_result + MockCompiler.assert_called_once_with(plan) + MockCompiler.return_value.compile.assert_called_once_with(cte_enabled=False) + + +@patch( + "snowflake.snowpark._internal.server_connection.get_new_description", + return_value=[], +) +def test_retry_increments_fallback_count(mock_desc): + sc = _make_server_connection() + session = _make_mock_session(cte_enabled=True, fallback_count=0) + plan = _make_mock_plan(session=session) + + prog_err = ProgrammingError("CTE error") + prog_err.sfqid = "sfqid1" + retry_result = _run_query_success() + + retry_plan_queries = _make_plan_queries() + + with patch.object(sc, "run_query", side_effect=[prog_err, retry_result]), patch( + "snowflake.snowpark._internal.server_connection.PlanCompiler" + ) as MockCompiler: + MockCompiler.return_value.compile.return_value = retry_plan_queries + sc.get_result_set(plan, block=True) + + assert session._cte_optimization_fallback_count == 1 + + +@patch( + "snowflake.snowpark._internal.server_connection.get_new_description", + return_value=[], +) +def test_retry_sends_fallback_telemetry_on_success(mock_desc): + sc = _make_server_connection() + session = _make_mock_session(cte_enabled=True, fallback_count=0) + plan = _make_mock_plan(session=session) + + prog_err = ProgrammingError("CTE error") + prog_err.sfqid = "bad_sfqid" + retry_result = _run_query_success() + retry_plan_queries = _make_plan_queries() + + with patch.object(sc, "run_query", side_effect=[prog_err, retry_result]), patch( + "snowflake.snowpark._internal.server_connection.PlanCompiler" + ) as MockCompiler: + MockCompiler.return_value.compile.return_value = retry_plan_queries + sc.get_result_set(plan, block=True) + + sc._telemetry_client.send_cte_execution_fallback_telemetry.assert_called_once_with( + session_id="fake_session_id", + plan_uuid="fake_plan_uuid", + sfqid="bad_sfqid", + error_type="ProgrammingError", + error_message=str(prog_err), + api_calls=[], + retry_succeeded=True, + fallback_count=1, + ) + + +# --------------------------------------------------------------------------- +# Tests: get_result_set — failed retry +# --------------------------------------------------------------------------- + + +@patch( + "snowflake.snowpark._internal.server_connection.get_new_description", + return_value=[], +) +def test_retry_fails_reraises_original_error(mock_desc): + """If the non-CTE retry also fails, the ORIGINAL error is re-raised (not + the retry error), preserving the original sfqid for debugging.""" + sc = _make_server_connection() + session = _make_mock_session(cte_enabled=True) + plan = _make_mock_plan(session=session) + + original_err = ProgrammingError("original CTE error") + original_err.sfqid = "original_sfqid" + retry_err = ProgrammingError("retry also failed") + + retry_plan_queries = _make_plan_queries() + + with patch.object(sc, "run_query", side_effect=[original_err, retry_err]), patch( + "snowflake.snowpark._internal.server_connection.PlanCompiler" + ) as MockCompiler: + MockCompiler.return_value.compile.return_value = retry_plan_queries + with pytest.raises(ProgrammingError) as exc_info: + sc.get_result_set(plan, block=True) + + assert exc_info.value is original_err + + +@patch( + "snowflake.snowpark._internal.server_connection.get_new_description", + return_value=[], +) +def test_retry_fails_sends_fallback_telemetry_with_retry_succeeded_false(mock_desc): + sc = _make_server_connection() + session = _make_mock_session(cte_enabled=True, fallback_count=0) + plan = _make_mock_plan(session=session) + + original_err = ProgrammingError("original error") + original_err.sfqid = "orig_sfqid" + retry_err = ProgrammingError("retry error") + retry_plan_queries = _make_plan_queries() + + with patch.object(sc, "run_query", side_effect=[original_err, retry_err]), patch( + "snowflake.snowpark._internal.server_connection.PlanCompiler" + ) as MockCompiler: + MockCompiler.return_value.compile.return_value = retry_plan_queries + with pytest.raises(ProgrammingError): + sc.get_result_set(plan, block=True) + + sc._telemetry_client.send_cte_execution_fallback_telemetry.assert_called_once() + _, kwargs = sc._telemetry_client.send_cte_execution_fallback_telemetry.call_args + assert kwargs["retry_succeeded"] is False + assert kwargs["sfqid"] == "orig_sfqid" + + +# --------------------------------------------------------------------------- +# Tests: auto-disable +# --------------------------------------------------------------------------- + + +@patch( + "snowflake.snowpark._internal.server_connection.get_new_description", + return_value=[], +) +def test_auto_disable_triggers_at_threshold(mock_desc): + """After _CTE_FALLBACK_AUTO_DISABLE_THRESHOLD successful retries the + session's CTE flag is permanently turned off.""" + sc = _make_server_connection() + # Start just below the threshold so this call pushes it over. + session = _make_mock_session( + cte_enabled=True, + fallback_count=_CTE_FALLBACK_AUTO_DISABLE_THRESHOLD - 1, + ) + plan = _make_mock_plan(session=session) + + prog_err = ProgrammingError("CTE error") + prog_err.sfqid = "sfqid" + retry_result = _run_query_success() + retry_plan_queries = _make_plan_queries() + + with patch.object(sc, "run_query", side_effect=[prog_err, retry_result]), patch( + "snowflake.snowpark._internal.server_connection.PlanCompiler" + ) as MockCompiler: + MockCompiler.return_value.compile.return_value = retry_plan_queries + sc.get_result_set(plan, block=True) + + assert session._cte_optimization_enabled is False + + +@patch( + "snowflake.snowpark._internal.server_connection.get_new_description", + return_value=[], +) +def test_auto_disable_sends_separate_telemetry_event(mock_desc): + sc = _make_server_connection() + session = _make_mock_session( + cte_enabled=True, + fallback_count=_CTE_FALLBACK_AUTO_DISABLE_THRESHOLD - 1, + ) + plan = _make_mock_plan(session=session) + + prog_err = ProgrammingError("CTE error") + prog_err.sfqid = "sfqid" + retry_result = _run_query_success() + retry_plan_queries = _make_plan_queries() + + with patch.object(sc, "run_query", side_effect=[prog_err, retry_result]), patch( + "snowflake.snowpark._internal.server_connection.PlanCompiler" + ) as MockCompiler: + MockCompiler.return_value.compile.return_value = retry_plan_queries + sc.get_result_set(plan, block=True) + + # auto-disable event is a separate call from the per-fallback event + sc._telemetry_client.send_cte_execution_fallback_telemetry.assert_called_once() + sc._telemetry_client.send_cte_auto_disabled_telemetry.assert_called_once_with( + session_id="fake_session_id", + fallback_count=_CTE_FALLBACK_AUTO_DISABLE_THRESHOLD, + ) + + +@patch( + "snowflake.snowpark._internal.server_connection.get_new_description", + return_value=[], +) +def test_no_auto_disable_below_threshold(mock_desc): + """fallback_count below threshold: CTE stays enabled.""" + sc = _make_server_connection() + session = _make_mock_session(cte_enabled=True, fallback_count=0) + plan = _make_mock_plan(session=session) + + prog_err = ProgrammingError("CTE error") + prog_err.sfqid = "sfqid" + retry_result = _run_query_success() + retry_plan_queries = _make_plan_queries() + + with patch.object(sc, "run_query", side_effect=[prog_err, retry_result]), patch( + "snowflake.snowpark._internal.server_connection.PlanCompiler" + ) as MockCompiler: + MockCompiler.return_value.compile.return_value = retry_plan_queries + sc.get_result_set(plan, block=True) + + # Only triggered when count reaches threshold (3), count is now 1. + assert session._cte_optimization_enabled is True + sc._telemetry_client.send_cte_auto_disabled_telemetry.assert_not_called() + + +# --------------------------------------------------------------------------- +# Tests: PlanCompiler.compile cte_enabled override +# --------------------------------------------------------------------------- + + +def test_plan_compiler_cte_enabled_false_skips_cte_step(): + """compile(cte_enabled=False) must not apply CTE optimization even when + session.cte_optimization_enabled is True.""" + from snowflake.snowpark._internal.compiler.plan_compiler import PlanCompiler + + plan = MagicMock() + plan.session.cte_optimization_enabled = True + plan.session._cte_optimization_enabled = True + plan.session.large_query_breakdown_enabled = False + plan.session._query_compilation_stage_enabled = True + + compiler = PlanCompiler(plan) + # When should_start_query_compilation is called with cte_enabled=False and + # large_query_breakdown_enabled=False, it must return False. + assert compiler.should_start_query_compilation(cte_enabled=False) is False + + +def test_plan_compiler_cte_enabled_none_uses_session_setting(): + """compile() with no override reads from session as before.""" + from snowflake.snowpark._internal.compiler.plan_compiler import PlanCompiler + + plan = MagicMock() + plan.session.cte_optimization_enabled = True + plan.session._cte_optimization_enabled = True + plan.session.large_query_breakdown_enabled = False + plan.session._query_compilation_stage_enabled = True + plan.source_plan = MagicMock() + + plan.session._conn.__class__ = object # not MockServerConnection + + compiler = PlanCompiler(plan) + # cte_enabled=None => falls back to session.cte_optimization_enabled=True + # source_plan is set and query_compilation_stage_enabled=True => True + assert compiler.should_start_query_compilation(cte_enabled=None) is True + + +def test_plan_compiler_cte_enabled_false_with_lqb_still_compiles(): + """If LQB is enabled, compilation still proceeds even with cte_enabled=False + because large_query_breakdown_enabled stays True.""" + from snowflake.snowpark._internal.compiler.plan_compiler import PlanCompiler + + plan = MagicMock() + plan.session.cte_optimization_enabled = True + plan.session._cte_optimization_enabled = True + plan.session.large_query_breakdown_enabled = True + plan.session._query_compilation_stage_enabled = True + plan.source_plan = MagicMock() + + plan.session._conn.__class__ = object # not MockServerConnection + + compiler = PlanCompiler(plan) + assert compiler.should_start_query_compilation(cte_enabled=False) is True