Skip to content
Draft

update #4131

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 27 additions & 10 deletions src/snowflake/snowpark/_internal/compiler/plan_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import copy
import logging
from typing import Any, Dict, List
from typing import Any, Dict, List, Optional

from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import (
get_complexity_score,
Expand Down Expand Up @@ -53,41 +53,58 @@ class PlanCompiler:
def __init__(self, plan: SnowflakePlan) -> None:
self._plan = plan

def should_start_query_compilation(self) -> bool:
def should_start_query_compilation(
self, cte_enabled: Optional[bool] = None
) -> bool:
"""
Whether optimization should be applied to the plan or not.
Optimization can be applied if
1) there is source logical plan attached to the current snowflake plan
2) the query compilation stage is enabled
3) optimizations are enabled in the current session, such as cte_optimization_enabled

Parameters
----------
cte_enabled : Optional[bool]
When provided, overrides ``session.cte_optimization_enabled`` for this
check only. Pass ``False`` to test whether compilation should proceed
with CTE optimization explicitly disabled (e.g. retry path).

Returns
-------
True if optimization should be applied. Otherwise, return False.
"""

current_session = self._plan.session
effective_cte = (
cte_enabled
if cte_enabled is not None
else current_session.cte_optimization_enabled
)
return (
not isinstance(current_session._conn, MockServerConnection)
and (self._plan.source_plan is not None)
and current_session._query_compilation_stage_enabled
and (
current_session.cte_optimization_enabled
or current_session.large_query_breakdown_enabled
)
and (effective_cte or current_session.large_query_breakdown_enabled)
)

def compile(self) -> Dict[PlanQueryType, List[Query]]:
def compile(
self, cte_enabled: Optional[bool] = None
) -> Dict[PlanQueryType, List[Query]]:
# initialize the queries with the original queries without optimization
final_plan = self._plan
queries = {
PlanQueryType.QUERIES: final_plan.queries,
PlanQueryType.POST_ACTIONS: final_plan.post_actions,
}

if self.should_start_query_compilation():
if self.should_start_query_compilation(cte_enabled=cte_enabled):
session = self._plan.session
effective_cte = (
cte_enabled
if cte_enabled is not None
else session.cte_optimization_enabled
)
try:
with measure_time() as total_time:
# preparation for compilation
Expand All @@ -107,7 +124,7 @@ def compile(self) -> Dict[PlanQueryType, List[Query]]:
# 3. apply each optimizations if needed
# CTE optimization
with measure_time() as cte_time:
if session.cte_optimization_enabled:
if effective_cte:
repeated_subquery_eliminator = RepeatedSubqueryElimination(
logical_plans, query_generator
)
Expand Down Expand Up @@ -151,7 +168,7 @@ def compile(self) -> Dict[PlanQueryType, List[Query]]:

# log telemetry data
summary_value = {
TelemetryField.CTE_OPTIMIZATION_ENABLED.value: session.cte_optimization_enabled,
TelemetryField.CTE_OPTIMIZATION_ENABLED.value: effective_cte,
TelemetryField.LARGE_QUERY_BREAKDOWN_ENABLED.value: session.large_query_breakdown_enabled,
CompilationStageTelemetryField.COMPLEXITY_SCORE_BOUNDS.value: session.large_query_breakdown_complexity_bounds,
CompilationStageTelemetryField.TIME_TAKEN_FOR_COMPILATION.value: total_time(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,16 @@ class CompilationStageTelemetryField(Enum):
# keys for repeated subquery elimination
CTE_NODE_CREATED = "cte_node_created"

# types / categories for CTE execution fallback
TYPE_CTE_EXECUTION_FALLBACK = "snowpark_cte_execution_fallback"
CAT_CTE_EXECUTION_FALLBACK = "cte_execution_fallback"
CAT_CTE_AUTO_DISABLED = "cte_auto_disabled"

# keys for CTE execution fallback events
CTE_FALLBACK_SFQID = "sfqid"
CTE_FALLBACK_RETRY_SUCCEEDED = "retry_succeeded"
CTE_FALLBACK_COUNT = "fallback_count"

# keys for large query breakdown
BREAKDOWN_SUMMARY = "breakdown_summary"
COMPLEXITY_SCORE_AFTER_CTE_OPTIMIZATION = "complexity_score_after_cte_optimization"
Expand Down
135 changes: 112 additions & 23 deletions src/snowflake/snowpark/_internal/server_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
SnowflakePlan,
)
from snowflake.snowpark._internal.ast.utils import DATAFRAME_AST_PARAMETER
from snowflake.snowpark._internal.compiler.plan_compiler import PlanCompiler
from snowflake.snowpark._internal.error_message import SnowparkClientExceptionMessages
from snowflake.snowpark._internal.telemetry import (
TelemetryClient,
Expand Down Expand Up @@ -81,6 +82,23 @@

logger = getLogger(__name__)

# Number of CTE execution fallbacks after which CTE optimization is permanently
# disabled for the session. Kept as a private module variable so it can be
# adjusted in one place without touching session state.
_CTE_FALLBACK_AUTO_DISABLE_THRESHOLD = 3


def _should_retry_cte_error(err: ProgrammingError) -> bool:
"""Return True if *err* warrants re-executing the query with CTE optimization
disabled.

Currently accepts any ``ProgrammingError``. The ``err.errno`` and
``err.sqlstate`` attributes are intentionally available here for targeted
filtering once specific error codes are identified from production data.
"""
return True


# parameters needed for usage tracking
PARAM_APPLICATION = "application"
PARAM_INTERNAL_APPLICATION_NAME = "internal_application_name"
Expand Down Expand Up @@ -756,39 +774,110 @@ def get_result_set(
else:
dataframe_ast = None

for i, query in enumerate(main_queries):
if isinstance(query, BatchInsertQuery):
self.run_batch_insert(query.sql, query.rows, **kwargs)
else:
is_last = i == len(main_queries) - 1 and not block
final_query = query.sql
for holder, id_ in placeholders.items():
final_query = final_query.replace(holder, id_)
if i == len(main_queries) - 1 and dataframe_ast:
kwargs[DATAFRAME_AST_PARAMETER] = dataframe_ast
is_final_query = i == len(main_queries) - 1
# CTE retry is eligible when the sync path is used, CTE was
# applied to this plan, and there is exactly one main query
# (guarantees no partial side-effects that would make a full
# restart unsafe). Multi-query retry support is a TODO.
cte_retry_eligible = (
block
and plan.session.cte_optimization_enabled
and len(main_queries) == 1
)
try:
for i, query in enumerate(main_queries):
if isinstance(query, BatchInsertQuery):
self.run_batch_insert(query.sql, query.rows, **kwargs)
else:
is_last = i == len(main_queries) - 1 and not block
final_query = query.sql
for holder, id_ in placeholders.items():
final_query = final_query.replace(holder, id_)
if i == len(main_queries) - 1 and dataframe_ast:
kwargs[DATAFRAME_AST_PARAMETER] = dataframe_ast
is_final_query = i == len(main_queries) - 1
result = self.run_query(
final_query,
to_pandas,
to_iter and is_final_query,
is_ddl_on_temp_object=query.is_ddl_on_temp_object,
block=not is_last,
data_type=data_type,
async_job_plan=plan,
log_on_exception=log_on_exception,
case_sensitive=case_sensitive,
params=query.params,
ignore_results=ignore_results,
async_post_actions=post_actions,
to_arrow=to_arrow and is_final_query,
**kwargs,
)
placeholders[query.query_id_place_holder] = (
result["sfqid"] if not is_last else result.query_id
)
result_meta = get_new_description(self._cursor)
if action_id < plan.session._last_canceled_id:
raise SnowparkClientExceptionMessages.SERVER_QUERY_IS_CANCELLED()
except ProgrammingError as exec_err:
if not (cte_retry_eligible and _should_retry_cte_error(exec_err)):
raise

# --- CTE execution fallback ---
# Recompile the plan with CTE optimization disabled and
# re-execute. We never mutate session state here; the
# cte_enabled override is local to this compile() call.
sfqid = getattr(exec_err, "sfqid", None)
retry_succeeded = False
try:
retry_plan_queries = PlanCompiler(plan).compile(
cte_enabled=False
)
retry_main = retry_plan_queries[PlanQueryType.QUERIES]
retry_query = retry_main[0]
result = self.run_query(
final_query,
retry_query.sql,
to_pandas,
to_iter and is_final_query,
is_ddl_on_temp_object=query.is_ddl_on_temp_object,
block=not is_last,
to_iter,
is_ddl_on_temp_object=retry_query.is_ddl_on_temp_object,
block=True,
data_type=data_type,
async_job_plan=plan,
log_on_exception=log_on_exception,
case_sensitive=case_sensitive,
params=query.params,
params=retry_query.params,
ignore_results=ignore_results,
async_post_actions=post_actions,
to_arrow=to_arrow and is_final_query,
to_arrow=to_arrow,
**kwargs,
)
placeholders[query.query_id_place_holder] = (
result["sfqid"] if not is_last else result.query_id
)
result_meta = get_new_description(self._cursor)
if action_id < plan.session._last_canceled_id:
raise SnowparkClientExceptionMessages.SERVER_QUERY_IS_CANCELLED()
# Switch post_actions to the retry compilation so the
# finally block cleans up the correct temp objects.
post_actions = retry_plan_queries[PlanQueryType.POST_ACTIONS]
retry_succeeded = True
except Exception:
pass # fall through to telemetry then re-raise

session = plan.session
session._cte_optimization_fallback_count += 1
fallback_count = session._cte_optimization_fallback_count
self._telemetry_client.send_cte_execution_fallback_telemetry(
session_id=session.session_id,
plan_uuid=plan.uuid,
sfqid=sfqid,
error_type=type(exec_err).__name__,
error_message=str(exec_err),
api_calls=plan.api_calls,
retry_succeeded=retry_succeeded,
fallback_count=fallback_count,
)
if fallback_count >= _CTE_FALLBACK_AUTO_DISABLE_THRESHOLD:
session._cte_optimization_enabled = False
self._telemetry_client.send_cte_auto_disabled_telemetry(
session_id=session.session_id,
fallback_count=fallback_count,
)

if not retry_succeeded:
raise exec_err
finally:
# delete created tmp object
if block:
Expand Down
46 changes: 46 additions & 0 deletions src/snowflake/snowpark/_internal/telemetry.py
Original file line number Diff line number Diff line change
Expand Up @@ -738,6 +738,52 @@ def send_query_compilation_stage_failed_telemetry(
}
self.send(message)

def send_cte_execution_fallback_telemetry(
self,
session_id: int,
plan_uuid: str,
sfqid: Optional[str],
error_type: str,
error_message: str,
api_calls: List[Dict],
retry_succeeded: bool,
fallback_count: int,
) -> None:
message = {
**self._create_basic_telemetry_data(
CompilationStageTelemetryField.TYPE_CTE_EXECUTION_FALLBACK.value
),
TelemetryField.KEY_DATA.value: {
TelemetryField.SESSION_ID.value: session_id,
TelemetryField.KEY_CATEGORY.value: CompilationStageTelemetryField.CAT_CTE_EXECUTION_FALLBACK.value,
CompilationStageTelemetryField.PLAN_UUID.value: plan_uuid,
CompilationStageTelemetryField.CTE_FALLBACK_SFQID.value: sfqid,
CompilationStageTelemetryField.ERROR_TYPE.value: error_type,
CompilationStageTelemetryField.ERROR_MESSAGE.value: error_message,
TelemetryField.KEY_API_CALLS.value: api_calls,
CompilationStageTelemetryField.CTE_FALLBACK_RETRY_SUCCEEDED.value: retry_succeeded,
CompilationStageTelemetryField.CTE_FALLBACK_COUNT.value: fallback_count,
},
}
self.send(message)

def send_cte_auto_disabled_telemetry(
self,
session_id: int,
fallback_count: int,
) -> None:
message = {
**self._create_basic_telemetry_data(
CompilationStageTelemetryField.TYPE_CTE_EXECUTION_FALLBACK.value
),
TelemetryField.KEY_DATA.value: {
TelemetryField.SESSION_ID.value: session_id,
TelemetryField.KEY_CATEGORY.value: CompilationStageTelemetryField.CAT_CTE_AUTO_DISABLED.value,
CompilationStageTelemetryField.CTE_FALLBACK_COUNT.value: fallback_count,
},
}
self.send(message)

def send_plan_metrics_telemetry(
self, session_id: int, data: Dict[str, Any]
) -> None:
Expand Down
3 changes: 3 additions & 0 deletions src/snowflake/snowpark/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -662,6 +662,9 @@ def __init__(
self._cte_optimization_enabled: bool = self.is_feature_enabled_for_version(
_PYTHON_SNOWPARK_USE_CTE_OPTIMIZATION_VERSION
)
# Counts how many times CTE execution has fallen back to a non-CTE retry.
# Used to auto-disable CTE optimization when the threshold is exceeded.
self._cte_optimization_fallback_count: int = 0
self._use_logical_type_for_create_df: bool = (
self._conn._get_client_side_session_parameter(
_PYTHON_SNOWPARK_USE_LOGICAL_TYPE_FOR_CREATE_DATAFRAME_STRING, True
Expand Down
Loading
Loading