Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,6 +524,14 @@ def execution_queries(self) -> Dict["PlanQueryType", List["Query"]]:
compiler = PlanCompiler(self)
return compiler.compile()

def get_execution_queries_without_cte(self) -> Dict["PlanQueryType", List["Query"]]:
"""Compile execution queries with CTE optimization skipped.
Used for retry when CTE-optimized SQL fails at the server."""
from snowflake.snowpark._internal.compiler.plan_compiler import PlanCompiler

compiler = PlanCompiler(self)
return compiler.compile(skip_cte_optimization=True)

@property
def children_plan_nodes(self) -> List[Union["Selectable", "SnowflakePlan"]]:
"""
Expand Down
9 changes: 7 additions & 2 deletions src/snowflake/snowpark/_internal/compiler/plan_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,9 @@ def should_start_query_compilation(self) -> bool:
)
)

def compile(self) -> Dict[PlanQueryType, List[Query]]:
def compile(
self, *, skip_cte_optimization: bool = False
) -> Dict[PlanQueryType, List[Query]]:
# initialize the queries with the original queries without optimization
final_plan = self._plan
queries = {
Expand Down Expand Up @@ -107,7 +109,10 @@ def compile(self) -> Dict[PlanQueryType, List[Query]]:
# 3. apply each optimizations if needed
# CTE optimization
with measure_time() as cte_time:
if session.cte_optimization_enabled:
if (
session.cte_optimization_enabled
and not skip_cte_optimization
):
repeated_subquery_eliminator = RepeatedSubqueryElimination(
logical_plans, query_generator
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,11 @@ class CompilationStageTelemetryField(Enum):
# keys for repeated subquery elimination
CTE_NODE_CREATED = "cte_node_created"

# keys for CTE execution retry
CTE_ERROR_COUNT = "cte_error_count"
CTE_QUERY_ID = "cte_query_id"
RETRY_QUERY_ID = "retry_query_id"

# keys for large query breakdown
BREAKDOWN_SUMMARY = "breakdown_summary"
COMPLEXITY_SCORE_AFTER_CTE_OPTIMIZATION = "complexity_score_after_cte_optimization"
Expand Down
139 changes: 135 additions & 4 deletions src/snowflake/snowpark/_internal/server_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
result_set_to_rows,
unwrap_stage_location_single_quote,
)
from snowflake.snowpark import context
from snowflake.snowpark.async_job import AsyncJob, _AsyncResultType
from snowflake.snowpark.query_history import QueryListener, QueryRecord
from snowflake.snowpark.row import Row
Expand Down Expand Up @@ -696,12 +697,145 @@ def get_result_set(
],
Union[List[ResultMetadata], List["ResultMetadataV2"]],
]:
action_id = plan.session._generate_new_action_id()
session = plan.session
action_id = session._generate_new_action_id()
plan_queries = plan.execution_queries
result, result_meta = None, None
statement_params = kwargs.get("_statement_params", None) or {}
statement_params["_PLAN_UUID"] = plan.uuid
kwargs["_statement_params"] = statement_params

should_retry_without_cte = (
block # Snowpark Connect does not use async queries. Async queries retry tracked by SNOW-3293313
and context._is_snowpark_connect_compatible_mode
and session.cte_optimization_enabled
)

try:
result, result_meta = self._execute_queries(
plan,
plan_queries,
action_id,
to_pandas=to_pandas,
to_iter=to_iter,
block=block,
data_type=data_type,
log_on_exception=log_on_exception,
case_sensitive=case_sensitive,
ignore_results=ignore_results,
to_arrow=to_arrow,
**kwargs,
)
except ProgrammingError as cte_error:
if not should_retry_without_cte:
raise

# Skip retry if CTE optimization didn't affect the SQL.
# Check for the Snowpark CTE prefix to cover both bare SELECTs
# and DML with embedded CTEs (e.g. INSERT INTO t WITH ...).
main_queries = plan_queries[PlanQueryType.QUERIES]
if not any("SNOWPARK_TEMP_CTE_" in q.sql for q in main_queries):
raise

unoptimized_plan_queries = plan.get_execution_queries_without_cte()

logger.debug(
"CTE-optimized query failed with ProgrammingError: %s. "
"Retrying without CTE optimization.",
cte_error,
)

# Retry safety: CTE errors are compilation-time rejections (no
# partial DML side effects). Setup DDLs are idempotent, post-
# actions run in a finally block (cleanup on failure), and
# compile() generates fresh temp object names each call.
try:
retry_action_id = session._generate_new_action_id()
result, result_meta = self._execute_queries(
plan,
unoptimized_plan_queries,
retry_action_id,
to_pandas=to_pandas,
to_iter=to_iter,
block=block,
data_type=data_type,
log_on_exception=log_on_exception,
case_sensitive=case_sensitive,
ignore_results=ignore_results,
to_arrow=to_arrow,
**kwargs,
)

except ProgrammingError:
raise cte_error
except Exception as retry_error:
# Log both errors for debugging
logger.error(
"Retry without CTE optimization failed with different error: %s. "
"Original CTE error: %s",
retry_error,
cte_error,
)
raise # Raise the actual retry error, not the original

else:
cte_query_id = getattr(cte_error, "sfqid", None)
retry_query_id = (
result.get("sfqid") if isinstance(result, dict) else None
)

self._telemetry_client.send_cte_execution_retry_telemetry(
session_id=self.get_session_id(),
plan_uuid=plan.uuid,
error_message=str(cte_error),
api_calls=plan.api_calls,
cte_query_id=cte_query_id,
retry_query_id=retry_query_id,
)

if context._cte_error_threshold > 0:
cte_disabled = False
with session._lock:
session._cte_error_count += 1
if session._cte_error_count >= context._cte_error_threshold:
session._cte_optimization_enabled = False
cte_disabled = True

if cte_disabled:

logger.warning(
"CTE optimization has caused %d execution failures. "
"Auto-disabling CTE optimization for "
"the remainder of this session to avoid further "
"performance impact.",
context._cte_error_threshold,
)
self._telemetry_client.send_cte_optimization_auto_disabled_telemetry(
session_id=self.get_session_id(),
cte_error_count=context._cte_error_threshold,
)

if result is None:
raise SnowparkClientExceptionMessages.SQL_LAST_QUERY_RETURN_RESULTSET()

return result, result_meta

def _execute_queries(
self,
plan: SnowflakePlan,
plan_queries,
action_id: int,
to_pandas: bool = False,
to_iter: bool = False,
block: bool = True,
data_type: _AsyncResultType = _AsyncResultType.ROW,
log_on_exception: bool = False,
case_sensitive: bool = True,
ignore_results: bool = False,
to_arrow: bool = False,
**kwargs,
) -> Tuple[Optional[Any], Optional[Any]]:
result, result_meta = None, None
try:
main_queries = plan_queries[PlanQueryType.QUERIES]
post_actions = plan_queries[PlanQueryType.POST_ACTIONS]
Expand Down Expand Up @@ -810,9 +944,6 @@ def get_result_set(
data=get_plan_telemetry_metrics(plan),
)

if result is None:
raise SnowparkClientExceptionMessages.SQL_LAST_QUERY_RETURN_RESULTSET()

return result, result_meta

def get_result_and_metadata(
Expand Down
42 changes: 42 additions & 0 deletions src/snowflake/snowpark/_internal/telemetry.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ class TelemetryField(Enum):
TYPE_CURSOR_CREATED = "snowpark_cursor_created"
TYPE_SQL_SIMPLIFIER_ENABLED = "snowpark_sql_simplifier_enabled"
TYPE_CTE_OPTIMIZATION_ENABLED = "snowpark_cte_optimization_enabled"
TYPE_CTE_EXECUTION_RETRY = "snowpark_cte_execution_retry"
TYPE_CTE_OPTIMIZATION_AUTO_DISABLED = "snowpark_cte_optimization_auto_disabled"
# telemetry for optimization that eliminates the extra cast expression generated for expressions
TYPE_ELIMINATE_NUMERIC_SQL_VALUE_CAST_ENABLED = (
"snowpark_eliminate_numeric_sql_value_cast_enabled"
Expand Down Expand Up @@ -660,6 +662,46 @@ def send_cte_optimization_telemetry(self, session_id: str) -> None:
}
self.send(message)

def send_cte_execution_retry_telemetry(
self,
session_id: str,
plan_uuid: str,
error_message: str,
api_calls: Optional[List[Dict[str, Any]]],
cte_query_id: Optional[str],
retry_query_id: Optional[str],
) -> None:
message = {
**self._create_basic_telemetry_data(
TelemetryField.TYPE_CTE_EXECUTION_RETRY.value
),
TelemetryField.KEY_DATA.value: {
TelemetryField.SESSION_ID.value: session_id,
CompilationStageTelemetryField.PLAN_UUID.value: plan_uuid,
CompilationStageTelemetryField.ERROR_MESSAGE.value: error_message,
TelemetryField.KEY_API_CALLS.value: api_calls,
CompilationStageTelemetryField.CTE_QUERY_ID.value: cte_query_id,
CompilationStageTelemetryField.RETRY_QUERY_ID.value: retry_query_id,
},
}
self.send(message)

def send_cte_optimization_auto_disabled_telemetry(
self,
session_id: str,
cte_error_count: int,
) -> None:
message = {
**self._create_basic_telemetry_data(
TelemetryField.TYPE_CTE_OPTIMIZATION_AUTO_DISABLED.value
),
TelemetryField.KEY_DATA.value: {
TelemetryField.SESSION_ID.value: session_id,
CompilationStageTelemetryField.CTE_ERROR_COUNT.value: cte_error_count,
},
}
self.send(message)

def send_eliminate_numeric_sql_value_cast_telemetry(
self, session_id: str, value: bool
) -> None:
Expand Down
2 changes: 2 additions & 0 deletions src/snowflake/snowpark/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
# This is an internal-only global flag, used to determine whether the api code which will be executed is compatible with snowflake.snowpark_connect
_is_snowpark_connect_compatible_mode = False

_cte_error_threshold = 3 # 0 to disable auto-cte-disable, otherwise the number of times CTE optimization can fail before it is automatically disabled for the remainder of the session.

# Following are internal-only global flags, used to enable development features.
_enable_dataframe_trace_on_error = False
_debug_eager_schema_validation = False
Expand Down
1 change: 1 addition & 0 deletions src/snowflake/snowpark/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -668,6 +668,7 @@ def __init__(
self._cte_optimization_enabled: bool = self.is_feature_enabled_for_version(
_PYTHON_SNOWPARK_USE_CTE_OPTIMIZATION_VERSION
)
self._cte_error_count: int = 0
self._use_logical_type_for_create_df: bool = (
self._conn._get_client_side_session_parameter(
_PYTHON_SNOWPARK_USE_LOGICAL_TYPE_FOR_CREATE_DATAFRAME_STRING, True
Expand Down
Loading
Loading