From d5d8bcbd340a586302e5b5b9a57a2ea3cb9df3e1 Mon Sep 17 00:00:00 2001 From: Hans Yu Date: Wed, 29 Apr 2026 01:03:26 +0200 Subject: [PATCH] refactor(backend): Unpack lists with sqla.select(). --- superset/commands/importers/v1/utils.py | 2 +- superset/common/tags.py | 76 +++--- superset/connectors/sqla/models.py | 6 +- superset/extensions/metadb.py | 4 +- ...1-17_e96dbf2cfef0_datasource_cluster_fk.py | 4 +- ...a813e_add_tables_relation_to_row_level_.py | 6 +- ..._a9422eeaae74_new_dataset_models_take_2.py | 228 ++++++++---------- ...e3017c6_tagged_object_unique_constraint.py | 12 +- superset/models/helpers.py | 20 +- superset/utils/core.py | 4 +- .../unit_tests/db_engine_specs/test_mssql.py | 2 +- 11 files changed, 162 insertions(+), 202 deletions(-) diff --git a/superset/commands/importers/v1/utils.py b/superset/commands/importers/v1/utils.py index 26442342fd45..063f62b17319 100644 --- a/superset/commands/importers/v1/utils.py +++ b/superset/commands/importers/v1/utils.py @@ -361,7 +361,7 @@ def safe_insert_dashboard_chart_relationships( # Get existing relationships only for dashboards being updated dashboard_ids = {dashboard_id for dashboard_id, _ in dashboard_chart_ids} existing_relationships = db.session.execute( - select([dashboard_slices.c.dashboard_id, dashboard_slices.c.slice_id]).where( + select(dashboard_slices.c.dashboard_id, dashboard_slices.c.slice_id).where( dashboard_slices.c.dashboard_id.in_(dashboard_ids) ) ).fetchall() diff --git a/superset/common/tags.py b/superset/common/tags.py index ce5c5ab19574..b4898249b0e8 100644 --- a/superset/common/tags.py +++ b/superset/common/tags.py @@ -32,11 +32,9 @@ def add_types_to_charts( charts = ( select( - [ - tag.c.id.label("tag_id"), - slices.c.id.label("object_id"), - literal(ObjectType.chart.name).label("object_type"), - ] + tag.c.id.label("tag_id"), + slices.c.id.label("object_id"), + literal(ObjectType.chart.name).label("object_type"), ) .select_from( join( @@ -64,11 +62,9 @@ def add_types_to_dashboards( dashboards = ( select( - [ - tag.c.id.label("tag_id"), - dashboard_table.c.id.label("object_id"), - literal(ObjectType.dashboard.name).label("object_type"), - ] + tag.c.id.label("tag_id"), + dashboard_table.c.id.label("object_id"), + literal(ObjectType.dashboard.name).label("object_type"), ) .select_from( join( @@ -96,11 +92,9 @@ def add_types_to_saved_queries( saved_queries = ( select( - [ - tag.c.id.label("tag_id"), - saved_query.c.id.label("object_id"), - literal(ObjectType.query.name).label("object_type"), - ] + tag.c.id.label("tag_id"), + saved_query.c.id.label("object_id"), + literal(ObjectType.query.name).label("object_type"), ) .select_from( join( @@ -128,11 +122,9 @@ def add_types_to_datasets( datasets = ( select( - [ - tag.c.id.label("tag_id"), - tables.c.id.label("object_id"), - literal(ObjectType.dataset.name).label("object_type"), - ] + tag.c.id.label("tag_id"), + tables.c.id.label("object_id"), + literal(ObjectType.dataset.name).label("object_type"), ) .select_from( join( @@ -238,11 +230,9 @@ def add_owners_to_charts( charts = ( select( - [ - tag.c.id.label("tag_id"), - slices.c.id.label("object_id"), - literal(ObjectType.chart.name).label("object_type"), - ] + tag.c.id.label("tag_id"), + slices.c.id.label("object_id"), + literal(ObjectType.chart.name).label("object_type"), ) .select_from( join( @@ -274,11 +264,9 @@ def add_owners_to_dashboards( dashboards = ( select( - [ - tag.c.id.label("tag_id"), - dashboard_table.c.id.label("object_id"), - literal(ObjectType.dashboard.name).label("object_type"), - ] + tag.c.id.label("tag_id"), + dashboard_table.c.id.label("object_id"), + literal(ObjectType.dashboard.name).label("object_type"), ) .select_from( join( @@ -310,11 +298,9 @@ def add_owners_to_saved_queries( saved_queries = ( select( - [ - tag.c.id.label("tag_id"), - saved_query.c.id.label("object_id"), - literal(ObjectType.query.name).label("object_type"), - ] + tag.c.id.label("tag_id"), + saved_query.c.id.label("object_id"), + literal(ObjectType.query.name).label("object_type"), ) .select_from( join( @@ -346,11 +332,9 @@ def add_owners_to_datasets( datasets = ( select( - [ - tag.c.id.label("tag_id"), - tables.c.id.label("object_id"), - literal(ObjectType.dataset.name).label("object_type"), - ] + tag.c.id.label("tag_id"), + tables.c.id.label("object_id"), + literal(ObjectType.dataset.name).label("object_type"), ) .select_from( join( @@ -440,7 +424,7 @@ def add_owners(metadata: MetaData) -> None: columns = ["tag_id", "object_id", "object_type"] # create a custom tag for each user - ids = select([users.c.id]) + ids = select(users.c.id) insert = tag.insert() for (id_,) in db.session.execute(ids): with contextlib.suppress(IntegrityError): # already exists @@ -478,18 +462,16 @@ def add_favorites(metadata: MetaData) -> None: columns = ["tag_id", "object_id", "object_type"] # create a custom tag for each user - ids = select([users.c.id]) + ids = select(users.c.id) insert = tag.insert() for (id_,) in db.session.execute(ids): with contextlib.suppress(IntegrityError): # already exists db.session.execute(insert, name=f"favorited_by:{id_}", type=TagType.type) favstars = ( select( - [ - tag.c.id.label("tag_id"), - favstar.c.obj_id.label("object_id"), - func.lower(favstar.c.class_name).label("object_type"), - ] + tag.c.id.label("tag_id"), + favstar.c.obj_id.label("object_id"), + func.lower(favstar.c.class_name).label("object_type"), ) .select_from( join( diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index e6796923980f..031142ad6f44 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -1695,11 +1695,9 @@ def adhoc_column_to_sqla( # pylint: disable=too-many-locals # for those we fall back to LIMIT 1. tbl, _unused_cte = self.get_from_clause(template_processor) if self.db_engine_spec.type_probe_needs_row: - qry = sa.select([sqla_column]).limit(1).select_from(tbl) + qry = sa.select(sqla_column).limit(1).select_from(tbl) else: - qry = ( - sa.select([sqla_column]).where(sa.false()).select_from(tbl) - ) + qry = sa.select(sqla_column).where(sa.false()).select_from(tbl) sql = self.database.compile_sqla_query( qry, catalog=self.catalog, diff --git a/superset/extensions/metadb.py b/superset/extensions/metadb.py index 375b4b87ce02..78274d52e609 100644 --- a/superset/extensions/metadb.py +++ b/superset/extensions/metadb.py @@ -361,7 +361,7 @@ def _build_sql( """ Build SQLAlchemy query object. """ - query = select([self._table]) + query = select(self._table) for column_name, filter_ in bounds.items(): column = self._table.c[column_name] @@ -445,7 +445,7 @@ def insert_row(self, row: Row) -> int: if self._rowid: return result.inserted_primary_key[0] - query = select([func.count()]).select_from(self._table) + query = select(func.count()).select_from(self._table) return connection.execute(query).scalar() @check_dml diff --git a/superset/migrations/versions/2020-01-08_01-17_e96dbf2cfef0_datasource_cluster_fk.py b/superset/migrations/versions/2020-01-08_01-17_e96dbf2cfef0_datasource_cluster_fk.py index 9682ee6c588d..bd3e89957c6c 100644 --- a/superset/migrations/versions/2020-01-08_01-17_e96dbf2cfef0_datasource_cluster_fk.py +++ b/superset/migrations/versions/2020-01-08_01-17_e96dbf2cfef0_datasource_cluster_fk.py @@ -49,7 +49,7 @@ def upgrade(): clusters = sa.Table("clusters", metadata, autoload=True) statement = datasources.update().values( - cluster_id=sa.select([clusters.c.id]) + cluster_id=sa.select(clusters.c.id) .where(datasources.c.cluster_name == clusters.c.cluster_name) .as_scalar() ) @@ -91,7 +91,7 @@ def downgrade(): clusters = sa.Table("clusters", metadata, autoload=True) statement = datasources.update().values( - cluster_name=sa.select([clusters.c.cluster_name]) + cluster_name=sa.select(clusters.c.cluster_name) .where(datasources.c.cluster_id == clusters.c.id) .as_scalar() ) diff --git a/superset/migrations/versions/2020-04-24_10-46_e557699a813e_add_tables_relation_to_row_level_.py b/superset/migrations/versions/2020-04-24_10-46_e557699a813e_add_tables_relation_to_row_level_.py index 1ca6edd7f328..6dae3192970f 100644 --- a/superset/migrations/versions/2020-04-24_10-46_e557699a813e_add_tables_relation_to_row_level_.py +++ b/superset/migrations/versions/2020-04-24_10-46_e557699a813e_add_tables_relation_to_row_level_.py @@ -49,7 +49,7 @@ def upgrade(): ) rlsf = sa.Table("row_level_security_filters", metadata, autoload=True) - filter_ids = sa.select([rlsf.c.id, rlsf.c.table_id]) + filter_ids = sa.select(rlsf.c.id, rlsf.c.table_id) for row in bind.execute(filter_ids): move_table_id = rls_filter_tables.insert().values( @@ -85,7 +85,7 @@ def downgrade(): rls_filter_tables = sa.Table("rls_filter_tables", metadata, autoload=True) rls_filter_roles = sa.Table("rls_filter_roles", metadata, autoload=True) - filter_tables = sa.select([rls_filter_tables.c.rls_filter_id]).group_by( + filter_tables = sa.select(rls_filter_tables.c.rls_filter_id).group_by( rls_filter_tables.c.rls_filter_id ) @@ -95,7 +95,7 @@ def downgrade(): filter_params = dict(bind.execute(filter_query).fetchone()) origin_id = filter_params.pop("id", None) table_ids = bind.execute( - sa.select([rls_filter_tables.c.table_id]).where( + sa.select(rls_filter_tables.c.table_id).where( rls_filter_tables.c.rls_filter_id == row["rls_filter_id"] ) ).fetchall() diff --git a/superset/migrations/versions/2022-04-01_14-38_a9422eeaae74_new_dataset_models_take_2.py b/superset/migrations/versions/2022-04-01_14-38_a9422eeaae74_new_dataset_models_take_2.py index efe5e20e6a0d..1af34b20beb3 100644 --- a/superset/migrations/versions/2022-04-01_14-38_a9422eeaae74_new_dataset_models_take_2.py +++ b/superset/migrations/versions/2022-04-01_14-38_a9422eeaae74_new_dataset_models_take_2.py @@ -336,23 +336,21 @@ def copy_tables(session: Session) -> None: insert_from_select( NewTable, select( - [ - # Tables need different uuid than datasets, since they are different - # entities. When INSERT FROM SELECT, we must provide a value for `uuid`, - # otherwise it'd use the default generated on Python side, which - # will cause duplicate values. They will be replaced by `assign_uuids` later. # noqa: E501 - SqlaTable.uuid, - SqlaTable.id.label("sqlatable_id"), - SqlaTable.created_on, - SqlaTable.changed_on, - SqlaTable.created_by_fk, - SqlaTable.changed_by_fk, - SqlaTable.table_name.label("name"), - SqlaTable.schema, - SqlaTable.database_id, - SqlaTable.is_managed_externally, - SqlaTable.external_url, - ] + # Tables need different uuid than datasets, since they are different + # entities. When INSERT FROM SELECT, we must provide a value for `uuid`, + # otherwise it'd use the default generated on Python side, which + # will cause duplicate values. They will be replaced by `assign_uuids` later. # noqa: E501 + SqlaTable.uuid, + SqlaTable.id.label("sqlatable_id"), + SqlaTable.created_on, + SqlaTable.changed_on, + SqlaTable.created_by_fk, + SqlaTable.changed_by_fk, + SqlaTable.table_name.label("name"), + SqlaTable.schema, + SqlaTable.database_id, + SqlaTable.is_managed_externally, + SqlaTable.external_url, ) # use an inner join to filter out only tables with valid database ids .select_from(sa.join(SqlaTable, Database, SqlaTable.database_id == Database.id)) @@ -369,20 +367,18 @@ def copy_datasets(session: Session) -> None: insert_from_select( NewDataset, select( - [ - SqlaTable.uuid, - SqlaTable.created_on, - SqlaTable.changed_on, - SqlaTable.created_by_fk, - SqlaTable.changed_by_fk, - SqlaTable.database_id, - SqlaTable.table_name.label("name"), - func.coalesce(SqlaTable.sql, SqlaTable.table_name).label("expression"), - is_physical_table.label("is_physical"), - SqlaTable.is_managed_externally, - SqlaTable.external_url, - SqlaTable.extra.label("extra_json"), - ] + SqlaTable.uuid, + SqlaTable.created_on, + SqlaTable.changed_on, + SqlaTable.created_by_fk, + SqlaTable.changed_by_fk, + SqlaTable.database_id, + SqlaTable.table_name.label("name"), + func.coalesce(SqlaTable.sql, SqlaTable.table_name).label("expression"), + is_physical_table.label("is_physical"), + SqlaTable.is_managed_externally, + SqlaTable.external_url, + SqlaTable.extra.label("extra_json"), ), ) @@ -390,7 +386,7 @@ def copy_datasets(session: Session) -> None: insert_from_select( dataset_user_association_table, select( - [NewDataset.id.label("dataset_id"), sqlatable_user_table.c.user_id] + NewDataset.id.label("dataset_id"), sqlatable_user_table.c.user_id ).select_from( sqlatable_user_table.join( SqlaTable, SqlaTable.id == sqlatable_user_table.c.table_id @@ -402,10 +398,8 @@ def copy_datasets(session: Session) -> None: insert_from_select( dataset_table_association_table, select( - [ - NewDataset.id.label("dataset_id"), - NewTable.id.label("table_id"), - ] + NewDataset.id.label("dataset_id"), + NewTable.id.label("table_id"), ).select_from( sa.join(SqlaTable, NewTable, NewTable.sqlatable_id == SqlaTable.id).join( NewDataset, NewDataset.uuid == SqlaTable.uuid @@ -423,25 +417,23 @@ def copy_columns(session: Session) -> None: insert_from_select( NewColumn, select( - [ - TableColumn.uuid, - TableColumn.created_on, - TableColumn.changed_on, - TableColumn.created_by_fk, - TableColumn.changed_by_fk, - TableColumn.groupby.label("is_dimensional"), - TableColumn.filterable.label("is_filterable"), - TableColumn.column_name.label("name"), - TableColumn.description, - func.coalesce(TableColumn.expression, TableColumn.column_name).label( - "expression" - ), - sa.literal(False).label("is_aggregation"), - is_physical_column.label("is_physical"), - func.coalesce(TableColumn.is_dttm, False).label("is_temporal"), - func.coalesce(TableColumn.type, UNKNOWN_TYPE).label("type"), - TableColumn.extra.label("extra_json"), - ] + TableColumn.uuid, + TableColumn.created_on, + TableColumn.changed_on, + TableColumn.created_by_fk, + TableColumn.changed_by_fk, + TableColumn.groupby.label("is_dimensional"), + TableColumn.filterable.label("is_filterable"), + TableColumn.column_name.label("name"), + TableColumn.description, + func.coalesce(TableColumn.expression, TableColumn.column_name).label( + "expression" + ), + sa.literal(False).label("is_aggregation"), + is_physical_column.label("is_physical"), + func.coalesce(TableColumn.is_dttm, False).label("is_temporal"), + func.coalesce(TableColumn.type, UNKNOWN_TYPE).label("type"), + TableColumn.extra.label("extra_json"), ).select_from(active_table_columns), ) @@ -452,10 +444,8 @@ def copy_columns(session: Session) -> None: insert_from_select( dataset_column_association_table, select( - [ - NewDataset.id.label("dataset_id"), - NewColumn.id.label("column_id"), - ], + NewDataset.id.label("dataset_id"), + NewColumn.id.label("column_id"), ).select_from( joined_columns_table.join(NewDataset, NewDataset.uuid == SqlaTable.uuid) ), @@ -472,33 +462,31 @@ def copy_metrics(session: Session) -> None: insert_from_select( NewColumn, select( - [ - SqlMetric.uuid, - SqlMetric.created_on, - SqlMetric.changed_on, - SqlMetric.created_by_fk, - SqlMetric.changed_by_fk, - SqlMetric.metric_name.label("name"), - SqlMetric.expression, - SqlMetric.description, - sa.literal(UNKNOWN_TYPE).label("type"), - ( - func.coalesce( - sa.func.lower(SqlMetric.metric_type).in_( - ADDITIVE_METRIC_TYPES_LOWER - ), - sa.literal(False), - ).label("is_additive") - ), - sa.literal(True).label("is_aggregation"), - # metrics are by default not filterable - sa.literal(False).label("is_filterable"), - sa.literal(False).label("is_dimensional"), - sa.literal(False).label("is_physical"), - sa.literal(False).label("is_temporal"), - SqlMetric.extra.label("extra_json"), - SqlMetric.warning_text, - ] + SqlMetric.uuid, + SqlMetric.created_on, + SqlMetric.changed_on, + SqlMetric.created_by_fk, + SqlMetric.changed_by_fk, + SqlMetric.metric_name.label("name"), + SqlMetric.expression, + SqlMetric.description, + sa.literal(UNKNOWN_TYPE).label("type"), + ( + func.coalesce( + sa.func.lower(SqlMetric.metric_type).in_( + ADDITIVE_METRIC_TYPES_LOWER + ), + sa.literal(False), + ).label("is_additive") + ), + sa.literal(True).label("is_aggregation"), + # metrics are by default not filterable + sa.literal(False).label("is_filterable"), + sa.literal(False).label("is_dimensional"), + sa.literal(False).label("is_physical"), + sa.literal(False).label("is_temporal"), + SqlMetric.extra.label("extra_json"), + SqlMetric.warning_text, ).select_from(active_metrics), ) @@ -506,10 +494,8 @@ def copy_metrics(session: Session) -> None: insert_from_select( dataset_column_association_table, select( - [ - NewDataset.id.label("dataset_id"), - NewColumn.id.label("column_id"), - ], + NewDataset.id.label("dataset_id"), + NewColumn.id.label("column_id"), ).select_from( active_metrics.join(NewDataset, NewDataset.uuid == SqlaTable.uuid).join( NewColumn, NewColumn.uuid == SqlMetric.uuid @@ -568,15 +554,13 @@ def print_update_count(): sqlalchemy_uri, ) in session.execute( select( - [ - NewDataset.database_id, - NewDataset.id.label("dataset_id"), - NewDataset.expression, - SqlaTable.extra, - NewDataset.is_physical, - SqlaTable.schema, - Database.sqlalchemy_uri, - ] + NewDataset.database_id, + NewDataset.id.label("dataset_id"), + NewDataset.expression, + SqlaTable.extra, + NewDataset.is_physical, + SqlaTable.schema, + Database.sqlalchemy_uri, ) .select_from(joined_tables) .offset(offset) @@ -725,29 +709,27 @@ def print_update_count(): query = ( select( # sorted alphabetically - [ - NewColumn.id.label("column_id"), - TableColumn.column_name, - NewColumn.changed_by_fk, - NewColumn.changed_on, - NewColumn.created_on, - NewColumn.description, - SqlMetric.d3format, - NewDataset.external_url, - NewColumn.extra_json, - NewColumn.is_dimensional, - NewColumn.is_filterable, - NewDataset.is_managed_externally, - NewColumn.is_physical, - SqlMetric.metric_type, - TableColumn.python_date_format, - Database.sqlalchemy_uri, - dataset_table_association_table.c.table_id, - func.coalesce( - TableColumn.verbose_name, SqlMetric.verbose_name - ).label("verbose_name"), - NewColumn.warning_text, - ] + NewColumn.id.label("column_id"), + TableColumn.column_name, + NewColumn.changed_by_fk, + NewColumn.changed_on, + NewColumn.created_on, + NewColumn.description, + SqlMetric.d3format, + NewDataset.external_url, + NewColumn.extra_json, + NewColumn.is_dimensional, + NewColumn.is_filterable, + NewDataset.is_managed_externally, + NewColumn.is_physical, + SqlMetric.metric_type, + TableColumn.python_date_format, + Database.sqlalchemy_uri, + dataset_table_association_table.c.table_id, + func.coalesce(TableColumn.verbose_name, SqlMetric.verbose_name).label( + "verbose_name" + ), + NewColumn.warning_text, ) .select_from(get_joined_tables(offset, limit)) .where( @@ -872,7 +854,7 @@ def print_update_count(): print(" Assign table column relations...") insert_from_select( table_column_association_table, - select([NewColumn.table_id, NewColumn.id.label("column_id")]) + select(NewColumn.table_id, NewColumn.id.label("column_id")) .select_from(NewColumn) .where(and_(NewColumn.is_physical, NewColumn.table_id.isnot(None))), ) diff --git a/superset/migrations/versions/2024-01-17_13-09_96164e3017c6_tagged_object_unique_constraint.py b/superset/migrations/versions/2024-01-17_13-09_96164e3017c6_tagged_object_unique_constraint.py index fd3d883c99eb..05716d92d4e0 100644 --- a/superset/migrations/versions/2024-01-17_13-09_96164e3017c6_tagged_object_unique_constraint.py +++ b/superset/migrations/versions/2024-01-17_13-09_96164e3017c6_tagged_object_unique_constraint.py @@ -59,12 +59,10 @@ def upgrade(): # Delete duplicates if any min_id_subquery = ( select( - [ - func.min(tagged_object_table.c.id).label("min_id"), - tagged_object_table.c.tag_id, - tagged_object_table.c.object_id, - tagged_object_table.c.object_type, - ] + func.min(tagged_object_table.c.id).label("min_id"), + tagged_object_table.c.tag_id, + tagged_object_table.c.object_id, + tagged_object_table.c.object_type, ) .group_by( tagged_object_table.c.tag_id, @@ -75,7 +73,7 @@ def upgrade(): ) delete_query = tagged_object_table.delete().where( - tagged_object_table.c.id.notin_(select([min_id_subquery.c.min_id])) + tagged_object_table.c.id.notin_(select(min_id_subquery.c.min_id)) ) bind.execute(delete_query) diff --git a/superset/models/helpers.py b/superset/models/helpers.py index c48cba355aad..c08230dd7b77 100644 --- a/superset/models/helpers.py +++ b/superset/models/helpers.py @@ -2456,7 +2456,7 @@ def values_for_column( # pylint: disable=too-many-locals # automatically add a random alias to the projection because of the # call to DISTINCT; others will uppercase the column names. This # gives us a deterministic column name in the dataframe. - [target_col.get_sqla_col(template_processor=tp).label("column_values")] + target_col.get_sqla_col(template_processor=tp).label("column_values") ) .select_from(tbl) .distinct() @@ -2548,15 +2548,15 @@ def _build_validation_query( ) -> Select: """Build validation query based on expression type. Raises on error.""" if expression_type == SqlExpressionType.COLUMN: - return sa.select([sa.literal_column(expression).label("test_col")]) + return sa.select(sa.literal_column(expression).label("test_col")) elif expression_type == SqlExpressionType.METRIC: - return sa.select([sa.literal_column(expression).label("test_metric")]) + return sa.select(sa.literal_column(expression).label("test_metric")) elif expression_type == SqlExpressionType.WHERE: - return sa.select([sa.literal(1)]).where(sa.text(expression)) + return sa.select(sa.literal(1)).where(sa.text(expression)) elif expression_type == SqlExpressionType.HAVING: dummy_col = sa.literal("A").label("dummy") return ( - sa.select([dummy_col]) + sa.select(dummy_col) .group_by(sa.text("dummy")) .having(sa.text(expression)) ) @@ -2988,7 +2988,7 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma if not db_engine_spec.allows_hidden_orderby_agg: select_exprs = remove_duplicates(select_exprs + orderby_exprs) - qry = sa.select(select_exprs) + qry = sa.select(*select_exprs) if groupby_all_columns: qry = qry.group_by(*groupby_all_columns.values()) @@ -3312,7 +3312,7 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma inner_select_exprs.append(inner) inner_select_exprs += [inner_main_metric_expr] - subq = sa.select(inner_select_exprs).select_from(tbl) + subq = sa.select(*inner_select_exprs).select_from(tbl) inner_time_filter = [] if dttm_col and not db_engine_spec.time_groupby_inline: @@ -3377,7 +3377,7 @@ def _create_join_condition(col_name: str, expr: Any) -> Any: ) # Reconstruct query with modified expressions - qry = sa.select(select_exprs) + qry = sa.select(*select_exprs) if groupby_all_columns: qry = qry.group_by(*groupby_all_columns.values()) @@ -3451,7 +3451,7 @@ def _create_top_groups_condition(col_name: str, expr: Any) -> Any: ) # Reconstruct query with modified expressions - qry = sa.select(select_exprs) + qry = sa.select(*select_exprs) if groupby_all_columns: qry = qry.group_by(*groupby_all_columns.values()) @@ -3478,7 +3478,7 @@ def _create_top_groups_condition(col_name: str, expr: Any) -> Any: ) label = "rowcount" col = self.make_sqla_column_compatible(literal_column("COUNT(*)"), label) - qry = sa.select([col]).select_from(qry.alias("rowcount_qry")) + qry = sa.select(col).select_from(qry.alias("rowcount_qry")) labels_expected = [label] filter_columns = [flt.get("col") for flt in filter] if filter else [] diff --git a/superset/utils/core.py b/superset/utils/core.py index 23a3017bf2cb..0c7f5df6d33b 100644 --- a/superset/utils/core.py +++ b/superset/utils/core.py @@ -765,7 +765,7 @@ def ping_connection(connection: Connection, branch: bool) -> None: # run a SELECT 1. use a core select() so that # the SELECT of a scalar value without a table is # appropriately formatted for the backend - connection.scalar(select([1])) + connection.scalar(select(1)) except exc.DBAPIError as err: # catch SQLAlchemy's DBAPIError, which is a wrapper # for the DBAPI's exception. It includes a .connection_invalidated @@ -777,7 +777,7 @@ def ping_connection(connection: Connection, branch: bool) -> None: # itself and establish a new connection. The disconnect detection # here also causes the whole connection pool to be invalidated # so that all stale connections are discarded. - connection.scalar(select([1])) + connection.scalar(select(1)) else: raise finally: diff --git a/tests/unit_tests/db_engine_specs/test_mssql.py b/tests/unit_tests/db_engine_specs/test_mssql.py index e0ce5e1180c8..6e1decf63a0d 100644 --- a/tests/unit_tests/db_engine_specs/test_mssql.py +++ b/tests/unit_tests/db_engine_specs/test_mssql.py @@ -81,7 +81,7 @@ def test_where_clause_n_prefix() -> None: tbl = table("tbl") sel = ( - select([str_col, unicode_col]) + select(str_col, unicode_col) .select_from(tbl) .where(str_col == "abc") .where(unicode_col == "abc")