Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions test/test_inspect.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,35 @@ def test_has_table(self, connection):
assert inspect.has_table("test")
assert not inspect.has_table("foo")

def test_reflection_ignores_schema(self, connection):
# supports_schemas=False: the `schema` argument is ignored regardless of value,
# so reflection always targets the connected database (the convention two-tier
# SQLAlchemy tooling relies on).
inspect = sa.inspect(connection)
bound_database = connection.connection.driver_connection.database.strip("/")

for schema in (bound_database, "some_other_database"):
assert "test" in inspect.get_table_names(schema=schema)
assert inspect.has_table("test", schema=schema)
assert inspect.get_columns("test", schema=schema)

def test_compile_ignores_schema_prefix(self, connection):
bound_database = connection.connection.driver_connection.database.strip("/")

# A table addressed via the connected database as schema (the way two-tier
# tooling does) must compile without a schema prefix and execute against YDB.
t = sa.Table("test", sa.MetaData(), schema=bound_database, autoload_with=connection)
stmt = sa.select(sa.func.count()).select_from(t)
compiled = str(stmt.compile(connection))
assert f"{bound_database}.`test`" not in compiled
assert f"{bound_database}.test" not in compiled
connection.execute(stmt).scalar()

# Any other schema is likewise dropped rather than leaking into the path.
foreign = sa.Table("test", sa.MetaData(), Column("id", Integer), schema="some_other_database")
compiled_foreign = str(sa.select(sa.func.count()).select_from(foreign).compile(connection))
assert "some_other_database." not in compiled_foreign

def test_view_reflection(self, connection, test_view):
view_name = test_view
inspect = sa.inspect(connection)
Expand Down
14 changes: 2 additions & 12 deletions ydb_sqlalchemy/sqlalchemy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,13 +250,9 @@ def __init__(
self._add_declare_for_yql_stmt_vars = _add_declare_for_yql_stmt_vars
self._statement_prefixes = tuple(_statement_prefixes_list) if _statement_prefixes_list else ()

def _ensure_schema_unsupported(self, schema):
if schema:
raise ydb_dbapi.NotSupportedError("unsupported on non empty schema")

def _describe_table(self, connection, table_name, schema=None) -> ydb.TableDescription:
self._ensure_schema_unsupported(schema)

# supports_schemas=False: the schema argument is ignored, reflection always
# targets the connected database.
qt = table_name if isinstance(table_name, str) else table_name.name
raw_conn = connection.connection
try:
Expand All @@ -266,15 +262,11 @@ def _describe_table(self, connection, table_name, schema=None) -> ydb.TableDescr

@reflection.cache
def get_view_names(self, connection, schema=None, **kw):
self._ensure_schema_unsupported(schema)

raw_conn = connection.connection
return raw_conn.get_view_names()

@reflection.cache
def get_view_definition(self, connection, view_name, schema=None, **kw):
self._ensure_schema_unsupported(schema)

quoted_view_name = self.identifier_preparer.quote(view_name)
result = connection.execute(sa.text(f"SHOW CREATE VIEW {quoted_view_name}"))
row = result.fetchone()
Expand All @@ -301,8 +293,6 @@ def get_columns(self, connection, table_name, schema=None, **kw):

@reflection.cache
def get_table_names(self, connection, schema=None, **kw):
self._ensure_schema_unsupported(schema)

raw_conn = connection.connection
return raw_conn.get_table_names()

Expand Down
8 changes: 8 additions & 0 deletions ydb_sqlalchemy/sqlalchemy/compiler/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,10 @@ class BaseYqlCompiler(StrSQLCompiler):
def get_from_hint_text(self, table, text):
return text

def visit_table(self, table, use_schema=True, **kwargs):
# supports_schemas=False: never emit a schema qualifier in FROM/hint clauses.
return super().visit_table(table, use_schema=False, **kwargs)

def group_by_clause(self, select, **kw):
# Hack to ensure it is possible to define labels in groupby.
kw.update(within_columns_clause=True)
Expand Down Expand Up @@ -530,5 +534,9 @@ def __init__(self, dialect):
final_quote="`",
)

def format_table(self, table, use_schema=True, name=None):
# supports_schemas=False: never emit a schema qualifier in DML/DDL.
return super().format_table(table, use_schema=False, name=name)

def format_index(self, index: sa.Index) -> str:
return super().format_index(index).replace("/", "_")
Loading