|
2 | 2 |
|
3 | 3 | import contextlib |
4 | 4 | import logging |
| 5 | +import re |
5 | 6 | import typing as t |
6 | 7 |
|
7 | 8 | from sqlglot import exp |
|
24 | 25 | set_catalog, |
25 | 26 | ) |
26 | 27 | from sqlmesh.core.schema_diff import SchemaDiffer |
| 28 | +from sqlmesh.core.snapshot.execution_tracker import QueryExecutionTracker |
27 | 29 | from sqlmesh.utils import optional_import, get_source_columns_to_types |
28 | 30 | from sqlmesh.utils.errors import SQLMeshError |
29 | 31 | from sqlmesh.utils.pandas import columns_to_types_from_dtypes |
@@ -72,6 +74,7 @@ class SnowflakeEngineAdapter(GetCurrentCatalogFromFunctionMixin, ClusteredByMixi |
72 | 74 | ) |
73 | 75 | MANAGED_TABLE_KIND = "DYNAMIC TABLE" |
74 | 76 | SNOWPARK = "snowpark" |
| 77 | + SUPPORTS_QUERY_EXECUTION_TRACKING = True |
75 | 78 |
|
76 | 79 | @contextlib.contextmanager |
77 | 80 | def session(self, properties: SessionProperties) -> t.Iterator[None]: |
@@ -664,3 +667,33 @@ def close(self) -> t.Any: |
664 | 667 | self._connection_pool.set_attribute(self.SNOWPARK, None) |
665 | 668 |
|
666 | 669 | return super().close() |
| 670 | + |
| 671 | + def _record_execution_stats( |
| 672 | + self, sql: str, rowcount: t.Optional[int] = None, bytes_processed: t.Optional[int] = None |
| 673 | + ) -> None: |
| 674 | + """Snowflake does not report row counts for CTAS like other DML operations. |
| 675 | +
|
| 676 | + They neither report the sentinel value -1 nor do they report 0 rows. Instead, they return a single data row |
| 677 | + containing the string "Table <table_name> successfully created." and a row count of 1. |
| 678 | +
|
| 679 | + We do not want to record the row count of 1 for CTAS operations, so we check for that data pattern and return |
| 680 | + early if it is detected. |
| 681 | +
|
| 682 | + Regex explanation - Snowflake identifiers may be: |
| 683 | + - An unquoted contiguous set of [a-zA-Z0-9_$] characters |
| 684 | + - A double-quoted string that may contain spaces and nested double-quotes represented by `""` |
| 685 | + - Example: " my ""table"" name " |
| 686 | + - Pattern: "(?:[^"]|"")+" |
| 687 | + - ?: is a non-capturing group |
| 688 | + - [^"] matches any single character except a double-quote |
| 689 | + - "" matches two sequential double-quotes |
| 690 | + """ |
| 691 | + if rowcount == 1: |
| 692 | + results = self.cursor.fetchall() |
| 693 | + if results and len(results) == 1: |
| 694 | + is_ctas = re.match( |
| 695 | + r'Table ([a-zA-Z0-9_$]+|"(?:[^"]|"")+") successfully created\.', results[0][0] |
| 696 | + ) |
| 697 | + if is_ctas: |
| 698 | + return |
| 699 | + QueryExecutionTracker.record_execution(sql, rowcount, bytes_processed) |
0 commit comments