Skip to content

Commit c44c83c

Browse files
authored
Feat: add support for parameterized python model names (#3208)
1 parent e9fb1f4 commit c44c83c

5 files changed

Lines changed: 59 additions & 20 deletions

File tree

sqlmesh/core/dialect.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1166,7 +1166,7 @@ def _unquote_schema(schema: t.Dict) -> t.Dict:
11661166

11671167
@contextmanager
11681168
def normalize_and_quote(
1169-
query: E, dialect: str, default_catalog: t.Optional[str], quote: bool = True
1169+
query: E, dialect: DialectType, default_catalog: t.Optional[str], quote: bool = True
11701170
) -> t.Iterator[E]:
11711171
qualify_tables(query, catalog=default_catalog, dialect=dialect)
11721172
normalize_identifiers(query, dialect=dialect)

sqlmesh/core/macros.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
from sqlmesh.utils.metaprogramming import Executable, prepare_env, print_exception
4545

4646
if t.TYPE_CHECKING:
47+
from sqlglot.dialects.dialect import DialectType
4748
from sqlmesh.core._typing import TableName
4849
from sqlmesh.core.engine_adapter import EngineAdapter
4950
from sqlmesh.core.snapshot import Snapshot
@@ -144,7 +145,7 @@ class MacroEvaluator:
144145

145146
def __init__(
146147
self,
147-
dialect: str = "",
148+
dialect: DialectType = "",
148149
python_env: t.Optional[t.Dict[str, Executable]] = None,
149150
jinja_env: t.Optional[Environment] = None,
150151
schema: t.Optional[MappingSchema] = None,

sqlmesh/core/model/definition.py

Lines changed: 50 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
)
4747

4848
if t.TYPE_CHECKING:
49+
from sqlglot.dialects.dialect import DialectType
4950
from sqlmesh.core._typing import TableName
5051
from sqlmesh.core.audit import ModelAudit, Audit
5152
from sqlmesh.core.context import ExecutionContext
@@ -1600,24 +1601,15 @@ def load_sql_based_model(
16001601
if prop.name.lower() == "audits":
16011602
model_audits = prop.args.get("value")
16021603

1603-
meta_python_env = _python_env(
1604-
expressions=meta,
1605-
jinja_macro_references=None,
1604+
meta_renderer = _meta_renderer(
1605+
expression=meta,
16061606
module_path=module_path,
1607-
macros=macros or macro.get_registry(),
1607+
macros=macros,
1608+
jinja_macros=jinja_macros,
16081609
variables=variables,
16091610
path=path,
1610-
)
1611-
meta_renderer = ExpressionRenderer(
1612-
meta,
1613-
dialect,
1614-
[],
1615-
path=path,
1616-
jinja_macro_registry=jinja_macros,
1617-
python_env=meta_python_env,
1611+
dialect=dialect,
16181612
default_catalog=default_catalog,
1619-
quote_identifiers=False,
1620-
normalize_identifiers=False,
16211613
)
16221614

16231615
rendered_meta_exprs = meta_renderer.render()
@@ -1954,6 +1946,19 @@ def create_python_model(
19541946
)
19551947
)
19561948

1949+
dialect = kwargs.get("dialect")
1950+
name_renderer = _meta_renderer(
1951+
expression=d.parse_one(name, dialect=dialect),
1952+
module_path=module_path,
1953+
macros=macros,
1954+
jinja_macros=jinja_macros,
1955+
variables=variables,
1956+
path=path,
1957+
dialect=dialect,
1958+
default_catalog=kwargs.get("default_catalog"),
1959+
)
1960+
name = t.cast(t.List[exp.Expression], name_renderer.render())[0].sql(dialect=dialect)
1961+
19571962
parsed_depends_on, referenced_variables = (
19581963
_parse_dependencies(python_env, entrypoint) if python_env is not None else (set(), set())
19591964
)
@@ -2385,6 +2390,37 @@ def _refs_to_sql(values: t.Any) -> exp.Expression:
23852390
return exp.Tuple(expressions=values)
23862391

23872392

2393+
def _meta_renderer(
2394+
expression: exp.Expression,
2395+
module_path: Path,
2396+
path: Path,
2397+
jinja_macros: t.Optional[JinjaMacroRegistry] = None,
2398+
macros: t.Optional[MacroRegistry] = None,
2399+
dialect: DialectType = None,
2400+
variables: t.Optional[t.Dict[str, t.Any]] = None,
2401+
default_catalog: t.Optional[str] = None,
2402+
) -> ExpressionRenderer:
2403+
meta_python_env = _python_env(
2404+
expressions=expression,
2405+
jinja_macro_references=None,
2406+
module_path=module_path,
2407+
macros=macros or macro.get_registry(),
2408+
variables=variables,
2409+
path=path,
2410+
)
2411+
return ExpressionRenderer(
2412+
expression,
2413+
dialect,
2414+
[],
2415+
path=path,
2416+
jinja_macro_registry=jinja_macros,
2417+
python_env=meta_python_env,
2418+
default_catalog=default_catalog,
2419+
quote_identifiers=False,
2420+
normalize_identifiers=False,
2421+
)
2422+
2423+
23882424
META_FIELD_CONVERTER: t.Dict[str, t.Callable] = {
23892425
"start": lambda value: exp.Literal.string(value),
23902426
"cron": lambda value: exp.Literal.string(value),

sqlmesh/core/renderer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828

2929
if t.TYPE_CHECKING:
3030
from sqlglot._typing import E
31+
from sqlglot.dialects.dialect import DialectType
3132

3233
from sqlmesh.core.snapshot import DeployabilityIndex, Snapshot
3334

@@ -39,7 +40,7 @@ class BaseExpressionRenderer:
3940
def __init__(
4041
self,
4142
expression: exp.Expression,
42-
dialect: str,
43+
dialect: DialectType,
4344
macro_definitions: t.List[d.MacroDef],
4445
path: Path = Path(),
4546
jinja_macro_registry: t.Optional[JinjaMacroRegistry] = None,

tests/core/test_model.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4565,7 +4565,7 @@ def test_variables_jinja():
45654565

45664566
def test_variables_python_model(mocker: MockerFixture) -> None:
45674567
@model(
4568-
"test_variables_python_model",
4568+
"foo_@{bar}",
45694569
kind="full",
45704570
columns={"a": "string", "b": "string", "c": "string"},
45714571
)
@@ -4580,12 +4580,13 @@ def model_with_variables(context, **kwargs):
45804580
]
45814581
)
45824582

4583-
python_model = model.get_registry()["test_variables_python_model"].model(
4583+
python_model = model.get_registry()["foo_@{bar}"].model(
45844584
module_path=Path("."),
45854585
path=Path("."),
4586-
variables={"test_var_a": "test_value", "test_var_unused": 2},
4586+
variables={"test_var_a": "test_value", "test_var_unused": 2, "bar": "suffix"},
45874587
)
45884588

4589+
assert python_model.name == "foo_suffix"
45894590
assert python_model.python_env[c.SQLMESH_VARS] == Executable.value({"test_var_a": "test_value"})
45904591

45914592
context = ExecutionContext(mocker.Mock(), {}, None, None)

0 commit comments

Comments
 (0)