diff --git a/superset/__init__.py b/superset/__init__.py index 91fcbf5ad8fa..e3ab53c46b33 100644 --- a/superset/__init__.py +++ b/superset/__init__.py @@ -14,10 +14,22 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import flask_appbuilder from werkzeug.local import LocalProxy -from superset.app import create_app # noqa: F401 -from superset.extensions import ( +# SQLAlchemy 2.0 enables "Annotated Declarative" mapping, which inspects class +# attribute type annotations and requires mapped attributes to use ``Mapped[...]``. +# Superset's models (and Flask-AppBuilder mixins) still carry legacy 1.x style +# annotations that are not wrapped in ``Mapped[...]``. Setting ``__allow_unmapped__`` +# on the shared declarative base preserves the legacy behavior so those annotations +# are ignored by the ORM. This must run before any model class is defined (i.e. +# before importing ``superset.app``), since the annotation check happens at class +# creation time. Models can be migrated incrementally to the typed ``Mapped[...]`` +# form. +flask_appbuilder.Model.__allow_unmapped__ = True + +from superset.app import create_app # noqa: E402, F401 +from superset.extensions import ( # noqa: E402 appbuilder, # noqa: F401 cache_manager, db, # noqa: F401 @@ -28,7 +40,7 @@ security_manager, # noqa: F401 talisman, # noqa: F401 ) -from superset.security import SupersetSecurityManager # noqa: F401 +from superset.security import SupersetSecurityManager # noqa: E402, F401 # All of the fields located here should be considered legacy. The correct way to # declare "global" dependencies is to define it in extensions.py, diff --git a/superset/commands/dataset/importers/v1/utils.py b/superset/commands/dataset/importers/v1/utils.py index 652111ec892c..f3d751193f6e 100644 --- a/superset/commands/dataset/importers/v1/utils.py +++ b/superset/commands/dataset/importers/v1/utils.py @@ -28,7 +28,7 @@ from pandas.errors import OutOfBoundsDatetime from sqlalchemy import BigInteger, Boolean, Date, DateTime, Float, String, Text from sqlalchemy.exc import MultipleResultsFound -from sqlalchemy.sql.visitors import VisitableType +from sqlalchemy.types import TypeEngine from superset import db, security_manager from superset.commands.dataset.exceptions import ( @@ -94,7 +94,7 @@ def redirect_request( } -def get_sqla_type(native_type: str) -> VisitableType: +def get_sqla_type(native_type: str) -> TypeEngine: if native_type.upper() in type_map: return type_map[native_type.upper()] @@ -107,7 +107,7 @@ def get_sqla_type(native_type: str) -> VisitableType: ) -def get_dtype(df: pd.DataFrame, dataset: SqlaTable) -> dict[str, VisitableType]: +def get_dtype(df: pd.DataFrame, dataset: SqlaTable) -> dict[str, TypeEngine]: return { column.column_name: get_sqla_type(column.type) for column in dataset.columns diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index 81b5c2fc8d9d..89174704bd80 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -335,7 +335,7 @@ def is_virtual(self) -> bool: return self.kind == DatasourceKind.VIRTUAL @declared_attr - def slices(self) -> RelationshipProperty: + def slices(self) -> Mapped[list["Slice"]]: return relationship( "Slice", overlaps="table", diff --git a/superset/initialization/__init__.py b/superset/initialization/__init__.py index 738bfb22984c..8cc6df311316 100644 --- a/superset/initialization/__init__.py +++ b/superset/initialization/__init__.py @@ -22,6 +22,7 @@ import sys from typing import Any, Callable, TYPE_CHECKING +import sqlalchemy as sa import wtforms_json from colorama import Fore, Style from deprecation import deprecated @@ -808,7 +809,8 @@ def check_and_warn_database_connection(self) -> None: try: with self.superset_app.app_context(): # Simple connection test - db.engine.execute("SELECT 1") + with db.engine.connect() as connection: + connection.execute(sa.text("SELECT 1")) except Exception: db_uri = self.database_uri safe_uri = make_url_safe(db_uri) if db_uri else "Not configured" diff --git a/superset/models/helpers.py b/superset/models/helpers.py index b603f0bb3cea..9bc22904a51f 100644 --- a/superset/models/helpers.py +++ b/superset/models/helpers.py @@ -60,7 +60,7 @@ from sqlalchemy.exc import MultipleResultsFound from sqlalchemy.ext.declarative import declared_attr from sqlalchemy.ext.hybrid import hybrid_property -from sqlalchemy.orm import Mapper, Session, validates, with_loader_criteria +from sqlalchemy.orm import Mapped, Mapper, Session, validates, with_loader_criteria from sqlalchemy.orm.session import ORMExecuteState from sqlalchemy.sql.elements import ColumnElement, Grouping, literal_column, TextClause from sqlalchemy.sql.expression import Label, Select, TextAsFrom @@ -583,7 +583,7 @@ class AuditMixinNullable(AuditMixin): ) @declared_attr - def created_by_fk(self) -> sa.Column: # pylint: disable=arguments-renamed + def created_by_fk(self) -> Mapped[Optional[int]]: # pylint: disable=arguments-renamed return sa.Column( sa.Integer, sa.ForeignKey("ab_user.id"), @@ -592,7 +592,7 @@ def created_by_fk(self) -> sa.Column: # pylint: disable=arguments-renamed ) @declared_attr - def changed_by_fk(self) -> sa.Column: # pylint: disable=arguments-renamed + def changed_by_fk(self) -> Mapped[Optional[int]]: # pylint: disable=arguments-renamed return sa.Column( sa.Integer, sa.ForeignKey("ab_user.id"), diff --git a/superset/queries/filters.py b/superset/queries/filters.py index 1890e38c2a5e..294261d16cf4 100644 --- a/superset/queries/filters.py +++ b/superset/queries/filters.py @@ -16,7 +16,12 @@ # under the License. from typing import Any -from flask_sqlalchemy import BaseQuery +try: + # Flask-SQLAlchemy 3.x (required by SQLAlchemy 2.0) + from flask_sqlalchemy.query import Query as BaseQuery +except ImportError: # pragma: no cover + # Flask-SQLAlchemy 2.x + from flask_sqlalchemy import BaseQuery from superset import security_manager from superset.models.sql_lab import Query diff --git a/superset/queries/saved_queries/filters.py b/superset/queries/saved_queries/filters.py index 821f42d6f112..cb721f934e38 100644 --- a/superset/queries/saved_queries/filters.py +++ b/superset/queries/saved_queries/filters.py @@ -18,10 +18,16 @@ from flask import g from flask_babel import lazy_gettext as _ -from flask_sqlalchemy import BaseQuery from sqlalchemy import or_ from sqlalchemy.orm.query import Query +try: + # Flask-SQLAlchemy 3.x (required by SQLAlchemy 2.0) + from flask_sqlalchemy.query import Query as BaseQuery +except ImportError: # pragma: no cover + # Flask-SQLAlchemy 2.x + from flask_sqlalchemy import BaseQuery + from superset.models.sql_lab import SavedQuery from superset.tags.filters import BaseTagIdFilter, BaseTagNameFilter from superset.views.base import BaseFilter diff --git a/superset/security/manager.py b/superset/security/manager.py index c8a44a6bffb3..73a42b0eecc5 100644 --- a/superset/security/manager.py +++ b/superset/security/manager.py @@ -54,7 +54,7 @@ from jwt.api_jwt import _jwt_global_obj from sqlalchemy import and_, func as sa_func, inspect, or_ from sqlalchemy.engine.base import Connection -from sqlalchemy.orm import eagerload, joinedload +from sqlalchemy.orm import joinedload from sqlalchemy.orm.exc import MultipleResultsFound from sqlalchemy.orm.mapper import Mapper from sqlalchemy.orm.query import Query as SqlaQuery @@ -1800,8 +1800,8 @@ def _get_all_pvms(self) -> list[PermissionView]: pvms = ( self.session.query(self.permissionview_model) .options( - eagerload(self.permissionview_model.permission), - eagerload(self.permissionview_model.view_menu), + joinedload(self.permissionview_model.permission), + joinedload(self.permissionview_model.view_menu), ) .all() ) diff --git a/superset/utils/encrypt.py b/superset/utils/encrypt.py index b788566228a1..217d60b8cb85 100644 --- a/superset/utils/encrypt.py +++ b/superset/utils/encrypt.py @@ -302,7 +302,7 @@ def _select_columns_from_table( table_name: str, ) -> Row: cols = ",".join(pk_columns + column_names) - return conn.execute(f"SELECT {cols} FROM {table_name}") # noqa: S608 + return conn.execute(text(f"SELECT {cols} FROM {table_name}")) # noqa: S608 def _target_type(self, encrypted_type: EncryptedType) -> EncryptedType: """The EncryptedType to re-encrypt a value *into*. @@ -430,7 +430,7 @@ def _re_encrypt_row( re_encrypted_columns = {} for column_name, encrypted_type in columns.items(): - raw_value = self._read_bytes(column_name, row[column_name]) + raw_value = self._read_bytes(column_name, row._mapping[column_name]) # NULL values aren't encrypted; there is nothing to migrate. if raw_value is None: @@ -508,13 +508,12 @@ def _re_encrypt_row( set_cols = ",".join(f"{name} = :{name}" for name in re_encrypted_columns) where_clause = " AND ".join(f"{pk} = :_pk_{pk}" for pk in pk_columns) - pk_bind = {f"_pk_{pk}": row[pk] for pk in pk_columns} + pk_bind = {f"_pk_{pk}": row._mapping[pk] for pk in pk_columns} conn.execute( text( f"UPDATE {table_name} SET {set_cols} WHERE {where_clause}" # noqa: S608 ), - **pk_bind, - **re_encrypted_columns, + {**pk_bind, **re_encrypted_columns}, ) def run(self) -> ReEncryptStats: diff --git a/superset/utils/mock_data.py b/superset/utils/mock_data.py index 58647dffa3e3..c8b4d1470e15 100644 --- a/superset/utils/mock_data.py +++ b/superset/utils/mock_data.py @@ -31,7 +31,7 @@ from sqlalchemy import Column, inspect, MetaData, Table as DBTable from sqlalchemy.dialects import postgresql from sqlalchemy.sql import func -from sqlalchemy.sql.visitors import VisitableType +from sqlalchemy.types import TypeEngine from superset import db from superset.sql.parse import Table @@ -42,7 +42,7 @@ class ColumnInfo(TypedDict): name: str - type: VisitableType + type: TypeEngine nullable: bool default: Optional[Any] autoincrement: str diff --git a/tests/integration_tests/reports/commands_tests.py b/tests/integration_tests/reports/commands_tests.py index ae1c855d85ae..3cc771a65de7 100644 --- a/tests/integration_tests/reports/commands_tests.py +++ b/tests/integration_tests/reports/commands_tests.py @@ -23,7 +23,6 @@ import pytest from flask.ctx import AppContext from flask_appbuilder.security.sqla.models import User -from flask_sqlalchemy import BaseQuery from freezegun import freeze_time from slack_sdk.errors import ( BotUserAccessError, @@ -37,6 +36,13 @@ ) from sqlalchemy.sql import func +try: + # Flask-SQLAlchemy 3.x (required by SQLAlchemy 2.0) + from flask_sqlalchemy.query import Query as BaseQuery +except ImportError: # pragma: no cover + # Flask-SQLAlchemy 2.x + from flask_sqlalchemy import BaseQuery + from superset import db from superset.commands.report.exceptions import ( AlertQueryError, diff --git a/tests/integration_tests/utils/encrypt_tests.py b/tests/integration_tests/utils/encrypt_tests.py index 5c88d43ecef7..2f994b86f6b3 100644 --- a/tests/integration_tests/utils/encrypt_tests.py +++ b/tests/integration_tests/utils/encrypt_tests.py @@ -31,6 +31,28 @@ from tests.integration_tests.base_tests import SupersetTestCase +def make_row(values: dict[str, Any]) -> Any: + """Build a genuine SQLAlchemy ``Row`` from a mapping. + + ``SecretsMigrator._re_encrypt_row`` consumes the ``Row`` objects yielded by + ``conn.execute(...)``, reading column values through ``row._mapping`` per the + SQLAlchemy 2.0 Row API. Tests must therefore pass a real ``Row`` rather than a + plain ``dict`` (which lacks ``_mapping``). The constructor signature differs + between SQLAlchemy 1.4 and 2.0, so both are handled here. + """ + from sqlalchemy.engine.result import SimpleResultMetaData + from sqlalchemy.engine.row import Row + + metadata = SimpleResultMetaData(tuple(values)) + data = tuple(values.values()) + try: + # SQLAlchemy 2.0: Row(parent, processors, key_to_index, data) + return Row(metadata, None, metadata._key_to_index, data) + except AttributeError: + # SQLAlchemy 1.4: Row(parent, processors, keymap, key_style, data) + return Row(metadata, None, metadata._keymap, Row._default_key_style, data) + + class CustomEncFieldAdapter(AbstractEncryptedFieldAdapter): def create( self, @@ -224,7 +246,8 @@ def test_re_encrypt_row_uses_pk_columns(self): current_field = encrypted_field_factory.create(String(1024)) conn = MagicMock() - row = {"uuid": b"\x00" * 16, "configuration": ciphertext} + pk_value = b"\x00" * 16 + row = make_row({"uuid": pk_value, "configuration": ciphertext}) stats = ReEncryptStats() migrator._re_encrypt_row( # noqa: SLF001 @@ -239,9 +262,11 @@ def test_re_encrypt_row_uses_pk_columns(self): assert conn.execute.call_count == 1 stmt = str(conn.execute.call_args.args[0]) assert "WHERE uuid = :_pk_uuid" in stmt - kwargs = conn.execute.call_args.kwargs - assert kwargs["_pk_uuid"] == row["uuid"] - assert "configuration" in kwargs + # The migrator passes bind params positionally (conn.execute(stmt, params)), + # so read them from args[1] rather than kwargs. + params = conn.execute.call_args.args[1] + assert params["_pk_uuid"] == pk_value + assert "configuration" in params assert stats == ReEncryptStats(re_encrypted=1, skipped=0, failed=0) def test_re_encrypt_row_is_idempotent(self): @@ -264,7 +289,7 @@ def test_re_encrypt_row_is_idempotent(self): assert field.process_result_value(ciphertext, dialect) == "hunter2" conn = MagicMock() - row = {"uuid": b"\x00" * 16, "configuration": ciphertext} + row = make_row({"uuid": b"\x00" * 16, "configuration": ciphertext}) stats = ReEncryptStats() migrator._re_encrypt_row( # noqa: SLF001 @@ -308,7 +333,7 @@ def test_re_encrypt_row_idempotent_when_previous_key_also_decrypts(self): ciphertext = field.process_bind_param("hunter2", dialect) conn = MagicMock() - row = {"uuid": b"\x00" * 16, "configuration": ciphertext} + row = make_row({"uuid": b"\x00" * 16, "configuration": ciphertext}) stats = ReEncryptStats() migrator._re_encrypt_row( # noqa: SLF001 @@ -342,7 +367,7 @@ def test_re_encrypt_row_counts_failures_without_raising(self): field = encrypted_field_factory.create(String(1024)) conn = MagicMock() - row = {"uuid": b"\x00" * 16, "configuration": b"not-valid-ciphertext"} + row = make_row({"uuid": b"\x00" * 16, "configuration": b"not-valid-ciphertext"}) stats = ReEncryptStats() migrator._re_encrypt_row( # noqa: SLF001 @@ -374,7 +399,7 @@ def test_re_encrypt_row_counts_nulls_separately(self): field = encrypted_field_factory.create(String(1024)) conn = MagicMock() - row = {"uuid": b"\x00" * 16, "configuration": None} + row = make_row({"uuid": b"\x00" * 16, "configuration": None}) stats = ReEncryptStats() migrator._re_encrypt_row( # noqa: SLF001 diff --git a/tests/unit_tests/utils/encrypt_test.py b/tests/unit_tests/utils/encrypt_test.py index 46ff1a20d4f2..2fc65f2cebb7 100644 --- a/tests/unit_tests/utils/encrypt_test.py +++ b/tests/unit_tests/utils/encrypt_test.py @@ -57,6 +57,18 @@ def _engine_migrator(target_engine: type) -> SecretsMigrator: return migrator +class _Row: + """Minimal stand-in for a SQLAlchemy ``Row``. + + ``_re_encrypt_row`` accesses columns via ``row._mapping[...]`` (the + SQLAlchemy 2.0-compatible idiom), so the fixtures wrap their column dicts + in an object exposing that attribute rather than passing a bare ``dict``. + """ + + def __init__(self, mapping: dict[str, object]) -> None: + self._mapping = mapping + + def test_default_engine_is_aes_cbc() -> None: """Without config, the adapter keeps the historical AES-CBC engine.""" field = SQLAlchemyUtilsAdapter().create(SECRET, String(128)) @@ -156,7 +168,7 @@ def test_engine_migration_cbc_to_gcm_re_encrypts() -> None: migrator = _engine_migrator(AesGcmEngine) conn = MagicMock() - row = {"id": 1, "password": ciphertext} + row = _Row({"id": 1, "password": ciphertext}) stats = ReEncryptStats() migrator._re_encrypt_row( # noqa: SLF001 @@ -165,7 +177,7 @@ def test_engine_migration_cbc_to_gcm_re_encrypts() -> None: assert stats == ReEncryptStats(re_encrypted=1) assert conn.execute.call_count == 1 - new_value = conn.execute.call_args.kwargs["password"] + new_value = conn.execute.call_args.args[1]["password"] # The stored value changed and now decrypts as GCM back to the plaintext. assert new_value != ciphertext gcm = _encrypted_type(AesGcmEngine) @@ -183,7 +195,7 @@ def test_engine_migration_idempotent_for_already_target() -> None: migrator = _engine_migrator(AesGcmEngine) conn = MagicMock() - row = {"id": 1, "password": gcm_value} + row = _Row({"id": 1, "password": gcm_value}) stats = ReEncryptStats() migrator._re_encrypt_row( # noqa: SLF001 @@ -206,7 +218,7 @@ def test_engine_migration_reads_cbc_after_config_already_flipped() -> None: migrator = _engine_migrator(AesGcmEngine) conn = MagicMock() - row = {"id": 1, "password": cbc_value} + row = _Row({"id": 1, "password": cbc_value}) stats = ReEncryptStats() migrator._re_encrypt_row( # noqa: SLF001 @@ -214,7 +226,7 @@ def test_engine_migration_reads_cbc_after_config_already_flipped() -> None: ) assert stats == ReEncryptStats(re_encrypted=1) - new_value = conn.execute.call_args.kwargs["password"] + new_value = conn.execute.call_args.args[1]["password"] assert gcm_column.process_result_value(new_value, DIALECT) == "hunter2" @@ -231,7 +243,7 @@ def test_engine_migration_gcm_to_cbc_rolls_back() -> None: migrator = _engine_migrator(AesEngine) conn = MagicMock() - row = {"id": 1, "password": gcm_value} + row = _Row({"id": 1, "password": gcm_value}) stats = ReEncryptStats() migrator._re_encrypt_row( # noqa: SLF001 @@ -239,7 +251,7 @@ def test_engine_migration_gcm_to_cbc_rolls_back() -> None: ) assert stats == ReEncryptStats(re_encrypted=1) - new_value = conn.execute.call_args.kwargs["password"] + new_value = conn.execute.call_args.args[1]["password"] assert new_value != gcm_value # The rolled-back value now decrypts as AES-CBC back to the plaintext. assert _encrypted_type(AesEngine).process_result_value(new_value, DIALECT) == ( @@ -272,7 +284,7 @@ def test_rollback_authenticated_probe_wins_over_spurious_cbc_skip() -> None: spurious_target.process_bind_param.return_value = b"new-cbc-ciphertext" conn = MagicMock() - row = {"id": 1, "password": gcm_value} + row = _Row({"id": 1, "password": gcm_value}) stats = ReEncryptStats() with mock.patch.object(migrator, "_target_type", return_value=spurious_target): @@ -302,7 +314,7 @@ def test_combined_key_rotation_and_engine_migration() -> None: migrator._previous_secret_key = old_key # noqa: SLF001 # rotate key too conn = MagicMock() - row = {"id": 1, "password": old_value} + row = _Row({"id": 1, "password": old_value}) stats = ReEncryptStats() migrator._re_encrypt_row( # noqa: SLF001 @@ -310,7 +322,7 @@ def test_combined_key_rotation_and_engine_migration() -> None: ) assert stats == ReEncryptStats(re_encrypted=1) - new_value = conn.execute.call_args.kwargs["password"] + new_value = conn.execute.call_args.args[1]["password"] # The migrated value decrypts as GCM under the *current* key. assert _encrypted_type(AesGcmEngine).process_result_value(new_value, DIALECT) == ( "hunter2" @@ -346,7 +358,7 @@ def test_key_rotation_for_aes_gcm_column() -> None: migrator = _key_rotation_migrator(previous_secret_key=old_key) conn = MagicMock() - row = {"id": 1, "password": old_value} + row = _Row({"id": 1, "password": old_value}) stats = ReEncryptStats() migrator._re_encrypt_row( # noqa: SLF001 @@ -354,7 +366,7 @@ def test_key_rotation_for_aes_gcm_column() -> None: ) assert stats == ReEncryptStats(re_encrypted=1) - new_value = conn.execute.call_args.kwargs["password"] + new_value = conn.execute.call_args.args[1]["password"] assert gcm_column.process_result_value(new_value, DIALECT) == "hunter2" @@ -362,7 +374,7 @@ def test_engine_migration_unreadable_value_counts_as_failure() -> None: """A value no engine/key can read is a failure, not a silent pass-through.""" migrator = _engine_migrator(AesGcmEngine) conn = MagicMock() - row = {"id": 1, "password": b"not-valid-ciphertext"} + row = _Row({"id": 1, "password": b"not-valid-ciphertext"}) stats = ReEncryptStats() migrator._re_encrypt_row( # noqa: SLF001