|
11 | 11 | from pytest_mock.plugin import MockerFixture |
12 | 12 | from sqlglot import exp, parse_one |
13 | 13 | from sqlglot.schema import MappingSchema |
| 14 | +from sqlmesh.cli.example_project import init_example_project |
14 | 15 |
|
15 | 16 | from sqlmesh.core import constants as c |
16 | 17 | from sqlmesh.core import dialect as d |
17 | | -from sqlmesh.core.config import Config |
18 | | -from sqlmesh.core.config.model import ModelDefaultsConfig |
| 18 | +from sqlmesh.core.config import ( |
| 19 | + Config, |
| 20 | + NameInferenceConfig, |
| 21 | + ModelDefaultsConfig, |
| 22 | +) |
19 | 23 | from sqlmesh.core.context import Context, ExecutionContext |
20 | 24 | from sqlmesh.core.dialect import parse |
21 | 25 | from sqlmesh.core.macros import MacroEvaluator, macro |
22 | 26 | from sqlmesh.core.model import ( |
| 27 | + PythonModel, |
23 | 28 | FullKind, |
24 | 29 | IncrementalByTimeRangeKind, |
25 | 30 | IncrementalUnmanagedKind, |
@@ -1692,26 +1697,38 @@ def b_model(context): |
1692 | 1697 |
|
1693 | 1698 | assert isinstance(python_model.kind, FullKind) |
1694 | 1699 |
|
| 1700 | + @model("kind_empty_dict", kind=dict(), columns={'"COL"': "int"}) |
| 1701 | + def my_model(context): |
| 1702 | + pass |
| 1703 | + |
1695 | 1704 | # error if kind dict with no `name` key |
1696 | 1705 | with pytest.raises(ConfigError, match="`kind` dictionary must contain a `name` key"): |
| 1706 | + python_model = model.get_registry()["kind_empty_dict"].model( |
| 1707 | + module_path=Path("."), |
| 1708 | + path=Path("."), |
| 1709 | + ) |
1697 | 1710 |
|
1698 | | - @model("kind_empty_dict", kind=dict(), columns={'"COL"': "int"}) |
1699 | | - def my_model(context): |
1700 | | - pass |
| 1711 | + @model("kind_dict_badname", kind=dict(name="test"), columns={'"COL"': "int"}) |
| 1712 | + def my_model_1(context): |
| 1713 | + pass |
1701 | 1714 |
|
1702 | 1715 | # error if kind dict with `name` key whose type is not a ModelKindName enum |
1703 | 1716 | with pytest.raises(ConfigError, match="with a valid ModelKindName enum value"): |
| 1717 | + python_model = model.get_registry()["kind_dict_badname"].model( |
| 1718 | + module_path=Path("."), |
| 1719 | + path=Path("."), |
| 1720 | + ) |
1704 | 1721 |
|
1705 | | - @model("kind_dict_badname", kind=dict(name="test"), columns={'"COL"': "int"}) |
1706 | | - def my_model(context): |
1707 | | - pass |
| 1722 | + @model("kind_instance", kind=FullKind(), columns={'"COL"': "int"}) |
| 1723 | + def my_model_2(context): |
| 1724 | + pass |
1708 | 1725 |
|
1709 | 1726 | # warning if kind is ModelKind instance |
1710 | 1727 | with patch.object(logger, "warning") as mock_logger: |
1711 | | - |
1712 | | - @model("kind_instance", kind=FullKind(), columns={'"COL"': "int"}) |
1713 | | - def my_model(context): |
1714 | | - pass |
| 1728 | + python_model = model.get_registry()["kind_instance"].model( |
| 1729 | + module_path=Path("."), |
| 1730 | + path=Path("."), |
| 1731 | + ) |
1715 | 1732 |
|
1716 | 1733 | assert ( |
1717 | 1734 | mock_logger.call_args[0][0] |
@@ -4450,3 +4467,76 @@ def test_incremental_by_partition(sushi_context, assert_exp_eq): |
4450 | 4467 | """ |
4451 | 4468 | ) |
4452 | 4469 | load_sql_based_model(expressions) |
| 4470 | + |
| 4471 | + |
| 4472 | +@pytest.mark.parametrize( |
| 4473 | + ["model_def", "path", "expected_name"], |
| 4474 | + [ |
| 4475 | + [ |
| 4476 | + """dialect duckdb,""", |
| 4477 | + """models/test_schema/test_model.sql,""", |
| 4478 | + "test_schema.test_model", |
| 4479 | + ], |
| 4480 | + [ |
| 4481 | + """dialect duckdb,""", |
| 4482 | + """models/test_model.sql,""", |
| 4483 | + "test_model", |
| 4484 | + ], |
| 4485 | + [ |
| 4486 | + """dialect duckdb,""", |
| 4487 | + """models/inventory/db/test_schema/test_model.sql,""", |
| 4488 | + "db.test_schema.test_model", |
| 4489 | + ], |
| 4490 | + ["""name test_model,""", """models/schema/test_model.sql,""", "test_model"], |
| 4491 | + ], |
| 4492 | +) |
| 4493 | +def test_model_table_name_inference( |
| 4494 | + sushi_context: Context, model_def: str, path: str, expected_name: str |
| 4495 | +): |
| 4496 | + model = load_sql_based_model( |
| 4497 | + d.parse( |
| 4498 | + f""" |
| 4499 | + MODEL ( |
| 4500 | + {model_def} |
| 4501 | + ); |
| 4502 | + SELECT a FROM tbl; |
| 4503 | + """, |
| 4504 | + default_dialect="duckdb", |
| 4505 | + ), |
| 4506 | + path=Path(f"$root/{path}"), |
| 4507 | + infer_names=True, |
| 4508 | + ) |
| 4509 | + assert model.name == expected_name |
| 4510 | + |
| 4511 | + |
| 4512 | +@pytest.mark.parametrize( |
| 4513 | + ["path", "expected_name"], |
| 4514 | + [ |
| 4515 | + [ |
| 4516 | + """models/test_schema/test_model.py""", |
| 4517 | + "test_schema.test_model", |
| 4518 | + ], |
| 4519 | + [ |
| 4520 | + """models/inventory/db/test_schema/test_model.py""", |
| 4521 | + "db.test_schema.test_model", |
| 4522 | + ], |
| 4523 | + ], |
| 4524 | +) |
| 4525 | +def test_python_model_name_inference(tmp_path: Path, path: str, expected_name: str) -> None: |
| 4526 | + init_example_project(tmp_path, dialect="duckdb") |
| 4527 | + config = Config( |
| 4528 | + model_defaults=ModelDefaultsConfig(dialect="duckdb"), |
| 4529 | + model_naming=NameInferenceConfig(infer_names=True), |
| 4530 | + ) |
| 4531 | + |
| 4532 | + foo_py_file = tmp_path / path |
| 4533 | + foo_py_file.parent.mkdir(parents=True, exist_ok=True) |
| 4534 | + foo_py_file.write_text("""from sqlmesh import model |
| 4535 | +@model( |
| 4536 | + columns={'"COL"': "int"}, |
| 4537 | +) |
| 4538 | +def my_model(context, **kwargs): |
| 4539 | + pass""") |
| 4540 | + context = Context(paths=tmp_path, config=config) |
| 4541 | + assert context.get_model(expected_name).name == expected_name |
| 4542 | + assert isinstance(context.get_model(expected_name), PythonModel) |
0 commit comments