2828from sqlmesh .utils .metaprogramming import Executable , prepare_env , print_exception
2929
3030if t .TYPE_CHECKING :
31+ from sqlmesh .core ._typing import TableName
3132 from sqlmesh .core .engine_adapter import EngineAdapter
3233
3334
@@ -278,12 +279,18 @@ def jinja_env(self) -> Environment:
278279 self ._jinja_env = JinjaMacroRegistry ().build_environment (** jinja_env_methods )
279280 return self ._jinja_env
280281
281- def columns_to_types (self , model_name : str ) -> t .Dict [str , exp .DataType ]:
282+ def columns_to_types (self , model_name : TableName | exp . Column ) -> t .Dict [str , exp .DataType ]:
282283 """Returns the columns-to-types mapping corresponding to the specified model."""
283284 if not isinstance (self ._schema , MappingSchema ):
284285 self .columns_to_types_called = True
285286 return {"__schema_unavailable_at_load__" : exp .DataType .build ("unknown" )}
286287
288+ if isinstance (model_name , exp .Column ):
289+ model_name = exp .table_ (
290+ model_name .this ,
291+ db = model_name .args .get ("table" ),
292+ catalog = model_name .args .get ("db" ),
293+ )
287294 columns_to_types = self ._schema .find (exp .to_table (model_name ))
288295 if columns_to_types is None :
289296 raise SQLMeshError (f"Schema for model '{ model_name } ' can't be statically determined." )
@@ -680,14 +687,25 @@ def eval_(evaluator: MacroEvaluator, condition: exp.Condition) -> t.Any:
680687def star (
681688 evaluator : MacroEvaluator ,
682689 relation : exp .Table ,
683- alias : t . Optional [ exp .Identifier | exp . Column ] = None ,
684- except_ : t . Optional [ exp .Array | exp .Tuple ] = None ,
690+ alias : exp .Column = exp . column ( "" ) ,
691+ except_ : exp .Array | exp .Tuple = exp . Tuple ( this = []) ,
685692 prefix : exp .Literal = exp .Literal .string ("" ),
686693 suffix : exp .Literal = exp .Literal .string ("" ),
687694 quote_identifiers : exp .Boolean = exp .true (),
688695) -> t .List [exp .Alias ]:
689696 """Returns a list of projections for the given relation.
690697
698+ Args:
699+ evaluator: MacroEvaluator that invoked the macro
700+ relation: The relation to select star from
701+ alias: The alias of the relation
702+ except_: Columns to exclude
703+ prefix: A prefix to use for all selections
704+ suffix: A suffix to use for all selections
705+ quote_identifiers: Whether or not quote the resulting aliases, defaults to true
706+ Returns:
707+ An array of columns.
708+
691709 Example:
692710 >>> from sqlglot import parse_one
693711 >>> from sqlmesh.core.macros import MacroEvaluator
@@ -705,28 +723,21 @@ def star(
705723 raise SQLMeshError (f"Invalid suffix '{ suffix } '. Expected a literal." )
706724 if not isinstance (quote_identifiers , exp .Boolean ):
707725 raise SQLMeshError (f"Invalid quote_identifiers '{ quote_identifiers } '. Expected a boolean." )
708- projections : t .List [exp .Alias ] = []
709- exclude = set ()
710- kwargs = {"quoted" : quote_identifiers .this }
711- if alias :
712- kwargs ["table" ] = alias .name
713- if except_ :
714- exclude |= {
715- e .name for e in except_ .expressions if isinstance (e , (exp .Identifier , exp .Column ))
716- }
717- for column , type_ in evaluator .columns_to_types (relation .sql ()).items ():
718- if column in exclude :
719- continue
720- projections .append (
721- exp .cast (exp .column (column , ** kwargs ), type_ ).as_ (
722- f"{ prefix .this } { column } { suffix .this } " , quoted = kwargs ["quoted" ]
723- )
726+
727+ exclude = {e .name for e in except_ .expressions }
728+ quoted = quote_identifiers .this
729+
730+ return [
731+ exp .cast (exp .column (column , table = alias .name , quoted = quoted ), type_ ).as_ (
732+ f"{ prefix .this } { column } { suffix .this } " , quoted = quoted
724733 )
725- return projections
734+ for column , type_ in evaluator .columns_to_types (relation ).items ()
735+ if column not in exclude
736+ ]
726737
727738
728739@macro ()
729- def generate_surrogate_key (_ : MacroEvaluator , * fields : exp .Column | exp . Identifier ) -> exp .Func :
740+ def generate_surrogate_key (_ : MacroEvaluator , * fields : exp .Column ) -> exp .Func :
730741 """Generates a surrogate key for the given fields.
731742
732743 Example:
@@ -736,16 +747,15 @@ def generate_surrogate_key(_: MacroEvaluator, *fields: exp.Column | exp.Identifi
736747 >>> MacroEvaluator().transform(parse_one(sql)).sql()
737748 "SELECT MD5(CONCAT(COALESCE(CAST(a AS TEXT), '_sqlmesh_surrogate_key_null_'), '|', COALESCE(CAST(b AS TEXT), '_sqlmesh_surrogate_key_null_'), '|', COALESCE(CAST(c AS TEXT), '_sqlmesh_surrogate_key_null_'))) FROM foo"
738749 """
739- default_null_value = exp .Literal .string ("_sqlmesh_surrogate_key_null_" )
740750 string_fields : t .List [exp .Expression ] = []
741751 for i , field in enumerate (fields ):
742752 if i > 0 :
743753 string_fields .append (exp .Literal .string ("|" ))
744754 string_fields .append (
745755 exp .func (
746756 "COALESCE" ,
747- exp .cast (field , exp .DataType .build ("string " )),
748- default_null_value ,
757+ exp .cast (field , exp .DataType .build ("text " )),
758+ exp . Literal . string ( "_sqlmesh_surrogate_key_null_" ) ,
749759 )
750760 )
751761 return exp .func ("MD5" , exp .func ("CONCAT" , * string_fields ))
@@ -762,12 +772,11 @@ def safe_add(_: MacroEvaluator, *fields: exp.Column) -> exp.Case:
762772 >>> MacroEvaluator().transform(parse_one(sql)).sql()
763773 'SELECT CASE WHEN a IS NULL AND b IS NULL THEN NULL ELSE COALESCE(a, 0) + COALESCE(b, 0) END FROM foo'
764774 """
765- null_cond = exp .and_ (* [field .is_ (exp .null ()) for field in fields ])
766- case = exp .Case ().when (null_cond , exp .null ())
767- terms : t .List [exp .Func | exp .Add ] = []
768- for field in fields :
769- terms .append (exp .func ("COALESCE" , field , 0 ))
770- return case .else_ (reduce (lambda a , b : a + b , terms ))
775+ return (
776+ exp .Case ()
777+ .when (exp .and_ (* (field .is_ (exp .null ()) for field in fields )), exp .null ())
778+ .else_ (reduce (lambda a , b : a + b , [exp .func ("COALESCE" , field , 0 ) for field in fields ])) # type: ignore
779+ )
771780
772781
773782@macro ()
@@ -781,12 +790,11 @@ def safe_sub(_: MacroEvaluator, *fields: exp.Expression) -> exp.Case:
781790 >>> MacroEvaluator().transform(parse_one(sql)).sql()
782791 'SELECT CASE WHEN a IS NULL AND b IS NULL THEN NULL ELSE COALESCE(a, 0) - COALESCE(b, 0) END FROM foo'
783792 """
784- null_cond = exp .and_ (* [field .is_ (exp .null ()) for field in fields ])
785- case = exp .Case ().when (null_cond , exp .null ())
786- terms : t .List [exp .Func | exp .Sub ] = []
787- for field in fields :
788- terms .append (exp .func ("COALESCE" , field , 0 ))
789- return case .else_ (reduce (lambda a , b : a - b , terms ))
793+ return (
794+ exp .Case ()
795+ .when (exp .and_ (* (field .is_ (exp .null ()) for field in fields )), exp .null ())
796+ .else_ (reduce (lambda a , b : a - b , [exp .func ("COALESCE" , field , 0 ) for field in fields ])) # type: ignore
797+ )
790798
791799
792800@macro ()
@@ -798,44 +806,48 @@ def safe_div(_: MacroEvaluator, numerator: exp.Expression, denominator: exp.Expr
798806 >>> from sqlmesh.core.macros import MacroEvaluator
799807 >>> sql = "SELECT @SAFE_DIV(a, b) FROM foo"
800808 >>> MacroEvaluator().transform(parse_one(sql)).sql()
801- 'SELECT a / CASE WHEN b = 0 THEN NULL ELSE b END FROM foo'
809+ 'SELECT a / NULLIF(b, 0) FROM foo'
802810 """
803- return numerator / exp .Case (). when ( denominator . eq ( 0 ), exp . null ()). else_ ( denominator )
811+ return numerator / exp .func ( "NULLIF" , denominator , 0 )
804812
805813
806814@macro ()
807815def union (
808816 evaluator : MacroEvaluator ,
809817 type_ : exp .Literal = exp .Literal .string ("ALL" ),
810- * tables : exp .Table ,
811- ) -> exp .Union :
812- """Returns a UNION of the given tables.
818+ * tables : exp .Column , # These represent tables but the ast node will be columns
819+ ) -> exp .Unionable :
820+ """Returns a UNION of the given tables. Only choosing columns that have the same name and type.
813821
814822 Example:
815823 >>> from sqlglot import parse_one
816824 >>> from sqlmesh.core.macros import MacroEvaluator
817825 >>> sql = "@UNION('distinct', foo, bar)"
818- >>> MacroEvaluator(schema={"foo": {"a": "int", "b": "string", "c": "string"}, "bar": {"a ": "int ", "b ": "int", "c ": "string "}}).transform(parse_one(sql)).sql()
826+ >>> MacroEvaluator(schema={"foo": {"a": "int", "b": "string", "c": "string"}, "bar": {"c ": "string ", "a ": "int", "b ": "int "}}).transform(parse_one(sql)).sql()
819827 'SELECT CAST(a AS INT) AS a, CAST(c AS TEXT) AS c FROM foo UNION SELECT CAST(a AS INT) AS a, CAST(c AS TEXT) AS c FROM bar'
820828 """
821- if type_ .this .upper () not in ("ALL" , "DISTINCT" ):
829+ kind = type_ .name .upper ()
830+ if kind not in ("ALL" , "DISTINCT" ):
822831 raise SQLMeshError (f"Invalid type '{ type_ } '. Expected 'ALL' or 'DISTINCT'." )
823- column_sets : t .List [t .Set [t .Tuple [str , exp .DataType ]]] = []
824- columns_seen : t .Dict [str , None ] = {} # Ensure order is deterministic, 3.6+ dicts are ordered
825- for table in tables :
826- map = evaluator .columns_to_types (table .sql ())
827- column_sets .append (set (map .items ()))
828- for c in map :
829- columns_seen [c ] = None
830- superset = reduce (lambda a , b : a .intersection (b ), column_sets )
831- precedence = {c : i for i , c in enumerate (columns_seen .keys ())}
832- projection = [
833- exp .cast (exp .column (name ), typ ).as_ (name )
834- for name , typ in sorted (superset , key = lambda c : precedence [c [0 ]])
832+
833+ columns = {
834+ column
835+ for column , _ in reduce (
836+ lambda a , b : a & b , # type: ignore
837+ (evaluator .columns_to_types (table ).items () for table in tables ),
838+ )
839+ }
840+
841+ projections = [
842+ exp .cast (column , type_ ).as_ (column )
843+ for column , type_ in evaluator .columns_to_types (tables [0 ]).items ()
844+ if column in columns
835845 ]
836- disinct = type_ .this .upper () == "DISTINCT"
837- selects : t .List [exp .Unionable ] = [exp .select (* projection ).from_ (t ) for t in tables ]
838- return t .cast (exp .Union , reduce (lambda a , b : a .union (b , disinct = disinct ), selects ))
846+
847+ return reduce (
848+ lambda a , b : a .union (b , distinct = kind == "DISTINCT" ), # type: ignore
849+ [exp .select (* projections ).from_ (t ) for t in tables ],
850+ )
839851
840852
841853@macro ()
0 commit comments