Skip to content

Commit b480eb4

Browse files
authored
chore: cleanup new macros (#1672)
1 parent 13ac076 commit b480eb4

File tree

1 file changed

+70
-58
lines changed

1 file changed

+70
-58
lines changed

sqlmesh/core/macros.py

Lines changed: 70 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from sqlmesh.utils.metaprogramming import Executable, prepare_env, print_exception
2929

3030
if t.TYPE_CHECKING:
31+
from sqlmesh.core._typing import TableName
3132
from sqlmesh.core.engine_adapter import EngineAdapter
3233

3334

@@ -278,12 +279,18 @@ def jinja_env(self) -> Environment:
278279
self._jinja_env = JinjaMacroRegistry().build_environment(**jinja_env_methods)
279280
return self._jinja_env
280281

281-
def columns_to_types(self, model_name: str) -> t.Dict[str, exp.DataType]:
282+
def columns_to_types(self, model_name: TableName | exp.Column) -> t.Dict[str, exp.DataType]:
282283
"""Returns the columns-to-types mapping corresponding to the specified model."""
283284
if not isinstance(self._schema, MappingSchema):
284285
self.columns_to_types_called = True
285286
return {"__schema_unavailable_at_load__": exp.DataType.build("unknown")}
286287

288+
if isinstance(model_name, exp.Column):
289+
model_name = exp.table_(
290+
model_name.this,
291+
db=model_name.args.get("table"),
292+
catalog=model_name.args.get("db"),
293+
)
287294
columns_to_types = self._schema.find(exp.to_table(model_name))
288295
if columns_to_types is None:
289296
raise SQLMeshError(f"Schema for model '{model_name}' can't be statically determined.")
@@ -680,14 +687,25 @@ def eval_(evaluator: MacroEvaluator, condition: exp.Condition) -> t.Any:
680687
def star(
681688
evaluator: MacroEvaluator,
682689
relation: exp.Table,
683-
alias: t.Optional[exp.Identifier | exp.Column] = None,
684-
except_: t.Optional[exp.Array | exp.Tuple] = None,
690+
alias: exp.Column = exp.column(""),
691+
except_: exp.Array | exp.Tuple = exp.Tuple(this=[]),
685692
prefix: exp.Literal = exp.Literal.string(""),
686693
suffix: exp.Literal = exp.Literal.string(""),
687694
quote_identifiers: exp.Boolean = exp.true(),
688695
) -> t.List[exp.Alias]:
689696
"""Returns a list of projections for the given relation.
690697
698+
Args:
699+
evaluator: MacroEvaluator that invoked the macro
700+
relation: The relation to select star from
701+
alias: The alias of the relation
702+
except_: Columns to exclude
703+
prefix: A prefix to use for all selections
704+
suffix: A suffix to use for all selections
705+
quote_identifiers: Whether or not quote the resulting aliases, defaults to true
706+
Returns:
707+
An array of columns.
708+
691709
Example:
692710
>>> from sqlglot import parse_one
693711
>>> from sqlmesh.core.macros import MacroEvaluator
@@ -705,28 +723,21 @@ def star(
705723
raise SQLMeshError(f"Invalid suffix '{suffix}'. Expected a literal.")
706724
if not isinstance(quote_identifiers, exp.Boolean):
707725
raise SQLMeshError(f"Invalid quote_identifiers '{quote_identifiers}'. Expected a boolean.")
708-
projections: t.List[exp.Alias] = []
709-
exclude = set()
710-
kwargs = {"quoted": quote_identifiers.this}
711-
if alias:
712-
kwargs["table"] = alias.name
713-
if except_:
714-
exclude |= {
715-
e.name for e in except_.expressions if isinstance(e, (exp.Identifier, exp.Column))
716-
}
717-
for column, type_ in evaluator.columns_to_types(relation.sql()).items():
718-
if column in exclude:
719-
continue
720-
projections.append(
721-
exp.cast(exp.column(column, **kwargs), type_).as_(
722-
f"{prefix.this}{column}{suffix.this}", quoted=kwargs["quoted"]
723-
)
726+
727+
exclude = {e.name for e in except_.expressions}
728+
quoted = quote_identifiers.this
729+
730+
return [
731+
exp.cast(exp.column(column, table=alias.name, quoted=quoted), type_).as_(
732+
f"{prefix.this}{column}{suffix.this}", quoted=quoted
724733
)
725-
return projections
734+
for column, type_ in evaluator.columns_to_types(relation).items()
735+
if column not in exclude
736+
]
726737

727738

728739
@macro()
729-
def generate_surrogate_key(_: MacroEvaluator, *fields: exp.Column | exp.Identifier) -> exp.Func:
740+
def generate_surrogate_key(_: MacroEvaluator, *fields: exp.Column) -> exp.Func:
730741
"""Generates a surrogate key for the given fields.
731742
732743
Example:
@@ -736,16 +747,15 @@ def generate_surrogate_key(_: MacroEvaluator, *fields: exp.Column | exp.Identifi
736747
>>> MacroEvaluator().transform(parse_one(sql)).sql()
737748
"SELECT MD5(CONCAT(COALESCE(CAST(a AS TEXT), '_sqlmesh_surrogate_key_null_'), '|', COALESCE(CAST(b AS TEXT), '_sqlmesh_surrogate_key_null_'), '|', COALESCE(CAST(c AS TEXT), '_sqlmesh_surrogate_key_null_'))) FROM foo"
738749
"""
739-
default_null_value = exp.Literal.string("_sqlmesh_surrogate_key_null_")
740750
string_fields: t.List[exp.Expression] = []
741751
for i, field in enumerate(fields):
742752
if i > 0:
743753
string_fields.append(exp.Literal.string("|"))
744754
string_fields.append(
745755
exp.func(
746756
"COALESCE",
747-
exp.cast(field, exp.DataType.build("string")),
748-
default_null_value,
757+
exp.cast(field, exp.DataType.build("text")),
758+
exp.Literal.string("_sqlmesh_surrogate_key_null_"),
749759
)
750760
)
751761
return exp.func("MD5", exp.func("CONCAT", *string_fields))
@@ -762,12 +772,11 @@ def safe_add(_: MacroEvaluator, *fields: exp.Column) -> exp.Case:
762772
>>> MacroEvaluator().transform(parse_one(sql)).sql()
763773
'SELECT CASE WHEN a IS NULL AND b IS NULL THEN NULL ELSE COALESCE(a, 0) + COALESCE(b, 0) END FROM foo'
764774
"""
765-
null_cond = exp.and_(*[field.is_(exp.null()) for field in fields])
766-
case = exp.Case().when(null_cond, exp.null())
767-
terms: t.List[exp.Func | exp.Add] = []
768-
for field in fields:
769-
terms.append(exp.func("COALESCE", field, 0))
770-
return case.else_(reduce(lambda a, b: a + b, terms))
775+
return (
776+
exp.Case()
777+
.when(exp.and_(*(field.is_(exp.null()) for field in fields)), exp.null())
778+
.else_(reduce(lambda a, b: a + b, [exp.func("COALESCE", field, 0) for field in fields])) # type: ignore
779+
)
771780

772781

773782
@macro()
@@ -781,12 +790,11 @@ def safe_sub(_: MacroEvaluator, *fields: exp.Expression) -> exp.Case:
781790
>>> MacroEvaluator().transform(parse_one(sql)).sql()
782791
'SELECT CASE WHEN a IS NULL AND b IS NULL THEN NULL ELSE COALESCE(a, 0) - COALESCE(b, 0) END FROM foo'
783792
"""
784-
null_cond = exp.and_(*[field.is_(exp.null()) for field in fields])
785-
case = exp.Case().when(null_cond, exp.null())
786-
terms: t.List[exp.Func | exp.Sub] = []
787-
for field in fields:
788-
terms.append(exp.func("COALESCE", field, 0))
789-
return case.else_(reduce(lambda a, b: a - b, terms))
793+
return (
794+
exp.Case()
795+
.when(exp.and_(*(field.is_(exp.null()) for field in fields)), exp.null())
796+
.else_(reduce(lambda a, b: a - b, [exp.func("COALESCE", field, 0) for field in fields])) # type: ignore
797+
)
790798

791799

792800
@macro()
@@ -798,44 +806,48 @@ def safe_div(_: MacroEvaluator, numerator: exp.Expression, denominator: exp.Expr
798806
>>> from sqlmesh.core.macros import MacroEvaluator
799807
>>> sql = "SELECT @SAFE_DIV(a, b) FROM foo"
800808
>>> MacroEvaluator().transform(parse_one(sql)).sql()
801-
'SELECT a / CASE WHEN b = 0 THEN NULL ELSE b END FROM foo'
809+
'SELECT a / NULLIF(b, 0) FROM foo'
802810
"""
803-
return numerator / exp.Case().when(denominator.eq(0), exp.null()).else_(denominator)
811+
return numerator / exp.func("NULLIF", denominator, 0)
804812

805813

806814
@macro()
807815
def union(
808816
evaluator: MacroEvaluator,
809817
type_: exp.Literal = exp.Literal.string("ALL"),
810-
*tables: exp.Table,
811-
) -> exp.Union:
812-
"""Returns a UNION of the given tables.
818+
*tables: exp.Column, # These represent tables but the ast node will be columns
819+
) -> exp.Unionable:
820+
"""Returns a UNION of the given tables. Only choosing columns that have the same name and type.
813821
814822
Example:
815823
>>> from sqlglot import parse_one
816824
>>> from sqlmesh.core.macros import MacroEvaluator
817825
>>> sql = "@UNION('distinct', foo, bar)"
818-
>>> MacroEvaluator(schema={"foo": {"a": "int", "b": "string", "c": "string"}, "bar": {"a": "int", "b": "int", "c": "string"}}).transform(parse_one(sql)).sql()
826+
>>> MacroEvaluator(schema={"foo": {"a": "int", "b": "string", "c": "string"}, "bar": {"c": "string", "a": "int", "b": "int"}}).transform(parse_one(sql)).sql()
819827
'SELECT CAST(a AS INT) AS a, CAST(c AS TEXT) AS c FROM foo UNION SELECT CAST(a AS INT) AS a, CAST(c AS TEXT) AS c FROM bar'
820828
"""
821-
if type_.this.upper() not in ("ALL", "DISTINCT"):
829+
kind = type_.name.upper()
830+
if kind not in ("ALL", "DISTINCT"):
822831
raise SQLMeshError(f"Invalid type '{type_}'. Expected 'ALL' or 'DISTINCT'.")
823-
column_sets: t.List[t.Set[t.Tuple[str, exp.DataType]]] = []
824-
columns_seen: t.Dict[str, None] = {} # Ensure order is deterministic, 3.6+ dicts are ordered
825-
for table in tables:
826-
map = evaluator.columns_to_types(table.sql())
827-
column_sets.append(set(map.items()))
828-
for c in map:
829-
columns_seen[c] = None
830-
superset = reduce(lambda a, b: a.intersection(b), column_sets)
831-
precedence = {c: i for i, c in enumerate(columns_seen.keys())}
832-
projection = [
833-
exp.cast(exp.column(name), typ).as_(name)
834-
for name, typ in sorted(superset, key=lambda c: precedence[c[0]])
832+
833+
columns = {
834+
column
835+
for column, _ in reduce(
836+
lambda a, b: a & b, # type: ignore
837+
(evaluator.columns_to_types(table).items() for table in tables),
838+
)
839+
}
840+
841+
projections = [
842+
exp.cast(column, type_).as_(column)
843+
for column, type_ in evaluator.columns_to_types(tables[0]).items()
844+
if column in columns
835845
]
836-
disinct = type_.this.upper() == "DISTINCT"
837-
selects: t.List[exp.Unionable] = [exp.select(*projection).from_(t) for t in tables]
838-
return t.cast(exp.Union, reduce(lambda a, b: a.union(b, disinct=disinct), selects))
846+
847+
return reduce(
848+
lambda a, b: a.union(b, distinct=kind == "DISTINCT"), # type: ignore
849+
[exp.select(*projections).from_(t) for t in tables],
850+
)
839851

840852

841853
@macro()

0 commit comments

Comments
 (0)