Skip to content

Commit 5d7f21d

Browse files
authored
Chore: Enable tests for native DataFrame implementations (#4406)
1 parent ea11179 commit 5d7f21d

22 files changed

Lines changed: 638 additions & 634 deletions

.circleci/continue_config.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -281,7 +281,6 @@ workflows:
281281
- trino
282282
- spark
283283
- clickhouse
284-
- clickhouse-cluster
285284
- risingwave
286285
- engine_tests_cloud:
287286
name: cloud_engine_<< matrix.engine >>

Makefile

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -126,54 +126,51 @@ engine-%-down:
126126
##################
127127

128128
clickhouse-test: engine-clickhouse-up
129-
pytest -n auto -x -m "clickhouse" --retries 3 --junitxml=test-results/junit-clickhouse.xml
130-
131-
clickhouse-cluster-test: engine-clickhouse-up
132-
pytest -n auto -x -m "clickhouse_cluster" --retries 3 --junitxml=test-results/junit-clickhouse-cluster.xml
129+
pytest -n auto -m "clickhouse" --retries 3 --junitxml=test-results/junit-clickhouse.xml
133130

134131
duckdb-test: engine-duckdb-install
135-
pytest -n auto -x -m "duckdb" --retries 3 --junitxml=test-results/junit-duckdb.xml
132+
pytest -n auto -m "duckdb" --retries 3 --junitxml=test-results/junit-duckdb.xml
136133

137134
mssql-test: engine-mssql-up
138-
pytest -n auto -x -m "mssql" --retries 3 --junitxml=test-results/junit-mssql.xml
135+
pytest -n auto -m "mssql" --retries 3 --junitxml=test-results/junit-mssql.xml
139136

140137
mysql-test: engine-mysql-up
141-
pytest -n auto -x -m "mysql" --retries 3 --junitxml=test-results/junit-mysql.xml
138+
pytest -n auto -m "mysql" --retries 3 --junitxml=test-results/junit-mysql.xml
142139

143140
postgres-test: engine-postgres-up
144-
pytest -n auto -x -m "postgres" --retries 3 --junitxml=test-results/junit-postgres.xml
141+
pytest -n auto -m "postgres" --retries 3 --junitxml=test-results/junit-postgres.xml
145142

146143
spark-test: engine-spark-up
147-
pytest -n auto -x -m "spark or pyspark" --retries 3 --junitxml=test-results/junit-spark.xml
144+
pytest -n auto -m "spark" --retries 3 --junitxml=test-results/junit-spark.xml
148145

149146
trino-test: engine-trino-up
150-
pytest -n auto -x -m "trino or trino_iceberg or trino_delta or trino_nessie" --retries 3 --junitxml=test-results/junit-trino.xml
147+
pytest -n auto -m "trino" --retries 3 --junitxml=test-results/junit-trino.xml
151148

152149
risingwave-test: engine-risingwave-up
153-
pytest -n auto -x -m "risingwave" --retries 3 --junitxml=test-results/junit-risingwave.xml
150+
pytest -n auto -m "risingwave" --retries 3 --junitxml=test-results/junit-risingwave.xml
154151

155152
#################
156153
# Cloud Engines #
157154
#################
158155

159156
snowflake-test: guard-SNOWFLAKE_ACCOUNT guard-SNOWFLAKE_WAREHOUSE guard-SNOWFLAKE_DATABASE guard-SNOWFLAKE_USER guard-SNOWFLAKE_PASSWORD engine-snowflake-install
160-
pytest -n auto -x -m "snowflake" --retries 3 --junitxml=test-results/junit-snowflake.xml
157+
pytest -n auto -m "snowflake" --retries 3 --junitxml=test-results/junit-snowflake.xml
161158

162159
bigquery-test: guard-BIGQUERY_KEYFILE engine-bigquery-install
163-
pytest -n auto -x -m "bigquery" --retries 3 --junitxml=test-results/junit-bigquery.xml
160+
pytest -n auto -m "bigquery" --retries 3 --junitxml=test-results/junit-bigquery.xml
164161

165162
databricks-test: guard-DATABRICKS_CATALOG guard-DATABRICKS_SERVER_HOSTNAME guard-DATABRICKS_HTTP_PATH guard-DATABRICKS_ACCESS_TOKEN guard-DATABRICKS_CONNECT_VERSION engine-databricks-install
166163
pip install 'databricks-connect==${DATABRICKS_CONNECT_VERSION}'
167-
pytest -n auto -x -m "databricks" --retries 3 --junitxml=test-results/junit-databricks.xml
164+
pytest -n auto -m "databricks" --retries 3 --junitxml=test-results/junit-databricks.xml
168165

169166
redshift-test: guard-REDSHIFT_HOST guard-REDSHIFT_USER guard-REDSHIFT_PASSWORD guard-REDSHIFT_DATABASE engine-redshift-install
170-
pytest -n auto -x -m "redshift" --retries 3 --junitxml=test-results/junit-redshift.xml
167+
pytest -n auto -m "redshift" --retries 3 --junitxml=test-results/junit-redshift.xml
171168

172169
clickhouse-cloud-test: guard-CLICKHOUSE_CLOUD_HOST guard-CLICKHOUSE_CLOUD_USERNAME guard-CLICKHOUSE_CLOUD_PASSWORD engine-clickhouse-install
173170
pytest -n 1 -m "clickhouse_cloud" --retries 3 --junitxml=test-results/junit-clickhouse-cloud.xml
174171

175172
athena-test: guard-AWS_ACCESS_KEY_ID guard-AWS_SECRET_ACCESS_KEY guard-ATHENA_S3_WAREHOUSE_LOCATION engine-athena-install
176-
pytest -n auto -x -m "athena" --retries 3 --retry-delay 10 --junitxml=test-results/junit-athena.xml
173+
pytest -n auto -m "athena" --retries 3 --retry-delay 10 --junitxml=test-results/junit-athena.xml
177174

178175
vscode_settings:
179176
mkdir -p .vscode

pytest.ini

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,11 @@ markers =
1616
github: test for Github CI/CD bot
1717
jupyter: tests for Jupyter integration
1818
web: tests for web UI
19-
spark_pyspark: test for Spark with PySpark dependency
2019
# Engine Adapters
2120
engine: test all engine adapters
2221
athena: test for Athena
2322
bigquery: test for BigQuery
24-
clickhouse: test for Clickhouse (standalone mode)
25-
clickhouse_cluster: test for Clickhouse (cluster mode)
23+
clickhouse: test for Clickhouse (standalone mode / cluster mode)
2624
clickhouse_cloud: test for Clickhouse (cloud mode)
2725
databricks: test for Databricks
2826
duckdb: test for DuckDB
@@ -33,9 +31,8 @@ markers =
3331
redshift: test for Redshift
3432
snowflake: test for Snowflake
3533
spark: test for Spark
36-
trino: test for Trino (Hive connector)
37-
trino_iceberg: test for Trino (Iceberg connector)
38-
trino_delta: test for Trino (Delta connector)
34+
trino: test for Trino (all connectors)
35+
risingwave: test for Risingwave
3936
addopts = -n 0 --dist=loadgroup
4037

4138
asyncio_default_fixture_loop_scope = session

sqlmesh/core/engine_adapter/base.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1008,6 +1008,8 @@ def create_view(
10081008
if materialized_properties and not materialized:
10091009
raise SQLMeshError("Materialized properties are only supported for materialized views")
10101010

1011+
query_or_df = self._native_df_to_pandas_df(query_or_df)
1012+
10111013
if isinstance(query_or_df, pd.DataFrame):
10121014
values: t.List[t.Tuple[t.Any, ...]] = list(
10131015
query_or_df.itertuples(index=False, name=None)
@@ -2002,6 +2004,19 @@ def _fetch_native_df(
20022004
self.execute(query, quote_identifiers=quote_identifiers)
20032005
return self.cursor.fetchdf()
20042006

2007+
def _native_df_to_pandas_df(
2008+
self,
2009+
query_or_df: QueryOrDF,
2010+
) -> t.Union[Query, pd.DataFrame]:
2011+
"""
2012+
Take a "native" DataFrame (eg Pyspark, Bigframe, Snowpark etc) and convert it to Pandas
2013+
"""
2014+
if isinstance(query_or_df, (exp.Query, exp.DerivedTable, pd.DataFrame)):
2015+
return query_or_df
2016+
2017+
# EngineAdapter subclasses that have native DataFrame types should override this
2018+
raise NotImplementedError(f"Unable to convert {type(query_or_df)} to Pandas")
2019+
20052020
def fetchdf(
20062021
self, query: t.Union[exp.Expression, str], quote_identifiers: bool = False
20072022
) -> pd.DataFrame:

sqlmesh/core/engine_adapter/bigquery.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from sqlmesh.utils import optional_import
2727
from sqlmesh.utils.date import to_datetime
2828
from sqlmesh.utils.errors import SQLMeshError
29+
from sqlmesh.utils.pandas import columns_to_types_from_dtypes
2930

3031
if t.TYPE_CHECKING:
3132
from google.api_core.retry import Retry
@@ -1107,6 +1108,39 @@ def _normalize_decimal_value(self, col: exp.Expression, precision: int) -> exp.E
11071108
def _normalize_nested_value(self, col: exp.Expression) -> exp.Expression:
11081109
return exp.func("TO_JSON_STRING", col, dialect=self.dialect)
11091110

1111+
@t.overload
1112+
def _columns_to_types(
1113+
self, query_or_df: DF, columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None
1114+
) -> t.Dict[str, exp.DataType]: ...
1115+
1116+
@t.overload
1117+
def _columns_to_types(
1118+
self, query_or_df: Query, columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None
1119+
) -> t.Optional[t.Dict[str, exp.DataType]]: ...
1120+
1121+
def _columns_to_types(
1122+
self, query_or_df: QueryOrDF, columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None
1123+
) -> t.Optional[t.Dict[str, exp.DataType]]:
1124+
if (
1125+
not columns_to_types
1126+
and bigframes
1127+
and isinstance(query_or_df, bigframes.dataframe.DataFrame)
1128+
):
1129+
# using dry_run=True attempts to prevent the DataFrame from being materialized just to read the column types from it
1130+
dtypes = query_or_df.to_pandas(dry_run=True).columnDtypes
1131+
return columns_to_types_from_dtypes(dtypes.items())
1132+
1133+
return super()._columns_to_types(query_or_df, columns_to_types)
1134+
1135+
def _native_df_to_pandas_df(
1136+
self,
1137+
query_or_df: QueryOrDF,
1138+
) -> t.Union[Query, pd.DataFrame]:
1139+
if bigframes and isinstance(query_or_df, bigframes.dataframe.DataFrame):
1140+
return query_or_df.to_pandas()
1141+
1142+
return super()._native_df_to_pandas_df(query_or_df)
1143+
11101144
@property
11111145
def _query_data(self) -> t.Any:
11121146
return self._connection_pool.get_attribute("query_data")

sqlmesh/core/engine_adapter/snowflake.py

Lines changed: 54 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from sqlmesh.core.schema_diff import SchemaDiffer
2828
from sqlmesh.utils import optional_import
2929
from sqlmesh.utils.errors import SQLMeshError
30+
from sqlmesh.utils.pandas import columns_to_types_from_dtypes
3031

3132
logger = logging.getLogger(__name__)
3233
snowpark = optional_import("snowflake.snowpark")
@@ -298,13 +299,23 @@ def query_factory() -> Query:
298299

299300
if is_snowpark_dataframe:
300301
temp_table.set("catalog", database)
301-
df_renamed = df.rename(
302-
{
303-
col: exp.to_identifier(col).sql(dialect=self.dialect, identify=True)
304-
for col in columns_to_types
305-
}
306-
) # type: ignore
307-
df_renamed.createOrReplaceTempView(
302+
303+
# only quote columns if they arent already quoted
304+
# if the Snowpark dataframe was created from a Pandas dataframe via snowpark.create_dataframe(pandas_df),
305+
# then they will be quoted already. But if the Snowpark dataframe was created manually by the user, then the
306+
# columns may not be quoted
307+
columns_already_quoted = all(
308+
col.startswith('"') and col.endswith('"') for col in df.columns
309+
)
310+
local_df = df
311+
if not columns_already_quoted:
312+
local_df = df.rename(
313+
{
314+
col: exp.to_identifier(col).sql(dialect=self.dialect, identify=True)
315+
for col in columns_to_types
316+
}
317+
) # type: ignore
318+
local_df.createOrReplaceTempView(
308319
temp_table.sql(dialect=self.dialect, identify=True)
309320
) # type: ignore
310321
elif isinstance(df, pd.DataFrame):
@@ -356,6 +367,11 @@ def query_factory() -> Query:
356367

357368
def cleanup() -> None:
358369
if is_snowpark_dataframe:
370+
if hasattr(df, "table_name"):
371+
if isinstance(df.table_name, str):
372+
# created by the Snowpark library if the Snowpark DataFrame was created from a Pandas DataFrame
373+
# (if the Snowpark DataFrame was created via native means then there is no 'table_name' property and no temp table)
374+
self.drop_table(df.table_name)
359375
self.drop_view(temp_table)
360376
else:
361377
self.drop_table(temp_table)
@@ -381,6 +397,15 @@ def _fetch_native_df(
381397
columns = self.cursor._result_set.batches[0].column_names
382398
return pd.DataFrame([dict(zip(columns, row)) for row in rows])
383399

400+
def _native_df_to_pandas_df(
401+
self,
402+
query_or_df: QueryOrDF,
403+
) -> t.Union[Query, pd.DataFrame]:
404+
if snowpark and isinstance(query_or_df, snowpark.DataFrame):
405+
return query_or_df.to_pandas()
406+
407+
return super()._native_df_to_pandas_df(query_or_df)
408+
384409
def _get_data_objects(
385410
self, schema_name: SchemaName, object_names: t.Optional[t.Set[str]] = None
386411
) -> t.List[DataObject]:
@@ -426,6 +451,10 @@ def _get_data_objects(
426451
if object_names:
427452
query = query.where(exp.column("TABLE_NAME").isin(*object_names))
428453

454+
# exclude SNOWPARK_TEMP_TABLE tables that are managed by the Snowpark library and are an implementation
455+
# detail of dealing with DataFrame's
456+
query = query.where(exp.column("TABLE_NAME").like("SNOWPARK_TEMP_TABLE%").not_())
457+
429458
df = self.fetchdf(query, quote_identifiers=True)
430459
if df.empty:
431460
return []
@@ -537,3 +566,21 @@ def clone_table(
537566
clone_kwargs=clone_kwargs,
538567
**kwargs,
539568
)
569+
570+
@t.overload
571+
def _columns_to_types(
572+
self, query_or_df: DF, columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None
573+
) -> t.Dict[str, exp.DataType]: ...
574+
575+
@t.overload
576+
def _columns_to_types(
577+
self, query_or_df: Query, columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None
578+
) -> t.Optional[t.Dict[str, exp.DataType]]: ...
579+
580+
def _columns_to_types(
581+
self, query_or_df: QueryOrDF, columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None
582+
) -> t.Optional[t.Dict[str, exp.DataType]]:
583+
if not columns_to_types and snowpark and isinstance(query_or_df, snowpark.DataFrame):
584+
return columns_to_types_from_dtypes(query_or_df.sample(n=1).to_pandas().dtypes.items())
585+
586+
return super()._columns_to_types(query_or_df, columns_to_types)

sqlmesh/core/engine_adapter/spark.py

Lines changed: 6 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -382,51 +382,14 @@ def create_state_table(
382382
partitioned_by=[exp.column(x) for x in primary_key] if primary_key else None,
383383
)
384384

385-
def create_view(
385+
def _native_df_to_pandas_df(
386386
self,
387-
view_name: TableName,
388387
query_or_df: QueryOrDF,
389-
columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None,
390-
replace: bool = True,
391-
materialized: bool = False,
392-
materialized_properties: t.Optional[t.Dict[str, t.Any]] = None,
393-
table_description: t.Optional[str] = None,
394-
column_descriptions: t.Optional[t.Dict[str, str]] = None,
395-
view_properties: t.Optional[t.Dict[str, exp.Expression]] = None,
396-
**create_kwargs: t.Any,
397-
) -> None:
398-
"""Create a view with a query or dataframe.
399-
400-
If a dataframe is passed in, it will be converted into a literal values statement.
401-
This should only be done if the dataframe is very small!
402-
403-
Args:
404-
view_name: The view name.
405-
query_or_df: A query or dataframe.
406-
columns_to_types: Columns to use in the view statement.
407-
replace: Whether or not to replace an existing view - defaults to True.
408-
materialized: Whether or not the view should be materialized - defaults to False.
409-
materialized_properties: Optional materialized view properties to add to the view.
410-
table_description: Optional table description from MODEL DDL.
411-
column_descriptions: Optional column descriptions from model query.
412-
create_kwargs: Additional kwargs to pass into the Create expression
413-
view_properties: Optional view properties to add to the view.
414-
"""
415-
pyspark_df = self.try_get_pyspark_df(query_or_df)
416-
if pyspark_df:
417-
query_or_df = pyspark_df.toPandas()
418-
super().create_view(
419-
view_name,
420-
query_or_df,
421-
columns_to_types,
422-
replace,
423-
materialized,
424-
materialized_properties,
425-
table_description,
426-
column_descriptions,
427-
view_properties=view_properties,
428-
**create_kwargs,
429-
)
388+
) -> t.Union[Query, pd.DataFrame]:
389+
if pyspark_df := self.try_get_pyspark_df(query_or_df):
390+
return pyspark_df.toPandas()
391+
392+
return super()._native_df_to_pandas_df(query_or_df)
430393

431394
def _create_table(
432395
self,

sqlmesh/utils/pandas.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,20 @@
2727
pd.Float32Dtype(): exp.DataType.build("float"),
2828
pd.Float64Dtype(): exp.DataType.build("double"),
2929
pd.StringDtype(): exp.DataType.build("text"), # type: ignore
30+
pd.StringDtype("pyarrow"): exp.DataType.build("text"),
3031
pd.BooleanDtype(): exp.DataType.build("boolean"),
3132
}
3233

3334

3435
def columns_to_types_from_df(df: pd.DataFrame) -> t.Dict[str, exp.DataType]:
36+
return columns_to_types_from_dtypes(df.dtypes.items())
37+
38+
39+
def columns_to_types_from_dtypes(
40+
dtypes: t.Iterable[t.Tuple[t.Hashable, t.Any]],
41+
) -> t.Dict[str, exp.DataType]:
3542
result = {}
36-
for column_name, column_type in df.dtypes.items():
43+
for column_name, column_type in dtypes:
3744
exp_type: t.Optional[exp.DataType] = None
3845
if hasattr(pd, "DatetimeTZDtype") and isinstance(column_type, pd.DatetimeTZDtype):
3946
exp_type = exp.DataType.build("timestamptz")

0 commit comments

Comments
 (0)