Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 15 additions & 3 deletions superset/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,22 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import flask_appbuilder
from werkzeug.local import LocalProxy

from superset.app import create_app # noqa: F401
from superset.extensions import (
# SQLAlchemy 2.0 enables "Annotated Declarative" mapping, which inspects class
# attribute type annotations and requires mapped attributes to use ``Mapped[...]``.
# Superset's models (and Flask-AppBuilder mixins) still carry legacy 1.x style
# annotations that are not wrapped in ``Mapped[...]``. Setting ``__allow_unmapped__``
# on the shared declarative base preserves the legacy behavior so those annotations
# are ignored by the ORM. This must run before any model class is defined (i.e.
# before importing ``superset.app``), since the annotation check happens at class
# creation time. Models can be migrated incrementally to the typed ``Mapped[...]``
# form.
flask_appbuilder.Model.__allow_unmapped__ = True

from superset.app import create_app # noqa: E402, F401
from superset.extensions import ( # noqa: E402
appbuilder, # noqa: F401
cache_manager,
db, # noqa: F401
Expand All @@ -28,7 +40,7 @@
security_manager, # noqa: F401
talisman, # noqa: F401
)
from superset.security import SupersetSecurityManager # noqa: F401
from superset.security import SupersetSecurityManager # noqa: E402, F401

# All of the fields located here should be considered legacy. The correct way to
# declare "global" dependencies is to define it in extensions.py,
Expand Down
6 changes: 3 additions & 3 deletions superset/commands/dataset/importers/v1/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from pandas.errors import OutOfBoundsDatetime
from sqlalchemy import BigInteger, Boolean, Date, DateTime, Float, String, Text
from sqlalchemy.exc import MultipleResultsFound
from sqlalchemy.sql.visitors import VisitableType
from sqlalchemy.types import TypeEngine

from superset import db, security_manager
from superset.commands.dataset.exceptions import (
Expand Down Expand Up @@ -94,7 +94,7 @@ def redirect_request(
}


def get_sqla_type(native_type: str) -> VisitableType:
def get_sqla_type(native_type: str) -> TypeEngine:
if native_type.upper() in type_map:
return type_map[native_type.upper()]

Expand All @@ -107,7 +107,7 @@ def get_sqla_type(native_type: str) -> VisitableType:
)


def get_dtype(df: pd.DataFrame, dataset: SqlaTable) -> dict[str, VisitableType]:
def get_dtype(df: pd.DataFrame, dataset: SqlaTable) -> dict[str, TypeEngine]:
return {
column.column_name: get_sqla_type(column.type)
for column in dataset.columns
Expand Down
2 changes: 1 addition & 1 deletion superset/connectors/sqla/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ def is_virtual(self) -> bool:
return self.kind == DatasourceKind.VIRTUAL

@declared_attr
def slices(self) -> RelationshipProperty:
def slices(self) -> Mapped[list["Slice"]]:
Comment thread
rusackas marked this conversation as resolved.
return relationship(
"Slice",
overlaps="table",
Expand Down
4 changes: 3 additions & 1 deletion superset/initialization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import sys
from typing import Any, Callable, TYPE_CHECKING

import sqlalchemy as sa
import wtforms_json
from colorama import Fore, Style
from deprecation import deprecated
Expand Down Expand Up @@ -808,7 +809,8 @@ def check_and_warn_database_connection(self) -> None:
try:
with self.superset_app.app_context():
# Simple connection test
db.engine.execute("SELECT 1")
with db.engine.connect() as connection:
connection.execute(sa.text("SELECT 1"))
except Exception:
db_uri = self.database_uri
safe_uri = make_url_safe(db_uri) if db_uri else "Not configured"
Expand Down
6 changes: 3 additions & 3 deletions superset/models/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@
from sqlalchemy.exc import MultipleResultsFound
from sqlalchemy.ext.declarative import declared_attr
from sqlalchemy.ext.hybrid import hybrid_property
from sqlalchemy.orm import Mapper, Session, validates, with_loader_criteria
from sqlalchemy.orm import Mapped, Mapper, Session, validates, with_loader_criteria
from sqlalchemy.orm.session import ORMExecuteState
from sqlalchemy.sql.elements import ColumnElement, Grouping, literal_column, TextClause
from sqlalchemy.sql.expression import Label, Select, TextAsFrom
Expand Down Expand Up @@ -583,7 +583,7 @@ class AuditMixinNullable(AuditMixin):
)

@declared_attr
def created_by_fk(self) -> sa.Column: # pylint: disable=arguments-renamed
def created_by_fk(self) -> Mapped[Optional[int]]: # pylint: disable=arguments-renamed
return sa.Column(
sa.Integer,
sa.ForeignKey("ab_user.id"),
Expand All @@ -592,7 +592,7 @@ def created_by_fk(self) -> sa.Column: # pylint: disable=arguments-renamed
)

@declared_attr
def changed_by_fk(self) -> sa.Column: # pylint: disable=arguments-renamed
def changed_by_fk(self) -> Mapped[Optional[int]]: # pylint: disable=arguments-renamed
return sa.Column(
sa.Integer,
sa.ForeignKey("ab_user.id"),
Expand Down
7 changes: 6 additions & 1 deletion superset/queries/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,12 @@
# under the License.
from typing import Any

from flask_sqlalchemy import BaseQuery
try:
# Flask-SQLAlchemy 3.x (required by SQLAlchemy 2.0)
from flask_sqlalchemy.query import Query as BaseQuery
except ImportError: # pragma: no cover
# Flask-SQLAlchemy 2.x
from flask_sqlalchemy import BaseQuery

from superset import security_manager
from superset.models.sql_lab import Query
Expand Down
8 changes: 7 additions & 1 deletion superset/queries/saved_queries/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,16 @@

from flask import g
from flask_babel import lazy_gettext as _
from flask_sqlalchemy import BaseQuery
from sqlalchemy import or_
from sqlalchemy.orm.query import Query

try:
# Flask-SQLAlchemy 3.x (required by SQLAlchemy 2.0)
from flask_sqlalchemy.query import Query as BaseQuery
except ImportError: # pragma: no cover
# Flask-SQLAlchemy 2.x
from flask_sqlalchemy import BaseQuery

from superset.models.sql_lab import SavedQuery
from superset.tags.filters import BaseTagIdFilter, BaseTagNameFilter
from superset.views.base import BaseFilter
Expand Down
6 changes: 3 additions & 3 deletions superset/security/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
from jwt.api_jwt import _jwt_global_obj
from sqlalchemy import and_, func as sa_func, inspect, or_
from sqlalchemy.engine.base import Connection
from sqlalchemy.orm import eagerload, joinedload
from sqlalchemy.orm import joinedload
from sqlalchemy.orm.exc import MultipleResultsFound
from sqlalchemy.orm.mapper import Mapper
from sqlalchemy.orm.query import Query as SqlaQuery
Expand Down Expand Up @@ -1800,8 +1800,8 @@ def _get_all_pvms(self) -> list[PermissionView]:
pvms = (
self.session.query(self.permissionview_model)
.options(
eagerload(self.permissionview_model.permission),
eagerload(self.permissionview_model.view_menu),
joinedload(self.permissionview_model.permission),
joinedload(self.permissionview_model.view_menu),
)
.all()
)
Expand Down
9 changes: 4 additions & 5 deletions superset/utils/encrypt.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ def _select_columns_from_table(
table_name: str,
) -> Row:
cols = ",".join(pk_columns + column_names)
return conn.execute(f"SELECT {cols} FROM {table_name}") # noqa: S608
return conn.execute(text(f"SELECT {cols} FROM {table_name}")) # noqa: S608

def _target_type(self, encrypted_type: EncryptedType) -> EncryptedType:
"""The EncryptedType to re-encrypt a value *into*.
Expand Down Expand Up @@ -430,7 +430,7 @@ def _re_encrypt_row(
re_encrypted_columns = {}

for column_name, encrypted_type in columns.items():
raw_value = self._read_bytes(column_name, row[column_name])
raw_value = self._read_bytes(column_name, row._mapping[column_name])

# NULL values aren't encrypted; there is nothing to migrate.
if raw_value is None:
Expand Down Expand Up @@ -508,13 +508,12 @@ def _re_encrypt_row(

set_cols = ",".join(f"{name} = :{name}" for name in re_encrypted_columns)
where_clause = " AND ".join(f"{pk} = :_pk_{pk}" for pk in pk_columns)
pk_bind = {f"_pk_{pk}": row[pk] for pk in pk_columns}
pk_bind = {f"_pk_{pk}": row._mapping[pk] for pk in pk_columns}
conn.execute(
text(
f"UPDATE {table_name} SET {set_cols} WHERE {where_clause}" # noqa: S608
),
**pk_bind,
**re_encrypted_columns,
{**pk_bind, **re_encrypted_columns},
)

def run(self) -> ReEncryptStats:
Expand Down
4 changes: 2 additions & 2 deletions superset/utils/mock_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from sqlalchemy import Column, inspect, MetaData, Table as DBTable
from sqlalchemy.dialects import postgresql
from sqlalchemy.sql import func
from sqlalchemy.sql.visitors import VisitableType
from sqlalchemy.types import TypeEngine

from superset import db
from superset.sql.parse import Table
Expand All @@ -42,7 +42,7 @@

class ColumnInfo(TypedDict):
name: str
type: VisitableType
type: TypeEngine
nullable: bool
default: Optional[Any]
autoincrement: str
Expand Down
8 changes: 7 additions & 1 deletion tests/integration_tests/reports/commands_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import pytest
from flask.ctx import AppContext
from flask_appbuilder.security.sqla.models import User
from flask_sqlalchemy import BaseQuery
from freezegun import freeze_time
from slack_sdk.errors import (
BotUserAccessError,
Expand All @@ -37,6 +36,13 @@
)
from sqlalchemy.sql import func

try:
# Flask-SQLAlchemy 3.x (required by SQLAlchemy 2.0)
from flask_sqlalchemy.query import Query as BaseQuery
except ImportError: # pragma: no cover
# Flask-SQLAlchemy 2.x
from flask_sqlalchemy import BaseQuery
Comment thread
rusackas marked this conversation as resolved.

from superset import db
from superset.commands.report.exceptions import (
AlertQueryError,
Expand Down
41 changes: 33 additions & 8 deletions tests/integration_tests/utils/encrypt_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,28 @@
from tests.integration_tests.base_tests import SupersetTestCase


def make_row(values: dict[str, Any]) -> Any:
"""Build a genuine SQLAlchemy ``Row`` from a mapping.

``SecretsMigrator._re_encrypt_row`` consumes the ``Row`` objects yielded by
``conn.execute(...)``, reading column values through ``row._mapping`` per the
SQLAlchemy 2.0 Row API. Tests must therefore pass a real ``Row`` rather than a
plain ``dict`` (which lacks ``_mapping``). The constructor signature differs
between SQLAlchemy 1.4 and 2.0, so both are handled here.
"""
from sqlalchemy.engine.result import SimpleResultMetaData
from sqlalchemy.engine.row import Row

metadata = SimpleResultMetaData(tuple(values))
data = tuple(values.values())
try:
# SQLAlchemy 2.0: Row(parent, processors, key_to_index, data)
return Row(metadata, None, metadata._key_to_index, data)
except AttributeError:
# SQLAlchemy 1.4: Row(parent, processors, keymap, key_style, data)
return Row(metadata, None, metadata._keymap, Row._default_key_style, data)


class CustomEncFieldAdapter(AbstractEncryptedFieldAdapter):
def create(
self,
Expand Down Expand Up @@ -224,7 +246,8 @@ def test_re_encrypt_row_uses_pk_columns(self):

current_field = encrypted_field_factory.create(String(1024))
conn = MagicMock()
row = {"uuid": b"\x00" * 16, "configuration": ciphertext}
pk_value = b"\x00" * 16
row = make_row({"uuid": pk_value, "configuration": ciphertext})
stats = ReEncryptStats()

migrator._re_encrypt_row( # noqa: SLF001
Expand All @@ -239,9 +262,11 @@ def test_re_encrypt_row_uses_pk_columns(self):
assert conn.execute.call_count == 1
stmt = str(conn.execute.call_args.args[0])
assert "WHERE uuid = :_pk_uuid" in stmt
kwargs = conn.execute.call_args.kwargs
assert kwargs["_pk_uuid"] == row["uuid"]
assert "configuration" in kwargs
# The migrator passes bind params positionally (conn.execute(stmt, params)),
# so read them from args[1] rather than kwargs.
params = conn.execute.call_args.args[1]
assert params["_pk_uuid"] == pk_value
assert "configuration" in params
assert stats == ReEncryptStats(re_encrypted=1, skipped=0, failed=0)

def test_re_encrypt_row_is_idempotent(self):
Expand All @@ -264,7 +289,7 @@ def test_re_encrypt_row_is_idempotent(self):
assert field.process_result_value(ciphertext, dialect) == "hunter2"

conn = MagicMock()
row = {"uuid": b"\x00" * 16, "configuration": ciphertext}
row = make_row({"uuid": b"\x00" * 16, "configuration": ciphertext})
stats = ReEncryptStats()

migrator._re_encrypt_row( # noqa: SLF001
Expand Down Expand Up @@ -308,7 +333,7 @@ def test_re_encrypt_row_idempotent_when_previous_key_also_decrypts(self):
ciphertext = field.process_bind_param("hunter2", dialect)

conn = MagicMock()
row = {"uuid": b"\x00" * 16, "configuration": ciphertext}
row = make_row({"uuid": b"\x00" * 16, "configuration": ciphertext})
stats = ReEncryptStats()

migrator._re_encrypt_row( # noqa: SLF001
Expand Down Expand Up @@ -342,7 +367,7 @@ def test_re_encrypt_row_counts_failures_without_raising(self):

field = encrypted_field_factory.create(String(1024))
conn = MagicMock()
row = {"uuid": b"\x00" * 16, "configuration": b"not-valid-ciphertext"}
row = make_row({"uuid": b"\x00" * 16, "configuration": b"not-valid-ciphertext"})
stats = ReEncryptStats()

migrator._re_encrypt_row( # noqa: SLF001
Expand Down Expand Up @@ -374,7 +399,7 @@ def test_re_encrypt_row_counts_nulls_separately(self):

field = encrypted_field_factory.create(String(1024))
conn = MagicMock()
row = {"uuid": b"\x00" * 16, "configuration": None}
row = make_row({"uuid": b"\x00" * 16, "configuration": None})
stats = ReEncryptStats()

migrator._re_encrypt_row( # noqa: SLF001
Expand Down
Loading
Loading